package org.tribuo.math.neighbour.bruteforce;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.tribuo.math.distance.Distance;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.neighbour.NeighboursQuery;

/* loaded from: input_file:org/tribuo/math/neighbour/bruteforce/NeighboursBruteForce.class */
public final class NeighboursBruteForce implements NeighboursQuery {
    private final SGDVector[] data;
    private final Distance distance;
    private final int numThreads;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/math/neighbour/bruteforce/NeighboursBruteForce$MutablePair.class */
    public static final class MutablePair implements Comparable<MutablePair> {
        int index;
        double value;

        public MutablePair(int i, double d) {
            this.index = i;
            this.value = d;
        }

        @Override // java.lang.Comparable
        public int compareTo(MutablePair mutablePair) {
            return Double.compare(mutablePair.value, this.value);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/math/neighbour/bruteforce/NeighboursBruteForce$SingleQueryRunnable.class */
    public final class SingleQueryRunnable implements Runnable {
        private final SGDVector point;
        private final int k;
        private final int index;
        final List<Pair<Integer, Double>>[] indexDistancePairListArray;

        SingleQueryRunnable(int i, SGDVector sGDVector, int i2, List<Pair<Integer, Double>>[] listArr) {
            this.point = sGDVector;
            this.k = i2;
            this.index = i;
            this.indexDistancePairListArray = listArr;
        }

        @Override // java.lang.Runnable
        public void run() {
            this.indexDistancePairListArray[this.index] = NeighboursBruteForce.this.query(this.point, this.k);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public NeighboursBruteForce(SGDVector[] sGDVectorArr, Distance distance, int i) {
        int size = sGDVectorArr[0].size();
        for (SGDVector sGDVector : sGDVectorArr) {
            if (sGDVector.size() != size) {
                throw new IllegalArgumentException("All the SGDVectors must be the same size.");
            }
        }
        this.data = sGDVectorArr;
        this.distance = distance;
        this.numThreads = i;
    }

    @Override // org.tribuo.math.neighbour.NeighboursQuery
    public List<Pair<Integer, Double>> query(SGDVector sGDVector, int i) {
        PriorityQueue priorityQueue = new PriorityQueue(i);
        for (int i2 = 0; i2 < this.data.length && i2 < i; i2++) {
            priorityQueue.offer(new MutablePair(i2, this.distance.computeDistance(sGDVector, this.data[i2])));
        }
        for (int i3 = i; i3 < this.data.length; i3++) {
            double computeDistance = this.distance.computeDistance(sGDVector, this.data[i3]);
            if (Double.compare(computeDistance, ((MutablePair) priorityQueue.peek()).value) < 0) {
                MutablePair mutablePair = (MutablePair) priorityQueue.poll();
                mutablePair.index = i3;
                mutablePair.value = computeDistance;
                priorityQueue.offer(mutablePair);
            }
        }
        Pair[] pairArr = new Pair[i];
        int i4 = 1;
        while (!priorityQueue.isEmpty()) {
            MutablePair mutablePair2 = (MutablePair) priorityQueue.poll();
            pairArr[i - i4] = new Pair(Integer.valueOf(mutablePair2.index), Double.valueOf(mutablePair2.value));
            i4++;
        }
        return Arrays.asList(pairArr);
    }

    @Override // org.tribuo.math.neighbour.NeighboursQuery
    public List<List<Pair<Integer, Double>>> query(SGDVector[] sGDVectorArr, int i) {
        int length = sGDVectorArr.length;
        List[] listArr = new List[length];
        if (this.numThreads == 1) {
            for (int i2 = 0; i2 < length; i2++) {
                listArr[i2] = query(sGDVectorArr[i2], i);
            }
        } else {
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
            for (int i3 = 0; i3 < length; i3++) {
                newFixedThreadPool.execute(new SingleQueryRunnable(i3, sGDVectorArr[i3], i, listArr));
            }
            newFixedThreadPool.shutdown();
            try {
                if (!newFixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES)) {
                    throw new RuntimeException("Parallel execution failed");
                }
            } catch (InterruptedException e) {
                throw new RuntimeException("Parallel execution failed", e);
            }
        }
        return new ArrayList(Arrays.asList(listArr));
    }

    @Override // org.tribuo.math.neighbour.NeighboursQuery
    public List<List<Pair<Integer, Double>>> queryAll(int i) {
        return query(this.data, i);
    }
}
