/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.memory;

import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.search.SearchResponse;
import java.io.Closeable;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.SharedIndexState;
import org.opensearch.knn.index.memory.SharedIndexStateManager;
import org.opensearch.knn.index.store.IndexInputWithBuffer;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.training.TrainingDataConsumer;
import org.opensearch.knn.training.VectorReader;

public interface NativeMemoryLoadStrategy<T extends NativeMemoryAllocation, U extends NativeMemoryEntryContext<T>> {
    public T load(U var1) throws IOException;

    public static class AnonymousLoadStrategy
    implements NativeMemoryLoadStrategy<NativeMemoryAllocation.AnonymousAllocation, NativeMemoryEntryContext.AnonymousEntryContext>,
    Closeable {
        private static AnonymousLoadStrategy INSTANCE;
        private final ExecutorService executor = Executors.newSingleThreadExecutor();

        public static synchronized AnonymousLoadStrategy getInstance() {
            if (INSTANCE == null) {
                INSTANCE = new AnonymousLoadStrategy();
            }
            return INSTANCE;
        }

        private AnonymousLoadStrategy() {
        }

        @Override
        public NativeMemoryAllocation.AnonymousAllocation load(NativeMemoryEntryContext.AnonymousEntryContext nativeMemoryEntryContext) {
            return new NativeMemoryAllocation.AnonymousAllocation(this.executor, nativeMemoryEntryContext.calculateSizeInKB());
        }

        @Override
        public void close() {
            this.executor.shutdown();
        }
    }

    public static class TrainingLoadStrategy
    implements NativeMemoryLoadStrategy<NativeMemoryAllocation.TrainingDataAllocation, NativeMemoryEntryContext.TrainingDataEntryContext>,
    Closeable {
        private static volatile TrainingLoadStrategy INSTANCE;
        private final ExecutorService executor = Executors.newSingleThreadExecutor();
        private VectorReader vectorReader;

        public static synchronized TrainingLoadStrategy getInstance() {
            if (INSTANCE == null) {
                INSTANCE = new TrainingLoadStrategy();
            }
            return INSTANCE;
        }

        public static void initialize(VectorReader vectorReader) {
            TrainingLoadStrategy.getInstance().vectorReader = vectorReader;
        }

        private TrainingLoadStrategy() {
        }

        @Override
        public NativeMemoryAllocation.TrainingDataAllocation load(NativeMemoryEntryContext.TrainingDataEntryContext nativeMemoryEntryContext) {
            NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation(this.executor, 0L, nativeMemoryEntryContext.calculateSizeInKB(), nativeMemoryEntryContext.getVectorDataType());
            QuantizationConfig quantizationConfig = nativeMemoryEntryContext.getQuantizationConfig();
            trainingDataAllocation.setQuantizationConfig(quantizationConfig);
            TrainingDataConsumer vectorDataConsumer = nativeMemoryEntryContext.getVectorDataType().getTrainingDataConsumer(trainingDataAllocation);
            trainingDataAllocation.writeLock();
            this.vectorReader.read(nativeMemoryEntryContext.getClusterService(), nativeMemoryEntryContext.getTrainIndexName(), nativeMemoryEntryContext.getTrainFieldName(), nativeMemoryEntryContext.getMaxVectorCount(), nativeMemoryEntryContext.getSearchSize(), vectorDataConsumer, (ActionListener<SearchResponse>)ActionListenerHelper.wrap(response -> trainingDataAllocation.writeUnlock(), ex -> {
                trainingDataAllocation.closeUnsafe();
                throw new RuntimeException((Throwable)ex);
            }));
            return trainingDataAllocation;
        }

        @Override
        public void close() throws IOException {
            this.executor.shutdown();
        }
    }

    public static class IndexLoadStrategy
    implements NativeMemoryLoadStrategy<NativeMemoryAllocation.IndexAllocation, NativeMemoryEntryContext.IndexEntryContext>,
    Closeable {
        private static final Logger log = LogManager.getLogger(IndexLoadStrategy.class);
        private static IndexLoadStrategy INSTANCE;
        private final ExecutorService executor = Executors.newSingleThreadExecutor();

        public static synchronized IndexLoadStrategy getInstance() {
            if (INSTANCE == null) {
                INSTANCE = new IndexLoadStrategy();
            }
            return INSTANCE;
        }

        private IndexLoadStrategy() {
        }

        @Override
        public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.IndexEntryContext indexEntryContext) throws IOException {
            String cacheKey = indexEntryContext.getKey();
            String vectorFileName = NativeMemoryCacheKeyHelper.extractVectorIndexFileName(cacheKey);
            if (vectorFileName == null) {
                throw new IllegalStateException("Invalid cache key was given. The key [" + cacheKey + "] does not contain the corresponding vector file name.");
            }
            KNNEngine knnEngine = KNNEngine.getEngineNameFromPath(vectorFileName);
            Directory directory = indexEntryContext.getDirectory();
            int indexSizeKb = Math.toIntExact(directory.fileLength(vectorFileName) / 1024L);
            try (IndexInput readStream = directory.openInput(vectorFileName, IOContext.READONCE);){
                IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(readStream);
                long indexAddress = JNIService.loadIndex(indexInputWithBuffer, indexEntryContext.getParameters(), knnEngine);
                NativeMemoryAllocation.IndexAllocation indexAllocation = this.createIndexAllocation(indexEntryContext, knnEngine, indexAddress, indexSizeKb, vectorFileName);
                return indexAllocation;
            }
        }

        private NativeMemoryAllocation.IndexAllocation createIndexAllocation(NativeMemoryEntryContext.IndexEntryContext indexEntryContext, KNNEngine knnEngine, long indexAddress, int indexSizeKb, String vectorFileName) {
            SharedIndexState sharedIndexState = null;
            String modelId = indexEntryContext.getModelId();
            if (IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, indexAddress)) {
                log.info("Index with model: \"{}\" requires shared state. Retrieving shared state.", (Object)modelId);
                sharedIndexState = SharedIndexStateManager.getInstance().get(indexAddress, modelId, knnEngine);
                JNIService.setSharedIndexState(indexAddress, sharedIndexState.getSharedIndexStateAddress(), knnEngine);
            }
            return new NativeMemoryAllocation.IndexAllocation(this.executor, indexAddress, indexSizeKb, knnEngine, vectorFileName, indexEntryContext.getOpenSearchIndexName(), sharedIndexState, IndexUtil.isBinaryIndex(knnEngine, indexEntryContext.getParameters()));
        }

        @Override
        public void close() {
            this.executor.shutdown();
        }
    }
}

