package org.tribuo.classification.liblinear;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.logging.Logger;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.common.liblinear.protos.LibLinearModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

/* loaded from: input_file:org/tribuo/classification/liblinear/LibLinearClassificationModel.class */
public class LibLinearClassificationModel extends LibLinearModel<Label> implements ONNXExportable {
    private static final long serialVersionUID = 3;
    private static final Logger logger = Logger.getLogger(LibLinearClassificationModel.class.getName());
    private final Set<Label> unobservedLabels;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LibLinearClassificationModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, List<Model> list) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, list.get(0).isProbabilityModel(), list);
        int[] labels = list.get(0).getLabels();
        if (labels.length == immutableOutputInfo.size()) {
            this.unobservedLabels = Collections.emptySet();
            return;
        }
        HashMap hashMap = new HashMap();
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            hashMap.put((Integer) pair.getA(), (Label) pair.getB());
        }
        for (int i = 0; i < labels.length; i++) {
            hashMap.remove(Integer.valueOf(i));
        }
        HashSet hashSet = new HashSet(hashMap.values().size());
        Iterator it2 = hashMap.values().iterator();
        while (it2.hasNext()) {
            hashSet.add(new Label(((Label) it2.next()).getLabel(), 0.0d));
        }
        this.unobservedLabels = Collections.unmodifiableSet(hashSet);
    }

    public static LibLinearClassificationModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        if (!"org.tribuo.classification.liblinear.LibLinearClassificationModel".equals(str)) {
            throw new IllegalStateException("Invalid protobuf, this class can only deserialize LibLinearClassificationModel");
        }
        LibLinearModelProto unpack = any.unpack(LibLinearModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(Label.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + deserialize.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        if (unpack.getModelsCount() != 1) {
            throw new IllegalStateException("Invalid protobuf, expected 1 model, found " + unpack.getModelsCount());
        }
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new ByteArrayInputStream(unpack.getModels(0).toByteArray()));
            Model model = (Model) objectInputStream.readObject();
            objectInputStream.close();
            return new LibLinearClassificationModel(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), outputDomain, Collections.singletonList(model));
        } catch (IOException | ClassNotFoundException e) {
            throw new IllegalStateException("Invalid protobuf, failed to deserialize liblinear model", e);
        }
    }

    public Prediction<Label> predict(Example<Label> example) {
        FeatureNode[] exampleToNodes = LibLinearTrainer.exampleToNodes(example, this.featureIDMap, (List) null);
        if (exampleToNodes.length == 1) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        Model model = (Model) this.models.get(0);
        int[] labels = model.getLabels();
        double[] dArr = new double[labels.length];
        if (model.isProbabilityModel()) {
            Linear.predictProbability(model, exampleToNodes, dArr);
        } else {
            Linear.predictValues(model, exampleToNodes, dArr);
            if (model.getNrClass() == 2 && dArr[1] == 0.0d) {
                dArr[1] = -dArr[0];
            }
        }
        double d = Double.NEGATIVE_INFINITY;
        Label label = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < dArr.length; i++) {
            String label2 = this.outputIDInfo.getOutput(labels[i]).getLabel();
            Label label3 = new Label(label2, dArr[i]);
            linkedHashMap.put(label2, label3);
            if (label3.getScore() > d) {
                d = label3.getScore();
                label = label3;
            }
        }
        if (!this.unobservedLabels.isEmpty()) {
            for (Label label4 : this.unobservedLabels) {
                linkedHashMap.put(label4.getLabel(), label4);
            }
        }
        return new Prediction<>(label, linkedHashMap, exampleToNodes.length - 1, example, this.generatesProbabilities);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        int size = i < 0 ? this.featureIDMap.size() : i;
        Model model = (Model) this.models.get(0);
        int[] labels = model.getLabels();
        double[] featureWeights = model.getFeatureWeights();
        Comparator comparingDouble = Comparator.comparingDouble(pair -> {
            return Math.abs(((Double) pair.getB()).doubleValue());
        });
        HashMap hashMap = new HashMap();
        int nrClass = model.getNrClass();
        int nrFeature = model.getNrFeature();
        if (nrClass == 2) {
            PriorityQueue priorityQueue = new PriorityQueue(size, comparingDouble);
            for (int i2 = 0; i2 < nrFeature; i2++) {
                Pair pair2 = new Pair(this.featureIDMap.get(i2).getName(), Double.valueOf(featureWeights[i2]));
                if (priorityQueue.size() < size) {
                    priorityQueue.offer(pair2);
                } else if (comparingDouble.compare(pair2, (Pair) priorityQueue.peek()) > 0) {
                    priorityQueue.poll();
                    priorityQueue.offer(pair2);
                }
            }
            ArrayList<Pair> arrayList = new ArrayList();
            while (priorityQueue.size() > 0) {
                arrayList.add((Pair) priorityQueue.poll());
            }
            Collections.reverse(arrayList);
            hashMap.put(this.outputIDInfo.getOutput(labels[0]).getLabel(), arrayList);
            ArrayList arrayList2 = new ArrayList();
            for (Pair pair3 : arrayList) {
                arrayList2.add(new Pair((String) pair3.getA(), Double.valueOf(-((Double) pair3.getB()).doubleValue())));
            }
            hashMap.put(this.outputIDInfo.getOutput(labels[1]).getLabel(), arrayList2);
        } else {
            for (int i3 = 0; i3 < labels.length; i3++) {
                PriorityQueue priorityQueue2 = new PriorityQueue(size, comparingDouble);
                for (int i4 = 0; i4 < nrFeature; i4++) {
                    Pair pair4 = new Pair(this.featureIDMap.get(i4).getName(), Double.valueOf(featureWeights[(i4 * nrClass) + i3]));
                    if (priorityQueue2.size() < size) {
                        priorityQueue2.offer(pair4);
                    } else if (comparingDouble.compare(pair4, (Pair) priorityQueue2.peek()) > 0) {
                        priorityQueue2.poll();
                        priorityQueue2.offer(pair4);
                    }
                }
                ArrayList arrayList3 = new ArrayList();
                while (priorityQueue2.size() > 0) {
                    arrayList3.add((Pair) priorityQueue2.poll());
                }
                Collections.reverse(arrayList3);
                hashMap.put(this.outputIDInfo.getOutput(labels[i3]).getLabel(), arrayList3);
            }
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public LibLinearClassificationModel m1copy(String str, ModelProvenance modelProvenance) {
        return new LibLinearClassificationModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, Collections.singletonList(copyModel((Model) this.models.get(0))));
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    protected double[][] getFeatureWeights() {
        return new double[]{((Model) this.models.get(0)).getFeatureWeights()};
    }

    protected Excuse<Label> innerGetExcuse(Example<Label> example, double[][] dArr) {
        Model model = (Model) this.models.get(0);
        double[] dArr2 = dArr[0];
        int[] labels = model.getLabels();
        int nrClass = model.getNrClass();
        Prediction<Label> predict = predict(example);
        HashMap hashMap = new HashMap();
        if (nrClass == 2) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            Iterator it = example.iterator();
            while (it.hasNext()) {
                Feature feature = (Feature) it.next();
                int id = this.featureIDMap.getID(feature.getName());
                if (id > -1) {
                    double value = dArr2[id] * feature.getValue();
                    arrayList.add(new Pair(feature.getName(), Double.valueOf(value)));
                    arrayList2.add(new Pair(feature.getName(), Double.valueOf(-value)));
                }
            }
            arrayList.sort((pair, pair2) -> {
                return ((Double) pair2.getB()).compareTo((Double) pair.getB());
            });
            arrayList2.sort((pair3, pair4) -> {
                return ((Double) pair4.getB()).compareTo((Double) pair3.getB());
            });
            hashMap.put(this.outputIDInfo.getOutput(labels[0]).getLabel(), arrayList);
            hashMap.put(this.outputIDInfo.getOutput(labels[1]).getLabel(), arrayList2);
        } else {
            for (int i = 0; i < labels.length; i++) {
                ArrayList arrayList3 = new ArrayList();
                Iterator it2 = example.iterator();
                while (it2.hasNext()) {
                    Feature feature2 = (Feature) it2.next();
                    int id2 = this.featureIDMap.getID(feature2.getName());
                    if (id2 > -1) {
                        arrayList3.add(new Pair(feature2.getName(), Double.valueOf(dArr2[(id2 * nrClass) + i] * feature2.getValue())));
                    }
                }
                arrayList3.sort((pair5, pair6) -> {
                    return ((Double) pair6.getB()).compareTo((Double) pair5.getB());
                });
                hashMap.put(this.outputIDInfo.getOutput(labels[i]).getLabel(), arrayList3);
            }
        }
        return new Excuse<>(example, predict, hashMap);
    }

    public OnnxMl.ModelProto exportONNXModel(String str, long j) {
        ONNXContext oNNXContext = new ONNXContext();
        oNNXContext.setName("Classification-LibLinear");
        ONNXPlaceholder floatInput = oNNXContext.floatInput(this.featureIDMap.size());
        writeONNXGraph(floatInput).assignTo(oNNXContext.floatOutput(this.outputIDInfo.size()));
        return ONNXExportable.buildModel(oNNXContext, str, j, this);
    }

    public ONNXNode writeONNXGraph(ONNXRef<?> oNNXRef) {
        double[] dArr;
        ONNXContext onnxContext = oNNXRef.onnxContext();
        Model model = (Model) this.models.get(0);
        double[] featureWeights = model.getFeatureWeights();
        int[] labels = model.getLabels();
        int size = this.featureIDMap.size();
        int length = labels.length;
        if (length != this.outputIDInfo.size()) {
            throw new IllegalStateException("Unexpected number of labels, output domain = " + this.outputIDInfo.size() + ", LibLinear's internal count = " + length);
        }
        if (model.getNrClass() == 2) {
            double[] dArr2 = new double[featureWeights.length * 2];
            for (int i = 0; i < featureWeights.length; i++) {
                if (labels[0] == 0) {
                    dArr2[i * 2] = featureWeights[i];
                    dArr2[(i * 2) + 1] = -featureWeights[i];
                } else {
                    dArr2[i * 2] = -featureWeights[i];
                    dArr2[(i * 2) + 1] = featureWeights[i];
                }
            }
            dArr = dArr2;
        } else {
            double[] dArr3 = new double[featureWeights.length];
            for (int i2 = 0; i2 < size + 1; i2++) {
                for (int i3 = 0; i3 < length; i3++) {
                    dArr3[(i2 * length) + labels[i3]] = featureWeights[(i2 * length) + i3];
                }
            }
            dArr = dArr3;
        }
        double[] dArr4 = dArr;
        ONNXNode apply = oNNXRef.apply(ONNXOperators.GEMM, Arrays.asList(onnxContext.floatTensor("liblinear_weights", Arrays.asList(Integer.valueOf(size), Integer.valueOf(length)), floatBuffer -> {
            for (int i4 = 0; i4 < dArr4.length - length; i4++) {
                floatBuffer.put((float) dArr4[i4]);
            }
        }), onnxContext.floatTensor("liblinear_biases", Collections.singletonList(Integer.valueOf(length)), floatBuffer2 -> {
            for (int i4 = size * length; i4 < dArr4.length; i4++) {
                floatBuffer2.put((float) dArr4[i4]);
            }
        })));
        return model.isProbabilityModel() ? apply.apply(ONNXOperators.SOFTMAX, Collections.singletonMap("axis", 1)) : apply;
    }
}
