/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.neighbour.bruteforce;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.tribuo.math.distance.Distance;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.neighbour.NeighboursQuery;

public final class NeighboursBruteForce
implements NeighboursQuery {
    private final SGDVector[] data;
    private final Distance distance;
    private final int numThreads;

    NeighboursBruteForce(SGDVector[] data, Distance distance, int numThreads) {
        int numFeatures = data[0].size();
        for (SGDVector vector : data) {
            if (vector.size() == numFeatures) continue;
            throw new IllegalArgumentException("All the SGDVectors must be the same size.");
        }
        this.data = data;
        this.distance = distance;
        this.numThreads = numThreads;
    }

    @Override
    public List<Pair<Integer, Double>> query(SGDVector point, int k) {
        double distance;
        int neighbor;
        PriorityQueue<MutablePair> queue = new PriorityQueue<MutablePair>(k);
        for (neighbor = 0; neighbor < this.data.length && neighbor < k; ++neighbor) {
            distance = this.distance.computeDistance(point, this.data[neighbor]);
            MutablePair newPair = new MutablePair(neighbor, distance);
            queue.offer(newPair);
        }
        for (neighbor = k; neighbor < this.data.length; ++neighbor) {
            distance = this.distance.computeDistance(point, this.data[neighbor]);
            if (Double.compare(distance, ((MutablePair)queue.peek()).value) >= 0) continue;
            MutablePair pair = (MutablePair)queue.poll();
            pair.index = neighbor;
            pair.value = distance;
            queue.offer(pair);
        }
        Pair[] indexDistanceArr = new Pair[k];
        int i = 1;
        while (!queue.isEmpty()) {
            MutablePair mutablePair = (MutablePair)queue.poll();
            indexDistanceArr[k - i] = new Pair((Object)mutablePair.index, (Object)mutablePair.value);
            ++i;
        }
        return Arrays.asList(indexDistanceArr);
    }

    @Override
    public List<List<Pair<Integer, Double>>> query(SGDVector[] points, int k) {
        int numQueries = points.length;
        List[] indexDistancePairListArray = new List[numQueries];
        if (this.numThreads == 1) {
            for (int point = 0; point < numQueries; ++point) {
                indexDistancePairListArray[point] = this.query(points[point], k);
            }
        } else {
            ExecutorService executorService = Executors.newFixedThreadPool(this.numThreads);
            for (int pointInd = 0; pointInd < numQueries; ++pointInd) {
                executorService.execute(new SingleQueryRunnable(pointInd, points[pointInd], k, indexDistancePairListArray));
            }
            executorService.shutdown();
            try {
                boolean finished = executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES);
                if (!finished) {
                    throw new RuntimeException("Parallel execution failed");
                }
            }
            catch (InterruptedException e) {
                throw new RuntimeException("Parallel execution failed", e);
            }
        }
        return new ArrayList<List<Pair<Integer, Double>>>(Arrays.asList(indexDistancePairListArray));
    }

    @Override
    public List<List<Pair<Integer, Double>>> queryAll(int k) {
        return this.query(this.data, k);
    }

    private static final class MutablePair
    implements Comparable<MutablePair> {
        int index;
        double value;

        public MutablePair(int index, double value) {
            this.index = index;
            this.value = value;
        }

        @Override
        public int compareTo(MutablePair o) {
            return Double.compare(o.value, this.value);
        }
    }

    private final class SingleQueryRunnable
    implements Runnable {
        private final SGDVector point;
        private final int k;
        private final int index;
        final List<Pair<Integer, Double>>[] indexDistancePairListArray;

        SingleQueryRunnable(int index, SGDVector point, int k, List<Pair<Integer, Double>>[] indexDistancePairListArray) {
            this.point = point;
            this.k = k;
            this.index = index;
            this.indexDistancePairListArray = indexDistancePairListArray;
        }

        @Override
        public void run() {
            this.indexDistancePairListArray[this.index] = NeighboursBruteForce.this.query(this.point, this.k);
        }
    }
}

