package org.tribuo.util.infotheory.impl;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.tribuo.util.infotheory.WeightedInformationTheory;

/* loaded from: input_file:org/tribuo/util/infotheory/impl/WeightedTripleDistribution.class */
public class WeightedTripleDistribution<T1, T2, T3> {
    public static final int DEFAULT_MAP_SIZE = 20;
    public final long count;
    private final Map<CachedTriple<T1, T2, T3>, WeightCountTuple> jointCount;
    private final Map<CachedPair<T1, T2>, WeightCountTuple> abCount;
    private final Map<CachedPair<T1, T3>, WeightCountTuple> acCount;
    private final Map<CachedPair<T2, T3>, WeightCountTuple> bcCount;
    private final Map<T1, WeightCountTuple> aCount;
    private final Map<T2, WeightCountTuple> bCount;
    private final Map<T3, WeightCountTuple> cCount;

    public WeightedTripleDistribution(long j, Map<CachedTriple<T1, T2, T3>, WeightCountTuple> map, Map<CachedPair<T1, T2>, WeightCountTuple> map2, Map<CachedPair<T1, T3>, WeightCountTuple> map3, Map<CachedPair<T2, T3>, WeightCountTuple> map4, Map<T1, WeightCountTuple> map5, Map<T2, WeightCountTuple> map6, Map<T3, WeightCountTuple> map7) {
        this.count = j;
        this.jointCount = map;
        this.abCount = map2;
        this.acCount = map3;
        this.bcCount = map4;
        this.aCount = map5;
        this.bCount = map6;
        this.cCount = map7;
    }

    public Map<CachedTriple<T1, T2, T3>, WeightCountTuple> getJointCount() {
        return this.jointCount;
    }

    public Map<CachedPair<T1, T2>, WeightCountTuple> getABCount() {
        return this.abCount;
    }

    public Map<CachedPair<T1, T3>, WeightCountTuple> getACCount() {
        return this.acCount;
    }

    public Map<CachedPair<T2, T3>, WeightCountTuple> getBCCount() {
        return this.bcCount;
    }

    public Map<T1, WeightCountTuple> getACount() {
        return this.aCount;
    }

    public Map<T2, WeightCountTuple> getBCount() {
        return this.bCount;
    }

    public Map<T3, WeightCountTuple> getCCount() {
        return this.cCount;
    }

    public static <T1, T2, T3> WeightedTripleDistribution<T1, T2, T3> constructFromLists(List<T1> list, List<T2> list2, List<T3> list3, List<Double> list4) {
        HashMap hashMap = new HashMap(20);
        HashMap hashMap2 = new HashMap(20);
        HashMap hashMap3 = new HashMap(20);
        HashMap hashMap4 = new HashMap(20);
        HashMap hashMap5 = new HashMap(20);
        HashMap hashMap6 = new HashMap(20);
        HashMap hashMap7 = new HashMap(20);
        long size = list.size();
        if (list.size() != list2.size() || list.size() != list3.size() || list.size() != list4.size()) {
            throw new IllegalArgumentException("Counting requires lists of the same length. first.size() = " + list.size() + ", second.size() = " + list2.size() + ", third.size() = " + list3.size() + ", weights.size() = " + list4.size());
        }
        for (int i = 0; i < list.size(); i++) {
            double doubleValue = list4.get(i).doubleValue();
            T1 t1 = list.get(i);
            T2 t2 = list2.get(i);
            T3 t3 = list3.get(i);
            CachedTriple cachedTriple = new CachedTriple(t1, t2, t3);
            CachedPair<T1, T2> ab = cachedTriple.getAB();
            CachedPair<T1, T3> ac = cachedTriple.getAC();
            CachedPair<T2, T3> bc = cachedTriple.getBC();
            WeightCountTuple weightCountTuple = (WeightCountTuple) hashMap.computeIfAbsent(cachedTriple, cachedTriple2 -> {
                return new WeightCountTuple();
            });
            weightCountTuple.weight += doubleValue;
            weightCountTuple.count++;
            WeightCountTuple weightCountTuple2 = (WeightCountTuple) hashMap2.computeIfAbsent(ab, cachedPair -> {
                return new WeightCountTuple();
            });
            weightCountTuple2.weight += doubleValue;
            weightCountTuple2.count++;
            WeightCountTuple weightCountTuple3 = (WeightCountTuple) hashMap3.computeIfAbsent(ac, cachedPair2 -> {
                return new WeightCountTuple();
            });
            weightCountTuple3.weight += doubleValue;
            weightCountTuple3.count++;
            WeightCountTuple weightCountTuple4 = (WeightCountTuple) hashMap4.computeIfAbsent(bc, cachedPair3 -> {
                return new WeightCountTuple();
            });
            weightCountTuple4.weight += doubleValue;
            weightCountTuple4.count++;
            WeightCountTuple weightCountTuple5 = (WeightCountTuple) hashMap5.computeIfAbsent(t1, obj -> {
                return new WeightCountTuple();
            });
            weightCountTuple5.weight += doubleValue;
            weightCountTuple5.count++;
            WeightCountTuple weightCountTuple6 = (WeightCountTuple) hashMap6.computeIfAbsent(t2, obj2 -> {
                return new WeightCountTuple();
            });
            weightCountTuple6.weight += doubleValue;
            weightCountTuple6.count++;
            WeightCountTuple weightCountTuple7 = (WeightCountTuple) hashMap7.computeIfAbsent(t3, obj3 -> {
                return new WeightCountTuple();
            });
            weightCountTuple7.weight += doubleValue;
            weightCountTuple7.count++;
        }
        WeightedInformationTheory.normaliseWeights(hashMap);
        WeightedInformationTheory.normaliseWeights(hashMap2);
        WeightedInformationTheory.normaliseWeights(hashMap3);
        WeightedInformationTheory.normaliseWeights(hashMap4);
        WeightedInformationTheory.normaliseWeights(hashMap5);
        WeightedInformationTheory.normaliseWeights(hashMap6);
        WeightedInformationTheory.normaliseWeights(hashMap7);
        return new WeightedTripleDistribution<>(size, hashMap, hashMap2, hashMap3, hashMap4, hashMap5, hashMap6, hashMap7);
    }

