package org.tribuo.interop.tensorflow;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiFunction;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.framework.losses.MeanSquaredError;
import org.tensorflow.framework.losses.Reduction;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TNumber;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.interop.tensorflow.protos.OutputConverterProto;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;

@ProtoSerializableClass(version = 0)
/* loaded from: input_file:org/tribuo/interop/tensorflow/RegressorConverter.class */
public class RegressorConverter implements OutputConverter<Regressor> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;

    public static RegressorConverter deserializeFromProto(int i, String str, Any any) {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        if (any.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new RegressorConverter();
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public OutputConverterProto m13serialize() {
        return ProtoUtil.serialize(this);
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public BiFunction<Ops, Pair<Placeholder<? extends TNumber>, Operand<TNumber>>, Operand<TNumber>> loss() {
        return (ops, pair) -> {
            return new MeanSquaredError("tribuo-mse", Reduction.SUM_OVER_BATCH_SIZE).call(ops, (Operand) pair.getA(), (Operand) pair.getB());
        };
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public <U extends TNumber> BiFunction<Ops, Operand<U>, Op> outputTransformFunction() {
        return (v0, v1) -> {
            return v0.identity(v1);
        };
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Prediction<Regressor> convertToPrediction(Tensor tensor, ImmutableOutputInfo<Regressor> immutableOutputInfo, int i, Example<Regressor> example) {
        return new Prediction<>(convertToOutput(tensor, immutableOutputInfo), i, example);
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Regressor convertToOutput(Tensor tensor, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        long[] asArray = getBatchPredictions(tensor, immutableOutputInfo.size()).shape().asArray();
        if (asArray[0] != serialVersionUID) {
            throw new IllegalArgumentException("Supplied tensor has too many results, found " + asArray[0]);
        }
        if (asArray[1] != immutableOutputInfo.size()) {
            throw new IllegalArgumentException("Supplied tensor has an incorrect number of dimensions, shape[1] = " + asArray[1] + ", expected " + immutableOutputInfo.size());
        }
        String[] strArr = new String[immutableOutputInfo.size()];
        double[] dArr = new double[immutableOutputInfo.size()];
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            int intValue = ((Integer) pair.getA()).intValue();
            strArr[intValue] = ((Regressor) pair.getB()).getNames()[0];
            dArr[intValue] = r0.getFloat(new long[]{0, intValue});
        }
        return new Regressor(strArr, dArr);
    }

    private FloatNdArray getBatchPredictions(Tensor tensor, int i) {
        if (!(tensor instanceof TFloat32)) {
            throw new IllegalArgumentException("Tensor is not a 32-bit float. Found type " + tensor.getClass().getName());
        }
        long[] asArray = tensor.shape().asArray();
        if (asArray.length != 2 && asArray.length != 1) {
            throw new IllegalArgumentException("Supplied tensor has the wrong number of dimensions, shape = " + Arrays.toString(asArray));
        }
        if (asArray.length == 1) {
            return ((TFloat32) tensor).slice(new Index[]{Indices.all(), Indices.newAxis()});
        }
        if (asArray[1] != i) {
            throw new IllegalArgumentException("Supplied tensor has incorrect number of elements, tensor value dimension: " + Arrays.toString(asArray) + ", output dimension: " + i);
        }
        return (TFloat32) tensor;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public List<Prediction<Regressor>> convertToBatchPrediction(Tensor tensor, ImmutableOutputInfo<Regressor> immutableOutputInfo, int[] iArr, List<Example<Regressor>> list) {
        List<Regressor> convertToBatchOutput = convertToBatchOutput(tensor, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        if (convertToBatchOutput.size() != list.size() || convertToBatchOutput.size() != iArr.length) {
            throw new IllegalArgumentException("Invalid number of predictions received from Tensorflow, expected " + iArr.length + ", received " + convertToBatchOutput.size());
        }
        for (int i = 0; i < convertToBatchOutput.size(); i++) {
            arrayList.add(new Prediction(convertToBatchOutput.get(i), iArr[i], list.get(i)));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public List<Regressor> convertToBatchOutput(Tensor tensor, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        FloatNdArray batchPredictions = getBatchPredictions(tensor, immutableOutputInfo.size());
        ArrayList arrayList = new ArrayList();
        int i = (int) batchPredictions.shape().asArray()[0];
        String[] strArr = new String[immutableOutputInfo.size()];
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            strArr[((Integer) pair.getA()).intValue()] = ((Regressor) pair.getB()).getNames()[0];
        }
        for (int i2 = 0; i2 < i; i2++) {
            double[] dArr = new double[strArr.length];
            for (int i3 = 0; i3 < strArr.length; i3++) {
                dArr[i3] = batchPredictions.getFloat(new long[]{i2, i3});
            }
            arrayList.add(new Regressor(strArr, dArr));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Tensor convertToTensor(Regressor regressor, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        TFloat32 tensorOf = TFloat32.tensorOf(Shape.of(new long[]{serialVersionUID, immutableOutputInfo.size()}));
        int[] iDtoNaturalOrderMapping = ((ImmutableRegressionInfo) immutableOutputInfo).getIDtoNaturalOrderMapping();
        double[] values = regressor.getValues();
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) ((Pair) it.next()).getA()).intValue();
            tensorOf.setFloat((float) values[iDtoNaturalOrderMapping[intValue]], new long[]{0, intValue});
        }
        return tensorOf;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Tensor convertToTensor(List<Example<Regressor>> list, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        TFloat32 tensorOf = TFloat32.tensorOf(Shape.of(new long[]{list.size(), immutableOutputInfo.size()}));
        int[] iDtoNaturalOrderMapping = ((ImmutableRegressionInfo) immutableOutputInfo).getIDtoNaturalOrderMapping();
        int i = 0;
        Iterator<Example<Regressor>> it = list.iterator();
        while (it.hasNext()) {
            double[] values = it.next().getOutput().getValues();
            for (int i2 = 0; i2 < immutableOutputInfo.size(); i2++) {
                tensorOf.setFloat((float) values[iDtoNaturalOrderMapping[i2]], new long[]{i, i2});
            }
            i++;
        }
        return tensorOf;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public boolean generatesProbabilities() {
        return false;
    }

    public String toString() {
        return "RegressorConverter()";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m14getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OutputConverter");
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Class<Regressor> getTypeWitness() {
        return Regressor.class;
    }
}
