package org.tribuo.classification.ensemble;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.ImmutableDataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.Trainer;
import org.tribuo.WeightedExamples;
import org.tribuo.classification.Label;
import org.tribuo.dataset.DatasetView;
import org.tribuo.ensemble.WeightedEnsembleModel;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/classification/ensemble/AdaBoostTrainer.class */
public class AdaBoostTrainer implements Trainer<Label> {
    private static final Logger logger = Logger.getLogger(AdaBoostTrainer.class.getName());

    @Config(mandatory = true, description = "The trainer to use to build each weak learner.")
    protected Trainer<Label> innerTrainer;

    @Config(mandatory = true, description = "The number of ensemble members to train.")
    protected int numMembers;

    @Config(mandatory = true, description = "The seed for the RNG.")
    protected long seed;
    protected SplittableRandom rng;
    protected int trainInvocationCounter;

    private AdaBoostTrainer() {
    }

    public AdaBoostTrainer(Trainer<Label> trainer, int i) {
        this(trainer, i, 12345L);
    }

    public AdaBoostTrainer(Trainer<Label> trainer, int i, long j) {
        this.innerTrainer = trainer;
        this.numMembers = i;
        this.seed = j;
        postConfig();
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    public String toString() {
        return "AdaBoostTrainer(innerTrainer=" + this.innerTrainer.toString() + ",numMembers=" + this.numMembers + ",seed=" + this.seed + ")";
    }

    public Model<Label> train(Dataset<Label> dataset, Map<String, Provenance> map) {
        return train(dataset, map, -1);
    }

    public Model<Label> train(Dataset<Label> dataset, Map<String, Provenance> map, int i) {
        SplittableRandom split;
        TrainerProvenance m20getProvenance;
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        synchronized (this) {
            if (i != -1) {
                setInvocationCount(i);
            }
            split = this.rng.split();
            m20getProvenance = m20getProvenance();
            this.trainInvocationCounter++;
        }
        boolean z = this.innerTrainer instanceof WeightedExamples;
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        logger.log(Level.INFO, "NumClasses = " + outputIDInfo.size());
        ArrayList arrayList = new ArrayList();
        float[] fArr = new float[this.numMembers];
        float[] generateUniformFloatVector = Util.generateUniformFloatVector(dataset.size(), 1.0f / dataset.size());
        if (z) {
            logger.info("Using weighted Adaboost.");
            dataset = ImmutableDataset.copyDataset(dataset);
            for (int i2 = 0; i2 < dataset.size(); i2++) {
                dataset.getExample(i2).setWeight(generateUniformFloatVector[i2]);
            }
        } else {
            logger.info("Using sampling Adaboost.");
        }
        for (int i3 = 0; i3 < this.numMembers; i3++) {
            logger.info("Building model " + i3);
            Model train = z ? this.innerTrainer.train(dataset) : this.innerTrainer.train(DatasetView.createWeightedBootstrapView(dataset, dataset.size(), split.nextLong(), generateUniformFloatVector, featureIDMap, outputIDInfo));
            List predict = train.predict(dataset);
            float accuracy = accuracy(predict, dataset, generateUniformFloatVector);
            float log = (float) (Math.log(accuracy / (1.0f - accuracy)) + Math.log(r0 - 1));
            arrayList.add(train);
            fArr[i3] = log;
            if (accuracy + 1.0E-10d > 1.0d) {
                float[] copyOf = Arrays.copyOf(fArr, arrayList.size());
                copyOf[arrayList.size() - 1] = 1.0f;
                logger.log(Level.FINE, "Perfect accuracy reached on iteration " + i3 + ", returning current model.");
                logger.log(Level.FINE, "Model weights:");
                Util.logVector(logger, Level.FINE, copyOf);
                return new WeightedEnsembleModel("boosted-ensemble", new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m20getProvenance, map, ListProvenance.createListProvenance(arrayList)), featureIDMap, outputIDInfo, arrayList, new VotingCombiner(), copyOf);
            }
            for (int i4 = 0; i4 < predict.size(); i4++) {
                if (!((Label) ((Prediction) predict.get(i4)).getOutput()).equals(dataset.getExample(i4).getOutput())) {
                    generateUniformFloatVector[i4] = (float) (generateUniformFloatVector[r1] * Math.exp(log));
                }
            }
            Util.inplaceNormalizeToDistribution(generateUniformFloatVector);
            if (z) {
                for (int i5 = 0; i5 < dataset.size(); i5++) {
                    dataset.getExample(i5).setWeight(generateUniformFloatVector[i5]);
                }
            }
        }
        logger.log(Level.FINE, "Model weights:");
        Util.logVector(logger, Level.FINE, fArr);
        return new WeightedEnsembleModel("boosted-ensemble", new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m20getProvenance, map, ListProvenance.createListProvenance(arrayList)), featureIDMap, outputIDInfo, arrayList, new VotingCombiner(), fArr);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < i) {
            this.rng.split();
            this.trainInvocationCounter++;
        }
    }

    private static float accuracy(List<Prediction<Label>> list, Dataset<Label> dataset, float[] fArr) {
        float f = 0.0f;
        float f2 = 0.0f;
        for (int i = 0; i < list.size(); i++) {
            if (((Label) list.get(i).getOutput()).equals(dataset.getExample(i).getOutput())) {
                f += fArr[i];
            }
            f2 += fArr[i];
        }
        logger.log(Level.FINEST, "Correct count = " + f + " size = " + dataset.size());
        return f / f2;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m20getProvenance() {
        return new TrainerProvenanceImpl(this);
    }
}
