package org.tribuo.classification.mnb;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Trainer;
import org.tribuo.WeightedExamples;
import org.tribuo.classification.Label;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

/* loaded from: input_file:org/tribuo/classification/mnb/MultinomialNaiveBayesTrainer.class */
public class MultinomialNaiveBayesTrainer implements Trainer<Label>, WeightedExamples {

    @Config(description = "Smoothing parameter.")
    private double alpha;
    private int trainInvocationCount;

    public MultinomialNaiveBayesTrainer() {
        this(1.0d);
    }

    public MultinomialNaiveBayesTrainer(double d) {
        this.alpha = 1.0d;
        this.trainInvocationCount = 0;
        if (d <= 0.0d) {
            throw new IllegalArgumentException("alpha parameter must be > 0");
        }
        this.alpha = d;
    }

    public Model<Label> train(Dataset<Label> dataset, Map<String, Provenance> map) {
        return train(dataset, map, -1);
    }

    public Model<Label> train(Dataset<Label> dataset, Map<String, Provenance> map, int i) {
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        HashMap hashMap = new HashMap();
        Iterator it = outputIDInfo.iterator();
        while (it.hasNext()) {
            hashMap.put((Integer) ((Pair) it.next()).getA(), new HashMap());
        }
        Iterator it2 = dataset.iterator();
        while (it2.hasNext()) {
            Example example = (Example) it2.next();
            Map map2 = (Map) hashMap.get(Integer.valueOf(outputIDInfo.getID(example.getOutput())));
            double weight = example.getWeight();
            Iterator it3 = example.iterator();
            while (it3.hasNext()) {
                Feature feature = (Feature) it3.next();
                if (feature.getValue() < 0.0d) {
                    throw new IllegalStateException("Multinomial Naive Bayes requires non-negative features. Found feature " + feature.toString());
                }
                map2.merge(Integer.valueOf(featureIDMap.getID(feature.getName())), Double.valueOf(weight * feature.getValue()), (v0, v1) -> {
                    return Double.sum(v0, v1);
                });
            }
        }
        if (i != -1) {
            setInvocationCount(i);
        }
        ModelProvenance modelProvenance = new ModelProvenance(MultinomialNaiveBayesModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m4getProvenance(), map);
        this.trainInvocationCount++;
        SparseVector[] sparseVectorArr = new SparseVector[outputIDInfo.size()];
        for (int i2 = 0; i2 < outputIDInfo.size(); i2++) {
            SparseVector createSparseVector = SparseVector.createSparseVector(featureIDMap.size(), (Map) hashMap.get(Integer.valueOf(i2)));
            double oneNorm = createSparseVector.oneNorm();
            createSparseVector.foreachInPlace(d -> {
                return Math.log((d + this.alpha) / (oneNorm + (featureIDMap.size() * this.alpha)));
            });
            sparseVectorArr[i2] = createSparseVector;
        }
        return new MultinomialNaiveBayesModel("", modelProvenance, featureIDMap, outputIDInfo, DenseSparseMatrix.createFromSparseVectors(sparseVectorArr), this.alpha);
    }

    public int getInvocationCount() {
        return this.trainInvocationCount;
    }

    public void setInvocationCount(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.trainInvocationCount = i;
    }

    public String toString() {
        return "MultinomialNaiveBayesTrainer(alpha=" + this.alpha + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m4getProvenance() {
        return new TrainerProvenanceImpl(this);
    }
}
