/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.action.prediction;

import io.lucenia.ml.common.model.MLModelManager;
import io.lucenia.ml.common.model.ModelAccessControlHelper;
import io.lucenia.ml.common.rest.RestActionUtils;
import io.lucenia.ml.common.task.MLPredictTaskRunner;
import io.lucenia.ml.common.task.MLTaskRunner;
import io.skylite.SkyliteStatusException;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionFilters;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.support.HandledTransportAction;
import io.skylite.core.client.Client;
import io.skylite.core.client.metadata.MetadataClient;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.breaker.CircuitBreakingException;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.common.inject.Inject;
import io.skylite.core.rest.RestStatus;
import io.skylite.core.security.auth.User;
import io.skylite.core.settings.Settings;
import io.skylite.core.tasks.Task;
import io.skylite.core.transport.TransportService;
import io.skylite.core.xcontent.MediaTypeRegistry;
import io.skylite.core.xcontent.NamedXContentRegistry;
import io.skylite.core.xcontent.ToXContent;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.cluster.MLNodeUtils;
import io.skylite.ml.common.cluster.TenantAwareHelper;
import io.skylite.ml.common.exception.MLResourceNotFoundException;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.model.MLModel;
import io.skylite.ml.common.model.MLModelCacheHelper;
import io.skylite.ml.common.settings.MLCommonsSettings;
import io.skylite.ml.common.settings.MLFeatureEnabledSetting;
import io.skylite.ml.common.transport.MLTaskResponse;
import io.skylite.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class TransportPredictionTaskAction
extends HandledTransportAction<ActionRequest, MLTaskResponse> {
    private static final Logger log = LogManager.getLogger(TransportPredictionTaskAction.class);
    private MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
    private TransportService transportService;
    private MLModelCacheHelper modelCacheHelper;
    private Client client;
    private MetadataClient metadataClient;
    private ClusterService clusterService;
    private NamedXContentRegistry xContentRegistry;
    private MLModelManager mlModelManager;
    private ModelAccessControlHelper modelAccessControlHelper;
    private volatile boolean enableAutomaticDeployment;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public TransportPredictionTaskAction(TransportService transportService, ActionFilters actionFilters, MLModelCacheHelper modelCacheHelper, MLPredictTaskRunner mlPredictTaskRunner, ClusterService clusterService, Client client, MetadataClient metadataClient, NamedXContentRegistry xContentRegistry, MLModelManager mlModelManager, ModelAccessControlHelper modelAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting, Settings settings) {
        super("cluster:admin/lucenia/ml/predict", transportService, actionFilters, MLPredictionTaskRequest::new);
        this.mlPredictTaskRunner = mlPredictTaskRunner;
        this.transportService = transportService;
        this.modelCacheHelper = modelCacheHelper;
        this.clusterService = clusterService;
        this.client = client;
        this.metadataClient = metadataClient;
        this.xContentRegistry = xContentRegistry;
        this.mlModelManager = mlModelManager;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.enableAutomaticDeployment = (Boolean)MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, it -> {
            this.enableAutomaticDeployment = it;
        });
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
        final MLPredictionTaskRequest mlPredictionTaskRequest = MLPredictionTaskRequest.fromActionRequest((ActionRequest)request);
        final String modelId = mlPredictionTaskRequest.getModelId();
        final String tenantId = mlPredictionTaskRequest.getTenantId();
        if (!TenantAwareHelper.validateTenantId((MLFeatureEnabledSetting)this.mlFeatureEnabledSetting, (String)tenantId, listener)) {
            return;
        }
        User user = mlPredictionTaskRequest.getUser();
        if (user == null) {
            user = RestActionUtils.getUserContext(this.client);
            mlPredictionTaskRequest.setUser(user);
        }
        final User userInfo = user;
        try (final ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            final ActionListener wrappedListener = ActionListenerHelper.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            MLModel cachedMlModel = this.modelCacheHelper.getModelInfo(modelId);
            ActionListener<MLModel> modelActionListener = new ActionListener<MLModel>(){

                public void onResponse(MLModel mlModel) {
                    context.restore();
                    TransportPredictionTaskAction.this.modelCacheHelper.setModelInfo(modelId, mlModel);
                    FunctionName functionName = mlModel.getAlgorithm();
                    if (FunctionName.isDLModel((FunctionName)functionName) && !TransportPredictionTaskAction.this.mlFeatureEnabledSetting.isLocalModelEnabled()) {
                        throw new IllegalStateException("Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.");
                    }
                    mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
                    TransportPredictionTaskAction.this.modelAccessControlHelper.validateModelGroupAccess(userInfo, TransportPredictionTaskAction.this.mlFeatureEnabledSetting, tenantId, mlModel.getModelGroupId(), TransportPredictionTaskAction.this.client, TransportPredictionTaskAction.this.metadataClient, (ActionListener<Boolean>)ActionListenerHelper.wrap(access -> {
                        if (!access.booleanValue()) {
                            wrappedListener.onFailure((Exception)new SkyliteStatusException("User Doesn't have privilege to perform this operation on this model", RestStatus.FORBIDDEN, new Object[0]));
                        } else if (TransportPredictionTaskAction.this.modelCacheHelper.getIsModelEnabled(modelId) != null && !TransportPredictionTaskAction.this.modelCacheHelper.getIsModelEnabled(modelId).booleanValue()) {
                            wrappedListener.onFailure((Exception)new SkyliteStatusException("Model is disabled.", RestStatus.FORBIDDEN, new Object[0]));
                        } else if (FunctionName.isDLModel((FunctionName)functionName)) {
                            if (TransportPredictionTaskAction.this.modelCacheHelper.getRateLimiter(modelId) != null && !TransportPredictionTaskAction.this.modelCacheHelper.getRateLimiter(modelId).request()) {
                                wrappedListener.onFailure((Exception)new SkyliteStatusException("Request is throttled at model level.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
                            } else if (userInfo != null && TransportPredictionTaskAction.this.modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()) != null && !TransportPredictionTaskAction.this.modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()).request()) {
                                wrappedListener.onFailure((Exception)new SkyliteStatusException("Request is throttled at user level. If you think there's an issue, please contact your cluster admin.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
                            } else {
                                TransportPredictionTaskAction.this.validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput());
                                TransportPredictionTaskAction.this.executePredict(mlPredictionTaskRequest, (ActionListener<MLTaskResponse>)wrappedListener, modelId);
                            }
                        } else {
                            TransportPredictionTaskAction.this.validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput());
                            TransportPredictionTaskAction.this.executePredict(mlPredictionTaskRequest, (ActionListener<MLTaskResponse>)wrappedListener, modelId);
                        }
                    }, e -> {
                        log.error("Failed to Validate Access for ModelId {}", (Object)modelId);
                        if (e instanceof SkyliteStatusException) {
                            wrappedListener.onFailure((Exception)new SkyliteStatusException(e.getMessage(), RestStatus.fromCode((int)((SkyliteStatusException)e).status().getStatus()), new Object[0]));
                        } else if (e instanceof MLResourceNotFoundException) {
                            wrappedListener.onFailure((Exception)new SkyliteStatusException(e.getMessage(), RestStatus.NOT_FOUND, new Object[0]));
                        } else if (e instanceof CircuitBreakingException) {
                            wrappedListener.onFailure(e);
                        } else {
                            wrappedListener.onFailure((Exception)new SkyliteStatusException("Failed to Validate Access for ModelId " + modelId, RestStatus.FORBIDDEN, new Object[0]));
                        }
                    }));
                }

                public void onFailure(Exception e) {
                    log.error("Failed to find model {}", (Object)modelId);
                    wrappedListener.onFailure(e);
                }
            };
            if (cachedMlModel != null) {
                modelActionListener.onResponse((Object)cachedMlModel);
            } else {
                this.mlModelManager.getModel(modelId, tenantId, modelActionListener);
            }
        }
    }

    private void executePredict(MLPredictionTaskRequest mlPredictionTaskRequest, ActionListener<MLTaskResponse> wrappedListener, String modelId) {
        String requestId = mlPredictionTaskRequest.getRequestID();
        log.debug("receive predict request {} for model {}", (Object)requestId, (Object)mlPredictionTaskRequest.getModelId());
        long startTime = System.nanoTime();
        FunctionName functionName = this.modelCacheHelper.getOptionalFunctionName(modelId).orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm());
        this.mlPredictTaskRunner.run(functionName, mlPredictionTaskRequest, this.transportService, (ActionListener<MLTaskResponse>)ActionListenerHelper.runAfter(wrappedListener, () -> {
            long endTime = System.nanoTime();
            double durationInMs = (double)(endTime - startTime) / 1000000.0;
            this.modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
            this.modelCacheHelper.refreshLastAccessTime(modelId);
            log.debug("completed predict request {} for model {}", (Object)requestId, (Object)modelId);
        }));
    }

    public void validateInputSchema(String modelId, MLInput mlInput) {
        if (this.modelCacheHelper.getModelInterface(modelId) != null && this.modelCacheHelper.getModelInterface(modelId).get("input") != null) {
            String inputSchemaString = (String)this.modelCacheHelper.getModelInterface(modelId).get("input");
            try {
                String InputString = mlInput.toXContent(MediaTypeRegistry.JSON.contentBuilder(), ToXContent.EMPTY_PARAMS).toString();
                String processedInputString = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue((String)InputString);
                MLNodeUtils.validateSchema((String)inputSchemaString, (String)processedInputString);
            }
            catch (Exception e) {
                throw new SkyliteStatusException("Error validating input schema: " + e.getMessage(), RestStatus.BAD_REQUEST, new Object[0]);
            }
        }
    }
}

