/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.baseline;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.baseline.DummyRegressionModel;
import org.tribuo.util.Util;

public final class DummyRegressionTrainer
implements Trainer<Regressor> {
    @Config(mandatory=true, description="Type of dummy regressor.")
    private DummyType dummyType;
    @Config(description="Constant value to use for the constant regressor.")
    private double constantValue = Double.NaN;
    @Config(description="Quartile to use.")
    private double quartile = Double.NaN;
    @Config(description="The seed for the RNG.")
    private long seed = 1L;
    private int trainInvocationCounter = 0;

    private DummyRegressionTrainer() {
    }

    public void postConfig() {
        if (this.dummyType == DummyType.CONSTANT && Double.isNaN(this.constantValue)) {
            throw new PropertyException("", "constantValue", "Please supply a constant value when using the type CONSTANT.");
        }
        if (this.dummyType == DummyType.QUARTILE && (this.quartile < 0.0 || this.quartile > 1.0)) {
            throw new PropertyException("", "quartile", "Please supply a quartile between zero and one when using the type QUARTILE.");
        }
    }

    public DummyRegressionModel train(Dataset<Regressor> examples, Map<String, Provenance> instanceProvenance) {
        return this.train((Dataset)examples, (Map)instanceProvenance, -1);
    }

    public DummyRegressionModel train(Dataset<Regressor> examples, Map<String, Provenance> instanceProvenance, int invocationCount) {
        int id;
        if (invocationCount != -1) {
            this.setInvocationCount(invocationCount);
        }
        ModelProvenance provenance = new ModelProvenance(DummyRegressionModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), this.getProvenance(), instanceProvenance);
        ++this.trainInvocationCounter;
        ImmutableOutputInfo outputInfo = examples.getOutputIDInfo();
        Set domain = outputInfo.getDomain();
        double[][] outputs = new double[outputInfo.size()][examples.size()];
        int i = 0;
        for (Example e : examples) {
            for (Regressor r : (Regressor)e.getOutput()) {
                id = outputInfo.getID((Output)r);
                outputs[id][i] = ((Regressor.DimensionTuple)r).getValue();
            }
            ++i;
        }
        switch (this.dummyType) {
            case CONSTANT: {
                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
                for (Regressor r : domain) {
                    id = outputInfo.getID((Output)r);
                    output[id] = new Regressor.DimensionTuple(r.getNames()[0], this.constantValue);
                }
                Regressor regressor = new Regressor(output);
                return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), (ImmutableOutputInfo<Regressor>)outputInfo, this.dummyType, regressor);
            }
            case MEAN: {
                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
                for (Regressor r : domain) {
                    id = outputInfo.getID((Output)r);
                    output[id] = new Regressor.DimensionTuple(r.getNames()[0], Util.mean((double[])outputs[id]));
                }
                Regressor regressor = new Regressor(output);
                return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), (ImmutableOutputInfo<Regressor>)outputInfo, this.dummyType, regressor);
            }
            case MEDIAN: {
                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
                for (Regressor r : domain) {
                    id = outputInfo.getID((Output)r);
                    Arrays.sort(outputs[id]);
                    output[id] = new Regressor.DimensionTuple(r.getNames()[0], outputs[id][outputs[id].length / 2]);
                }
                Regressor regressor = new Regressor(output);
                return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), (ImmutableOutputInfo<Regressor>)outputInfo, this.dummyType, regressor);
            }
            case QUARTILE: {
                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
                for (Regressor r : domain) {
                    id = outputInfo.getID((Output)r);
                    Arrays.sort(outputs[id]);
                    output[id] = new Regressor.DimensionTuple(r.getNames()[0], outputs[id][(int)(this.quartile * (double)outputs[id].length)]);
                }
                Regressor regressor = new Regressor(output);
                return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), (ImmutableOutputInfo<Regressor>)outputInfo, this.dummyType, regressor);
            }
            case GAUSSIAN: {
                double[] means = new double[outputs.length];
                double[] variances = new double[outputs.length];
                String[] names = new String[outputs.length];
                for (Regressor r : domain) {
                    int id2 = outputInfo.getID((Output)r);
                    names[id2] = r.getNames()[0];
                    Pair meanVariance = Util.meanAndVariance((double[])outputs[id2]);
                    means[id2] = (Double)meanVariance.getA();
                    variances[id2] = (Double)meanVariance.getB();
                }
                return new DummyRegressionModel(provenance, examples.getFeatureIDMap(), (ImmutableOutputInfo<Regressor>)outputInfo, this.seed, means, variances, names);
            }
        }
        throw new IllegalStateException("Unknown dummyType " + (Object)((Object)this.dummyType));
    }

    public String toString() {
        switch (this.dummyType) {
            case CONSTANT: {
                return "DummyRegressionTrainer(dummyType=CONSTANT,constantValue=" + this.constantValue + ")";
            }
            case MEAN: {
                return "DummyRegressionTrainer(dummyType=MEAN)";
            }
            case MEDIAN: {
                return "DummyRegressionTrainer(dummyType=MEDIAN)";
            }
            case QUARTILE: {
                return "DummyRegressionTrainer(dummyType=QUARTILE,quartile=" + this.quartile + ")";
            }
            case GAUSSIAN: {
                return "DummyRegressionTrainer(dummyType=GAUSSIAN,seed=" + this.seed + ")";
            }
        }
        return "DummyRegressionTrainer(dummyType=" + (Object)((Object)this.dummyType) + ")";
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.trainInvocationCounter = invocationCount;
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }

    public static DummyRegressionTrainer createConstantTrainer(double value) {
        DummyRegressionTrainer trainer = new DummyRegressionTrainer();
        trainer.dummyType = DummyType.CONSTANT;
        trainer.constantValue = value;
        return trainer;
    }

    public static DummyRegressionTrainer createGaussianTrainer(long seed) {
        DummyRegressionTrainer trainer = new DummyRegressionTrainer();
        trainer.dummyType = DummyType.GAUSSIAN;
        trainer.seed = seed;
        return trainer;
    }

    public static DummyRegressionTrainer createMeanTrainer() {
        DummyRegressionTrainer trainer = new DummyRegressionTrainer();
        trainer.dummyType = DummyType.MEAN;
        return trainer;
    }

    public static DummyRegressionTrainer createMedianTrainer() {
        DummyRegressionTrainer trainer = new DummyRegressionTrainer();
        trainer.dummyType = DummyType.MEDIAN;
        return trainer;
    }

    public static DummyRegressionTrainer createQuartileTrainer(double value) {
        if (Double.isNaN(value) || value < 0.0 || value > 1.0) {
            throw new IllegalArgumentException("Please provide an appropriate value between 0.0 and 1.0, found " + value);
        }
        DummyRegressionTrainer trainer = new DummyRegressionTrainer();
        trainer.dummyType = DummyType.QUARTILE;
        trainer.quartile = value;
        return trainer;
    }

    public static enum DummyType {
        MEAN,
        MEDIAN,
        QUARTILE,
        CONSTANT,
        GAUSSIAN;

    }

    @Deprecated
    public static final class DummyRegressionTrainerProvenance
    implements TrainerProvenance {
        private static final long serialVersionUID = 1L;
        private final String className;
        private final DummyType dummyType;
        private final long seed;
        private final double constantValue;
        private final double quartile;

        public DummyRegressionTrainerProvenance(DummyRegressionTrainer host) {
            this.className = host.getClass().getName();
            this.dummyType = host.dummyType;
            this.seed = host.seed;
            this.constantValue = host.constantValue;
            this.quartile = host.quartile;
        }

        public DummyRegressionTrainerProvenance(Map<String, Provenance> map) {
            this.className = ((StringProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)"class-name", StringProvenance.class, (String)DummyRegressionTrainerProvenance.class.getSimpleName())).getValue();
            this.dummyType = (DummyType)((EnumProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)"dummyType", EnumProvenance.class, (String)DummyRegressionTrainerProvenance.class.getSimpleName())).getValue();
            this.seed = ((LongProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)"seed", LongProvenance.class, (String)DummyRegressionTrainerProvenance.class.getSimpleName())).getValue();
            this.constantValue = ((DoubleProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)"constantValue", DoubleProvenance.class, (String)DummyRegressionTrainerProvenance.class.getSimpleName())).getValue();
            this.quartile = ((DoubleProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)"quartile", DoubleProvenance.class, (String)DummyRegressionTrainerProvenance.class.getSimpleName())).getValue();
        }

        public Map<String, Provenance> getConfiguredParameters() {
            HashMap<String, Provenance> map = new HashMap<String, Provenance>();
            map.put("dummyType", (Provenance)new EnumProvenance("dummyType", (Enum)this.dummyType));
            map.put("constantValue", (Provenance)new DoubleProvenance("constantValue", this.constantValue));
            map.put("quartile", (Provenance)new DoubleProvenance("quartile", this.quartile));
            map.put("seed", (Provenance)new LongProvenance("seed", this.seed));
            return map;
        }

        public String getClassName() {
            return this.className;
        }

        public String toString() {
            return this.generateString("Trainer");
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            DummyRegressionTrainerProvenance pairs = (DummyRegressionTrainerProvenance)o;
            return this.seed == pairs.seed && Double.compare(pairs.constantValue, this.constantValue) == 0 && Double.compare(pairs.quartile, this.quartile) == 0 && this.className.equals(pairs.className) && this.dummyType == pairs.dummyType;
        }

        public int hashCode() {
            return Objects.hash(new Object[]{this.className, this.dummyType, this.seed, this.constantValue, this.quartile});
        }
    }
}

