/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.engine.algorithms.agent;

import io.lucenia.ml.common.action.memory.conversation.GetInteractionAction;
import io.lucenia.ml.common.action.memory.conversation.GetInteractionRequest;
import io.lucenia.ml.common.engine.algorithms.agent.MLChatAgentRunner;
import io.lucenia.ml.common.engine.algorithms.agent.MLConversationalFlowAgentRunner;
import io.lucenia.ml.common.engine.algorithms.agent.MLFlowAgentRunner;
import io.lucenia.ml.common.engine.memory.ConversationIndexMemory;
import io.skylite.ResourceNotFoundException;
import io.skylite.SkyliteExceptionsHelper;
import io.skylite.SkyliteStatusException;
import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.action.get.GetResponse;
import io.skylite.core.client.Client;
import io.skylite.core.client.ReleasableSkyliteClient;
import io.skylite.core.client.metadata.GetDataObjectRequest;
import io.skylite.core.client.metadata.MetadataClient;
import io.skylite.core.cluster.service.ClusterService;
import io.skylite.core.common.Strings;
import io.skylite.core.common.concurrent.ThreadContext;
import io.skylite.core.index.IndexNotFoundException;
import io.skylite.core.rest.RestStatus;
import io.skylite.core.search.fetch.subphase.FetchSourceContext;
import io.skylite.core.settings.Settings;
import io.skylite.core.xcontent.DeprecationHandler;
import io.skylite.core.xcontent.LoggingDeprecationHandler;
import io.skylite.core.xcontent.MediaTypeRegistry;
import io.skylite.core.xcontent.NamedXContentRegistry;
import io.skylite.core.xcontent.XContentParser;
import io.skylite.core.xcontent.XContentParserUtils;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.MLAgentType;
import io.skylite.ml.common.agent.MLAgent;
import io.skylite.ml.common.agent.MLMemorySpec;
import io.skylite.ml.common.algorithms.agent.MLAgentRunner;
import io.skylite.ml.common.annotation.Function;
import io.skylite.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import io.skylite.ml.common.engine.Executable;
import io.skylite.ml.common.engine.memory.ConversationIndexMessage;
import io.skylite.ml.common.engine.memory.Memory;
import io.skylite.ml.common.engine.memory.Message;
import io.skylite.ml.common.engine.tools.Tool;
import io.skylite.ml.common.input.Input;
import io.skylite.ml.common.input.execute.agent.AgentMLInput;
import io.skylite.ml.common.output.Output;
import io.skylite.ml.common.output.model.ModelTensor;
import io.skylite.ml.common.output.model.ModelTensorOutput;
import io.skylite.ml.common.output.model.ModelTensors;
import io.skylite.ml.common.settings.SettingsChangeListener;
import java.security.AccessController;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

