/*
 * Decompiled with CFR 0.152.
 */
package io.lucenia.ml.common.engine.algorithms.clustering;

import io.lucenia.ml.common.engine.algorithms.clustering.SampleSummary;
import io.lucenia.ml.common.engine.algorithms.clustering.SerializableSummary;
import io.lucenia.ml.common.engine.algorithms.clustering.Summarizer;
import io.lucenia.ml.common.engine.algorithms.util.TribuoUtil;
import io.skylite.common.Randomness;
import io.skylite.common.collect.Tuple;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.algorithms.util.ModelSerDeSer;
import io.skylite.ml.common.annotation.Function;
import io.skylite.ml.common.dataframe.DataFrame;
import io.skylite.ml.common.dataframe.DataFrameBuilder;
import io.skylite.ml.common.dataset.DataFrameInputDataset;
import io.skylite.ml.common.engine.Encryptor;
import io.skylite.ml.common.engine.TrainAndPredictable;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.input.parameter.MLAlgoParams;
import io.skylite.ml.common.input.parameter.clustering.RCFSummarizeParams;
import io.skylite.ml.common.model.MLModel;
import io.skylite.ml.common.model.MLModelState;
import io.skylite.ml.common.output.MLOutput;
import io.skylite.ml.common.output.MLPredictionOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.function.BiFunction;

@Function(value=FunctionName.RCF_SUMMARIZE)
public class RCFSummarize
implements TrainAndPredictable {
    public static final String VERSION = "1.0.0";
    private static final RCFSummarizeParams.DistanceType DEFAULT_DISTANCE_TYPE = RCFSummarizeParams.DistanceType.L2;
    private static int DEFAULT_MAX_K = 10;
    private static boolean DEFAULT_PHASE1_REASSIGN = true;
    private static boolean DEFAULT_PARALLEL = false;
    private final Random rnd = Randomness.get();
    private RCFSummarizeParams parameters;
    private BiFunction<float[], float[], Double> distance;
    private SampleSummary summary;

    public RCFSummarize() {
    }

    public RCFSummarize(MLAlgoParams parameters) {
        this.parameters = parameters == null ? RCFSummarizeParams.builder().maxK(Integer.valueOf(DEFAULT_MAX_K)).initialK(Integer.valueOf(DEFAULT_MAX_K)).phase1Reassign(Boolean.valueOf(DEFAULT_PHASE1_REASSIGN)).parallel(Boolean.valueOf(DEFAULT_PARALLEL)).build() : (RCFSummarizeParams)parameters;
        this.validateParametersAndRefine();
        this.createDistance();
    }

    private void validateParametersAndRefine() {
        Boolean phase1Reassign = this.parameters.getPhase1Reassign();
        Boolean parallel = this.parameters.getParallel();
        Integer maxK = this.parameters.getMaxK();
        Integer initialK = this.parameters.getInitialK();
        RCFSummarizeParams.DistanceType distType = this.parameters.getDistanceType();
        if (maxK != null && maxK <= 0) {
            throw new IllegalArgumentException("max K should be positive");
        }
        if (initialK != null && initialK <= 0) {
            throw new IllegalArgumentException("initial K should be positive");
        }
        if (maxK == null) {
            maxK = DEFAULT_MAX_K;
        }
        if (initialK == null) {
            initialK = maxK;
        }
        if (distType == null) {
            distType = DEFAULT_DISTANCE_TYPE;
        }
        if (phase1Reassign == null) {
            phase1Reassign = false;
        }
        if (parallel == null) {
            parallel = false;
        }
        this.parameters = RCFSummarizeParams.builder().maxK(maxK).initialK(initialK).phase1Reassign(phase1Reassign).parallel(parallel).distanceType(distType).build();
    }

    private void createDistance() {
        RCFSummarizeParams.DistanceType distanceType = Optional.ofNullable(this.parameters.getDistanceType()).orElse(DEFAULT_DISTANCE_TYPE);
        switch (distanceType) {
            case L1: {
                this.distance = Summarizer::L1distance;
                break;
            }
            case L2: {
                this.distance = Summarizer::L2distance;
                break;
            }
            case LInfinity: {
                this.distance = Summarizer::LInfinitydistance;
                break;
            }
            default: {
                this.distance = Summarizer::L2distance;
            }
        }
    }

    public MLModel train(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        Tuple<String[], float[][]> featureNamesValues = TribuoUtil.transformDataFrameFloat(dataFrame);
        SampleSummary summary = Summarizer.summarize((float[][])featureNamesValues.v2(), (int)this.parameters.getMaxK(), (int)this.parameters.getInitialK(), (boolean)this.parameters.getPhase1Reassign(), this.distance, this.rnd.nextLong(), this.parameters.getParallel());
        MLModel model = MLModel.builder().name(FunctionName.RCF_SUMMARIZE.name()).algorithm(FunctionName.RCF_SUMMARIZE).version(VERSION).content(ModelSerDeSer.serializeToBase64((Object)new SerializableSummary(summary))).modelState(MLModelState.TRAINED).build();
        return model;
    }

    public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
        this.summary = ((SerializableSummary)ModelSerDeSer.deserialize((MLModel)model)).getSummary();
    }

    public void close() {
        this.summary = null;
    }

    public boolean isModelReady() {
        return this.summary != null;
    }

    public MLOutput predict(MLInput mlInput) {
        List centroidsLst = Arrays.asList(this.summary.summaryPoints);
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        Tuple<String[], float[][]> featureNamesValues = TribuoUtil.transformDataFrameFloat(dataFrame);
        ArrayList predictions = new ArrayList();
        Arrays.stream((float[][])featureNamesValues.v2()).forEach(e -> predictions.add(RCFSummarize.findNearest(e, centroidsLst, this.distance)));
        ArrayList listClusterID = new ArrayList();
        predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e)));
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
    }

    public MLOutput predict(MLInput mlInput, MLModel model) {
        if (model == null) {
            throw new IllegalArgumentException("No model found for RCFSummarize prediction.");
        }
        this.summary = ((SerializableSummary)ModelSerDeSer.deserialize((MLModel)model)).getSummary();
        return this.predict(mlInput);
    }

    public MLOutput trainAndPredict(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        Tuple<String[], float[][]> featureNamesValues = TribuoUtil.transformDataFrameFloat(dataFrame);
        SampleSummary summary = Summarizer.summarize((float[][])featureNamesValues.v2(), (int)this.parameters.getMaxK(), (int)this.parameters.getInitialK(), (boolean)this.parameters.getPhase1Reassign(), this.distance, this.rnd.nextLong(), this.parameters.getParallel());
        List centroidsLst = Arrays.asList(summary.summaryPoints);
        ArrayList predictions = new ArrayList();
        Arrays.stream((float[][])featureNamesValues.v2()).forEach(e -> predictions.add(RCFSummarize.findNearest(e, centroidsLst, this.distance)));
        ArrayList listClusterID = new ArrayList();
        predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e)));
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
    }

    public static <T, C extends Number> int findNearest(T query, Iterable<T> base, BiFunction<T, T, C> dist) {
        int index = -1;
        double minValue = Double.MAX_VALUE;
        int i = 0;
        for (T e : base) {
            double d = ((Number)dist.apply(query, e)).doubleValue();
            if (d < minValue) {
                minValue = d;
                index = i;
            }
            ++i;
        }
        return index;
    }
}

