/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.task;

import io.lucenia.ml.common.action.stats.MLStatsNodeResponse;
import io.lucenia.ml.common.action.stats.MLStatsNodesAction;
import io.lucenia.ml.common.action.stats.MLStatsNodesRequest;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.client.Client;
import io.skylite.core.cluster.node.DiscoveryNode;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.settings.Settings;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.cluster.DiscoveryNodeHelper;
import io.skylite.ml.common.settings.MLCommonsSettings;
import io.skylite.ml.common.stats.MLNodeLevelStat;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import javax.naming.LimitExceededException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class MLTaskDispatcher {
    private static final Logger log = LogManager.getLogger(MLTaskDispatcher.class);
    private final short DEFAULT_JVM_HEAP_USAGE_THRESHOLD = (short)85;
    private final String ROUND_ROBIN = "round_robin";
    private final String LEAST_LOAD = "least_load";
    private final ClusterService clusterService;
    private final Client client;
    private AtomicInteger nextNode;
    private volatile Integer maxMLBatchTaskPerNode;
    private volatile String dispatchPolicy;
    private DiscoveryNodeHelper nodeHelper;

    public MLTaskDispatcher(ClusterService clusterService, Client client, Settings settings, DiscoveryNodeHelper nodeHelper) {
        this.clusterService = clusterService;
        this.client = client;
        this.nodeHelper = nodeHelper;
        this.maxMLBatchTaskPerNode = (Integer)MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE.get(settings);
        this.nextNode = new AtomicInteger(0);
        this.dispatchPolicy = (String)MLCommonsSettings.ML_COMMONS_TASK_DISPATCH_POLICY.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_TASK_DISPATCH_POLICY, it -> {
            this.dispatchPolicy = it;
        });
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE, it -> {
            this.maxMLBatchTaskPerNode = it;
        });
    }

    public void dispatch(FunctionName functionName, ActionListener<DiscoveryNode> actionListener) {
        if ("round_robin".equals(this.dispatchPolicy)) {
            this.dispatchTaskWithRoundRobin(functionName, actionListener);
        } else if ("least_load".equals(this.dispatchPolicy)) {
            this.dispatchTaskWithLeastLoad(functionName, actionListener);
        } else {
            throw new IllegalArgumentException("Unknown policy");
        }
    }

    public void dispatchPredictTask(String[] nodeIds, ActionListener<DiscoveryNode> actionListener) {
        if (nodeIds == null || nodeIds.length == 0) {
            throw new IllegalArgumentException("no eligible node to run predict request");
        }
        if ("round_robin".equals(this.dispatchPolicy)) {
            this.dispatchTaskWithRoundRobin(nodeIds, ActionListenerHelper.wrap(nodeId -> actionListener.onResponse((Object)this.nodeHelper.getNode(nodeId)), e -> actionListener.onFailure(e)));
        } else if ("least_load".equals(this.dispatchPolicy)) {
            this.dispatchTaskWithLeastLoad(nodeIds, actionListener);
        } else {
            throw new IllegalArgumentException("Unknown policy");
        }
    }

    private <T> void dispatchTaskWithRoundRobin(T[] nodes, ActionListener<T> listener) {
        int currentNode = this.nextNode.getAndIncrement();
        if (currentNode > nodes.length - 1) {
            currentNode = 0;
            this.nextNode.set(currentNode + 1);
        }
        listener.onResponse(nodes[currentNode]);
    }

    private void dispatchTaskWithLeastLoad(String[] nodeIds, ActionListener<DiscoveryNode> listener) {
        DiscoveryNode[] nodes = this.nodeHelper.getNodes(nodeIds);
        this.dispatchTaskWithLeastLoad(nodes, listener);
    }

    private void dispatchTaskWithLeastLoad(DiscoveryNode[] nodes, ActionListener<DiscoveryNode> listener) {
        MLStatsNodesRequest MLStatsNodesRequest2 = new MLStatsNodesRequest(nodes);
        MLStatsNodesRequest2.addNodeLevelStats(Set.of(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_JVM_HEAP_USAGE));
        this.client.execute((ActionType)MLStatsNodesAction.INSTANCE, (ActionRequest)MLStatsNodesRequest2, ActionListenerHelper.wrap(mlStatsResponse -> {
            List candidateNodeResponse = mlStatsResponse.getNodes().stream().filter(stat -> (Long)stat.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE) < 85L).collect(Collectors.toList());
            if (candidateNodeResponse.size() == 0) {
                String errorMessage = "All nodes' memory usage exceeds limitation 85. No eligible node available to run ml jobs ";
                log.warn(errorMessage);
                listener.onFailure((Exception)new LimitExceededException(errorMessage));
                return;
            }
            if ((candidateNodeResponse = candidateNodeResponse.stream().filter(stat -> (Long)stat.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT) < (long)this.maxMLBatchTaskPerNode.intValue()).collect(Collectors.toList())).size() == 0) {
                String errorMessage = "All nodes' executing ML task count reach limitation.";
                log.warn(errorMessage);
                listener.onFailure((Exception)new LimitExceededException(errorMessage));
                return;
            }
            Optional targetNode = candidateNodeResponse.stream().sorted((r1, r2) -> {
                int result = ((Long)r1.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)).compareTo((Long)r2.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT));
                if (result == 0) {
                    return ((Long)r1.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)).compareTo((Long)r2.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE));
                }
                return result;
            }).findFirst();
            listener.onResponse((Object)((MLStatsNodeResponse)((Object)((Object)targetNode.get()))).getNode());
        }, exception -> {
            log.error("Failed to get node's task stats", (Throwable)exception);
            listener.onFailure(exception);
        }));
    }

    private void dispatchTaskWithLeastLoad(FunctionName functionName, ActionListener<DiscoveryNode> listener) {
        DiscoveryNode[] eligibleNodes = this.nodeHelper.getEligibleNodes(functionName);
        this.dispatchTaskWithLeastLoad(eligibleNodes, listener);
    }

    private void dispatchTaskWithRoundRobin(FunctionName functionName, ActionListener<DiscoveryNode> listener) {
        DiscoveryNode[] eligibleNodes = this.nodeHelper.getEligibleNodes(functionName);
        if (eligibleNodes == null || eligibleNodes.length == 0) {
            throw new IllegalArgumentException("No eligible node found to execute this request. It's best practice to provision ML nodes to serve your models. You can disable this setting to serve the model on your data node for development purposes by disabling the \"plugins.ml_commons.only_run_on_ml_node\" configuration using the _cluster/setting api");
        }
        this.dispatchTaskWithRoundRobin(eligibleNodes, listener);
    }
}

