package org.tribuo.regression.sgd;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.data.DataOptions;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.optimisers.GradientOptimiserOptions;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.evaluation.RegressionEvaluation;
import org.tribuo.regression.sgd.linear.LinearSGDTrainer;
import org.tribuo.regression.sgd.objectives.AbsoluteLoss;
import org.tribuo.regression.sgd.objectives.Huber;
import org.tribuo.regression.sgd.objectives.SquaredLoss;

/* loaded from: input_file:org/tribuo/regression/sgd/TrainTest.class */
public class TrainTest {
    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

    /* loaded from: input_file:org/tribuo/regression/sgd/TrainTest$LossEnum.class */
    public enum LossEnum {
        ABSOLUTE,
        SQUARED,
        HUBER
    }

    /* loaded from: input_file:org/tribuo/regression/sgd/TrainTest$SGDOptions.class */
    public static class SGDOptions implements Options {
        public DataOptions general;
        public GradientOptimiserOptions gradientOptions;

        @Option(charName = 'i', longName = "epochs", usage = "Number of SGD epochs.")
        public int epochs = 5;

        @Option(charName = 'o', longName = "objective", usage = "Loss function.")
        public LossEnum loss = LossEnum.SQUARED;

        @Option(charName = 'p', longName = "logging-interval", usage = "Log the objective after <int> examples.")
        public int loggingInterval = 100;

        @Option(charName = 'z', longName = "minibatch-size", usage = "Minibatch size.")
        public int minibatchSize = 1;

        public String getOptionsDescription() {
            return "Trains and tests a linear SGD regression model on the specified datasets.";
        }
    }

    public static void main(String[] strArr) throws IOException {
        RegressionObjective huber;
        LabsLogFormatter.setAllLogFormatters();
        SGDOptions sGDOptions = new SGDOptions();
        try {
            ConfigurationManager configurationManager = new ConfigurationManager(strArr, sGDOptions);
            if (sGDOptions.general.trainingPath == null || sGDOptions.general.testingPath == null) {
                logger.info(configurationManager.usage());
                return;
            }
            logger.info("Configuring gradient optimiser");
            switch (sGDOptions.loss) {
                case ABSOLUTE:
                    huber = new AbsoluteLoss();
                    break;
                case SQUARED:
                    huber = new SquaredLoss();
                    break;
                case HUBER:
                    huber = new Huber();
                    break;
                default:
                    logger.warning("Unknown objective function " + sGDOptions.loss);
                    logger.info(configurationManager.usage());
                    return;
            }
            StochasticGradientOptimiser optimiser = sGDOptions.gradientOptions.getOptimiser();
            logger.info(String.format("Set logging interval to %d", Integer.valueOf(sGDOptions.loggingInterval)));
            RegressionFactory regressionFactory = new RegressionFactory();
            Pair load = sGDOptions.general.load(regressionFactory);
            Dataset dataset = (Dataset) load.getA();
            Dataset dataset2 = (Dataset) load.getB();
            LinearSGDTrainer linearSGDTrainer = new LinearSGDTrainer(huber, optimiser, sGDOptions.epochs, sGDOptions.loggingInterval, sGDOptions.minibatchSize, sGDOptions.general.seed);
            logger.info("Training using " + linearSGDTrainer.toString());
            long currentTimeMillis = System.currentTimeMillis();
            Model train = linearSGDTrainer.train(dataset);
            logger.info("Finished training regressor " + org.tribuo.util.Util.formatDuration(currentTimeMillis, System.currentTimeMillis()));
            long currentTimeMillis2 = System.currentTimeMillis();
            RegressionEvaluation evaluate = regressionFactory.getEvaluator().evaluate(train, dataset2);
            logger.info("Finished evaluating model " + org.tribuo.util.Util.formatDuration(currentTimeMillis2, System.currentTimeMillis()));
            System.out.println(evaluate.toString());
            if (sGDOptions.general.outputPath != null) {
                sGDOptions.general.saveModel(train);
            }
        } catch (UsageException e) {
            logger.info(e.getMessage());
        }
    }
}
