/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.action.mcpserver;

import io.lucenia.ml.common.action.mcpserver.McpToolsHelper;
import io.lucenia.ml.common.transport.mcpserver.action.MLMcpToolsUpdateOnNodesAction;
import io.lucenia.ml.common.transport.mcpserver.requests.McpToolBaseInput;
import io.lucenia.ml.common.transport.mcpserver.requests.register.McpToolRegisterInput;
import io.lucenia.ml.common.transport.mcpserver.requests.update.MLMcpToolsUpdateNodesRequest;
import io.lucenia.ml.common.transport.mcpserver.requests.update.McpToolUpdateInput;
import io.lucenia.ml.common.transport.mcpserver.responses.update.MLMcpToolsUpdateNodesResponse;
import io.skylite.SkyliteException;
import io.skylite.common.action.ActionListener;
import io.skylite.common.xcontent.json.JsonXContent;
import io.skylite.core.action.ActionFilters;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.action.WriteRequest;
import io.skylite.core.action.bulk.BulkItemResponse;
import io.skylite.core.action.bulk.BulkRequest;
import io.skylite.core.action.bulk.BulkResponse;
import io.skylite.core.action.search.SearchResponse;
import io.skylite.core.action.support.HandledTransportAction;
import io.skylite.core.action.update.UpdateRequest;
import io.skylite.core.client.Client;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.common.inject.Inject;
import io.skylite.core.tasks.Task;
import io.skylite.core.threadpool.ThreadPool;
import io.skylite.core.transport.TransportService;
import io.skylite.core.xcontent.DeprecationHandler;
import io.skylite.core.xcontent.LoggingDeprecationHandler;
import io.skylite.core.xcontent.NamedXContentRegistry;
import io.skylite.core.xcontent.XContentParser;
import io.skylite.core.xcontent.XContentParserUtils;
import io.skylite.ml.common.cluster.DiscoveryNodeHelper;
import io.skylite.ml.common.settings.MLCommonsSettings;
import io.skylite.ml.common.settings.MLFeatureEnabledSetting;
import io.skylite.ml.common.systemindices.MLIndex;
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class TransportMcpToolsUpdateAction
extends HandledTransportAction<ActionRequest, MLMcpToolsUpdateNodesResponse> {
    private static final Logger log = LogManager.getLogger(TransportMcpToolsUpdateAction.class);
    TransportService transportService;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    DiscoveryNodeHelper nodeFilter;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final McpToolsHelper mcpToolsHelper;

    @Inject
    public TransportMcpToolsUpdateAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry xContentRegistry, DiscoveryNodeHelper nodeFilter, McpToolsHelper mcpToolsHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        super("cluster:admin/lucenia/ml/mcp_tools/update", transportService, actionFilters, MLMcpToolsUpdateNodesRequest::new);
        this.transportService = transportService;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.nodeFilter = nodeFilter;
        this.mcpToolsHelper = mcpToolsHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLMcpToolsUpdateNodesResponse> listener) {
        if (!this.mlFeatureEnabledSetting.isMcpServerEnabled()) {
            listener.onFailure((Exception)new SkyliteException(MLCommonsSettings.ML_COMMONS_MCP_SERVER_DISABLED_MESSAGE, new Object[0]));
            return;
        }
        if (!this.clusterService.state().metadata().hasIndex(MLIndex.MCP_TOOLS.getIndexName())) {
            listener.onFailure((Exception)new SkyliteException("MCP tools index doesn't exist", new Object[0]));
            return;
        }
        MLMcpToolsUpdateNodesRequest updateNodesRequest = (MLMcpToolsUpdateNodesRequest)request;
        HashSet updateToolSet = new HashSet();
        updateNodesRequest.getMcpTools().forEach(x -> updateToolSet.add(x.getName()));
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener searchResultListener = ActionListener.wrap(searchResult -> {
                if (Objects.requireNonNull(searchResult.getHits().getHits()).length > 0) {
                    ArrayList<SearchedMcpToolWrapper> searchedMcpToolWrappers = new ArrayList<SearchedMcpToolWrapper>();
                    Arrays.stream(Objects.requireNonNull(searchResult.getHits().getHits())).forEach(x -> {
                        try (XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, x.getSourceAsString());){
                            XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                            McpToolRegisterInput registerMcpTool = McpToolRegisterInput.parse(parser);
                            updateToolSet.remove(registerMcpTool.getName());
                            SearchedMcpToolWrapper updateMcpToolWrapper = new SearchedMcpToolWrapper.SearchedMcpToolWrapperBuilder().seqNo(x.getSeqNo()).primaryTerm(x.getPrimaryTerm()).mcpTool(registerMcpTool).build();
                            searchedMcpToolWrappers.add(updateMcpToolWrapper);
                        }
                        catch (IOException e) {
                            log.error("Failed to parse mcp tools configuration");
                            restoreListener.onFailure((Exception)e);
                        }
                    });
                    if (!updateToolSet.isEmpty()) {
                        String errMsg = String.format(Locale.ROOT, "Failed to find tools: %s in system index", updateToolSet);
                        log.warn(errMsg);
                        restoreListener.onFailure((Exception)new SkyliteException(errMsg, new Object[0]));
                    } else {
                        this.updateMcpTools(updateNodesRequest, searchedMcpToolWrappers, (ActionListener<MLMcpToolsUpdateNodesResponse>)restoreListener);
                    }
                } else {
                    restoreListener.onFailure((Exception)new SkyliteException("Failed to update tools as none of them is found in index", new Object[0]));
                }
            }, e -> {
                log.error("Failed to search mcp tools index", (Throwable)e);
                restoreListener.onFailure(e);
            });
            this.mcpToolsHelper.searchToolsWithPrimaryTermAndSeqNo(updateNodesRequest.getMcpTools().stream().map(McpToolBaseInput::getName).toList(), (ActionListener<SearchResponse>)searchResultListener);
        }
        catch (Exception e2) {
            log.error("Failed to update mcp tools", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private MLMcpToolsUpdateNodesRequest mergeDocFields(MLMcpToolsUpdateNodesRequest updateNodesRequest, List<SearchedMcpToolWrapper> updateMcpToolWrappers, BulkResponse bulkResponse) {
        Map<String, McpToolRegisterInput> mcpToolsMap = updateMcpToolWrappers.stream().collect(Collectors.toMap(x -> x.getMcpTool().getName(), SearchedMcpToolWrapper::getMcpTool));
        Map<String, Long> versions = Arrays.stream(bulkResponse.getItems()).filter(x -> !x.isFailed()).collect(Collectors.toMap(BulkItemResponse::getId, x -> x.getResponse().getVersion()));
        updateNodesRequest.getMcpTools().forEach(x -> {
            McpToolRegisterInput registerMcpTool = (McpToolRegisterInput)mcpToolsMap.get(x.getName());
            x.setType(registerMcpTool.getType());
            if (x.getAttributes() == null) {
                x.setAttributes(registerMcpTool.getAttributes());
            }
            if (x.getParameters() == null) {
                x.setParameters(registerMcpTool.getParameters());
            }
            if (x.getDescription() == null) {
                x.setDescription(registerMcpTool.getDescription());
            }
            x.setVersion((Long)versions.get(x.getName()));
        });
        return updateNodesRequest;
    }

    private void updateMcpTools(MLMcpToolsUpdateNodesRequest updateNodesRequest, List<SearchedMcpToolWrapper> searchedMcpToolWrappers, ActionListener<MLMcpToolsUpdateNodesResponse> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener updateResultListener = ActionListener.wrap(bulkResponse -> {
                if (!bulkResponse.hasFailures()) {
                    this.updateMcpToolsOnNodes(new StringBuilder(), this.mergeDocFields(updateNodesRequest, searchedMcpToolWrappers, (BulkResponse)bulkResponse), updateNodesRequest.getMcpTools().stream().map(McpToolBaseInput::getName).collect(Collectors.toUnmodifiableSet()), (ActionListener<MLMcpToolsUpdateNodesResponse>)restoreListener);
                } else {
                    AtomicReference updateSucceedTools = new AtomicReference();
                    updateSucceedTools.set(new HashSet());
                    AtomicReference updateFailedTools = new AtomicReference();
                    updateFailedTools.set(new HashMap());
                    Arrays.stream(bulkResponse.getItems()).forEach(y -> {
                        if (y.isFailed()) {
                            ((Map)updateFailedTools.get()).put(y.getId(), y.getFailure().getMessage());
                            updateNodesRequest.getMcpTools().removeIf(x -> x.getName().equals(y.getId()));
                            searchedMcpToolWrappers.removeIf(x -> x.getMcpTool().getName().equals(y.getId()));
                        } else {
                            ((Set)updateSucceedTools.get()).add(y.getId());
                        }
                    });
                    StringBuilder errMsgBuilder = new StringBuilder();
                    for (Map.Entry indexFailedTool : ((Map)updateFailedTools.get()).entrySet()) {
                        errMsgBuilder.append(String.format(Locale.ROOT, "Failed to update mcp tool: %s in system index with error: %s", indexFailedTool.getKey(), indexFailedTool.getValue()));
                        errMsgBuilder.append("\n");
                    }
                    log.error(errMsgBuilder.toString());
                    if (!((Set)updateSucceedTools.get()).isEmpty()) {
                        this.updateMcpToolsOnNodes(errMsgBuilder, this.mergeDocFields(updateNodesRequest, searchedMcpToolWrappers, (BulkResponse)bulkResponse), (Set)updateSucceedTools.get(), (ActionListener<MLMcpToolsUpdateNodesResponse>)restoreListener);
                    } else {
                        restoreListener.onFailure((Exception)new SkyliteException(errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1).toString(), new Object[0]));
                    }
                }
            }, e -> {
                log.error("Failed to update mcp tools in system index because exception: {}", (Object)e.getMessage());
                restoreListener.onFailure(e);
            });
            Map<String, SearchedMcpToolWrapper> searchedMcpToolWrapperMap = searchedMcpToolWrappers.stream().collect(Collectors.toMap(x -> x.getMcpTool().getName() != null ? x.getMcpTool().getName() : x.getMcpTool().getType(), x -> x));
            BulkRequest bulkRequest = new BulkRequest();
            for (McpToolUpdateInput mcpTool : updateNodesRequest.getMcpTools()) {
                String toolId = mcpTool.getName() != null ? mcpTool.getName() : mcpTool.getType();
                UpdateRequest updateRequest = new UpdateRequest(MLIndex.MCP_TOOLS.getIndexName(), toolId);
                updateRequest.setIfSeqNo(searchedMcpToolWrapperMap.get(toolId).getSeqNo().longValue());
                updateRequest.setIfPrimaryTerm(searchedMcpToolWrapperMap.get(toolId).getPrimaryTerm().longValue());
                HashMap<String, Object> source = new HashMap<String, Object>();
                if (mcpTool.getDescription() != null) {
                    source.put("description", mcpTool.getDescription());
                }
                if (mcpTool.getParameters() != null) {
                    source.put("parameters", mcpTool.getParameters());
                }
                if (mcpTool.getAttributes() != null) {
                    source.put("attributes", mcpTool.getAttributes());
                }
                source.put("last_update_time", Instant.now().toEpochMilli());
                updateRequest.doc(source);
                bulkRequest.add(updateRequest);
            }
            bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            this.client.bulk(bulkRequest, updateResultListener);
        }
        catch (Exception e2) {
            log.error("Failed to update mcp tools", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private void updateMcpToolsOnNodes(StringBuilder errMsgBuilder, MLMcpToolsUpdateNodesRequest toolsUpdateNodesRequest, Set<String> indexSucceedTools, ActionListener<MLMcpToolsUpdateNodesResponse> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener addToMemoryResultListener = ActionListener.wrap(r -> {
                if (r.failures() != null && !r.failures().isEmpty()) {
                    r.failures().forEach(x -> {
                        errMsgBuilder.append(String.format(Locale.ROOT, "Tools: %s are updated successfully but failed to update to mcp server memory with error: %s", indexSucceedTools, x.getRootCause().getMessage()));
                        errMsgBuilder.append("\n");
                    });
                    errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1);
                    log.error(errMsgBuilder.toString());
                    restoreListener.onFailure((Exception)new SkyliteException(errMsgBuilder.toString(), new Object[0]));
                } else if (errMsgBuilder.isEmpty()) {
                    restoreListener.onResponse((Object)r);
                } else {
                    restoreListener.onFailure((Exception)new SkyliteException(errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1).toString(), new Object[0]));
                }
            }, e -> {
                errMsgBuilder.append(String.format(Locale.ROOT, "Tools are updated successfully but failed to update to mcp server memory with error: %s", e.getMessage()));
                log.error(errMsgBuilder.toString(), (Throwable)e);
                restoreListener.onFailure((Exception)new SkyliteException(errMsgBuilder.toString(), new Object[0]));
            });
            this.client.execute((ActionType)MLMcpToolsUpdateOnNodesAction.INSTANCE, (ActionRequest)toolsUpdateNodesRequest, addToMemoryResultListener);
        }
        catch (Exception e2) {
            log.error("Failed to update mcp tools on nodes", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private static class SearchedMcpToolWrapper {
        private final McpToolRegisterInput mcpTool;
        private final Long primaryTerm;
        private final Long seqNo;
        private final Long version;

        private SearchedMcpToolWrapper(SearchedMcpToolWrapperBuilder builder) {
            this.mcpTool = builder.mcpTool;
            this.primaryTerm = builder.primaryTerm;
            this.seqNo = builder.seqNo;
            this.version = builder.version;
        }

        public McpToolRegisterInput getMcpTool() {
            return this.mcpTool;
        }

        public Long getPrimaryTerm() {
            return this.primaryTerm;
        }

        public Long getSeqNo() {
            return this.seqNo;
        }

        public Long getVersion() {
            return this.version;
        }

        static class SearchedMcpToolWrapperBuilder {
            private McpToolRegisterInput mcpTool;
            private Long primaryTerm;
            private Long seqNo;
            private Long version;

            SearchedMcpToolWrapperBuilder() {
            }

            SearchedMcpToolWrapperBuilder mcpTool(McpToolRegisterInput mcpTool) {
                this.mcpTool = mcpTool;
                return this;
            }

            SearchedMcpToolWrapperBuilder primaryTerm(Long primaryTerm) {
                this.primaryTerm = primaryTerm;
                return this;
            }

            SearchedMcpToolWrapperBuilder seqNo(Long seqNo) {
                this.seqNo = seqNo;
                return this;
            }

            SearchedMcpToolWrapperBuilder version(Long version) {
                this.version = version;
                return this;
            }

            SearchedMcpToolWrapper build() {
                return new SearchedMcpToolWrapper(this);
            }
        }
    }
}

