/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.search.pipelines.generative;

import com.fasterxml.jackson.databind.node.ArrayNode;
import io.lucenia.ml.common.search.pipelines.generative.client.ConversationalMemoryClient;
import io.lucenia.ml.common.search.pipelines.generative.llm.ModelLocator;
import io.skylite.common.action.ActionListener;
import io.skylite.common.processor.AbstractProcessor;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.search.SearchResponse;
import io.skylite.core.client.Client;
import io.skylite.core.common.Strings;
import io.skylite.core.ingest.ConfigurationUtils;
import io.skylite.core.search.SearchHit;
import io.skylite.core.search.SearchRequest;
import io.skylite.core.search.pipeline.SearchProcessor;
import io.skylite.core.search.pipeline.SearchResponseProcessor;
import io.skylite.ml.common.conversation.Interaction;
import io.skylite.ml.common.exception.MLException;
import io.skylite.ml.common.search.pipelines.generative.GenerativeQAProcessorConstants;
import io.skylite.ml.common.search.pipelines.generative.GenerativeSearchResponse;
import io.skylite.ml.common.search.pipelines.generative.ext.GenerativeQAParamUtil;
import io.skylite.ml.common.search.pipelines.generative.ext.GenerativeQAParameters;
import io.skylite.ml.common.search.pipelines.generative.llm.ChatCompletionInput;
import io.skylite.ml.common.search.pipelines.generative.llm.ChatCompletionOutput;
import io.skylite.ml.common.search.pipelines.generative.llm.Llm;
import io.skylite.ml.common.search.pipelines.generative.llm.LlmIOUtil;
import io.skylite.ml.common.search.pipelines.generative.prompt.PromptUtil;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BooleanSupplier;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class GenerativeQAResponseProcessor
extends AbstractProcessor
implements SearchResponseProcessor {
    private static final Logger log = LogManager.getLogger(GenerativeQAResponseProcessor.class);
    public static String IllegalArgumentMessage = "Please check the provided generative_qa_parameters are complete and non-null(https://opensearch.org/docs/latest/search-plugins/conversational-search/#rag-pipeline). Messages in the memory can not have Null value for input and response";
    private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;
    private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;
    private final String llmModel;
    private final List<String> contextFields;
    private final String systemPrompt;
    private final String userInstructions;
    private ConversationalMemoryClient memoryClient;
    private Llm llm;
    private final BooleanSupplier featureFlagSupplier;

    protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure, Llm llm, String llmModel, List<String> contextFields, String systemPrompt, String userInstructions, BooleanSupplier supplier) {
        super(tag, description, ignoreFailure);
        this.llmModel = llmModel;
        this.contextFields = contextFields;
        this.systemPrompt = systemPrompt;
        this.userInstructions = userInstructions;
        this.llm = llm;
        this.memoryClient = new ConversationalMemoryClient(client);
        this.featureFlagSupplier = supplier;
    }

    public SearchResponse processResponse(SearchRequest searchRequest, SearchResponse searchResponse) {
        throw new UnsupportedOperationException();
    }

    public void processResponseAsync(SearchRequest request, SearchResponse response, ActionListener<SearchResponse> responseListener) {
        String llmModel;
        log.debug("Entering processResponse.");
        if (!this.featureFlagSupplier.getAsBoolean()) {
            throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG);
        }
        GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters((SearchRequest)request);
        if (params == null) {
            throw new IllegalArgumentException("generative_qa_parameters not found. Please provide ext.generative_qa_parameters to proceed. For more info, refer: https://opensearch.org/docs/latest/search-plugins/conversational-search/#step-6-use-the-pipeline-for-rag");
        }
        Integer t = params.getTimeout();
        if (t == null || t == -1) {
            t = 30;
        }
        int timeout = t;
        log.debug("Timeout for this request: {} seconds.", (Object)timeout);
        String llmQuestion = params.getLlmQuestion();
        String string = llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
        if (llmModel == null) {
            throw new IllegalArgumentException("llm_model cannot be null.");
        }
        String conversationId = params.getConversationId();
        if (conversationId != null && !Strings.hasText((String)conversationId)) {
            throw new IllegalArgumentException("Empty conversation_id is not allowed.");
        }
        Instant start = Instant.now();
        Integer interactionSize = params.getInteractionSize();
        if (interactionSize == null || interactionSize == -1) {
            interactionSize = 10;
        }
        log.debug("Using interaction size of {}", (Object)interactionSize);
        Integer topN = params.getContextSize();
        if (topN == null) {
            topN = -1;
        }
        List<String> searchResults = this.getSearchResults(response, topN);
        String effectiveSystemPrompt = this.systemPrompt;
        String effectiveUserInstructions = this.userInstructions;
        if (params.getSystemPrompt() != null) {
            effectiveSystemPrompt = params.getSystemPrompt();
        }
        if (params.getUserInstructions() != null) {
            effectiveUserInstructions = params.getUserInstructions();
        }
        ArrayList chatHistory = new ArrayList();
        if (conversationId == null) {
            this.doChatCompletion(LlmIOUtil.createChatCompletionInput((String)this.systemPrompt, (String)this.userInstructions, (String)llmModel, (String)llmQuestion, chatHistory, searchResults, (int)timeout, (String)params.getLlmResponseField(), (List)params.getLlmMessages()), null, llmQuestion, searchResults, response, responseListener);
        } else {
            Instant memoryStart = Instant.now();
            this.memoryClient.getInteractions(conversationId, interactionSize, (ActionListener<List<Interaction>>)ActionListenerHelper.wrap(r -> {
                log.debug("getInteractions complete. ({})", (Object)this.getDuration(memoryStart));
                chatHistory.addAll(r);
                this.doChatCompletion(LlmIOUtil.createChatCompletionInput((String)this.systemPrompt, (String)this.userInstructions, (String)llmModel, (String)llmQuestion, (List)chatHistory, (List)searchResults, (int)timeout, (String)params.getLlmResponseField(), (List)params.getLlmMessages()), conversationId, llmQuestion, searchResults, response, responseListener);
            }, arg_0 -> responseListener.onFailure(arg_0)));
        }
    }

    private void doChatCompletion(ChatCompletionInput input, final String conversationId, final String llmQuestion, final List<String> searchResults, final SearchResponse response, final ActionListener<SearchResponse> responseListener) {
        final Instant chatStart = Instant.now();
        this.llm.doChatCompletion(input, (ActionListener)new ActionListener<ChatCompletionOutput>(){

            public void onResponse(ChatCompletionOutput output) {
                log.debug("doChatCompletion complete. ({})", (Object)GenerativeQAResponseProcessor.this.getDuration(chatStart));
                String answer = this.getAnswer(output);
                String errorMessage = this.getError(output);
                if (conversationId != null) {
                    Instant memoryStart = Instant.now();
                    GenerativeQAResponseProcessor.this.memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.getPromptTemplate((String)GenerativeQAResponseProcessor.this.systemPrompt, (String)GenerativeQAResponseProcessor.this.userInstructions), answer, "retrieval_augmented_generation", Collections.singletonMap("metadata", GenerativeQAResponseProcessor.jsonArrayToString(searchResults)), (ActionListener<String>)ActionListenerHelper.wrap(r -> {
                        responseListener.onResponse((Object)GenerativeQAResponseProcessor.this.insertAnswer(response, answer, errorMessage, (String)r));
                        log.info("Created a new interaction: {} ({})", r, (Object)GenerativeQAResponseProcessor.this.getDuration(memoryStart));
                    }, arg_0 -> ((ActionListener)responseListener).onFailure(arg_0)));
                } else {
                    responseListener.onResponse((Object)GenerativeQAResponseProcessor.this.insertAnswer(response, answer, errorMessage, null));
                }
            }

            public void onFailure(Exception e) {
                responseListener.onFailure(e);
            }

            private String getError(ChatCompletionOutput output) {
                return output.isErrorOccurred() ? (String)output.getErrors().get(0) : null;
            }

            private String getAnswer(ChatCompletionOutput output) {
                return output.isErrorOccurred() ? null : (String)output.getAnswers().get(0);
            }
        });
    }

    public String getType() {
        return "retrieval_augmented_generation";
    }

    private long getDuration(Instant start) {
        return Duration.between(start, Instant.now()).toMillis();
    }

    private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {
        return new GenerativeSearchResponse(answer, errorMessage, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(), response.getSkippedShards(), (long)response.getSuccessfulShards(), response.getShardFailures(), response.getClusters(), interactionId);
    }

    private List<String> getSearchResults(SearchResponse response, Integer topN) {
        ArrayList<String> searchResults = new ArrayList<String>();
        SearchHit[] hits = response.getHits().getHits();
        int total = hits.length;
        int end = topN != -1 ? Math.min(topN, total) : total;
        for (int i = 0; i < end; ++i) {
            Map docSourceMap = hits[i].getSourceAsMap();
            for (String contextField : this.contextFields) {
                Object context = docSourceMap.get(contextField);
                if (context == null) {
                    throw new RuntimeException("Context " + contextField + " not found in search hit " + String.valueOf(hits[i]));
                }
                searchResults.add(context.toString());
            }
        }
        return searchResults;
    }

    private static String jsonArrayToString(List<String> listOfStrings) {
        ArrayNode array = Strings.OBJECT_MAPPER.createArrayNode();
        listOfStrings.forEach(arg_0 -> ((ArrayNode)array).add(arg_0));
        return array.toString();
    }

    public Llm getLlm() {
        return this.llm;
    }

    public void setLlm(Llm llm) {
        this.llm = llm;
    }

    public void setMemoryClient(ConversationalMemoryClient memoryClient) {
        this.memoryClient = memoryClient;
    }

    public static final class Factory
    implements SearchProcessor.Factory<GenerativeQAResponseProcessor> {
        private final Client client;
        private final BooleanSupplier featureFlagSupplier;

        public Factory(Client client, BooleanSupplier supplier) {
            this.client = client;
            this.featureFlagSupplier = supplier;
        }

        public GenerativeQAResponseProcessor create(Map<String, SearchProcessor.Factory<GenerativeQAResponseProcessor>> processorFactories, String tag, String description, boolean ignoreFailure, Map<String, Object> config, SearchProcessor.PipelineContext pipelineContext) throws Exception {
            if (this.featureFlagSupplier.getAsBoolean()) {
                String modelId = ConfigurationUtils.readOptionalStringProperty((String)"retrieval_augmented_generation", (String)tag, config, (String)"model_id");
                String llmModel = ConfigurationUtils.readOptionalStringProperty((String)"retrieval_augmented_generation", (String)tag, config, (String)"llm_model");
                List contextFields = ConfigurationUtils.readList((String)"retrieval_augmented_generation", (String)tag, config, (String)"context_field_list");
                if (contextFields.isEmpty()) {
                    throw ConfigurationUtils.newConfigurationException((String)"retrieval_augmented_generation", (String)tag, (String)"context_field_list", (String)"required property can't be empty.");
                }
                String systemPrompt = ConfigurationUtils.readOptionalStringProperty((String)"retrieval_augmented_generation", (String)tag, config, (String)"system_prompt");
                String userInstructions = ConfigurationUtils.readOptionalStringProperty((String)"retrieval_augmented_generation", (String)tag, config, (String)"user_instructions");
                return new GenerativeQAResponseProcessor(this.client, tag, description, ignoreFailure, ModelLocator.getLlm(modelId, this.client), llmModel, contextFields, systemPrompt, userInstructions, this.featureFlagSupplier);
            }
            throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG);
        }
    }
}

