package org.tribuo.common.liblinear;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/common/liblinear/LibLinearTrainer.class */
public abstract class LibLinearTrainer<T extends Output<T>> implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(LibLinearTrainer.class.getName());
    protected Parameter libLinearParams;

    @Config(description = "Algorithm to use.")
    protected LibLinearType<T> trainerType;

    @Config(description = "Cost penalty for misclassifications.")
    protected double cost;

    @Config(description = "Maximum number of iterations before terminating.")
    protected int maxIterations;

    @Config(description = "Stop iterating when the loss score decreases by less than this value.")
    protected double terminationCriterion;

    @Config(description = "Epsilon insensitivity in the regression cost function.")
    protected double epsilon;

    @Config(description = "RNG seed.")
    protected long seed;
    private SplittableRandom rng;
    private int trainInvocationCount;

    protected LibLinearTrainer() {
        this.cost = 1.0d;
        this.maxIterations = 1000;
        this.terminationCriterion = 0.1d;
        this.epsilon = 0.1d;
        this.seed = 12345L;
        this.trainInvocationCount = 0;
    }

    protected LibLinearTrainer(LibLinearType<T> libLinearType, double d, int i, double d2) {
        this((LibLinearType) libLinearType, d, i, d2, 0.1d);
    }

    protected LibLinearTrainer(LibLinearType<T> libLinearType, double d, int i, double d2, long j) {
        this(libLinearType, d, i, d2, 0.1d, j);
    }

    protected LibLinearTrainer(LibLinearType<T> libLinearType, double d, int i, double d2, double d3) {
        this(libLinearType, d, i, d2, d3, 12345L);
    }

    protected LibLinearTrainer(LibLinearType<T> libLinearType, double d, int i, double d2, double d3, long j) {
        this.cost = 1.0d;
        this.maxIterations = 1000;
        this.terminationCriterion = 0.1d;
        this.epsilon = 0.1d;
        this.seed = 12345L;
        this.trainInvocationCount = 0;
        this.trainerType = libLinearType;
        this.cost = d;
        this.maxIterations = i;
        this.terminationCriterion = d2;
        this.epsilon = d3;
        this.seed = j;
        postConfig();
    }

    public void postConfig() {
        this.libLinearParams = new Parameter(this.trainerType.getSolverType(), this.cost, this.terminationCriterion, this.maxIterations, this.epsilon);
        this.rng = new SplittableRandom(this.seed);
        Linear.disableDebugOutput();
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public LibLinearModel<T> m5train(Dataset<T> dataset) {
        return train((Dataset) dataset, Collections.emptyMap());
    }

    public LibLinearModel<T> train(Dataset<T> dataset, Map<String, Provenance> map) {
        return train((Dataset) dataset, map, -1);
    }

    public LibLinearModel<T> train(Dataset<T> dataset, Map<String, Provenance> map, int i) {
        SplittableRandom split;
        TrainerProvenance m6getProvenance;
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        synchronized (this) {
            if (i != -1) {
                setInvocationCount(i);
            }
            split = this.rng.split();
            m6getProvenance = m6getProvenance();
            this.trainInvocationCount++;
        }
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo<T> outputIDInfo = dataset.getOutputIDInfo();
        Parameter parameter = setupParameters(outputIDInfo);
        parameter.setRandom(new Random(split.nextLong()));
        ModelProvenance modelProvenance = new ModelProvenance(LibLinearModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m6getProvenance, map);
        Pair<FeatureNode[][], double[][]> extractData = extractData(dataset, outputIDInfo, featureIDMap);
        return createModel(modelProvenance, featureIDMap, outputIDInfo, trainModels(parameter, featureIDMap.size() + 1, (FeatureNode[][]) extractData.getA(), (double[][]) extractData.getB()));
    }

    public int getInvocationCount() {
        return this.trainInvocationCount;
    }

    public synchronized void setInvocationCount(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCount = 0;
        while (this.trainInvocationCount < i) {
            this.rng.split();
            this.trainInvocationCount++;
        }
    }

    public String toString() {
        return "LibLinearTrainer(solver=" + this.libLinearParams.getSolverType() + ",cost=" + this.libLinearParams.getC() + ",terminationCriterion=" + this.libLinearParams.getEps() + ",maxIterations=" + this.libLinearParams.getMaxIters() + ",regression-epsilon=" + this.libLinearParams.getP() + ",seed=" + this.seed + ')';
    }

    protected abstract List<Model> trainModels(Parameter parameter, int i, FeatureNode[][] featureNodeArr, double[][] dArr);

    protected abstract LibLinearModel<T> createModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, List<Model> list);

    protected abstract Pair<FeatureNode[][], double[][]> extractData(Dataset<T> dataset, ImmutableOutputInfo<T> immutableOutputInfo, ImmutableFeatureMap immutableFeatureMap);

    protected Parameter setupParameters(ImmutableOutputInfo<T> immutableOutputInfo) {
        return this.libLinearParams.clone();
    }

    public static <T extends Output<T>> FeatureNode[] exampleToNodes(Example<T> example, ImmutableFeatureMap immutableFeatureMap, List<FeatureNode> list) {
        int size = immutableFeatureMap.size() + 1;
        if (list == null) {
            list = new ArrayList();
        }
        list.clear();
        int i = -1;
        Iterator it = example.iterator();
        while (it.hasNext()) {
            Feature feature = (Feature) it.next();
            int id = immutableFeatureMap.getID(feature.getName());
            if (id > i) {
                i = id;
                list.add(new FeatureNode(id + 1, feature.getValue()));
            } else if (id > -1) {
                int binarySearch = Util.binarySearch(list, id + 1, (v0) -> {
                    return v0.getIndex();
                });
                if (binarySearch < 0) {
                    list.add(-(binarySearch + 1), new FeatureNode(id + 1, feature.getValue()));
                } else {
                    FeatureNode featureNode = list.get(binarySearch);
                    featureNode.setValue(featureNode.getValue() + feature.getValue());
                }
            }
        }
        list.add(new FeatureNode(size, 1.0d));
        return (FeatureNode[]) list.toArray(new FeatureNode[0]);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m6getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ org.tribuo.Model m3train(Dataset dataset, Map map, int i) {
        return train(dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ org.tribuo.Model m4train(Dataset dataset, Map map) {
        return train(dataset, (Map<String, Provenance>) map);
    }
}
