/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.action.mcpserver;

import io.lucenia.ml.common.action.mcpserver.McpToolsHelper;
import io.lucenia.ml.common.engine.systemindices.MLIndicesHandler;
import io.lucenia.ml.common.transport.mcpserver.action.MLMcpToolsRegisterOnNodesAction;
import io.lucenia.ml.common.transport.mcpserver.requests.McpToolBaseInput;
import io.lucenia.ml.common.transport.mcpserver.requests.register.MLMcpToolsRegisterNodesRequest;
import io.lucenia.ml.common.transport.mcpserver.requests.register.McpToolRegisterInput;
import io.lucenia.ml.common.transport.mcpserver.responses.register.MLMcpToolsRegisterNodesResponse;
import io.skylite.SkyliteException;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionFilters;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.action.DocWriteRequest;
import io.skylite.core.action.WriteRequest;
import io.skylite.core.action.bulk.BulkItemResponse;
import io.skylite.core.action.bulk.BulkRequest;
import io.skylite.core.action.bulk.BulkResponse;
import io.skylite.core.action.index.IndexRequest;
import io.skylite.core.action.support.HandledTransportAction;
import io.skylite.core.client.Client;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.Strings;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.common.inject.Inject;
import io.skylite.core.tasks.Task;
import io.skylite.core.threadpool.ThreadPool;
import io.skylite.core.transport.TransportService;
import io.skylite.core.xcontent.NamedXContentRegistry;
import io.skylite.ml.common.cluster.DiscoveryNodeHelper;
import io.skylite.ml.common.settings.MLCommonsSettings;
import io.skylite.ml.common.settings.MLFeatureEnabledSetting;
import io.skylite.ml.common.systemindices.MLIndex;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class TransportMcpToolsRegisterAction
extends HandledTransportAction<ActionRequest, MLMcpToolsRegisterNodesResponse> {
    TransportService transportService;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    DiscoveryNodeHelper nodeFilter;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final MLIndicesHandler mlIndicesHandler;
    private final McpToolsHelper mcpToolsHelper;
    private static final Logger log = LogManager.getLogger(TransportMcpToolsRegisterAction.class);
    private static final Map<String, Map<String, Object>> mcpToolsRegistry = new ConcurrentHashMap<String, Map<String, Object>>();
    private Map<String, Object> originalSchemas = new HashMap<String, Object>();

    @Inject
    public TransportMcpToolsRegisterAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry xContentRegistry, DiscoveryNodeHelper nodeFilter, MLIndicesHandler mlIndicesHandler, McpToolsHelper mcpToolsHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        super("cluster:admin/lucenia/ml/mcp_tools/register", transportService, actionFilters, MLMcpToolsRegisterNodesRequest::new);
        this.transportService = transportService;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.nodeFilter = nodeFilter;
        this.mlIndicesHandler = mlIndicesHandler;
        this.mcpToolsHelper = mcpToolsHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLMcpToolsRegisterNodesResponse> listener) {
        if (!this.mlFeatureEnabledSetting.isMcpServerEnabled()) {
            listener.onFailure((Exception)new SkyliteException(MLCommonsSettings.ML_COMMONS_MCP_SERVER_DISABLED_MESSAGE, new Object[0]));
            return;
        }
        MLMcpToolsRegisterNodesRequest registerNodesRequest = (MLMcpToolsRegisterNodesRequest)request;
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener initIndexListener = ActionListener.wrap(created -> {
                ActionListener searchResultListener = ActionListener.wrap(searchResult -> {
                    if (!searchResult.isEmpty()) {
                        Set registerToolNames = registerNodesRequest.getMcpTools().stream().map(McpToolBaseInput::getName).collect(Collectors.toSet());
                        List<String> existingTools = searchResult.stream().map(McpToolBaseInput::getName).filter(registerToolNames::contains).toList();
                        String exceptionMessage = String.format(Locale.ROOT, "Unable to register tools: %s as they already exist", existingTools);
                        log.warn(exceptionMessage);
                        restoreListener.onFailure((Exception)new IllegalArgumentException(exceptionMessage));
                    } else {
                        this.indexMcpTools(registerNodesRequest, (ActionListener<MLMcpToolsRegisterNodesResponse>)restoreListener);
                    }
                }, e -> {
                    log.error("Failed to search mcp tools index", (Throwable)e);
                    restoreListener.onFailure(e);
                });
                this.mcpToolsHelper.searchToolsWithVersion(registerNodesRequest.getMcpTools().stream().map(McpToolBaseInput::getName).toList(), (ActionListener<List<McpToolRegisterInput>>)searchResultListener);
            }, e -> {
                log.error("Failed to create .plugins-ml-mcp-tools index", (Throwable)e);
                restoreListener.onFailure(e);
            });
            this.mlIndicesHandler.initMLMcpToolsIndex((ActionListener<Boolean>)initIndexListener);
        }
        catch (Exception e2) {
            log.error("Failed to register mcp tools", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private void indexMcpTools(MLMcpToolsRegisterNodesRequest registerNodesRequest, ActionListener<MLMcpToolsRegisterNodesResponse> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener indexResultListener = ActionListener.wrap(bulkResponse -> {
                if (!bulkResponse.hasFailures()) {
                    this.registerMcpToolsOnNodes(new StringBuilder(), this.updateVersion(registerNodesRequest, (BulkResponse)bulkResponse), registerNodesRequest.getMcpTools().stream().map(McpToolBaseInput::getName).filter(name -> name != null).collect(Collectors.toSet()), (ActionListener<MLMcpToolsRegisterNodesResponse>)restoreListener);
                } else {
                    AtomicReference indexSucceedTools = new AtomicReference();
                    indexSucceedTools.set(new HashSet());
                    AtomicReference indexFailedTools = new AtomicReference();
                    indexFailedTools.set(new HashMap());
                    Arrays.stream(bulkResponse.getItems()).forEach(y -> {
                        if (y.isFailed()) {
                            ((Map)indexFailedTools.get()).put(y.getId(), y.getFailure().getMessage());
                            registerNodesRequest.getMcpTools().removeIf(x -> x.getName().equals(y.getId()));
                        } else {
                            ((Set)indexSucceedTools.get()).add(y.getId());
                        }
                    });
                    StringBuilder errMsgBuilder = new StringBuilder();
                    for (Map.Entry indexFailedTool : ((Map)indexFailedTools.get()).entrySet()) {
                        errMsgBuilder.append(String.format(Locale.ROOT, "Failed to persist mcp tool: %s into system index with error: %s", indexFailedTool.getKey(), indexFailedTool.getValue()));
                        errMsgBuilder.append("\n");
                    }
                    log.error(errMsgBuilder.toString());
                    if (!((Set)indexSucceedTools.get()).isEmpty()) {
                        this.registerMcpToolsOnNodes(errMsgBuilder, this.updateVersion(registerNodesRequest, (BulkResponse)bulkResponse), (Set)indexSucceedTools.get(), (ActionListener<MLMcpToolsRegisterNodesResponse>)restoreListener);
                    } else {
                        restoreListener.onFailure((Exception)new SkyliteException(errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1).toString(), new Object[0]));
                    }
                }
            }, e -> {
                log.error("Failed to persist mcp tools into system index because exception: {}", (Object)e.getMessage());
                restoreListener.onFailure(e);
            });
            BulkRequest bulkRequest = new BulkRequest();
            bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            for (McpToolRegisterInput mcpTool : registerNodesRequest.getMcpTools()) {
                IndexRequest indexRequest = new IndexRequest(MLIndex.MCP_TOOLS.getIndexName());
                indexRequest.opType(DocWriteRequest.OpType.CREATE);
                String toolId = mcpTool.getName() != null ? mcpTool.getName() : mcpTool.getType();
                indexRequest.id(toolId);
                HashMap<String, Object> source = new HashMap<String, Object>();
                source.put("name", mcpTool.getName());
                source.put("type", mcpTool.getType());
                source.put("parameters", mcpTool.getParameters() != null ? mcpTool.getParameters() : Map.of());
                Map<String, Object> attributes = mcpTool.getAttributes();
                if (attributes != null && attributes.containsKey("input_schema")) {
                    Object originalSchemaObj = attributes.get("input_schema");
                    this.originalSchemas.put(toolId, originalSchemaObj);
                    if (!(originalSchemaObj instanceof String)) {
                        HashMap<String, Object> newAttributes = new HashMap<String, Object>(attributes);
                        newAttributes.put("input_schema", Strings.toJson((Object)originalSchemaObj));
                        attributes = newAttributes;
                    }
                }
                source.put("attributes", attributes);
                source.put("description", mcpTool.getDescription());
                source.put("create_time", Instant.now().toEpochMilli());
                indexRequest.source(source);
                bulkRequest.add(indexRequest);
            }
            this.client.bulk(bulkRequest, indexResultListener);
        }
        catch (Exception e2) {
            log.error("Failed to register mcp tools", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private MLMcpToolsRegisterNodesRequest updateVersion(MLMcpToolsRegisterNodesRequest registerNodesRequest, BulkResponse bulkResponse) {
        Map<String, Long> version = Arrays.stream(bulkResponse.getItems()).filter(x -> !x.isFailed()).collect(Collectors.toMap(BulkItemResponse::getId, x -> x.getResponse().getVersion()));
        registerNodesRequest.getMcpTools().forEach(x -> x.setVersion((Long)version.get(x.getName())));
        return registerNodesRequest;
    }

    private void registerMcpToolsOnNodes(StringBuilder errMsgBuilder, MLMcpToolsRegisterNodesRequest registerNodesRequest, Set<String> indexSucceedTools, ActionListener<MLMcpToolsRegisterNodesResponse> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener addToMemoryResultListener = ActionListener.wrap(r -> {
                if (r.failures() != null && !r.failures().isEmpty()) {
                    r.failures().forEach(x -> {
                        String schemaInfo = "";
                        try {
                            for (String toolName : indexSucceedTools) {
                                Object originalSchema = this.originalSchemas.get(toolName);
                                if (originalSchema == null) continue;
                                schemaInfo = String.format(Locale.ROOT, " [Original schema submitted: %s]", originalSchema);
                                break;
                            }
                        }
                        catch (Exception schemaEx) {
                            log.error(errMsgBuilder.toString());
                            if (schemaEx instanceof SkyliteException) {
                                restoreListener.onFailure(schemaEx);
                                return;
                            }
                            schemaInfo = " [Failed to retrieve original schema for debugging]";
                        }
                        String errorMsg = x.getRootCause().getMessage();
                        Object enhancedError = errorMsg;
                        if (errorMsg != null && errorMsg.startsWith("Invalid schema:")) {
                            enhancedError = errorMsg + "\nPossible issues: \n  - Schema must be a valid JSON Schema (draft-07 or later)\n  - Property types must be strings (e.g., \"type\":\"string\", not \"type\":123)\n  - Check for missing quotes around property values\n  - Ensure 'required' is an array of property names" + schemaInfo;
                        }
                        errMsgBuilder.append(String.format(Locale.ROOT, "Tools: %s are persisted successfully but failed to register to mcp server memory with error: %s", indexSucceedTools, enhancedError));
                        errMsgBuilder.append("\n");
                    });
                    errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1);
                    log.error(errMsgBuilder.toString());
                    restoreListener.onFailure((Exception)new SkyliteException(errMsgBuilder.toString(), new Object[0]));
                } else if (errMsgBuilder.isEmpty()) {
                    restoreListener.onResponse((Object)r);
                } else {
                    restoreListener.onFailure((Exception)new SkyliteException(errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1).toString(), new Object[0]));
                }
            }, e -> {
                errMsgBuilder.append(String.format(Locale.ROOT, "Tools are persisted successfully but failed to register to mcp server memory with error: %s", e.getMessage()));
                log.error(errMsgBuilder.toString(), (Throwable)e);
                restoreListener.onFailure((Exception)new SkyliteException(errMsgBuilder.toString(), new Object[0]));
            });
            this.client.execute((ActionType)MLMcpToolsRegisterOnNodesAction.INSTANCE, (ActionRequest)registerNodesRequest, addToMemoryResultListener);
        }
        catch (Exception e2) {
            log.error("Failed to register mcp tools on nodes", (Throwable)e2);
            listener.onFailure(e2);
        }
    }
}

