package hex.genmodel.algos.gam;

import hex.genmodel.ConverterFactoryProvidingModel;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.CategoricalEncoder;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.RowToRawDataConverter;
import hex.genmodel.utils.ArrayUtils;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import java.util.Map;

/* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/gam/GamMojoModelBase.class */
public abstract class GamMojoModelBase extends MojoModel implements ConverterFactoryProvidingModel {
    public LinkFunctionType _link_function;
    boolean _useAllFactorLevels;
    int _cats;
    int[] _catNAFills;
    int[] _catOffsets;
    int _nums;
    int _numsCenter;
    double[] _numNAFills;
    double[] _numNAFillsCenter;
    boolean _meanImputation;
    double[] _beta;
    double[] _beta_no_center;
    double[] _beta_center;
    double[][] _beta_multinomial;
    double[][] _beta_multinomial_no_center;
    double[][] _beta_multinomial_center;
    DistributionFamily _family;
    String[] _gam_columns;
    int _num_gam_columns;
    int[] _bs;
    int[] _num_knots;
    double[][] _knots;
    double[][][] _binvD;
    double[][][] _zTranspose;
    String[][] _gamColNames;
    String[][] _gamColNamesCenter;
    String[] _names_no_centering;
    int _totFeatureSize;
    int _betaSizePerClass;
    int _betaCenterSizePerClass;
    double _tweedieLinkPower;
    double[][] _basisVals;
    double[][] _hj;
    int _numExpandedGamCols;
    int _lastClass;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GamMojoModelBase(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
    }

    @Override // hex.genmodel.GenModel
    public double[] score0(double[] dArr, double[] dArr2) {
        if (this._meanImputation) {
            imputeMissingWithMeans(dArr);
        }
        return gamScore0(dArr, dArr2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [double[], double[][]] */
    public void init() {
        this._basisVals = new double[this._gam_columns.length];
        this._hj = new double[this._gam_columns.length];
        for (int i = 0; i < this._num_gam_columns; i++) {
            this._basisVals[i] = new double[this._num_knots[i]];
            this._hj[i] = ArrayUtils.eleDiff(this._knots[i]);
        }
        this._lastClass = this._nclasses - 1;
    }

    abstract double[] gamScore0(double[] dArr, double[] dArr2);

    private void imputeMissingWithMeans(double[] dArr) {
        for (int i = 0; i < this._cats; i++) {
            if (Double.isNaN(dArr[i])) {
                dArr[i] = this._catNAFills[i];
            }
        }
        if (dArr.length == nfeatures()) {
            for (int i2 = 0; i2 < this._numsCenter; i2++) {
                if (Double.isNaN(dArr[i2 + this._cats])) {
                    dArr[i2 + this._cats] = this._numNAFillsCenter[i2];
                }
            }
            return;
        }
        for (int i3 = 0; i3 < this._nums; i3++) {
            int i4 = i3 + this._cats;
            if (Double.isNaN(dArr[i4])) {
                dArr[i4] = this._numNAFills[i3];
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double evalLink(double d) {
        switch (this._link_function) {
            case identity:
                return GenModel.GLM_identityInv(d);
            case logit:
                return GenModel.GLM_logitInv(d);
            case log:
                return GenModel.GLM_logInv(d);
            case inverse:
                return GenModel.GLM_inverseInv(d);
            case tweedie:
                return GenModel.GLM_tweedieInv(d, this._tweedieLinkPower);
            default:
                throw new UnsupportedOperationException("Unexpected link function " + this._link_function);
        }
    }

    int readCatVal(double d, int i) {
        int i2 = this._useAllFactorLevels ? (int) d : ((int) d) - 1;
        if (i2 < 0) {
            return -1;
        }
        return i2 + this._catOffsets[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double generateEta(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        int length = this._catOffsets.length - 1;
        for (int i = 0; i < length; i++) {
            int readCatVal = readCatVal(dArr2[i], i);
            if (readCatVal < this._catOffsets[i + 1] && readCatVal >= 0) {
                d += dArr[readCatVal];
            }
        }
        int i2 = this._catOffsets[this._cats] - this._cats;
        int length2 = (dArr.length - 1) - i2;
        for (int i3 = this._cats; i3 < length2; i3++) {
            d += dArr[i2 + i3] * dArr2[i3];
        }
        return d + dArr[dArr.length - 1];
    }

    private boolean gamificationNeeded(double[] dArr, int i) {
        for (int i2 = i; i2 < dArr.length; i2++) {
            if (!Double.isNaN(dArr[i2])) {
                return false;
            }
        }
        return true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double[] addExpandGamCols(double[] dArr, RowData rowData) {
        int i;
        int i2;
        int i3 = this._totFeatureSize - this._numExpandedGamCols;
        if (!gamificationNeeded(dArr, i3)) {
            return dArr;
        }
        double[] nanArray = ArrayUtils.nanArray(this._totFeatureSize);
        System.arraycopy(dArr, 0, nanArray, 0, i3);
        for (int i4 = 0; i4 < this._num_gam_columns; i4++) {
            if (this._bs[i4] != 0) {
                throw new IllegalArgumentException("spline type not implemented!");
            }
            Object obj = rowData.get(this._gam_columns[i4]);
            if (obj == null) {
                i = i3;
                i2 = this._num_knots[i4];
            } else {
                GamUtilsCubicRegression.expandOneGamCol(obj instanceof String ? Double.parseDouble((String) obj) : ((Double) obj).doubleValue(), this._binvD[i4], this._basisVals[i4], this._hj[i4], this._knots[i4]);
                System.arraycopy(this._basisVals[i4], 0, nanArray, i3, this._num_knots[i4]);
                i = i3;
                i2 = this._num_knots[i4];
            }
            i3 = i + i2;
        }
        return nanArray;
    }

    @Override // hex.genmodel.ConverterFactoryProvidingModel
    public RowToRawDataConverter makeConverterFactory(Map<String, Integer> map, Map<Integer, CategoricalEncoder> map2, EasyPredictModelWrapper.ErrorConsumer errorConsumer, EasyPredictModelWrapper.Config config) {
        return new GamRowToRawDataConverter(this, map, map2, errorConsumer, config);
    }
}
