package org.tribuo.classification.evaluation;

import java.util.Iterator;
import java.util.logging.Logger;
import org.tribuo.classification.Classifiable;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricTarget;

/* loaded from: input_file:org/tribuo/classification/evaluation/ConfusionMetrics.class */
public final class ConfusionMetrics {
    private static final Logger logger = Logger.getLogger(ConfusionMetrics.class.getName());

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.tribuo.classification.evaluation.ConfusionMetrics$1, reason: invalid class name */
    /* loaded from: input_file:org/tribuo/classification/evaluation/ConfusionMetrics$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tribuo$evaluation$metrics$EvaluationMetric$Average = new int[EvaluationMetric.Average.values().length];

        static {
            try {
                $SwitchMap$org$tribuo$evaluation$metrics$EvaluationMetric$Average[EvaluationMetric.Average.MACRO.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tribuo$evaluation$metrics$EvaluationMetric$Average[EvaluationMetric.Average.MICRO.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    @FunctionalInterface
    /* loaded from: input_file:org/tribuo/classification/evaluation/ConfusionMetrics$ConfusionFunction.class */
    public interface ConfusionFunction<T extends Classifiable<T>> {
        double compute(double d, double d2, double d3, double d4);

        /* JADX WARN: Multi-variable type inference failed */
        default double compute(MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix) {
            if (metricTarget.getOutputTarget().isPresent()) {
                return compute((ConfusionFunction<T>) metricTarget.getOutputTarget().get(), (ConfusionMatrix<ConfusionFunction<T>>) confusionMatrix);
            }
            if (metricTarget.getAverageTarget().isPresent()) {
                return compute((EvaluationMetric.Average) metricTarget.getAverageTarget().get(), confusionMatrix);
            }
            throw new IllegalStateException("MetricTarget with no actual target");
        }

        default double compute(T t, ConfusionMatrix<T> confusionMatrix) {
            return compute(confusionMatrix.tp(t), confusionMatrix.fp(t), confusionMatrix.tn(t), confusionMatrix.fn(t));
        }

        /* JADX WARN: Multi-variable type inference failed */
        default double compute(EvaluationMetric.Average average, ConfusionMatrix<T> confusionMatrix) {
            switch (AnonymousClass1.$SwitchMap$org$tribuo$evaluation$metrics$EvaluationMetric$Average[average.ordinal()]) {
                case 1:
                    if (confusionMatrix.getDomain().size() == 0) {
                        ConfusionMetrics.logger.warning("Empty domain: macro-average ill-defined.");
                        return Double.NaN;
                    }
                    double d = 0.0d;
                    Iterator it = confusionMatrix.getDomain().getDomain().iterator();
                    while (it.hasNext()) {
                        d += compute((ConfusionFunction<T>) it.next(), (ConfusionMatrix<ConfusionFunction<T>>) confusionMatrix);
                    }
                    return d / confusionMatrix.getDomain().size();
                case 2:
                    if (confusionMatrix.support() != 0.0d) {
                        return compute(confusionMatrix.tp(), confusionMatrix.fp(), confusionMatrix.tn(), confusionMatrix.fn());
                    }
                    ConfusionMetrics.logger.warning("No predictions: micro-average ill-defined.");
                    return Double.NaN;
                default:
                    throw new IllegalArgumentException("Unsupported average type: " + average.name());
            }
        }
    }

    private ConfusionMetrics() {
    }

    public static <T extends Classifiable<T>> double accuracy(MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix) {
        return metricTarget.getOutputTarget().isPresent() ? accuracy((Classifiable) metricTarget.getOutputTarget().get(), confusionMatrix) : accuracy((EvaluationMetric.Average) metricTarget.getAverageTarget().get(), confusionMatrix);
    }

    public static <T extends Classifiable<T>> double accuracy(T t, ConfusionMatrix<T> confusionMatrix) {
        if (confusionMatrix.support(t) != 0.0d) {
            return confusionMatrix.tp(t) / confusionMatrix.support(t);
        }
        logger.warning("No predictions for " + t + ": accuracy ill-defined");
        return Double.NaN;
    }

