package org.tribuo.interop.onnx;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtException;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.interop.onnx.protos.OutputTransformerProto;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.regression.Regressor;

@ProtoSerializableClass(version = 0)
/* loaded from: input_file:org/tribuo/interop/onnx/RegressorTransformer.class */
public class RegressorTransformer implements OutputTransformer<Regressor> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;

    public static RegressorTransformer deserializeFromProto(int i, String str, Any any) {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        if (any.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new RegressorTransformer();
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public Prediction<Regressor> transformToPrediction(List<OnnxValue> list, ImmutableOutputInfo<Regressor> immutableOutputInfo, int i, Example<Regressor> example) {
        return new Prediction<>(transformToOutput2(list, immutableOutputInfo), i, example);
    }

    /* renamed from: transformToOutput, reason: avoid collision after fix types in other method */
    public Regressor transformToOutput2(List<OnnxValue> list, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        float[][] batchPredictions = getBatchPredictions(list);
        if (batchPredictions.length != 1) {
            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + batchPredictions.length);
        }
        if (batchPredictions[0].length != immutableOutputInfo.size()) {
            throw new IllegalArgumentException("Supplied tensor has an incorrect number of dimensions, predictions[0].length = " + batchPredictions[0].length + ", expected " + immutableOutputInfo.size());
        }
        String[] strArr = new String[immutableOutputInfo.size()];
        double[] dArr = new double[immutableOutputInfo.size()];
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            int intValue = ((Integer) pair.getA()).intValue();
            strArr[intValue] = ((Regressor) pair.getB()).getNames()[0];
            dArr[intValue] = batchPredictions[0][intValue];
        }
        return new Regressor(strArr, dArr);
    }

    private float[][] getBatchPredictions(List<OnnxValue> list) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("Supplied output has incorrect number of elements, expected 1, found " + list.size());
        }
        OnnxTensor onnxTensor = (OnnxValue) list.get(0);
        if (!(onnxTensor instanceof OnnxTensor)) {
            throw new IllegalArgumentException("Supplied output was not an OnnxTensor, found " + onnxTensor.getClass().toString());
        }
        OnnxTensor onnxTensor2 = onnxTensor;
        long[] shape = onnxTensor2.getInfo().getShape();
        if (shape.length != 2) {
            throw new IllegalArgumentException("Expected shape [batchSize][numDimensions], found " + Arrays.toString(shape));
        }
        try {
            if (onnxTensor2.getInfo().type == OnnxJavaType.FLOAT) {
                return (float[][]) onnxTensor2.getValue();
            }
            throw new IllegalArgumentException("Supplied output was an invalid tensor type, expected float, found " + onnxTensor2.getInfo().type);
        } catch (OrtException e) {
            throw new IllegalStateException("Failed to read tensor value", e);
        }
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public List<Prediction<Regressor>> transformToBatchPrediction(List<OnnxValue> list, ImmutableOutputInfo<Regressor> immutableOutputInfo, int[] iArr, List<Example<Regressor>> list2) {
        List<Regressor> transformToBatchOutput = transformToBatchOutput(list, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        if (transformToBatchOutput.size() != list2.size() || transformToBatchOutput.size() != iArr.length) {
            throw new IllegalArgumentException("Invalid number of predictions received from the ONNXExternalModel, expected " + iArr.length + ", received " + transformToBatchOutput.size());
        }
        for (int i = 0; i < transformToBatchOutput.size(); i++) {
            arrayList.add(new Prediction(transformToBatchOutput.get(i), iArr[i], list2.get(i)));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public List<Regressor> transformToBatchOutput(List<OnnxValue> list, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        float[][] batchPredictions = getBatchPredictions(list);
        ArrayList arrayList = new ArrayList();
        String[] strArr = new String[immutableOutputInfo.size()];
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            strArr[((Integer) pair.getA()).intValue()] = ((Regressor) pair.getB()).getNames()[0];
        }
        for (float[] fArr : batchPredictions) {
            double[] dArr = new double[strArr.length];
            for (int i = 0; i < strArr.length; i++) {
                dArr[i] = fArr[i];
            }
            arrayList.add(new Regressor(strArr, dArr));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public boolean generatesProbabilities() {
        return false;
    }

    public String toString() {
        return "RegressorTransformer()";
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public Class<Regressor> getTypeWitness() {
        return Regressor.class;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        return obj != null && getClass() == obj.getClass();
    }

    public int hashCode() {
        return 31;
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public OutputTransformerProto m16serialize() {
        return ProtoUtil.serialize(this);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m17getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OutputTransformer");
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public /* bridge */ /* synthetic */ Regressor transformToOutput(List list, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        return transformToOutput2((List<OnnxValue>) list, immutableOutputInfo);
    }
}
