package org.tribuo.classification.ensemble;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.protos.core.EnsembleCombinerProto;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXRef;

/* loaded from: input_file:org/tribuo/classification/ensemble/FullyWeightedVotingCombiner.class */
public final class FullyWeightedVotingCombiner implements EnsembleCombiner<Label> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;

    public static FullyWeightedVotingCombiner deserializeFromProto(int i, String str, Any any) {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        if (any.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new FullyWeightedVotingCombiner();
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public EnsembleCombinerProto m24serialize() {
        EnsembleCombinerProto.Builder newBuilder = EnsembleCombinerProto.newBuilder();
        newBuilder.setClassName(getClass().getName());
        newBuilder.setVersion(0);
        return newBuilder.build();
    }

    public Prediction<Label> combine(ImmutableOutputInfo<Label> immutableOutputInfo, List<Prediction<Label>> list) {
        int i = 0;
        double size = 1.0d / list.size();
        double d = 0.0d;
        double[] dArr = new double[immutableOutputInfo.size()];
        for (Prediction<Label> prediction : list) {
            if (i < prediction.getNumActiveFeatures()) {
                i = prediction.getNumActiveFeatures();
            }
            for (Label label : prediction.getOutputScores().values()) {
                double score = size * label.getScore();
                d += score;
                int id = immutableOutputInfo.getID(label);
                dArr[id] = dArr[id] + score;
            }
        }
        double d2 = Double.NEGATIVE_INFINITY;
        Label label2 = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            String label3 = ((Label) immutableOutputInfo.getOutput(i2)).getLabel();
            Label label4 = new Label(label3, dArr[i2] / d);
            linkedHashMap.put(label3, label4);
            if (label4.getScore() > d2) {
                d2 = label4.getScore();
                label2 = label4;
            }
        }
        return new Prediction<>(label2, linkedHashMap, i, list.get(0).getExample(), true);
    }

    public Prediction<Label> combine(ImmutableOutputInfo<Label> immutableOutputInfo, List<Prediction<Label>> list, float[] fArr) {
        if (list.size() != fArr.length) {
            throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()=" + list.size() + ", weights.length=" + fArr.length);
        }
        int i = 0;
        double d = 0.0d;
        double[] dArr = new double[immutableOutputInfo.size()];
        for (int i2 = 0; i2 < fArr.length; i2++) {
            Prediction<Label> prediction = list.get(i2);
            if (i < prediction.getNumActiveFeatures()) {
                i = prediction.getNumActiveFeatures();
            }
            for (Label label : prediction.getOutputScores().values()) {
                double score = fArr[i2] * label.getScore();
                d += score;
                int id = immutableOutputInfo.getID(label);
                dArr[id] = dArr[id] + score;
            }
        }
        double d2 = Double.NEGATIVE_INFINITY;
        Label label2 = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i3 = 0; i3 < dArr.length; i3++) {
            String label3 = ((Label) immutableOutputInfo.getOutput(i3)).getLabel();
            Label label4 = new Label(label3, dArr[i3] / d);
            linkedHashMap.put(label3, label4);
            if (label4.getScore() > d2) {
                d2 = label4.getScore();
                label2 = label4;
            }
        }
        return new Prediction<>(label2, linkedHashMap, i, list.get(0).getExample(), true);
    }

    public String toString() {
        return "FullyWeightedVotingCombiner()";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m25getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "EnsembleCombiner");
    }

    public Class<Label> getTypeWitness() {
        return Label.class;
    }

    public ONNXNode exportCombiner(ONNXNode oNNXNode) {
        HashMap hashMap = new HashMap();
        hashMap.put("axes", new int[]{2});
        hashMap.put("keepdims", 0);
        return oNNXNode.apply(ONNXOperators.REDUCE_MEAN, hashMap);
    }

    public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode oNNXNode, T t) {
        return oNNXNode.apply(ONNXOperators.MUL, t.apply(ONNXOperators.UNSQUEEZE, oNNXNode.onnxContext().array("unsqueeze_ensemble_output", new long[]{0, serialVersionUID}))).apply(ONNXOperators.REDUCE_SUM, oNNXNode.onnxContext().array("sum_across_ensemble_axes", new long[]{2}), Collections.singletonMap("keepdims", 0)).apply(ONNXOperators.DIV, t.apply(ONNXOperators.REDUCE_SUM));
    }

    public boolean equals(Object obj) {
        return obj instanceof FullyWeightedVotingCombiner;
    }

    public int hashCode() {
        return 31;
    }
}
