package org.tribuo.regression.ensemble;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.protos.core.EnsembleCombinerProto;
import org.tribuo.regression.Regressor;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXRef;

/* loaded from: input_file:org/tribuo/regression/ensemble/AveragingCombiner.class */
public class AveragingCombiner implements EnsembleCombiner<Regressor> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;

    public static AveragingCombiner deserializeFromProto(int i, String str, Any any) {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        if (any.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new AveragingCombiner();
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public EnsembleCombinerProto m20serialize() {
        EnsembleCombinerProto.Builder newBuilder = EnsembleCombinerProto.newBuilder();
        newBuilder.setClassName(getClass().getName());
        newBuilder.setVersion(0);
        return newBuilder.build();
    }

    public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> immutableOutputInfo, List<Prediction<Regressor>> list) {
        int size = list.size();
        int size2 = immutableOutputInfo.size();
        int i = 0;
        double[] dArr = new double[size2];
        double[] dArr2 = new double[size2];
        for (Prediction<Regressor> prediction : list) {
            if (i < prediction.getNumActiveFeatures()) {
                i = prediction.getNumActiveFeatures();
            }
            Regressor regressor = (Regressor) prediction.getOutput();
            for (int i2 = 0; i2 < size2; i2++) {
                double d = regressor.getValues()[i2];
                double d2 = dArr[i2];
                int i3 = i2;
                dArr[i3] = dArr[i3] + (d - d2);
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + ((d - d2) * (d - dArr[i2]));
            }
        }
        String[] names = ((Regressor) list.get(0).getOutput()).getNames();
        if (size > 1) {
            for (int i5 = 0; i5 < size2; i5++) {
                int i6 = i5;
                dArr2[i6] = dArr2[i6] / (size - 1);
            }
        } else {
            Arrays.fill(dArr2, 0.0d);
        }
        return new Prediction<>(new Regressor(names, dArr, dArr2), i, list.get(0).getExample());
    }

    public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> immutableOutputInfo, List<Prediction<Regressor>> list, float[] fArr) {
        if (list.size() != fArr.length) {
            throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()=" + list.size() + ", weights.length=" + fArr.length);
        }
        int size = immutableOutputInfo.size();
        int i = 0;
        double[] dArr = new double[size];
        double[] dArr2 = new double[size];
        double d = 0.0d;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            Prediction<Regressor> prediction = list.get(i2);
            if (i < prediction.getNumActiveFeatures()) {
                i = prediction.getNumActiveFeatures();
            }
            Regressor regressor = (Regressor) prediction.getOutput();
            float f = fArr[i2];
            d += f;
            for (int i3 = 0; i3 < size; i3++) {
                double d2 = regressor.getValues()[i3];
                double d3 = dArr[i3];
                int i4 = i3;
                dArr[i4] = dArr[i4] + ((f / d) * (d2 - d3));
                int i5 = i3;
                dArr2[i5] = dArr2[i5] + (f * (d2 - d3) * (d2 - dArr[i3]));
            }
        }
        String[] names = ((Regressor) list.get(0).getOutput()).getNames();
        if (fArr.length > 1) {
            for (int i6 = 0; i6 < size; i6++) {
                int i7 = i6;
                dArr2[i7] = dArr2[i7] / (d - 1.0d);
            }
        } else {
            Arrays.fill(dArr2, 0.0d);
        }
        return new Prediction<>(new Regressor(names, dArr, dArr2), i, list.get(0).getExample());
    }

    public String toString() {
        return "MultipleOutputAveragingCombiner()";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m21getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "EnsembleCombiner");
    }

    public Class<Regressor> getTypeWitness() {
        return Regressor.class;
    }

    public ONNXNode exportCombiner(ONNXNode oNNXNode) {
        HashMap hashMap = new HashMap();
        hashMap.put("axes", new int[]{2});
        hashMap.put("keepdims", 0);
        return oNNXNode.apply(ONNXOperators.REDUCE_MEAN, hashMap);
    }

    public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode oNNXNode, T t) {
        ONNXInitializer array = oNNXNode.onnxContext().array("unsqueeze_ensemble_output", new long[]{0, serialVersionUID});
        ONNXInitializer array2 = oNNXNode.onnxContext().array("sum_across_ensemble_axes", new long[]{2});
        return oNNXNode.apply(ONNXOperators.MUL, t.apply(ONNXOperators.UNSQUEEZE, array)).apply(ONNXOperators.REDUCE_SUM, array2, Collections.singletonMap("keepdims", 0)).apply(ONNXOperators.DIV, t.apply(ONNXOperators.REDUCE_SUM));
    }

    public boolean equals(Object obj) {
        return obj instanceof AveragingCombiner;
    }

    public int hashCode() {
        return 31;
    }
}
