package org.tribuo.multilabel.sgd.fm;

import com.oracle.labs.mlrg.olcut.config.Config;
import java.util.logging.Logger;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.common.sgd.AbstractFMTrainer;
import org.tribuo.common.sgd.FMParameters;
import org.tribuo.common.sgd.SGDObjective;
import org.tribuo.math.FeedForwardParameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.multilabel.sgd.MultiLabelObjective;
import org.tribuo.multilabel.sgd.objectives.BinaryCrossEntropy;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/multilabel/sgd/fm/FMMultiLabelTrainer.class */
public class FMMultiLabelTrainer extends AbstractFMTrainer<MultiLabel, SGDVector, FMMultiLabelModel> {
    private static final Logger logger = Logger.getLogger(FMMultiLabelTrainer.class.getName());

    @Config(description = "The classification objective function to use.")
    private MultiLabelObjective objective;

    public FMMultiLabelTrainer(MultiLabelObjective multiLabelObjective, StochasticGradientOptimiser stochasticGradientOptimiser, int i, int i2, int i3, long j, int i4, double d) {
        super(stochasticGradientOptimiser, i, i2, i3, j, i4, d);
        this.objective = new BinaryCrossEntropy();
        this.objective = multiLabelObjective;
    }

    public FMMultiLabelTrainer(MultiLabelObjective multiLabelObjective, StochasticGradientOptimiser stochasticGradientOptimiser, int i, int i2, long j, int i3, double d) {
        this(multiLabelObjective, stochasticGradientOptimiser, i, i2, 1, j, i3, d);
    }

    public FMMultiLabelTrainer(MultiLabelObjective multiLabelObjective, StochasticGradientOptimiser stochasticGradientOptimiser, int i, long j, int i2, double d) {
        this(multiLabelObjective, stochasticGradientOptimiser, i, 1000, 1, j, i2, d);
    }

    private FMMultiLabelTrainer() {
        this.objective = new BinaryCrossEntropy();
    }

    protected SparseVector getTarget(ImmutableOutputInfo<MultiLabel> immutableOutputInfo, MultiLabel multiLabel) {
        return multiLabel.convertToSparseVector(immutableOutputInfo);
    }

    protected SGDObjective<SGDVector> getObjective() {
        return this.objective;
    }

    protected FMMultiLabelModel createModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<MultiLabel> immutableOutputInfo, FMParameters fMParameters) {
        return new FMMultiLabelModel(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, fMParameters, this.objective.getNormalizer(), this.objective.isProbabilistic(), this.objective.threshold());
    }

    protected String getModelClassName() {
        return FMMultiLabelModel.class.getName();
    }

    public String toString() {
        return "FMMultiLabelTrainer(objective=" + this.objective.toString() + ",optimiser=" + this.optimiser.toString() + ",epochs=" + this.epochs + ",minibatchSize=" + this.minibatchSize + ",seed=" + this.seed + ",factorizedDimSize=" + this.factorizedDimSize + ",variance=" + this.variance + ")";
    }

    protected /* bridge */ /* synthetic */ Model createModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo immutableOutputInfo, FeedForwardParameters feedForwardParameters) {
        return createModel(str, modelProvenance, immutableFeatureMap, (ImmutableOutputInfo<MultiLabel>) immutableOutputInfo, (FMParameters) feedForwardParameters);
    }

    protected /* bridge */ /* synthetic */ Object getTarget(ImmutableOutputInfo immutableOutputInfo, Output output) {
        return getTarget((ImmutableOutputInfo<MultiLabel>) immutableOutputInfo, (MultiLabel) output);
    }
}
