package org.tribuo.interop.oci;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.interop.oci.protos.OCIOutputConverterProto;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.regression.Regressor;

@ProtoSerializableClass(version = 0)
/* loaded from: input_file:org/tribuo/interop/oci/OCIRegressorConverter.class */
public final class OCIRegressorConverter implements OCIOutputConverter<Regressor> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;

    public static OCIRegressorConverter deserializeFromProto(int i, String str, Any any) {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        if (any.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new OCIRegressorConverter();
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public Prediction<Regressor> convertOutput(DenseVector denseVector, int i, Example<Regressor> example, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        if (denseVector.size() != immutableOutputInfo.size()) {
            throw new IllegalStateException("Expected scores for each output, received " + denseVector.size() + " when there are " + immutableOutputInfo.size() + "outputs");
        }
        String[] strArr = new String[immutableOutputInfo.size()];
        double[] dArr = new double[immutableOutputInfo.size()];
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            int intValue = ((Integer) pair.getA()).intValue();
            strArr[intValue] = ((Regressor) pair.getB()).getNames()[0];
            dArr[intValue] = denseVector.get(intValue);
        }
        return new Prediction<>(new Regressor(strArr, dArr), i, example);
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public List<Prediction<Regressor>> convertOutput(DenseMatrix denseMatrix, int[] iArr, List<Example<Regressor>> list, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        if (denseMatrix.getDimension1Size() != list.size()) {
            throw new IllegalStateException("Expected one prediction per example, recieved " + denseMatrix.getDimension1Size() + " predictions when there are " + list.size() + " examples.");
        }
        ArrayList arrayList = new ArrayList();
        if (denseMatrix.getDimension2Size() != immutableOutputInfo.size()) {
            throw new IllegalStateException("Expected scores for each output, received " + denseMatrix.getDimension2Size() + " when there are " + immutableOutputInfo.size() + "outputs");
        }
        String[] strArr = new String[immutableOutputInfo.size()];
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            strArr[((Integer) pair.getA()).intValue()] = ((Regressor) pair.getB()).getNames()[0];
        }
        for (int i = 0; i < denseMatrix.getDimension1Size(); i++) {
            double[] dArr = new double[strArr.length];
            for (int i2 = 0; i2 < strArr.length; i2++) {
                dArr[i2] = denseMatrix.get(i, i2);
            }
            arrayList.add(new Prediction(new Regressor(strArr, dArr), iArr[i], list.get(i)));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public boolean generatesProbabilities() {
        return false;
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public Class<Regressor> getTypeWitness() {
        return Regressor.class;
    }

    public String toString() {
        return "OCIRegressorConverter()";
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        return obj != null && getClass() == obj.getClass();
    }

    public int hashCode() {
        return 31;
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public OCIOutputConverterProto m11serialize() {
        return ProtoUtil.serialize(this);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m12getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OCIOutputConverter");
    }
}
