/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.task;

import io.lucenia.ml.common.engine.systemindices.MLIndicesHandler;
import io.lucenia.ml.common.jobs.MLBatchTaskUpdateJobParameter;
import io.skylite.SkyliteExceptionsHelper;
import io.skylite.common.Randomness;
import io.skylite.common.action.ActionListener;
import io.skylite.common.unit.TimeValue;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.DocWriteResponse;
import io.skylite.core.action.WriteRequest;
import io.skylite.core.action.admin.indices.refresh.RefreshRequest;
import io.skylite.core.action.index.IndexRequest;
import io.skylite.core.action.index.IndexResponse;
import io.skylite.core.action.update.UpdateRequest;
import io.skylite.core.action.update.UpdateResponse;
import io.skylite.core.client.Client;
import io.skylite.core.client.metadata.MetadataClient;
import io.skylite.core.client.metadata.PutDataObjectRequest;
import io.skylite.core.client.metadata.UpdateDataObjectRequest;
import io.skylite.core.client.metadata.UpdateDataObjectResponse;
import io.skylite.core.common.Strings;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.index.IndexNotFoundException;
import io.skylite.core.index.engine.DocumentMissingException;
import io.skylite.core.index.query.BoolQueryBuilder;
import io.skylite.core.index.query.QueryBuilder;
import io.skylite.core.jobs.schedule.Schedule;
import io.skylite.core.rest.RestStatus;
import io.skylite.core.search.SearchRequest;
import io.skylite.core.search.builder.SearchSourceBuilder;
import io.skylite.core.threadpool.ThreadPool;
import io.skylite.core.xcontent.MediaTypeRegistry;
import io.skylite.core.xcontent.ToXContentObject;
import io.skylite.jobs.schedule.IntervalSchedule;
import io.skylite.ml.common.exception.MLException;
import io.skylite.ml.common.exception.MLExceptionUtils;
import io.skylite.ml.common.exception.MLLimitExceededException;
import io.skylite.ml.common.exception.MLResourceNotFoundException;
import io.skylite.ml.common.task.MLTask;
import io.skylite.ml.common.task.MLTaskCache;
import io.skylite.ml.common.task.MLTaskState;
import io.skylite.ml.common.task.MLTaskType;
import java.io.IOException;
import java.security.AccessController;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.index.query.QueryBuilders;

public class MLTaskManager {
    private static final Logger log = LogManager.getLogger(MLTaskManager.class);
    public static int TASK_SEMAPHORE_TIMEOUT = 5000;
    private static final int MAX_UPDATE_RETRY_TIMES = 3;
    private final Map<String, MLTaskCache> taskCaches;
    private final Map<String, AtomicInteger> taskUpdateRetryCount = new ConcurrentHashMap<String, AtomicInteger>();
    private final Client client;
    private final MetadataClient sdkClient;
    private final ThreadPool threadPool;
    private final MLIndicesHandler mlIndicesHandler;
    private final Map<MLTaskType, AtomicInteger> runningTasksCount;
    public static final Set<MLTaskState> TASK_DONE_STATES = Set.of(MLTaskState.COMPLETED, MLTaskState.COMPLETED_WITH_ERROR, MLTaskState.FAILED, MLTaskState.CANCELLED);

    public MLTaskManager(Client client, MetadataClient sdkClient, ThreadPool threadPool, MLIndicesHandler mlIndicesHandler) {
        this.client = client;
        this.sdkClient = sdkClient;
        this.threadPool = threadPool;
        this.mlIndicesHandler = mlIndicesHandler;
        this.taskCaches = new ConcurrentHashMap<String, MLTaskCache>();
        this.runningTasksCount = new ConcurrentHashMap<MLTaskType, AtomicInteger>();
    }

    public synchronized void checkLimitAndAddRunningTask(MLTask mlTask, Integer limit) {
        AtomicInteger runningTaskCount = this.runningTasksCount.computeIfAbsent(mlTask.getTaskType(), it -> new AtomicInteger(0));
        if (runningTaskCount.get() < 0) {
            runningTaskCount.set(0);
        }
        log.debug("Task id: {}, current running task {}: {}", (Object)mlTask.getTaskId(), (Object)mlTask.getTaskType(), (Object)runningTaskCount.get());
        if (runningTaskCount.get() >= limit) {
            String error = "exceed max running task limit";
            log.warn("{} for task {}", (Object)error, (Object)mlTask.getTaskId());
            throw new MLLimitExceededException(error);
        }
        if (this.contains(mlTask.getTaskId())) {
            this.getMLTask(mlTask.getTaskId()).setState(MLTaskState.RUNNING);
        } else {
            mlTask.setState(MLTaskState.RUNNING);
            this.add(mlTask);
        }
        runningTaskCount.incrementAndGet();
    }

