package org.apache.ignite.ml.selection.scoring.evaluator.aggregator;

import java.io.Serializable;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.scoring.evaluator.context.BinaryClassificationEvaluationContext;
import org.apache.ignite.ml.structures.LabeledVector;

/* loaded from: input_file:org/apache/ignite/ml/selection/scoring/evaluator/aggregator/BinaryClassificationPointwiseMetricStatsAggregator.class */
public class BinaryClassificationPointwiseMetricStatsAggregator<L extends Serializable> implements MetricStatsAggregator<L, BinaryClassificationEvaluationContext<L>, BinaryClassificationPointwiseMetricStatsAggregator<L>> {
    private static final long serialVersionUID = -7677193556950322385L;
    private L falseLabel;
    private L truthLabel;
    private int truePositive;
    int falsePositive;
    int trueNegative;
    int falseNegative;

    /* loaded from: input_file:org/apache/ignite/ml/selection/scoring/evaluator/aggregator/BinaryClassificationPointwiseMetricStatsAggregator$WithCustomLabelsAggregator.class */
    public static class WithCustomLabelsAggregator<L extends Serializable> extends BinaryClassificationPointwiseMetricStatsAggregator<L> {
        private final L truthLabel;
        private final L falseLabel;

        public WithCustomLabelsAggregator(L l, L l2) {
            this.truthLabel = l;
            this.falseLabel = l2;
        }

        @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.BinaryClassificationPointwiseMetricStatsAggregator, org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
        public BinaryClassificationEvaluationContext<L> createInitializedContext() {
            return (BinaryClassificationEvaluationContext<L>) new BinaryClassificationEvaluationContext<L>(this.falseLabel, this.truthLabel) { // from class: org.apache.ignite.ml.selection.scoring.evaluator.aggregator.BinaryClassificationPointwiseMetricStatsAggregator.WithCustomLabelsAggregator.1
                private static final long serialVersionUID = 4739649114414953828L;

                @Override // org.apache.ignite.ml.selection.scoring.evaluator.context.EvaluationContext
                public boolean needToCompute() {
                    return false;
                }
            };
        }

        @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.BinaryClassificationPointwiseMetricStatsAggregator, org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
        public /* bridge */ /* synthetic */ void initByContext(Object obj) {
            super.initByContext((BinaryClassificationEvaluationContext) obj);
        }

        @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.BinaryClassificationPointwiseMetricStatsAggregator, org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
        public /* bridge */ /* synthetic */ MetricStatsAggregator mergeWith(MetricStatsAggregator metricStatsAggregator) {
            return super.mergeWith((BinaryClassificationPointwiseMetricStatsAggregator) metricStatsAggregator);
        }
    }

    public BinaryClassificationPointwiseMetricStatsAggregator() {
    }

    public BinaryClassificationPointwiseMetricStatsAggregator(L l, L l2, int i, int i2, int i3, int i4) {
        this.falseLabel = l;
        this.truthLabel = l2;
        this.truePositive = i;
        this.falsePositive = i2;
        this.trueNegative = i3;
        this.falseNegative = i4;
    }

    @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
    public void aggregate(IgniteModel<Vector, L> igniteModel, LabeledVector<L> labeledVector) {
        L predict = igniteModel.predict(labeledVector.features());
        L label = labeledVector.label();
        if (predict.equals(this.falseLabel) && label.equals(this.falseLabel)) {
            this.trueNegative++;
            return;
        }
        if (predict.equals(this.falseLabel) && label.equals(this.truthLabel)) {
            this.falseNegative++;
            return;
        }
        if (predict.equals(this.truthLabel) && label.equals(this.truthLabel)) {
            this.truePositive++;
        } else if (predict.equals(this.truthLabel) && label.equals(this.falseLabel)) {
            this.falsePositive++;
        }
    }

    @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
    public BinaryClassificationPointwiseMetricStatsAggregator<L> mergeWith(BinaryClassificationPointwiseMetricStatsAggregator binaryClassificationPointwiseMetricStatsAggregator) {
        A.ensure(this.falseLabel.equals(binaryClassificationPointwiseMetricStatsAggregator.falseLabel), "this.falseLabel == other.falseLabel");
        A.ensure(this.truthLabel.equals(binaryClassificationPointwiseMetricStatsAggregator.truthLabel), "this.truthLabel == other.truthLabel");
        return new BinaryClassificationPointwiseMetricStatsAggregator<>(this.falseLabel, this.truthLabel, this.truePositive + binaryClassificationPointwiseMetricStatsAggregator.truePositive, this.falsePositive + binaryClassificationPointwiseMetricStatsAggregator.falsePositive, this.trueNegative + binaryClassificationPointwiseMetricStatsAggregator.trueNegative, this.falseNegative + binaryClassificationPointwiseMetricStatsAggregator.falseNegative);
    }

    @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
    public BinaryClassificationEvaluationContext<L> createInitializedContext() {
        return new BinaryClassificationEvaluationContext<>();
    }

    @Override // org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator
    public void initByContext(BinaryClassificationEvaluationContext<L> binaryClassificationEvaluationContext) {
        this.falseLabel = binaryClassificationEvaluationContext.getFirstClsLbl();
        this.truthLabel = binaryClassificationEvaluationContext.getSecondClsLbl();
    }

    public L getFalseLabel() {
        return this.falseLabel;
    }

    public L getTruthLabel() {
        return this.truthLabel;
    }

    public int getTruePositive() {
        return this.truePositive;
    }

    public int getFalsePositive() {
        return this.falsePositive;
    }

    public int getTrueNegative() {
        return this.trueNegative;
    }

    public int getFalseNegative() {
        return this.falseNegative;
    }

    public int getN() {
        return this.truePositive + this.falsePositive + this.trueNegative + this.falseNegative;
    }
}
