package org.tribuo.util.infotheory.impl;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.tribuo.util.infotheory.WeightedInformationTheory;

/* loaded from: input_file:org/tribuo/util/infotheory/impl/WeightedPairDistribution.class */
public class WeightedPairDistribution<T1, T2> {
    public final long count;
    private final Map<CachedPair<T1, T2>, WeightCountTuple> jointCounts;
    private final Map<T1, WeightCountTuple> firstCount;
    private final Map<T2, WeightCountTuple> secondCount;

    public WeightedPairDistribution(long j, Map<CachedPair<T1, T2>, WeightCountTuple> map, Map<T1, WeightCountTuple> map2, Map<T2, WeightCountTuple> map3) {
        this.count = j;
        this.jointCounts = new LinkedHashMap(map);
        this.firstCount = new LinkedHashMap(map2);
        this.secondCount = new LinkedHashMap(map3);
    }

    public WeightedPairDistribution(long j, LinkedHashMap<CachedPair<T1, T2>, WeightCountTuple> linkedHashMap, LinkedHashMap<T1, WeightCountTuple> linkedHashMap2, LinkedHashMap<T2, WeightCountTuple> linkedHashMap3) {
        this.count = j;
        this.jointCounts = linkedHashMap;
        this.firstCount = linkedHashMap2;
        this.secondCount = linkedHashMap3;
    }

    public Map<CachedPair<T1, T2>, WeightCountTuple> getJointCounts() {
        return this.jointCounts;
    }

    public Map<T1, WeightCountTuple> getFirstCount() {
        return this.firstCount;
    }

    public Map<T2, WeightCountTuple> getSecondCount() {
        return this.secondCount;
    }

    public static <T1, T2> WeightedPairDistribution<T1, T2> constructFromLists(List<T1> list, List<T2> list2, List<Double> list3) {
        LinkedHashMap linkedHashMap = new LinkedHashMap(20);
        LinkedHashMap linkedHashMap2 = new LinkedHashMap(20);
        LinkedHashMap linkedHashMap3 = new LinkedHashMap(20);
        if (list.size() != list2.size() || list.size() != list3.size()) {
            throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + list.size() + ", second.size() = " + list2.size() + ", weights.size() = " + list3.size());
        }
        long j = 0;
        for (int i = 0; i < list.size(); i++) {
            T1 t1 = list.get(i);
            T2 t2 = list2.get(i);
            double doubleValue = list3.get(i).doubleValue();
            WeightCountTuple weightCountTuple = (WeightCountTuple) linkedHashMap.computeIfAbsent(new CachedPair(t1, t2), cachedPair -> {
                return new WeightCountTuple();
            });
            weightCountTuple.weight += doubleValue;
            weightCountTuple.count++;
            WeightCountTuple weightCountTuple2 = (WeightCountTuple) linkedHashMap2.computeIfAbsent(t1, obj -> {
                return new WeightCountTuple();
            });
            weightCountTuple2.weight += doubleValue;
            weightCountTuple2.count++;
            WeightCountTuple weightCountTuple3 = (WeightCountTuple) linkedHashMap3.computeIfAbsent(t2, obj2 -> {
                return new WeightCountTuple();
            });
            weightCountTuple3.weight += doubleValue;
            weightCountTuple3.count++;
            j++;
        }
        WeightedInformationTheory.normaliseWeights(linkedHashMap);
        WeightedInformationTheory.normaliseWeights(linkedHashMap2);
        WeightedInformationTheory.normaliseWeights(linkedHashMap3);
        return new WeightedPairDistribution<>(j, linkedHashMap, linkedHashMap2, linkedHashMap3);
    }

    public static <T1, T2> WeightedPairDistribution<T1, T2> constructFromMap(Map<CachedPair<T1, T2>, WeightCountTuple> map) {
        LinkedHashMap linkedHashMap = new LinkedHashMap(map);
        LinkedHashMap linkedHashMap2 = new LinkedHashMap(20);
        LinkedHashMap linkedHashMap3 = new LinkedHashMap(20);
        long j = 0;
        for (Map.Entry entry : linkedHashMap.entrySet()) {
            CachedPair cachedPair = (CachedPair) entry.getKey();
            WeightCountTuple weightCountTuple = (WeightCountTuple) entry.getValue();
            Object a = cachedPair.getA();
            Object b = cachedPair.getB();
            double d = weightCountTuple.weight * weightCountTuple.count;
            WeightCountTuple weightCountTuple2 = (WeightCountTuple) linkedHashMap2.computeIfAbsent(a, obj -> {
                return new WeightCountTuple();
            });
            weightCountTuple2.weight += d;
            weightCountTuple2.count += weightCountTuple.count;
            WeightCountTuple weightCountTuple3 = (WeightCountTuple) linkedHashMap3.computeIfAbsent(b, obj2 -> {
                return new WeightCountTuple();
            });
            weightCountTuple3.weight += d;
            weightCountTuple3.count += weightCountTuple.count;
            j += weightCountTuple.count;
        }
        WeightedInformationTheory.normaliseWeights(linkedHashMap2);
        WeightedInformationTheory.normaliseWeights(linkedHashMap3);
        return new WeightedPairDistribution<>(j, linkedHashMap, linkedHashMap2, linkedHashMap3);
    }
}
