/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.action.batch;

import io.lucenia.ml.common.model.MLModelManager;
import io.lucenia.ml.common.task.MLTaskManager;
import io.skylite.SkyliteStatusException;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionFilters;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.index.IndexResponse;
import io.skylite.core.action.support.HandledTransportAction;
import io.skylite.core.client.Client;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.Strings;
import io.skylite.core.common.inject.Inject;
import io.skylite.core.rest.RestStatus;
import io.skylite.core.settings.Settings;
import io.skylite.core.tasks.Task;
import io.skylite.core.threadpool.ThreadPool;
import io.skylite.core.transport.TransportService;
import io.skylite.ml.common.engine.MLEngineClassLoader;
import io.skylite.ml.common.engine.ingest.Ingestable;
import io.skylite.ml.common.exception.MLExceptionUtils;
import io.skylite.ml.common.settings.MLCommonsSettings;
import io.skylite.ml.common.settings.MLFeatureEnabledSetting;
import io.skylite.ml.common.task.MLTask;
import io.skylite.ml.common.task.MLTaskState;
import io.skylite.ml.common.task.MLTaskType;
import io.skylite.ml.common.transport.batch.MLBatchIngestionInput;
import io.skylite.ml.common.transport.batch.MLBatchIngestionRequest;
import io.skylite.ml.common.transport.batch.MLBatchIngestionResponse;
import java.time.Instant;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class TransportBatchIngestionAction
extends HandledTransportAction<ActionRequest, MLBatchIngestionResponse> {
    private static final Logger log = LogManager.getLogger(TransportBatchIngestionAction.class);
    private static final String S3_URI_REGEX = "^s3://([a-zA-Z0-9.-]+)(/.*)?$";
    private static final Pattern S3_URI_PATTERN = Pattern.compile("^s3://([a-zA-Z0-9.-]+)(/.*)?$");
    public static final String TYPE = "type";
    public static final String SOURCE = "source";
    TransportService transportService;
    MLTaskManager mlTaskManager;
    MLModelManager mlModelManager;
    private final Client client;
    private ThreadPool threadPool;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private volatile Integer batchIngestionBulkSize;

    @Inject
    public TransportBatchIngestionAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, Client client, MLTaskManager mlTaskManager, ThreadPool threadPool, MLModelManager mlModelManager, MLFeatureEnabledSetting mlFeatureEnabledSetting, Settings settings) {
        super("cluster:admin/lucenia/ml/batch_ingestion", transportService, actionFilters, MLBatchIngestionRequest::new);
        this.transportService = transportService;
        this.client = client;
        this.mlTaskManager = mlTaskManager;
        this.threadPool = threadPool;
        this.mlModelManager = mlModelManager;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.batchIngestionBulkSize = (Integer)MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE, it -> {
            this.batchIngestionBulkSize = it;
        });
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatchIngestionResponse> listener) {
        MLBatchIngestionRequest mlBatchIngestionRequest = MLBatchIngestionRequest.fromActionRequest((ActionRequest)request);
        MLBatchIngestionInput mlBatchIngestionInput = mlBatchIngestionRequest.getMlBatchIngestionInput();
        try {
            if (!this.mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled().booleanValue()) {
                throw new IllegalStateException("Offline batch ingestion is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_ingestion_enabled\" to true.");
            }
            this.validateBatchIngestInput(mlBatchIngestionInput);
            if (mlBatchIngestionInput.getConnectorId() != null && (mlBatchIngestionInput.getCredential() == null || mlBatchIngestionInput.getCredential().isEmpty())) {
                this.mlModelManager.getConnectorCredential(mlBatchIngestionInput.getConnectorId(), (ActionListener<Map<String, String>>)ActionListenerHelper.wrap(credentialMap -> {
                    mlBatchIngestionInput.setCredential(credentialMap);
                    this.createMLTaskandExecute(mlBatchIngestionInput, listener);
                }, e -> {
                    log.error(e.getMessage());
                    listener.onFailure((Exception)new SkyliteStatusException("Fail to fetch credentials from the connector in the batch ingestion input: " + e.getMessage(), RestStatus.BAD_REQUEST, new Object[0]));
                }));
            } else {
                this.createMLTaskandExecute(mlBatchIngestionInput, listener);
            }
        }
        catch (IllegalArgumentException e2) {
            log.error(e2.getMessage());
            listener.onFailure((Exception)new SkyliteStatusException("IllegalArgumentException in the batch ingestion input: " + e2.getMessage(), RestStatus.BAD_REQUEST, new Object[0]));
        }
        catch (Exception e3) {
            listener.onFailure(e3);
        }
    }

    protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInput, ActionListener<MLBatchIngestionResponse> listener) {
        MLTask mlTask = MLTask.builder().async(true).taskType(MLTaskType.BATCH_INGEST).createTime(Instant.now()).lastUpdateTime(Instant.now()).state(MLTaskState.CREATED).build();
        this.mlModelManager.checkMaxBatchJobTask(mlTask, (ActionListener<Boolean>)ActionListenerHelper.wrap(exceedLimits -> {
            if (exceedLimits.booleanValue()) {
                String error = "Exceeded maximum limit for BATCH_INGEST tasks. To increase the limit, update the plugins.ml_commons.max_batch_ingestion_tasks setting.";
                log.warn(error + " in task " + mlTask.getTaskId());
                listener.onFailure((Exception)new SkyliteStatusException(error, RestStatus.TOO_MANY_REQUESTS, new Object[0]));
            } else {
                this.mlTaskManager.createMLTask(mlTask, (ActionListener<IndexResponse>)ActionListenerHelper.wrap(response -> {
                    String taskId = response.getId();
                    try {
                        mlTask.setTaskId(taskId);
                        this.mlTaskManager.add(mlTask);
                        listener.onResponse((Object)new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
                        String ingestType = (String)mlBatchIngestionInput.getDataSources().get(TYPE);
                        Ingestable ingestable = (Ingestable)MLEngineClassLoader.initInstance((Object)ingestType.toLowerCase(Locale.ROOT), (Object)this.client, Client.class);
                        this.threadPool.executor("lucenia_ml_ingest").execute(() -> this.executeWithErrorHandling(() -> {
                            double successRate = ingestable.ingest(mlBatchIngestionInput, this.batchIngestionBulkSize.intValue());
                            this.handleSuccessRate(successRate, taskId);
                        }, taskId));
                    }
                    catch (Exception ex) {
                        log.error("Failed in batch ingestion", (Throwable)ex);
                        this.mlTaskManager.updateMLTask(taskId, null, Map.of("state", MLTaskState.FAILED, "error", MLExceptionUtils.getRootCauseMessage((Throwable)ex)), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
                        listener.onFailure(ex);
                    }
                }, exception -> {
                    log.error("Failed to create batch ingestion task", (Throwable)exception);
                    listener.onFailure(exception);
                }));
            }
        }, exception -> {
            log.error("Failed to check the maximum BATCH_INGEST Task limits", (Throwable)exception);
            listener.onFailure(exception);
        }));
    }

    protected void executeWithErrorHandling(Runnable task, String taskId) {
        try {
            task.run();
        }
        catch (Strings.JsonFieldMissingException jsonPathNotFoundException) {
            log.error("Error in jsonParse fields", (Throwable)jsonPathNotFoundException);
            this.mlTaskManager.updateMLTask(taskId, null, Map.of("state", MLTaskState.FAILED, "error", jsonPathNotFoundException.getMessage()), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
        }
        catch (Exception e) {
            log.error("Error in ingest, failed to produce a successRate", (Throwable)e);
            this.mlTaskManager.updateMLTask(taskId, null, Map.of("state", MLTaskState.FAILED, "error", MLExceptionUtils.getRootCauseMessage((Throwable)e)), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
        }
    }

    protected void handleSuccessRate(double successRate, String taskId) {
        if (successRate == 100.0) {
            this.mlTaskManager.updateMLTask(taskId, null, Map.of("state", MLTaskState.COMPLETED), 5000L, true);
        } else if (successRate > 0.0) {
            this.mlTaskManager.updateMLTask(taskId, null, Map.of("state", MLTaskState.FAILED, "error", "batch ingestion successful rate is " + successRate), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
        } else {
            this.mlTaskManager.updateMLTask(taskId, null, Map.of("state", MLTaskState.FAILED, "error", "batch ingestion successful rate is 0"), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
        }
    }

    private void validateBatchIngestInput(MLBatchIngestionInput mlBatchIngestionInput) {
        if (mlBatchIngestionInput == null || mlBatchIngestionInput.getDataSources() == null || mlBatchIngestionInput.getDataSources().isEmpty()) {
            throw new IllegalArgumentException("The batch ingest input data source cannot be null");
        }
        if (mlBatchIngestionInput.getCredential() == null && mlBatchIngestionInput.getConnectorId() == null) {
            throw new IllegalArgumentException("The batch ingest credential or connector_id cannot be null");
        }
        Map dataSources = mlBatchIngestionInput.getDataSources();
        if (dataSources.get(TYPE) == null || dataSources.get(SOURCE) == null) {
            throw new IllegalArgumentException("The batch ingest input data source is missing data type or source");
        }
        if (((String)dataSources.get(TYPE)).equalsIgnoreCase("s3")) {
            List s3Uris = (List)dataSources.get(SOURCE);
            if (s3Uris == null || s3Uris.isEmpty()) {
                throw new IllegalArgumentException("The batch ingest input s3Uris is empty");
            }
            Map<Boolean, List<String>> partitionedUris = s3Uris.stream().collect(Collectors.partitioningBy(uri -> S3_URI_PATTERN.matcher((CharSequence)uri).matches()));
            List<String> invalidUris = partitionedUris.get(false);
            if (!invalidUris.isEmpty()) {
                throw new IllegalArgumentException("The following batch ingest input S3 URIs are invalid: " + String.valueOf(invalidUris));
            }
        }
    }
}

