package org.tribuo.interop.oci;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.interop.oci.protos.OCIMultiLabelConverterProto;
import org.tribuo.interop.oci.protos.OCIOutputConverterProto;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoUtil;

@ProtoSerializableClass(serializedDataClass = OCIMultiLabelConverterProto.class, version = 0)
/* loaded from: input_file:org/tribuo/interop/oci/OCIMultiLabelConverter.class */
public final class OCIMultiLabelConverter implements OCIOutputConverter<MultiLabel> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    public static final double DEFAULT_THRESHOLD = 0.5d;

    @Config(mandatory = true, description = "Does this converter produce probabilistic outputs.")
    private boolean generatesProbabilities;

    @Config(description = "Threshold for generating a label.")
    private double threshold;

    private OCIMultiLabelConverter() {
        this.threshold = 0.5d;
    }

    public OCIMultiLabelConverter(double d, boolean z) {
        this.threshold = 0.5d;
        this.threshold = d;
        this.generatesProbabilities = z;
        if (z) {
            if (d < 0.0d || d > 1.0d) {
                throw new IllegalArgumentException("Threshold must be between 0 and 1 to generate probabilities, found " + d);
            }
        }
    }

    public void postConfig() {
        if (this.generatesProbabilities) {
            if (this.threshold < 0.0d || this.threshold > 1.0d) {
                throw new PropertyException("", "threshold", "Threshold must be between 0 and 1 to generate probabilities, found " + this.threshold);
            }
        }
    }

    public static OCIMultiLabelConverter deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        OCIMultiLabelConverterProto unpack = any.unpack(OCIMultiLabelConverterProto.class);
        return new OCIMultiLabelConverter(unpack.getThreshold(), unpack.getGeneratesProbabilities());
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public Prediction<MultiLabel> convertOutput(DenseVector denseVector, int i, Example<MultiLabel> example, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        if (denseVector.size() != immutableOutputInfo.size()) {
            throw new IllegalStateException("Expected scores for each output, received " + denseVector.size() + " when there are " + immutableOutputInfo.size() + "outputs");
        }
        HashMap hashMap = new HashMap(immutableOutputInfo.size());
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < denseVector.size(); i2++) {
            double d = denseVector.get(i2);
            String labelString = immutableOutputInfo.getOutput(i2).getLabelString();
            Label label = new Label(labelString, d);
            if (d > this.threshold) {
                hashSet.add(label);
            }
            hashMap.put(labelString, new MultiLabel(label));
        }
        return new Prediction<>(new MultiLabel(hashSet), hashMap, i, example, this.generatesProbabilities);
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public List<Prediction<MultiLabel>> convertOutput(DenseMatrix denseMatrix, int[] iArr, List<Example<MultiLabel>> list, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        if (denseMatrix.getDimension1Size() != list.size()) {
            throw new IllegalStateException("Expected one prediction per example, recieved " + denseMatrix.getDimension1Size() + " predictions when there are " + list.size() + " examples.");
        }
        if (denseMatrix.getDimension2Size() != immutableOutputInfo.size()) {
            throw new IllegalStateException("Expected scores for each output, received " + denseMatrix.getDimension2Size() + " when there are " + immutableOutputInfo.size() + "outputs");
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < denseMatrix.getDimension1Size(); i++) {
            HashMap hashMap = new HashMap(immutableOutputInfo.size());
            HashSet hashSet = new HashSet();
            for (int i2 = 0; i2 < denseMatrix.getDimension2Size(); i2++) {
                double d = denseMatrix.get(i, i2);
                String labelString = immutableOutputInfo.getOutput(i2).getLabelString();
                Label label = new Label(labelString, d);
                if (d > this.threshold) {
                    hashSet.add(label);
                }
                hashMap.put(labelString, new MultiLabel(label));
            }
            arrayList.add(new Prediction(new MultiLabel(hashSet), hashMap, iArr[i], list.get(i), this.generatesProbabilities));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public boolean generatesProbabilities() {
        return this.generatesProbabilities;
    }

    @Override // org.tribuo.interop.oci.OCIOutputConverter
    public Class<MultiLabel> getTypeWitness() {
        return MultiLabel.class;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public String toString() {
        return "OCIMultiLabelConverter(generatesProbabilities=" + this.generatesProbabilities + ")";
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        OCIMultiLabelConverter oCIMultiLabelConverter = (OCIMultiLabelConverter) obj;
        return this.generatesProbabilities == oCIMultiLabelConverter.generatesProbabilities && Double.compare(oCIMultiLabelConverter.threshold, this.threshold) == 0;
    }

    public int hashCode() {
        return Objects.hash(Boolean.valueOf(this.generatesProbabilities), Double.valueOf(this.threshold));
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public OCIOutputConverterProto m9serialize() {
        return ProtoUtil.serialize(this);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m10getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OCIOutputConverter");
    }
}
