package org.tribuo.clustering.kmeans;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.kmeans.KMeansTrainer;
import org.tribuo.clustering.kmeans.protos.KMeansModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.distance.Distance;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.DistanceProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/clustering/kmeans/KMeansModel.class */
public class KMeansModel extends Model<ClusterID> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    private final DenseVector[] centroidVectors;

    @Deprecated
    private KMeansTrainer.Distance distanceType;
    private Distance dist;

    /* JADX INFO: Access modifiers changed from: package-private */
    public KMeansModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<ClusterID> immutableOutputInfo, DenseVector[] denseVectorArr, Distance distance) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.centroidVectors = denseVectorArr;
        this.dist = distance;
    }

    public static KMeansModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        KMeansModelProto unpack = any.unpack(KMeansModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(ClusterID.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a clustering domain, found " + deserialize.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        ImmutableFeatureMap featureDomain = deserialize.featureDomain();
        if (unpack.getCentroidVectorsCount() == 0) {
            throw new IllegalStateException("Invalid protobuf, no centroids were found");
        }
        DenseVector[] denseVectorArr = new DenseVector[unpack.getCentroidVectorsCount()];
        List<TensorProto> centroidVectorsList = unpack.getCentroidVectorsList();
        for (int i2 = 0; i2 < denseVectorArr.length; i2++) {
            Tensor deserialize2 = Tensor.deserialize(centroidVectorsList.get(i2));
            if (!(deserialize2 instanceof DenseVector)) {
                throw new IllegalStateException("Invalid protobuf, expected centroid to be a dense vector, found " + deserialize2.getClass());
            }
            DenseVector denseVector = (DenseVector) deserialize2;
            if (denseVector.size() != featureDomain.size()) {
                throw new IllegalStateException("Invalid protobuf, centroid did not contain all the features, found " + denseVector.size() + " expected " + featureDomain.size());
            }
            denseVectorArr[i2] = denseVector;
        }
        return new KMeansModel(deserialize.name(), deserialize.provenance(), featureDomain, outputDomain, denseVectorArr, ProtoUtil.deserialize(unpack.getDistance()));
    }

    public DenseVector[] getCentroidVectors() {
        DenseVector[] denseVectorArr = new DenseVector[this.centroidVectors.length];
        for (int i = 0; i < denseVectorArr.length; i++) {
            denseVectorArr[i] = this.centroidVectors[i].copy();
        }
        return denseVectorArr;
    }

    public List<List<Feature>> getCentroids() {
        ArrayList arrayList = new ArrayList(this.centroidVectors.length);
        for (int i = 0; i < this.centroidVectors.length; i++) {
            ArrayList arrayList2 = new ArrayList(this.featureIDMap.size());
            VectorIterator it = this.centroidVectors[i].iterator();
            while (it.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it.next();
                arrayList2.add(new Feature(this.featureIDMap.get(vectorTuple.index).getName(), vectorTuple.value));
            }
            arrayList.add(arrayList2);
        }
        return arrayList;
    }

    public Prediction<ClusterID> predict(Example<ClusterID> example) {
        DenseVector createDenseVector = example.size() == this.featureIDMap.size() ? DenseVector.createDenseVector(example, this.featureIDMap, false) : SparseVector.createSparseVector(example, this.featureIDMap, false);
        if (createDenseVector.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        double d = Double.POSITIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < this.centroidVectors.length; i2++) {
            double computeDistance = this.dist.computeDistance(this.centroidVectors[i2], createDenseVector);
            if (computeDistance < d) {
                d = computeDistance;
                i = i2;
            }
        }
        return new Prediction<>(new ClusterID(i), createDenseVector.size(), example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) {
        return Optional.empty();
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m1serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        KMeansModelProto.Builder newBuilder = KMeansModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setDistance((DistanceProto) this.dist.serialize());
        for (DenseVector denseVector : this.centroidVectors) {
            newBuilder.addCentroidVectors(denseVector.serialize());
        }
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m52build()));
        newBuilder2.setClassName(KMeansModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public KMeansModel m0copy(String str, ModelProvenance modelProvenance) {
        DenseVector[] denseVectorArr = new DenseVector[this.centroidVectors.length];
        for (int i = 0; i < this.centroidVectors.length; i++) {
            denseVectorArr[i] = this.centroidVectors[i].copy();
        }
        return new KMeansModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, denseVectorArr, this.dist);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        if (this.dist == null) {
            this.dist = this.distanceType.getDistanceType().getDistance();
        }
    }
}
