/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.neighbour.kdtree;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.tribuo.math.distance.Distance;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.neighbour.NeighboursQuery;
import org.tribuo.math.neighbour.kdtree.DimensionNode;

public final class KDTree
implements NeighboursQuery {
    private final SGDVector[] data;
    private final int numThreads;
    private final DimensionNode root;

    KDTree(SGDVector[] data, Distance distance, int numThreads) {
        this.data = data;
        this.numThreads = numThreads;
        int numDimensions = data[0].size();
        IntAndVector[] points = new IntAndVector[data.length];
        for (int i = 0; i < data.length; ++i) {
            points[i] = new IntAndVector(i, data[i]);
            if (data[i].size() == numDimensions) continue;
            throw new IllegalArgumentException("All the SGDVectors must be the same size.");
        }
        this.root = KDTree.generateTree(0, numDimensions - 1, points, 0, data.length - 1, distance);
    }

    @Override
    public List<Pair<Integer, Double>> query(SGDVector point, int k) {
        DistanceIntAndVectorBoundedMinHeap queue = new DistanceIntAndVectorBoundedMinHeap(k);
        this.initializeQueue(point, queue);
        this.root.nearest(point, queue, false);
        Pair[] indexDistanceArr = new Pair[k];
        int i = 1;
        while (!queue.isEmpty()) {
            MutableDistIntAndVectorTuple tuple = queue.poll();
            indexDistanceArr[k - i] = new Pair((Object)tuple.intAndVector.idx, (Object)tuple.dist);
            ++i;
        }
        return Arrays.asList(indexDistanceArr);
    }

    @Override
    public List<List<Pair<Integer, Double>>> query(SGDVector[] points, int k) {
        int numQueries = points.length;
        List[] indexDistancePairListArray = new List[numQueries];
        if (this.numThreads == 1) {
            for (int point = 0; point < numQueries; ++point) {
                indexDistancePairListArray[point] = this.query(points[point], k);
            }
        } else {
            ExecutorService executorService = Executors.newFixedThreadPool(this.numThreads);
            for (int pointInd = 0; pointInd < numQueries; ++pointInd) {
                executorService.execute(new SingleQueryRunnable(pointInd, points[pointInd], k, indexDistancePairListArray));
            }
            executorService.shutdown();
            try {
                boolean finished = executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES);
                if (!finished) {
                    throw new RuntimeException("Parallel execution failed");
                }
            }
            catch (InterruptedException e) {
                throw new RuntimeException("Parallel execution failed", e);
            }
        }
        return Arrays.asList(indexDistancePairListArray);
    }

    @Override
    public List<List<Pair<Integer, Double>>> queryAll(int k) {
        return this.query(this.data, k);
    }

    private static DimensionNode generateTree(int d, int maxD, IntAndVector[] points, int left, int right, Distance distance) {
        if (right < left) {
            return null;
        }
        if (right == left) {
            return new DimensionNode(d, points[left], distance);
        }
        int median = 1 + (right - left) / 2;
        KDTree.setMedian(points, median, left, right, d);
        DimensionNode medianNode = new DimensionNode(d, points[left + median - 1], distance);
        if (++d > maxD) {
            d = 0;
        }
        medianNode.setBelow(KDTree.generateTree(d, maxD, points, left, left + median - 2, distance));
        medianNode.setAbove(KDTree.generateTree(d, maxD, points, left + median, right, distance));
        return medianNode;
    }

    private static void setMedian(IntAndVector[] points, int median, int left, int right, int dimension) {
        int initialPivotIndex;
        int newPivotIndex;
        while (left + median - 1 != (newPivotIndex = KDTree.partitionOnIndex(points, left, right, initialPivotIndex = KDTree.getPivotPointIndex(points, left, right, dimension), dimension))) {
            if (left + median - 1 < newPivotIndex) {
                right = newPivotIndex - 1;
                continue;
            }
            median -= newPivotIndex + 1 - left;
            left = newPivotIndex + 1;
        }
        return;
    }

    private static int getPivotPointIndex(IntAndVector[] points, int left, int right, int dimension) {
        int lowIndex = left;
        int midIndex = (left + right) / 2;
        if (KDTree.compareByDimension(points[lowIndex], points[midIndex], dimension) >= 0) {
            lowIndex = midIndex;
            midIndex = left;
        }
        if (KDTree.compareByDimension(points[right], points[lowIndex], dimension) <= 0) {
            return lowIndex;
        }
        if (KDTree.compareByDimension(points[right], points[midIndex], dimension) <= 0) {
            return right;
        }
        return midIndex;
    }

    private static int partitionOnIndex(IntAndVector[] points, int left, int right, int pivotIndex, int dimension) {
        IntAndVector pivot = points[pivotIndex];
        KDTree.swap(points, right, pivotIndex);
        int store = left;
        for (int idx = left; idx < right; ++idx) {
            if (KDTree.compareByDimension(points[idx], pivot, dimension) > 0) continue;
            KDTree.swap(points, idx, store);
            ++store;
        }
        KDTree.swap(points, right, store);
        return store;
    }

    private static void swap(IntAndVector[] points, int ind1, int ind2) {
        if (ind1 == ind2) {
            return;
        }
        IntAndVector tmpPoint = points[ind1];
        points[ind1] = points[ind2];
        points[ind2] = tmpPoint;
    }

    private void initializeQueue(SGDVector point, DistanceIntAndVectorBoundedMinHeap queue) {
        DimensionNode parentOfPoint = this.approximateParentNode(point);
        parentOfPoint.nearest(point, queue, true);
    }

    public DimensionNode approximateParentNode(SGDVector point) {
        DimensionNode node;
        DimensionNode bestNode = node = this.root;
        while (node != null) {
            DimensionNode next;
            if (node.getBelow() != null && node.getAbove() != null) {
                bestNode = node;
            }
            if ((next = node.isBelow(point) ? node.getBelow() : node.getAbove()) == null) break;
            node = next;
        }
        return bestNode;
    }

    private static int compareByDimension(IntAndVector intAndVector1, IntAndVector intAndVector2, int dimension) {
        return Double.compare(intAndVector1.vector.get(dimension), intAndVector2.vector.get(dimension));
    }

    static final class IntAndVector {
        final int idx;
        final SGDVector vector;

        public IntAndVector(int idx, SGDVector vector) {
            this.idx = idx;
            this.vector = vector;
        }
    }

    static final class DistanceIntAndVectorBoundedMinHeap {
        private final HashSet<Integer> ids = new HashSet();
        final int size;
        private final PriorityQueue<MutableDistIntAndVectorTuple> queue;

        DistanceIntAndVectorBoundedMinHeap(int size) {
            this.size = size;
            this.queue = new PriorityQueue(size);
        }

        void boundedOffer(IntAndVector intAndVector, double distance) {
            if (this.ids.contains(intAndVector.idx)) {
                return;
            }
            if (this.queue.size() < this.size) {
                this.queue.offer(new MutableDistIntAndVectorTuple(distance, intAndVector));
                this.ids.add(intAndVector.idx);
            } else if (Double.compare(distance, this.queue.peek().dist) < 0) {
                MutableDistIntAndVectorTuple tuple = this.poll();
                tuple.dist = distance;
                tuple.intAndVector = intAndVector;
                this.queue.offer(tuple);
                this.ids.add(intAndVector.idx);
            }
        }

        MutableDistIntAndVectorTuple peek() {
            return this.queue.peek();
        }

        MutableDistIntAndVectorTuple poll() {
            MutableDistIntAndVectorTuple tuple = this.queue.poll();
            this.ids.remove(tuple.intAndVector.idx);
            return tuple;
        }

        boolean isFull() {
            return this.queue.size() == this.size;
        }

        boolean isEmpty() {
            return this.queue.isEmpty();
        }
    }

    static final class MutableDistIntAndVectorTuple
    implements Comparable<MutableDistIntAndVectorTuple> {
        double dist;
        IntAndVector intAndVector;

        public MutableDistIntAndVectorTuple(double dist, IntAndVector intAndVector) {
            this.dist = dist;
            this.intAndVector = intAndVector;
        }

        @Override
        public int compareTo(MutableDistIntAndVectorTuple o) {
            return Double.compare(o.dist, this.dist);
        }
    }

    private final class SingleQueryRunnable
    implements Runnable {
        private final SGDVector point;
        private final int k;
        private final int index;
        final List<Pair<Integer, Double>>[] indexDistancePairListArray;

        SingleQueryRunnable(int index, SGDVector point, int k, List<Pair<Integer, Double>>[] indexDistancePairListArray) {
            this.point = point;
            this.k = k;
            this.index = index;
            this.indexDistancePairListArray = indexDistancePairListArray;
        }

        @Override
        public void run() {
            this.indexDistancePairListArray[this.index] = KDTree.this.query(this.point, this.k);
        }
    }
}