    public synchronized void checkMaxBatchJobTask(MLTaskType mlTaskType, int maxTaskLimit, ActionListener<Boolean> listener) {
        try {
            BoolQueryBuilder boolQuery = QueryBuilders.boolQuery().must((QueryBuilder)QueryBuilders.termQuery((String)"task_type", (String)mlTaskType.name())).must((QueryBuilder)QueryBuilders.boolQuery().should((QueryBuilder)QueryBuilders.termQuery((String)"state", (Object)MLTaskState.CREATED)).should((QueryBuilder)QueryBuilders.termQuery((String)"state", (Object)MLTaskState.RUNNING)));
            SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query((QueryBuilder)boolQuery);
            SearchRequest searchRequest = new SearchRequest(new String[]{".plugins-ml-task"});
            searchRequest.source(searchSourceBuilder);
            try (ThreadContext.StoredContext threadContext = this.client.threadPool().getThreadContext().stashContext();){
                ActionListener internalListener = ActionListenerHelper.runBefore((ActionListener)ActionListenerHelper.wrap(searchResponse -> {
                    long matchedCount = searchResponse.getHits().getHits().length;
                    Boolean exceedLimit = false;
                    if (matchedCount >= (long)maxTaskLimit) {
                        exceedLimit = true;
                    }
                    listener.onResponse((Object)exceedLimit);
                }, arg_0 -> listener.onFailure(arg_0)), () -> threadContext.restore());
                this.client.admin().indices().refresh(new RefreshRequest(new String[]{".plugins-ml-task"}), ActionListenerHelper.wrap(refreshResponse -> this.client.search(searchRequest, internalListener), e -> {
                    log.error("Failed to refresh Task index during search MLTaskType for {}", (Object)mlTaskType);
                    internalListener.onFailure(e);
                }));
            }
            catch (Exception e2) {
                listener.onFailure(e2);
            }
        }
        catch (Exception e3) {
            log.error("Failed to search ML task for {}", (Object)mlTaskType);
            listener.onFailure(e3);
        }
    }

    public synchronized void add(MLTask mlTask) {
        this.add(mlTask, null);
    }

    public synchronized void add(MLTask mlTask, List<String> workerNodes) {
        String taskId = mlTask.getTaskId();
        if (this.contains(taskId)) {
            throw new IllegalArgumentException("Duplicate taskId");
        }
        this.taskCaches.put(taskId, new MLTaskCache(mlTask, workerNodes));
        log.debug("add ML task to cache, taskId: {}, taskType: {} ", (Object)taskId, (Object)mlTask.getTaskType());
    }

    public boolean contains(String taskId) {
        return this.taskCaches.containsKey(taskId);
    }

    public void remove(String taskId) {
        if (this.contains(taskId)) {
            AtomicInteger runningTaskCount;
            MLTaskCache taskCache = this.taskCaches.remove(taskId);
            MLTask mlTask = taskCache.getMlTask();
            if (mlTask.getState() != MLTaskState.CREATED && (runningTaskCount = this.runningTasksCount.get(mlTask.getTaskType())) != null) {
                runningTaskCount.decrementAndGet();
            }
            log.debug("remove ML task from cache {}", (Object)taskId);
        }
    }

    public MLTask getMLTask(String taskId) {
        if (this.contains(taskId)) {
            return this.taskCaches.get(taskId).getMlTask();
        }
        return null;
    }

    public MLTaskCache getMLTaskCache(String taskId) {
        if (this.contains(taskId)) {
            return this.taskCaches.get(taskId);
        }
        return null;
    }

    public Set<String> getWorkNodes(String taskId) {
        if (this.taskCaches.containsKey(taskId)) {
            return this.taskCaches.get(taskId).getWorkerNodes();
        }
        return null;
    }

    public void addNodeError(String taskId, String workerNodeId, String error) {
        log.debug("add task error: taskId: {}, workerNodeId: {}, error: {}", (Object)taskId, (Object)workerNodeId, (Object)error);
        if (this.taskCaches.containsKey(taskId)) {
            this.taskCaches.get(taskId).addError(workerNodeId, error);
        }
    }

    public String[] getAllTaskIds() {
        return Strings.toStringArray(this.taskCaches.keySet());
    }

