package org.tribuo.interop.tensorflow;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiFunction;
import java.util.logging.Logger;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.framework.op.FrameworkOps;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Mean;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TNumber;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.interop.tensorflow.protos.OutputConverterProto;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoUtil;

@ProtoSerializableClass(version = 0)
/* loaded from: input_file:org/tribuo/interop/tensorflow/MultiLabelConverter.class */
public class MultiLabelConverter implements OutputConverter<MultiLabel> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(MultiLabelConverter.class.getName());
    public static final int CURRENT_VERSION = 0;
    public static final double THRESHOLD = 0.5d;

    public static MultiLabelConverter deserializeFromProto(int i, String str, Any any) {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        if (any.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new MultiLabelConverter();
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public OutputConverterProto m11serialize() {
        return ProtoUtil.serialize(this);
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public BiFunction<Ops, Pair<Placeholder<? extends TNumber>, Operand<TNumber>>, Operand<TNumber>> loss() {
        return (ops, pair) -> {
            FrameworkOps create = FrameworkOps.create(ops);
            return ops.math.mean(create.nn.sigmoidCrossEntropyWithLogits((Placeholder) pair.getA(), (Operand) pair.getB()), ops.constant(0), new Mean.Options[0]);
        };
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public <V extends TNumber> BiFunction<Ops, Operand<V>, Op> outputTransformFunction() {
        return (ops, operand) -> {
            return ops.math.sigmoid(operand);
        };
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Prediction<MultiLabel> convertToPrediction(Tensor tensor, ImmutableOutputInfo<MultiLabel> immutableOutputInfo, int i, Example<MultiLabel> example) {
        FloatNdArray batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        long j = batchPredictions.shape().asArray()[0];
        if (j != serialVersionUID) {
            throw new IllegalArgumentException("Supplied tensor has too many results, batchSize = " + j);
        }
        return generatePrediction(batchPredictions.slice(new Index[]{Indices.at(0L), Indices.all()}), immutableOutputInfo, i, example);
    }

    private Prediction<MultiLabel> generatePrediction(FloatNdArray floatNdArray, ImmutableOutputInfo<MultiLabel> immutableOutputInfo, int i, Example<MultiLabel> example) {
        long[] asArray = floatNdArray.shape().asArray();
        if (asArray.length != 1) {
            throw new IllegalArgumentException("Failed to get scalar predictions. Found " + Arrays.toString(asArray));
        }
        if (asArray[0] > 2147483647L) {
            throw new IllegalArgumentException("More than Integer.MAX_VALUE predictions. Found " + asArray[0]);
        }
        int i2 = (int) asArray[0];
        HashMap hashMap = new HashMap(immutableOutputInfo.size());
        HashSet hashSet = new HashSet();
        for (int i3 = 0; i3 < i2; i3++) {
            String labelString = immutableOutputInfo.getOutput(i3).getLabelString();
            double d = floatNdArray.getFloat(new long[]{i3});
            Label label = new Label(labelString, d);
            if (d > 0.5d) {
                hashSet.add(label);
            }
            hashMap.put(labelString, new MultiLabel(label));
        }
        return new Prediction<>(new MultiLabel(hashSet), hashMap, i, example, true);
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public MultiLabel convertToOutput(Tensor tensor, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        FloatNdArray batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        long j = batchPredictions.shape().asArray()[0];
        if (j != serialVersionUID) {
            throw new IllegalArgumentException("Supplied tensor has too many results, batchSize = " + j);
        }
        return generateMultiLabel(batchPredictions.slice(new Index[]{Indices.at(0L), Indices.all()}), immutableOutputInfo);
    }

    private MultiLabel generateMultiLabel(FloatNdArray floatNdArray, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        long[] asArray = floatNdArray.shape().asArray();
        if (asArray.length != 1) {
            throw new IllegalArgumentException("Failed to get scalar predictions. Found " + Arrays.toString(asArray));
        }
        if (asArray[0] > 2147483647L) {
            throw new IllegalArgumentException("More than Integer.MAX_VALUE predictions. Found " + asArray[0]);
        }
        int i = (int) asArray[0];
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < i; i2++) {
            double d = floatNdArray.getFloat(new long[]{i2});
            Label label = new Label(immutableOutputInfo.getOutput(i2).getLabelString(), d);
            if (d > 0.5d) {
                hashSet.add(label);
            }
        }
        return new MultiLabel(hashSet);
    }

    private FloatNdArray getBatchPredictions(Tensor tensor, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        long[] asArray = tensor.shape().asArray();
        if (asArray.length != 2) {
            throw new IllegalArgumentException("Supplied tensor has the wrong number of dimensions, shape = " + Arrays.toString(asArray));
        }
        int i = (int) asArray[1];
        if (i != immutableOutputInfo.size()) {
            throw new IllegalArgumentException("Supplied tensor has incorrect number of elements, tensor output dimension: " + i + ", outputInfo dimension: " + immutableOutputInfo.size());
        }
        if (tensor instanceof TFloat16) {
            return (TFloat16) tensor;
        }
        if (tensor instanceof TFloat32) {
            return (TFloat32) tensor;
        }
        throw new IllegalArgumentException("Tensor is not a probability distribution. Found type " + tensor.getClass().getName());
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public List<Prediction<MultiLabel>> convertToBatchPrediction(Tensor tensor, ImmutableOutputInfo<MultiLabel> immutableOutputInfo, int[] iArr, List<Example<MultiLabel>> list) {
        FloatNdArray batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        int i = (int) batchPredictions.shape().asArray()[0];
        if (i != list.size() || i != iArr.length) {
            throw new IllegalArgumentException("Invalid number of predictions received from Tensorflow, expected " + iArr.length + ", received " + i);
        }
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(generatePrediction(batchPredictions.slice(new Index[]{Indices.at(i2), Indices.all()}), immutableOutputInfo, iArr[i2], list.get(i2)));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public List<MultiLabel> convertToBatchOutput(Tensor tensor, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        FloatNdArray batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        int i = (int) batchPredictions.shape().asArray()[0];
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(generateMultiLabel(batchPredictions.slice(new Index[]{Indices.at(i2), Indices.all()}), immutableOutputInfo));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Tensor convertToTensor(MultiLabel multiLabel, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        SparseVector convertToSparseVector = multiLabel.convertToSparseVector(immutableOutputInfo);
        TFloat32 tensorOf = TFloat32.tensorOf(Shape.of(new long[]{serialVersionUID, immutableOutputInfo.size()}));
        for (int i = 0; i < immutableOutputInfo.size(); i++) {
            tensorOf.setFloat(0.0f, new long[]{0, i});
        }
        VectorIterator it = convertToSparseVector.iterator();
        while (it.hasNext()) {
            tensorOf.setFloat((float) ((VectorTuple) it.next()).value, new long[]{0, r0.index});
        }
        return tensorOf;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Tensor convertToTensor(List<Example<MultiLabel>> list, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        TFloat32 tensorOf = TFloat32.tensorOf(Shape.of(new long[]{list.size(), immutableOutputInfo.size()}));
        int i = 0;
        Iterator<Example<MultiLabel>> it = list.iterator();
        while (it.hasNext()) {
            SparseVector convertToSparseVector = it.next().getOutput().convertToSparseVector(immutableOutputInfo);
            for (int i2 = 0; i2 < immutableOutputInfo.size(); i2++) {
                tensorOf.setFloat(0.0f, new long[]{i, i2});
            }
            VectorIterator it2 = convertToSparseVector.iterator();
            while (it2.hasNext()) {
                tensorOf.setFloat((float) ((VectorTuple) it2.next()).value, new long[]{i, r0.index});
            }
            i++;
        }
        return tensorOf;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public boolean generatesProbabilities() {
        return true;
    }

    public String toString() {
        return "MultiLabelConverter()";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m12getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OutputConverter");
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Class<MultiLabel> getTypeWitness() {
        return MultiLabel.class;
    }
}
