package org.tribuo.util.infotheory;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
import org.tribuo.util.infotheory.impl.CachedPair;
import org.tribuo.util.infotheory.impl.CachedTriple;
import org.tribuo.util.infotheory.impl.PairDistribution;
import org.tribuo.util.infotheory.impl.RowList;
import org.tribuo.util.infotheory.impl.TripleDistribution;

/* loaded from: input_file:org/tribuo/util/infotheory/InformationTheory.class */
public final class InformationTheory {
    public static final double SAMPLES_RATIO = 5.0d;
    public static final int DEFAULT_MAP_SIZE = 20;
    private static final Logger logger = Logger.getLogger(InformationTheory.class.getName());
    public static final double LOG_2 = Math.log(2.0d);
    public static final double LOG_E = Math.log(2.718281828459045d);
    public static double LOG_BASE = LOG_2;

    /* loaded from: input_file:org/tribuo/util/infotheory/InformationTheory$GTestStatistics.class */
    public static final class GTestStatistics {
        public final double gStatistic;
        public final int numStates;
        public final double probability;

        public GTestStatistics(double d, int i, double d2) {
            this.gStatistic = d;
            this.numStates = i;
            this.probability = d2;
        }

        public String toString() {
            return "GTest(statistic=" + this.gStatistic + ",probability=" + this.probability + ",numStates=" + this.numStates + ")";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/util/infotheory/InformationTheory$ScoreStateCountTuple.class */
    public static class ScoreStateCountTuple {
        public final double score;
        public final int stateCount;

        ScoreStateCountTuple(double d, int i) {
            this.score = d;
            this.stateCount = i;
        }

        public String toString() {
            return "ScoreStateCount(score=" + this.score + ",stateCount=" + this.stateCount + ")";
        }
    }

    private InformationTheory() {
    }

    public static <T1, T2> double mi(Set<List<T1>> set, Set<List<T2>> set2) {
        return mi(new RowList(set), new RowList(set2));
    }

    public static <T1, T2, T3> double cmi(List<T1> list, List<T2> list2, Set<List<T3>> set) {
        return set.isEmpty() ? mi(list, list2) : conditionalMI(list, list2, new RowList(set));
    }

    public static <T1, T2, T3> GTestStatistics gTest(List<T1> list, List<T2> list2, Set<List<T3>> set) {
        ScoreStateCountTuple innerMI = set == null ? innerMI(list, list2) : set.isEmpty() ? innerMI(list, list2) : innerConditionalMI(list, list2, new RowList(set));
        double size = 2 * list2.size() * innerMI.score;
        return new GTestStatistics(size, innerMI.stateCount, computeChiSquaredProbability(innerMI.stateCount, size));
    }

    private static double computeChiSquaredProbability(int i, double d) {
        if (d <= 0.0d) {
            return 0.0d;
        }
        return Gamma.regularizedGammaP(i / 2, d / 2, 1.0E-14d, Integer.MAX_VALUE);
    }

    public static <T1, T2, T3> double jointMI(List<T1> list, List<T2> list2, List<T3> list3) {
        if (list.size() == list2.size() && list.size() == list3.size()) {
            return jointMI(TripleDistribution.constructFromLists(list, list2, list3));
        }
        throw new IllegalArgumentException("Joint Mutual Information requires three vectors the same length. first.size() = " + list.size() + ", second.size() = " + list2.size() + ", target.size() = " + list3.size());
    }

    public static <T1, T2, T3> double jointMI(TripleDistribution<T1, T2, T3> tripleDistribution) {
        double d = tripleDistribution.count;
        Map<CachedTriple<T1, T2, T3>, MutableLong> jointCount = tripleDistribution.getJointCount();
        Map<CachedPair<T1, T2>, MutableLong> aBCount = tripleDistribution.getABCount();
        Map<T3, MutableLong> cCount = tripleDistribution.getCCount();
        double d2 = 0.0d;
        for (Map.Entry<CachedTriple<T1, T2, T3>, MutableLong> entry : jointCount.entrySet()) {
            double doubleValue = entry.getValue().doubleValue();
            d2 += (doubleValue / d) * Math.log((d * doubleValue) / (aBCount.get(entry.getKey().getAB()).doubleValue() * cCount.get(entry.getKey().getC()).doubleValue()));
        }
        double d3 = d2 / LOG_BASE;
        double size = d / jointCount.size();
        if (size < 5.0d) {
            logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{Double.valueOf(d3), Double.valueOf(size), Double.valueOf(d), Integer.valueOf(jointCount.size())});
        }
        return d3;
    }