    public int getRunningTaskCount() {
        int res = 0;
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            MLTask mlTask = entry.getValue().getMlTask();
            if (mlTask.getState() == null || mlTask.getState() != MLTaskState.RUNNING) continue;
            ++res;
        }
        return res;
    }

    public void clear() {
        this.taskCaches.clear();
    }

    public void createMLTask(MLTask mlTask, ActionListener<IndexResponse> listener) {
        this.mlIndicesHandler.initMLTaskIndex((ActionListener<Boolean>)ActionListenerHelper.wrap(indexCreated -> {
            if (!indexCreated.booleanValue()) {
                listener.onFailure((Exception)new RuntimeException("No response to create ML task index"));
                return;
            }
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.sdkClient.putDataObjectAsync(((PutDataObjectRequest.Builder)((PutDataObjectRequest.Builder)PutDataObjectRequest.builder().index(".plugins-ml-task")).tenantId(mlTask.getTenantId())).dataObject((ToXContentObject)mlTask).build()).whenComplete((r, throwable) -> {
                    context.restore();
                    if (throwable != null) {
                        Exception cause = SkyliteExceptionsHelper.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                        log.error("Failed to index ML task", (Throwable)cause);
                        listener.onFailure(cause);
                    } else {
                        try {
                            IndexResponse indexResponse = r.indexResponse();
                            log.info("Task creation result: {}, Task id: {}", (Object)indexResponse.getResult(), (Object)indexResponse.getId());
                            listener.onResponse((Object)indexResponse);
                        }
                        catch (Exception e) {
                            listener.onFailure(e);
                        }
                    }
                });
            }
            catch (Exception e) {
                log.error("Failed to create ML task for {}, {}", (Object)mlTask.getFunctionName(), (Object)mlTask.getTaskType());
                listener.onFailure(e);
            }
        }, e -> {
            log.error("Failed to create ML task index", (Throwable)e);
            listener.onFailure(e);
        }));
    }

    public void updateTaskStateAsRunning(String taskId, String tenantId, boolean isAsyncTask) {
        if (!this.contains(taskId)) {
            throw new IllegalArgumentException("Task not found");
        }
        MLTask task = this.getMLTask(taskId);
        task.setState(MLTaskState.RUNNING);
        if (isAsyncTask) {
            this.updateMLTask(taskId, tenantId, Map.of("state", MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false);
        }
    }

    public void updateMLTask(String taskId, String tenantId, Map<String, Object> updatedFields, long timeoutInMillis, boolean removeFromCache) {
        ActionListener internalListener = ActionListenerHelper.wrap(response -> {
            if (response.status() == RestStatus.OK) {
                log.debug("Updated ML task successfully: {}, taskId: {}, updatedFields: {}", (Object)response.status(), (Object)taskId, (Object)updatedFields);
            } else {
                log.error("Failed to update ML task {}, status: {}, updatedFields: {}", (Object)taskId, (Object)response.status(), (Object)updatedFields);
            }
        }, e -> MLExceptionUtils.logException((String)("Failed to update ML task: " + taskId), (Exception)e, (Logger)log));
        this.updateMLTask(taskId, tenantId, updatedFields, (ActionListener<UpdateResponse>)internalListener, timeoutInMillis, removeFromCache);
    }

    public void updateMLTask(String taskId, String tenantId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener, long timeoutInMillis, boolean removeFromCache) {
        this.updateMLTaskInternal(taskId, tenantId, updatedFields, listener, timeoutInMillis, removeFromCache, 0);
    }

    private void updateMLTaskInternal(String taskId, String tenantId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener, long timeoutInMillis, boolean removeFromCache, int attempt) {
        MLTaskCache taskCache = this.taskCaches.get(taskId);
        if (removeFromCache) {
            this.remove(taskId);
        }
        if (taskCache == null) {
            listener.onFailure((Exception)new MLResourceNotFoundException("Can't find task in cache: " + taskId));
            return;
        }
        Semaphore semaphore = taskCache.getUpdateTaskIndexSemaphore();
        if (semaphore != null) {
            try {
                if (!semaphore.tryAcquire(timeoutInMillis, TimeUnit.MILLISECONDS)) {
                    if (attempt < 3) {
                        AccessController.doPrivileged(() -> {
                            long base = Math.min(1000L, 100L << attempt);
                            long jitter = Randomness.get().nextLong(0L, base / 2L);
                            long delay = base + jitter;
                            log.warn("MLTask[{}] semaphore busy \u2014 retry {}/{} ({}ms)", (Object)taskId, (Object)(attempt + 1), (Object)3, (Object)delay);
                            Runnable retryTask = () -> this.updateMLTaskInternal(taskId, tenantId, updatedFields, listener, timeoutInMillis, removeFromCache, attempt + 1);
                            this.threadPool.schedule(retryTask, TimeValue.timeValueMillis((long)delay), "lucenia_ml_general");
                            return null;
                        });
                        return;
                    }
                    log.error("MLTask[{}] semaphore contention \u2014 exceeded {} retries", (Object)taskId, (Object)3);
                    listener.onFailure((Exception)new MLException("Other updating request not finished yet"));
                    return;
                }
            }
            catch (InterruptedException e) {
                log.error("Interrupted while acquiring semaphore for ML task {}", (Object)taskId);
                listener.onFailure((Exception)e);
                return;
            }
        }
        this.threadPool.executor("lucenia_ml_general").execute(() -> {
            ActionListener actionListener = semaphore == null ? listener : ActionListenerHelper.runAfter((ActionListener)listener, semaphore::release);
            try {
                if (updatedFields == null || updatedFields.isEmpty()) {
                    actionListener.onFailure((Exception)new IllegalArgumentException("Updated fields is null or empty"));
                    return;
                }
                HashMap<String, Long> updatedContent = new HashMap<String, Long>(updatedFields);
                updatedContent.put("last_update_time", Instant.now().toEpochMilli());
                UpdateDataObjectRequest.Builder requestBuilder = ((UpdateDataObjectRequest.Builder)((UpdateDataObjectRequest.Builder)((UpdateDataObjectRequest.Builder)UpdateDataObjectRequest.builder().index(".plugins-ml-task")).id(taskId)).tenantId(tenantId)).dataObject(updatedContent);
                if (updatedFields.containsKey("state") && TASK_DONE_STATES.contains(updatedFields.get("state"))) {
                    requestBuilder.retryOnConflict(3);
                }
                UpdateDataObjectRequest updateDataObjectRequest = requestBuilder.build();
                try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                    this.sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((r, throwable) -> {
                        context.restore();
                        if (throwable != null) {
                            Exception cause = SkyliteExceptionsHelper.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                            if (this.isRetryableIndexError(cause)) {
                                AccessController.doPrivileged(() -> {
                                    int retry = this.taskUpdateRetryCount.computeIfAbsent(taskId, k -> new AtomicInteger(0)).incrementAndGet();
                                    if (retry <= 3) {
                                        long delay = (long)retry * 200L;
                                        log.warn("MLTask[{}] update retry {} ({}ms) \u2014 likely during cluster boot / index init race", (Object)taskId, (Object)retry, (Object)delay);
                                        Runnable retryTask = () -> this.mlIndicesHandler.initMLTaskIndex((ActionListener<Boolean>)ActionListener.wrap(ok -> this.updateMLTask(taskId, tenantId, updatedFields, listener, timeoutInMillis, removeFromCache), arg_0 -> ((ActionListener)actionListener).onFailure(arg_0)));
                                        this.threadPool.schedule(retryTask, TimeValue.timeValueMillis((long)delay), "lucenia_ml_general");
                                        return null;
                                    }
                                    log.error("Exceeded retry attempts for MLTask {} \u2014 marking FAILED", (Object)taskId);
                                    this.taskUpdateRetryCount.remove(taskId);
                                    this.updateMLTaskDirectly(taskId, Map.of("state", MLTaskState.FAILED, "error", "Failed to update task after 3 retries"), (ActionListener<UpdateResponse>)ActionListenerHelper.wrap(resp -> {
                                        this.remove(taskId);
                                        actionListener.onFailure(cause);
                                    }, e -> {
                                        this.remove(taskId);
                                        actionListener.onFailure(cause);
                                    }));
                                    return null;
                                });
                                return;
                            }
                            actionListener.onFailure(cause);
                            return;
                        }
                        this.taskUpdateRetryCount.remove(taskId);
                        this.handleUpdateDataObjectCompletionStage((UpdateDataObjectResponse)r, (Throwable)throwable, this.getUpdateResponseListener(taskId, (ActionListener<UpdateResponse>)actionListener));
                    });
                }
                catch (Exception e) {
                    log.error("Failed to update ML task {}", (Object)taskId);
                    actionListener.onFailure(e);
                }
            }
            catch (Exception e) {
                if (semaphore != null) {
                    semaphore.release();
                }
                log.error("Failed to update ML task {}", (Object)taskId);
                listener.onFailure(e);
            }
        });
    }

    public void updateMLTaskDirectly(String taskId, Map<String, Object> updatedFields) {
        this.updateMLTaskDirectly(taskId, updatedFields, (ActionListener<UpdateResponse>)ActionListenerHelper.wrap(r -> log.debug("updated ML task directly: {}", (Object)taskId), e -> log.error("Failed to update ML task {}", (Object)taskId)));
    }

    public void updateMLTaskDirectly(String taskId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener) {
        try {
            if (updatedFields == null || updatedFields.isEmpty()) {
                listener.onFailure((Exception)new IllegalArgumentException("Updated fields is null or empty"));
                return;
            }
            UpdateRequest updateRequest = new UpdateRequest(".plugins-ml-task", taskId);
            HashMap<String, Object> updatedContent = new HashMap<String, Object>(updatedFields);
            updatedContent.put("last_update_time", Instant.now().toEpochMilli());
            updateRequest.doc(updatedContent);
            updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            if (updatedFields.containsKey("state") && TASK_DONE_STATES.contains(updatedFields.get("state"))) {
                updateRequest.retryOnConflict(3);
            }
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.client.update(updateRequest, ActionListenerHelper.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore()));
            }
            catch (Exception e) {
                listener.onFailure(e);
            }
        }
        catch (Exception e) {
            log.error("Failed to update ML task {}", (Object)taskId);
            listener.onFailure(e);
        }
    }

    public boolean containsModel(String modelId) {
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            if (!modelId.equals(entry.getValue().getMlTask().getModelId())) continue;
            return true;
        }
        return false;
    }

    public List<String[]> getLocalRunningDeployModelTasks() {
        ArrayList<String> runningDeployModelTaskIds = new ArrayList<String>();
        ArrayList<String> runningDeployModelIds = new ArrayList<String>();
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            MLTask mlTask = entry.getValue().getMlTask();
            if (mlTask.getTaskType() != MLTaskType.DEPLOY_MODEL || mlTask.getState() == MLTaskState.CREATED) continue;
            runningDeployModelTaskIds.add(entry.getKey());
            runningDeployModelIds.add(mlTask.getModelId());
        }
        return Arrays.asList(runningDeployModelTaskIds.toArray(new String[0]), runningDeployModelIds.toArray(new String[0]));
    }

    private void handleUpdateDataObjectCompletionStage(UpdateDataObjectResponse r, Throwable throwable, ActionListener<UpdateResponse> updateListener) {
        if (throwable != null) {
            Exception cause = SkyliteExceptionsHelper.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
            updateListener.onFailure(cause);
        } else {
            try {
                updateListener.onResponse((Object)r.updateResponse());
            }
            catch (Exception e) {
                updateListener.onFailure(e);
            }
        }
    }

    private ActionListener<UpdateResponse> getUpdateResponseListener(String taskId, ActionListener<UpdateResponse> actionListener) {
        return ActionListenerHelper.wrap(updateResponse -> {
            if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
                log.error("Failed to update the task with ID: {}", (Object)taskId);
                actionListener.onResponse(updateResponse);
                return;
            }
            log.info("Successfully updated the task with ID: {}", (Object)taskId);
            actionListener.onResponse(updateResponse);
        }, exception -> {
            log.error("Failed to update ML task with ID {}. Details: {}", (Object)taskId, exception);
            actionListener.onFailure(exception);
        });
    }

    public void startTaskPollingJob() throws IOException {
        String id = "ml_batch_task_polling_job";
        String jobName = "poll_batch_jobs";
        String interval = "1";
        Long lockDurationSeconds = 20L;
        MLBatchTaskUpdateJobParameter jobParameter = new MLBatchTaskUpdateJobParameter(jobName, (Schedule)new IntervalSchedule(Instant.now(), Integer.parseInt(interval), ChronoUnit.MINUTES), lockDurationSeconds, null);
        IndexRequest indexRequest = (IndexRequest)((IndexRequest)new IndexRequest().index(".ml_commons_task_polling_job")).id(id).source(jobParameter.toXContent(MediaTypeRegistry.JSON.contentBuilder(), null)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
        this.client.index(indexRequest, ActionListenerHelper.wrap(r -> log.info("Indexed ml task polling job successfully"), e -> log.error("Failed to index task polling job", (Throwable)e)));
    }

    private boolean isRetryableIndexError(Throwable t) {
        while (t != null) {
            if (t instanceof IndexNotFoundException || t instanceof DocumentMissingException) {
                return true;
            }
            String msg = t.getMessage();
            if (msg != null && ((msg = msg.toLowerCase(Locale.ROOT)).contains("index_not_found_exception") || msg.contains("index not found") || msg.contains("no such index") || msg.contains("document_missing_exception"))) {
                return true;
            }
            t = t.getCause();
        }
        return false;
    }
}

