package org.tribuo.regression.slm;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.logging.Logger;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.regression.slm.SLMTrainer;

/* loaded from: input_file:org/tribuo/regression/slm/LARSLassoTrainer.class */
public class LARSLassoTrainer extends SLMTrainer {
    private static final Logger logger = Logger.getLogger(LARSLassoTrainer.class.getName());

    public LARSLassoTrainer(int i) {
        super(true, i);
    }

    public LARSLassoTrainer() {
        this(-1);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.tribuo.regression.slm.SLMTrainer
    public DenseVector newWeights(SLMTrainer.SLMState sLMState) {
        if (sLMState.last) {
            return super.newWeights(sLMState);
        }
        Pair<DenseVector, DenseMatrix> ordinaryLeastSquares = SLMTrainer.ordinaryLeastSquares(sLMState.xpi, sLMState.r);
        if (ordinaryLeastSquares == null) {
            return null;
        }
        DenseVector unpack = sLMState.unpack((DenseVector) ordinaryLeastSquares.getA());
        DenseMatrix denseMatrix = (DenseMatrix) ordinaryLeastSquares.getB();
        ArrayList arrayList = new ArrayList();
        double sum = denseMatrix.rowSum().sum();
        double d = sLMState.C;
        DenseVector a = SLMTrainer.getA(sLMState.X, sLMState.xpi, SLMTrainer.getWA(denseMatrix, sum));
        for (int i = 0; i < sLMState.numFeatures; i++) {
            if (!sLMState.activeSet.contains(Integer.valueOf(i))) {
                double d2 = sLMState.corr.get(i);
                double d3 = a.get(i);
                double d4 = (d - d2) / (sum - d3);
                double d5 = (d + d2) / (sum + d3);
                if (d4 >= 0.0d) {
                    arrayList.add(Double.valueOf(d4));
                }
                if (d5 >= 0.0d) {
                    arrayList.add(Double.valueOf(d5));
                }
            }
        }
        unpack.scaleInPlace(((Double) Collections.min(arrayList)).doubleValue());
        for (int i2 = 0; i2 < sLMState.numFeatures; i2++) {
            double d6 = sLMState.beta.get(i2);
            double d7 = unpack.get(i2);
            if ((d6 > 0.0d && d6 + d7 < 0.0d) || (d6 < 0.0d && d6 + d7 > 0.0d)) {
                sLMState.beta.set(i2, 0.0d);
                unpack.set(i2, 0.0d);
                Integer valueOf = Integer.valueOf(i2);
                sLMState.active.remove(valueOf);
                sLMState.activeSet.remove(valueOf);
            }
        }
        return sLMState.beta.add(unpack);
    }

    @Override // org.tribuo.regression.slm.SLMTrainer
    public String toString() {
        return "LARSLassoTrainer(maxNumFeatures=" + this.maxNumFeatures + ")";
    }
}
