/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.task;

import io.lucenia.ml.common.engine.indices.MLInputDatasetHandler;
import io.lucenia.ml.common.engine.systemindices.MLIndicesHandler;
import io.lucenia.ml.common.task.MLTaskDispatcher;
import io.lucenia.ml.common.task.MLTaskManager;
import io.lucenia.ml.common.task.MLTaskRunner;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionListenerResponseHandler;
import io.skylite.core.action.WriteRequest;
import io.skylite.core.action.index.IndexRequest;
import io.skylite.core.action.index.IndexResponse;
import io.skylite.core.action.support.ThreadedActionListener;
import io.skylite.core.client.Client;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.threadpool.ThreadPool;
import io.skylite.core.transport.TransportResponseHandler;
import io.skylite.core.xcontent.MediaTypeRegistry;
import io.skylite.core.xcontent.ToXContent;
import io.skylite.ml.common.breaker.MLCircuitBreakerService;
import io.skylite.ml.common.cluster.DiscoveryNodeHelper;
import io.skylite.ml.common.dataset.MLInputDataType;
import io.skylite.ml.common.dataset.MLInputDataset;
import io.skylite.ml.common.engine.MLEngine;
import io.skylite.ml.common.input.Input;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.model.MLModel;
import io.skylite.ml.common.output.MLOutput;
import io.skylite.ml.common.output.MLTrainingOutput;
import io.skylite.ml.common.stats.ActionName;
import io.skylite.ml.common.stats.MLActionLevelStat;
import io.skylite.ml.common.stats.MLNodeLevelStat;
import io.skylite.ml.common.stats.MLStats;
import io.skylite.ml.common.task.MLTask;
import io.skylite.ml.common.task.MLTaskState;
import io.skylite.ml.common.task.MLTaskType;
import io.skylite.ml.common.transport.MLTaskResponse;
import io.skylite.ml.common.transport.training.MLTrainingTaskRequest;
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class MLTrainingTaskRunner
extends MLTaskRunner<MLTrainingTaskRequest, MLTaskResponse> {
    private static final Logger log = LogManager.getLogger(MLTrainingTaskRunner.class);
    private final ThreadPool threadPool;
    private final ClusterService clusterService;
    private final Client client;
    private final MLIndicesHandler mlIndicesHandler;
    private final MLInputDatasetHandler mlInputDatasetHandler;
    protected final DiscoveryNodeHelper nodeHelper;
    private final MLEngine mlEngine;

    public MLTrainingTaskRunner(ThreadPool threadPool, ClusterService clusterService, Client client, MLTaskManager mlTaskManager, MLStats mlStats, MLIndicesHandler mlIndicesHandler, MLInputDatasetHandler mlInputDatasetHandler, MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService, DiscoveryNodeHelper nodeHelper, MLEngine mlEngine) {
        super(mlTaskManager, mlStats, nodeHelper, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
        this.threadPool = threadPool;
        this.clusterService = clusterService;
        this.client = client;
        this.mlIndicesHandler = mlIndicesHandler;
        this.mlInputDatasetHandler = mlInputDatasetHandler;
        this.nodeHelper = nodeHelper;
        this.mlEngine = mlEngine;
    }

    @Override
    protected String getTransportActionName() {
        return "cluster:admin/lucenia/ml/train";
    }

    @Override
    protected TransportResponseHandler<MLTaskResponse> getResponseHandler(ActionListener<MLTaskResponse> listener) {
        return new ActionListenerResponseHandler(listener, MLTaskResponse::new);
    }

    @Override
    protected void executeTask(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
        MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
        Instant now = Instant.now();
        MLTask mlTask = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNodes(List.of(this.clusterService.localNode().getId())).createTime(now).lastUpdateTime(now).async(request.isAsync()).build();
        if (request.isAsync()) {
            this.mlTaskManager.createMLTask(mlTask, (ActionListener<IndexResponse>)ActionListenerHelper.wrap(r -> {
                String taskId = r.getId();
                mlTask.setTaskId(taskId);
                listener.onResponse((Object)new MLTaskResponse((MLOutput)new MLTrainingOutput(null, taskId, mlTask.getState().name())));
                ActionListener internalListener = ActionListenerHelper.wrap(res -> {
                    String modelId = ((MLTrainingOutput)res.getOutput()).getModelId();
                    this.mlStats.createModelCounterStatIfAbsent(modelId, ActionName.TRAIN, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
                    log.info("ML model trained successfully, task id: {}, model id: {}", (Object)taskId, (Object)modelId);
                    mlTask.setModelId(modelId);
                    this.handleAsyncMLTaskComplete(mlTask);
                }, ex -> {
                    log.error("Failed to train ML model for task " + taskId);
                    this.handleAsyncMLTaskFailure(mlTask, (Exception)ex);
                });
                this.startTrainingTask(mlTask, request.getMlInput(), (ActionListener<MLTaskResponse>)internalListener);
            }, e -> {
                log.error("Failed to create ML task", (Throwable)e);
                listener.onFailure(e);
            }));
        } else {
            mlTask.setTaskId(UUID.randomUUID().toString());
            this.startTrainingTask(mlTask, request.getMlInput(), listener);
        }
    }

    private void startTrainingTask(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
        ActionListener<MLTaskResponse> internalListener = this.wrappedCleanupListener(listener, mlTask.getTaskId());
        this.mlStats.getStat((Enum)MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
        this.mlStats.getStat((Enum)MLNodeLevelStat.ML_REQUEST_COUNT).increment();
        this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
        mlTask.setState(MLTaskState.RUNNING);
        this.mlTaskManager.add(mlTask);
        try {
            if (mlInput.getInputDataset().getInputDataType().equals((Object)MLInputDataType.SEARCH_QUERY)) {
                ActionListener dataFrameActionListener = ActionListenerHelper.wrap(dataSet -> this.train(mlTask, mlInput.toBuilder().inputDataset(dataSet).build(), internalListener), e -> {
                    log.error("Failed to generate DataFrame from search query", (Throwable)e);
                    internalListener.onFailure(e);
                });
                this.mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), (ActionListener<MLInputDataset>)new ThreadedActionListener(log, this.threadPool, "lucenia_ml_train", dataFrameActionListener, false));
            } else {
                this.threadPool.executor("lucenia_ml_train").execute(() -> this.train(mlTask, mlInput, internalListener));
            }
        }
        catch (Exception e2) {
            log.error("Failed to train " + String.valueOf(mlInput.getAlgorithm()), (Throwable)e2);
            internalListener.onFailure(e2);
        }
    }

    private void train(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
        ActionListener listener = ActionListenerHelper.wrap(r -> actionListener.onResponse(r), e -> {
            this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
            this.mlStats.getStat((Enum)MLNodeLevelStat.ML_FAILURE_COUNT).increment();
            actionListener.onFailure(e);
        });
        try {
            this.mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.getTenantId(), mlTask.isAsync());
            MLModel mlModel = this.mlEngine.train((Input)mlInput);
            this.mlIndicesHandler.initModelIndexIfAbsent((ActionListener<Boolean>)ActionListenerHelper.wrap(indexCreated -> {
                if (!indexCreated.booleanValue()) {
                    listener.onFailure((Exception)new RuntimeException("No response to create ML task index"));
                    return;
                }
                try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                    ActionListener indexResponseListener = ActionListenerHelper.wrap(r -> {
                        log.info("Model saved into index, result:{}, model id: {}", (Object)r.getResult(), (Object)r.getId());
                        String returnedTaskId = mlTask.isAsync() ? mlTask.getTaskId() : null;
                        MLTrainingOutput output = new MLTrainingOutput(r.getId(), returnedTaskId, MLTaskState.COMPLETED.name());
                        listener.onResponse((Object)MLTaskResponse.builder().output((MLOutput)output).build());
                    }, e -> listener.onFailure(e));
                    IndexRequest indexRequest = new IndexRequest(".plugins-ml-model");
                    indexRequest.source(mlModel.toXContent(MediaTypeRegistry.JSON.contentBuilder(), ToXContent.EMPTY_PARAMS));
                    indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                    this.client.index(indexRequest, ActionListenerHelper.runBefore((ActionListener)indexResponseListener, () -> context.restore()));
                }
                catch (Exception e2) {
                    log.error("Failed to save ML model", (Throwable)e2);
                    listener.onFailure(e2);
                }
            }, e -> {
                log.error("Failed to init ML model index", (Throwable)e);
                listener.onFailure(e);
            }));
        }
        catch (Exception e2) {
            log.error("Failed to train " + String.valueOf(mlInput.getAlgorithm()), (Throwable)e2);
            listener.onFailure(e2);
        }
    }
}

