/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.rest;

import io.lucenia.ml.common.action.stats.MLStatsNodesAction;
import io.lucenia.ml.common.action.stats.MLStatsNodesRequest;
import io.lucenia.ml.common.rest.MLRestHandler;
import io.lucenia.ml.common.rest.RestActionUtils;
import io.lucenia.ml.common.utils.IndexUtils;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.action.search.SearchResponse;
import io.skylite.core.client.node.NodeClient;
import io.skylite.core.cluster.node.DiscoveryNode;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.rest.RestChannel;
import io.skylite.core.rest.RestHandler;
import io.skylite.core.rest.RestRequest;
import io.skylite.core.rest.RestResponse;
import io.skylite.core.rest.RestStatus;
import io.skylite.core.search.SearchHit;
import io.skylite.core.search.SearchRequest;
import io.skylite.core.xcontent.NamedXContentRegistry;
import io.skylite.core.xcontent.ToXContent;
import io.skylite.core.xcontent.XContentBuilder;
import io.skylite.core.xcontent.XContentParser;
import io.skylite.core.xcontent.XContentParserUtils;
import io.skylite.ml.common.stats.MLClusterLevelStat;
import io.skylite.ml.common.stats.MLNodeLevelStat;
import io.skylite.ml.common.stats.MLStat;
import io.skylite.ml.common.stats.MLStatLevel;
import io.skylite.ml.common.stats.MLStats;
import io.skylite.ml.common.stats.MLStatsInput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;

