package org.tribuo.regression.evaluation;

import com.oracle.labs.mlrg.olcut.util.MutableDouble;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.Regressor;

/* loaded from: input_file:org/tribuo/regression/evaluation/RegressionSufficientStatistics.class */
public final class RegressionSufficientStatistics {
    final int n;
    final ImmutableOutputInfo<Regressor> domain;
    final Map<String, MutableDouble> sumAbsoluteError = new LinkedHashMap();
    final Map<String, MutableDouble> sumSquaredError = new LinkedHashMap();
    final Map<String, double[]> predictedValues = new LinkedHashMap();
    final Map<String, double[]> trueValues = new LinkedHashMap();
    final float[] weights;
    final float weightSum;

    public RegressionSufficientStatistics(ImmutableOutputInfo<Regressor> immutableOutputInfo, List<Prediction<Regressor>> list, boolean z) {
        this.domain = immutableOutputInfo;
        this.n = list.size();
        this.weights = initWeights(list, z);
        Iterator it = immutableOutputInfo.getDomain().iterator();
        while (it.hasNext()) {
            String str = ((Regressor) it.next()).getNames()[0];
            this.sumAbsoluteError.put(str, new MutableDouble());
            this.sumSquaredError.put(str, new MutableDouble());
            this.predictedValues.put(str, new double[this.n]);
            this.trueValues.put(str, new double[this.n]);
        }
        this.weightSum = tabulate(list);
    }

    private float tabulate(List<Prediction<Regressor>> list) {
        float f = 0.0f;
        for (int i = 0; i < this.n; i++) {
            Prediction<Regressor> prediction = list.get(i);
            float f2 = this.weights[i];
            f += f2;
            Regressor regressor = (Regressor) prediction.getOutput();
            Regressor regressor2 = (Regressor) prediction.getExample().getOutput();
            if (regressor2.equals(RegressionFactory.UNKNOWN_REGRESSOR)) {
                throw new IllegalArgumentException("The sentinel Unknown Regressor was used as a ground truth output at prediction number " + i);
            }
            if (regressor.equals(RegressionFactory.UNKNOWN_REGRESSOR)) {
                throw new IllegalArgumentException("The sentinel Unknown Regressor was predicted by the model at prediction number " + i);
            }
            for (int i2 = 0; i2 < regressor2.size(); i2++) {
                String str = regressor2.getNames()[i2];
                double d = regressor2.getValues()[i2];
                double d2 = regressor.getValues()[i2];
                double d3 = d - d2;
                this.sumAbsoluteError.get(str).increment(f2 * Math.abs(d3));
                this.sumSquaredError.get(str).increment(f2 * d3 * d3);
                this.trueValues.get(str)[i] = d;
                this.predictedValues.get(str)[i] = d2;
            }
        }
        return f;
    }

    private static float[] initWeights(List<Prediction<Regressor>> list, boolean z) {
        float[] fArr = new float[list.size()];
        if (z) {
            for (int i = 0; i < list.size(); i++) {
                fArr[i] = list.get(i).getExample().getWeight();
            }
        } else {
            Arrays.fill(fArr, 1.0f);
        }
        return fArr;
    }
}
