package org.tribuo.util.infotheory.impl;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:org/tribuo/util/infotheory/impl/PairDistribution.class */
public class PairDistribution<T1, T2> {
    public final long count;
    public final Map<CachedPair<T1, T2>, MutableLong> jointCounts;
    public final Map<T1, MutableLong> firstCount;
    public final Map<T2, MutableLong> secondCount;

    public PairDistribution(long j, Map<CachedPair<T1, T2>, MutableLong> map, Map<T1, MutableLong> map2, Map<T2, MutableLong> map3) {
        this.count = j;
        this.jointCounts = new LinkedHashMap(map);
        this.firstCount = new LinkedHashMap(map2);
        this.secondCount = new LinkedHashMap(map3);
    }

    public PairDistribution(long j, LinkedHashMap<CachedPair<T1, T2>, MutableLong> linkedHashMap, LinkedHashMap<T1, MutableLong> linkedHashMap2, LinkedHashMap<T2, MutableLong> linkedHashMap3) {
        this.count = j;
        this.jointCounts = linkedHashMap;
        this.firstCount = linkedHashMap2;
        this.secondCount = linkedHashMap3;
    }

    public static <T1, T2> PairDistribution<T1, T2> constructFromLists(List<T1> list, List<T2> list2) {
        LinkedHashMap linkedHashMap = new LinkedHashMap(20);
        LinkedHashMap linkedHashMap2 = new LinkedHashMap(20);
        LinkedHashMap linkedHashMap3 = new LinkedHashMap(20);
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Counting requires arrays of the same length. first.size() = " + list.size() + ", second.size() = " + list2.size());
        }
        long j = 0;
        for (int i = 0; i < list.size(); i++) {
            T1 t1 = list.get(i);
            T2 t2 = list2.get(i);
            ((MutableLong) linkedHashMap.computeIfAbsent(new CachedPair(t1, t2), cachedPair -> {
                return new MutableLong();
            })).increment();
            ((MutableLong) linkedHashMap2.computeIfAbsent(t1, obj -> {
                return new MutableLong();
            })).increment();
            ((MutableLong) linkedHashMap3.computeIfAbsent(t2, obj2 -> {
                return new MutableLong();
            })).increment();
            j++;
        }
        return new PairDistribution<>(j, linkedHashMap, linkedHashMap2, linkedHashMap3);
    }

    public static <T1, T2> PairDistribution<T1, T2> constructFromMap(Map<CachedPair<T1, T2>, MutableLong> map) {
        return constructFromMap(map, new HashMap(20), new HashMap(20));
    }

    public static <T1, T2> PairDistribution<T1, T2> constructFromMap(Map<CachedPair<T1, T2>, MutableLong> map, int i, int i2) {
        return constructFromMap(map, new HashMap(i), new HashMap(i2));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T1, T2> PairDistribution<T1, T2> constructFromMap(Map<CachedPair<T1, T2>, MutableLong> map, Map<T1, MutableLong> map2, Map<T2, MutableLong> map3) {
        long j = 0;
        for (Map.Entry<CachedPair<T1, T2>, MutableLong> entry : map.entrySet()) {
            CachedPair<T1, T2> key = entry.getKey();
            long longValue = entry.getValue().longValue();
            Object a = key.getA();
            Object b = key.getB();
            ((MutableLong) map2.computeIfAbsent(a, obj -> {
                return new MutableLong();
            })).increment(longValue);
            ((MutableLong) map3.computeIfAbsent(b, obj2 -> {
                return new MutableLong();
            })).increment(longValue);
            j += longValue;
        }
        return new PairDistribution<>(j, map, map2, map3);
    }
}
