package org.tribuo.common.sgd;

import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.math.FeedForwardParameters;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/common/sgd/AbstractSGDModel.class */
public abstract class AbstractSGDModel<T extends Output<T>> extends Model<T> {
    private static final long serialVersionUID = 1;
    protected FeedForwardParameters modelParameters;
    protected boolean addBias;

    /* loaded from: input_file:org/tribuo/common/sgd/AbstractSGDModel$PredAndActive.class */
    protected static final class PredAndActive {
        public final DenseVector prediction;
        public final int numActiveFeatures;

        PredAndActive(DenseVector denseVector, int i) {
            this.prediction = denseVector;
            this.numActiveFeatures = i;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractSGDModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, FeedForwardParameters feedForwardParameters, boolean z, boolean z2) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, z);
        this.addBias = true;
        this.modelParameters = feedForwardParameters;
        this.addBias = z2;
    }

    protected PredAndActive predictSingle(Example<T> example) {
        DenseVector createDenseVector = example.size() == this.featureIDMap.size() ? DenseVector.createDenseVector(example, this.featureIDMap, this.addBias) : SparseVector.createSparseVector(example, this.featureIDMap, this.addBias);
        if (createDenseVector.numActiveElements() == (this.addBias ? 1 : 0)) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        return new PredAndActive(this.modelParameters.predict(createDenseVector), createDenseVector.numActiveElements());
    }

    public FeedForwardParameters getModelParameters() {
        return this.modelParameters.copy();
    }
}
