/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sequence.viterbi;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.protos.LabelFeatureExtractorProto;
import org.tribuo.classification.protos.ViterbiModelProto;
import org.tribuo.classification.sequence.viterbi.LabelFeatureExtractor;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.protos.core.SequenceModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceExample;
import org.tribuo.sequence.SequenceModel;

public class ViterbiModel
extends SequenceModel<Label> {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    private final Model<Label> model;
    private final LabelFeatureExtractor labelFeatureExtractor;
    private final int stackSize;
    private final ScoreAggregation scoreAggregation;

    ViterbiModel(String name, ModelProvenance description, Model<Label> model, LabelFeatureExtractor labelFeatureExtractor, int stackSize, ScoreAggregation scoreAggregation) {
        super(name, description, model.getFeatureIDMap(), model.getOutputIDInfo());
        this.model = model;
        this.labelFeatureExtractor = labelFeatureExtractor;
        this.stackSize = stackSize;
        this.scoreAggregation = scoreAggregation;
    }

    public static ViterbiModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        ViterbiModelProto proto = (ViterbiModelProto)message.unpack(ViterbiModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        Model model = Model.deserialize((ModelProto)proto.getModel());
        if (!model.validate(Label.class)) {
            throw new IllegalStateException("Invalid protobuf, expected a classification model, found " + model);
        }
        Model labelModel = model.castModel(Label.class);
        LabelFeatureExtractor labelFeatureExtractor = LabelFeatureExtractor.deserialize(proto.getLabelFeatureExtractor());
        int stackSize = proto.getStackSize();
        ScoreAggregation scoreAggregation = ScoreAggregation.valueOf(proto.getScoreAggregation());
        return new ViterbiModel(carrier.name(), carrier.provenance(), (Model<Label>)labelModel, labelFeatureExtractor, stackSize, scoreAggregation);
    }

    public List<List<Prediction<Label>>> predict(SequenceDataset<Label> examples) {
        ArrayList<List<Prediction<Label>>> predictions = new ArrayList<List<Prediction<Label>>>();
        for (SequenceExample e : examples) {
            predictions.add(this.predict((SequenceExample<Label>)e));
        }
        return predictions;
    }

    public List<Prediction<Label>> predict(SequenceExample<Label> examples) {
        if (this.stackSize == 1) {
            ArrayList<Label> labels = new ArrayList<Label>();
            ArrayList<Prediction<Label>> returnValues = new ArrayList<Prediction<Label>>();
            for (Example example : examples) {
                List<Feature> labelFeatures = this.extractFeatures(labels);
                example.addAll(labelFeatures);
                Prediction prediction = this.model.predict(example);
                labels.add((Label)prediction.getOutput());
                returnValues.add((Prediction<Label>)prediction);
            }
            return returnValues;
        }
        return this.viterbi(examples);
    }

    public Model<Label> getInnerModel() {
        return this.model;
    }

    private List<Feature> extractFeatures(List<Label> labels) {
        ArrayList<Feature> labelFeatures = new ArrayList<Feature>();
        for (Feature labelFeature : this.labelFeatureExtractor.extractFeatures(labels, 1.0)) {
            int id = this.featureIDMap.getID(labelFeature.getName());
            if (id <= -1) continue;
            labelFeatures.add(labelFeature);
        }
        return labelFeatures;
    }

    private List<Prediction<Label>> viterbi(SequenceExample<Label> examples) {
        Collection<Object> paths = null;
        int[] numUsed = new int[examples.size()];
        int i = 0;
        for (Example example : examples) {
            if (paths == null) {
                paths = new ArrayList();
                Prediction prediction = this.model.predict(example);
                numUsed[i] = prediction.getNumActiveFeatures();
                Map distribution = prediction.getOutputScores();
                for (Label label : this.getTopLabels(distribution)) {
                    paths.add(new Path(label, label.getScore(), null));
                }
            } else {
                HashMap<Label, Path> maxPaths = new HashMap<Label, Path>();
                for (Path path : paths) {
                    Example clonedExample = example.copy();
                    ArrayList<Label> previousLabels = new ArrayList<Label>(path.labels);
                    List<Feature> labelFeatures = this.extractFeatures(previousLabels);
                    clonedExample.addAll(labelFeatures);
                    Prediction prediction = this.model.predict(clonedExample);
                    numUsed[i] = prediction.getNumActiveFeatures();
                    Map distribution = prediction.getOutputScores();
                    for (Label label : this.getTopLabels(distribution)) {
                        double labelScore = label.getScore();
                        double score = this.scoreAggregation == ScoreAggregation.ADD ? path.score + labelScore : path.score * labelScore;
                        Path maxPath = (Path)maxPaths.get(label);
                        if (maxPath != null && !(score > maxPath.score)) continue;
                        maxPaths.put(label, new Path(label, score, path));
                    }
                }
                paths = maxPaths.values();
            }
            ++i;
        }
        Path maxPath = (Path)Collections.max(paths);
        ArrayList<Prediction<Label>> output = new ArrayList<Prediction<Label>>();
        for (int j = 0; j < examples.size(); ++j) {
            Example e = examples.get(j);
            output.add((Prediction<Label>)new Prediction((Output)maxPath.labels.get(j), numUsed[j], e));
        }
        return output;
    }

    protected List<Label> getTopLabels(Map<String, Label> distribution) {
        return ViterbiModel.getTopLabels(distribution, this.stackSize);
    }

    protected static List<Label> getTopLabels(Map<String, Label> distribution, int stackSize) {
        return distribution.values().stream().sorted(Comparator.comparingDouble(Label::getScore).reversed()).limit(stackSize).collect(Collectors.toList());
    }

    public int getStackSize() {
        return this.stackSize;
    }

    public ScoreAggregation getScoreAggregation() {
        return this.scoreAggregation;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        return this.model.getTopFeatures(n);
    }

    public SequenceModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        ViterbiModelProto.Builder modelBuilder = ViterbiModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setModel(this.model.serialize());
        modelBuilder.setLabelFeatureExtractor((LabelFeatureExtractorProto)this.labelFeatureExtractor.serialize());
        modelBuilder.setStackSize(this.stackSize);
        modelBuilder.setScoreAggregation(this.scoreAggregation.name());
        SequenceModelProto.Builder builder = SequenceModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(ViterbiModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }

    public static enum ScoreAggregation {
        ADD,
        MULTIPLY;

    }

    private static class Path
    implements Comparable<Path> {
        public final double score;
        public final Path parent;
        public final List<Label> labels;

        public Path(Label label, double score, Path parent) {
            this.score = score;
            this.parent = parent;
            this.labels = new ArrayList<Label>();
            if (this.parent != null) {
                this.labels.addAll(this.parent.labels);
            }
            this.labels.add(label);
        }

        @Override
        public int compareTo(Path that) {
            return Double.compare(this.score, that.score);
        }
    }
}

