package org.tribuo.interop.onnx;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtException;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.classification.Label;
import org.tribuo.protos.ProtoSerializableClass;

@ProtoSerializableClass(version = 0)
/* loaded from: input_file:org/tribuo/interop/onnx/LabelOneVOneTransformer.class */
public final class LabelOneVOneTransformer extends LabelTransformer {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(LabelTransformer.class.getName());
    public static final int CURRENT_VERSION = 0;

    public LabelOneVOneTransformer() {
        super(false);
    }

    public void postConfig() {
        if (this.generatesProbabilities) {
            throw new PropertyException("", "generatesProbabilities", "generatesProbabilities must not be set to true for this class.");
        }
    }

    public static LabelOneVOneTransformer deserializeFromProto(int i, String str, Any any) {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        if (any.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new LabelOneVOneTransformer();
    }

    @Override // org.tribuo.interop.onnx.LabelTransformer
    protected float[][] getBatchPredictions(List<OnnxValue> list, ImmutableOutputInfo<Label> immutableOutputInfo) {
        try {
            if (list.size() == 1) {
                if (!(list.get(0) instanceof OnnxTensor)) {
                    throw new IllegalArgumentException("Expected the first element to be a float OnnxTensor, found " + list.get(0));
                }
                OnnxTensor onnxTensor = list.get(0);
                if (onnxTensor.getInfo().type != OnnxJavaType.FLOAT) {
                    throw new IllegalArgumentException("Expected the first element to be a float OnnxTensor, found " + list.get(0));
                }
                long[] shape = onnxTensor.getInfo().getShape();
                if (shape.length != 2 || shape[1] != (immutableOutputInfo.size() * (immutableOutputInfo.size() - 1)) / 2) {
                    throw new IllegalArgumentException("Invalid shape for the score tensor, expected shape [batchSize,(numOutputs*(numOutputs-1))/2], found " + Arrays.toString(shape));
                }
                int size = immutableOutputInfo.size();
                float[][] fArr = (float[][]) onnxTensor.getValue();
                float[][] fArr2 = new float[(int) shape[0]][size];
                for (int i = 0; i < shape[0]; i++) {
                    int i2 = 0;
                    for (int i3 = 0; i3 < size; i3++) {
                        for (int i4 = i3 + 1; i4 < size; i4++) {
                            if (fArr[i][i2] > 0.0f) {
                                float[] fArr3 = fArr2[i];
                                int i5 = i3;
                                fArr3[i5] = fArr3[i5] + 1.0f;
                            } else {
                                float[] fArr4 = fArr2[i];
                                int i6 = i4;
                                fArr4[i6] = fArr4[i6] + 1.0f;
                            }
                            i2++;
                        }
                    }
                }
                return fArr2;
            }
            if (list.size() != 2) {
                throw new IllegalArgumentException("Unexpected number of OnnxValues returned, expected 2, received " + list.size());
            }
            if (!(list.get(0) instanceof OnnxTensor) || !(list.get(1) instanceof OnnxTensor)) {
                throw new IllegalArgumentException("Expected an OnnxTensor, received a " + list.get(1).getInfo().toString());
            }
            OnnxTensor onnxTensor2 = list.get(0);
            OnnxTensor onnxTensor3 = list.get(1);
            if (onnxTensor3.getInfo().type != OnnxJavaType.FLOAT) {
                throw new IllegalArgumentException("Expected the second element to be a float OnnxTensor, found " + list.get(1));
            }
            long[] shape2 = onnxTensor3.getInfo().getShape();
            if (shape2.length != 2 || (shape2[1] != 2 && shape2[1] != (immutableOutputInfo.size() * (immutableOutputInfo.size() - 1)) / 2)) {
                throw new IllegalArgumentException("Invalid shape for the score tensor, expected shape [batchSize,(numOutputs*(numOutputs-1))/2], found " + Arrays.toString(shape2));
            }
            long[] shape3 = onnxTensor2.getInfo().getShape();
            if (shape3.length != 1 || shape3[0] != shape2[0]) {
                throw new IllegalArgumentException("Invalid shape for labels, did not match the size of the scores, found labels.shape " + Arrays.toString(shape3) + ", and scores.shape " + Arrays.toString(shape2));
            }
            int size2 = immutableOutputInfo.size();
            float[][] fArr5 = (float[][]) onnxTensor3.getValue();
            float[][] fArr6 = new float[(int) shape2[0]][size2];
            for (int i7 = 0; i7 < shape2[0]; i7++) {
                int i8 = 0;
                for (int i9 = 0; i9 < size2; i9++) {
                    for (int i10 = i9 + 1; i10 < size2; i10++) {
                        if (fArr5[i7][i8] > 0.0f) {
                            float[] fArr7 = fArr6[i7];
                            int i11 = i9;
                            fArr7[i11] = fArr7[i11] + 1.0f;
                        } else {
                            float[] fArr8 = fArr6[i7];
                            int i12 = i10;
                            fArr8[i12] = fArr8[i12] + 1.0f;
                        }
                        i8++;
                    }
                }
            }
            return fArr6;
        } catch (OrtException e) {
            throw new IllegalStateException("Failed to read a value out of the onnx result.", e);
        }
    }

    @Override // org.tribuo.interop.onnx.LabelTransformer
    public String toString() {
        return "LabelOneVOneTransformer(generatesProbabilities=" + this.generatesProbabilities + ")";
    }
}
