package org.tribuo.regression.sgd.fm;

import com.oracle.labs.mlrg.olcut.config.Config;
import java.util.Iterator;
import java.util.logging.Logger;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.common.sgd.AbstractFMTrainer;
import org.tribuo.common.sgd.FMParameters;
import org.tribuo.common.sgd.SGDObjective;
import org.tribuo.math.FeedForwardParameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.DenseVector;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.sgd.RegressionObjective;

/* loaded from: input_file:org/tribuo/regression/sgd/fm/FMRegressionTrainer.class */
public class FMRegressionTrainer extends AbstractFMTrainer<Regressor, DenseVector, FMRegressionModel> {
    private static final Logger logger = Logger.getLogger(FMRegressionTrainer.class.getName());

    @Config(mandatory = true, description = "The regression objective to use.")
    private RegressionObjective objective;

    @Config(mandatory = true, description = "Standardise the output variables before fitting the model.")
    private boolean standardise;

    public FMRegressionTrainer(RegressionObjective regressionObjective, StochasticGradientOptimiser stochasticGradientOptimiser, int i, int i2, int i3, long j, int i4, double d, boolean z) {
        super(stochasticGradientOptimiser, i, i2, i3, j, i4, d);
        this.objective = regressionObjective;
        this.standardise = z;
    }

    public FMRegressionTrainer(RegressionObjective regressionObjective, StochasticGradientOptimiser stochasticGradientOptimiser, int i, int i2, long j, int i3, double d, boolean z) {
        this(regressionObjective, stochasticGradientOptimiser, i, i2, 1, j, i3, d, z);
    }

    public FMRegressionTrainer(RegressionObjective regressionObjective, StochasticGradientOptimiser stochasticGradientOptimiser, int i, long j, int i2, double d, boolean z) {
        this(regressionObjective, stochasticGradientOptimiser, i, 1000, 1, j, i2, d, z);
    }

    private FMRegressionTrainer() {
    }

    protected DenseVector getTarget(ImmutableOutputInfo<Regressor> immutableOutputInfo, Regressor regressor) {
        ImmutableRegressionInfo immutableRegressionInfo = (ImmutableRegressionInfo) immutableOutputInfo;
        double[] dArr = new double[immutableOutputInfo.size()];
        Iterator it = regressor.iterator();
        while (it.hasNext()) {
            Regressor.DimensionTuple dimensionTuple = (Regressor.DimensionTuple) it.next();
            int id = immutableOutputInfo.getID(dimensionTuple);
            double value = dimensionTuple.getValue();
            if (this.standardise) {
                value = (value - immutableRegressionInfo.getMean(id)) / immutableRegressionInfo.getVariance(id);
            }
            dArr[id] = value;
        }
        return DenseVector.createDenseVector(dArr);
    }

    protected SGDObjective<DenseVector> getObjective() {
        return this.objective;
    }

    protected FMRegressionModel createModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, FMParameters fMParameters) {
        String[] strArr = new String[immutableOutputInfo.size()];
        for (Regressor regressor : immutableOutputInfo.getDomain()) {
            strArr[immutableOutputInfo.getID(regressor)] = regressor.getNames()[0];
        }
        return new FMRegressionModel(str, strArr, modelProvenance, immutableFeatureMap, immutableOutputInfo, fMParameters, this.standardise);
    }

    protected String getModelClassName() {
        return FMRegressionModel.class.getName();
    }

    public String toString() {
        return "FMRegressionTrainer(objective=" + this.objective.toString() + ",optimiser=" + this.optimiser.toString() + ",epochs=" + this.epochs + ",minibatchSize=" + this.minibatchSize + ",seed=" + this.seed + ",factorizedDimSize=" + this.factorizedDimSize + ",variance=" + this.variance + ",standardise=" + this.standardise + ")";
    }

    protected /* bridge */ /* synthetic */ Model createModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo immutableOutputInfo, FeedForwardParameters feedForwardParameters) {
        return createModel(str, modelProvenance, immutableFeatureMap, (ImmutableOutputInfo<Regressor>) immutableOutputInfo, (FMParameters) feedForwardParameters);
    }

    protected /* bridge */ /* synthetic */ Object getTarget(ImmutableOutputInfo immutableOutputInfo, Output output) {
        return getTarget((ImmutableOutputInfo<Regressor>) immutableOutputInfo, (Regressor) output);
    }
}
