/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common.algorithms.sample;

import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.algorithms.util.ModelSerDeSer;
import io.skylite.ml.common.annotation.Function;
import io.skylite.ml.common.dataframe.DataFrame;
import io.skylite.ml.common.dataset.DataFrameInputDataset;
import io.skylite.ml.common.engine.Encryptor;
import io.skylite.ml.common.engine.Predictable;
import io.skylite.ml.common.engine.Trainable;
import io.skylite.ml.common.exception.MLException;
import io.skylite.ml.common.input.MLInput;
import io.skylite.ml.common.input.parameter.MLAlgoParams;
import io.skylite.ml.common.input.parameter.sample.SampleAlgoParams;
import io.skylite.ml.common.model.MLModel;
import io.skylite.ml.common.model.MLModelState;
import io.skylite.ml.common.output.MLOutput;
import io.skylite.ml.common.output.sample.SampleAlgoOutput;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;

@Function(value=FunctionName.SAMPLE_ALGO)
public class SampleAlgo
implements Trainable,
Predictable {
    public static final String VERSION = "1.0.0";
    private static final int DEFAULT_SAMPLE_PARAM = -1;
    private int sampleParam;

    public SampleAlgo() {
    }

    public SampleAlgo(MLAlgoParams parameters) {
        this.sampleParam = Optional.ofNullable(((SampleAlgoParams)parameters).getSampleParam()).orElse(-1);
    }

    @Override
    public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
        throw new MLException("Sample Algo doesn't support init model");
    }

    @Override
    public void close() {
        this.sampleParam = -1;
    }

    @Override
    public boolean isModelReady() {
        return true;
    }

    @Override
    public MLOutput predict(MLInput mlInput) {
        AtomicReference<Double> sum = new AtomicReference<Double>(0.0);
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        dataFrame.forEach(row -> row.forEach(item -> sum.updateAndGet(v -> v + item.doubleValue())));
        return SampleAlgoOutput.builder().sampleResult(sum.get()).build();
    }

    @Override
    public MLOutput predict(MLInput mlInput, MLModel model) {
        if (model == null) {
            throw new IllegalArgumentException("No model found for sample algo.");
        }
        return this.predict(mlInput);
    }

    @Override
    public MLModel train(MLInput mlInput) {
        MLModel model = MLModel.builder().name(FunctionName.SAMPLE_ALGO.name()).algorithm(FunctionName.SAMPLE_ALGO).version(VERSION).content(ModelSerDeSer.serializeToBase64("This is a sample testing model with parameter: " + this.sampleParam)).modelState(MLModelState.TRAINED).build();
        return model;
    }
}

