/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.engine.algorithms.rcf;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
import io.lucenia.ml.common.engine.algorithms.rcf.RCFModelSerDeSer;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.algorithms.util.ModelSerDeSer;
import io.skylite.ml.common.annotation.Function;
import io.skylite.ml.common.dataframe.ColumnMeta;
import io.skylite.ml.common.dataframe.ColumnValue;
import io.skylite.ml.common.dataframe.DataFrame;
import io.skylite.ml.common.dataframe.DataFrameBuilder;
import io.skylite.ml.common.dataframe.Row;
import io.skylite.ml.common.dataset.DataFrameInputDataset;
import io.skylite.ml.common.engine.Encryptor;
import io.skylite.ml.common.engine.TrainAndPredictable;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.input.parameter.MLAlgoParams;
import io.skylite.ml.common.input.parameter.rcf.BatchRCFParams;
import io.skylite.ml.common.model.MLModel;
import io.skylite.ml.common.model.MLModelState;
import io.skylite.ml.common.output.MLOutput;
import io.skylite.ml.common.output.MLPredictionOutput;
import java.lang.constant.Constable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

@Function(value=FunctionName.BATCH_RCF)
public class BatchRandomCutForest
implements TrainAndPredictable {
    private static final Logger log = LogManager.getLogger(BatchRandomCutForest.class);
    public static final String VERSION = "1.0.0";
    private static final int DEFAULT_NUMBER_OF_TREES = 30;
    private static final int DEFAULT_OUTPUT_AFTER = 32;
    private static final int DEFAULT_SAMPLES_SIZE = 256;
    private static final double DEFAULT_ANOMALY_SCORE_THRESHOLD = 1.0;
    private Integer numberOfTrees = 30;
    private Integer sampleSize = 256;
    private Integer outputAfter = 32;
    private Double anomalyScoreThreshold = 1.0;
    private Integer trainingDataSize;
    private static final RandomCutForestMapper rcfMapper = new RandomCutForestMapper();
    private RandomCutForest forest;

    public BatchRandomCutForest() {
    }

    public BatchRandomCutForest(MLAlgoParams parameters) {
        rcfMapper.setSaveExecutorContextEnabled(true);
        if (parameters != null) {
            BatchRCFParams rcfParams = (BatchRCFParams)parameters;
            this.numberOfTrees = Optional.ofNullable(rcfParams.getNumberOfTrees()).orElse(30);
            this.sampleSize = Optional.ofNullable(rcfParams.getSampleSize()).orElse(256);
            this.outputAfter = Optional.ofNullable(rcfParams.getOutputAfter()).orElse(32);
            this.anomalyScoreThreshold = Optional.ofNullable(rcfParams.getAnomalyScoreThreshold()).orElse(1.0);
            this.trainingDataSize = rcfParams.getTrainingDataSize();
        }
    }

    public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
        RandomCutForestState state = RCFModelSerDeSer.deserializeRCF(model);
        this.forest = rcfMapper.toModel(state);
    }

    public void close() {
        this.forest = null;
    }

    public boolean isModelReady() {
        return this.forest != null;
    }

    public MLOutput predict(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        List<Map<String, Object>> predictResult = this.process(dataFrame, this.forest, 0);
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
    }

    public MLOutput predict(MLInput mlInput, MLModel model) {
        if (model == null) {
            throw new IllegalArgumentException("No model found for batch RCF prediction.");
        }
        RandomCutForestState state = RCFModelSerDeSer.deserializeRCF(model);
        this.forest = rcfMapper.toModel(state);
        return this.predict(mlInput);
    }

    public MLModel train(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        RandomCutForest forest = this.createRandomCutForest(dataFrame);
        Integer actualTrainingDataSize = this.trainingDataSize == null ? dataFrame.size() : this.trainingDataSize.intValue();
        this.process(dataFrame, forest, actualTrainingDataSize);
        RandomCutForestState state = rcfMapper.toState(forest);
        MLModel model = MLModel.builder().name(FunctionName.BATCH_RCF.name()).algorithm(FunctionName.BATCH_RCF).version(VERSION).content(ModelSerDeSer.encodeBase64((byte[])RCFModelSerDeSer.serializeRCF(state))).modelState(MLModelState.TRAINED).build();
        return model;
    }

    public MLOutput trainAndPredict(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        RandomCutForest forest = this.createRandomCutForest(dataFrame);
        Integer actualTrainingDataSize = this.trainingDataSize == null ? dataFrame.size() : this.trainingDataSize.intValue();
        List<Map<String, Object>> predictResult = this.process(dataFrame, forest, actualTrainingDataSize);
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
    }

    private List<Map<String, Object>> process(DataFrame dataFrame, RandomCutForest forest, Integer actualTrainingDataSize) {
        ArrayList<Double> pointList = new ArrayList<Double>();
        ColumnMeta[] columnMetas = dataFrame.columnMetas();
        ArrayList<Map<String, Object>> predictResult = new ArrayList<Map<String, Object>>();
        for (int rowNum = 0; rowNum < dataFrame.size(); ++rowNum) {
            for (int i = 0; i < columnMetas.length; ++i) {
                Row row = dataFrame.getRow(rowNum);
                ColumnValue value = row.getValue(i);
                pointList.add(value.doubleValue());
            }
            double[] point = pointList.stream().mapToDouble(d -> d).toArray();
            pointList.clear();
            double anomalyScore = forest.getAnomalyScore(point);
            if (actualTrainingDataSize == null || rowNum < actualTrainingDataSize) {
                forest.update(point);
            }
            HashMap<String, Constable> result = new HashMap<String, Constable>();
            result.put("score", Double.valueOf(anomalyScore));
            result.put("anomalous", Boolean.valueOf(anomalyScore > this.anomalyScoreThreshold));
            predictResult.add(result);
        }
        return predictResult;
    }

    private RandomCutForest createRandomCutForest(DataFrame dataFrame) {
        RandomCutForest forest = RandomCutForest.builder().dimensions(dataFrame.columnMetas().length).numberOfTrees(this.numberOfTrees.intValue()).sampleSize(this.sampleSize.intValue()).outputAfter(this.outputAfter.intValue()).parallelExecutionEnabled(false).build();
        return forest;
    }
}