public class RestMLStatsAction
extends MLRestHandler {
    private static final Logger log = LogManager.getLogger(RestMLStatsAction.class);
    private static final String STATS_ML_ACTION = "stats_ml";
    private MLStats mlStats;
    private ClusterService clusterService;
    private IndexUtils indexUtils;
    private NamedXContentRegistry xContentRegistry;
    private static final String QUERY_ALL_MODEL_META_DOC = "{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}";
    private static final Set<String> ML_NODE_STAT_NAMES = EnumSet.allOf(MLNodeLevelStat.class).stream().map(stat -> stat.name()).collect(Collectors.toSet());

    public RestMLStatsAction(MLStats mlStats, ClusterService clusterService, IndexUtils indexUtils, NamedXContentRegistry xContentRegistry) {
        this.mlStats = mlStats;
        this.clusterService = clusterService;
        this.indexUtils = indexUtils;
        this.xContentRegistry = xContentRegistry;
    }

    public String getName() {
        return STATS_ML_ACTION;
    }

    public List<RestHandler.Route> routes() {
        return List.of(new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/{nodeId}/stats/"), new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/{nodeId}/stats/{stat}"), new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/stats/"), new RestHandler.Route(RestRequest.Method.GET, "/_plugins/_ml/stats/{stat}"));
    }

    protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, final NodeClient client) throws IOException {
        MLStatsInput mlStatsInput;
        boolean hasContent = request.hasContent();
        if (hasContent) {
            XContentParser parser = request.contentParser();
            XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
            mlStatsInput = MLStatsInput.parse((XContentParser)parser);
        } else {
            mlStatsInput = this.createMlStatsInputFromRequestParams(request);
        }
        String[] nodeIds = mlStatsInput.retrieveStatsOnAllNodes() ? this.getAllNodes() : mlStatsInput.getNodeIds().toArray(new String[0]);
        final MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(nodeIds, mlStatsInput);
        final HashMap<MLClusterLevelStat, Object> clusterStatsMap = new HashMap<MLClusterLevelStat, Object>();
        if (mlStatsInput.getTargetStatLevels().contains(MLStatLevel.CLUSTER)) {
            clusterStatsMap.putAll(this.getClusterStatsMap(mlStatsInput));
        }
        final MLStatsInput finalMlStatsInput = mlStatsInput;
        return channel -> {
            try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext();){
                SearchRequest searchRequest = IndexUtils.buildHiddenModelSearchRequest();
                client.search(searchRequest, ActionListenerHelper.runBefore((ActionListener)new ActionListener<SearchResponse>(){

                    public void onResponse(SearchResponse searchResponse) {
                        HashSet<String> hiddenModelIds = new HashSet<String>(searchResponse.getHits().getHits().length);
                        for (SearchHit hit : searchResponse.getHits()) {
                            hiddenModelIds.add(hit.getId());
                        }
                        mlStatsNodesRequest.setHiddenModelIds(hiddenModelIds);
                        if (finalMlStatsInput.getTargetStatLevels().contains(MLStatLevel.CLUSTER) && (finalMlStatsInput.retrieveAllClusterLevelStats() || finalMlStatsInput.getClusterLevelStats().contains(MLClusterLevelStat.ML_MODEL_COUNT))) {
                            RestMLStatsAction.this.indexUtils.getNumberOfDocumentsInIndex(".plugins-ml-model", RestMLStatsAction.QUERY_ALL_MODEL_META_DOC, RestMLStatsAction.this.xContentRegistry, (ActionListener<Long>)ActionListenerHelper.wrap(modelCount -> {
                                clusterStatsMap.put(MLClusterLevelStat.ML_MODEL_COUNT, modelCount);
                                RestMLStatsAction.this.indexUtils.getNumberOfDocumentsInIndex(".plugins-ml-connector", (ActionListener<Long>)ActionListenerHelper.wrap(connectorCount -> {
                                    clusterStatsMap.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, connectorCount);
                                    RestMLStatsAction.this.getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
                                }, e -> {
                                    String errorMessage = "Failed to get ML model count";
                                    log.error(errorMessage, (Throwable)e);
                                    RestMLStatsAction.this.onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, (Exception)e);
                                }));
                            }, e -> {
                                String errorMessage = "Failed to get ML model count";
                                log.error(errorMessage, (Throwable)e);
                                RestMLStatsAction.this.onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, (Exception)e);
                            }));
                        } else {
                            try {
                                RestMLStatsAction.this.getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
                            }
                            catch (IOException e2) {
                                RestMLStatsAction.this.onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to retrieve Cluster level metrics", e2);
                            }
                        }
                    }

                    public void onFailure(Exception e) {
                        try {
                            RestMLStatsAction.this.getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
                        }
                        catch (IOException ex) {
                            RestMLStatsAction.this.onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to retrieve Cluster level metrics", e);
                        }
                    }
                }, () -> ((ThreadContext.StoredContext)threadContext).restore()));
            }
        };
    }

    MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) {
        Optional<String[]> stats;
        MLStatsInput mlStatsInput = new MLStatsInput();
        Optional<String[]> nodeIds = RestActionUtils.splitCommaSeparatedParam(request, "nodeId");
        if (nodeIds.isPresent()) {
            mlStatsInput.getNodeIds().addAll(Arrays.asList(nodeIds.get()));
        }
        if ((stats = RestActionUtils.splitCommaSeparatedParam(request, "stat")).isPresent()) {
            for (String state : stats.get()) {
                if (ML_NODE_STAT_NAMES.contains(state = state.toUpperCase(Locale.ROOT))) {
                    mlStatsInput.getNodeLevelStats().add(MLNodeLevelStat.from((String)state));
                    continue;
                }
                mlStatsInput.getClusterLevelStats().add(MLClusterLevelStat.from((String)state));
            }
            if (mlStatsInput.getClusterLevelStats().size() > 0) {
                mlStatsInput.getTargetStatLevels().add(MLStatLevel.CLUSTER);
            }
            if (mlStatsInput.getNodeLevelStats().size() > 0) {
                mlStatsInput.getTargetStatLevels().add(MLStatLevel.NODE);
            }
        } else {
            mlStatsInput.getTargetStatLevels().addAll(EnumSet.allOf(MLStatLevel.class));
        }
        return mlStatsInput;
    }

    void getNodeStats(MLStatsInput mlStatsInput, Map<MLClusterLevelStat, Object> clusterStatsMap, NodeClient client, MLStatsNodesRequest mlStatsNodesRequest, RestChannel channel) throws IOException {
        XContentBuilder builder = channel.newBuilder();
        if (mlStatsInput.onlyRetrieveClusterLevelStats()) {
            builder.startObject();
            if (clusterStatsMap != null && clusterStatsMap.size() > 0) {
                for (Map.Entry<MLClusterLevelStat, Object> entry : clusterStatsMap.entrySet()) {
                    builder.field(entry.getKey().name().toLowerCase(Locale.ROOT), entry.getValue());
                }
            }
            builder.endObject();
            channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.OK, builder));
        } else {
            client.execute((ActionType)MLStatsNodesAction.INSTANCE, (ActionRequest)mlStatsNodesRequest, ActionListenerHelper.wrap(r -> {
                List nodeStats;
                builder.startObject();
                if (clusterStatsMap != null && clusterStatsMap.size() > 0) {
                    for (Map.Entry entry : clusterStatsMap.entrySet()) {
                        builder.field(((MLClusterLevelStat)entry.getKey()).name().toLowerCase(Locale.ROOT), entry.getValue());
                    }
                }
                if ((nodeStats = r.getNodes().stream().filter(s -> !s.isEmpty()).collect(Collectors.toList())) != null && nodeStats.size() > 0) {
                    r.toXContent(builder, ToXContent.EMPTY_PARAMS);
                }
                builder.endObject();
                channel.sendResponse((RestResponse)new BytesRestResponse(RestStatus.OK, builder));
            }, e -> {
                String errorMessage = "Failed to get ML node level stats";
                log.error(errorMessage, (Throwable)e);
                this.onFailed(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, (Exception)e);
            }));
        }
    }

    private String[] getAllNodes() {
        Iterator iterator = this.clusterService.state().nodes().iterator();
        ArrayList<String> nodeIds = new ArrayList<String>();
        while (iterator.hasNext()) {
            nodeIds.add(((DiscoveryNode)iterator.next()).getId());
        }
        return nodeIds.toArray(new String[0]);
    }

    private void onFailed(RestChannel channel, RestStatus status, String errorMessage, Exception exception) {
        BytesRestResponse bytesRestResponse;
        try {
            bytesRestResponse = new BytesRestResponse(channel, exception);
        }
        catch (Exception e) {
            bytesRestResponse = new BytesRestResponse(status, errorMessage);
        }
        channel.sendResponse((RestResponse)bytesRestResponse);
    }

    private Map<MLClusterLevelStat, Object> getClusterStatsMap(MLStatsInput mlStatsInput) {
        HashMap<MLClusterLevelStat, Object> clusterStats = new HashMap<MLClusterLevelStat, Object>();
        this.mlStats.getClusterStats().entrySet().stream().filter(s -> mlStatsInput.retrieveStat((Enum)s.getKey())).forEach(s -> clusterStats.put((MLClusterLevelStat)s.getKey(), ((MLStat)s.getValue()).getValue()));
        return clusterStats;
    }
}

