/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common.engine;

import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.util.Progress;
import io.skylite.common.SuppressForbidden;
import io.skylite.common.action.ActionListener;
import io.skylite.common.io.PathUtils;
import io.skylite.core.common.Strings;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.engine.MLEngine;
import io.skylite.ml.common.model.MLDeploySetting;
import io.skylite.ml.common.model.MLModelFormat;
import io.skylite.ml.common.model.QuestionAnsweringModelConfig;
import io.skylite.ml.common.model.TextEmbeddingModelConfig;
import io.skylite.ml.common.transport.register.MLRegisterModelInput;
import io.skylite.ml.common.utils.FileUtils;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public final class ModelDownloader {
    private static final Logger log = LogManager.getLogger(ModelDownloader.class);
    public static final String CHUNK_FILES = "chunk_files";
    public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes";
    public static final String MODEL_FILE_HASH = "model_file_hash";
    public static final int CHUNK_SIZE = 10000000;
    public static final String PYTORCH_FILE_EXTENSION = ".pt";
    public static final String ONNX_FILE_EXTENSION = ".onnx";
    public static final String TOKENIZER_FILE_NAME = "tokenizer.json";
    public static final String PYTORCH_ENGINE = "PyTorch";
    public static final String ONNX_ENGINE = "OnnxRuntime";

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void downloadPrebuiltModelConfig(MLEngine mlEngine, String taskId, MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelInput> listener) {
        String modelName = registerModelInput.getModelName();
        FunctionName algorithm = registerModelInput.getFunctionName();
        String version = registerModelInput.getVersion();
        MLModelFormat modelFormat = registerModelInput.getModelFormat();
        Boolean isHidden = registerModelInput.getIsHidden();
        boolean deployModel = registerModelInput.isDeployModel();
        String[] modelNodeIds = registerModelInput.getModelNodeIds();
        String modelGroupId = registerModelInput.getModelGroupId();
        MLDeploySetting mlDeploySetting = registerModelInput.getDeploySetting();
        try {
            AccessController.doPrivileged(() -> {
                Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
                String configCacheFilePath = registerModelPath.resolve("config.json").toString();
                String configFileUrl = mlEngine.getPrebuiltModelConfigPath(modelName, version, modelFormat);
                String modelZipFileUrl = mlEngine.getPrebuiltModelPath(modelName, version, modelFormat);
                DownloadUtils.download((String)configFileUrl, (String)configCacheFilePath, (Progress)new ProgressBar());
                Map config = Strings.readLargeJsonFile((Path)Path.of(configCacheFilePath, new String[0]));
                if (config == null) {
                    listener.onFailure((Exception)new IllegalArgumentException("model config not found"));
                    return null;
                }
                MLRegisterModelInput.MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder();
                String functionName = config.containsKey("function_name") ? (String)config.get("function_name") : (String)config.get("model_task_type");
                builder.modelName(modelName).version(version).url(modelZipFileUrl).deployModel(deployModel).modelNodeIds(modelNodeIds).isHidden(isHidden).modelGroupId(modelGroupId).functionName(FunctionName.from(functionName)).deploySetting(mlDeploySetting);
                config.entrySet().forEach(entry -> {
                    switch (entry.getKey().toString()) {
                        case "model_format": {
                            builder.modelFormat(MLModelFormat.from(entry.getValue().toString()));
                            break;
                        }
                        case "model_config": {
                            if (FunctionName.QUESTION_ANSWERING.equals((Object)algorithm)) {
                                QuestionAnsweringModelConfig.QuestionAnsweringModelConfigBuilder configBuilder = QuestionAnsweringModelConfig.builder();
                                Map configMap = (Map)entry.getValue();
                                for (Map.Entry configEntry : configMap.entrySet()) {
                                    switch (configEntry.getKey().toString()) {
                                        case "model_type": {
                                            configBuilder.modelType(configEntry.getValue().toString());
                                            break;
                                        }
                                        case "all_config": {
                                            configBuilder.allConfig(configEntry.getValue().toString());
                                            break;
                                        }
                                        case "framework_type": {
                                            configBuilder.frameworkType(QuestionAnsweringModelConfig.FrameworkType.from(configEntry.getValue().toString()));
                                            break;
                                        }
                                    }
                                }
                                builder.modelConfig(configBuilder.build());
                                break;
                            }
                            TextEmbeddingModelConfig.Builder configBuilder = TextEmbeddingModelConfig.builder();
                            Map configMap = (Map)entry.getValue();
                            for (Map.Entry configEntry : configMap.entrySet()) {
                                switch (configEntry.getKey().toString()) {
                                    case "model_type": {
                                        configBuilder.modelType(configEntry.getValue().toString());
                                        break;
                                    }
                                    case "all_config": {
                                        configBuilder.allConfig(configEntry.getValue().toString());
                                        break;
                                    }
                                    case "embedding_dimension": {
                                        configBuilder.embeddingDimension(((Number)configEntry.getValue()).intValue());
                                        break;
                                    }
                                    case "framework_type": {
                                        configBuilder.frameworkType(TextEmbeddingModelConfig.FrameworkType.from(configEntry.getValue().toString()));
                                        break;
                                    }
                                    case "pooling_mode": {
                                        configBuilder.poolingMode(TextEmbeddingModelConfig.PoolingMode.from(configEntry.getValue().toString().toUpperCase(Locale.ROOT)));
                                        break;
                                    }
                                    case "normalize_result": {
                                        configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString()));
                                        break;
                                    }
                                    case "model_max_length": {
                                        configBuilder.modelMaxLength(((Number)configEntry.getValue()).intValue());
                                        break;
                                    }
                                    case "query_prefix": {
                                        configBuilder.queryPrefix(configEntry.getValue().toString());
                                        break;
                                    }
                                    case "passage_prefix": {
                                        configBuilder.passagePrefix(configEntry.getValue().toString());
                                        break;
                                    }
                                }
                            }
                            builder.modelConfig(configBuilder.build());
                            break;
                        }
                        case "model_content_hash_value": {
                            builder.hashValue(entry.getValue().toString());
                            break;
                        }
                    }
                });
                listener.onResponse((Object)builder.build());
                return null;
            });
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
        finally {
            FileUtils.deleteFileQuietly(mlEngine.getRegisterModelPath(taskId));
        }
    }

    public boolean isModelAllowed(MLRegisterModelInput registerModelInput, List modelMetaList) {
        String modelName = registerModelInput.getModelName();
        String version = registerModelInput.getVersion();
        MLModelFormat modelFormat = registerModelInput.getModelFormat();
        for (Object meta : modelMetaList) {
            String name = (String)((Map)meta).get("name");
            List versions = (List)((Map)meta).get("version");
            List formats = (List)((Map)meta).get("format");
            if (!name.equals(modelName) || !versions.contains(version.toLowerCase(Locale.ROOT)) || !formats.contains(modelFormat.toString().toLowerCase(Locale.ROOT))) continue;
            return true;
        }
        return false;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public List downloadPrebuiltModelMetaList(MLEngine mlEngine, String taskId, MLRegisterModelInput registerModelInput) throws PrivilegedActionException {
        String modelName = registerModelInput.getModelName();
        String version = registerModelInput.getVersion();
        try {
            List list = AccessController.doPrivileged(() -> {
                Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
                String cacheFilePath = registerModelPath.resolve("model_meta_list.json").toString();
                String modelMetaListUrl = mlEngine.getPrebuiltModelMetaListPath();
                DownloadUtils.download((String)modelMetaListUrl, (String)cacheFilePath, (Progress)new ProgressBar());
                List config = Strings.readLargeJsonListFile((Path)Path.of(cacheFilePath, new String[0]));
                return config;
            });
            return list;
        }
        finally {
            FileUtils.deleteFileQuietly(mlEngine.getRegisterModelPath(taskId));
        }
    }

    @SuppressForbidden(reason="PathUtils needed to avoid dependency injection")
    public void downloadAndSplit(MLEngine mlEngine, MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, FunctionName functionName, ActionListener<Map<String, Object>> listener) {
        try {
            AccessController.doPrivileged(() -> {
                Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
                Path modelZipPath = PathUtils.get((String)(String.valueOf(registerModelPath) + ".zip"), (String[])new String[0]);
                Path modelPartsPath = registerModelPath.resolve("chunks");
                log.info("download model to file {}", (Object)modelZipPath.toAbsolutePath());
                DownloadUtils.download((String)url, (String)modelZipPath.toString(), (Progress)new ProgressBar());
                this.verifyModelZipFile(modelFormat, modelZipPath.toString(), modelName, functionName);
                String hash = FileUtils.calculateFileHash(modelZipPath);
                if (modelContentHash == null) {
                    log.error("Hash code need to be provided when register via url.");
                    throw new IllegalArgumentException("Model content Hash code need to be provided when register via url. Please calculate sha 256 Hash code.");
                }
                if (hash.equals(modelContentHash)) {
                    List<String> chunkFiles = FileUtils.splitFileIntoChunks(modelZipPath, modelPartsPath, 10000000);
                    HashMap<String, Object> result = new HashMap<String, Object>();
                    result.put(CHUNK_FILES, chunkFiles);
                    result.put(MODEL_SIZE_IN_BYTES, Files.size(modelZipPath));
                    result.put(MODEL_FILE_HASH, FileUtils.calculateFileHash(modelZipPath));
                    FileUtils.deleteFileQuietly(modelZipPath);
                    listener.onResponse(result);
                    return null;
                }
                log.error("Model content hash can't match original hash value when registering");
                throw new IllegalArgumentException("model content changed");
            });
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    @SuppressForbidden(reason="pure java zip")
    public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName, FunctionName functionName) throws IOException {
        boolean hasPtFile = false;
        boolean hasOnnxFile = false;
        boolean hasTokenizerFile = false;
        try (ZipFile zipFile = new ZipFile(modelZipFilePath);){
            Enumeration<? extends ZipEntry> zipEntries = zipFile.entries();
            while (zipEntries.hasMoreElements()) {
                String fileName = zipEntries.nextElement().getName();
                hasPtFile = ModelDownloader.hasModelFile(modelFormat, MLModelFormat.TORCH_SCRIPT, PYTORCH_FILE_EXTENSION, hasPtFile, fileName);
                hasOnnxFile = ModelDownloader.hasModelFile(modelFormat, MLModelFormat.ONNX, ONNX_FILE_EXTENSION, hasOnnxFile, fileName);
                if (!fileName.equals(TOKENIZER_FILE_NAME)) continue;
                hasTokenizerFile = true;
            }
        }
        if (!hasPtFile && !hasOnnxFile && functionName != FunctionName.SPARSE_TOKENIZE) {
            throw new IllegalArgumentException("Can't find model file");
        }
        if (!hasTokenizerFile && modelName != FunctionName.METRICS_CORRELATION.toString()) {
            throw new IllegalArgumentException("No tokenizer file");
        }
    }

    private static boolean hasModelFile(MLModelFormat modelFormat, MLModelFormat targetModelFormat, String fileExtension, boolean hasModelFile, String fileName) {
        if (fileName.endsWith(fileExtension)) {
            if (modelFormat != targetModelFormat) {
                throw new IllegalArgumentException("Model format is " + String.valueOf((Object)modelFormat) + ", but find " + fileExtension + " file");
            }
            if (hasModelFile) {
                throw new IllegalArgumentException("Find multiple model files, but expected only one");
            }
            return true;
        }
        return hasModelFile;
    }

    public void deleteFileCache(MLEngine mlEngine, String modelId) {
        FileUtils.deleteFileQuietly(mlEngine.getModelCachePath(modelId));
        FileUtils.deleteFileQuietly(mlEngine.getDeployModelPath(modelId));
        FileUtils.deleteFileQuietly(mlEngine.getRegisterModelPath(modelId));
    }
}

