/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common.remote;

import io.skylite.core.common.Strings;
import io.skylite.core.http.HttpRequest;
import io.skylite.core.http.HttpRequestSigner;
import io.skylite.core.http.HttpSignerRegistry;
import io.skylite.core.script.Script;
import io.skylite.core.script.ScriptService;
import io.skylite.core.script.ScriptType;
import io.skylite.core.script.TemplateScript;
import io.skylite.ml.common.connector.Connector;
import io.skylite.ml.common.connector.ConnectorAction;
import io.skylite.ml.common.connector.MLPostProcessFunction;
import io.skylite.ml.common.connector.MLPreProcessFunction;
import io.skylite.ml.common.connector.functions.preprocess.DefaultPreProcessFunction;
import io.skylite.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction;
import io.skylite.ml.common.dataset.TextDocsInputDataSet;
import io.skylite.ml.common.dataset.TextSimilarityInputDataSet;
import io.skylite.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.model.MLGuard;
import io.skylite.ml.common.output.model.ModelTensor;
import io.skylite.ml.common.output.model.ModelTensors;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class ConnectorUtils {
    private static final Logger log = LogManager.getLogger(ConnectorUtils.class);
    private static final HttpRequestSigner signer;
    public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters";
    public static final List<String> SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES;

    public static HttpRequest signRequest(HttpRequest request, Map<String, String> signingParams) {
        return signer.sign(request, signingParams);
    }

    public static HttpRequest signRequest(String signerApiName, HttpRequest request, Map<String, String> signingParams) {
        HttpRequestSigner requestSigner = HttpSignerRegistry.getSigner((String)signerApiName);
        return requestSigner.sign(request, signingParams);
    }

    public static HttpRequestSigner getSigner(String signerApiName) {
        return HttpSignerRegistry.getSigner((String)signerApiName);
    }

    public static RemoteInferenceInputDataSet processInput(String action, MLInput mlInput, Connector connector, Map<String, String> parameters, ScriptService scriptService) {
        if (mlInput == null) {
            throw new IllegalArgumentException("Input is null");
        }
        Optional<ConnectorAction> connectorAction = connector.findAction(action);
        if (connectorAction.isEmpty()) {
            throw new IllegalArgumentException("no " + action + " action found");
        }
        RemoteInferenceInputDataSet inputData = ConnectorUtils.processMLInput(action, mlInput, connector, parameters, scriptService);
        ConnectorUtils.escapeRemoteInferenceInputData(inputData);
        return inputData;
    }

    private static RemoteInferenceInputDataSet processMLInput(String action, MLInput mlInput, Connector connector, Map<String, String> parameters, ScriptService scriptService) {
        String preProcessFunction = ConnectorUtils.getPreprocessFunction(action, mlInput, connector);
        if (preProcessFunction == null) {
            if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
                return (RemoteInferenceInputDataSet)mlInput.getInputDataset();
            }
            throw new IllegalArgumentException("pre_process_function not defined in connector");
        }
        if (MLPreProcessFunction.contains(preProcessFunction = ConnectorUtils.fillProcessFunctionParameter(parameters, preProcessFunction))) {
            Function<MLInput, RemoteInferenceInputDataSet> function = MLPreProcessFunction.get(preProcessFunction);
            return function.apply(mlInput);
        }
        if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
            if (parameters.containsKey("pre_process_function.process_remote_inference_input") && Boolean.parseBoolean(parameters.get("pre_process_function.process_remote_inference_input"))) {
                HashMap<String, String> params = new HashMap<String, String>();
                params.putAll(connector.getParameters());
                params.putAll(parameters);
                RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction, params);
                return function.apply(mlInput);
            }
            return (RemoteInferenceInputDataSet)mlInput.getInputDataset();
        }
        MLInput newInput = ConnectorUtils.escapeMLInput(mlInput);
        boolean convertInputToJsonString = parameters.containsKey("pre_process_function.convert_input_to_json_string") && Boolean.parseBoolean(parameters.get("pre_process_function.convert_input_to_json_string"));
        DefaultPreProcessFunction function = DefaultPreProcessFunction.builder().scriptService(scriptService).preProcessFunction(preProcessFunction).convertInputToJsonString(convertInputToJsonString).build();
        return function.apply(newInput);
    }

    private static MLInput escapeMLInput(MLInput mlInput) {
        if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
            List<String> docs = ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs();
            List newDocs = Strings.processTextDocs(docs);
            TextDocsInputDataSet newInputData = ((TextDocsInputDataSet)mlInput.getInputDataset()).toBuilder().docs(newDocs).build();
            return mlInput.toBuilder().inputDataset(newInputData).build();
        }
        if (mlInput.getInputDataset() instanceof TextSimilarityInputDataSet) {
            String query = ((TextSimilarityInputDataSet)mlInput.getInputDataset()).getQueryText();
            String newQuery = Strings.processTextDoc((String)query);
            List<String> docs = ((TextSimilarityInputDataSet)mlInput.getInputDataset()).getTextDocs();
            List newDocs = Strings.processTextDocs(docs);
            TextSimilarityInputDataSet newInputData = ((TextSimilarityInputDataSet)mlInput.getInputDataset()).toBuilder().queryText(newQuery).textDocs(newDocs).build();
            return mlInput.toBuilder().inputDataset(newInputData).build();
        }
        return mlInput;
    }

    public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet inputData) {
        HashMap<String, String> newParameters = new HashMap<String, String>();
        if (inputData.getParameters() != null) {
            inputData.getParameters().forEach((key, value) -> {
                if (value == null) {
                    newParameters.put((String)key, (String)null);
                } else if (Strings.isJson((String)value)) {
                    newParameters.put((String)key, (String)value);
                } else if ("response_filter".equals(key)) {
                    newParameters.put((String)key, (String)value);
                } else {
                    newParameters.put((String)key, StringEscapeUtils.escapeJson((String)value));
                }
            });
            inputData.setParameters(newParameters);
        }
    }

    private static String getPreprocessFunction(String action, MLInput mlInput, Connector connector) {
        Optional<ConnectorAction> connectorAction = connector.findAction(action);
        String preProcessFunction = connectorAction.get().getPreProcessFunction();
        if (preProcessFunction != null) {
            return preProcessFunction;
        }
        if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
            return "connector.pre_process.default.embedding";
        }
        return null;
    }

    public static ModelTensors processOutput(String action, String modelResponse, Connector connector, ScriptService scriptService, Map<String, String> parameters, MLGuard mlGuard) throws IOException {
        boolean scriptReturnModelTensor;
        if (modelResponse == null) {
            throw new IllegalArgumentException("model response is null");
        }
        if (mlGuard != null && !mlGuard.validate(modelResponse, MLGuard.Type.OUTPUT, Map.of("question", Strings.processTextDoc((String)modelResponse))).booleanValue()) {
            throw new IllegalArgumentException("guardrails triggered for LLM output");
        }
        ArrayList<ModelTensor> modelTensors = new ArrayList<ModelTensor>();
        Optional<ConnectorAction> connectorAction = connector.findAction(action);
        if (connectorAction.isEmpty()) {
            throw new IllegalArgumentException("no " + action + " action found");
        }
        String postProcessFunction = connectorAction.get().getPostProcessFunction();
        postProcessFunction = ConnectorUtils.fillProcessFunctionParameter(parameters, postProcessFunction);
        String responseFilter = parameters.get("response_filter");
        if (MLPostProcessFunction.contains(postProcessFunction)) {
            if (Strings.isBlank((CharSequence)responseFilter)) {
                responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction);
            }
            Object filteredOutput = Strings.extractJsonObjectField((String)modelResponse, (String)responseFilter);
            List<ModelTensor> processedResponse = MLPostProcessFunction.get(postProcessFunction).apply(filteredOutput);
            return ModelTensors.builder().mlModelTensors(processedResponse).build();
        }
        Optional<String> processedResponse = ConnectorUtils.executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
        String response = processedResponse.orElse(modelResponse);
        boolean bl = scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && Strings.isJson((String)response);
        if (responseFilter == null) {
            connector.parseResponse(response, modelTensors, scriptReturnModelTensor);
        } else {
            Object filteredResponse = Strings.extractJsonObjectField((String)response, (String)parameters.get("response_filter"));
            connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor);
        }
        return ModelTensors.builder().mlModelTensors(modelTensors).build();
    }

    private static String fillProcessFunctionParameter(Map<String, String> parameters, String processFunction) {
        if (processFunction != null && processFunction.contains("${parameters.")) {
            HashMap<String, String> tmpParameters = new HashMap<String, String>();
            for (String key : parameters.keySet()) {
                tmpParameters.put(key, Strings.toJson((Object)parameters.get(key)));
            }
            StringSubstitutor substitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}");
            processFunction = substitutor.replace(processFunction);
        }
        return processFunction;
    }

    public static ConnectorAction createConnectorAction(Connector connector, ConnectorAction.ActionType actionType) {
        HashMap<String, String> parameters;
        Optional<ConnectorAction> batchPredictAction = connector.findAction(ConnectorAction.ActionType.BATCH_PREDICT.name());
        String predictEndpoint = batchPredictAction.get().getUrl();
        HashMap<String, String> hashMap = parameters = connector.getParameters() != null ? new HashMap<String, String>(connector.getParameters()) : Collections.emptyMap();
        if (!parameters.isEmpty()) {
            StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
            predictEndpoint = substitutor.replace(predictEndpoint);
        }
        boolean isCancelAction = actionType == ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
        String method = "POST";
        String requestBody = null;
        Object url = "";
        switch (ConnectorUtils.getRemoteServerFromURL(predictEndpoint)) {
            case "sagemaker": {
                url = isCancelAction ? predictEndpoint.replace("CreateTransformJob", "StopTransformJob") : predictEndpoint.replace("CreateTransformJob", "DescribeTransformJob");
                requestBody = "{ \"TransformJobName\" : \"${parameters.TransformJobName}\"}";
                break;
            }
            case "openai": 
            case "cohere": {
                url = isCancelAction ? predictEndpoint + "/${parameters.id}/cancel" : predictEndpoint + "/${parameters.id}";
                method = isCancelAction ? "POST" : "GET";
                break;
            }
            case "bedrock": {
                url = isCancelAction ? predictEndpoint + "/${parameters.processedJobArn}/stop" : predictEndpoint + "/${parameters.processedJobArn}";
                method = isCancelAction ? "POST" : "GET";
                break;
            }
            default: {
                String errorMessage = isCancelAction ? "Please configure the action type to cancel the batch job in the connector" : "Please configure the action type to get the batch job details in the connector";
                throw new UnsupportedOperationException(errorMessage);
            }
        }
        return ConnectorAction.builder().actionType(actionType).method(method).url((String)url).requestBody(requestBody).headers(batchPredictAction.get().getHeaders()).build();
    }

    public static String getRemoteServerFromURL(String url) {
        return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse("");
    }

    public static Optional<String> executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) {
        Map result = Strings.fromJson((String)resultJson, (String)"result");
        if (postProcessFunction != null) {
            return Optional.ofNullable(ConnectorUtils.executeScript(scriptService, Strings.addDefaultMethod((String)postProcessFunction), result));
        }
        return Optional.empty();
    }

    private static String executeScript(ScriptService scriptService, String painlessScript, Map<String, Object> params) {
        Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap());
        TemplateScript templateScript = ((TemplateScript.Factory)scriptService.compile(script, TemplateScript.CONTEXT)).newInstance(params);
        return templateScript.execute();
    }

    static {
        SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List.of("sagemaker", "openai", "bedrock", "cohere");
        signer = HttpSignerRegistry.getHttpSigner();
    }
}

