package org.tribuo.regression.slm;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.VariableIDInfo;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel;
import org.tribuo.regression.slm.protos.SparseLinearModelProto;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

/* loaded from: input_file:org/tribuo/regression/slm/SparseLinearModel.class */
public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel implements ONNXExportable {
    private static final long serialVersionUID = 3;
    private static final Logger logger = Logger.getLogger(SparseLinearModel.class.getName());
    public static final int CURRENT_VERSION = 0;
    private SparseVector[] weights;
    private final DenseVector featureMeans;
    private final DenseVector featureVariance;
    private final boolean bias;
    private double[] yMean;
    private double[] yVariance;
    private boolean enet41MappingFix;

    /* JADX INFO: Access modifiers changed from: package-private */
    public SparseLinearModel(String str, String[] strArr, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, SparseVector[] sparseVectorArr, DenseVector denseVector, DenseVector denseVector2, double[] dArr, double[] dArr2, boolean z) {
        super(str, strArr, modelProvenance, immutableFeatureMap, immutableOutputInfo, generateActiveFeatures(strArr, immutableFeatureMap, sparseVectorArr));
        this.weights = sparseVectorArr;
        this.featureMeans = denseVector;
        this.featureVariance = denseVector2;
        this.bias = z;
        this.yVariance = dArr2;
        this.yMean = dArr;
        this.enet41MappingFix = true;
    }

    public static SparseLinearModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        SparseLinearModelProto unpack = any.unpack(SparseLinearModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + deserialize.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        String[] strArr = new String[unpack.getDimensionsCount()];
        if (strArr.length != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, found insufficient dimension names, expected " + outputDomain.size() + ", found " + strArr.length);
        }
        for (int i2 = 0; i2 < strArr.length; i2++) {
            strArr[i2] = unpack.getDimensions(i2);
        }
        SparseVector[] sparseVectorArr = new SparseVector[outputDomain.size()];
        if (sparseVectorArr.length != unpack.getWeightsCount()) {
            throw new IllegalStateException("Invalid protobuf, expected same weight dimension as output domain size, found " + unpack.getWeightsCount() + " weights and " + outputDomain.size() + " output dimensions");
        }
        int size = unpack.getBias() ? deserialize.featureDomain().size() + 1 : deserialize.featureDomain().size();
        for (int i3 = 0; i3 < sparseVectorArr.length; i3++) {
            Tensor deserialize2 = Tensor.deserialize(unpack.getWeights(i3));
            if (!(deserialize2 instanceof SparseVector)) {
                throw new IllegalStateException("Invalid protobuf, expected a SparseVector, found " + deserialize2.getClass());
            }
            SparseVector sparseVector = (SparseVector) deserialize2;
            if (sparseVector.size() != size) {
                throw new IllegalStateException("Invalid protobuf, weights size and feature domain do not match, expected " + size + ", found " + sparseVector.size());
            }
            sparseVectorArr[i3] = sparseVector;
        }
        DenseVector deserialize3 = Tensor.deserialize(unpack.getFeatureMeans());
        if (!(deserialize3 instanceof DenseVector)) {
            throw new IllegalStateException("Invalid protobuf, feature means must be a dense vector, found " + deserialize3.getClass());
        }
        DenseVector denseVector = deserialize3;
        if (denseVector.size() != size) {
            throw new IllegalStateException("Invalid protobuf, feature means not the right size, expected " + size + ", found " + denseVector.size());
        }
        DenseVector deserialize4 = Tensor.deserialize(unpack.getFeatureNorms());
        if (!(deserialize4 instanceof DenseVector)) {
            throw new IllegalStateException("Invalid protobuf, feature means must be a dense vector, found " + deserialize4.getClass());
        }
        DenseVector denseVector2 = deserialize4;
        if (denseVector2.size() != size) {
            throw new IllegalStateException("Invalid protobuf, feature means not the right size, expected " + size + ", found " + denseVector2.size());
        }
        double[] primitiveDouble = Util.toPrimitiveDouble(unpack.getYMeanList());
        if (primitiveDouble.length != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, y means not the right size, expected " + deserialize.outputDomain().size() + " found " + primitiveDouble.length);
        }
        double[] primitiveDouble2 = Util.toPrimitiveDouble(unpack.getYNormList());
        if (primitiveDouble2.length != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, y norms not the right size, expected " + deserialize.outputDomain().size() + " found " + primitiveDouble2.length);
        }
        return new SparseLinearModel(deserialize.name(), strArr, deserialize.provenance(), deserialize.featureDomain(), outputDomain, sparseVectorArr, denseVector, denseVector2, primitiveDouble, primitiveDouble2, unpack.getBias());
    }

    private static Map<String, List<String>> generateActiveFeatures(String[] strArr, ImmutableFeatureMap immutableFeatureMap, SparseVector[] sparseVectorArr) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < strArr.length; i++) {
            ArrayList arrayList = new ArrayList();
            VectorIterator it = sparseVectorArr[i].iterator();
            while (it.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it.next();
                if (vectorTuple.index == immutableFeatureMap.size()) {
                    arrayList.add("BIAS");
                } else {
                    arrayList.add(immutableFeatureMap.get(vectorTuple.index).getName());
                }
            }
            hashMap.put(strArr[i], arrayList);
        }
        return hashMap;
    }

    protected SparseVector createFeatures(Example<Regressor> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, this.bias);
        createSparseVector.intersectAndAddInPlace(this.featureMeans, d -> {
            return -d;
        });
        createSparseVector.hadamardProductInPlace(this.featureVariance, d2 -> {
            return 1.0d / d2;
        });
        return createSparseVector;
    }

    protected Regressor.DimensionTuple scoreDimension(int i, SparseVector sparseVector) {
        return new Regressor.DimensionTuple(this.dimensions[i], ((this.weights[i].numActiveElements() > 0 ? this.weights[i].dot(sparseVector) : 1.0d) * this.yVariance[i]) + this.yMean[i]);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        int size = i < 0 ? this.featureIDMap.size() + 1 : i;
        Comparator comparingDouble = Comparator.comparingDouble(pair -> {
            return Math.abs(((Double) pair.getB()).doubleValue());
        });
        HashMap hashMap = new HashMap();
        PriorityQueue priorityQueue = new PriorityQueue(size, comparingDouble);
        for (int i2 = 0; i2 < this.dimensions.length; i2++) {
            priorityQueue.clear();
            VectorIterator it = this.weights[i2].iterator();
            while (it.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it.next();
                VariableIDInfo variableIDInfo = this.featureIDMap.get(vectorTuple.index);
                Pair pair2 = new Pair(variableIDInfo == null ? "BIAS" : variableIDInfo.getName(), Double.valueOf(vectorTuple.value));
                if (priorityQueue.size() < size) {
                    priorityQueue.offer(pair2);
                } else if (comparingDouble.compare(pair2, (Pair) priorityQueue.peek()) > 0) {
                    priorityQueue.poll();
                    priorityQueue.offer(pair2);
                }
            }
            ArrayList arrayList = new ArrayList();
            while (priorityQueue.size() > 0) {
                arrayList.add((Pair) priorityQueue.poll());
            }
            Collections.reverse(arrayList);
            hashMap.put(this.dimensions[i2], arrayList);
        }
        return hashMap;
    }

    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
        Prediction predict = predict(example);
        HashMap hashMap = new HashMap();
        SparseVector createFeatures = createFeatures(example);
        for (int i = 0; i < this.dimensions.length; i++) {
            ArrayList arrayList = new ArrayList();
            VectorIterator it = createFeatures.iterator();
            while (it.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it.next();
                arrayList.add(new Pair(this.featureIDMap.get(vectorTuple.index).getName(), Double.valueOf(this.weights[i].get(vectorTuple.index) * vectorTuple.value)));
            }
            arrayList.sort((pair, pair2) -> {
                return ((Double) pair2.getB()).compareTo((Double) pair.getB());
            });
            hashMap.put(this.dimensions[i], arrayList);
        }
        return Optional.of(new Excuse(example, predict, hashMap));
    }

    protected Model<Regressor> copy(String str, ModelProvenance modelProvenance) {
        return new SparseLinearModel(str, (String[]) Arrays.copyOf(this.dimensions, this.dimensions.length), modelProvenance, this.featureIDMap, this.outputIDInfo, copyWeights(), this.featureMeans.copy(), this.featureVariance.copy(), Arrays.copyOf(this.yMean, this.yMean.length), Arrays.copyOf(this.yVariance, this.yVariance.length), this.bias);
    }

    private SparseVector[] copyWeights() {
        SparseVector[] sparseVectorArr = new SparseVector[this.weights.length];
        for (int i = 0; i < this.weights.length; i++) {
            sparseVectorArr[i] = this.weights[i].copy();
        }
        return sparseVectorArr;
    }

    public Map<String, SparseVector> getWeights() {
        SparseVector[] copyWeights = copyWeights();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.dimensions.length; i++) {
            hashMap.put(this.dimensions[i], copyWeights[i]);
        }
        return hashMap;
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m13serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        SparseLinearModelProto.Builder newBuilder = SparseLinearModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.addAllDimensions(Arrays.asList(this.dimensions));
        for (SparseVector sparseVector : this.weights) {
            newBuilder.addWeights(sparseVector.serialize());
        }
        newBuilder.setFeatureMeans(this.featureMeans.serialize());
        newBuilder.setFeatureNorms(this.featureVariance.serialize());
        newBuilder.setBias(this.bias);
        newBuilder.addAllYMean((Iterable) Arrays.stream(this.yMean).boxed().collect(Collectors.toList()));
        newBuilder.addAllYNorm((Iterable) Arrays.stream(this.yVariance).boxed().collect(Collectors.toList()));
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m58build()));
        newBuilder2.setClassName(SparseLinearModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    public OnnxMl.ModelProto exportONNXModel(String str, long j) {
        ONNXContext oNNXContext = new ONNXContext();
        ONNXPlaceholder floatInput = oNNXContext.floatInput(this.featureIDMap.size());
        ONNXPlaceholder floatOutput = oNNXContext.floatOutput(this.outputIDInfo.size());
        oNNXContext.setName("Regression-SparseLinearModel");
        return ONNXExportable.buildModel(writeONNXGraph(floatInput).assignTo(floatOutput).onnxContext(), str, j, this);
    }

    public ONNXNode writeONNXGraph(ONNXRef<?> oNNXRef) {
        ONNXContext onnxContext = oNNXRef.onnxContext();
        ONNXRef floatTensor = onnxContext.floatTensor("slm_weights", Arrays.asList(Integer.valueOf(this.featureIDMap.size()), Integer.valueOf(this.outputIDInfo.size())), floatBuffer -> {
            for (int i = 0; i < this.featureIDMap.size(); i++) {
                for (int i2 = 0; i2 < this.weights.length; i2++) {
                    floatBuffer.put((float) this.weights[i2].get(i));
                }
            }
        });
        ONNXRef floatTensor2 = onnxContext.floatTensor("slm_biases", Collections.singletonList(Integer.valueOf(this.outputIDInfo.size())), floatBuffer2 -> {
            Arrays.stream(this.weights).forEachOrdered(sparseVector -> {
                floatBuffer2.put((float) sparseVector.get(this.featureIDMap.size()));
            });
        });
        return oNNXRef.apply(ONNXOperators.SUB, onnxContext.array("feature_mean", this.bias ? Arrays.copyOf(this.featureMeans.toArray(), this.featureIDMap.size()) : this.featureMeans.toArray())).apply(ONNXOperators.DIV, onnxContext.array("feature_variance", this.bias ? Arrays.copyOf(this.featureVariance.toArray(), this.featureIDMap.size()) : this.featureVariance.toArray())).apply(ONNXOperators.GEMM, Arrays.asList(floatTensor, floatTensor2)).apply(ONNXOperators.MUL, onnxContext.array("y_variance", this.yVariance)).apply(ONNXOperators.ADD, onnxContext.array("y_mean", this.yMean));
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        String str = (String) ((PrimitiveProvenance) this.provenance.getTrainerProvenance().getInstanceValues().get("tribuo-version")).getValue();
        if (!this.provenance.getTrainerProvenance().getClassName().equals("org.tribuo.regression.slm.ElasticNetCDTrainer") || this.enet41MappingFix) {
            return;
        }
        if (str.startsWith("4.0.0") || str.startsWith("4.0.1") || str.startsWith("4.0.2") || str.startsWith("4.1.0") || str.equals("4.1.1-SNAPSHOT")) {
            this.enet41MappingFix = true;
            int[] iDtoNaturalOrderMapping = this.outputIDInfo.getIDtoNaturalOrderMapping();
            SparseVector[] sparseVectorArr = new SparseVector[this.weights.length];
            double[] dArr = new double[this.weights.length];
            double[] dArr2 = new double[this.weights.length];
            for (int i = 0; i < iDtoNaturalOrderMapping.length; i++) {
                sparseVectorArr[i] = this.weights[iDtoNaturalOrderMapping[i]];
                dArr[i] = this.yMean[iDtoNaturalOrderMapping[i]];
                dArr2[i] = this.yVariance[iDtoNaturalOrderMapping[i]];
            }
            this.yMean = dArr;
            this.yVariance = dArr2;
            this.weights = sparseVectorArr;
        }
    }
}
