package org.tribuo.regression.sgd.fm;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.Arrays;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.common.sgd.AbstractFMModel;
import org.tribuo.common.sgd.AbstractSGDModel;
import org.tribuo.common.sgd.FMParameters;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.Parameters;
import org.tribuo.math.protos.ParametersProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.sgd.protos.FMRegressionModelProto;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;

/* loaded from: input_file:org/tribuo/regression/sgd/fm/FMRegressionModel.class */
public class FMRegressionModel extends AbstractFMModel<Regressor> implements ONNXExportable {
    private static final long serialVersionUID = 3;
    public static final int CURRENT_VERSION = 0;
    private final String[] dimensionNames;
    private final boolean standardise;

    /* JADX INFO: Access modifiers changed from: package-private */
    public FMRegressionModel(String str, String[] strArr, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, FMParameters fMParameters, boolean z) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, fMParameters, false);
        this.dimensionNames = strArr;
        this.standardise = z;
    }

    public static FMRegressionModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        FMRegressionModelProto unpack = any.unpack(FMRegressionModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + deserialize.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        FMParameters deserialize2 = Parameters.deserialize(unpack.getParams());
        if (!(deserialize2 instanceof FMParameters)) {
            throw new IllegalStateException("Invalid protobuf, parameters must be FMParameters, found " + deserialize2.getClass());
        }
        String[] strArr = (String[]) unpack.mo23getDimensionNamesList().toArray(new String[0]);
        if (strArr.length != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, found a different number of dimension names to the output dimensions, found " + strArr.length + " , expected " + outputDomain.size());
        }
        return new FMRegressionModel(deserialize.name(), strArr, deserialize.provenance(), deserialize.featureDomain(), outputDomain, deserialize2, unpack.getStandardise());
    }

    public Prediction<Regressor> predict(Example<Regressor> example) {
        AbstractSGDModel.PredAndActive predictSingle = predictSingle(example);
        double[] array = predictSingle.prediction.toArray();
        if (this.standardise) {
            array = unstandardisePredictions(array);
        }
        return new Prediction<>(new Regressor(this.dimensionNames, array), predictSingle.numActiveFeatures, example);
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m4serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        FMRegressionModelProto.Builder newBuilder = FMRegressionModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setParams((ParametersProto) this.modelParameters.serialize());
        newBuilder.addAllDimensionNames(Arrays.asList(this.dimensionNames));
        newBuilder.setStandardise(this.standardise);
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setVersion(0);
        newBuilder2.setClassName(FMRegressionModel.class.getName());
        newBuilder2.setSerializedData(Any.pack(newBuilder.m56build()));
        return newBuilder2.build();
    }

    private double[] unstandardisePredictions(double[] dArr) {
        ImmutableRegressionInfo immutableRegressionInfo = this.outputIDInfo;
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = (dArr[i] * immutableRegressionInfo.getVariance(i)) + immutableRegressionInfo.getMean(i);
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public FMRegressionModel m3copy(String str, ModelProvenance modelProvenance) {
        return new FMRegressionModel(str, (String[]) Arrays.copyOf(this.dimensionNames, this.dimensionNames.length), modelProvenance, this.featureIDMap, this.outputIDInfo, this.modelParameters.copy(), this.standardise);
    }

    protected String getDimensionName(int i) {
        return this.dimensionNames[i];
    }

    protected String onnxModelName() {
        return "FMRegressionModel";
    }

    protected ONNXNode onnxOutput(ONNXNode oNNXNode) {
        if (!this.standardise) {
            return oNNXNode;
        }
        ImmutableRegressionInfo immutableRegressionInfo = this.outputIDInfo;
        double[] dArr = new double[this.outputIDInfo.size()];
        double[] dArr2 = new double[this.outputIDInfo.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = immutableRegressionInfo.getMean(i);
            dArr2[i] = immutableRegressionInfo.getVariance(i);
        }
        return oNNXNode.apply(ONNXOperators.MUL, oNNXNode.onnxContext().array("y_var", dArr2)).apply(ONNXOperators.ADD, oNNXNode.onnxContext().array("y_mean", dArr));
    }
}
