package org.tribuo.clustering.hdbscan;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.hdbscan.HdbscanTrainer;
import org.tribuo.clustering.hdbscan.protos.ClusterExemplarProto;
import org.tribuo.clustering.hdbscan.protos.HdbscanModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.distance.Distance;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.DistanceProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/clustering/hdbscan/HdbscanModel.class */
public final class HdbscanModel extends Model<ClusterID> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    private final List<Integer> clusterLabels;
    private final DenseVector outlierScoresVector;

    @Deprecated
    private HdbscanTrainer.Distance distanceType;
    private Distance dist;
    private final List<HdbscanTrainer.ClusterExemplar> clusterExemplars;
    private final double noisePointsOutlierScore;

    /* JADX INFO: Access modifiers changed from: package-private */
    public HdbscanModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<ClusterID> immutableOutputInfo, List<Integer> list, DenseVector denseVector, List<HdbscanTrainer.ClusterExemplar> list2, Distance distance, double d) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.clusterLabels = Collections.unmodifiableList(list);
        this.outlierScoresVector = denseVector;
        this.clusterExemplars = Collections.unmodifiableList(list2);
        this.dist = distance;
        this.noisePointsOutlierScore = d;
    }

    public static HdbscanModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        HdbscanModelProto unpack = any.unpack(HdbscanModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(ClusterID.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a clustering domain, found " + deserialize.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        DenseVector deserialize2 = Tensor.deserialize(unpack.getOutlierScoresVector());
        if (!(deserialize2 instanceof DenseVector)) {
            throw new IllegalStateException("Invalid protobuf, outlier scores must be a dense vector, found " + deserialize2.getClass());
        }
        DenseVector denseVector = deserialize2;
        ArrayList<Integer> arrayList = new ArrayList(unpack.getClusterLabelsList());
        for (Integer num : arrayList) {
            if (outputDomain.getOutput(num.intValue()) == null && num.intValue() != -1) {
                throw new IllegalStateException("Invalid protobuf, found cluster id " + num + " which is not present in the domain " + outputDomain);
            }
        }
        if (arrayList.size() != denseVector.size()) {
            throw new IllegalStateException("Invalid protobuf, expected the same number of outlier scores as cluster labels, found " + denseVector.size() + " scores and " + arrayList.size() + " labels");
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator<ClusterExemplarProto> it = unpack.getClusterExemplarsList().iterator();
        while (it.hasNext()) {
            arrayList2.add(HdbscanTrainer.ClusterExemplar.deserialize(it.next()));
        }
        return new HdbscanModel(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), outputDomain, arrayList, denseVector, arrayList2, ProtoUtil.deserialize(unpack.getDistance()), unpack.getNoisePointsOutlierScore());
    }

    public List<Integer> getClusterLabels() {
        return this.clusterLabels;
    }

    public List<Double> getOutlierScores() {
        ArrayList arrayList = new ArrayList(this.outlierScoresVector.size());
        for (double d : this.outlierScoresVector.toArray()) {
            arrayList.add(Double.valueOf(d));
        }
        return arrayList;
    }

    public List<HdbscanTrainer.ClusterExemplar> getClusterExemplars() {
        ArrayList arrayList = new ArrayList(this.clusterExemplars.size());
        Iterator<HdbscanTrainer.ClusterExemplar> it = this.clusterExemplars.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().copy());
        }
        return arrayList;
    }

    public List<Pair<Integer, List<Feature>>> getClusters() {
        ArrayList arrayList = new ArrayList(this.clusterExemplars.size());
        for (HdbscanTrainer.ClusterExemplar clusterExemplar : this.clusterExemplars) {
            ArrayList arrayList2 = new ArrayList(clusterExemplar.getFeatures().numActiveElements());
            for (VectorTuple vectorTuple : clusterExemplar.getFeatures()) {
                arrayList2.add(new Feature(this.featureIDMap.get(vectorTuple.index).getName(), vectorTuple.value));
            }
            arrayList.add(new Pair(clusterExemplar.getLabel(), arrayList2));
        }
        return arrayList;
    }

    public Prediction<ClusterID> predict(Example<ClusterID> example) {
        DenseVector createDenseVector = example.size() == this.featureIDMap.size() ? DenseVector.createDenseVector(example, this.featureIDMap, false) : SparseVector.createSparseVector(example, this.featureIDMap, false);
        if (createDenseVector.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example);
        }
        double d = Double.POSITIVE_INFINITY;
        int i = 0;
        double d2 = 0.0d;
        if (Double.compare(this.noisePointsOutlierScore, 0.0d) > 0) {
            boolean z = true;
            for (HdbscanTrainer.ClusterExemplar clusterExemplar : this.clusterExemplars) {
                double computeDistance = this.dist.computeDistance(clusterExemplar.getFeatures(), createDenseVector);
                if (z && computeDistance <= clusterExemplar.getMaxDistToEdge().doubleValue()) {
                    z = false;
                }
                if (computeDistance < d) {
                    d = computeDistance;
                    i = clusterExemplar.getLabel().intValue();
                    d2 = clusterExemplar.getOutlierScore().doubleValue();
                }
            }
            if (z) {
                i = 0;
                d2 = this.noisePointsOutlierScore;
            }
        } else {
            for (HdbscanTrainer.ClusterExemplar clusterExemplar2 : this.clusterExemplars) {
                double computeDistance2 = this.dist.computeDistance(clusterExemplar2.getFeatures(), createDenseVector);
                if (computeDistance2 < d) {
                    d = computeDistance2;
                    i = clusterExemplar2.getLabel().intValue();
                    d2 = clusterExemplar2.getOutlierScore().doubleValue();
                }
            }
        }
        return new Prediction<>(new ClusterID(i, d2), createDenseVector.size(), example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) {
        return Optional.empty();
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m2serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        HdbscanModelProto.Builder newBuilder = HdbscanModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.addAllClusterLabels(this.clusterLabels);
        newBuilder.setOutlierScoresVector(this.outlierScoresVector.serialize());
        newBuilder.setDistance((DistanceProto) this.dist.serialize());
        Iterator<HdbscanTrainer.ClusterExemplar> it = this.clusterExemplars.iterator();
        while (it.hasNext()) {
            newBuilder.addClusterExemplars(it.next().serialize());
        }
        newBuilder.setNoisePointsOutlierScore(this.noisePointsOutlierScore);
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m97build()));
        newBuilder2.setClassName(HdbscanModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public HdbscanModel m1copy(String str, ModelProvenance modelProvenance) {
        DenseVector copy = this.outlierScoresVector.copy();
        return new HdbscanModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, new ArrayList(this.clusterLabels), copy, new ArrayList(this.clusterExemplars), this.dist, this.noisePointsOutlierScore);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        if (this.dist == null) {
            this.dist = this.distanceType.getDistanceType().getDistance();
        }
    }
}
