/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common.input;

import io.skylite.core.ParseField;
import io.skylite.core.common.io.stream.StreamInput;
import io.skylite.core.common.io.stream.StreamOutput;
import io.skylite.core.xcontent.ObjectParser;
import io.skylite.core.xcontent.ToXContent;
import io.skylite.core.xcontent.XContentBuilder;
import io.skylite.core.xcontent.XContentParser;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.MLCommonsClassLoader;
import io.skylite.ml.common.connector.ConnectorAction;
import io.skylite.ml.common.dataframe.DataFrame;
import io.skylite.ml.common.dataframe.DefaultDataFrame;
import io.skylite.ml.common.dataset.DataFrameInputDataset;
import io.skylite.ml.common.dataset.MLInputDataset;
import io.skylite.ml.common.dataset.QuestionAnsweringInputDataSet;
import io.skylite.ml.common.dataset.TextDocsInputDataSet;
import io.skylite.ml.common.dataset.TextSimilarityInputDataSet;
import io.skylite.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import io.skylite.ml.common.input.Input;
import io.skylite.ml.common.input.parameter.MLAlgoParams;
import io.skylite.ml.common.output.model.ModelResultFilter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

public class MLInput
implements Input {
    public static final String ALGORITHM_FIELD = "algorithm";
    public static final String ML_PARAMETERS_FIELD = "parameters";
    public static final String INPUT_INDEX_FIELD = "input_index";
    public static final String INPUT_QUERY_FIELD = "input_query";
    public static final String INPUT_DATA_FIELD = "input_data";
    public static final String RETURN_BYTES_FIELD = "return_bytes";
    public static final String RETURN_NUMBER_FIELD = "return_number";
    public static final String TARGET_RESPONSE_FIELD = "target_response";
    public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";
    public static final String TEXT_DOCS_FIELD = "text_docs";
    public static final String QUERY_TEXT_FIELD = "query_text";
    public static final String PARAMETERS_FIELD = "parameters";
    public static final String QUESTION_FIELD = "question";
    public static final String CONTEXT_FIELD = "context";
    protected FunctionName algorithm;
    protected MLAlgoParams parameters;
    protected MLInputDataset inputDataset;
    private int version = 1;
    public static final ObjectParser<? extends MLInputBuilder, String> PARSER = MLInput.createParser(MLInputBuilder.class.getSimpleName(), algorithm -> new MLInputBuilder((String)algorithm));

    public MLInput() {
    }

    public MLInput(FunctionName algorithm, MLAlgoParams parameters, MLInputDataset inputDataset) {
        this.validate(algorithm);
        this.algorithm = algorithm;
        this.parameters = parameters;
        this.inputDataset = inputDataset;
    }

    private void validate(FunctionName algorithm) {
        if (algorithm == null) {
            throw new IllegalArgumentException("algorithm can't be null");
        }
    }

    public MLInput(StreamInput in) throws IOException {
        this.algorithm = (FunctionName)in.readEnum(FunctionName.class);
        if (in.readBoolean()) {
            this.parameters = (MLAlgoParams)MLCommonsClassLoader.initMLInstance(this.algorithm, in, StreamInput.class);
        }
        if (in.readBoolean()) {
            this.inputDataset = MLInputDataset.fromStream(in);
        }
        this.version = in.readInt();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeEnum((Enum)this.algorithm);
        if (this.parameters != null) {
            out.writeBoolean(true);
            this.parameters.writeTo(out);
        } else {
            out.writeBoolean(false);
        }
        if (this.inputDataset != null) {
            out.writeBoolean(true);
            this.inputDataset.writeTo(out);
        } else {
            out.writeBoolean(false);
        }
        out.writeInt(this.version);
    }

    public XContentBuilder extendedXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        return builder;
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(ALGORITHM_FIELD, this.algorithm.name());
        if (this.parameters != null) {
            builder.field("parameters", (ToXContent)this.parameters);
        }
        if (this.inputDataset != null) {
            switch (this.inputDataset.getInputDataType()) {
                case DATA_FRAME: {
                    builder.startObject(INPUT_DATA_FIELD);
                    ((DataFrameInputDataset)this.inputDataset).getDataFrame().toXContent(builder, EMPTY_PARAMS);
                    builder.endObject();
                    break;
                }
                case TEXT_DOCS: {
                    List<Integer> targetPositions;
                    TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet)this.inputDataset;
                    List<String> docs = textInputDataSet.getDocs();
                    ModelResultFilter resultFilter = textInputDataSet.getResultFilter();
                    if (docs != null && !docs.isEmpty()) {
                        builder.field(TEXT_DOCS_FIELD, (Object)docs.toArray(new String[0]));
                    }
                    if (resultFilter == null) break;
                    builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes());
                    builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber());
                    List<String> targetResponse = resultFilter.getTargetResponse();
                    if (targetResponse != null && !targetResponse.isEmpty()) {
                        builder.field(TARGET_RESPONSE_FIELD, (Object)targetResponse.toArray(new String[0]));
                    }
                    if ((targetPositions = resultFilter.getTargetResponsePositions()) == null || targetPositions.isEmpty()) break;
                    builder.field(TARGET_RESPONSE_POSITIONS_FIELD, (Object)targetPositions.toArray(new Integer[0]));
                    break;
                }
                case TEXT_SIMILARITY: {
                    TextSimilarityInputDataSet inputDataSet = (TextSimilarityInputDataSet)this.inputDataset;
                    List<String> documents = inputDataSet.getTextDocs();
                    String queryText = inputDataSet.getQueryText();
                    builder.field(QUERY_TEXT_FIELD, queryText);
                    if (documents == null || documents.isEmpty()) break;
                    builder.startArray(TEXT_DOCS_FIELD);
                    for (String d : documents) {
                        builder.value(d);
                    }
                    builder.endArray();
                    break;
                }
                case QUESTION_ANSWERING: {
                    QuestionAnsweringInputDataSet qaInputDataSet = (QuestionAnsweringInputDataSet)this.inputDataset;
                    String question = qaInputDataSet.getQuestion();
                    String context = qaInputDataSet.getContext();
                    builder.field(QUESTION_FIELD, question);
                    builder.field(CONTEXT_FIELD, context);
                    break;
                }
                case REMOTE: {
                    RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet)this.inputDataset;
                    Map<String, String> parameters = remoteInferenceInputDataSet.getParameters();
                    builder.field("parameters", parameters);
                    builder.field("action_type", (Object)remoteInferenceInputDataSet.getActionType());
                    break;
                }
                default: {
                    this.inputDataset.toXContent(builder, params);
                }
            }
        }
        builder.endObject();
        return builder;
    }

    protected static <T extends MLInputBuilder> ObjectParser<T, String> createParser(String name, Function<String, T> ctor) {
        ObjectParser parser = ObjectParser.fromBuilder((String)name, (boolean)true, ctor);
        parser.declareObject(MLInputBuilder::mlParameters, (p, c) -> (MLAlgoParams)p.namedObject(MLAlgoParams.class, c, null), new ParseField("parameters", new String[0]));
        parser.declareStringArray(MLInputBuilder::sourceIndices, new ParseField(INPUT_INDEX_FIELD, new String[0]));
        parser.declareObject(MLInputBuilder::dataFrame, (p, c) -> DefaultDataFrame.parse(p), new ParseField(INPUT_DATA_FIELD, new String[0]));
        parser.declareBoolean(MLInputBuilder::returnBytes, new ParseField(RETURN_BYTES_FIELD, new String[0]));
        parser.declareBoolean(MLInputBuilder::returnNumber, new ParseField(RETURN_NUMBER_FIELD, new String[0]));
        parser.declareStringArray(MLInputBuilder::targetResponse, new ParseField(TARGET_RESPONSE_FIELD, new String[0]));
        parser.declareIntArray(MLInputBuilder::targetResponsePositions, new ParseField(TARGET_RESPONSE_POSITIONS_FIELD, new String[0]));
        parser.declareStringArray(MLInputBuilder::textDocs, new ParseField(TEXT_DOCS_FIELD, new String[0]));
        parser.declareString(MLInputBuilder::queryText, new ParseField(QUERY_TEXT_FIELD, new String[0]));
        parser.declareString(MLInputBuilder::question, new ParseField(QUESTION_FIELD, new String[0]));
        parser.declareString(MLInputBuilder::context, new ParseField(CONTEXT_FIELD, new String[0]));
        return parser;
    }

    public static MLInput parse(XContentParser parser, String inputAlgoName, ConnectorAction.ActionType actionType) throws IOException {
        RemoteInferenceInputDataSet remoteInferenceInputDataSet;
        MLInput mlInput = MLInput.parse(parser, inputAlgoName);
        if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet && (remoteInferenceInputDataSet = (RemoteInferenceInputDataSet)mlInput.getInputDataset()).getActionType() == null) {
            remoteInferenceInputDataSet.setActionType(actionType);
        }
        return mlInput;
    }

    public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
        String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
        FunctionName algorithm = FunctionName.from(algorithmName);
        if (MLCommonsClassLoader.canInitMLInput(algorithm)) {
            MLInput mlInput = (MLInput)MLCommonsClassLoader.initMLInput(algorithm, new Object[]{parser, algorithm}, XContentParser.class, FunctionName.class);
            mlInput.setAlgorithm(algorithm);
            return mlInput;
        }
        return ((MLInputBuilder)PARSER.parse(parser, (Object)inputAlgoName)).build();
    }

    public static Builder builder() {
        return new Builder();
    }

    public Builder toBuilder() {
        return new Builder(this);
    }

    @Override
    public FunctionName getFunctionName() {
        return this.algorithm;
    }

    public FunctionName getAlgorithm() {
        return this.algorithm;
    }

    public void setAlgorithm(FunctionName algorithm) {
        this.algorithm = algorithm;
    }

    public MLAlgoParams getParameters() {
        return this.parameters;
    }

    public void setParameters(MLAlgoParams parameters) {
        this.parameters = parameters;
    }

    public MLInputDataset getInputDataset() {
        return this.inputDataset;
    }

    public int getVersion() {
        return this.version;
    }

    public void setVersion(int version) {
        this.version = version;
    }

    public void setInputDataset(MLInputDataset inputDataset) {
        this.inputDataset = inputDataset;
    }

    public boolean equals(Object o) {
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        MLInput mlInput = (MLInput)o;
        return this.version == mlInput.version && this.algorithm == mlInput.algorithm && Objects.equals(this.parameters, mlInput.parameters) && Objects.equals(this.inputDataset, mlInput.inputDataset);
    }

    public int hashCode() {
        return Objects.hash(new Object[]{this.algorithm, this.parameters, this.inputDataset, this.version});
    }

    public String toString() {
        return "MLInput{algorithm=" + String.valueOf((Object)this.algorithm) + ", parameters=" + String.valueOf(this.parameters) + ", inputDataset=" + String.valueOf(this.inputDataset) + ", version=" + this.version + "}";
    }

    public static class MLInputBuilder {
        protected MLAlgoParams mlParameters = null;
        protected List<String> sourceIndices = new ArrayList<String>();
        protected DataFrame dataFrame = null;
        protected boolean returnBytes = false;
        protected boolean returnNumber = true;
        protected List<String> targetResponse = new ArrayList<String>();
        protected List<Integer> targetResponsePositions = new ArrayList<Integer>();
        protected List<String> textDocs = new ArrayList<String>();
        protected String queryText = null;
        protected String question = null;
        protected String context = null;
        protected final FunctionName algorithm;

        public MLInputBuilder(String algorithm) {
            this.algorithm = FunctionName.from(algorithm.toUpperCase(Locale.ROOT));
        }

        protected ObjectParser<? extends MLInputBuilder, String> parser() {
            return PARSER;
        }

        public MLInputBuilder mlParameters(MLAlgoParams mlParameters) {
            this.mlParameters = mlParameters;
            return this;
        }

        public MLInputBuilder sourceIndices(List<String> sourceIndices) {
            this.sourceIndices = sourceIndices;
            return this;
        }

        public MLInputBuilder dataFrame(DataFrame dataFrame) {
            this.dataFrame = dataFrame;
            return this;
        }

        public MLInputBuilder returnBytes(boolean returnBytes) {
            this.returnBytes = returnBytes;
            return this;
        }

        public MLInputBuilder returnNumber(boolean returnNumber) {
            this.returnNumber = returnNumber;
            return this;
        }

        public MLInputBuilder targetResponse(List<String> targetResponse) {
            this.targetResponse = targetResponse;
            return this;
        }

        public MLInputBuilder targetResponsePositions(List<Integer> targetResponsePositions) {
            this.targetResponsePositions = targetResponsePositions;
            return this;
        }

        public MLInputBuilder textDocs(List<String> textDocs) {
            this.textDocs = textDocs;
            return this;
        }

        public MLInputBuilder queryText(String queryText) {
            this.queryText = queryText;
            return this;
        }

        public MLInputBuilder question(String question) {
            this.question = question;
            return this;
        }

        public MLInputBuilder context(String context) {
            this.context = context;
            return this;
        }

        protected MLInputDataset createInputDataset() {
            MLInputDataset inputDataSet = null;
            if (this.algorithm == FunctionName.TEXT_EMBEDDING || this.algorithm == FunctionName.SPARSE_ENCODING || this.algorithm == FunctionName.SPARSE_TOKENIZE) {
                ModelResultFilter filter = new ModelResultFilter(this.returnBytes, this.returnNumber, this.targetResponse, this.targetResponsePositions);
                inputDataSet = new TextDocsInputDataSet(this.textDocs, filter);
            } else if (this.algorithm == FunctionName.TEXT_SIMILARITY) {
                inputDataSet = new TextSimilarityInputDataSet(this.queryText, this.textDocs);
            } else if (this.algorithm == FunctionName.QUESTION_ANSWERING) {
                inputDataSet = new QuestionAnsweringInputDataSet(this.question, this.context);
            } else if (this.dataFrame != null) {
                inputDataSet = new DataFrameInputDataset(this.dataFrame);
            }
            return inputDataSet;
        }

        public MLInput build() {
            MLInputDataset inputDataSet = this.createInputDataset();
            if (inputDataSet != null) {
                return new MLInput(this.algorithm, this.mlParameters, inputDataSet);
            }
            if (this.dataFrame != null) {
                inputDataSet = new DataFrameInputDataset(this.dataFrame);
            }
            return new MLInput(this.algorithm, this.mlParameters, inputDataSet);
        }
    }

    public static class Builder {
        protected FunctionName algorithm;
        protected MLAlgoParams parameters;
        protected MLInputDataset inputDataset;

        public Builder() {
        }

        public Builder(MLInput mlInput) {
            this.algorithm = mlInput.algorithm;
            this.parameters = mlInput.parameters;
            this.inputDataset = mlInput.inputDataset;
        }

        public Builder algorithm(FunctionName algorithm) {
            this.algorithm = algorithm;
            return this;
        }

        public Builder parameters(MLAlgoParams parameters) {
            this.parameters = parameters;
            return this;
        }

        public Builder inputDataset(MLInputDataset inputDataset) {
            this.inputDataset = inputDataset;
            return this;
        }

        public MLInput build() {
            return new MLInput(this.algorithm, this.parameters, this.inputDataset);
        }
    }
}

