/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common.model;

import io.skylite.common.TokenBucket;
import io.skylite.common.util.FastMath;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.engine.MLExecutable;
import io.skylite.ml.common.engine.Predictable;
import io.skylite.ml.common.model.MLGuard;
import io.skylite.ml.common.model.MLModel;
import io.skylite.ml.common.model.MLModelState;
import io.skylite.ml.common.stats.MLPredictRequestStats;
import java.time.Instant;
import java.util.DoubleSummaryStatistics;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.DoubleStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class MLModelCache {
    private static final Logger log = LogManager.getLogger(MLModelCache.class);
    private MLModelState modelState;
    private FunctionName functionName;
    private Predictable predictor;
    private MLExecutable executor;
    private TokenBucket rateLimiter;
    private Map<String, TokenBucket> userRateLimiterMap;
    private Boolean isModelEnabled;
    private final Set<String> targetWorkerNodes = ConcurrentHashMap.newKeySet();
    private final Set<String> workerNodes = ConcurrentHashMap.newKeySet();
    private MLModel modelInfo;
    private final Queue<Double> modelInferenceDurationQueue = new ConcurrentLinkedQueue<Double>();
    private final Queue<Double> predictRequestDurationQueue = new ConcurrentLinkedQueue<Double>();
    private Long memSizeEstimationCPU;
    private Long memSizeEstimationGPU;
    private MLGuard mlGuard;
    private Map<String, String> modelInterface;
    private Boolean deployToAllNodes;
    private Instant lastAccessTime;
    private Boolean isAutoDeploying;

    public void setTargetWorkerNodes(List<String> targetWorkerNodes) {
        if (targetWorkerNodes == null || targetWorkerNodes.size() == 0) {
            throw new IllegalArgumentException("Null or empty target worker nodes");
        }
        this.targetWorkerNodes.clear();
        this.targetWorkerNodes.addAll(targetWorkerNodes);
    }

    public String[] getTargetWorkerNodes() {
        return this.targetWorkerNodes.toArray(new String[0]);
    }

    public void removeWorkerNode(String nodeId, boolean isFromUndeploy) {
        if (this.isDeployToAllNodes() || isFromUndeploy) {
            this.targetWorkerNodes.remove(nodeId);
        }
        if (isFromUndeploy) {
            this.deployToAllNodes = false;
        }
        this.workerNodes.remove(nodeId);
        if (this.targetWorkerNodes.isEmpty() || this.workerNodes.isEmpty()) {
            this.modelInfo = null;
        }
    }

    public void removeWorkerNodes(Set<String> removedNodes, boolean isFromUndeploy) {
        if (this.isDeployToAllNodes() || isFromUndeploy) {
            this.targetWorkerNodes.removeAll(removedNodes);
        }
        if (isFromUndeploy) {
            this.deployToAllNodes = false;
        }
        this.workerNodes.removeAll(removedNodes);
        if (this.targetWorkerNodes.isEmpty() || this.workerNodes.isEmpty()) {
            this.modelInfo = null;
        }
    }

    public void addWorkerNode(String nodeId) {
        if (this.isDeployToAllNodes()) {
            this.targetWorkerNodes.add(nodeId);
        }
        this.workerNodes.add(nodeId);
    }

    public String[] getWorkerNodes() {
        return this.workerNodes.toArray(new String[0]);
    }

    public void setModelInfo(MLModel modelInfo) {
        this.modelInfo = modelInfo;
    }

    public MLModel getCachedModelInfo() {
        return this.modelInfo;
    }

    public void syncWorkerNode(Set<String> workerNodes) {
        this.workerNodes.clear();
        this.workerNodes.addAll(workerNodes);
    }

    public boolean isDeployToAllNodes() {
        return this.deployToAllNodes != null && this.deployToAllNodes != false;
    }

    public void clearWorkerNodes() {
        this.workerNodes.clear();
    }

    public void clear() {
        this.modelState = null;
        this.functionName = null;
        this.workerNodes.clear();
        this.modelInfo = null;
        this.modelInferenceDurationQueue.clear();
        this.predictRequestDurationQueue.clear();
        if (this.predictor != null) {
            this.predictor.close();
        }
        this.memSizeEstimationCPU = 0L;
        this.memSizeEstimationGPU = 0L;
        if (this.executor != null) {
            this.executor.close();
        }
        this.isModelEnabled = null;
        this.rateLimiter = null;
        this.userRateLimiterMap = null;
        this.mlGuard = null;
        this.modelInterface = null;
    }

    public void addModelInferenceDuration(double duration, long maxRequestCount) {
        this.addInferenceDuration(duration, maxRequestCount, this.modelInferenceDurationQueue);
    }

    public void addPredictRequestDuration(double duration, long maxRequestCount) {
        this.addInferenceDuration(duration, maxRequestCount, this.predictRequestDurationQueue);
    }

    private void addInferenceDuration(double duration, long maxRequestCount, Queue<Double> queue) {
        this.resizeInferenceQueue(maxRequestCount, queue);
        if (maxRequestCount > 0L) {
            queue.add(duration);
        }
    }

    public void resizeMonitoringQueue(long maxRequestCount) {
        log.debug("resize inference duration monitoring queue with size {}", (Object)maxRequestCount);
        this.resizeInferenceQueue(maxRequestCount, this.predictRequestDurationQueue);
        this.resizeInferenceQueue(maxRequestCount, this.modelInferenceDurationQueue);
    }

    private void resizeInferenceQueue(long maxRequestCount, Queue<Double> queue) {
        if (maxRequestCount <= 0L) {
            queue.clear();
        } else {
            while ((long)queue.size() >= maxRequestCount) {
                queue.poll();
            }
        }
    }

    public MLPredictRequestStats getInferenceStats(boolean modelInference) {
        Queue<Double> queue;
        Queue<Double> queue2 = queue = modelInference ? this.modelInferenceDurationQueue : this.predictRequestDurationQueue;
        if (queue.size() > 0) {
            MLPredictRequestStats.Builder statsBuilder = MLPredictRequestStats.builder();
            DoubleStream doubleStream = queue.stream().mapToDouble(v -> v);
            DoubleSummaryStatistics doubleSummaryStatistics = doubleStream.summaryStatistics();
            statsBuilder.count(doubleSummaryStatistics.getCount());
            statsBuilder.max(doubleSummaryStatistics.getMax());
            statsBuilder.min(doubleSummaryStatistics.getMin());
            statsBuilder.average(doubleSummaryStatistics.getAverage());
            statsBuilder.p50(FastMath.computePercentile(queue, (double)50.0));
            statsBuilder.p90(FastMath.computePercentile(queue, (double)90.0));
            statsBuilder.p99(FastMath.computePercentile(queue, (double)99.0));
            return statsBuilder.build();
        }
        return null;
    }

    public boolean isValidCache() {
        return this.modelState != null || this.workerNodes.size() > 0;
    }

    public MLModelState getModelState() {
        return this.modelState;
    }

    public MLModelCache setModelState(MLModelState modelState) {
        this.modelState = modelState;
        return this;
    }

    public FunctionName getFunctionName() {
        return this.functionName;
    }

    public MLModelCache setFunctionName(FunctionName functionName) {
        this.functionName = functionName;
        return this;
    }

    public Predictable getPredictor() {
        return this.predictor;
    }

    public MLModelCache setPredictor(Predictable predictor) {
        this.predictor = predictor;
        return this;
    }

    public MLExecutable getExecutor() {
        return this.executor;
    }

    public MLModelCache setExecutor(MLExecutable executor) {
        this.executor = executor;
        return this;
    }

    public TokenBucket getRateLimiter() {
        return this.rateLimiter;
    }

    public MLModelCache setRateLimiter(TokenBucket rateLimiter) {
        this.rateLimiter = rateLimiter;
        return this;
    }

    public Map<String, TokenBucket> getUserRateLimiterMap() {
        return this.userRateLimiterMap;
    }

    public MLModelCache setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap) {
        this.userRateLimiterMap = userRateLimiterMap;
        return this;
    }

    public Boolean getIsModelEnabled() {
        return this.isModelEnabled;
    }

    public MLModelCache setIsModelEnabled(Boolean modelEnabled) {
        this.isModelEnabled = modelEnabled;
        return this;
    }

    public Long getMemSizeEstimationCPU() {
        return this.memSizeEstimationCPU;
    }

    public MLModelCache setMemSizeEstimationCPU(Long memSizeEstimationCPU) {
        this.memSizeEstimationCPU = memSizeEstimationCPU;
        return this;
    }

    public Long getMemSizeEstimationGPU() {
        return this.memSizeEstimationGPU;
    }

    public MLModelCache setMemSizeEstimationGPU(Long memSizeEstimationGPU) {
        this.memSizeEstimationGPU = memSizeEstimationGPU;
        return this;
    }

    public MLGuard getMlGuard() {
        return this.mlGuard;
    }

    public MLModelCache setMlGuard(MLGuard mlGuard) {
        this.mlGuard = mlGuard;
        return this;
    }

    public Map<String, String> getModelInterface() {
        return this.modelInterface;
    }

    public MLModelCache setModelInterface(Map<String, String> modelInterface) {
        this.modelInterface = modelInterface;
        return this;
    }

    public Instant getLastAccessTime() {
        return this.lastAccessTime;
    }

    public MLModelCache setLastAccessTime(Instant lastAccessTime) {
        this.lastAccessTime = lastAccessTime;
        return this;
    }

    public Boolean getIsAutoDeploying() {
        return this.isAutoDeploying;
    }

    public MLModelCache setIsAutoDeploying(Boolean autoDeploying) {
        this.isAutoDeploying = autoDeploying;
        return this;
    }

    public void setDeployToAllNodes(Boolean deployToAllNodes) {
        this.deployToAllNodes = deployToAllNodes;
    }
}

