/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.action.undeploy;

import io.lucenia.action.support.nodes.TransportNodesAction;
import io.lucenia.ml.common.action.undeploy.TransportUndeployModelsAction;
import io.lucenia.ml.common.model.MLModelManager;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionFilters;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.action.FailedNodeException;
import io.skylite.core.action.WriteRequest;
import io.skylite.core.action.bulk.BulkRequest;
import io.skylite.core.action.support.nodes.BaseNodesRequest;
import io.skylite.core.action.update.UpdateRequest;
import io.skylite.core.client.Client;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.common.inject.Inject;
import io.skylite.core.common.io.stream.StreamInput;
import io.skylite.core.tasks.Task;
import io.skylite.core.threadpool.ThreadPool;
import io.skylite.core.transport.TransportService;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.cluster.DiscoveryNodeHelper;
import io.skylite.ml.common.model.MLModelState;
import io.skylite.ml.common.stats.MLNodeLevelStat;
import io.skylite.ml.common.stats.MLStats;
import io.skylite.ml.common.transport.sync.MLSyncUpAction;
import io.skylite.ml.common.transport.sync.MLSyncUpInput;
import io.skylite.ml.common.transport.sync.MLSyncUpNodesRequest;
import io.skylite.ml.common.transport.undeploy.MLUndeployModelNodeRequest;
import io.skylite.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
import io.skylite.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import io.skylite.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class TransportUndeployModelAction
extends TransportNodesAction<MLUndeployModelNodesRequest, MLUndeployModelNodesResponse, MLUndeployModelNodeRequest, MLUndeployModelNodeResponse> {
    private static final Logger log = LogManager.getLogger(TransportUndeployModelsAction.class);
    private final MLModelManager mlModelManager;
    private final ClusterService clusterService;
    private final Client client;
    private final DiscoveryNodeHelper nodeFilter;
    private final MLStats mlStats;

    @Inject
    public TransportUndeployModelAction(TransportService transportService, ActionFilters actionFilters, MLModelManager mlModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, DiscoveryNodeHelper nodeFilter, MLStats mlStats) {
        super("cluster:admin/lucenia/ml/undeploy_model", threadPool, clusterService, transportService, actionFilters, MLUndeployModelNodesRequest::new, MLUndeployModelNodeRequest::new, "management", MLUndeployModelNodeResponse.class);
        this.mlModelManager = mlModelManager;
        this.clusterService = clusterService;
        this.client = client;
        this.nodeFilter = nodeFilter;
        this.mlStats = mlStats;
    }

    public void doExecute(Task task, MLUndeployModelNodesRequest request, ActionListener<MLUndeployModelNodesResponse> listener) {
        ActionListener wrappedListener = ActionListenerHelper.wrap(undeployModelNodesResponse -> this.processUndeployModelResponseAndUpdate((MLUndeployModelNodesResponse)undeployModelNodesResponse, listener), arg_0 -> listener.onFailure(arg_0));
        super.doExecute(task, (BaseNodesRequest)request, wrappedListener);
    }

    void processUndeployModelResponseAndUpdate(MLUndeployModelNodesResponse undeployModelNodesResponse, ActionListener<MLUndeployModelNodesResponse> listener) {
        List responses = undeployModelNodesResponse.getNodes();
        if (responses == null || responses.isEmpty()) {
            listener.onResponse((Object)undeployModelNodesResponse);
            return;
        }
        HashMap<String, List<String>> actualRemovedNodesMap = new HashMap<String, List<String>>();
        HashMap modelWorkNodesBeforeRemoval = new HashMap();
        responses.forEach(r -> {
            Map nodeCounts = r.getModelWorkerNodeBeforeRemoval();
            if (nodeCounts != null) {
                for (Map.Entry entry : nodeCounts.entrySet()) {
                    if (entry.getValue() == null || modelWorkNodesBeforeRemoval.containsKey(entry.getKey()) && ((String[])modelWorkNodesBeforeRemoval.get(entry.getKey())).length >= ((String[])entry.getValue()).length) continue;
                    modelWorkNodesBeforeRemoval.put((String)entry.getKey(), (String[])entry.getValue());
                }
            }
            Map modelUndeployStatus = r.getModelUndeployStatus();
            for (Map.Entry entry : modelUndeployStatus.entrySet()) {
                String status = (String)entry.getValue();
                if (!"undeployed".equals(status)) continue;
                String modelId = (String)entry.getKey();
                if (!actualRemovedNodesMap.containsKey(modelId)) {
                    actualRemovedNodesMap.put(modelId, new ArrayList());
                }
                ((List)actualRemovedNodesMap.get(modelId)).add(r.getNode().getId());
            }
        });
        MLSyncUpInput syncUpInput = MLSyncUpInput.builder().removedWorkerNodes(this.covertRemoveNodesMapForSyncUp(actualRemovedNodesMap)).build();
        MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(this.nodeFilter.getAllNodes(), syncUpInput);
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            if (actualRemovedNodesMap.size() > 0) {
                BulkRequest bulkRequest = new BulkRequest();
                HashMap<String, Boolean> deployToAllNodes = new HashMap<String, Boolean>();
                for (String modelId : actualRemovedNodesMap.keySet()) {
                    UpdateRequest updateRequest = new UpdateRequest();
                    List removedNodes = (List)actualRemovedNodesMap.get(modelId);
                    int removedNodeCount = removedNodes.size();
                    HashMap<String, Object> updateDocument = new HashMap<String, Object>();
                    if (((String[])modelWorkNodesBeforeRemoval.get(modelId)).length == removedNodeCount) {
                        updateDocument.put("planning_worker_nodes", List.of());
                        updateDocument.put("planning_worker_node_count", 0);
                        updateDocument.put("current_worker_node_count", 0);
                        updateDocument.put("model_state", MLModelState.UNDEPLOYED);
                    } else {
                        updateDocument.put("deploy_to_all_nodes", false);
                        List newPlanningWorkerNodes = Arrays.stream((String[])modelWorkNodesBeforeRemoval.get(modelId)).filter(x -> !removedNodes.contains(x)).collect(Collectors.toList());
                        updateDocument.put("planning_worker_nodes", newPlanningWorkerNodes);
                        updateDocument.put("planning_worker_node_count", newPlanningWorkerNodes.size());
                        updateDocument.put("current_worker_node_count", newPlanningWorkerNodes.size());
                        deployToAllNodes.put(modelId, false);
                    }
                    ((UpdateRequest)updateRequest.index(".plugins-ml-model")).id(modelId).doc(updateDocument);
                    bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                }
                syncUpInput.setDeployToAllNodes(deployToAllNodes);
                ActionListener actionListener = ActionListenerHelper.wrap(r -> log.debug("updated model state as undeployed for : {}", (Object)Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0]))), e -> log.error("Failed to update model state as undeployed", (Throwable)e));
                this.client.bulk(bulkRequest, ActionListenerHelper.runAfter((ActionListener)actionListener, () -> {
                    this.syncUpUndeployedModels(syncUpRequest);
                    listener.onResponse((Object)undeployModelNodesResponse);
                }));
            } else {
                this.syncUpUndeployedModels(syncUpRequest);
                listener.onResponse((Object)undeployModelNodesResponse);
            }
        }
    }

    protected MLUndeployModelNodesResponse newResponse(MLUndeployModelNodesRequest nodesRequest, List<MLUndeployModelNodeResponse> responses, List<FailedNodeException> failures) {
        return new MLUndeployModelNodesResponse(this.clusterService.getClusterName(), responses, failures);
    }

    private Map<String, String[]> covertRemoveNodesMapForSyncUp(Map<String, List<String>> actualRemovedNodesMap) {
        HashMap<String, String[]> removedNodesMap = new HashMap<String, String[]>();
        for (Map.Entry<String, List<String>> entry : actualRemovedNodesMap.entrySet()) {
            removedNodesMap.put(entry.getKey(), entry.getValue().toArray(new String[0]));
            log.debug("removed node for model: {}, {}", (Object)entry.getKey(), (Object)Arrays.toString(entry.getValue().toArray(new String[0])));
        }
        return removedNodesMap;
    }

    private void syncUpUndeployedModels(MLSyncUpNodesRequest syncUpRequest) {
        this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)syncUpRequest, ActionListenerHelper.wrap(r -> log.debug("sync up removed nodes successfully"), e -> log.error("failed to sync up removed node", (Throwable)e)));
    }

    protected MLUndeployModelNodeRequest newNodeRequest(MLUndeployModelNodesRequest request) {
        return new MLUndeployModelNodeRequest(request);
    }

    protected MLUndeployModelNodeResponse newNodeResponse(StreamInput in) throws IOException {
        return new MLUndeployModelNodeResponse(in);
    }

    protected MLUndeployModelNodeResponse nodeOperation(MLUndeployModelNodeRequest request) {
        return this.createUndeployModelNodeResponse(request.getMlUndeployModelNodesRequest());
    }

    private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployModelNodesRequest MLUndeployModelNodesRequest2) {
        String[] removedModelIds;
        this.mlStats.getStat((Enum)MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
        String[] modelIds = MLUndeployModelNodesRequest2.getModelIds();
        HashMap<String, String[]> modelWorkerNodesMap = new HashMap<String, String[]>();
        boolean specifiedModelIds = modelIds != null && modelIds.length > 0;
        String[] stringArray = removedModelIds = specifiedModelIds ? modelIds : this.mlModelManager.getAllModelIds();
        if (removedModelIds != null) {
            for (String modelId : removedModelIds) {
                FunctionName functionName = this.mlModelManager.getModelFunctionName(modelId);
                String[] workerNodes = this.mlModelManager.getWorkerNodes(modelId, functionName);
                modelWorkerNodesMap.put(modelId, workerNodes);
            }
        }
        Map<String, String> modelUndeployStatus = this.mlModelManager.undeployModel(modelIds);
        this.mlStats.getStat((Enum)MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
        return new MLUndeployModelNodeResponse(this.clusterService.localNode(), modelUndeployStatus, modelWorkerNodesMap);
    }
}

