/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.baseline;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.protos.core.OutputProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.baseline.DummyRegressionTrainer;
import org.tribuo.regression.protos.DummyRegressionModelProto;
import org.tribuo.util.Util;

public class DummyRegressionModel
extends Model<Regressor> {
    private static final long serialVersionUID = 2L;
    public static final int CURRENT_VERSION = 0;
    private final DummyRegressionTrainer.DummyType dummyType;
    private final Regressor output;
    private final long seed;
    private final Random rng;
    private final double[] means;
    private final double[] variances;
    private final String[] dimensionNames;

    DummyRegressionModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, long seed, double[] means, double[] variances, String[] names) {
        super("dummy-GAUSSIAN-regression", description, featureIDMap, outputIDInfo, false);
        this.dummyType = DummyRegressionTrainer.DummyType.GAUSSIAN;
        this.output = null;
        this.seed = seed;
        this.rng = new Random(seed);
        this.means = Arrays.copyOf(means, means.length);
        this.variances = Arrays.copyOf(variances, variances.length);
        this.dimensionNames = Arrays.copyOf(names, names.length);
    }

    DummyRegressionModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, DummyRegressionTrainer.DummyType dummyType, Regressor regressor) {
        super("dummy-" + (Object)((Object)dummyType) + "-regression", description, featureIDMap, outputIDInfo, false);
        this.dummyType = dummyType;
        this.output = regressor;
        this.seed = 12345L;
        this.rng = null;
        this.means = new double[0];
        this.variances = new double[0];
        this.dimensionNames = new String[0];
    }

    private DummyRegressionModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Regressor> outputInfo, DummyRegressionTrainer.DummyType dummyType, Regressor regressor, long seed, double[] means, double[] variances, String[] dimensionNames) {
        super(name, provenance, featureMap, outputInfo, false);
        this.dummyType = dummyType;
        this.output = regressor;
        this.seed = seed;
        this.rng = new Random(seed);
        this.means = means;
        this.variances = variances;
        this.dimensionNames = dimensionNames;
    }

    public static DummyRegressionModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        DummyRegressionModelProto proto = (DummyRegressionModelProto)message.unpack(DummyRegressionModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        DummyRegressionTrainer.DummyType dummyType = DummyRegressionTrainer.DummyType.valueOf(proto.getDummyType());
        Regressor constantRegressor = null;
        if (!dummyType.equals((Object)DummyRegressionTrainer.DummyType.GAUSSIAN)) {
            Output output = Output.deserialize((OutputProto)proto.getOutput());
            if (!(output instanceof Regressor)) {
                throw new IllegalStateException("Invalid protobuf, expected a Regressor, found " + output.getClass());
            }
            constantRegressor = (Regressor)output;
        }
        long seed = proto.getSeed();
        double[] means = Util.toPrimitiveDouble(proto.getMeansList());
        double[] variances = Util.toPrimitiveDouble(proto.getVariancesList());
        String[] dimensionNames = (String[])proto.getDimensionNamesList().toArray((Object[])new String[0]);
        return new DummyRegressionModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<Regressor>)outputDomain, dummyType, constantRegressor, seed, means, variances, dimensionNames);
    }

    public Prediction<Regressor> predict(Example<Regressor> example) {
        switch (this.dummyType) {
            case CONSTANT: 
            case MEAN: 
            case MEDIAN: 
            case QUARTILE: {
                return new Prediction((Output)this.output, 0, example);
            }
            case GAUSSIAN: {
                Regressor.DimensionTuple[] dimensions = new Regressor.DimensionTuple[this.dimensionNames.length];
                for (int i = 0; i < this.dimensionNames.length; ++i) {
                    double regressionValue = this.rng.nextGaussian() * this.variances[i] + this.means[i];
                    dimensions[i] = new Regressor.DimensionTuple(this.dimensionNames[i], regressionValue);
                }
                return new Prediction((Output)new Regressor(dimensions), 0, example);
            }
        }
        throw new IllegalStateException("Unknown dummyType " + (Object)((Object)this.dummyType));
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        if (n != 0) {
            return Collections.singletonMap("ALL_OUTPUTS", Collections.singletonList(new Pair((Object)"BIAS", (Object)1.0)));
        }
        return Collections.emptyMap();
    }

    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
        return Optional.of(new Excuse(example, this.predict(example), this.getTopFeatures(1)));
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        DummyRegressionModelProto.Builder modelBuilder = DummyRegressionModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setDummyType(this.dummyType.name());
        if (this.output != null) {
            modelBuilder.setOutput(this.output.serialize());
        }
        modelBuilder.addAllMeans(Arrays.stream(this.means).boxed().collect(Collectors.toList()));
        modelBuilder.addAllVariances(Arrays.stream(this.variances).boxed().collect(Collectors.toList()));
        modelBuilder.addAllDimensionNames(Arrays.asList(this.dimensionNames));
        modelBuilder.setSeed(this.seed);
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(DummyRegressionModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }

    protected Model<Regressor> copy(String newName, ModelProvenance newProvenance) {
        switch (this.dummyType) {
            case GAUSSIAN: {
                return new DummyRegressionModel(newProvenance, this.featureIDMap, (ImmutableOutputInfo<Regressor>)this.outputIDInfo, this.seed, this.means, this.variances, this.dimensionNames);
            }
            case CONSTANT: 
            case MEAN: 
            case MEDIAN: 
            case QUARTILE: {
                return new DummyRegressionModel(newProvenance, this.featureIDMap, (ImmutableOutputInfo<Regressor>)this.outputIDInfo, this.dummyType, this.output.copy());
            }
        }
        throw new IllegalStateException("Unknown dummyType " + (Object)((Object)this.dummyType));
    }
}

