package org.tribuo.multilabel.sgd.linear;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.HashMap;
import java.util.HashSet;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.sgd.AbstractLinearSGDModel;
import org.tribuo.common.sgd.AbstractSGDModel;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.Parameters;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.protos.NormalizerProto;
import org.tribuo.math.protos.ParametersProto;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.multilabel.sgd.protos.MultiLabelLinearSGDProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXNode;

/* loaded from: input_file:org/tribuo/multilabel/sgd/linear/LinearSGDModel.class */
public class LinearSGDModel extends AbstractLinearSGDModel<MultiLabel> implements ONNXExportable {
    private static final long serialVersionUID = 2;
    public static final int CURRENT_VERSION = 0;
    private final VectorNormalizer normalizer;
    private final double threshold;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LinearSGDModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<MultiLabel> immutableOutputInfo, LinearParameters linearParameters, VectorNormalizer vectorNormalizer, boolean z, double d) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, linearParameters, z);
        this.normalizer = vectorNormalizer;
        this.threshold = d;
    }

    public static LinearSGDModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        MultiLabelLinearSGDProto unpack = any.unpack(MultiLabelLinearSGDProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(MultiLabel.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a multi-label domain, found " + deserialize.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        LinearParameters deserialize2 = Parameters.deserialize(unpack.getParams());
        if (!(deserialize2 instanceof LinearParameters)) {
            throw new IllegalStateException("Invalid protobuf, parameters must be LinearParameters, found " + deserialize2.getClass());
        }
        return new LinearSGDModel(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), outputDomain, deserialize2, VectorNormalizer.deserialize(unpack.getNormalizer()), deserialize.generatesProbabilities(), unpack.getThreshold());
    }

    public Prediction<MultiLabel> predict(Example<MultiLabel> example) {
        AbstractSGDModel.PredAndActive predictSingle = predictSingle(example);
        DenseVector denseVector = predictSingle.prediction;
        denseVector.normalize(this.normalizer);
        HashMap hashMap = new HashMap();
        HashSet hashSet = new HashSet();
        for (int i = 0; i < denseVector.size(); i++) {
            String labelString = this.outputIDInfo.getOutput(i).getLabelString();
            double d = denseVector.get(i);
            Label label = new Label(this.outputIDInfo.getOutput(i).getLabelString(), d);
            if (d > this.threshold) {
                hashSet.add(label);
            }
            hashMap.put(labelString, new MultiLabel(label));
        }
        return new Prediction<>(new MultiLabel(hashSet), hashMap, predictSingle.numActiveFeatures - 1, example, this.generatesProbabilities);
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m7serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        MultiLabelLinearSGDProto.Builder newBuilder = MultiLabelLinearSGDProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setParams((ParametersProto) this.modelParameters.serialize());
        newBuilder.setNormalizer((NormalizerProto) this.normalizer.serialize());
        newBuilder.setThreshold(this.threshold);
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setVersion(0);
        newBuilder2.setClassName(LinearSGDModel.class.getName());
        newBuilder2.setSerializedData(Any.pack(newBuilder.m102build()));
        return newBuilder2.build();
    }

    protected String getDimensionName(int i) {
        return this.outputIDInfo.getOutput(i).getLabelString();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public LinearSGDModel m6copy(String str, ModelProvenance modelProvenance) {
        return new LinearSGDModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.modelParameters.copy(), this.normalizer, this.generatesProbabilities, this.threshold);
    }

    protected ONNXNode onnxOutput(ONNXNode oNNXNode) {
        return this.normalizer.exportNormalizer(oNNXNode);
    }

    protected String onnxModelName() {
        return "MultiLabel-LinearSGDModel";
    }
}