    private static <T1, T2, T3> ScoreStateCountTuple innerConditionalMI(TripleDistribution<T1, T2, T3> tripleDistribution, boolean z) {
        Map<CachedTriple<T1, T2, T3>, MutableLong> jointCount = tripleDistribution.getJointCount();
        Map<CachedPair<T1, T2>, MutableLong> aBCount = tripleDistribution.getABCount();
        Map<CachedPair<T1, T3>, MutableLong> aCCount = tripleDistribution.getACCount();
        Map<CachedPair<T2, T3>, MutableLong> bCCount = tripleDistribution.getBCCount();
        Map<T2, MutableLong> bCount = tripleDistribution.getBCount();
        Map<T3, MutableLong> cCount = tripleDistribution.getCCount();
        double d = tripleDistribution.count;
        double d2 = 0.0d;
        if (z) {
            for (Map.Entry<CachedTriple<T1, T2, T3>, MutableLong> entry : jointCount.entrySet()) {
                double doubleValue = entry.getValue().doubleValue();
                d2 += (doubleValue / d) * Math.log((bCount.get(entry.getKey().getB()).doubleValue() * doubleValue) / (aBCount.get(entry.getKey().getAB()).doubleValue() * bCCount.get(entry.getKey().getBC()).doubleValue()));
            }
        } else {
            for (Map.Entry<CachedTriple<T1, T2, T3>, MutableLong> entry2 : jointCount.entrySet()) {
                double doubleValue2 = entry2.getValue().doubleValue();
                d2 += (doubleValue2 / d) * Math.log((cCount.get(entry2.getKey().getC()).doubleValue() * doubleValue2) / (aCCount.get(entry2.getKey().getAC()).doubleValue() * bCCount.get(entry2.getKey().getBC()).doubleValue()));
            }
        }
        double d3 = d2 / LOG_BASE;
        double size = d / jointCount.size();
        if (size < 5.0d) {
            logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d3), Double.valueOf(size)});
        }
        return new ScoreStateCountTuple(d3, jointCount.size());
    }

    private static <T1, T2, T3> ScoreStateCountTuple innerConditionalMI(List<T1> list, List<T2> list2, List<T3> list3) {
        if (list.size() == list2.size() && list.size() == list3.size()) {
            return innerConditionalMI(TripleDistribution.constructFromLists(list, list2, list3), false);
        }
        throw new IllegalArgumentException("Conditional Mutual Information requires three vectors the same length. first.size() = " + list.size() + ", second.size() = " + list2.size() + ", condition.size() = " + list3.size());
    }

    public static <T1, T2, T3> double conditionalMI(List<T1> list, List<T2> list2, List<T3> list3) {
        return innerConditionalMI(list, list2, list3).score;
    }

    public static <T1, T2, T3> double conditionalMI(TripleDistribution<T1, T2, T3> tripleDistribution) {
        return innerConditionalMI(tripleDistribution, false).score;
    }

    public static <T1, T2, T3> double conditionalMIFlipped(TripleDistribution<T1, T2, T3> tripleDistribution) {
        return innerConditionalMI(tripleDistribution, true).score;
    }

    private static <T1, T2> ScoreStateCountTuple innerMI(PairDistribution<T1, T2> pairDistribution) {
        Map<CachedPair<T1, T2>, MutableLong> map = pairDistribution.jointCounts;
        Map<T1, MutableLong> map2 = pairDistribution.firstCount;
        Map<T2, MutableLong> map3 = pairDistribution.secondCount;
        double d = pairDistribution.count;
        double d2 = 0.0d;
        boolean z = false;
        for (Map.Entry<CachedPair<T1, T2>, MutableLong> entry : map.entrySet()) {
            double doubleValue = entry.getValue().doubleValue();
            double d3 = doubleValue / d;
            double doubleValue2 = map2.get(entry.getKey().getA()).doubleValue();
            double doubleValue3 = map3.get(entry.getKey().getB()).doubleValue();
            double d4 = d * doubleValue;
            double d5 = doubleValue2 * doubleValue3;
            double d6 = d4 / d5;
            double log = Math.log(d6);
            if (Double.isNaN(log) || Double.isNaN(d3) || Double.isNaN(d2)) {
                logger.log(Level.WARNING, "State = " + entry.getKey().toString());
                logger.log(Level.WARNING, "mi = " + d2 + " prob = " + d3 + " top = " + d4 + " bottom = " + d5 + " ratio = " + d6 + " logRatio = " + log);
                z = true;
            }
            d2 += d3 * log;
        }
        double d7 = d2 / LOG_BASE;
        double size = d / map.size();
        if (size < 5.0d) {
            logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d7), Double.valueOf(size)});
        }
        if (z) {
            logger.log(Level.SEVERE, "NanFound ", (Throwable) new IllegalStateException("NaN found"));
        }
        return new ScoreStateCountTuple(d7, map.size());
    }

    private static <T1, T2> ScoreStateCountTuple innerMI(List<T1> list, List<T2> list2) {
        if (list.size() == list2.size()) {
            return innerMI(PairDistribution.constructFromLists(list, list2));
        }
        throw new IllegalArgumentException("Mutual Information requires two vectors the same length. first.size() = " + list.size() + ", second.size() = " + list2.size());
    }

    public static <T1, T2> double mi(List<T1> list, List<T2> list2) {
        return innerMI(list, list2).score;
    }

    public static <T1, T2> double mi(PairDistribution<T1, T2> pairDistribution) {
        return innerMI(pairDistribution).score;
    }

    public static <T1, T2> double jointEntropy(List<T1> list, List<T2> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Joint Entropy requires two vectors the same length. first.size() = " + list.size() + ", second.size() = " + list2.size());
        }
        double size = list.size();
        double d = 0.0d;
        Iterator<Map.Entry<CachedPair<T1, T2>, MutableLong>> it = PairDistribution.constructFromLists(list, list2).jointCounts.entrySet().iterator();
        while (it.hasNext()) {
            double doubleValue = it.next().getValue().doubleValue() / size;
            d -= doubleValue * Math.log(doubleValue);
        }
        double d2 = d / LOG_BASE;
        double size2 = size / r0.size();
        if (size2 < 5.0d) {
            logger.log(Level.INFO, "Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d2), Double.valueOf(size2)});
        }
        return d2;
    }

    public static <T1, T2> double conditionalEntropy(List<T1> list, List<T2> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Conditional Entropy requires two vectors the same length. vector.size() = " + list.size() + ", condition.size() = " + list2.size());
        }
        double size = list.size();
        double d = 0.0d;
        PairDistribution constructFromLists = PairDistribution.constructFromLists(list, list2);
        Map<CachedPair<T1, T2>, MutableLong> map = constructFromLists.jointCounts;
        Map<T2, MutableLong> map2 = constructFromLists.secondCount;
        for (Map.Entry<CachedPair<T1, T2>, MutableLong> entry : map.entrySet()) {
            double doubleValue = entry.getValue().doubleValue() / size;
            d -= doubleValue * Math.log(doubleValue / (map2.get(entry.getKey().getB()).doubleValue() / size));
        }
        double d2 = d / LOG_BASE;
        double size2 = size / map.size();
        if (size2 < 5.0d) {
            logger.log(Level.INFO, "Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d2), Double.valueOf(size2)});
        }
        return d2;
    }

    public static <T> double entropy(List<T> list) {
        double size = list.size();
        double d = 0.0d;
        Iterator it = calculateCountDist(list).entrySet().iterator();
        while (it.hasNext()) {
            double longValue = ((Long) ((Map.Entry) it.next()).getValue()).longValue() / size;
            d -= longValue * Math.log(longValue);
        }
        double d2 = d / LOG_BASE;
        double size2 = size / r0.size();
        if (size2 < 5.0d) {
            logger.log(Level.INFO, "Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{Double.valueOf(d2), Double.valueOf(size2)});
        }
        return d2;
    }

    public static <T> Map<T, Long> calculateCountDist(List<T> list) {
        HashMap hashMap = new HashMap(20);
        for (T t : list) {
            hashMap.put(t, Long.valueOf(((Long) hashMap.getOrDefault(t, 0L)).longValue() + 1));
        }
        return hashMap;
    }

    public static double calculateEntropy(Stream<Double> stream) {
        return ((Double) stream.map(d -> {
            return Double.valueOf(((-d.doubleValue()) * Math.log(d.doubleValue())) / LOG_BASE);
        }).reduce(Double.valueOf(0.0d), (v0, v1) -> {
            return Double.sum(v0, v1);
        })).doubleValue();
    }

    public static double calculateEntropy(DoubleStream doubleStream) {
        return doubleStream.map(d -> {
            return ((-d) * Math.log(d)) / LOG_BASE;
        }).sum();
    }

    public static <T> double expectedMI(List<T> list, List<T> list2) {
        PairDistribution constructFromLists = PairDistribution.constructFromLists(list, list2);
        Map<T1, MutableLong> map = constructFromLists.firstCount;
        Map<T2, MutableLong> map2 = constructFromLists.secondCount;
        long j = constructFromLists.count;
        double d = 0.0d;
        for (Map.Entry entry : map.entrySet()) {
            for (Map.Entry entry2 : map2.entrySet()) {
                long longValue = ((MutableLong) entry.getValue()).longValue();
                long longValue2 = ((MutableLong) entry2.getValue()).longValue();
                long min = Math.min(longValue, longValue2);
                long j2 = (longValue + longValue2) - j;
                long j3 = j2 > 1 ? j2 : 1L;
                while (true) {
                    long j4 = j3;
                    if (j4 <= min) {
                        d += (j4 / j) * Math.log((j * j4) / (longValue * longValue2)) * Math.exp((((((((Gamma.logGamma(longValue + 1) + Gamma.logGamma(longValue2 + 1)) + Gamma.logGamma((j - longValue) + 1)) + Gamma.logGamma((j - longValue2) + 1)) - Gamma.logGamma(j + 1)) - Gamma.logGamma(j4 + 1)) - Gamma.logGamma((longValue - j4) + 1)) - Gamma.logGamma((longValue2 - j4) + 1)) - Gamma.logGamma((((j - longValue) - longValue2) + j4) + 1));
                        j3 = j4 + 1;
                    }
                }
            }
        }
        return d;
    }
}
