/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.plugin.transport;

import io.skylite.Version;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionFilters;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.index.IndexResponse;
import io.skylite.core.action.support.HandledTransportAction;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.inject.Inject;
import io.skylite.core.tasks.Task;
import io.skylite.core.transport.TransportService;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.knn.plugin.transport.TrainingModelRequest;
import org.opensearch.knn.plugin.transport.TrainingModelResponse;
import org.opensearch.knn.training.TrainingJob;
import org.opensearch.knn.training.TrainingJobRunner;

public class TrainingModelTransportAction
extends HandledTransportAction<TrainingModelRequest, TrainingModelResponse> {
    private final ClusterService clusterService;

    @Inject
    public TrainingModelTransportAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService) {
        super("cluster:admin/knn_training_model_action", transportService, actionFilters, TrainingModelRequest::new);
        this.clusterService = clusterService;
    }

    protected void doExecute(Task task, TrainingModelRequest request, ActionListener<TrainingModelResponse> listener) {
        KNNMethodContext knnMethodContext = request.getKnnMethodContext();
        KNNMethodConfigContext knnMethodConfigContext = request.getKnnMethodConfigContext();
        QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY;
        if (knnMethodContext != null && request.getKnnMethodConfigContext() != null) {
            KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine().getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
            quantizationConfig = knnLibraryIndexingContext.getQuantizationConfig();
        }
        NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext(request.getTrainingDataSizeInKB(), request.getTrainingIndex(), request.getTrainingField(), NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(), this.clusterService, request.getMaximumVectorCount(), request.getSearchSize(), request.getVectorDataType(), quantizationConfig);
        NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext = new NativeMemoryEntryContext.AnonymousEntryContext(request.getKnnMethodContext().estimateOverheadInKB(KNNMethodConfigContext.builder().dimension(request.getDimension()).vectorDataType(request.getVectorDataType()).versionCreated(Version.CURRENT).build()), NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance());
        TrainingJob trainingJob = new TrainingJob(request.getModelId(), request.getKnnMethodContext(), NativeMemoryCacheManager.getInstance(), trainingDataEntryContext, modelAnonymousEntryContext, request.getKnnMethodConfigContext(), request.getDescription(), this.clusterService.localNode().getEphemeralId(), request.getMode(), request.getCompressionLevel());
        KNNCounter.TRAINING_REQUESTS.increment();
        ActionListener wrappedListener = ActionListenerHelper.wrap(arg_0 -> listener.onResponse(arg_0), ex -> {
            KNNCounter.TRAINING_ERRORS.increment();
            listener.onFailure(ex);
        });
        try {
            TrainingJobRunner.getInstance().execute(trainingJob, (ActionListener<IndexResponse>)ActionListenerHelper.wrap(indexResponse -> wrappedListener.onResponse((Object)new TrainingModelResponse(indexResponse.getId())), arg_0 -> ((ActionListener)wrappedListener).onFailure(arg_0)));
        }
        catch (IOException | InterruptedException | ExecutionException e) {
            wrappedListener.onFailure(e);
        }
    }
}

