/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common.model;

import io.skylite.core.ParseField;
import io.skylite.core.common.io.stream.StreamInput;
import io.skylite.core.common.io.stream.StreamOutput;
import io.skylite.core.xcontent.NamedXContentRegistry;
import io.skylite.core.xcontent.ToXContent;
import io.skylite.core.xcontent.XContentBuilder;
import io.skylite.core.xcontent.XContentParser;
import io.skylite.core.xcontent.XContentParserUtils;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.model.MLModelConfig;
import java.io.IOException;
import java.util.Locale;

public class TextEmbeddingModelConfig
extends MLModelConfig {
    public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name();
    public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(TextEmbeddingModelConfig.class, new ParseField(PARSE_FIELD_NAME, new String[0]), it -> TextEmbeddingModelConfig.parse(it));
    public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension";
    public static final String FRAMEWORK_TYPE_FIELD = "framework_type";
    public static final String POOLING_MODE_FIELD = "pooling_mode";
    public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
    public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";
    public static final String QUERY_PREFIX = "query_prefix";
    public static final String PASSAGE_PREFIX = "passage_prefix";
    private final Integer embeddingDimension;
    private final FrameworkType frameworkType;
    private final PoolingMode poolingMode;
    private final boolean normalizeResult;
    private final Integer modelMaxLength;
    private final String queryPrefix;
    private final String passagePrefix;

    public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig, PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
        this(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, null, null);
    }

    public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig, PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength, String queryPrefix, String passagePrefix) {
        super(modelType, allConfig);
        if (embeddingDimension == null) {
            throw new IllegalArgumentException("embedding dimension is null");
        }
        if (frameworkType == null) {
            throw new IllegalArgumentException("framework type is null");
        }
        this.embeddingDimension = embeddingDimension;
        this.frameworkType = frameworkType;
        this.poolingMode = poolingMode;
        this.normalizeResult = normalizeResult;
        this.modelMaxLength = modelMaxLength;
        this.queryPrefix = queryPrefix;
        this.passagePrefix = passagePrefix;
    }

    public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException {
        String modelType = null;
        Integer embeddingDimension = null;
        FrameworkType frameworkType = null;
        String allConfig = null;
        PoolingMode poolingMode = null;
        boolean normalizeResult = false;
        Integer modelMaxLength = null;
        String queryPrefix = null;
        String passagePrefix = null;
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
        block22: while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();
            switch (fieldName) {
                case "model_type": {
                    modelType = parser.text();
                    continue block22;
                }
                case "embedding_dimension": {
                    embeddingDimension = parser.intValue();
                    continue block22;
                }
                case "framework_type": {
                    frameworkType = FrameworkType.from(parser.text().toUpperCase(Locale.ROOT));
                    continue block22;
                }
                case "all_config": {
                    allConfig = parser.text();
                    continue block22;
                }
                case "pooling_mode": {
                    poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT));
                    continue block22;
                }
                case "normalize_result": {
                    normalizeResult = parser.booleanValue();
                    continue block22;
                }
                case "model_max_length": {
                    modelMaxLength = parser.intValue();
                    continue block22;
                }
                case "query_prefix": {
                    queryPrefix = parser.text();
                    continue block22;
                }
                case "passage_prefix": {
                    passagePrefix = parser.text();
                    continue block22;
                }
            }
            parser.skipChildren();
        }
        return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, queryPrefix, passagePrefix);
    }

    public String getWriteableName() {
        return PARSE_FIELD_NAME;
    }

    public TextEmbeddingModelConfig(StreamInput in) throws IOException {
        super(in);
        this.embeddingDimension = in.readInt();
        this.frameworkType = (FrameworkType)in.readEnum(FrameworkType.class);
        this.poolingMode = in.readBoolean() ? (PoolingMode)in.readEnum(PoolingMode.class) : null;
        this.normalizeResult = in.readBoolean();
        this.modelMaxLength = in.readOptionalInt();
        this.queryPrefix = in.readOptionalString();
        this.passagePrefix = in.readOptionalString();
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        super.writeTo(out);
        out.writeInt(this.embeddingDimension.intValue());
        out.writeEnum((Enum)this.frameworkType);
        if (this.poolingMode != null) {
            out.writeBoolean(true);
            out.writeEnum((Enum)this.poolingMode);
        } else {
            out.writeBoolean(false);
        }
        out.writeBoolean(this.normalizeResult);
        out.writeOptionalInt(this.modelMaxLength);
        out.writeOptionalString(this.queryPrefix);
        out.writeOptionalString(this.passagePrefix);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (this.modelType != null) {
            builder.field("model_type", this.modelType);
        }
        if (this.embeddingDimension != null) {
            builder.field(EMBEDDING_DIMENSION_FIELD, this.embeddingDimension);
        }
        if (this.frameworkType != null) {
            builder.field(FRAMEWORK_TYPE_FIELD, (Object)this.frameworkType);
        }
        if (this.allConfig != null) {
            builder.field("all_config", this.allConfig);
        }
        if (this.modelMaxLength != null) {
            builder.field(MODEL_MAX_LENGTH_FIELD, this.modelMaxLength);
        }
        if (this.poolingMode != null) {
            builder.field(POOLING_MODE_FIELD, (Object)this.poolingMode);
        }
        if (this.normalizeResult) {
            builder.field(NORMALIZE_RESULT_FIELD, this.normalizeResult);
        }
        if (this.queryPrefix != null) {
            builder.field(QUERY_PREFIX, this.queryPrefix);
        }
        if (this.passagePrefix != null) {
            builder.field(PASSAGE_PREFIX, this.passagePrefix);
        }
        builder.endObject();
        return builder;
    }

    public int getEmbeddingDimension() {
        return this.embeddingDimension;
    }

    public FrameworkType getFrameworkType() {
        return this.frameworkType;
    }

    public PoolingMode getPoolingMode() {
        return this.poolingMode;
    }

    public boolean isNormalizeResult() {
        return this.normalizeResult;
    }

    public Integer getModelMaxLength() {
        return this.modelMaxLength;
    }

    public String getQueryPrefix() {
        return this.queryPrefix;
    }

    public String getPassagePrefix() {
        return this.passagePrefix;
    }

    public static Builder builder() {
        return new Builder();
    }

    public Builder toBuilder() {
        return new Builder().modelType(this.modelType).embeddingDimension(this.embeddingDimension).frameworkType(this.frameworkType).allConfig(this.allConfig).poolingMode(this.poolingMode).normalizeResult(this.normalizeResult).modelMaxLength(this.modelMaxLength).queryPrefix(this.queryPrefix).passagePrefix(this.passagePrefix);
    }

    public static enum FrameworkType {
        HUGGINGFACE_TRANSFORMERS,
        SENTENCE_TRANSFORMERS,
        HUGGINGFACE_TRANSFORMERS_NEURON;


        public static FrameworkType from(String value) {
            try {
                return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT));
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Wrong framework type");
            }
        }
    }

    public static enum PoolingMode {
        MEAN("mean"),
        MEAN_SQRT_LEN("mean_sqrt_len"),
        MAX("max"),
        WEIGHTED_MEAN("weightedmean"),
        CLS("cls"),
        LAST_TOKEN("lasttoken");

        private String name;

        public String getName() {
            return this.name;
        }

        private PoolingMode(String name) {
            this.name = name;
        }

        public static PoolingMode from(String value) {
            try {
                return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT));
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Wrong pooling method");
            }
        }
    }

    public static class Builder {
        private String modelType;
        private Integer embeddingDimension;
        private FrameworkType frameworkType;
        private String allConfig;
        private PoolingMode poolingMode;
        private boolean normalizeResult;
        private Integer modelMaxLength;
        private String queryPrefix;
        private String passagePrefix;

        public Builder modelType(String modelType) {
            this.modelType = modelType;
            return this;
        }

        public Builder embeddingDimension(int embeddingDimension) {
            this.embeddingDimension = embeddingDimension;
            return this;
        }

        public Builder frameworkType(FrameworkType frameworkType) {
            this.frameworkType = frameworkType;
            return this;
        }

        public Builder allConfig(String allConfig) {
            this.allConfig = allConfig;
            return this;
        }

        public Builder poolingMode(PoolingMode poolingMode) {
            this.poolingMode = poolingMode;
            return this;
        }

        public Builder normalizeResult(boolean normalizeResult) {
            this.normalizeResult = normalizeResult;
            return this;
        }

        public Builder modelMaxLength(Integer modelMaxLength) {
            this.modelMaxLength = modelMaxLength;
            return this;
        }

        public Builder queryPrefix(String queryPrefix) {
            this.queryPrefix = queryPrefix;
            return this;
        }

        public Builder passagePrefix(String passagePrefix) {
            this.passagePrefix = passagePrefix;
            return this;
        }

        public TextEmbeddingModelConfig build() {
            return new TextEmbeddingModelConfig(this.modelType, this.embeddingDimension, this.frameworkType, this.allConfig, this.poolingMode, this.normalizeResult, this.modelMaxLength, this.queryPrefix, this.passagePrefix);
        }
    }
}

