/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.baseline;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.ImmutableLabelInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.baseline.DummyClassifierTrainer;
import org.tribuo.classification.protos.DummyClassifierModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.protos.core.OutputProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

public class DummyClassifierModel
extends Model<Label> {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    private final DummyClassifierTrainer.DummyType dummyType;
    private final Label constantLabel;
    private final double[] cdf;
    private final Random rng;
    private final long seed;

    DummyClassifierModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo) {
        super("dummy-MOST_FREQUENT-classifier", description, featureIDMap, outputIDInfo, false);
        this.dummyType = DummyClassifierTrainer.DummyType.MOST_FREQUENT;
        this.constantLabel = DummyClassifierModel.findMostFrequentLabel(outputIDInfo);
        this.cdf = null;
        this.seed = 12345L;
        this.rng = null;
    }

    DummyClassifierModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, DummyClassifierTrainer.DummyType dummyType, long seed) {
        super("dummy-" + (Object)((Object)dummyType) + "-classifier", description, featureIDMap, outputIDInfo, false);
        this.dummyType = dummyType;
        this.constantLabel = LabelFactory.UNKNOWN_LABEL;
        this.cdf = dummyType == DummyClassifierTrainer.DummyType.UNIFORM ? DummyClassifierModel.generateUniformCDF(outputIDInfo) : DummyClassifierModel.generateStratifiedCDF(outputIDInfo);
        this.seed = seed;
        this.rng = new Random(seed);
    }

    DummyClassifierModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, Label constantLabel) {
        super("dummy-CONSTANT-classifier", description, featureIDMap, outputIDInfo, false);
        this.dummyType = DummyClassifierTrainer.DummyType.CONSTANT;
        this.constantLabel = constantLabel;
        this.cdf = null;
        this.seed = 12345L;
        this.rng = null;
    }

    private DummyClassifierModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Label> outputInfo, DummyClassifierTrainer.DummyType type, Label constantLabel, double[] cdf, long seed) {
        super(name, provenance, featureMap, outputInfo, false);
        this.dummyType = type;
        this.constantLabel = constantLabel;
        this.cdf = cdf;
        this.seed = seed;
        this.rng = new Random(seed);
    }

    public static DummyClassifierModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        DummyClassifierModelProto proto = (DummyClassifierModelProto)message.unpack(DummyClassifierModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        DummyClassifierTrainer.DummyType dummyType = DummyClassifierTrainer.DummyType.valueOf(proto.getDummyType());
        Output output = Output.deserialize((OutputProto)proto.getConstantLabel());
        if (!(output instanceof Label)) {
            throw new IllegalStateException("Invalid protobuf, expected a label, found " + output.getClass());
        }
        Label constantLabel = (Label)output;
        double[] cdf = null;
        if (proto.getCdfCount() > 0) {
            cdf = Util.toPrimitiveDouble(proto.getCdfList());
        }
        long seed = proto.getSeed();
        return new DummyClassifierModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<Label>)outputDomain, dummyType, constantLabel, cdf, seed);
    }

    public Prediction<Label> predict(Example<Label> example) {
        switch (this.dummyType) {
            case CONSTANT: 
            case MOST_FREQUENT: {
                return new Prediction((Output)this.constantLabel, 0, example);
            }
            case UNIFORM: 
            case STRATIFIED: {
                return new Prediction((Output)DummyClassifierModel.sampleLabel(this.cdf, (ImmutableOutputInfo<Label>)this.outputIDInfo, this.rng), 0, example);
            }
        }
        throw new IllegalStateException("Unknown dummyType " + (Object)((Object)this.dummyType));
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        HashMap<String, List<Pair<String, Double>>> map = new HashMap<String, List<Pair<String, Double>>>();
        if (n != 0) {
            map.put("ALL_OUTPUTS", Collections.singletonList(new Pair((Object)"BIAS", (Object)1.0)));
        }
        return map;
    }

    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
        return Optional.of(new Excuse(example, this.predict(example), this.getTopFeatures(1)));
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        DummyClassifierModelProto.Builder modelBuilder = DummyClassifierModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setDummyType(this.dummyType.name());
        modelBuilder.setConstantLabel(this.constantLabel.serialize());
        if (this.cdf != null) {
            modelBuilder.addAllCdf(Arrays.stream(this.cdf).boxed().collect(Collectors.toList()));
        }
        modelBuilder.setSeed(this.seed);
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(DummyClassifierModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }

    protected DummyClassifierModel copy(String newName, ModelProvenance newProvenance) {
        switch (this.dummyType) {
            case CONSTANT: {
                return new DummyClassifierModel(newProvenance, this.featureIDMap, (ImmutableOutputInfo<Label>)this.outputIDInfo, this.constantLabel.copy());
            }
            case MOST_FREQUENT: {
                return new DummyClassifierModel(newProvenance, this.featureIDMap, (ImmutableOutputInfo<Label>)this.outputIDInfo);
            }
            case UNIFORM: 
            case STRATIFIED: {
                return new DummyClassifierModel(newProvenance, this.featureIDMap, (ImmutableOutputInfo<Label>)this.outputIDInfo, this.dummyType, this.seed);
            }
        }
        throw new IllegalStateException("Unknown dummyType " + (Object)((Object)this.dummyType));
    }

    private static Label sampleLabel(double[] cdf, ImmutableOutputInfo<Label> outputIDInfo, Random rng) {
        int sample = Util.sampleFromCDF((double[])cdf, (Random)rng);
        return (Label)outputIDInfo.getOutput(sample);
    }

    private static Label findMostFrequentLabel(ImmutableOutputInfo<Label> outputInfo) {
        Label maxLabel = null;
        long count = -1L;
        ImmutableLabelInfo labelInfo = (ImmutableLabelInfo)outputInfo;
        Iterator<Pair<Integer, Label>> iterator = labelInfo.iterator();
        while (iterator.hasNext()) {
            Pair<Integer, Label> p = iterator.next();
            long curCount = labelInfo.getLabelCount((Integer)p.getA());
            if (curCount <= count) continue;
            count = curCount;
            maxLabel = (Label)p.getB();
        }
        return maxLabel;
    }

    private static double[] generateUniformCDF(ImmutableOutputInfo<Label> outputInfo) {
        int length = outputInfo.getDomain().size();
        double[] pmf = Util.generateUniformVector((int)length, (double)(1.0 / (double)length));
        return Util.generateCDF((double[])pmf);
    }

    private static double[] generateStratifiedCDF(ImmutableOutputInfo<Label> outputInfo) {
        ImmutableLabelInfo labelInfo = (ImmutableLabelInfo)outputInfo;
        int length = labelInfo.getDomain().size();
        long counts = labelInfo.getTotalObservations();
        double[] pmf = new double[length];
        Iterator<Pair<Integer, Label>> iterator = labelInfo.iterator();
        while (iterator.hasNext()) {
            Pair<Integer, Label> p = iterator.next();
            int idx = (Integer)p.getA();
            pmf[idx] = (double)labelInfo.getLabelCount(idx) / (double)counts;
        }
        return Util.generateCDF((double[])pmf);
    }
}

