/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common.algorithms;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import io.skylite.common.SuppressForbidden;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.engine.Encryptor;
import io.skylite.ml.common.engine.MLEngine;
import io.skylite.ml.common.engine.ModelDownloader;
import io.skylite.ml.common.engine.Predictable;
import io.skylite.ml.common.exception.MLException;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.model.MLModel;
import io.skylite.ml.common.model.MLModelConfig;
import io.skylite.ml.common.model.MLModelFormat;
import io.skylite.ml.common.output.MLOutput;
import io.skylite.ml.common.output.model.ModelResultFilter;
import io.skylite.ml.common.output.model.ModelTensorOutput;
import io.skylite.ml.common.output.model.ModelTensors;
import io.skylite.ml.common.utils.FileUtils;
import io.skylite.ml.common.utils.ZipUtils;
import java.io.IOException;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public abstract class DLModel
implements Predictable {
    private static final Logger log = LogManager.getLogger(DLModel.class);
    public static final String MODEL_ZIP_FILE = "model_zip_file";
    public static final String MODEL_HELPER = "model_helper";
    public static final String ML_ENGINE = "ml_engine";
    protected ModelDownloader modelDownloader;
    protected MLEngine mlEngine;
    protected String modelId;
    protected Predictor<Input, Output>[] predictors;
    protected ZooModel[] models;
    protected Device[] devices;
    protected AtomicInteger nextDevice = new AtomicInteger(0);
    protected MLModelConfig modelConfig;

    @Override
    public MLOutput predict(MLInput mlInput, MLModel model) {
        throw new IllegalArgumentException("model not deployed");
    }

    @Override
    public MLOutput predict(MLInput mlInput) {
        if (this.modelDownloader == null || this.modelId == null) {
            throw new IllegalArgumentException("model not deployed");
        }
        try {
            return AccessController.doPrivileged(() -> {
                Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader());
                if (!this.isModelReady()) {
                    throw new MLException("model not deployed.");
                }
                return this.predict(this.modelId, mlInput);
            });
        }
        catch (Throwable e) {
            String errorMsg = "Failed to inference " + String.valueOf((Object)mlInput.getAlgorithm()) + " model: " + this.modelId;
            log.error(errorMsg, e);
            throw new MLException(errorMsg, e);
        }
    }

    protected Predictor<Input, Output> getPredictor() {
        int currentDevice = this.nextDevice.getAndIncrement();
        if (currentDevice > this.devices.length - 1) {
            this.nextDevice.set((currentDevice %= this.devices.length) + 1);
        }
        return this.predictors[currentDevice];
    }

    public abstract ModelTensorOutput predict(String var1, MLInput var2) throws TranslateException;

    @Override
    public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
        String engine = switch (model.getModelFormat()) {
            case MLModelFormat.TORCH_SCRIPT -> "PyTorch";
            case MLModelFormat.ONNX -> "OnnxRuntime";
            default -> throw new IllegalArgumentException("unsupported engine");
        };
        Path modelZipPath = (Path)params.get(MODEL_ZIP_FILE);
        this.modelDownloader = (ModelDownloader)params.get(MODEL_HELPER);
        this.mlEngine = (MLEngine)params.get(ML_ENGINE);
        if (modelZipPath == null) {
            throw new IllegalArgumentException("model file is null");
        }
        if (this.modelDownloader == null) {
            throw new IllegalArgumentException("model helper is null");
        }
        if (this.mlEngine == null) {
            throw new IllegalArgumentException("ML engine is null");
        }
        this.modelId = model.getModelId();
        if (this.modelId == null) {
            throw new IllegalArgumentException("model id is null");
        }
        if (!FunctionName.isDLModel(model.getAlgorithm())) {
            throw new IllegalArgumentException("wrong function name");
        }
        this.loadModel(modelZipPath, this.modelId, model.getName(), model.getVersion(), model.getModelConfig(), engine);
    }

    @Override
    public void close() {
        if (this.modelDownloader != null && this.modelId != null) {
            this.modelDownloader.deleteFileCache(this.mlEngine, this.modelId);
            if (this.predictors != null) {
                this.closePredictors(this.predictors);
                this.predictors = null;
            }
            if (this.models != null) {
                this.closeModels(this.models);
                this.models = null;
            }
        }
    }

    @Override
    public boolean isModelReady() {
        return this.predictors != null && this.modelDownloader != null && this.modelId != null;
    }

    public abstract Translator<Input, Output> getTranslator(String var1, MLModelConfig var2);

    public abstract TranslatorFactory getTranslatorFactory(String var1, MLModelConfig var2);

    public Map<String, Object> getArguments(MLModelConfig modelConfig) {
        return null;
    }

    public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
    }

    protected void doLoadModel(List<Predictor<Input, Output>> predictorList, List<ZooModel<Input, Output>> modelList, String engine, Path modelPath, MLModelConfig modelConfig) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
        try {
            AccessController.doPrivileged(() -> {
                this.devices = Engine.getEngine((String)engine).getDevices();
                this.modelConfig = modelConfig;
                for (int i = 0; i < this.devices.length; ++i) {
                    log.debug("load model {} to device {}: {}", (Object)this.modelId, (Object)i, (Object)this.devices[i]);
                    Criteria.Builder criteriaBuilder = Criteria.builder().setTypes(Input.class, Output.class).optApplication(Application.UNDEFINED).optEngine(engine).optDevice(this.devices[i]).optModelPath(modelPath);
                    Translator<Input, Output> translator = this.getTranslator(engine, modelConfig);
                    TranslatorFactory translatorFactory = this.getTranslatorFactory(engine, modelConfig);
                    if (translatorFactory != null) {
                        criteriaBuilder.optTranslatorFactory(translatorFactory);
                    } else if (translator != null) {
                        criteriaBuilder.optTranslator(translator);
                    }
                    Map<String, Object> arguments = this.getArguments(modelConfig);
                    if (arguments != null && arguments.size() > 0) {
                        for (Map.Entry<String, Object> entry : arguments.entrySet()) {
                            criteriaBuilder.optArgument(entry.getKey(), entry.getValue());
                        }
                    }
                    Criteria criteria = criteriaBuilder.build();
                    ZooModel model = criteria.loadModel();
                    Predictor predictor = model.newPredictor();
                    predictorList.add(predictor);
                    modelList.add(model);
                    this.warmUp(predictor, this.modelId, modelConfig);
                }
                if (predictorList.size() > 0) {
                    this.predictors = predictorList.toArray(new Predictor[0]);
                    predictorList.clear();
                }
                if (modelList.size() > 0) {
                    this.models = modelList.toArray(new ZooModel[0]);
                    modelList.clear();
                }
                log.info("Model {} is successfully deployed on {} devices", (Object)this.modelId, (Object)this.devices.length);
                return null;
            });
        }
        catch (PrivilegedActionException e) {
            Exception cause = e.getException();
            if (cause instanceof ModelNotFoundException) {
                throw (ModelNotFoundException)cause;
            }
            if (cause instanceof MalformedModelException) {
                throw (MalformedModelException)cause;
            }
            if (cause instanceof IOException) {
                throw (IOException)cause;
            }
            if (cause instanceof TranslateException) {
                throw (TranslateException)cause;
            }
            if (cause instanceof RuntimeException) {
                throw (RuntimeException)cause;
            }
            throw new RuntimeException("Failed to load model", cause);
        }
    }

    @SuppressForbidden(reason="Fix mutable system props")
    protected void loadModel(Path modelZipPath, String modelId, String modelName, String version, MLModelConfig modelConfig, String engine) {
        try {
            if (!"PyTorch".equals(engine) && !"OnnxRuntime".equals(engine)) {
                throw new IllegalArgumentException("unsupported engine");
            }
            ArrayList predictorList = new ArrayList();
            ArrayList modelList = new ArrayList();
            AccessController.doPrivileged(() -> {
                ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
                try {
                    System.setProperty("PYTORCH_PRECXX11", "true");
                    System.setProperty("PYTORCH_VERSION", "1.13.1");
                    System.setProperty("DJL_CACHE_DIR", this.mlEngine.getMlCachePath().toAbsolutePath().toString());
                    System.setProperty("java.library.path", this.mlEngine.getMlCachePath().toAbsolutePath().toString());
                    System.setProperty("ai.djl.pytorch.num_interop_threads", "1");
                    System.setProperty("ai.djl.pytorch.num_threads", "1");
                    Thread.currentThread().setContextClassLoader(Model.class.getClassLoader());
                    Path modelPath = this.mlEngine.getModelCachePath(modelId, modelName, version);
                    if (Files.exists(modelPath, new LinkOption[0])) {
                        FileUtils.deleteDirectory(modelPath);
                    }
                    ZipUtils.unzip(modelZipPath, modelPath);
                    boolean findModelFile = false;
                    try (DirectoryStream<Path> stream = Files.newDirectoryStream(modelPath);){
                        for (Path file : stream) {
                            String name = file.getFileName().toString();
                            if (!name.endsWith(".pt") && !name.endsWith(".onnx")) continue;
                            if (findModelFile) {
                                throw new IllegalArgumentException("found multiple models in the ZIP file.");
                            }
                            findModelFile = true;
                            int dotIndex = name.lastIndexOf(".");
                            String suffix = name.substring(dotIndex);
                            String targetModelFileName = modelPath.getFileName().toString();
                            if (targetModelFileName.equals(name.substring(0, dotIndex))) continue;
                            Files.move(file, modelPath.resolve(targetModelFileName + suffix), StandardCopyOption.REPLACE_EXISTING);
                        }
                    }
                    this.doLoadModel(predictorList, modelList, engine, modelPath, modelConfig);
                    Void void_ = null;
                    return void_;
                }
                catch (Throwable e) {
                    String errorMessage = "Failed to deploy model " + modelId;
                    log.error(errorMessage, e);
                    this.close();
                    if (!predictorList.isEmpty()) {
                        this.closePredictors(predictorList.toArray(new Predictor[0]));
                        predictorList.clear();
                    }
                    if (!modelList.isEmpty()) {
                        this.closeModels(modelList.toArray(new ZooModel[0]));
                        modelList.clear();
                    }
                    throw new MLException(errorMessage, e);
                }
                finally {
                    FileUtils.deleteFileQuietly(this.mlEngine.getDeployModelPath(modelId));
                    Thread.currentThread().setContextClassLoader(contextClassLoader);
                }
            });
        }
        catch (PrivilegedActionException e) {
            String errorMsg = "Failed to deploy model " + modelId;
            log.error(errorMsg, (Throwable)e);
            throw new MLException(errorMsg, e);
        }
    }

    protected void closePredictors(Predictor[] predictors) {
        log.debug("will close {} predictor for model {}", (Object)predictors.length, (Object)this.modelId);
        for (Predictor predictor : predictors) {
            predictor.close();
        }
    }

    protected void closeModels(ZooModel[] models) {
        log.debug("will close {} zoo model for model {}", (Object)models.length, (Object)this.modelId);
        for (ZooModel model : models) {
            model.close();
        }
    }

    public ModelTensors parseModelTensorOutput(Output output, ModelResultFilter resultFilter) {
        if (output == null) {
            throw new MLException("No output generated");
        }
        byte[] bytes = output.getData().getAsBytes();
        ModelTensors tensorOutput = ModelTensors.fromBytes(bytes);
        if (resultFilter != null) {
            tensorOutput.filter(resultFilter);
        }
        return tensorOutput;
    }
}

