package org.tribuo.interop.oci;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.interop.oci.protos.OCILabelConverterProto;
import org.tribuo.interop.oci.protos.OCIOutputConverterProto;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.ProtoUtil;

@ProtoSerializableClass(serializedDataClass = OCILabelConverterProto.class, version = 0)
/* loaded from: input_file:org/tribuo/interop/oci/OCILabelConverter.class */
public final class OCILabelConverter implements OCIOutputConverter<Label> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;

    @ProtoSerializableField
    @Config(mandatory = true, description = "Does this converter produce probabilistic outputs.")
    private boolean generatesProbabilities;

    private OCILabelConverter() {
    }

    public OCILabelConverter(boolean z) {
        this.generatesProbabilities = z;
    }

    public static OCILabelConverter deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        return new OCILabelConverter(any.unpack(OCILabelConverterProto.class).getGeneratesProbabilities());
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public Prediction<Label> convertOutput(DenseVector denseVector, int i, Example<Label> example, ImmutableOutputInfo<Label> immutableOutputInfo) {
        if (denseVector.size() == 1) {
            double d = denseVector.get(0);
            if (d != ((int) d)) {
                throw new IllegalStateException("Expected a class index, found " + d);
            }
            Label output = immutableOutputInfo.getOutput((int) d);
            if (output != null) {
                return new Prediction<>(output, i, example);
            }
            throw new IllegalStateException("Expected a class index in the range 0 - " + immutableOutputInfo.size() + " received " + ((int) d));
        }
        if (denseVector.size() != immutableOutputInfo.size()) {
            throw new IllegalStateException("Expected scores for each output, received " + denseVector.size() + " when there are " + immutableOutputInfo.size() + "outputs");
        }
        Label label = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i2 = 0; i2 < denseVector.size(); i2++) {
            String label2 = immutableOutputInfo.getOutput(i2).getLabel();
            Label label3 = new Label(label2, denseVector.get(i2));
            linkedHashMap.put(label2, label3);
            if (label == null || label3.getScore() > label.getScore()) {
                label = label3;
            }
        }
        return new Prediction<>(label, linkedHashMap, i, example, this.generatesProbabilities);
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public List<Prediction<Label>> convertOutput(DenseMatrix denseMatrix, int[] iArr, List<Example<Label>> list, ImmutableOutputInfo<Label> immutableOutputInfo) {
        if (denseMatrix.getDimension1Size() != list.size()) {
            throw new IllegalStateException("Expected one prediction per example, recieved " + denseMatrix.getDimension1Size() + " predictions when there are " + list.size() + " examples.");
        }
        ArrayList arrayList = new ArrayList();
        if (denseMatrix.getDimension2Size() == 1) {
            for (int i = 0; i < denseMatrix.getDimension1Size(); i++) {
                double d = denseMatrix.get(i, 0);
                if (d != ((int) d)) {
                    throw new IllegalStateException("Expected a class index at position " + i + ", found " + d);
                }
                Label output = immutableOutputInfo.getOutput((int) d);
                if (output == null) {
                    throw new IllegalStateException("Expected a class index at position " + i + " in the range 0 - " + immutableOutputInfo.size() + " received " + ((int) d));
                }
                arrayList.add(new Prediction(output, iArr[i], list.get(i)));
            }
        } else {
            if (denseMatrix.getDimension2Size() != immutableOutputInfo.size()) {
                throw new IllegalStateException("Expected scores for each output, received " + denseMatrix.getDimension2Size() + " when there are " + immutableOutputInfo.size() + "outputs");
            }
            for (int i2 = 0; i2 < denseMatrix.getDimension1Size(); i2++) {
                Label label = null;
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                for (int i3 = 0; i3 < denseMatrix.getDimension2Size(); i3++) {
                    String label2 = immutableOutputInfo.getOutput(i3).getLabel();
                    Label label3 = new Label(label2, denseMatrix.get(i2, i3));
                    linkedHashMap.put(label2, label3);
                    if (label == null || label3.getScore() > label.getScore()) {
                        label = label3;
                    }
                }
                arrayList.add(new Prediction(label, linkedHashMap, iArr[i2], list.get(i2), this.generatesProbabilities));
            }
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public boolean generatesProbabilities() {
        return this.generatesProbabilities;
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public Class<Label> getTypeWitness() {
        return Label.class;
    }

    public String toString() {
        return "OCILabelConverter(generatesProbabilities=" + this.generatesProbabilities + ")";
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        return obj != null && getClass() == obj.getClass() && this.generatesProbabilities == ((OCILabelConverter) obj).generatesProbabilities;
    }

    public int hashCode() {
        return Objects.hash(Boolean.valueOf(this.generatesProbabilities));
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public OCIOutputConverterProto m0serialize() {
        return ProtoUtil.serialize(this);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m1getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OCIOutputConverter");
    }
}
