/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.plugin.transport;

import io.skylite.common.ValidationException;
import io.skylite.common.action.ActionListener;
import io.skylite.common.transport.TransportRequestOptions;
import io.skylite.core.action.ActionFilters;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionListenerResponseHandler;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.action.support.HandledTransportAction;
import io.skylite.core.client.Client;
import io.skylite.core.cluster.node.DiscoveryNode;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.inject.Inject;
import io.skylite.core.search.SearchRequest;
import io.skylite.core.search.builder.SearchSourceBuilder;
import io.skylite.core.tasks.Task;
import io.skylite.core.transport.TransportRequest;
import io.skylite.core.transport.TransportResponseHandler;
import io.skylite.core.transport.TransportService;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.plugin.transport.TrainingJobRouteDecisionInfoAction;
import org.opensearch.knn.plugin.transport.TrainingJobRouteDecisionInfoNodeResponse;
import org.opensearch.knn.plugin.transport.TrainingJobRouteDecisionInfoRequest;
import org.opensearch.knn.plugin.transport.TrainingJobRouteDecisionInfoResponse;
import org.opensearch.knn.plugin.transport.TrainingModelRequest;
import org.opensearch.knn.plugin.transport.TrainingModelResponse;

public class TrainingJobRouterTransportAction
extends HandledTransportAction<TrainingModelRequest, TrainingModelResponse> {
    private final TransportService transportService;
    private final ClusterService clusterService;
    private final Client client;

    @Inject
    public TrainingJobRouterTransportAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, Client client) {
        super("cluster:admin/knn_training_job_router_action", transportService, actionFilters, TrainingModelRequest::new);
        this.clusterService = clusterService;
        this.client = client;
        this.transportService = transportService;
    }

    protected void doExecute(Task task, TrainingModelRequest request, ActionListener<TrainingModelResponse> listener) {
        this.getTrainingIndexSizeInKB(request, (ActionListener<Integer>)ActionListenerHelper.wrap(size -> {
            request.setTrainingDataSizeInKB((int)size);
            this.routeRequest(request, listener);
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    protected void routeRequest(TrainingModelRequest request, ActionListener<TrainingModelResponse> listener) {
        this.client.execute((ActionType)TrainingJobRouteDecisionInfoAction.INSTANCE, (ActionRequest)new TrainingJobRouteDecisionInfoRequest(new String[0]), ActionListenerHelper.wrap(response -> {
            DiscoveryNode node = this.selectNode(request.getPreferredNodeId(), (TrainingJobRouteDecisionInfoResponse)((Object)response));
            if (node == null) {
                ValidationException exception = new ValidationException();
                exception.addValidationError("Cluster does not have capacity to train");
                listener.onFailure((Exception)exception);
                return;
            }
            this.transportService.sendRequest(node, "cluster:admin/knn_training_model_action", (TransportRequest)request, TransportRequestOptions.EMPTY, (TransportResponseHandler)new ActionListenerResponseHandler(listener, TrainingModelResponse::new));
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    protected DiscoveryNode selectNode(String preferredNode, TrainingJobRouteDecisionInfoResponse jobInfo) {
        DiscoveryNode selectedNode = null;
        Map eligibleNodes = this.clusterService.state().nodes().getDataNodes();
        for (TrainingJobRouteDecisionInfoNodeResponse response : jobInfo.getNodes()) {
            DiscoveryNode currentNode = response.getNode();
            if (!eligibleNodes.containsKey(currentNode.getId()) || response.getTrainingJobCount() >= 1) continue;
            selectedNode = currentNode;
            if (!StringUtils.isEmpty((String)preferredNode) && !selectedNode.getId().equals(preferredNode)) continue;
            return selectedNode;
        }
        return selectedNode;
    }

    protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelRequest, ActionListener<Integer> listener) {
        SearchRequest countRequest = new SearchRequest(new String[]{trainingModelRequest.getTrainingIndex()});
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).trackTotalHits(true);
        countRequest.source(searchSourceBuilder);
        searchSourceBuilder.terminateAfter(0);
        this.client.search(countRequest, ActionListenerHelper.wrap(searchResponse -> {
            long trainingVectors = searchResponse.getHits().getTotalHits().value();
            if ((long)trainingModelRequest.getMaximumVectorCount() < trainingVectors) {
                trainingVectors = trainingModelRequest.getMaximumVectorCount();
            }
            listener.onResponse((Object)TrainingJobRouterTransportAction.estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension(), trainingModelRequest.getVectorDataType()));
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    public static int estimateVectorSetSizeInKB(long vectorCount, int dimension, VectorDataType vectorDataType) {
        switch (vectorDataType) {
            case BINARY: {
                return Math.toIntExact((long)(1 * (dimension / 8)) * vectorCount / (long)KNNConstants.BYTES_PER_KILOBYTES.intValue() + 1L);
            }
            case BYTE: {
                return Math.toIntExact((long)(1 * dimension) * vectorCount / (long)KNNConstants.BYTES_PER_KILOBYTES.intValue() + 1L);
            }
        }
        return Math.toIntExact((long)(4 * dimension) * vectorCount / (long)KNNConstants.BYTES_PER_KILOBYTES.intValue() + 1L);
    }
}

