/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.action.register;

import io.lucenia.ml.common.engine.systemindices.MLIndicesHandler;
import io.lucenia.ml.common.helpers.ConnectorAccessControlHelper;
import io.lucenia.ml.common.model.MLModelGroupManager;
import io.lucenia.ml.common.model.MLModelManager;
import io.lucenia.ml.common.model.ModelAccessControlHelper;
import io.lucenia.ml.common.rest.RestActionUtils;
import io.lucenia.ml.common.task.MLTaskDispatcher;
import io.lucenia.ml.common.task.MLTaskManager;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionFilters;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionListenerResponseHandler;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.action.index.IndexResponse;
import io.skylite.core.action.search.SearchResponse;
import io.skylite.core.action.support.HandledTransportAction;
import io.skylite.core.client.Client;
import io.skylite.core.client.metadata.MetadataClient;
import io.skylite.core.cluster.node.DiscoveryNode;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.common.inject.Inject;
import io.skylite.core.security.auth.User;
import io.skylite.core.settings.Settings;
import io.skylite.core.tasks.Task;
import io.skylite.core.threadpool.ThreadPool;
import io.skylite.core.transport.TransportRequest;
import io.skylite.core.transport.TransportResponseHandler;
import io.skylite.core.transport.TransportService;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.cluster.DiscoveryNodeHelper;
import io.skylite.ml.common.cluster.TenantAwareHelper;
import io.skylite.ml.common.connector.Connector;
import io.skylite.ml.common.connector.ConnectorAction;
import io.skylite.ml.common.engine.ModelDownloader;
import io.skylite.ml.common.exception.MLExceptionUtils;
import io.skylite.ml.common.settings.MLCommonsSettings;
import io.skylite.ml.common.settings.MLFeatureEnabledSetting;
import io.skylite.ml.common.stats.MLStats;
import io.skylite.ml.common.task.MLTask;
import io.skylite.ml.common.task.MLTaskState;
import io.skylite.ml.common.task.MLTaskType;
import io.skylite.ml.common.transport.connector.MLCreateConnectorAction;
import io.skylite.ml.common.transport.connector.MLCreateConnectorInput;
import io.skylite.ml.common.transport.connector.MLCreateConnectorRequest;
import io.skylite.ml.common.transport.forward.MLForwardInput;
import io.skylite.ml.common.transport.forward.MLForwardRequest;
import io.skylite.ml.common.transport.forward.MLForwardRequestType;
import io.skylite.ml.common.transport.forward.MLForwardResponse;
import io.skylite.ml.common.transport.model_group.MLRegisterModelGroupInput;
import io.skylite.ml.common.transport.register.MLRegisterModelInput;
import io.skylite.ml.common.transport.register.MLRegisterModelRequest;
import io.skylite.ml.common.transport.register.MLRegisterModelResponse;
import io.skylite.ml.common.utils.ModelInterfaceUtils;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.util.Strings;

