/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common.engine;

import io.skylite.common.action.ActionListener;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.connector.Connector;
import io.skylite.ml.common.connector.ConnectorAction;
import io.skylite.ml.common.dataframe.DataFrame;
import io.skylite.ml.common.dataset.DataFrameInputDataset;
import io.skylite.ml.common.dataset.MLInputDataset;
import io.skylite.ml.common.engine.Encryptor;
import io.skylite.ml.common.engine.Executable;
import io.skylite.ml.common.engine.MLEngineClassLoader;
import io.skylite.ml.common.engine.MLExecutable;
import io.skylite.ml.common.engine.Predictable;
import io.skylite.ml.common.engine.TrainAndPredictable;
import io.skylite.ml.common.engine.Trainable;
import io.skylite.ml.common.input.Input;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.input.parameter.MLAlgoParams;
import io.skylite.ml.common.model.MLModel;
import io.skylite.ml.common.model.MLModelFormat;
import io.skylite.ml.common.output.MLOutput;
import io.skylite.ml.common.output.Output;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class MLEngine {
    private static final Logger log = LogManager.getLogger(MLEngine.class);
    public static final String REGISTER_MODEL_FOLDER = "register";
    public static final String DEPLOY_MODEL_FOLDER = "deploy";
    private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";
    private final Path mlConfigPath;
    private final Path mlCachePath;
    private final Path mlModelsCachePath;
    private Encryptor encryptor;

    public MLEngine(Path luceniaDataFolder, Encryptor encryptor) {
        this.mlCachePath = luceniaDataFolder.resolve("ml_cache");
        this.mlModelsCachePath = this.mlCachePath.resolve("models_cache");
        this.mlConfigPath = this.mlCachePath.resolve("config");
        this.encryptor = encryptor;
    }

    public Path getMlConfigPath() {
        return this.mlConfigPath;
    }

    public Path getMlCachePath() {
        return this.mlCachePath;
    }

    public String getPrebuiltModelMetaListPath() {
        return "https://artifacts.opensearch.org/models/ml-models/model_listing/pre_trained_models.json";
    }

    public String getPrebuiltModelConfigPath(String modelName, String version, MLModelFormat modelFormat) {
        String format = modelFormat.name().toLowerCase(Locale.ROOT);
        return String.format(Locale.ROOT, "%s/%s/%s/%s/config.json", "https://artifacts.opensearch.org/models/ml-models", modelName, version, format);
    }

    public String getPrebuiltModelPath(String modelName, String version, MLModelFormat modelFormat) {
        int index = modelName.indexOf("/") + 1;
        String format = modelFormat.name().toLowerCase(Locale.ROOT);
        String modelZipFileName = modelName.substring(index).replace("/", "_") + "-" + version + "-" + format;
        return String.format(Locale.ROOT, "%s/%s/%s/%s/%s.zip", "https://artifacts.opensearch.org/models/ml-models", modelName, version, format, modelZipFileName);
    }

    public Path getRegisterModelPath(String modelId, String modelName, String version) {
        return this.getRegisterModelPath(modelId).resolve(version).resolve(modelName);
    }

    public Path getRegisterModelPath(String modelId) {
        return this.getRegisterModelRootPath().resolve(modelId);
    }

    public Path getRegisterModelRootPath() {
        return this.mlModelsCachePath.resolve(REGISTER_MODEL_FOLDER);
    }

    public Path getDeployModelPath(String modelId) {
        return this.getDeployModelRootPath().resolve(modelId);
    }

    public Path getDeployModelZipPath(String modelId, String modelName) {
        return this.mlModelsCachePath.resolve(DEPLOY_MODEL_FOLDER).resolve(modelId).resolve(modelName).resolveSibling(modelName + ".zip");
    }

    public Path getDeployModelRootPath() {
        return this.mlModelsCachePath.resolve(DEPLOY_MODEL_FOLDER);
    }

    public Path getDeployModelChunkPath(String modelId, Integer chunkNumber) {
        return this.mlModelsCachePath.resolve(DEPLOY_MODEL_FOLDER).resolve(modelId).resolve("chunks").resolve("" + chunkNumber);
    }

    public Path getModelCachePath(String modelId, String modelName, String version) {
        return this.getModelCachePath(modelId).resolve(version).resolve(modelName);
    }

    public Path getModelCachePath(String modelId) {
        return this.getModelCacheRootPath().resolve(modelId);
    }

    public Path getModelCacheRootPath() {
        return this.mlModelsCachePath.resolve("models");
    }

    public MLModel train(Input input) {
        this.validateMLInput(input);
        MLInput mlInput = (MLInput)input;
        Trainable trainable = (Trainable)MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
        if (trainable == null) {
            throw new IllegalArgumentException("Unsupported algorithm: " + String.valueOf((Object)mlInput.getAlgorithm()));
        }
        return trainable.train(mlInput);
    }

    public Map<String, String> getConnectorCredential(Connector connector) {
        connector.decrypt(ConnectorAction.ActionType.PREDICT.name(), (credential, tenantId) -> this.encryptor.decrypt((String)credential, connector.getTenantId()), connector.getTenantId());
        Map<String, String> decryptedCredential = connector.getDecryptedCredential();
        String region = connector.getParameters().get("region");
        if (region != null) {
            decryptedCredential.putIfAbsent("region", region);
        }
        return decryptedCredential;
    }

    public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
        Predictable predictable = (Predictable)MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
        predictable.initModel(mlModel, params, this.encryptor);
        return predictable;
    }

    public MLExecutable deployExecute(MLModel mlModel, Map<String, Object> params) {
        MLExecutable executable = (MLExecutable)MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
        executable.initModel(mlModel, params);
        return executable;
    }

    public MLOutput predict(Input input, MLModel model) {
        this.validateMLInput(input);
        MLInput mlInput = (MLInput)input;
        Predictable predictable = (Predictable)MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
        if (predictable == null) {
            throw new IllegalArgumentException("Unsupported algorithm: " + String.valueOf((Object)mlInput.getAlgorithm()));
        }
        return predictable.predict(mlInput, model);
    }

    public MLOutput trainAndPredict(Input input) {
        this.validateMLInput(input);
        MLInput mlInput = (MLInput)input;
        TrainAndPredictable trainAndPredictable = (TrainAndPredictable)MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
        if (trainAndPredictable == null) {
            throw new IllegalArgumentException("Unsupported algorithm: " + String.valueOf((Object)mlInput.getAlgorithm()));
        }
        return trainAndPredictable.trainAndPredict(mlInput);
    }

    public void execute(Input input, ActionListener<Output> listener) throws Exception {
        this.validateInput(input);
        if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) {
            MLExecutable executable = (MLExecutable)MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
            if (executable == null) {
                throw new IllegalArgumentException("Unsupported executable function: " + String.valueOf((Object)input.getFunctionName()));
            }
            executable.execute(input, listener);
        } else {
            Executable executable = (Executable)MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
            if (executable == null) {
                throw new IllegalArgumentException("Unsupported executable function: " + String.valueOf((Object)input.getFunctionName()));
            }
            executable.execute(input, listener);
        }
    }

    private void validateMLInput(Input input) {
        DataFrame dataFrame;
        this.validateInput(input);
        if (!(input instanceof MLInput)) {
            throw new IllegalArgumentException("Input should be MLInput");
        }
        MLInput mlInput = (MLInput)input;
        MLInputDataset inputDataset = mlInput.getInputDataset();
        if (inputDataset == null) {
            throw new IllegalArgumentException("Input data set should not be null");
        }
        if (inputDataset instanceof DataFrameInputDataset && ((dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame()) == null || dataFrame.size() == 0)) {
            throw new IllegalArgumentException("Input data frame should not be null or empty");
        }
    }

    private void validateInput(Input input) {
        if (input == null) {
            throw new IllegalArgumentException("Input should not be null");
        }
        if (input.getFunctionName() == null) {
            throw new IllegalArgumentException("Function name should not be null");
        }
    }

    public String encrypt(String credential, String tenantId) {
        return this.encryptor.encrypt(credential, tenantId);
    }
}

