package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.function.DoubleUnaryOperator;
import java.util.logging.Logger;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.Tensor;

/* loaded from: input_file:org/tribuo/math/optimisers/AdaGrad.class */
public class AdaGrad implements StochasticGradientOptimiser {
    private static final Logger logger = Logger.getLogger(AdaGrad.class.getName());

    @Config(mandatory = true, description = "Initial learning rate used to scale the gradients.")
    private double initialLearningRate;

    @Config(description = "Epsilon for numerical stability around zero.")
    private double epsilon;

    @Config(description = "Initial value for the gradient accumulator.")
    private double initialValue;
    private Tensor[] gradsSquared;

    public AdaGrad(double d, double d2, double d3) {
        this.epsilon = 1.0E-6d;
        this.initialValue = 0.0d;
        this.initialLearningRate = d;
        this.epsilon = d2;
        this.initialValue = d3;
    }

    public AdaGrad(double d, double d2) {
        this(d, d2, 0.0d);
    }

    public AdaGrad(double d) {
        this(d, 1.0E-6d, 0.0d);
    }

    private AdaGrad() {
        this.epsilon = 1.0E-6d;
        this.initialValue = 0.0d;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void initialise(Parameters parameters) {
        this.gradsSquared = parameters.getEmptyCopy();
        if (this.initialValue != 0.0d) {
            for (Tensor tensor : this.gradsSquared) {
                tensor.scalarAddInPlace(this.initialValue);
            }
        }
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public Tensor[] step(Tensor[] tensorArr, double d) {
        DoubleUnaryOperator doubleUnaryOperator = d2 -> {
            return d * d * d2 * d2;
        };
        DoubleUnaryOperator doubleUnaryOperator2 = d3 -> {
            return (d * this.initialLearningRate) / (this.epsilon + Math.sqrt(d3));
        };
        for (int i = 0; i < tensorArr.length; i++) {
            Tensor tensor = this.gradsSquared[i];
            Tensor tensor2 = tensorArr[i];
            tensor.intersectAndAddInPlace(tensor2, doubleUnaryOperator);
            tensor2.hadamardProductInPlace(tensor, doubleUnaryOperator2);
        }
        return tensorArr;
    }

    public String toString() {
        return "AdaGrad(initialLearningRate=" + this.initialLearningRate + ",epsilon=" + this.epsilon + ",initialValue=" + this.initialValue + ")";
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void reset() {
        this.gradsSquared = null;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public AdaGrad copy() {
        return new AdaGrad(this.initialLearningRate, this.epsilon);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m29getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "StochasticGradientOptimiser");
    }
}