    public static <T extends Classifiable<T>> double accuracy(EvaluationMetric.Average average, ConfusionMatrix<T> confusionMatrix) {
        if (average.equals(EvaluationMetric.Average.MICRO)) {
            if (confusionMatrix.support() != 0.0d) {
                return confusionMatrix.tp() / confusionMatrix.support();
            }
            logger.warning("No predictions: accuracy ill-defined");
            return Double.NaN;
        }
        if (confusionMatrix.getDomain().size() == 0) {
            logger.warning("Empty domain: accuracy ill-defined");
            return Double.NaN;
        }
        double d = 0.0d;
        Iterator it = confusionMatrix.getDomain().getDomain().iterator();
        while (it.hasNext()) {
            d += accuracy((Classifiable) it.next(), confusionMatrix);
        }
        return d / confusionMatrix.getDomain().size();
    }

    public static <T extends Classifiable<T>> double balancedErrorRate(ConfusionMatrix<T> confusionMatrix) {
        if (confusionMatrix.getDomain().size() == 0) {
            logger.warning("Empty domain: balanced error rate ill-defined");
            return Double.NaN;
        }
        double d = 0.0d;
        Iterator it = confusionMatrix.getDomain().getDomain().iterator();
        while (it.hasNext()) {
            d += recall(new MetricTarget((Classifiable) it.next()), confusionMatrix);
        }
        return 1.0d - (d / confusionMatrix.getDomain().size());
    }

    private static <T extends Classifiable<T>> double compute(ConfusionFunction<T> confusionFunction, MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix) {
        return confusionFunction.compute(metricTarget, confusionMatrix);
    }

    public static <T extends Classifiable<T>> double tp(MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix) {
        return compute(ConfusionMetrics::tp, metricTarget, confusionMatrix);
    }

    public static <T extends Classifiable<T>> double fp(MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix) {
        return compute(ConfusionMetrics::fp, metricTarget, confusionMatrix);
    }

    public static <T extends Classifiable<T>> double tn(MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix) {
        return compute(ConfusionMetrics::tn, metricTarget, confusionMatrix);
    }

    public static <T extends Classifiable<T>> double fn(MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix) {
        return compute(ConfusionMetrics::fn, metricTarget, confusionMatrix);
    }

    private static double tp(double d, double d2, double d3, double d4) {
        return d;
    }

    private static double fp(double d, double d2, double d3, double d4) {
        return d2;
    }

    private static double tn(double d, double d2, double d3, double d4) {
        return d3;
    }

    private static double fn(double d, double d2, double d3, double d4) {
        return d4;
    }

    public static <T extends Classifiable<T>> double precision(MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix) {
        return compute(ConfusionMetrics::precision, metricTarget, confusionMatrix);
    }

    public static double precision(double d, double d2, double d3, double d4) {
        double d5 = d + d2;
        if (d5 == 0.0d) {
            return 0.0d;
        }
        return d / d5;
    }

    public static <T extends Classifiable<T>> double recall(MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix) {
        return compute(ConfusionMetrics::recall, metricTarget, confusionMatrix);
    }

    public static double recall(double d, double d2, double d3, double d4) {
        double d5 = d + d4;
        if (d5 == 0.0d) {
            return 0.0d;
        }
        return d / d5;
    }

    public static <T extends Classifiable<T>> double f1(MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix) {
        return compute(ConfusionMetrics::f1, metricTarget, confusionMatrix);
    }

    public static double f1(double d, double d2, double d3, double d4) {
        return fscore(1.0d, d, d2, d3, d4);
    }

    public static double fscore(double d, double d2, double d3, double d4, double d5) {
        double d6 = d * d;
        double precision = precision(d2, d3, d4, d5);
        double recall = recall(d2, d3, d4, d5);
        double d7 = (d6 * precision) + recall;
        if (d7 == 0.0d) {
            return 0.0d;
        }
        return (((1.0d + d6) * precision) * recall) / d7;
    }

    public static <T extends Classifiable<T>> double fscore(MetricTarget<T> metricTarget, ConfusionMatrix<T> confusionMatrix, double d) {
        return compute((d2, d3, d4, d5) -> {
            return fscore(d, d2, d3, d4, d5);
        }, metricTarget, confusionMatrix);
    }
}
