package org.tribuo.regression.baseline;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.baseline.DummyRegressionTrainer;
import org.tribuo.regression.protos.DummyRegressionModelProto;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/regression/baseline/DummyRegressionModel.class */
public class DummyRegressionModel extends Model<Regressor> {
    private static final long serialVersionUID = 2;
    public static final int CURRENT_VERSION = 0;
    private final DummyRegressionTrainer.DummyType dummyType;
    private final Regressor output;
    private final long seed;
    private final Random rng;
    private final double[] means;
    private final double[] variances;
    private final String[] dimensionNames;

    /* JADX INFO: Access modifiers changed from: package-private */
    public DummyRegressionModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, long j, double[] dArr, double[] dArr2, String[] strArr) {
        super("dummy-GAUSSIAN-regression", modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.dummyType = DummyRegressionTrainer.DummyType.GAUSSIAN;
        this.output = null;
        this.seed = j;
        this.rng = new Random(j);
        this.means = Arrays.copyOf(dArr, dArr.length);
        this.variances = Arrays.copyOf(dArr2, dArr2.length);
        this.dimensionNames = (String[]) Arrays.copyOf(strArr, strArr.length);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DummyRegressionModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, DummyRegressionTrainer.DummyType dummyType, Regressor regressor) {
        super("dummy-" + dummyType + "-regression", modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.dummyType = dummyType;
        this.output = regressor;
        this.seed = 12345L;
        this.rng = null;
        this.means = new double[0];
        this.variances = new double[0];
        this.dimensionNames = new String[0];
    }

    private DummyRegressionModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, DummyRegressionTrainer.DummyType dummyType, Regressor regressor, long j, double[] dArr, double[] dArr2, String[] strArr) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.dummyType = dummyType;
        this.output = regressor;
        this.seed = j;
        this.rng = new Random(j);
        this.means = dArr;
        this.variances = dArr2;
        this.dimensionNames = strArr;
    }

    public static DummyRegressionModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        DummyRegressionModelProto unpack = any.unpack(DummyRegressionModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + deserialize.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        DummyRegressionTrainer.DummyType valueOf = DummyRegressionTrainer.DummyType.valueOf(unpack.getDummyType());
        Regressor regressor = null;
        if (!valueOf.equals(DummyRegressionTrainer.DummyType.GAUSSIAN)) {
            Output deserialize2 = Output.deserialize(unpack.getOutput());
            if (!(deserialize2 instanceof Regressor)) {
                throw new IllegalStateException("Invalid protobuf, expected a Regressor, found " + deserialize2.getClass());
            }
            regressor = (Regressor) deserialize2;
        }
        return new DummyRegressionModel(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), outputDomain, valueOf, regressor, unpack.getSeed(), Util.toPrimitiveDouble(unpack.getMeansList()), Util.toPrimitiveDouble(unpack.getVariancesList()), (String[]) unpack.mo95getDimensionNamesList().toArray(new String[0]));
    }

    public Prediction<Regressor> predict(Example<Regressor> example) {
        switch (this.dummyType) {
            case CONSTANT:
            case MEAN:
            case MEDIAN:
            case QUARTILE:
                return new Prediction<>(this.output, 0, example);
            case GAUSSIAN:
                Regressor.DimensionTuple[] dimensionTupleArr = new Regressor.DimensionTuple[this.dimensionNames.length];
                for (int i = 0; i < this.dimensionNames.length; i++) {
                    dimensionTupleArr[i] = new Regressor.DimensionTuple(this.dimensionNames[i], (this.rng.nextGaussian() * this.variances[i]) + this.means[i]);
                }
                return new Prediction<>(new Regressor(dimensionTupleArr), 0, example);
            default:
                throw new IllegalStateException("Unknown dummyType " + this.dummyType);
        }
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return i != 0 ? Collections.singletonMap("ALL_OUTPUTS", Collections.singletonList(new Pair("BIAS", Double.valueOf(1.0d)))) : Collections.emptyMap();
    }

    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
        return Optional.of(new Excuse(example, predict(example), getTopFeatures(1)));
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m13serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        DummyRegressionModelProto.Builder newBuilder = DummyRegressionModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setDummyType(this.dummyType.name());
        if (this.output != null) {
            newBuilder.setOutput(this.output.mo12serialize());
        }
        newBuilder.addAllMeans((Iterable) Arrays.stream(this.means).boxed().collect(Collectors.toList()));
        newBuilder.addAllVariances((Iterable) Arrays.stream(this.variances).boxed().collect(Collectors.toList()));
        newBuilder.addAllDimensionNames(Arrays.asList(this.dimensionNames));
        newBuilder.setSeed(this.seed);
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m128build()));
        newBuilder2.setClassName(DummyRegressionModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    protected Model<Regressor> copy(String str, ModelProvenance modelProvenance) {
        switch (this.dummyType) {
            case CONSTANT:
            case MEAN:
            case MEDIAN:
            case QUARTILE:
                return new DummyRegressionModel(modelProvenance, this.featureIDMap, this.outputIDInfo, this.dummyType, this.output.mo11copy());
            case GAUSSIAN:
                return new DummyRegressionModel(modelProvenance, this.featureIDMap, this.outputIDInfo, this.seed, this.means, this.variances, this.dimensionNames);
            default:
                throw new IllegalStateException("Unknown dummyType " + this.dummyType);
        }
    }
}
