package org.tribuo.common.nearest;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import com.oracle.labs.mlrg.olcut.util.StreamUtil;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.security.AccessController;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.Future;
import java.util.function.Function;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.common.nearest.KNNTrainer;
import org.tribuo.common.nearest.protos.KNNModelProto;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.distance.Distance;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.neighbour.NeighboursQuery;
import org.tribuo.math.neighbour.NeighboursQueryFactory;
import org.tribuo.math.neighbour.bruteforce.NeighboursBruteForceFactory;
import org.tribuo.math.protos.DistanceProto;
import org.tribuo.math.protos.NeighbourFactoryProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.EnsembleCombinerProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.protos.core.OutputProto;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/common/nearest/KNNModel.class */
public class KNNModel<T extends Output<T>> extends Model<T> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    private final Pair<SGDVector, T>[] vectors;
    private final int k;

    @Deprecated
    private KNNTrainer.Distance distance;
    private Distance dist;
    private final int numThreads;
    private final Backend parallelBackend;
    private final EnsembleCombiner<T> combiner;
    private NeighboursQueryFactory neighboursQueryFactory;
    private transient NeighboursQuery neighboursQuery;
    private static final Logger logger = Logger.getLogger(KNNModel.class.getName());
    private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory(null);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.tribuo.common.nearest.KNNModel$1, reason: invalid class name */
    /* loaded from: input_file:org/tribuo/common/nearest/KNNModel$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tribuo$common$nearest$KNNModel$Backend = new int[Backend.values().length];

        static {
            try {
                $SwitchMap$org$tribuo$common$nearest$KNNModel$Backend[Backend.STREAMS.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tribuo$common$nearest$KNNModel$Backend[Backend.THREADPOOL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$tribuo$common$nearest$KNNModel$Backend[Backend.INNERTHREADPOOL.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/tribuo/common/nearest/KNNModel$Backend.class */
    public enum Backend {
        STREAMS,
        THREADPOOL,
        INNERTHREADPOOL
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/common/nearest/KNNModel$CustomForkJoinWorkerThreadFactory.class */
    public static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
        private CustomForkJoinWorkerThreadFactory() {
        }

        @Override // java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory
        public final ForkJoinWorkerThread newThread(ForkJoinPool forkJoinPool) {
            return (ForkJoinWorkerThread) AccessController.doPrivileged(() -> {
                return new ForkJoinWorkerThread(forkJoinPool) { // from class: org.tribuo.common.nearest.KNNModel.CustomForkJoinWorkerThreadFactory.1
                };
            });
        }

        /* synthetic */ CustomForkJoinWorkerThreadFactory(AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/common/nearest/KNNModel$OutputDoublePair.class */
    public static final class OutputDoublePair<T extends Output<T>> implements Comparable<OutputDoublePair<T>> {
        T output;
        double value;

        public OutputDoublePair(T t, double d) {
            this.output = t;
            this.value = d;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            OutputDoublePair outputDoublePair = (OutputDoublePair) obj;
            return Double.compare(outputDoublePair.value, this.value) == 0 && this.output.equals(outputDoublePair.output);
        }

        public int hashCode() {
            return Objects.hash(this.output, Double.valueOf(this.value));
        }

        @Override // java.lang.Comparable
        public int compareTo(OutputDoublePair<T> outputDoublePair) {
            return Double.compare(this.value, outputDoublePair.value);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public KNNModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, boolean z, int i, Distance distance, int i2, EnsembleCombiner<T> ensembleCombiner, Pair<SGDVector, T>[] pairArr, Backend backend, NeighboursQueryFactory neighboursQueryFactory) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, z);
        this.k = i;
        this.dist = distance;
        this.numThreads = i2;
        this.combiner = ensembleCombiner;
        this.parallelBackend = backend;
        this.vectors = pairArr;
        this.neighboursQueryFactory = neighboursQueryFactory;
        this.neighboursQuery = neighboursQueryFactory.createNeighboursQuery(getSGDVectorArr());
    }

    public static KNNModel<?> deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        KNNModelProto unpack = any.unpack(KNNModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        ImmutableFeatureMap featureDomain = deserialize.featureDomain();
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        Class<?> cls = outputDomain.getOutput(0).getClass();
        EnsembleCombiner deserialize2 = EnsembleCombiner.deserialize(unpack.getCombiner());
        if (!cls.equals(deserialize2.getTypeWitness())) {
            throw new IllegalStateException("Invalid protobuf, combiner and output domain have a type mismatch, expected " + cls + " found " + deserialize2.getTypeWitness());
        }
        int k = unpack.getK();
        if (k < 1) {
            throw new IllegalStateException("Invalid protobuf, k must be positive, found " + k);
        }
        int numThreads = unpack.getNumThreads();
        if (numThreads < 0) {
            throw new IllegalStateException("Invalid protobuf, numThreads must be positive, found " + numThreads);
        }
        if (unpack.getVectorsCount() == 0) {
            throw new IllegalStateException("Invalid protobuf, no vectors were found");
        }
        if (unpack.getVectorsCount() != unpack.getOutputsCount()) {
            throw new IllegalStateException("Invalid protobuf, different numbers of outputs and vectors were found, " + unpack.getVectorsCount() + " vectors, " + unpack.getOutputsCount() + " outputs");
        }
        Pair[] pairArr = new Pair[unpack.getVectorsCount()];
        List<TensorProto> vectorsList = unpack.getVectorsList();
        List<OutputProto> outputsList = unpack.getOutputsList();
        for (int i2 = 0; i2 < pairArr.length; i2++) {
            SGDVector deserialize3 = Tensor.deserialize(vectorsList.get(i2));
            Output deserialize4 = Output.deserialize(outputsList.get(i2));
            if (!(deserialize3 instanceof SGDVector)) {
                throw new IllegalStateException("Invalid protobuf, expected centroid to be a vector, found " + deserialize3.getClass());
            }
            SGDVector sGDVector = deserialize3;
            if (sGDVector.size() != featureDomain.size()) {
                throw new IllegalStateException("Invalid protobuf, vector did not contain all the features, found " + sGDVector.size() + " expected " + featureDomain.size());
            }
            if (!deserialize4.getClass().equals(cls)) {
                throw new IllegalStateException("Invalid protobuf, output type did not match, found " + deserialize4.getClass() + " expected " + cls);
            }
            pairArr[i2] = new Pair(sGDVector, deserialize4);
        }
        return new KNNModel<>(deserialize.name(), deserialize.provenance(), featureDomain, outputDomain, deserialize.generatesProbabilities(), k, ProtoUtil.deserialize(unpack.getDistance()), numThreads, deserialize2, pairArr, Backend.valueOf(unpack.getParallelBackend()), NeighboursQueryFactory.deserialize(unpack.getNeighboursQueryFactory()));
    }

    public Prediction<T> predict(Example<T> example) {
        List list;
        DenseVector createDenseVector = example.size() == this.featureIDMap.size() ? DenseVector.createDenseVector(example, this.featureIDMap, false) : SparseVector.createSparseVector(example, this.featureIDMap, false);
        if (createDenseVector.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example);
        }
        DenseVector denseVector = createDenseVector;
        Function function = pair -> {
            return new OutputDoublePair((Output) pair.getB(), this.dist.computeDistance((SGDVector) pair.getA(), denseVector));
        };
        Stream of = Stream.of((Object[]) this.vectors);
        if (this.numThreads > 1) {
            try {
                DenseVector denseVector2 = createDenseVector;
                list = (List) (System.getSecurityManager() == null ? new ForkJoinPool(this.numThreads) : new ForkJoinPool(this.numThreads, THREAD_FACTORY, null, false)).submit(() -> {
                    return (List) StreamUtil.boundParallelism((Stream) of.parallel()).map(function).sorted().limit(this.k).map(outputDoublePair -> {
                        return new Prediction(outputDoublePair.output, denseVector2.numActiveElements(), example);
                    }).collect(Collectors.toList());
                }).get();
            } catch (InterruptedException | ExecutionException e) {
                logger.log(Level.SEVERE, "Exception when predicting in KNNModel", e);
                throw new IllegalStateException("Failed to process example in parallel", e);
            }
        } else {
            DenseVector denseVector3 = createDenseVector;
            list = (List) of.map(function).sorted().limit(this.k).map(outputDoublePair -> {
                return new Prediction(outputDoublePair.output, denseVector3.numActiveElements(), example);
            }).collect(Collectors.toList());
        }
        return this.combiner.combine(this.outputIDInfo, list);
    }

    protected List<Prediction<T>> innerPredict(Iterable<Example<T>> iterable) {
        if (this.numThreads > 1) {
            return innerPredictMultithreaded(iterable);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Example<T> example : iterable) {
            arrayList2.clear();
            DenseVector createDenseVector = example.size() == this.featureIDMap.size() ? DenseVector.createDenseVector(example, this.featureIDMap, false) : SparseVector.createSparseVector(example, this.featureIDMap, false);
            Iterator it = this.neighboursQuery.query(createDenseVector, this.k).iterator();
            while (it.hasNext()) {
                arrayList2.add(new Prediction((Output) this.vectors[((Integer) ((Pair) it.next()).getA()).intValue()].getB(), createDenseVector.numActiveElements(), example));
            }
            arrayList.add(this.combiner.combine(this.outputIDInfo, arrayList2));
        }
        return arrayList;
    }

    private List<Prediction<T>> innerPredictMultithreaded(Iterable<Example<T>> iterable) {
        switch (AnonymousClass1.$SwitchMap$org$tribuo$common$nearest$KNNModel$Backend[this.parallelBackend.ordinal()]) {
            case KNNModelProto.METADATA_FIELD_NUMBER /* 1 */:
                logger.log(Level.FINE, "Parallel backend - streams");
                return innerPredictStreams(iterable);
            case KNNModelProto.VECTORS_FIELD_NUMBER /* 2 */:
                logger.log(Level.FINE, "Parallel backend - threadpool");
                return innerPredictThreadPool(iterable);
            case KNNModelProto.OUTPUTS_FIELD_NUMBER /* 3 */:
                logger.log(Level.FINE, "Parallel backend - within example threadpool");
                return innerPredictWithinExampleThreadPool(iterable);
            default:
                throw new IllegalArgumentException("Unknown backend " + this.parallelBackend);
        }
    }

    private List<Prediction<T>> innerPredictStreams(Iterable<Example<T>> iterable) {
        ArrayList arrayList = new ArrayList();
        List list = null;
        ForkJoinPool forkJoinPool = System.getSecurityManager() == null ? new ForkJoinPool(this.numThreads) : new ForkJoinPool(this.numThreads, THREAD_FACTORY, null, false);
        for (Example<T> example : iterable) {
            DenseVector createDenseVector = example.size() == this.featureIDMap.size() ? DenseVector.createDenseVector(example, this.featureIDMap, false) : SparseVector.createSparseVector(example, this.featureIDMap, false);
            DenseVector denseVector = createDenseVector;
            Function function = pair -> {
                return new OutputDoublePair((Output) pair.getB(), this.dist.computeDistance((SGDVector) pair.getA(), denseVector));
            };
            Stream of = Stream.of((Object[]) this.vectors);
            try {
                DenseVector denseVector2 = createDenseVector;
                list = (List) forkJoinPool.submit(() -> {
                    return (List) StreamUtil.boundParallelism((Stream) of.parallel()).map(function).sorted().limit(this.k).map(outputDoublePair -> {
                        return new Prediction(outputDoublePair.output, denseVector2.numActiveElements(), example);
                    }).collect(Collectors.toList());
                }).get();
            } catch (InterruptedException | ExecutionException e) {
                logger.log(Level.SEVERE, "Exception when predicting in KNNModel", e);
            }
            arrayList.add(this.combiner.combine(this.outputIDInfo, list));
        }
        return arrayList;
    }

    private List<Prediction<T>> innerPredictThreadPool(Iterable<Example<T>> iterable) {
        ArrayList arrayList = new ArrayList();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
        ArrayList arrayList2 = new ArrayList();
        for (Example<T> example : iterable) {
            arrayList2.add(newFixedThreadPool.submit(() -> {
                return innerPredictOne(this.neighboursQuery, this.vectors, this.combiner, this.featureIDMap, this.outputIDInfo, this.k, example);
            }));
        }
        try {
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                arrayList.add((Prediction) ((Future) it.next()).get());
            }
            newFixedThreadPool.shutdown();
            return arrayList;
        } catch (InterruptedException | ExecutionException e) {
            throw new IllegalStateException("Thread pool went bang", e);
        }
    }

    private List<Prediction<T>> innerPredictWithinExampleThreadPool(Iterable<Example<T>> iterable) {
        ArrayList arrayList = new ArrayList();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
        ThreadLocal<PriorityQueue<OutputDoublePair<T>>> withInitial = ThreadLocal.withInitial(() -> {
            return new PriorityQueue(this.k, (outputDoublePair, outputDoublePair2) -> {
                return Double.compare(outputDoublePair2.value, outputDoublePair.value);
            });
        });
        Iterator<Example<T>> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(innerPredictThreadPool(newFixedThreadPool, withInitial, this.dist, it.next()));
        }
        newFixedThreadPool.shutdown();
        return arrayList;
    }

    private Prediction<T> innerPredictThreadPool(ExecutorService executorService, ThreadLocal<PriorityQueue<OutputDoublePair<T>>> threadLocal, Distance distance, Example<T> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numThreads; i++) {
            int length = i * (this.vectors.length / this.numThreads);
            int length2 = (i + 1) * (this.vectors.length / this.numThreads);
            arrayList.add(executorService.submit(() -> {
                return innerPredictChunk(threadLocal, this.vectors, length, length2, distance, this.k, createSparseVector);
            }));
        }
        PriorityQueue priorityQueue = new PriorityQueue(this.k, (outputDoublePair, outputDoublePair2) -> {
            return Double.compare(outputDoublePair2.value, outputDoublePair.value);
        });
        try {
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                for (OutputDoublePair outputDoublePair3 : (List) ((Future) it.next()).get()) {
                    if (priorityQueue.size() < this.k) {
                        priorityQueue.offer(outputDoublePair3);
                    } else if (Double.compare(outputDoublePair3.value, ((OutputDoublePair) priorityQueue.peek()).value) < 0) {
                        priorityQueue.poll();
                        priorityQueue.offer(outputDoublePair3);
                    }
                }
            }
            ArrayList arrayList2 = new ArrayList();
            Iterator it2 = priorityQueue.iterator();
            while (it2.hasNext()) {
                arrayList2.add(new Prediction(((OutputDoublePair) it2.next()).output, createSparseVector.numActiveElements(), example));
            }
            return this.combiner.combine(this.outputIDInfo, arrayList2);
        } catch (InterruptedException | ExecutionException e) {
            throw new IllegalStateException("Thread pool went bang", e);
        }
    }

    private static <T extends Output<T>> List<OutputDoublePair<T>> innerPredictChunk(ThreadLocal<PriorityQueue<OutputDoublePair<T>>> threadLocal, Pair<SGDVector, T>[] pairArr, int i, int i2, Distance distance, int i3, SGDVector sGDVector) {
        PriorityQueue<OutputDoublePair<T>> priorityQueue = threadLocal.get();
        priorityQueue.clear();
        int min = Math.min(i2, pairArr.length);
        for (int i4 = i; i4 < min; i4++) {
            double computeDistance = distance.computeDistance((SGDVector) pairArr[i4].getA(), sGDVector);
            if (priorityQueue.size() < i3) {
                priorityQueue.offer(new OutputDoublePair<>((Output) pairArr[i4].getB(), computeDistance));
            } else if (Double.compare(computeDistance, priorityQueue.peek().value) < 0) {
                OutputDoublePair<T> poll = priorityQueue.poll();
                poll.output = (T) pairArr[i4].getB();
                poll.value = computeDistance;
                priorityQueue.offer(poll);
            }
        }
        return new ArrayList(priorityQueue);
    }

    private static <T extends Output<T>> Prediction<T> innerPredictOne(NeighboursQuery neighboursQuery, Pair<SGDVector, T>[] pairArr, EnsembleCombiner<T> ensembleCombiner, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, int i, Example<T> example) {
        DenseVector createDenseVector = example.size() == immutableFeatureMap.size() ? DenseVector.createDenseVector(example, immutableFeatureMap, false) : SparseVector.createSparseVector(example, immutableFeatureMap, false);
        List query = neighboursQuery.query(createDenseVector, i);
        ArrayList arrayList = new ArrayList();
        Iterator it = query.iterator();
        while (it.hasNext()) {
            arrayList.add(new Prediction((Output) pairArr[((Integer) ((Pair) it.next()).getA()).intValue()].getB(), createDenseVector.numActiveElements(), example));
        }
        return ensembleCombiner.combine(immutableOutputInfo, arrayList);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        return Optional.empty();
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m5serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        KNNModelProto.Builder newBuilder = KNNModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        for (Pair<SGDVector, T> pair : this.vectors) {
            newBuilder.addVectors((TensorProto) ((SGDVector) pair.getA()).serialize());
            newBuilder.addOutputs((OutputProto) ((Output) pair.getB()).serialize());
        }
        newBuilder.setK(this.k);
        newBuilder.setDistance((DistanceProto) this.dist.serialize());
        newBuilder.setNumThreads(this.numThreads);
        newBuilder.setParallelBackend(this.parallelBackend.name());
        newBuilder.setCombiner((EnsembleCombinerProto) this.combiner.serialize());
        newBuilder.setNeighboursQueryFactory((NeighbourFactoryProto) this.neighboursQueryFactory.serialize());
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m50build()));
        newBuilder2.setClassName(KNNModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public KNNModel<T> m4copy(String str, ModelProvenance modelProvenance) {
        Pair[] pairArr = new Pair[this.vectors.length];
        for (int i = 0; i < this.vectors.length; i++) {
            pairArr[i] = new Pair(((SGDVector) this.vectors[i].getA()).copy(), ((Output) this.vectors[i].getB()).copy());
        }
        return new KNNModel<>(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.generatesProbabilities, this.k, this.dist, this.numThreads, this.combiner, pairArr, this.parallelBackend, this.neighboursQueryFactory);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        if (this.dist == null) {
            this.dist = this.distance.getDistanceType().getDistance();
        }
        if (this.neighboursQueryFactory == null) {
            this.neighboursQueryFactory = new NeighboursBruteForceFactory(this.dist, this.numThreads);
        }
        this.neighboursQuery = this.neighboursQueryFactory.createNeighboursQuery(getSGDVectorArr());
    }

    private SGDVector[] getSGDVectorArr() {
        SGDVector[] sGDVectorArr = new SGDVector[this.vectors.length];
        int i = 0;
        for (Pair<SGDVector, T> pair : this.vectors) {
            sGDVectorArr[i] = (SGDVector) pair.getA();
            i++;
        }
        return sGDVectorArr;
    }
}
