/*
 * Decompiled with CFR 0.152.
 */
package io.skylite.ml.common.input.parameter.regression;

import io.skylite.core.ParseField;
import io.skylite.core.common.io.stream.StreamInput;
import io.skylite.core.common.io.stream.StreamOutput;
import io.skylite.core.xcontent.NamedXContentRegistry;
import io.skylite.core.xcontent.ToXContent;
import io.skylite.core.xcontent.XContentBuilder;
import io.skylite.core.xcontent.XContentParser;
import io.skylite.core.xcontent.XContentParserUtils;
import io.skylite.ml.common.FunctionName;
import io.skylite.ml.common.annotation.MLAlgoParameter;
import io.skylite.ml.common.input.parameter.MLAlgoParams;
import java.io.IOException;
import java.util.Locale;
import java.util.Objects;

@MLAlgoParameter(algorithms={FunctionName.LOGISTIC_REGRESSION})
public class LogisticRegressionParams
implements MLAlgoParams {
    public static final String PARSE_FIELD_NAME = FunctionName.LOGISTIC_REGRESSION.name();
    public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(MLAlgoParams.class, new ParseField(PARSE_FIELD_NAME, new String[0]), it -> LogisticRegressionParams.parse(it));
    public static final String OBJECTIVE_FIELD = "objective";
    public static final String OPTIMISER_FIELD = "optimiser";
    public static final String MOMENTUM_TYPE_FIELD = "momentum_type";
    public static final String LEARNING_RATE_FIELD = "learning_rate";
    public static final String EPSILON_FIELD = "epsilon";
    public static final String MOMENTUM_FACTOR_FIELD = "momentum_factor";
    public static final String BETA1_FIELD = "beta1";
    public static final String BETA2_FIELD = "beta2";
    public static final String DECAY_RATE_FIELD = "decay_rate";
    public static final String EPOCHS_FIELD = "epochs";
    public static final String BATCH_SIZE_FIELD = "batch_size";
    public static final String LOGGING_INTERVAL_FIELD = "logging_interval";
    public static final String SEED_FIELD = "seed";
    public static final String TARGET_FIELD = "target";
    private ObjectiveType objectiveType;
    private OptimizerType optimizerType;
    private MomentumType momentumType;
    private Double learningRate;
    private Double epsilon;
    private Double momentumFactor;
    private Double beta1;
    private Double beta2;
    private Double decayRate;
    private Integer epochs;
    private Integer batchSize;
    private Integer loggingInterval;
    private Long seed;
    private String target;

    public LogisticRegressionParams(ObjectiveType objectiveType, OptimizerType optimizerType, MomentumType momentumType, Double learningRate, Double epsilon, Double momentumFactor, Double beta1, Double beta2, Double decayRate, Integer epochs, Integer batchSize, Integer loggingInterval, Long seed, String target) {
        this.objectiveType = objectiveType;
        this.optimizerType = optimizerType;
        this.momentumType = momentumType;
        this.learningRate = learningRate;
        this.epsilon = epsilon;
        this.momentumFactor = momentumFactor;
        this.beta1 = beta1;
        this.beta2 = beta2;
        this.decayRate = decayRate;
        this.epochs = epochs;
        this.batchSize = batchSize;
        this.loggingInterval = loggingInterval;
        this.seed = seed;
        this.target = target;
    }

    public LogisticRegressionParams(StreamInput in) throws IOException {
        if (in.readBoolean()) {
            this.objectiveType = (ObjectiveType)in.readEnum(ObjectiveType.class);
        }
        if (in.readBoolean()) {
            this.optimizerType = (OptimizerType)in.readEnum(OptimizerType.class);
        }
        if (in.readBoolean()) {
            this.momentumType = (MomentumType)in.readEnum(MomentumType.class);
        }
        this.learningRate = in.readOptionalDouble();
        this.epsilon = in.readOptionalDouble();
        this.momentumFactor = in.readOptionalDouble();
        this.beta1 = in.readOptionalDouble();
        this.beta2 = in.readOptionalDouble();
        this.decayRate = in.readOptionalDouble();
        this.epochs = in.readOptionalInt();
        this.batchSize = in.readOptionalInt();
        this.loggingInterval = in.readOptionalInt();
        this.seed = in.readOptionalLong();
        this.target = in.readOptionalString();
    }

    public ObjectiveType getObjectiveType() {
        return this.objectiveType;
    }

    public void setObjectiveType(ObjectiveType objectiveType) {
        this.objectiveType = objectiveType;
    }

    public OptimizerType getOptimizerType() {
        return this.optimizerType;
    }

    public void setOptimizerType(OptimizerType optimizerType) {
        this.optimizerType = optimizerType;
    }

    public MomentumType getMomentumType() {
        return this.momentumType;
    }

    public void setMomentumType(MomentumType momentumType) {
        this.momentumType = momentumType;
    }

    public Double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(Double learningRate) {
        this.learningRate = learningRate;
    }

    public Double getEpsilon() {
        return this.epsilon;
    }

    public void setEpsilon(Double epsilon) {
        this.epsilon = epsilon;
    }

    public Double getMomentumFactor() {
        return this.momentumFactor;
    }

    public void setMomentumFactor(Double momentumFactor) {
        this.momentumFactor = momentumFactor;
    }

    public Double getBeta1() {
        return this.beta1;
    }

    public void setBeta1(Double beta1) {
        this.beta1 = beta1;
    }

    public Double getBeta2() {
        return this.beta2;
    }

    public void setBeta2(Double beta2) {
        this.beta2 = beta2;
    }

    public Double getDecayRate() {
        return this.decayRate;
    }

    public void setDecayRate(Double decayRate) {
        this.decayRate = decayRate;
    }

    public Integer getEpochs() {
        return this.epochs;
    }

    public void setEpochs(Integer epochs) {
        this.epochs = epochs;
    }

    public Integer getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(Integer batchSize) {
        this.batchSize = batchSize;
    }

    public Integer getLoggingInterval() {
        return this.loggingInterval;
    }

    public void setLoggingInterval(Integer loggingInterval) {
        this.loggingInterval = loggingInterval;
    }

    public Long getSeed() {
        return this.seed;
    }

    public void setSeed(Long seed) {
        this.seed = seed;
    }

    public String getTarget() {
        return this.target;
    }

    public void setTarget(String target) {
        this.target = target;
    }

    public static Builder builder() {
        return new Builder();
    }

    public Builder toBuilder() {
        return new Builder().objectiveType(this.objectiveType).optimizerType(this.optimizerType).momentumType(this.momentumType).learningRate(this.learningRate).epsilon(this.epsilon).momentumFactor(this.momentumFactor).beta1(this.beta1).beta2(this.beta2).decayRate(this.decayRate).epochs(this.epochs).batchSize(this.batchSize).loggingInterval(this.loggingInterval).seed(this.seed).target(this.target);
    }

    public static MLAlgoParams parse(XContentParser parser) throws IOException {
        ObjectiveType objective = null;
        OptimizerType optimizerType = null;
        MomentumType momentumType = null;
        Double learningRate = null;
        Double epsilon = null;
        Double momentumFactor = null;
        Double beta1 = null;
        Double beta2 = null;
        Double decayRate = null;
        Integer epochs = null;
        Integer batchSize = null;
        Integer loggingInterval = null;
        Long seed = null;
        String target = null;
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
        block32: while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();
            switch (fieldName) {
                case "objective": {
                    objective = ObjectiveType.valueOf(parser.text().toUpperCase(Locale.ROOT));
                    continue block32;
                }
                case "optimiser": {
                    optimizerType = OptimizerType.valueOf(parser.text().toUpperCase(Locale.ROOT));
                    continue block32;
                }
                case "momentum_type": {
                    momentumType = MomentumType.valueOf(parser.text().toUpperCase(Locale.ROOT));
                    continue block32;
                }
                case "learning_rate": {
                    learningRate = parser.doubleValue(false);
                    continue block32;
                }
                case "epsilon": {
                    epsilon = parser.doubleValue(false);
                    continue block32;
                }
                case "momentum_factor": {
                    momentumFactor = parser.doubleValue(false);
                    continue block32;
                }
                case "beta1": {
                    beta1 = parser.doubleValue(false);
                    continue block32;
                }
                case "beta2": {
                    beta2 = parser.doubleValue(false);
                    continue block32;
                }
                case "decay_rate": {
                    decayRate = parser.doubleValue(false);
                    continue block32;
                }
                case "epochs": {
                    epochs = parser.intValue(false);
                    continue block32;
                }
                case "batch_size": {
                    batchSize = parser.intValue(false);
                    continue block32;
                }
                case "logging_interval": {
                    loggingInterval = parser.intValue(false);
                    continue block32;
                }
                case "seed": {
                    seed = parser.longValue(false);
                    continue block32;
                }
                case "target": {
                    target = parser.text();
                    continue block32;
                }
            }
            parser.skipChildren();
        }
        return new LogisticRegressionParams(objective, optimizerType, momentumType, learningRate, epsilon, momentumFactor, beta1, beta2, decayRate, epochs, batchSize, loggingInterval, seed, target);
    }

    public void writeTo(StreamOutput out) throws IOException {
        if (this.objectiveType != null) {
            out.writeBoolean(true);
            out.writeEnum((Enum)this.objectiveType);
        } else {
            out.writeBoolean(false);
        }
        if (this.optimizerType != null) {
            out.writeBoolean(true);
            out.writeEnum((Enum)this.optimizerType);
        } else {
            out.writeBoolean(false);
        }
        if (this.momentumType != null) {
            out.writeBoolean(true);
            out.writeEnum((Enum)this.momentumType);
        } else {
            out.writeBoolean(false);
        }
        out.writeOptionalDouble(this.learningRate);
        out.writeOptionalDouble(this.epsilon);
        out.writeOptionalDouble(this.momentumFactor);
        out.writeOptionalDouble(this.beta1);
        out.writeOptionalDouble(this.beta2);
        out.writeOptionalDouble(this.decayRate);
        out.writeOptionalInt(this.epochs);
        out.writeOptionalInt(this.batchSize);
        out.writeOptionalInt(this.loggingInterval);
        out.writeOptionalLong(this.seed);
        out.writeOptionalString(this.target);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (this.objectiveType != null) {
            builder.field(OBJECTIVE_FIELD, (Object)this.objectiveType);
        }
        if (this.optimizerType != null) {
            builder.field(OPTIMISER_FIELD, (Object)this.optimizerType);
        }
        if (this.momentumType != null) {
            builder.field(MOMENTUM_TYPE_FIELD, (Object)this.momentumType);
        }
        if (this.learningRate != null) {
            builder.field(LEARNING_RATE_FIELD, this.learningRate);
        }
        if (this.epsilon != null) {
            builder.field(EPSILON_FIELD, this.epsilon);
        }
        if (this.momentumFactor != null) {
            builder.field(MOMENTUM_FACTOR_FIELD, this.momentumFactor);
        }
        if (this.beta1 != null) {
            builder.field(BETA1_FIELD, this.beta1);
        }
        if (this.beta2 != null) {
            builder.field(BETA2_FIELD, this.beta2);
        }
        if (this.decayRate != null) {
            builder.field(DECAY_RATE_FIELD, this.decayRate);
        }
        if (this.epochs != null) {
            builder.field(EPOCHS_FIELD, this.epochs);
        }
        if (this.batchSize != null) {
            builder.field(BATCH_SIZE_FIELD, this.batchSize);
        }
        if (this.loggingInterval != null) {
            builder.field(LOGGING_INTERVAL_FIELD, this.loggingInterval);
        }
        if (this.seed != null) {
            builder.field(SEED_FIELD, this.seed);
        }
        if (this.target != null) {
            builder.field(TARGET_FIELD, this.target);
        }
        builder.endObject();
        return builder;
    }

    public String getWriteableName() {
        return PARSE_FIELD_NAME;
    }

    @Override
    public int getVersion() {
        return 1;
    }

    public String toString() {
        return "LogisticRegressionParams{objectiveType=" + String.valueOf((Object)this.objectiveType) + ", optimizerType=" + String.valueOf((Object)this.optimizerType) + ", momentumType=" + String.valueOf((Object)this.momentumType) + ", learningRate=" + this.learningRate + ", epsilon=" + this.epsilon + ", momentumFactor=" + this.momentumFactor + ", beta1=" + this.beta1 + ", beta2=" + this.beta2 + ", decayRate=" + this.decayRate + ", epochs=" + this.epochs + ", batchSize=" + this.batchSize + ", loggingInterval=" + this.loggingInterval + ", seed=" + this.seed + ", target='" + this.target + "'}";
    }

    public boolean equals(Object o) {
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        LogisticRegressionParams that = (LogisticRegressionParams)o;
        return this.objectiveType == that.objectiveType && this.optimizerType == that.optimizerType && this.momentumType == that.momentumType && Objects.equals(this.learningRate, that.learningRate) && Objects.equals(this.epsilon, that.epsilon) && Objects.equals(this.momentumFactor, that.momentumFactor) && Objects.equals(this.beta1, that.beta1) && Objects.equals(this.beta2, that.beta2) && Objects.equals(this.decayRate, that.decayRate) && Objects.equals(this.epochs, that.epochs) && Objects.equals(this.batchSize, that.batchSize) && Objects.equals(this.loggingInterval, that.loggingInterval) && Objects.equals(this.seed, that.seed) && Objects.equals(this.target, that.target);
    }

    public int hashCode() {
        return Objects.hash(new Object[]{this.objectiveType, this.optimizerType, this.momentumType, this.learningRate, this.epsilon, this.momentumFactor, this.beta1, this.beta2, this.decayRate, this.epochs, this.batchSize, this.loggingInterval, this.seed, this.target});
    }

    public static enum ObjectiveType {
        HINGE,
        LOGMULTICLASS;


        public static ObjectiveType from(String value) {
            try {
                return ObjectiveType.valueOf(value);
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Wrong objective type");
            }
        }
    }

    public static enum OptimizerType {
        SIMPLE_SGD,
        LINEAR_DECAY_SGD,
        SQRT_DECAY_SGD,
        ADA_GRAD,
        ADA_DELTA,
        ADAM,
        RMS_PROP;


        public static OptimizerType from(String value) {
            try {
                return OptimizerType.valueOf(value);
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Wrong optimizer type");
            }
        }
    }

    public static enum MomentumType {
        STANDARD,
        NESTEROV;


        public static MomentumType from(String value) {
            try {
                return MomentumType.valueOf(value);
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Wrong momentum type");
            }
        }
    }

    public static class Builder {
        private ObjectiveType objectiveType;
        private OptimizerType optimizerType;
        private MomentumType momentumType;
        private Double learningRate;
        private Double epsilon;
        private Double momentumFactor;
        private Double beta1;
        private Double beta2;
        private Double decayRate;
        private Integer epochs;
        private Integer batchSize;
        private Integer loggingInterval;
        private Long seed;
        private String target;

        public LogisticRegressionParams build() {
            return new LogisticRegressionParams(this.objectiveType, this.optimizerType, this.momentumType, this.learningRate, this.epsilon, this.momentumFactor, this.beta1, this.beta2, this.decayRate, this.epochs, this.batchSize, this.loggingInterval, this.seed, this.target);
        }

        public Builder objectiveType(ObjectiveType objectiveType) {
            this.objectiveType = objectiveType;
            return this;
        }

        public Builder optimizerType(OptimizerType optimizerType) {
            this.optimizerType = optimizerType;
            return this;
        }

        public Builder momentumType(MomentumType momentumType) {
            this.momentumType = momentumType;
            return this;
        }

        public Builder learningRate(Double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder epsilon(Double epsilon) {
            this.epsilon = epsilon;
            return this;
        }

        public Builder momentumFactor(Double momentumFactor) {
            this.momentumFactor = momentumFactor;
            return this;
        }

        public Builder beta1(Double beta1) {
            this.beta1 = beta1;
            return this;
        }

        public Builder beta2(Double beta2) {
            this.beta2 = beta2;
            return this;
        }

        public Builder decayRate(Double decayRate) {
            this.decayRate = decayRate;
            return this;
        }

        public Builder epochs(Integer epochs) {
            this.epochs = epochs;
            return this;
        }

        public Builder batchSize(Integer batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        public Builder loggingInterval(Integer loggingInterval) {
            this.loggingInterval = loggingInterval;
            return this;
        }

        public Builder seed(Long seed) {
            this.seed = seed;
            return this;
        }

        public Builder target(String target) {
            this.target = target;
            return this;
        }
    }
}

