/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.action.deploy;

import io.lucenia.action.support.nodes.TransportNodesAction;
import io.lucenia.ml.common.model.MLModelManager;
import io.lucenia.ml.common.task.MLTaskManager;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionFilters;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionListenerResponseHandler;
import io.skylite.core.action.FailedNodeException;
import io.skylite.core.client.Client;
import io.skylite.core.cluster.node.DiscoveryNode;
import io.skylite.core.cluster.node.DiscoveryNodes;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.common.inject.Inject;
import io.skylite.core.common.io.stream.StreamInput;
import io.skylite.core.threadpool.ThreadPool;
import io.skylite.core.transport.TransportRequest;
import io.skylite.core.transport.TransportResponseHandler;
import io.skylite.core.transport.TransportService;
import io.skylite.core.xcontent.NamedXContentRegistry;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.breaker.MLCircuitBreakerService;
import io.skylite.ml.common.engine.ModelDownloader;
import io.skylite.ml.common.exception.MLExceptionUtils;
import io.skylite.ml.common.stats.MLStats;
import io.skylite.ml.common.task.MLTask;
import io.skylite.ml.common.transport.deploy.MLDeployModelInput;
import io.skylite.ml.common.transport.deploy.MLDeployModelNodeRequest;
import io.skylite.ml.common.transport.deploy.MLDeployModelNodeResponse;
import io.skylite.ml.common.transport.deploy.MLDeployModelNodesRequest;
import io.skylite.ml.common.transport.deploy.MLDeployModelNodesResponse;
import io.skylite.ml.common.transport.forward.MLForwardInput;
import io.skylite.ml.common.transport.forward.MLForwardRequest;
import io.skylite.ml.common.transport.forward.MLForwardRequestType;
import io.skylite.ml.common.transport.forward.MLForwardResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class TransportDeployModelOnNodeAction
extends TransportNodesAction<MLDeployModelNodesRequest, MLDeployModelNodesResponse, MLDeployModelNodeRequest, MLDeployModelNodeResponse> {
    private static final Logger log = LogManager.getLogger(TransportDeployModelOnNodeAction.class);
    TransportService transportService;
    ModelDownloader modelDownloader;
    MLTaskManager mlTaskManager;
    MLModelManager mlModelManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    MLCircuitBreakerService mlCircuitBreakerService;
    MLStats mlStats;

    @Inject
    public TransportDeployModelOnNodeAction(TransportService transportService, ActionFilters actionFilters, ModelDownloader modelDownloader, MLTaskManager mlTaskManager, MLModelManager mlModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry xContentRegistry, MLCircuitBreakerService mlCircuitBreakerService, MLStats mlStats) {
        super("cluster:admin/lucenia/ml/deploy_model_on_nodes", threadPool, clusterService, transportService, actionFilters, MLDeployModelNodesRequest::new, MLDeployModelNodeRequest::new, "management", MLDeployModelNodeResponse.class);
        this.transportService = transportService;
        this.modelDownloader = modelDownloader;
        this.mlTaskManager = mlTaskManager;
        this.mlModelManager = mlModelManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.mlCircuitBreakerService = mlCircuitBreakerService;
        this.mlStats = mlStats;
    }

    protected MLDeployModelNodesResponse newResponse(MLDeployModelNodesRequest nodesRequest, List<MLDeployModelNodeResponse> responses, List<FailedNodeException> failures) {
        return new MLDeployModelNodesResponse(this.clusterService.getClusterName(), responses, failures);
    }

    protected MLDeployModelNodeRequest newNodeRequest(MLDeployModelNodesRequest request) {
        return new MLDeployModelNodeRequest(request);
    }

    protected MLDeployModelNodeResponse newNodeResponse(StreamInput in) throws IOException {
        return new MLDeployModelNodeResponse(in);
    }

    protected MLDeployModelNodeResponse nodeOperation(MLDeployModelNodeRequest request) {
        return this.createDeployModelNodeResponse(request.getMLDeployModelNodesRequest());
    }

    private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNodesRequest mlDeployModelNodesRequest) {
        MLDeployModelInput deployModelInput = mlDeployModelNodesRequest.getMlDeployModelInput();
        String tenantId = mlDeployModelNodesRequest.getMlDeployModelInput().getTenantId();
        String modelId = deployModelInput.getModelId();
        String taskId = deployModelInput.getTaskId();
        String coordinatingNodeId = deployModelInput.getCoordinatingNodeId();
        MLTask mlTask = deployModelInput.getMlTask();
        String modelContentHash = deployModelInput.getModelContentHash();
        boolean deployToAllNodes = deployModelInput.getIsDeployToAllNodes();
        HashMap<String, String> modelDeployStatus = new HashMap<String, String>();
        modelDeployStatus.put(modelId, "received");
        String localNodeId = this.clusterService.localNode().getId();
        ActionListener taskDoneListener = ActionListenerHelper.wrap(res -> log.info("deploy model task done {}", (Object)taskId), ex -> MLExceptionUtils.logException((String)("Deploy model task failed: " + taskId), (Exception)ex, (Logger)log));
        this.deployModel(modelId, tenantId, modelContentHash, mlTask.getFunctionName(), localNodeId, coordinatingNodeId, deployToAllNodes, mlTask, (ActionListener<String>)ActionListenerHelper.wrap(r -> {
            MLForwardInput mlForwardInput = MLForwardInput.builder().requestType(MLForwardRequestType.DEPLOY_MODEL_DONE).taskId(taskId).modelId(modelId).workerNodeId(this.clusterService.localNode().getId()).tenantId(tenantId).build();
            MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput);
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.transportService.sendRequest(this.getNodeById(coordinatingNodeId), "cluster:admin/opensearch/mlinternal/forward", (TransportRequest)deployModelDoneMessage, (TransportResponseHandler)new ActionListenerResponseHandler(taskDoneListener, MLForwardResponse::new));
            }
        }, e -> {
            MLForwardInput mlForwardInput = MLForwardInput.builder().requestType(MLForwardRequestType.DEPLOY_MODEL_DONE).taskId(taskId).modelId(modelId).workerNodeId(this.clusterService.localNode().getId()).error(MLExceptionUtils.getRootCauseMessage((Throwable)e)).tenantId(tenantId).build();
            MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput);
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.transportService.sendRequest(this.getNodeById(coordinatingNodeId), "cluster:admin/opensearch/mlinternal/forward", (TransportRequest)deployModelDoneMessage, (TransportResponseHandler)new ActionListenerResponseHandler(taskDoneListener, MLForwardResponse::new));
            }
        }));
        return new MLDeployModelNodeResponse(this.clusterService.localNode(), modelDeployStatus);
    }

    private DiscoveryNode getNodeById(String nodeId) {
        DiscoveryNodes nodes = this.clusterService.state().getNodes();
        for (DiscoveryNode node : nodes) {
            if (!node.getId().equals(nodeId)) continue;
            return node;
        }
        return null;
    }

    private void deployModel(String modelId, String tenantId, String modelContentHash, FunctionName functionName, String localNodeId, String coordinatingNodeId, boolean deployToAllNodes, MLTask mlTask, ActionListener<String> listener) {
        try {
            log.debug("start deploying model {}", (Object)modelId);
            this.mlModelManager.deployModel(modelId, tenantId, modelContentHash, functionName, deployToAllNodes, false, mlTask, (ActionListener<String>)ActionListenerHelper.runBefore(listener, () -> {
                if (!coordinatingNodeId.equals(localNodeId)) {
                    this.mlTaskManager.remove(mlTask.getTaskId());
                }
            }));
        }
        catch (Exception e) {
            MLExceptionUtils.logException((String)("Failed to deploy model " + modelId), (Exception)e, (Logger)log);
            listener.onFailure(e);
        }
    }
}

