package org.tribuo.regression.liblinear;

import com.oracle.labs.mlrg.olcut.util.Pair;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.liblinear.LinearRegressionType;

/* loaded from: input_file:org/tribuo/regression/liblinear/LibLinearRegressionTrainer.class */
public class LibLinearRegressionTrainer extends LibLinearTrainer<Regressor> {
    private static final Logger logger = Logger.getLogger(LibLinearRegressionTrainer.class.getName());
    boolean forceZero;

    public LibLinearRegressionTrainer() {
        this(new LinearRegressionType(LinearRegressionType.LinearType.L2R_L2LOSS_SVR));
    }

    public LibLinearRegressionTrainer(LinearRegressionType linearRegressionType) {
        this(linearRegressionType, 1.0d, 1000, 0.1d, 0.1d);
    }

    public LibLinearRegressionTrainer(LinearRegressionType linearRegressionType, double d, int i, double d2, double d3) {
        this(linearRegressionType, d, i, d2, d3, 12345L);
    }

    public LibLinearRegressionTrainer(LinearRegressionType linearRegressionType, double d, int i, double d2, double d3, long j) {
        super(linearRegressionType, d, i, d2, d3, j);
        this.forceZero = false;
    }

    public void postConfig() {
        super.postConfig();
        if (!this.trainerType.isRegression()) {
            throw new IllegalArgumentException("Supplied classification or anomaly detection parameters to a regression linear model.");
        }
    }

    protected List<Model> trainModels(Parameter parameter, int i, FeatureNode[][] featureNodeArr, double[][] dArr) {
        ArrayList arrayList = new ArrayList();
        for (double[] dArr2 : dArr) {
            Problem problem = new Problem();
            problem.l = featureNodeArr.length;
            problem.y = dArr2;
            problem.x = featureNodeArr;
            problem.n = i;
            problem.bias = 1.0d;
            if (this.forceZero) {
                parameter.setRandom(new Random(0L));
            }
            arrayList.add(Linear.train(problem, parameter));
        }
        return arrayList;
    }

    protected LibLinearModel<Regressor> createModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, List<Model> list) {
        if (list.size() != immutableOutputInfo.size()) {
            throw new IllegalArgumentException("Regression uses one model per dimension. Found " + list.size() + " models, and " + immutableOutputInfo.size() + " dimensions.");
        }
        return new LibLinearRegressionModel("liblinear-regression-model", modelProvenance, immutableFeatureMap, immutableOutputInfo, list);
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected Pair<FeatureNode[][], double[][]> extractData(Dataset<Regressor> dataset, ImmutableOutputInfo<Regressor> immutableOutputInfo, ImmutableFeatureMap immutableFeatureMap) {
        int size = immutableOutputInfo.size();
        int[] naturalOrderToIDMapping = ((ImmutableRegressionInfo) immutableOutputInfo).getNaturalOrderToIDMapping();
        ArrayList arrayList = new ArrayList();
        FeatureNode[] featureNodeArr = new FeatureNode[dataset.size()];
        double[][] dArr = new double[size][dataset.size()];
        int i = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            double[] values = example.getOutput().getValues();
            for (int i2 = 0; i2 < values.length; i2++) {
                dArr[naturalOrderToIDMapping[i2]][i] = values[i2];
            }
            featureNodeArr[i] = exampleToNodes(example, immutableFeatureMap, arrayList);
            i++;
        }
        return new Pair<>(featureNodeArr, dArr);
    }
}