    public static <T1, T2, T3> WeightedTripleDistribution<T1, T2, T3> constructFromMap(Map<CachedTriple<T1, T2, T3>, WeightCountTuple> map) {
        HashMap hashMap = new HashMap(20);
        HashMap hashMap2 = new HashMap(20);
        HashMap hashMap3 = new HashMap(20);
        HashMap hashMap4 = new HashMap(20);
        HashMap hashMap5 = new HashMap(20);
        HashMap hashMap6 = new HashMap(20);
        long j = 0;
        for (Map.Entry<CachedTriple<T1, T2, T3>, WeightCountTuple> entry : map.entrySet()) {
            CachedTriple<T1, T2, T3> key = entry.getKey();
            WeightCountTuple value = entry.getValue();
            CachedPair<T1, T2> ab = key.getAB();
            CachedPair<T1, T3> ac = key.getAC();
            CachedPair<T2, T3> bc = key.getBC();
            T1 a = key.getA();
            T2 b = key.getB();
            T3 c = key.getC();
            j += value.count;
            double d = value.weight * value.count;
            WeightCountTuple weightCountTuple = (WeightCountTuple) hashMap.computeIfAbsent(ab, cachedPair -> {
                return new WeightCountTuple();
            });
            weightCountTuple.weight += d;
            weightCountTuple.count += value.count;
            WeightCountTuple weightCountTuple2 = (WeightCountTuple) hashMap2.computeIfAbsent(ac, cachedPair2 -> {
                return new WeightCountTuple();
            });
            weightCountTuple2.weight += d;
            weightCountTuple2.count += value.count;
            WeightCountTuple weightCountTuple3 = (WeightCountTuple) hashMap3.computeIfAbsent(bc, cachedPair3 -> {
                return new WeightCountTuple();
            });
            weightCountTuple3.weight += d;
            weightCountTuple3.count += value.count;
            WeightCountTuple weightCountTuple4 = (WeightCountTuple) hashMap4.computeIfAbsent(a, obj -> {
                return new WeightCountTuple();
            });
            weightCountTuple4.weight += d;
            weightCountTuple4.count += value.count;
            WeightCountTuple weightCountTuple5 = (WeightCountTuple) hashMap5.computeIfAbsent(b, obj2 -> {
                return new WeightCountTuple();
            });
            weightCountTuple5.weight += d;
            weightCountTuple5.count += value.count;
            WeightCountTuple weightCountTuple6 = (WeightCountTuple) hashMap6.computeIfAbsent(c, obj3 -> {
                return new WeightCountTuple();
            });
            weightCountTuple6.weight += d;
            weightCountTuple6.count += value.count;
        }
        WeightedInformationTheory.normaliseWeights(hashMap);
        WeightedInformationTheory.normaliseWeights(hashMap2);
        WeightedInformationTheory.normaliseWeights(hashMap3);
        WeightedInformationTheory.normaliseWeights(hashMap4);
        WeightedInformationTheory.normaliseWeights(hashMap5);
        WeightedInformationTheory.normaliseWeights(hashMap6);
        return new WeightedTripleDistribution<>(j, map, hashMap, hashMap2, hashMap3, hashMap4, hashMap5, hashMap6);
    }
}