public class TransportRegisterModelAction
extends HandledTransportAction<ActionRequest, MLRegisterModelResponse> {
    private static final Logger log = LogManager.getLogger(TransportRegisterModelAction.class);
    TransportService transportService;
    ModelDownloader modelHelper;
    MLIndicesHandler mlIndicesHandler;
    MLModelManager mlModelManager;
    MLTaskManager mlTaskManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    private final MetadataClient sdkClient;
    Settings settings;
    DiscoveryNodeHelper nodeFilter;
    MLTaskDispatcher mlTaskDispatcher;
    MLStats mlStats;
    volatile String trustedUrlRegex;
    private List<String> trustedConnectorEndpointsRegex;
    ModelAccessControlHelper modelAccessControlHelper;
    private volatile boolean isModelUrlAllowed;
    ConnectorAccessControlHelper connectorAccessControlHelper;
    MLModelGroupManager mlModelGroupManager;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public TransportRegisterModelAction(TransportService transportService, ActionFilters actionFilters, ModelDownloader modelHelper, MLIndicesHandler mlIndicesHandler, MLModelManager mlModelManager, MLTaskManager mlTaskManager, ClusterService clusterService, Settings settings, ThreadPool threadPool, Client client, MetadataClient sdkClient, DiscoveryNodeHelper nodeFilter, MLTaskDispatcher mlTaskDispatcher, MLStats mlStats, ModelAccessControlHelper modelAccessControlHelper, ConnectorAccessControlHelper connectorAccessControlHelper, MLModelGroupManager mlModelGroupManager, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        super("cluster:admin/lucenia/ml/register_model", transportService, actionFilters, MLRegisterModelRequest::new);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlIndicesHandler = mlIndicesHandler;
        this.mlModelManager = mlModelManager;
        this.mlTaskManager = mlTaskManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.sdkClient = sdkClient;
        this.nodeFilter = nodeFilter;
        this.mlTaskDispatcher = mlTaskDispatcher;
        this.mlStats = mlStats;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.connectorAccessControlHelper = connectorAccessControlHelper;
        this.mlModelGroupManager = mlModelGroupManager;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.settings = settings;
        this.trustedUrlRegex = (String)MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX, it -> {
            this.trustedUrlRegex = it;
        });
        this.isModelUrlAllowed = (Boolean)MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL, it -> {
            this.isModelUrlAllowed = it;
        });
        this.trustedConnectorEndpointsRegex = (List)MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, it -> {
            this.trustedConnectorEndpointsRegex = it;
        });
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterModelResponse> listener) {
        MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest((ActionRequest)request);
        MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput();
        if (!TenantAwareHelper.validateTenantId((MLFeatureEnabledSetting)this.mlFeatureEnabledSetting, (String)registerModelInput.getTenantId(), listener)) {
            return;
        }
        if (FunctionName.isDLModel((FunctionName)registerModelInput.getFunctionName()) && !this.mlFeatureEnabledSetting.isLocalModelEnabled()) {
            throw new IllegalStateException("Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.");
        }
        if (registerModelInput.getUrl() != null && !this.isModelUrlAllowed) {
            throw new IllegalArgumentException("To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use Lucenia pre-trained models.");
        }
        registerModelInput.setIsHidden(Boolean.valueOf(RestActionUtils.isSuperAdminUser(this.clusterService, this.client)));
        if (Strings.isEmpty((CharSequence)registerModelInput.getModelGroupId())) {
            this.mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), registerModelInput.getTenantId(), (ActionListener<SearchResponse>)ActionListenerHelper.wrap(modelGroups -> {
                if (modelGroups != null && modelGroups.getHits().getTotalHits() != null && modelGroups.getHits().getTotalHits().value() != 0L) {
                    String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId();
                    registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided);
                    this.checkUserAccess(registerModelInput, listener, true);
                } else {
                    this.doRegister(registerModelInput, listener);
                }
            }, e -> {
                log.error("Failed to search model group index", (Throwable)e);
                listener.onFailure(e);
            }));
        } else {
            this.checkUserAccess(registerModelInput, listener, false);
        }
    }

    private void checkUserAccess(MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelResponse> listener, Boolean isModelNameAlreadyExisting) {
        User user = RestActionUtils.getUserContext(this.client);
        this.modelAccessControlHelper.validateModelGroupAccess(user, this.mlFeatureEnabledSetting, registerModelInput.getTenantId(), registerModelInput.getModelGroupId(), this.client, this.sdkClient, (ActionListener<Boolean>)ActionListenerHelper.wrap(access -> {
            if (access.booleanValue()) {
                this.doRegister(registerModelInput, listener);
                return;
            }
            if (isModelNameAlreadyExisting.booleanValue()) {
                if (registerModelInput.getUrl() == null && registerModelInput.getFunctionName() != FunctionName.REMOTE && registerModelInput.getConnectorId() == null) {
                    listener.onFailure((Exception)new IllegalArgumentException("Without a model group ID, the system will use the model name {" + registerModelInput.getModelName() + "} to create a new model group. However, this name is taken by another group with id {" + registerModelInput.getModelGroupId() + "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request."));
                } else {
                    listener.onFailure((Exception)new IllegalArgumentException("The name {" + registerModelInput.getModelName() + "} you provided is unavailable because it is used by another model group with id {" + registerModelInput.getModelGroupId() + "} to which you do not have access. Please provide a different name."));
                }
                return;
            }
            listener.onFailure((Exception)new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelResponse> listener) {
        FunctionName functionName = registerModelInput.getFunctionName();
        if (FunctionName.REMOTE == functionName) {
            if (Strings.isNotBlank((String)registerModelInput.getConnectorId())) {
                this.connectorAccessControlHelper.validateConnectorAccess(this.sdkClient, this.client, registerModelInput.getConnectorId(), registerModelInput.getTenantId(), this.mlFeatureEnabledSetting, (ActionListener<Boolean>)ActionListenerHelper.wrap(r -> {
                    if (Boolean.TRUE.equals(r)) {
                        if (registerModelInput.getModelInterface() == null) {
                            this.mlModelManager.getConnector(registerModelInput.getConnectorId(), registerModelInput.getTenantId(), (ActionListener<Connector>)ActionListenerHelper.wrap(connector -> {
                                ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector((MLRegisterModelInput)registerModelInput, (Connector)connector);
                                this.createModelGroup(registerModelInput, listener);
                            }, arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
                        } else {
                            this.createModelGroup(registerModelInput, listener);
                        }
                    } else {
                        listener.onFailure((Exception)new IllegalArgumentException("You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId()));
                    }
                }, e -> {
                    log.error("You don't have permission to use the connector provided, connector id: {}", (Object)registerModelInput.getConnectorId());
                    listener.onFailure(e);
                }));
            } else {
                this.validateInternalConnector(registerModelInput);
                ActionListener dryRunResultListener = ActionListenerHelper.wrap(res -> {
                    log.info("Dry run create connector successfully");
                    if (registerModelInput.getModelInterface() == null) {
                        ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector((MLRegisterModelInput)registerModelInput);
                    }
                    this.createModelGroup(registerModelInput, listener);
                }, e -> {
                    log.error(e.getMessage(), (Throwable)e);
                    listener.onFailure(e);
                });
                MLCreateConnectorRequest mlCreateConnectorRequest = this.createDryRunConnectorRequest();
                this.client.execute((ActionType)MLCreateConnectorAction.INSTANCE, (ActionRequest)mlCreateConnectorRequest, dryRunResultListener);
            }
        } else {
            this.createModelGroup(registerModelInput, listener);
        }
    }

    private void createModelGroup(MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelResponse> listener) {
        if (Strings.isEmpty((CharSequence)registerModelInput.getModelGroupId())) {
            MLRegisterModelGroupInput mlRegisterModelGroupInput = this.createRegisterModelGroupRequest(registerModelInput);
            this.mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, (ActionListener<String>)ActionListenerHelper.wrap(modelGroupId -> {
                registerModelInput.setModelGroupId(modelGroupId);
                registerModelInput.setDoesVersionCreateModelGroup(Boolean.valueOf(true));
                this.registerModel(registerModelInput, listener);
            }, e -> {
                MLExceptionUtils.logException((String)"Failed to create Model Group", (Exception)e, (Logger)log);
                listener.onFailure(e);
            }));
        } else {
            registerModelInput.setDoesVersionCreateModelGroup(Boolean.valueOf(false));
            this.registerModel(registerModelInput, listener);
        }
    }

    private MLCreateConnectorRequest createDryRunConnectorRequest() {
        MLCreateConnectorInput createConnectorInput = MLCreateConnectorInput.builder().dryRun(true).build();
        return new MLCreateConnectorRequest(createConnectorInput);
    }

    private void validateInternalConnector(MLRegisterModelInput registerModelInput) {
        if (registerModelInput.getConnector() == null) {
            log.error("You must provide connector content when creating a remote model without providing connector id!");
            throw new IllegalArgumentException("You must provide connector content when creating a remote model without connector id!");
        }
        if (registerModelInput.getConnector().getActionEndpoint(ConnectorAction.ActionType.PREDICT.name(), registerModelInput.getConnector().getParameters()) == null) {
            log.error("Connector endpoint is required when creating a remote model without connector id!");
            throw new IllegalArgumentException("Connector endpoint is required when creating a remote model without connector id!");
        }
        registerModelInput.getConnector().validateConnectorURL(this.trustedConnectorEndpointsRegex);
    }

    private void registerModel(MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelResponse> listener) {
        boolean validUrl;
        Pattern pattern = Pattern.compile(this.trustedUrlRegex);
        String url = registerModelInput.getUrl();
        if (url != null && !(validUrl = pattern.matcher(url).find())) {
            throw new IllegalArgumentException("URL can't match trusted url regex");
        }
        boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE;
        MLTask mlTask = MLTask.builder().async(isAsync).taskType(MLTaskType.REGISTER_MODEL).functionName(registerModelInput.getFunctionName()).createTime(Instant.now()).lastUpdateTime(Instant.now()).state(MLTaskState.CREATED).workerNodes(List.of(this.clusterService.localNode().getId())).tenantId(registerModelInput.getTenantId()).build();
        if (!isAsync) {
            this.mlTaskManager.createMLTask(mlTask, (ActionListener<IndexResponse>)ActionListenerHelper.wrap(response -> {
                String taskId = response.getId();
                mlTask.setTaskId(taskId);
                this.mlModelManager.registerMLRemoteModel(this.sdkClient, registerModelInput, mlTask, listener);
            }, e -> {
                MLExceptionUtils.logException((String)"Failed to register model", (Exception)e, (Logger)log);
                listener.onFailure(e);
            }));
            return;
        }
        this.mlTaskDispatcher.dispatch(registerModelInput.getFunctionName(), (ActionListener<DiscoveryNode>)ActionListenerHelper.wrap(node -> {
            String nodeId = node.getId();
            mlTask.setWorkerNodes(List.of(nodeId));
            this.mlTaskManager.createMLTask(mlTask, (ActionListener<IndexResponse>)ActionListenerHelper.wrap(response -> {
                String taskId = response.getId();
                mlTask.setTaskId(taskId);
                listener.onResponse((Object)new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name()));
                ActionListener forwardActionListener = ActionListenerHelper.wrap(res -> {
                    log.debug("Register model response: {}", res);
                    if (!this.clusterService.localNode().getId().equals(nodeId)) {
                        this.mlTaskManager.remove(taskId);
                    }
                }, ex -> {
                    MLExceptionUtils.logException((String)"Failed to register model", (Exception)ex, (Logger)log);
                    this.mlTaskManager.updateMLTask(taskId, registerModelInput.getTenantId(), Map.of("error", MLExceptionUtils.getRootCauseMessage((Throwable)ex), "state", MLTaskState.FAILED), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
                });
                try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                    this.mlTaskManager.add(mlTask, Arrays.asList(nodeId));
                    MLForwardInput forwardInput = MLForwardInput.builder().requestType(MLForwardRequestType.REGISTER_MODEL).registerModelInput(registerModelInput).mlTask(mlTask).build();
                    MLForwardRequest forwardRequest = new MLForwardRequest(forwardInput);
                    this.transportService.sendRequest(node, "cluster:admin/opensearch/mlinternal/forward", (TransportRequest)forwardRequest, (TransportResponseHandler)new ActionListenerResponseHandler(forwardActionListener, MLForwardResponse::new));
                }
                catch (Exception e) {
                    forwardActionListener.onFailure(e);
                }
            }, e -> {
                MLExceptionUtils.logException((String)"Failed to register model", (Exception)e, (Logger)log);
                listener.onFailure(e);
            }));
        }, e -> {
            MLExceptionUtils.logException((String)"Failed to register model", (Exception)e, (Logger)log);
            listener.onFailure(e);
        }));
    }

    private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterModelInput registerModelInput) {
        return MLRegisterModelGroupInput.builder().name(registerModelInput.getModelName()).description(registerModelInput.getDescription()).backendRoles(registerModelInput.getBackendRoles()).modelAccessMode(registerModelInput.getAccessMode()).isAddAllBackendRoles(registerModelInput.getAddAllBackendRoles()).tenantId(registerModelInput.getTenantId()).build();
    }
}

