package org.tribuo.anomaly.evaluation;

import java.util.List;
import java.util.function.ToDoubleBiFunction;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.anomaly.Event;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricContext;
import org.tribuo.evaluation.metrics.MetricTarget;

/* loaded from: input_file:org/tribuo/anomaly/evaluation/AnomalyMetric.class */
public class AnomalyMetric implements EvaluationMetric<Event, Context> {
    private final MetricTarget<Event> target;
    private final String name;
    private final ToDoubleBiFunction<MetricTarget<Event>, Context> impl;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/tribuo/anomaly/evaluation/AnomalyMetric$Context.class */
    public static final class Context extends MetricContext<Event> {
        private final long truePositive;
        private final long falsePositive;
        private final long trueNegative;
        private final long falseNegative;

        Context(Model<Event> model, List<Prediction<Event>> list) {
            super(model, list);
            PredictionStatistics tabulate = tabulate(list);
            this.truePositive = tabulate.truePositive;
            this.falsePositive = tabulate.falsePositive;
            this.trueNegative = tabulate.trueNegative;
            this.falseNegative = tabulate.falseNegative;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public long getTruePositive() {
            return this.truePositive;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public long getFalsePositive() {
            return this.falsePositive;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public long getTrueNegative() {
            return this.trueNegative;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public long getFalseNegative() {
            return this.falseNegative;
        }

        private static PredictionStatistics tabulate(List<Prediction<Event>> list) {
            long j = 0;
            long j2 = 0;
            long j3 = 0;
            long j4 = 0;
            for (Prediction<Event> prediction : list) {
                Event.EventType type = ((Event) prediction.getExample().getOutput()).getType();
                Event.EventType type2 = ((Event) prediction.getOutput()).getType();
                if (type != Event.EventType.ANOMALOUS) {
                    if (type != Event.EventType.EXPECTED) {
                        throw new IllegalArgumentException("Evaluation data contained EventType.UNKNOWN as the ground truth output.");
                    }
                    if (type2 == Event.EventType.ANOMALOUS) {
                        j2++;
                    } else if (type2 == Event.EventType.EXPECTED) {
                        j3++;
                    }
                } else if (type2 == Event.EventType.ANOMALOUS) {
                    j++;
                } else if (type2 == Event.EventType.EXPECTED) {
                    j4++;
                }
            }
            return new PredictionStatistics(j, j2, j3, j4);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/anomaly/evaluation/AnomalyMetric$PredictionStatistics.class */
    public static final class PredictionStatistics {
        private final long truePositive;
        private final long falsePositive;
        private final long trueNegative;
        private final long falseNegative;

        PredictionStatistics(long j, long j2, long j3, long j4) {
            this.truePositive = j;
            this.falsePositive = j2;
            this.trueNegative = j3;
            this.falseNegative = j4;
        }
    }

    public AnomalyMetric(MetricTarget<Event> metricTarget, String str, ToDoubleBiFunction<MetricTarget<Event>, Context> toDoubleBiFunction) {
        this.target = metricTarget;
        this.name = str;
        this.impl = toDoubleBiFunction;
    }

    public double compute(Context context) {
        return this.impl.applyAsDouble(this.target, context);
    }

    public MetricTarget<Event> getTarget() {
        return this.target;
    }

    public String getName() {
        return this.name;
    }

    public Context createContext(Model<Event> model, List<Prediction<Event>> list) {
        return buildContext(model, list);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Context buildContext(Model<Event> model, List<Prediction<Event>> list) {
        return new Context(model, list);
    }

    /* renamed from: createContext, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ MetricContext m17createContext(Model model, List list) {
        return createContext((Model<Event>) model, (List<Prediction<Event>>) list);
    }
}
