package org.tribuo.classification.sequence;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.evaluation.LabelMetric;
import org.tribuo.classification.evaluation.LabelMetrics;
import org.tribuo.evaluation.metrics.MetricContext;
import org.tribuo.evaluation.metrics.MetricID;
import org.tribuo.evaluation.metrics.MetricTarget;
import org.tribuo.provenance.EvaluationProvenance;
import org.tribuo.sequence.AbstractSequenceEvaluator;
import org.tribuo.sequence.SequenceEvaluation;
import org.tribuo.sequence.SequenceModel;

/* loaded from: input_file:org/tribuo/classification/sequence/LabelSequenceEvaluator.class */
public class LabelSequenceEvaluator extends AbstractSequenceEvaluator<Label, LabelMetric.Context, LabelSequenceEvaluation, LabelMetric> {
    protected Set<LabelMetric> createMetrics(SequenceModel<Label> sequenceModel) {
        HashSet hashSet = new HashSet();
        Iterator it = sequenceModel.getOutputIDInfo().getDomain().iterator();
        while (it.hasNext()) {
            MetricTarget<Label> metricTarget = new MetricTarget<>((Label) it.next());
            hashSet.add(LabelMetrics.TP.forTarget(metricTarget));
            hashSet.add(LabelMetrics.FP.forTarget(metricTarget));
            hashSet.add(LabelMetrics.TN.forTarget(metricTarget));
            hashSet.add(LabelMetrics.FN.forTarget(metricTarget));
            hashSet.add(LabelMetrics.PRECISION.forTarget(metricTarget));
            hashSet.add(LabelMetrics.RECALL.forTarget(metricTarget));
            hashSet.add(LabelMetrics.F1.forTarget(metricTarget));
            hashSet.add(LabelMetrics.ACCURACY.forTarget(metricTarget));
        }
        MetricTarget<Label> microAverageTarget = MetricTarget.microAverageTarget();
        hashSet.add(LabelMetrics.TP.forTarget(microAverageTarget));
        hashSet.add(LabelMetrics.FP.forTarget(microAverageTarget));
        hashSet.add(LabelMetrics.TN.forTarget(microAverageTarget));
        hashSet.add(LabelMetrics.FN.forTarget(microAverageTarget));
        hashSet.add(LabelMetrics.PRECISION.forTarget(microAverageTarget));
        hashSet.add(LabelMetrics.RECALL.forTarget(microAverageTarget));
        hashSet.add(LabelMetrics.F1.forTarget(microAverageTarget));
        hashSet.add(LabelMetrics.ACCURACY.forTarget(microAverageTarget));
        MetricTarget<Label> macroAverageTarget = MetricTarget.macroAverageTarget();
        hashSet.add(LabelMetrics.TP.forTarget(macroAverageTarget));
        hashSet.add(LabelMetrics.FP.forTarget(macroAverageTarget));
        hashSet.add(LabelMetrics.TN.forTarget(macroAverageTarget));
        hashSet.add(LabelMetrics.FN.forTarget(macroAverageTarget));
        hashSet.add(LabelMetrics.PRECISION.forTarget(macroAverageTarget));
        hashSet.add(LabelMetrics.RECALL.forTarget(macroAverageTarget));
        hashSet.add(LabelMetrics.F1.forTarget(macroAverageTarget));
        hashSet.add(LabelMetrics.ACCURACY.forTarget(macroAverageTarget));
        hashSet.add(LabelMetrics.BALANCED_ERROR_RATE.forTarget(macroAverageTarget));
        return hashSet;
    }

    protected LabelMetric.Context createContext(SequenceModel<Label> sequenceModel, List<List<Prediction<Label>>> list) {
        return new LabelMetric.Context(sequenceModel, flattenList(list));
    }

    protected LabelSequenceEvaluation createEvaluation(LabelMetric.Context context, Map<MetricID<Label>, Double> map, EvaluationProvenance evaluationProvenance) {
        return new LabelSequenceEvaluation(map, context, evaluationProvenance);
    }

    private static List<Prediction<Label>> flattenList(List<List<Prediction<Label>>> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<List<Prediction<Label>>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.addAll(it.next());
        }
        return arrayList;
    }

    protected /* bridge */ /* synthetic */ SequenceEvaluation createEvaluation(MetricContext metricContext, Map map, EvaluationProvenance evaluationProvenance) {
        return createEvaluation((LabelMetric.Context) metricContext, (Map<MetricID<Label>, Double>) map, evaluationProvenance);
    }

    /* renamed from: createContext, reason: collision with other method in class */
    protected /* bridge */ /* synthetic */ MetricContext m373createContext(SequenceModel sequenceModel, List list) {
        return createContext((SequenceModel<Label>) sequenceModel, (List<List<Prediction<Label>>>) list);
    }
}
