/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.rest.mcpserver;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.lucenia.ml.common.transport.mcpserver.action.MLMcpServerAction;
import io.lucenia.ml.common.transport.mcpserver.requests.server.MLMcpServerRequest;
import io.lucenia.ml.common.transport.mcpserver.requests.server.MLMcpServerResponse;
import io.skylite.SkyliteException;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.client.node.NodeClient;
import io.skylite.core.common.bytes.BytesArray;
import io.skylite.core.common.bytes.BytesReference;
import io.skylite.core.rest.RestChannel;
import io.skylite.core.rest.RestHandler;
import io.skylite.core.rest.RestRequest;
import io.skylite.core.rest.RestResponse;
import io.skylite.core.rest.RestStatus;
import io.skylite.ml.common.settings.MLCommonsSettings;
import io.skylite.ml.common.settings.MLFeatureEnabledSetting;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;

public class RestMcpServerAction
extends BaseRestHandler {
    private static final Logger log = LogManager.getLogger(RestMcpServerAction.class);
    private static final String ML_MCP_SERVER_ACTION = "ml_mcp_server_action";
    public static final String MCP_SERVER_ENDPOINT = "/_plugins/_ml/mcp";
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final ObjectMapper objectMapper;

    public RestMcpServerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.objectMapper = new ObjectMapper();
    }

    public List<RestHandler.Route> routes() {
        return List.of(new RestHandler.Route(RestRequest.Method.POST, MCP_SERVER_ENDPOINT), new RestHandler.Route(RestRequest.Method.GET, MCP_SERVER_ENDPOINT));
    }

    public String getName() {
        return ML_MCP_SERVER_ACTION;
    }

    protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
        if (!this.mlFeatureEnabledSetting.isMcpServerEnabled()) {
            throw new SkyliteException(MLCommonsSettings.ML_COMMONS_MCP_SERVER_DISABLED_MESSAGE, new Object[0]);
        }
        if (request.method() == RestRequest.Method.GET) {
            return channel -> channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.METHOD_NOT_ALLOWED, "", (BytesReference)BytesArray.EMPTY));
        }
        return channel -> {
            try {
                if (request.content() == null) {
                    this.sendErrorResponse((RestChannel)channel, null, -32700, "Parse error: empty body");
                    return;
                }
                String requestBody = request.content().utf8ToString();
                if (requestBody == null || requestBody.isBlank()) {
                    this.sendErrorResponse((RestChannel)channel, null, -32700, "Parse error: empty body");
                    return;
                }
                MLMcpServerRequest mcpRequest = new MLMcpServerRequest(requestBody);
                client.execute((ActionType)MLMcpServerAction.INSTANCE, (ActionRequest)mcpRequest, (ActionListener)new ActionListener<MLMcpServerResponse>(){

                    public void onResponse(MLMcpServerResponse response) {
                        try {
                            if (response.getError() != null) {
                                Map<String, Object> errorMap = response.getError();
                                Object id = errorMap.get("id");
                                int code = (Integer)errorMap.get("error_code");
                                String message = (String)errorMap.get("message");
                                RestMcpServerAction.this.sendErrorResponse(channel, id, code, message);
                            } else if (response.getMcpResponse() != null) {
                                String resp = response.getMcpResponse();
                                channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.OK, "application/json", resp));
                            } else {
                                channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.ACCEPTED, "", (BytesReference)BytesArray.EMPTY));
                            }
                        }
                        catch (Exception e) {
                            log.error("Failed to send response", (Throwable)e);
                            RestMcpServerAction.this.sendErrorResponse(channel, null, -32603, "Failed to send response");
                        }
                    }

                    public void onFailure(Exception e) {
                        log.error("Failed to handle MCP request", (Throwable)e);
                        RestMcpServerAction.this.sendErrorResponse(channel, null, -32603, "Internal server error: " + e.getMessage());
                    }
                });
            }
            catch (Exception e) {
                log.error("Failed to handle MCP request", (Throwable)e);
                this.sendErrorResponse((RestChannel)channel, null, -32603, "Internal server error");
            }
        };
    }

    private void sendErrorResponse(RestChannel channel, Object id, int code, String message) {
        try {
            HashMap<String, Object> errorResponse = new HashMap<String, Object>();
            errorResponse.put("jsonrpc", "2.0");
            errorResponse.put("id", id);
            errorResponse.put("error", Map.of("code", code, "message", message));
            String responseJson = this.objectMapper.writeValueAsString(errorResponse);
            channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.OK, "application/json", responseJson));
        }
        catch (Exception e) {
            log.error("Failed to send error response", (Throwable)e);
            try {
                channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, "application/json", "{\"jsonrpc\":\"2.0\",\"id\":null,\"error\":{\"code\":-32603,\"message\":\"Failed to send error response\"}}"));
            }
            catch (Exception inner) {
                log.error("Even fallback error response failed", (Throwable)inner);
            }
        }
    }
}

