/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.engine.algorithms.clustering;

import io.lucenia.ml.common.engine.algorithms.util.TribuoOutputType;
import io.lucenia.ml.common.engine.algorithms.util.TribuoUtil;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.algorithms.util.ModelSerDeSer;
import io.skylite.ml.common.annotation.Function;
import io.skylite.ml.common.dataframe.DataFrame;
import io.skylite.ml.common.dataframe.DataFrameBuilder;
import io.skylite.ml.common.dataset.DataFrameInputDataset;
import io.skylite.ml.common.engine.Encryptor;
import io.skylite.ml.common.engine.TrainAndPredictable;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.input.parameter.MLAlgoParams;
import io.skylite.ml.common.input.parameter.clustering.KMeansParams;
import io.skylite.ml.common.model.MLModel;
import io.skylite.ml.common.model.MLModelState;
import io.skylite.ml.common.output.MLOutput;
import io.skylite.ml.common.output.MLPredictionOutput;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.MutableDataset;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.clustering.kmeans.KMeansModel;
import org.tribuo.clustering.kmeans.KMeansTrainer;

@Function(value=FunctionName.KMEANS)
public class KMeans
implements TrainAndPredictable {
    public static final String VERSION = "1.0.0";
    private static final KMeansParams.DistanceType DEFAULT_DISTANCE_TYPE = KMeansParams.DistanceType.EUCLIDEAN;
    private static int DEFAULT_CENTROIDS = 2;
    private static int DEFAULT_ITERATIONS = 10;
    private KMeansParams parameters;
    private int numThreads = Math.max(Runtime.getRuntime().availableProcessors() / 2, 1);
    private long seed = System.currentTimeMillis();
    private KMeansTrainer.Distance distance;
    private KMeansModel kMeansModel;

    public KMeans() {
    }

    public KMeans(MLAlgoParams parameters) {
        this.parameters = parameters == null ? KMeansParams.builder().build() : (KMeansParams)parameters;
        this.validateParameters();
        this.createDistance();
    }

    private void validateParameters() {
        if (this.parameters.getCentroids() != null && this.parameters.getCentroids() <= 0) {
            throw new IllegalArgumentException("K should be positive.");
        }
        if (this.parameters.getIterations() != null && this.parameters.getIterations() <= 0) {
            throw new IllegalArgumentException("Iterations should be positive.");
        }
    }

    private void createDistance() {
        KMeansParams.DistanceType distanceType = Optional.ofNullable(this.parameters.getDistanceType()).orElse(DEFAULT_DISTANCE_TYPE);
        switch (distanceType) {
            case COSINE: {
                this.distance = KMeansTrainer.Distance.COSINE;
                break;
            }
            case L1: {
                this.distance = KMeansTrainer.Distance.L1;
                break;
            }
            default: {
                this.distance = KMeansTrainer.Distance.EUCLIDEAN;
            }
        }
    }

    public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
        this.kMeansModel = (KMeansModel)ModelSerDeSer.deserialize((MLModel)model);
    }

    public void close() {
        this.kMeansModel = null;
    }

    public boolean isModelReady() {
        return this.kMeansModel != null;
    }

    public MLOutput predict(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        MutableDataset predictionDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans prediction data from opensearch", TribuoOutputType.CLUSTERID);
        List predictions = this.kMeansModel.predict(predictionDataset);
        ArrayList listClusterID = new ArrayList();
        predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", ((ClusterID)e.getOutput()).getID())));
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
    }

    public MLOutput predict(MLInput mlInput, MLModel model) {
        if (model == null) {
            throw new IllegalArgumentException("No model found for KMeans prediction.");
        }
        this.kMeansModel = (KMeansModel)ModelSerDeSer.deserialize((MLModel)model);
        return this.predict(mlInput);
    }

    public MLModel train(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        MutableDataset trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training data from opensearch", TribuoOutputType.CLUSTERID);
        Integer centroids = Optional.ofNullable(this.parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
        Integer iterations = Optional.ofNullable(this.parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
        KMeansTrainer trainer = new KMeansTrainer(centroids.intValue(), iterations.intValue(), this.distance, this.numThreads, this.seed);
        KMeansModel kMeansModel = trainer.train(trainDataset);
        MLModel model = MLModel.builder().name(FunctionName.KMEANS.name()).algorithm(FunctionName.KMEANS).version(VERSION).content(ModelSerDeSer.serializeToBase64((Object)kMeansModel)).modelState(MLModelState.TRAINED).build();
        return model;
    }

    public MLOutput trainAndPredict(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        MutableDataset trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training and predicting data from opensearch", TribuoOutputType.CLUSTERID);
        Integer centroids = Optional.ofNullable(this.parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
        Integer iterations = Optional.ofNullable(this.parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
        KMeansTrainer trainer = new KMeansTrainer(centroids.intValue(), iterations.intValue(), this.distance, this.numThreads, this.seed);
        KMeansModel kMeansModel = trainer.train(trainDataset);
        List predictions = kMeansModel.predict(trainDataset);
        ArrayList listClusterID = new ArrayList();
        predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", ((ClusterID)e.getOutput()).getID())));
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
    }
}

