/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.rest;

import io.lucenia.ml.common.action.profile.MLProfileAction;
import io.lucenia.ml.common.action.profile.MLProfileModelResponse;
import io.lucenia.ml.common.action.profile.MLProfileNodeResponse;
import io.lucenia.ml.common.action.profile.MLProfileRequest;
import io.lucenia.ml.common.rest.MLRestHandler;
import io.lucenia.ml.common.rest.RestActionUtils;
import io.lucenia.ml.common.utils.IndexUtils;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.action.search.SearchResponse;
import io.skylite.core.client.node.NodeClient;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.rest.RestChannel;
import io.skylite.core.rest.RestHandler;
import io.skylite.core.rest.RestRequest;
import io.skylite.core.rest.RestResponse;
import io.skylite.core.rest.RestStatus;
import io.skylite.core.search.SearchHit;
import io.skylite.core.search.SearchRequest;
import io.skylite.core.xcontent.ToXContent;
import io.skylite.core.xcontent.XContentBuilder;
import io.skylite.core.xcontent.XContentParser;
import io.skylite.core.xcontent.XContentParserUtils;
import io.skylite.ml.common.profile.MLModelProfile;
import io.skylite.ml.common.profile.MLProfileInput;
import io.skylite.ml.common.task.MLTask;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;

