/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.engine.tools;

import io.skylite.common.action.ActionListener;
import io.skylite.core.action.ActionListenerHelper;
import io.skylite.core.action.ActionRequest;
import io.skylite.core.action.ActionType;
import io.skylite.core.client.Client;
import io.skylite.core.common.Strings;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.dataset.MLInputDataset;
import io.skylite.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import io.skylite.ml.common.engine.tools.Parser;
import io.skylite.ml.common.engine.tools.ToolAnnotation;
import io.skylite.ml.common.engine.tools.WithModelTool;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.output.model.ModelTensor;
import io.skylite.ml.common.output.model.ModelTensorOutput;
import io.skylite.ml.common.output.model.ModelTensors;
import io.skylite.ml.common.transport.prediction.MLPredictionTaskAction;
import io.skylite.ml.common.transport.prediction.MLPredictionTaskRequest;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

@ToolAnnotation(value="MLModelTool")
public class MLModelTool
implements WithModelTool {
    private static final Logger log = LogManager.getLogger(MLModelTool.class);
    public static final String TYPE = "MLModelTool";
    public static final String RESPONSE_FIELD = "response_field";
    public static final String MODEL_ID_FIELD = "model_id";
    public static final String DEFAULT_RESPONSE_FIELD = "response";
    private String name = "MLModelTool";
    static String DEFAULT_DESCRIPTION = "Use this tool to run any model.";
    private String description = DEFAULT_DESCRIPTION;
    private Client client;
    private String modelId;
    private Parser inputParser;
    private Parser outputParser;
    private String responseField;
    private Map<String, Object> attributes;

    public MLModelTool(Client client, String modelId, String responseField) {
        this.client = client;
        this.modelId = modelId;
        this.responseField = responseField;
        this.outputParser = o -> {
            try {
                List mlModelOutputs = (List)o;
                Map dataAsMap = ((ModelTensor)((ModelTensors)mlModelOutputs.get(0)).getMlModelTensors().get(0)).getDataAsMap();
                if (dataAsMap.containsKey(responseField)) {
                    return dataAsMap.get(responseField);
                }
                return Strings.toJson((Object)dataAsMap);
            }
            catch (Exception e) {
                throw new IllegalStateException("LLM returns wrong or empty tensors", e);
            }
        };
    }

    public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
        RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();
        String tenantId = null;
        if (parameters != null) {
            tenantId = parameters.get("tenant_id");
        }
        MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().modelId(this.modelId).mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)inputDataSet).build()).tenantId(tenantId).build();
        this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)request, ActionListenerHelper.wrap(r -> {
            ModelTensorOutput modelTensorOutput = (ModelTensorOutput)r.getOutput();
            modelTensorOutput.getMlModelOutputs();
            listener.onResponse(this.outputParser.parse((Object)modelTensorOutput.getMlModelOutputs()));
        }, e -> {
            log.error("Failed to run model {}", (Object)this.modelId);
            listener.onFailure(e);
        }));
    }

    public String getType() {
        return TYPE;
    }

    public String getVersion() {
        return null;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String s) {
        this.name = s;
    }

    public boolean validate(Map<String, String> parameters) {
        return parameters != null && parameters.size() != 0;
    }

    public String getDescription() {
        return this.description;
    }

    public void setDescription(String description) {
        this.description = description;
    }

    public Map<String, Object> getAttributes() {
        return this.attributes;
    }

    public void setAttributes(Map<String, Object> attributes) {
        this.attributes = attributes;
    }

    public Client getClient() {
        return this.client;
    }

    public String getModelId() {
        return this.modelId;
    }

    public void setInputParser(Parser inputParser) {
        this.inputParser = inputParser;
    }

    public Parser getOutputParser() {
        return this.outputParser;
    }

    public void setOutputParser(Parser outputParser) {
        this.outputParser = outputParser;
    }

    public String getResponseField() {
        return this.responseField;
    }

    public void setResponseField(String responseField) {
        this.responseField = responseField;
    }

    public static class Factory
    implements WithModelTool.Factory<MLModelTool> {
        private Client client;
        private static Factory INSTANCE;

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public static Factory getInstance() {
            if (INSTANCE != null) {
                return INSTANCE;
            }
            Class<MLModelTool> clazz = MLModelTool.class;
            synchronized (MLModelTool.class) {
                if (INSTANCE != null) {
                    // ** MonitorExit[var0] (shouldn't be in output)
                    return INSTANCE;
                }
                INSTANCE = new Factory();
                // ** MonitorExit[var0] (shouldn't be in output)
                return INSTANCE;
            }
        }

        public void init(Client client) {
            this.client = client;
        }

        public MLModelTool create(Map<String, Object> map) {
            return new MLModelTool(this.client, (String)map.get(MLModelTool.MODEL_ID_FIELD), (String)map.getOrDefault(MLModelTool.RESPONSE_FIELD, MLModelTool.DEFAULT_RESPONSE_FIELD));
        }

        public String getDefaultDescription() {
            return DEFAULT_DESCRIPTION;
        }

        public String getDefaultType() {
            return MLModelTool.TYPE;
        }

        public String getDefaultVersion() {
            return null;
        }

        public List<String> getAllModelKeys() {
            return List.of(MLModelTool.MODEL_ID_FIELD);
        }
    }
}

