package org.tribuo.regression.sgd.objectives;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.function.DoubleUnaryOperator;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.regression.sgd.RegressionObjective;

/* loaded from: input_file:org/tribuo/regression/sgd/objectives/Huber.class */
public class Huber implements RegressionObjective {
    public static final double DEFAULT_COST = 5.0d;

    @Config(description = "Cost beyond which the loss function is linear.")
    private double cost;
    private DoubleUnaryOperator lossFunc;

    public Huber() {
        this.cost = 5.0d;
        postConfig();
    }

    public Huber(double d) {
        this.cost = 5.0d;
        this.cost = d;
        postConfig();
    }

    public void postConfig() {
        if (this.cost <= 0.0d) {
            throw new PropertyException("", "cost", "Cost must be a positive value, found " + this.cost);
        }
        this.lossFunc = d -> {
            return d > this.cost ? (this.cost * d) - ((0.5d * this.cost) * this.cost) : 0.5d * d * d;
        };
    }

    @Override // org.tribuo.regression.sgd.RegressionObjective
    @Deprecated
    public Pair<Double, SGDVector> loss(DenseVector denseVector, SGDVector sGDVector) {
        return lossAndGradient(denseVector, sGDVector);
    }

    @Override // org.tribuo.regression.sgd.RegressionObjective
    public Pair<Double, SGDVector> lossAndGradient(DenseVector denseVector, SGDVector sGDVector) {
        DenseVector subtract = denseVector.subtract(sGDVector);
        DenseVector copy = subtract.copy();
        copy.foreachInPlace(Math::abs);
        double reduce = copy.reduce(0.0d, this.lossFunc, Double::sum);
        subtract.foreachInPlace(d -> {
            return Math.abs(d) > this.cost ? Double.compare(d, 0.0d) * this.cost : d;
        });
        return new Pair<>(Double.valueOf(reduce), subtract);
    }

    public String toString() {
        return "Huber(cost=" + this.cost + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m13getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "RegressionObjective");
    }
}
