package org.tribuo.classification.libsvm;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.libsvm.protos.LibSVMClassificationModelProto;
import org.tribuo.common.libsvm.KernelType;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.LibSVMTrainer;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

/* loaded from: input_file:org/tribuo/classification/libsvm/LibSVMClassificationModel.class */
public class LibSVMClassificationModel extends LibSVMModel<Label> implements ONNXExportable {
    private static final long serialVersionUID = 3;
    public static final int CURRENT_VERSION = 0;
    private final Set<Label> unobservedLabels;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LibSVMClassificationModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, List<svm_model> list) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, list.get(0).param.probability == 1, list);
        int[] iArr = list.get(0).label;
        if (iArr.length == immutableOutputInfo.size()) {
            this.unobservedLabels = Collections.emptySet();
            return;
        }
        HashMap hashMap = new HashMap();
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            hashMap.put((Integer) pair.getA(), (Label) pair.getB());
        }
        for (int i = 0; i < iArr.length; i++) {
            hashMap.remove(Integer.valueOf(i));
        }
        HashSet hashSet = new HashSet(hashMap.values().size());
        Iterator it2 = hashMap.values().iterator();
        while (it2.hasNext()) {
            hashSet.add(new Label(((Label) it2.next()).getLabel(), 0.0d));
        }
        this.unobservedLabels = Collections.unmodifiableSet(hashSet);
    }

    public static LibSVMClassificationModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        LibSVMClassificationModelProto unpack = any.unpack(LibSVMClassificationModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(Label.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + deserialize.outputDomain().getClass());
        }
        return new LibSVMClassificationModel(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), deserialize.outputDomain(), Collections.singletonList(deserializeModel(unpack.getModel())));
    }

    public int getNumberOfSupportVectors() {
        return ((svm_model) this.models.get(0)).SV.length;
    }

    public Prediction<Label> predict(Example<Label> example) {
        svm_model svm_modelVar = (svm_model) this.models.get(0);
        svm_node[] exampleToNodes = LibSVMTrainer.exampleToNodes(example, this.featureIDMap, (List) null);
        if (exampleToNodes.length == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        int[] iArr = svm_modelVar.label;
        double[] dArr = new double[iArr.length];
        if (this.generatesProbabilities) {
            svm.svm_predict_probability(svm_modelVar, exampleToNodes, dArr);
        } else {
            double[] dArr2 = new double[(iArr.length * (iArr.length - 1)) / 2];
            svm.svm_predict_values(svm_modelVar, exampleToNodes, dArr2);
            int i = 0;
            for (int i2 = 0; i2 < iArr.length; i2++) {
                for (int i3 = i2 + 1; i3 < iArr.length; i3++) {
                    if (dArr2[i] > 0.0d) {
                        int i4 = i2;
                        dArr[i4] = dArr[i4] + 1.0d;
                    } else {
                        int i5 = i3;
                        dArr[i5] = dArr[i5] + 1.0d;
                    }
                    i++;
                }
            }
        }
        double d = Double.NEGATIVE_INFINITY;
        Label label = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i6 = 0; i6 < dArr.length; i6++) {
            String label2 = this.outputIDInfo.getOutput(iArr[i6]).getLabel();
            Label label3 = new Label(label2, dArr[i6]);
            linkedHashMap.put(label2, label3);
            if (label3.getScore() > d) {
                d = label3.getScore();
                label = label3;
            }
        }
        if (!this.unobservedLabels.isEmpty()) {
            for (Label label4 : this.unobservedLabels) {
                linkedHashMap.put(label4.getLabel(), label4);
            }
        }
        return new Prediction<>(label, linkedHashMap, exampleToNodes.length, example, this.generatesProbabilities);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public LibSVMClassificationModel m0copy(String str, ModelProvenance modelProvenance) {
        return new LibSVMClassificationModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, Collections.singletonList(LibSVMModel.copyModel((svm_model) this.models.get(0))));
    }

    public OnnxMl.ModelProto exportONNXModel(String str, long j) {
        ONNXContext oNNXContext = new ONNXContext();
        ONNXPlaceholder floatInput = oNNXContext.floatInput(this.featureIDMap.size());
        ONNXPlaceholder floatOutput = oNNXContext.floatOutput(this.outputIDInfo.size());
        oNNXContext.setName("Classification-LibSVM");
        writeONNXGraph(floatInput).assignTo(floatOutput);
        return ONNXExportable.buildModel(oNNXContext, str, j, this);
    }

    public ONNXNode writeONNXGraph(ONNXRef<?> oNNXRef) {
        ONNXContext onnxContext = oNNXRef.onnxContext();
        svm_model svm_modelVar = (svm_model) this.models.get(0);
        int length = (svm_modelVar.label.length * (svm_modelVar.label.length - 1)) / 2;
        int size = this.featureIDMap.size();
        HashMap hashMap = new HashMap();
        hashMap.put("classlabels_ints", svm_modelVar.label);
        float[] fArr = new float[svm_modelVar.l * (svm_modelVar.nr_class - 1)];
        for (int i = 0; i < svm_modelVar.nr_class - 1; i++) {
            for (int i2 = 0; i2 < svm_modelVar.l; i2++) {
                fArr[(i * svm_modelVar.l) + i2] = (float) svm_modelVar.sv_coef[i][i2];
            }
        }
        hashMap.put("coefficients", fArr);
        hashMap.put("kernel_params", new float[]{(float) svm_modelVar.param.gamma, (float) svm_modelVar.param.coef0, svm_modelVar.param.degree});
        hashMap.put("kernel_type", KernelType.getKernelType(svm_modelVar.param.kernel_type).name());
        float[] fArr2 = new float[svm_modelVar.rho.length];
        for (int i3 = 0; i3 < fArr2.length; i3++) {
            fArr2[i3] = (float) (-svm_modelVar.rho[i3]);
        }
        hashMap.put("rho", fArr2);
        float[] fArr3 = new float[svm_modelVar.l * size];
        for (int i4 = 0; i4 < svm_modelVar.l; i4++) {
            for (svm_node svm_nodeVar : svm_modelVar.SV[i4]) {
                fArr3[(i4 * size) + svm_nodeVar.index] = (float) svm_nodeVar.value;
            }
        }
        hashMap.put("support_vectors", fArr3);
        hashMap.put("vectors_per_class", Arrays.copyOf(svm_modelVar.nSV, svm_modelVar.label.length));
        if (this.generatesProbabilities) {
            hashMap.put("prob_a", Arrays.copyOf(Util.toFloatArray(svm_modelVar.probA), length));
            hashMap.put("prob_b", Arrays.copyOf(Util.toFloatArray(svm_modelVar.probB), length));
        }
        List apply = oNNXRef.apply(ONNXOperators.SVM_CLASSIFIER, Arrays.asList("pred_label", "svm_output"), hashMap);
        ONNXNode oNNXNode = (ONNXNode) apply.get(1);
        ONNXNode oNNXNode2 = oNNXNode;
        if (!this.generatesProbabilities) {
            oNNXNode2 = svm_modelVar.nr_class == 2 ? writeDecisionFunction(oNNXNode.apply(ONNXOperators.MUL, onnxContext.constant("minus_one", -1.0f))) : writeDecisionFunction(oNNXNode);
        }
        int[] iArr = new int[svm_modelVar.label.length];
        for (int i5 = 0; i5 < svm_modelVar.label.length; i5++) {
            iArr[svm_modelVar.label[i5]] = i5;
        }
        return oNNXNode2.apply(ONNXOperators.GATHER, onnxContext.array("label_indices", iArr), Collections.singletonMap("axis", 1));
    }

    private ONNXNode writeDecisionFunction(ONNXNode oNNXNode) {
        ONNXContext onnxContext = oNNXNode.onnxContext();
        ONNXInitializer constant = onnxContext.constant("one", 1.0f);
        ONNXNode cast = oNNXNode.apply(ONNXOperators.LESS, onnxContext.constant("zero", 0.0f)).cast(Float.TYPE);
        svm_model svm_modelVar = (svm_model) this.models.get(0);
        TreeMap treeMap = new TreeMap();
        int i = 0;
        for (int i2 = 0; i2 < svm_modelVar.nr_class; i2++) {
            for (int i3 = i2 + 1; i3 < svm_modelVar.nr_class; i3++) {
                ONNXNode apply = cast.apply(ONNXOperators.ARRAY_FEATURE_EXTRACTOR, onnxContext.constant("Vind_" + i, i), "Vsvcv_" + i);
                ((List) treeMap.computeIfAbsent(Integer.valueOf(i3), num -> {
                    return new ArrayList();
                })).add(apply);
                ((List) treeMap.computeIfAbsent(Integer.valueOf(i2), num2 -> {
                    return new ArrayList();
                })).add(apply.apply(ONNXOperators.NEG, "Vnegv_" + i).apply(ONNXOperators.ADD, constant, "Vnegv1_" + i));
                i++;
            }
        }
        return onnxContext.operation(ONNXOperators.CONCAT, (List) treeMap.values().stream().map(list -> {
            return onnxContext.operation(ONNXOperators.SUM, list, "svm_votes");
        }).collect(Collectors.toList()), "svm_output", Collections.singletonMap("axis", 1));
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m1serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        LibSVMClassificationModelProto.Builder newBuilder = LibSVMClassificationModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setModel(serializeModel((svm_model) this.models.get(0)));
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m47build()));
        newBuilder2.setClassName(LibSVMClassificationModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }
}
