package org.tribuo.regression.libsvm;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.common.libsvm.KernelType;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.LibSVMTrainer;
import org.tribuo.common.libsvm.protos.SVMModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.libsvm.protos.LibSVMRegressionModelProto;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

/* loaded from: input_file:org/tribuo/regression/libsvm/LibSVMRegressionModel.class */
public class LibSVMRegressionModel extends LibSVMModel<Regressor> implements ONNXExportable {
    private static final long serialVersionUID = 2;
    public static final int CURRENT_VERSION = 0;
    private final String[] dimensionNames;
    private double[] means;
    private double[] variances;
    private final boolean standardized;
    private int[] mapping;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LibSVMRegressionModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, List<svm_model> list) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, false, list);
        this.dimensionNames = Regressor.extractNames(immutableOutputInfo);
        this.means = null;
        this.variances = null;
        this.standardized = false;
        this.mapping = ((ImmutableRegressionInfo) immutableOutputInfo).getIDtoNaturalOrderMapping();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LibSVMRegressionModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, List<svm_model> list, double[] dArr, double[] dArr2) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, false, list);
        this.dimensionNames = Regressor.extractNames(immutableOutputInfo);
        this.means = dArr;
        this.variances = dArr2;
        this.standardized = true;
        this.mapping = ((ImmutableRegressionInfo) immutableOutputInfo).getIDtoNaturalOrderMapping();
    }

    private LibSVMRegressionModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, List<svm_model> list, double[] dArr, double[] dArr2, boolean z) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, false, list);
        this.dimensionNames = Regressor.extractNames(immutableOutputInfo);
        this.means = dArr;
        this.variances = dArr2;
        this.standardized = z;
        this.mapping = ((ImmutableRegressionInfo) immutableOutputInfo).getIDtoNaturalOrderMapping();
    }

    public static LibSVMRegressionModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        LibSVMRegressionModelProto unpack = any.unpack(LibSVMRegressionModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + deserialize.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        if (outputDomain.size() != unpack.getModelCount()) {
            throw new IllegalStateException("Invalid protobuf, did not find a model for each output dimension, expected " + outputDomain.size() + " found " + unpack.getModelCount());
        }
        ArrayList arrayList = new ArrayList();
        Iterator<SVMModelProto> it = unpack.getModelList().iterator();
        while (it.hasNext()) {
            arrayList.add(deserializeModel(it.next()));
        }
        double[] primitiveDouble = unpack.getMeansCount() == 0 ? null : Util.toPrimitiveDouble(unpack.getMeansList());
        if (primitiveDouble != null && primitiveDouble.length != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, expected " + outputDomain.size() + " means, found " + primitiveDouble.length);
        }
        double[] primitiveDouble2 = unpack.getVariancesCount() == 0 ? null : Util.toPrimitiveDouble(unpack.getVariancesList());
        if (primitiveDouble2 == null || primitiveDouble2.length == outputDomain.size()) {
            return new LibSVMRegressionModel(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), outputDomain, Collections.unmodifiableList(arrayList), primitiveDouble, primitiveDouble2, unpack.getStandardized());
        }
        throw new IllegalStateException("Invalid protobuf, expected " + outputDomain.size() + " variances, found " + primitiveDouble2.length);
    }

    boolean isStandardized() {
        return this.standardized;
    }

    public Map<String, Integer> getNumberOfSupportVectors() {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.dimensionNames.length; i++) {
            hashMap.put(this.dimensionNames[i], Integer.valueOf(((svm_model) this.models.get(i)).SV.length));
        }
        return hashMap;
    }

    public Prediction<Regressor> predict(Example<Regressor> example) {
        svm_node[] exampleToNodes = LibSVMTrainer.exampleToNodes(example, this.featureIDMap, (List) null);
        if (exampleToNodes.length == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        double[] dArr = new double[1];
        double[] dArr2 = new double[this.models.size()];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[this.mapping[i]] = svm.svm_predict_values((svm_model) this.models.get(i), exampleToNodes, dArr);
            if (this.standardized) {
                dArr2[i] = (dArr2[i] * this.variances[this.mapping[i]]) + this.means[this.mapping[i]];
            }
        }
        return new Prediction<>(new Regressor(this.dimensionNames, dArr2), exampleToNodes.length, example);
    }

    protected double[] getMeans() {
        if (this.means != null) {
            return Arrays.copyOf(this.means, this.means.length);
        }
        return null;
    }

    protected double[] getVariances() {
        if (this.variances != null) {
            return Arrays.copyOf(this.variances, this.variances.length);
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public LibSVMRegressionModel m0copy(String str, ModelProvenance modelProvenance) {
        ArrayList arrayList = new ArrayList();
        Iterator it = this.models.iterator();
        while (it.hasNext()) {
            arrayList.add(copyModel((svm_model) it.next()));
        }
        return new LibSVMRegressionModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, arrayList);
    }

    public OnnxMl.ModelProto exportONNXModel(String str, long j) {
        ONNXContext oNNXContext = new ONNXContext();
        ONNXPlaceholder floatInput = oNNXContext.floatInput(this.featureIDMap.size());
        ONNXPlaceholder floatOutput = oNNXContext.floatOutput(this.outputIDInfo.size());
        oNNXContext.setName("Regression-LibSVM");
        return ONNXExportable.buildModel(writeONNXGraph(floatInput).assignTo(floatOutput).onnxContext(), str, j, this);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static ONNXNode buildONNXSVMRegressor(int i, ONNXRef<?> oNNXRef, svm_model svm_modelVar) {
        HashMap hashMap = new HashMap();
        hashMap.put("coefficients", Util.toFloatArray(svm_modelVar.sv_coef[0]));
        hashMap.put("kernel_params", new float[]{(float) svm_modelVar.param.gamma, (float) svm_modelVar.param.coef0, svm_modelVar.param.degree});
        hashMap.put("kernel_type", KernelType.getKernelType(svm_modelVar.param.kernel_type).name());
        hashMap.put("n_supports", Integer.valueOf(svm_modelVar.l));
        hashMap.put("one_class", 0);
        hashMap.put("rho", new float[]{(float) (-svm_modelVar.rho[0])});
        float[] fArr = new float[svm_modelVar.l * i];
        for (int i2 = 0; i2 < svm_modelVar.l; i2++) {
            for (svm_node svm_nodeVar : svm_modelVar.SV[i2]) {
                fArr[(i2 * i) + svm_nodeVar.index] = (float) svm_nodeVar.value;
            }
        }
        hashMap.put("support_vectors", fArr);
        return oNNXRef.apply(ONNXOperators.SVM_REGRESSOR, hashMap);
    }

    public ONNXNode writeONNXGraph(ONNXRef<?> oNNXRef) {
        ONNXContext onnxContext = oNNXRef.onnxContext();
        int size = this.featureIDMap.size();
        ONNXNode operation = onnxContext.operation(ONNXOperators.CONCAT, (List) this.models.stream().map(svm_modelVar -> {
            return buildONNXSVMRegressor(size, oNNXRef, svm_modelVar);
        }).collect(Collectors.toList()), "concat_output", Collections.singletonMap("axis", 1));
        if (!this.standardized) {
            return operation;
        }
        return operation.apply(ONNXOperators.MUL, onnxContext.array("y_variances", this.variances)).apply(ONNXOperators.ADD, onnxContext.array("y_mean", this.means));
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m1serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        LibSVMRegressionModelProto.Builder newBuilder = LibSVMRegressionModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        Iterator it = this.models.iterator();
        while (it.hasNext()) {
            newBuilder.addModel(serializeModel((svm_model) it.next()));
        }
        if (this.means != null) {
            newBuilder.addAllMeans((Iterable) Arrays.stream(this.means).boxed().collect(Collectors.toList()));
        }
        if (this.variances != null) {
            newBuilder.addAllVariances((Iterable) Arrays.stream(this.variances).boxed().collect(Collectors.toList()));
        }
        newBuilder.setStandardized(this.standardized);
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m46build()));
        newBuilder2.setClassName(LibSVMRegressionModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        if (this.mapping == null) {
            this.mapping = this.outputIDInfo.getIDtoNaturalOrderMapping();
            ArrayList arrayList = new ArrayList(this.models);
            double[] dArr = new double[arrayList.size()];
            double[] dArr2 = new double[arrayList.size()];
            for (int i = 0; i < this.mapping.length; i++) {
                arrayList.set(i, (svm_model) this.models.get(this.mapping[i]));
                if (this.means != null) {
                    dArr[i] = this.means[this.mapping[i]];
                    dArr2[i] = this.variances[this.mapping[i]];
                }
            }
            this.models = Collections.unmodifiableList(arrayList);
            this.means = dArr;
            this.variances = dArr2;
        }
    }
}