@Function(value=FunctionName.AGENT)
public class MLAgentExecutor
implements Executable,
SettingsChangeListener {
    private static final Logger log = LogManager.getLogger(MLAgentExecutor.class);
    public static final String MEMORY_ID = "memory_id";
    public static final String QUESTION = "question";
    public static final String PARENT_INTERACTION_ID = "parent_interaction_id";
    public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id";
    public static final String MESSAGE_HISTORY_LIMIT = "message_history_limit";
    private Client client;
    private MetadataClient metadataClient;
    private Settings settings;
    private ClusterService clusterService;
    private NamedXContentRegistry xContentRegistry;
    private Map<String, Tool.Factory<?>> toolFactories;
    private Map<String, Memory.Factory> memoryFactoryMap;
    private volatile Boolean isMultiTenancyEnabled;

    public MLAgentExecutor() {
    }

    public MLAgentExecutor(Client client, MetadataClient metadataClient, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map<String, Tool.Factory<?>> toolFactories, Map<String, Memory.Factory> memoryFactoryMap, Boolean isMultiTenancyEnabled) {
        this.client = client;
        this.metadataClient = metadataClient;
        this.settings = settings;
        this.clusterService = clusterService;
        this.xContentRegistry = xContentRegistry;
        this.toolFactories = toolFactories;
        this.memoryFactoryMap = memoryFactoryMap;
        this.isMultiTenancyEnabled = isMultiTenancyEnabled;
    }

    public void onMultiTenancyEnabledChanged(boolean isEnabled) {
        this.isMultiTenancyEnabled = isEnabled;
    }

    public void execute(Input input, ActionListener<Output> listener) {
        if (!(input instanceof AgentMLInput)) {
            throw new IllegalArgumentException("wrong input");
        }
        AgentMLInput agentMLInput = (AgentMLInput)input;
        String agentId = agentMLInput.getAgentId();
        String tenantId = agentMLInput.getTenantId();
        RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)agentMLInput.getInputDataset();
        if (inputDataSet == null || inputDataSet.getParameters() == null) {
            throw new IllegalArgumentException("Agent input data can not be empty.");
        }
        if (this.isMultiTenancyEnabled.booleanValue() && tenantId == null) {
            throw new SkyliteStatusException("You don't have permission to access this resource", RestStatus.FORBIDDEN, new Object[0]);
        }
        ArrayList<ModelTensors> outputs = new ArrayList<ModelTensors>();
        ArrayList modelTensors = new ArrayList();
        outputs.add(ModelTensors.builder().mlModelTensors(modelTensors).build());
        FetchSourceContext fetchSourceContext = new FetchSourceContext(true, Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY);
        GetDataObjectRequest getDataObjectRequest = ((GetDataObjectRequest.Builder)((GetDataObjectRequest.Builder)((GetDataObjectRequest.Builder)GetDataObjectRequest.builder().index(".plugins-ml-agent")).id(agentId)).tenantId(tenantId)).fetchSourceContext(fetchSourceContext).build();
        if (this.clusterService.state().metadata().hasIndex(".plugins-ml-agent")) {
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.metadataClient.getDataObjectAsync(getDataObjectRequest, (Executor)this.client.threadPool().executor("lucenia_ml_general")).whenComplete((response, throwable) -> {
                    block19: {
                        context.restore();
                        log.debug("Completed Get Agent Request, Agent id:{}", (Object)agentId);
                        if (throwable != null) {
                            Exception cause = SkyliteExceptionsHelper.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                            if (SkyliteExceptionsHelper.unwrap((Throwable)cause, (Class[])new Class[]{IndexNotFoundException.class}) != null) {
                                log.error("Failed to get Agent index", (Throwable)cause);
                                listener.onFailure((Exception)new SkyliteStatusException("Failed to get agent index", RestStatus.NOT_FOUND, new Object[0]));
                            } else {
                                log.error("Failed to get ML Agent {}", (Object)agentId);
                                listener.onFailure(cause);
                            }
                        } else {
                            try {
                                GetResponse getAgentResponse;
                                GetResponse getResponse = getAgentResponse = response.parser() == null ? null : GetResponse.fromXContent((XContentParser)response.parser());
                                if (getAgentResponse != null && getAgentResponse.isExists()) {
                                    try (XContentParser parser = MediaTypeRegistry.JSON.xContent().createParser(this.xContentRegistry, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, getAgentResponse.getSourceAsString());){
                                        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                                        MLAgent mlAgent = MLAgent.parse((XContentParser)parser);
                                        if (this.isMultiTenancyEnabled.booleanValue() && !Objects.equals(tenantId, mlAgent.getTenantId())) {
                                            listener.onFailure((Exception)new SkyliteStatusException("You don't have permission to access this resource", RestStatus.FORBIDDEN, new Object[0]));
                                        }
                                        MLMemorySpec memorySpec = mlAgent.getMemory();
                                        String memoryId = (String)inputDataSet.getParameters().get(MEMORY_ID);
                                        String parentInteractionId = (String)inputDataSet.getParameters().get(PARENT_INTERACTION_ID);
                                        String regenerateInteractionId = (String)inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID);
                                        String appType = mlAgent.getAppType();
                                        String question = (String)inputDataSet.getParameters().get(QUESTION);
                                        if (memoryId == null && regenerateInteractionId != null) {
                                            throw new IllegalArgumentException("A memory ID must be provided to regenerate.");
                                        }
                                        if (memorySpec != null && memorySpec.getType() != null && this.memoryFactoryMap.containsKey(memorySpec.getType()) && (memoryId == null || parentInteractionId == null)) {
                                            ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory)this.memoryFactoryMap.get(memorySpec.getType());
                                            conversationIndexMemoryFactory.create(question, memoryId, appType, (ActionListener<ConversationIndexMemory>)ActionListenerHelper.wrap(memory -> {
                                                inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId());
                                                ActionListener<Object> agentActionListener = this.createAgentActionListener(listener, outputs, modelTensors, mlAgent.getType());
                                                if (regenerateInteractionId != null) {
                                                    log.info("Regenerate for existing interaction {}", (Object)regenerateInteractionId);
                                                    this.client.execute((ActionType)GetInteractionAction.INSTANCE, (ActionRequest)new GetInteractionRequest(regenerateInteractionId), ActionListenerHelper.wrap(interactionRes -> {
                                                        inputDataSet.getParameters().putIfAbsent(QUESTION, interactionRes.getInteraction().getInput());
                                                        this.saveRootInteractionAndExecute(agentActionListener, (ConversationIndexMemory)memory, inputDataSet, mlAgent);
                                                    }, e -> {
                                                        log.error("Failed to get existing interaction for regeneration", (Throwable)e);
                                                        listener.onFailure(e);
                                                    }));
                                                } else {
                                                    this.saveRootInteractionAndExecute(agentActionListener, (ConversationIndexMemory)memory, inputDataSet, mlAgent);
                                                }
                                            }, ex -> {
                                                log.error("Failed to read conversation memory", (Throwable)ex);
                                                listener.onFailure(ex);
                                            }));
                                        } else {
                                            ActionListener<Object> agentActionListener = this.createAgentActionListener(listener, outputs, modelTensors, mlAgent.getType());
                                            this.executeAgent(inputDataSet, mlAgent, agentActionListener);
                                        }
                                        break block19;
                                    }
                                    catch (Exception e) {
                                        log.error("Failed to parse ml agent {}", (Object)agentId);
                                        listener.onFailure(e);
                                    }
                                    break block19;
                                }
                                listener.onFailure((Exception)new SkyliteStatusException("Failed to find agent with the provided agent id: " + agentId, RestStatus.NOT_FOUND, new Object[0]));
                            }
                            catch (Exception e) {
                                log.error("Failed to get agent", (Throwable)e);
                                listener.onFailure(e);
                            }
                        }
                    }
                });
            }
        } else {
            listener.onFailure((Exception)new ResourceNotFoundException("Agent index not found", new Object[0]));
        }
    }

    private void saveRootInteractionAndExecute(ActionListener<Object> listener, ConversationIndexMemory memory, RemoteInferenceInputDataSet inputDataSet, MLAgent mlAgent) {
        String appType = mlAgent.getAppType();
        String question = (String)inputDataSet.getParameters().get(QUESTION);
        String regenerateInteractionId = (String)inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID);
        ConversationIndexMessage msg = ConversationIndexMessage.conversationIndexMessageBuilder().type(appType).question(question).response("").finalAnswer(Boolean.valueOf(true)).sessionId(memory.getConversationId()).build();
        memory.save((Message)msg, null, null, null, ActionListenerHelper.wrap(interaction -> {
            log.info("Created parent interaction ID: {}", (Object)interaction.getId());
            inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId());
            if (regenerateInteractionId != null) {
                memory.getMemoryManager().deleteInteractionAndTrace(regenerateInteractionId, (ActionListener<Boolean>)ActionListenerHelper.wrap(deleted -> this.executeAgent(inputDataSet, mlAgent, listener), e -> {
                    log.error("Failed to regenerate for interaction {}", (Object)regenerateInteractionId);
                    listener.onFailure(e);
                }));
            } else {
                this.executeAgent(inputDataSet, mlAgent, listener);
            }
        }, ex -> {
            log.error("Failed to create parent interaction", (Throwable)ex);
            listener.onFailure(ex);
        }));
    }

    private void executeAgent(RemoteInferenceInputDataSet inputDataSet, MLAgent mlAgent, ActionListener<Object> agentActionListener) {
        MLAgentRunner mlAgentRunner = this.getAgentRunner(mlAgent);
        mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
    }

    private ActionListener<Object> createAgentActionListener(ActionListener<Output> listener, List<ModelTensors> outputs, List<ModelTensor> modelTensors, String agentType) {
        return ActionListenerHelper.wrap(output -> {
            if (output != null) {
                if (output instanceof ModelTensorOutput) {
                    ModelTensorOutput modelTensorOutput = (ModelTensorOutput)output;
                    modelTensorOutput.getMlModelOutputs().forEach(outs -> modelTensors.addAll(outs.getMlModelTensors()));
                } else if (output instanceof ModelTensor) {
                    modelTensors.add((ModelTensor)output);
                } else if (output instanceof List) {
                    if (((List)output).get(0) instanceof ModelTensor) {
                        modelTensors.addAll((List)output);
                    } else if (((List)output).get(0) instanceof ModelTensors) {
                        ((List)output).forEach(outs -> modelTensors.addAll(outs.getMlModelTensors()));
                    } else {
                        String result = AccessController.doPrivileged(() -> Strings.toJson((Object)output));
                        modelTensors.add(ModelTensor.builder().name("response").result(result).build());
                    }
                } else {
                    String result = output instanceof String ? (String)output : AccessController.doPrivileged(() -> Strings.toJson((Object)output));
                    modelTensors.add(ModelTensor.builder().name("response").result(result).build());
                }
                listener.onResponse((Object)ModelTensorOutput.builder().mlModelOutputs(outputs).build());
            } else {
                listener.onResponse(null);
            }
        }, ex -> {
            log.error("Failed to run {} agent", (Object)agentType);
            listener.onFailure(ex);
        });
    }

    protected MLAgentRunner getAgentRunner(MLAgent mlAgent) {
        MLAgentType agentType = MLAgentType.from((String)mlAgent.getType().toUpperCase(Locale.ROOT));
        switch (agentType) {
            case FLOW: {
                return new MLFlowAgentRunner((ReleasableSkyliteClient)this.client, this.settings, this.clusterService, this.xContentRegistry, this.toolFactories, this.memoryFactoryMap);
            }
            case CONVERSATIONAL_FLOW: {
                return new MLConversationalFlowAgentRunner(this.client, this.settings, this.clusterService, this.xContentRegistry, this.toolFactories, this.memoryFactoryMap);
            }
            case CONVERSATIONAL: {
                return new MLChatAgentRunner(this.client, this.settings, this.clusterService, this.xContentRegistry, this.toolFactories, this.memoryFactoryMap);
            }
        }
        throw new IllegalArgumentException("Unsupported agent type: " + mlAgent.getType());
    }

    public Client getClient() {
        return this.client;
    }

    public void setClient(Client client) {
        this.client = client;
    }

    public MetadataClient getMetadataClient() {
        return this.metadataClient;
    }

    public void setMetadataClient(MetadataClient metadataClient) {
        this.metadataClient = metadataClient;
    }

    public Settings getSettings() {
        return this.settings;
    }

    public void setSettings(Settings settings) {
        this.settings = settings;
    }

    public ClusterService getClusterService() {
        return this.clusterService;
    }

    public void setClusterService(ClusterService clusterService) {
        this.clusterService = clusterService;
    }

    public NamedXContentRegistry getxContentRegistry() {
        return this.xContentRegistry;
    }

    public void setxContentRegistry(NamedXContentRegistry xContentRegistry) {
        this.xContentRegistry = xContentRegistry;
    }

    public Map<String, Tool.Factory<?>> getToolFactories() {
        return this.toolFactories;
    }

    public void setToolFactories(Map<String, Tool.Factory<?>> toolFactories) {
        this.toolFactories = toolFactories;
    }

    public Map<String, Memory.Factory> getMemoryFactoryMap() {
        return this.memoryFactoryMap;
    }

    public void setMemoryFactoryMap(Map<String, Memory.Factory> memoryFactoryMap) {
        this.memoryFactoryMap = memoryFactoryMap;
    }

    public Boolean getMultiTenancyEnabled() {
        return this.isMultiTenancyEnabled;
    }

    public void setMultiTenancyEnabled(Boolean multiTenancyEnabled) {
        this.isMultiTenancyEnabled = multiTenancyEnabled;
    }

    public boolean equals(Object o) {
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        MLAgentExecutor that = (MLAgentExecutor)o;
        return Objects.equals(this.client, that.client) && Objects.equals(this.metadataClient, that.metadataClient) && Objects.equals(this.settings, that.settings) && Objects.equals(this.clusterService, that.clusterService) && Objects.equals(this.xContentRegistry, that.xContentRegistry) && Objects.equals(this.toolFactories, that.toolFactories) && Objects.equals(this.memoryFactoryMap, that.memoryFactoryMap) && Objects.equals(this.isMultiTenancyEnabled, that.isMultiTenancyEnabled);
    }

    public int hashCode() {
        return Objects.hash(this.client, this.metadataClient, this.settings, this.clusterService, this.xContentRegistry, this.toolFactories, this.memoryFactoryMap, this.isMultiTenancyEnabled);
    }

    public String toString() {
        return "MLAgentExecutor{client=" + String.valueOf(this.client) + ", metadataClient=" + String.valueOf(this.metadataClient) + ", settings=" + String.valueOf(this.settings) + ", clusterService=" + String.valueOf(this.clusterService) + ", xContentRegistry=" + String.valueOf(this.xContentRegistry) + ", toolFactories=" + String.valueOf(this.toolFactories) + ", memoryFactoryMap=" + String.valueOf(this.memoryFactoryMap) + ", isMultiTenancyEnabled=" + this.isMultiTenancyEnabled + "}";
    }
}