public class RestMLProfileAction
extends MLRestHandler {
    private static final Logger log = LogManager.getLogger(RestMLProfileAction.class);
    private static final String PROFILE_ML_ACTION = "profile_ml";
    private static final String VIEW = "view";
    private static final String MODEL_VIEW = "model";
    private static final String NODE_VIEW = "node";
    private ClusterService clusterService;

    public RestMLProfileAction(ClusterService clusterService) {
        this.clusterService = clusterService;
    }

    public String getName() {
        return PROFILE_ML_ACTION;
    }

    public List<RestHandler.Route> routes() {
        return List.of(new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/profile/models/{model_id}"), new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/profile/models"), new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/profile/tasks/{task_id}"), new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/profile/tasks"), new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/profile"));
    }

    protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, final NodeClient client) throws IOException {
        MLProfileInput mlProfileInput;
        boolean hasContent = request.hasContent();
        if (hasContent) {
            XContentParser parser = request.contentParser();
            XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
            mlProfileInput = MLProfileInput.parse((XContentParser)parser);
        } else {
            mlProfileInput = this.createMLProfileInputFromRequestParams(request);
        }
        final String view = RestActionUtils.getStringParam(request, VIEW).orElse(NODE_VIEW);
        String[] nodeIds = mlProfileInput.retrieveProfileOnAllNodes() ? RestActionUtils.getAllNodes(this.clusterService) : mlProfileInput.getNodeIds().toArray(new String[0]);
        final MLProfileRequest mlProfileRequest = new MLProfileRequest(nodeIds, mlProfileInput);
        SearchRequest searchRequest = IndexUtils.buildHiddenModelSearchRequest();
        return channel -> {
            final XContentBuilder builder = channel.newBuilder();
            try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext();){
                client.search(searchRequest, ActionListenerHelper.runBefore((ActionListener)new ActionListener<SearchResponse>(){

                    public void onResponse(SearchResponse searchResponse) {
                        HashSet<String> hiddenModelIds = new HashSet<String>(searchResponse.getHits().getHits().length);
                        for (SearchHit hit : searchResponse.getHits()) {
                            hiddenModelIds.add(hit.getId());
                        }
                        mlProfileRequest.setHiddenModelIds(hiddenModelIds);
                        client.execute((ActionType)MLProfileAction.INSTANCE, (ActionRequest)mlProfileRequest, ActionListenerHelper.wrap(r -> {
                            builder.startObject();
                            List<MLProfileNodeResponse> nodeProfiles = r.getNodes().stream().filter(s -> !s.isEmpty()).collect(Collectors.toList());
                            log.debug("Build MLProfileNodeResponse for size of {}", (Object)nodeProfiles.size());
                            if (nodeProfiles.size() > 0) {
                                if (RestMLProfileAction.NODE_VIEW.equals(view)) {
                                    r.toXContent(builder, ToXContent.EMPTY_PARAMS);
                                } else if (RestMLProfileAction.MODEL_VIEW.equals(view)) {
                                    Map<String, MLProfileModelResponse> modelCentricProfileMap = RestMLProfileAction.this.buildModelCentricResult(nodeProfiles);
                                    builder.startObject("models");
                                    for (Map.Entry<String, MLProfileModelResponse> entry : modelCentricProfileMap.entrySet()) {
                                        builder.field(entry.getKey(), (ToXContent)entry.getValue());
                                    }
                                    builder.endObject();
                                }
                            }
                            builder.endObject();
                            channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.OK, builder));
                        }, e -> {
                            String errorMessage = "Failed to get ML node level profile";
                            log.error(errorMessage, (Throwable)e);
                            RestMLProfileAction.this.onFailed(channel, errorMessage, (Exception)e);
                        }));
                    }

                    public void onFailure(Exception e) {
                        try {
                            builder.startObject();
                            builder.endObject();
                            channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.OK, builder));
                        }
                        catch (IOException ex) {
                            String errorMessage = "Failed to get ML node level profile";
                            log.error(errorMessage, (Throwable)e);
                            RestMLProfileAction.this.onFailed(channel, errorMessage, e);
                        }
                    }
                }, () -> ((ThreadContext.StoredContext)threadContext).restore()));
            }
        };
    }

    private Map<String, MLProfileModelResponse> buildModelCentricResult(List<MLProfileNodeResponse> nodeResponses) {
        HashMap<String, MLProfileModelResponse> modelCentricMap = new HashMap<String, MLProfileModelResponse>();
        for (MLProfileNodeResponse mlProfileNodeResponse : nodeResponses) {
            String nodeId = mlProfileNodeResponse.getNode().getId();
            Map<String, MLModelProfile> modelProfileMap = mlProfileNodeResponse.getMlNodeModels();
            Map<String, MLTask> taskProfileMap = mlProfileNodeResponse.getMlNodeTasks();
            for (Map.Entry<String, MLModelProfile> entry : modelProfileMap.entrySet()) {
                MLProfileModelResponse mlProfileModelResponse = (MLProfileModelResponse)modelCentricMap.get(entry.getKey());
                if (mlProfileModelResponse == null) {
                    mlProfileModelResponse = new MLProfileModelResponse(entry.getValue().getTargetWorkerNodes(), entry.getValue().getWorkerNodes());
                    modelCentricMap.put(entry.getKey(), mlProfileModelResponse);
                }
                if (mlProfileModelResponse.getTargetWorkerNodes() == null || mlProfileModelResponse.getWorkerNodes() == null) {
                    mlProfileModelResponse.setTargetWorkerNodes(entry.getValue().getTargetWorkerNodes());
                    mlProfileModelResponse.setWorkerNodes(entry.getValue().getWorkerNodes());
                }
                MLModelProfile modelProfile = new MLModelProfile(entry.getValue().getModelState(), entry.getValue().getPredictor(), null, null, entry.getValue().getModelInferenceStats(), entry.getValue().getPredictRequestStats(), entry.getValue().getMemSizeEstimationCPU(), entry.getValue().getMemSizeEstimationGPU());
                mlProfileModelResponse.getMlModelProfileMap().putAll(Map.of(nodeId, modelProfile));
            }
            for (Map.Entry<String, MLModelProfile> entry : taskProfileMap.entrySet()) {
                String modelId = ((MLTask)entry.getValue()).getModelId();
                MLProfileModelResponse mlProfileModelResponse = (MLProfileModelResponse)modelCentricMap.get(modelId);
                if (mlProfileModelResponse == null) {
                    mlProfileModelResponse = new MLProfileModelResponse();
                    modelCentricMap.put(modelId, mlProfileModelResponse);
                }
                mlProfileModelResponse.getMlTaskMap().putAll(Map.of(entry.getKey(), (MLTask)entry.getValue()));
            }
        }
        return modelCentricMap;
    }

    MLProfileInput createMLProfileInputFromRequestParams(RestRequest request) {
        MLProfileInput mlProfileInput = new MLProfileInput();
        Optional<String[]> modelIds = RestActionUtils.splitCommaSeparatedParam(request, "model_id");
        String uri = request.getHttpRequest().uri();
        boolean profileModel = uri.contains("models");
        boolean profileTask = uri.contains("tasks");
        if (modelIds.isPresent()) {
            mlProfileInput.getModelIds().addAll(Arrays.asList(modelIds.get()));
        } else if (profileModel) {
            mlProfileInput.setReturnAllModels(true);
        }
        Optional<String[]> taskIds = RestActionUtils.splitCommaSeparatedParam(request, "task_id");
        if (taskIds.isPresent()) {
            mlProfileInput.getTaskIds().addAll(Arrays.asList(taskIds.get()));
        } else if (profileTask) {
            mlProfileInput.setReturnAllTasks(true);
        }
        if (!profileModel && !profileTask) {
            mlProfileInput.setReturnAllTasks(true);
            mlProfileInput.setReturnAllModels(true);
        }
        return mlProfileInput;
    }

    private void onFailed(RestChannel channel, String message, Exception e) {
        try {
            XContentBuilder builder = channel.newBuilder();
            builder.startObject();
            builder.field("error", message);
            builder.field("exception", e.getMessage());
            builder.endObject();
            channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, builder));
        }
        catch (IOException ioException) {
            log.error("Failed to send failure response", (Throwable)ioException);
        }
    }
}

