/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.crf;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.logging.Logger;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.sequence.ConfidencePredictingSequenceModel;
import org.tribuo.classification.sgd.crf.CRFParameters;
import org.tribuo.classification.sgd.crf.Chunk;
import org.tribuo.classification.sgd.protos.CRFModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.Parameters;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.protos.ParametersProto;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.SequenceModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.sequence.SequenceExample;

public class CRFModel
extends ConfidencePredictingSequenceModel {
    private static final Logger logger = Logger.getLogger(CRFModel.class.getName());
    private static final long serialVersionUID = 2L;
    public static final int CURRENT_VERSION = 0;
    private final CRFParameters parameters;
    private ConfidenceType confidenceType;

    CRFModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap, CRFParameters parameters) {
        super(name, description, featureIDMap, labelIDMap);
        this.parameters = parameters;
        this.confidenceType = ConfidenceType.NONE;
    }

    public static CRFModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        CRFModelProto proto = (CRFModelProto)message.unpack(CRFModelProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        Parameters params = Parameters.deserialize((ParametersProto)proto.getParams());
        if (!(params instanceof CRFParameters)) {
            throw new IllegalStateException("Invalid protobuf, parameters must be CRFParameters, found " + params.getClass());
        }
        ConfidenceType confidenceType = ConfidenceType.valueOf(proto.getConfidenceType());
        CRFModel model = new CRFModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<Label>)outputDomain, (CRFParameters)params);
        model.confidenceType = confidenceType;
        return model;
    }

    public void setConfidenceType(ConfidenceType type) {
        this.confidenceType = type;
    }

    public DenseVector getFeatureWeights(int featureID) {
        if (featureID < 0 || featureID >= this.featureIDMap.size()) {
            logger.warning("Unknown feature");
            return new DenseVector(0);
        }
        return this.parameters.getFeatureWeights(featureID);
    }

    public DenseVector getFeatureWeights(String featureName) {
        int id = this.featureIDMap.getID(featureName);
        if (id > -1) {
            return this.getFeatureWeights(this.featureIDMap.getID(featureName));
        }
        logger.warning("Unknown feature");
        return new DenseVector(0);
    }

    public List<Prediction<Label>> predict(SequenceExample<Label> example) {
        SGDVector[] features = CRFModel.convertToVector(example, this.featureIDMap);
        ArrayList<Prediction<Label>> output = new ArrayList<Prediction<Label>>();
        if (this.confidenceType == ConfidenceType.MULTIPLY) {
            DenseVector[] marginals = this.parameters.predictMarginals(features);
            for (int i = 0; i < marginals.length; ++i) {
                double maxScore = Double.NEGATIVE_INFINITY;
                Label maxLabel = null;
                LinkedHashMap<String, Label> predMap = new LinkedHashMap<String, Label>();
                for (int j = 0; j < marginals[i].size(); ++j) {
                    String labelName = ((Label)this.outputIDMap.getOutput(j)).getLabel();
                    Label label = new Label(labelName, marginals[i].get(j));
                    predMap.put(labelName, label);
                    if (!(label.getScore() > maxScore)) continue;
                    maxScore = label.getScore();
                    maxLabel = label;
                }
                output.add((Prediction<Label>)new Prediction(maxLabel, predMap, features[i].numActiveElements(), example.get(i), true));
            }
        } else {
            int[] predLabels = this.parameters.predict(features);
            for (int i = 0; i < predLabels.length; ++i) {
                output.add((Prediction<Label>)new Prediction((Output)((Label)this.outputIDMap.getOutput(predLabels[i])), features[i].numActiveElements(), example.get(i)));
            }
        }
        return output;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        int maxFeatures = n < 0 ? this.featureIDMap.size() + 1 : n;
        Comparator<Pair> comparator = Comparator.comparing(Pair::getB);
        int numClasses = this.outputIDMap.size();
        int numFeatures = this.featureIDMap.size();
        HashMap<String, List<Pair<String, Double>>> map = new HashMap<String, List<Pair<String, Double>>>();
        for (int i = 0; i < numClasses; ++i) {
            PriorityQueue<Pair> q = new PriorityQueue<Pair>(maxFeatures, comparator);
            for (int j = 0; j < numFeatures; ++j) {
                Pair curr = new Pair((Object)this.featureIDMap.get(j).getName(), (Object)this.parameters.getWeight(i, j));
                if (q.size() < maxFeatures) {
                    q.offer(curr);
                    continue;
                }
                if (comparator.compare(curr, q.peek()) <= 0) continue;
                q.poll();
                q.offer(curr);
            }
            Pair curr = new Pair((Object)"BIAS", (Object)this.parameters.getBias(i));
            if (q.size() < maxFeatures) {
                q.offer(curr);
            } else if (comparator.compare(curr, q.peek()) > 0) {
                q.poll();
                q.offer(curr);
            }
            ArrayList<Pair> b = new ArrayList<Pair>();
            while (q.size() > 0) {
                b.add(q.poll());
            }
            Collections.reverse(b);
            map.put(((Label)this.outputIDMap.getOutput(i)).getLabel(), b);
        }
        return map;
    }

    public <SUB extends ConfidencePredictingSequenceModel.Subsequence> List<Double> scoreSubsequences(SequenceExample<Label> example, List<Prediction<Label>> predictions, List<SUB> subsequences) {
        if (this.confidenceType == ConfidenceType.CONSTRAINED_BP) {
            ArrayList<Chunk> chunks = new ArrayList<Chunk>();
            for (ConfidencePredictingSequenceModel.Subsequence subsequence : subsequences) {
                int[] ids = new int[subsequence.length()];
                for (int i = 0; i < ids.length; ++i) {
                    ids[i] = this.outputIDMap.getID((Output)((Label)predictions.get(i + subsequence.begin).getOutput()));
                }
                chunks.add(new Chunk(subsequence.begin, ids));
            }
            return this.scoreChunks(example, chunks);
        }
        return ConfidencePredictingSequenceModel.multiplyWeights(predictions, subsequences);
    }

    public List<Double> scoreChunks(SequenceExample<Label> example, List<Chunk> chunks) {
        SGDVector[] features = CRFModel.convertToVector(example, this.featureIDMap);
        return this.parameters.predictConfidenceUsingCBP(features, chunks);
    }

    public String generateWeightsString() {
        StringBuilder buffer = new StringBuilder();
        Tensor[] weights = this.parameters.get();
        buffer.append("Biases = ");
        buffer.append(weights[0].toString());
        buffer.append('\n');
        buffer.append("Feature-Label weights = \n");
        buffer.append(weights[1].toString());
        buffer.append('\n');
        buffer.append("Label-Label weights = \n");
        buffer.append(weights[2].toString());
        buffer.append('\n');
        return buffer.toString();
    }

    public SequenceModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        CRFModelProto.Builder modelBuilder = CRFModelProto.newBuilder();
        modelBuilder.setConfidenceType(this.confidenceType.name());
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setParams(this.parameters.serialize());
        SequenceModelProto.Builder builder = SequenceModelProto.newBuilder();
        builder.setVersion(0);
        builder.setClassName(CRFModel.class.getName());
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        return builder.build();
    }

    @Deprecated
    public static <T extends Output<T>> SparseVector[] convert(SequenceExample<T> example, ImmutableFeatureMap featureIDMap) {
        int length = example.size();
        if (length == 0) {
            throw new IllegalArgumentException("SequenceExample is empty, " + example.toString());
        }
        SparseVector[] features = new SparseVector[length];
        int i = 0;
        for (Example e : example) {
            features[i] = SparseVector.createSparseVector((Example)e, (ImmutableFeatureMap)featureIDMap, (boolean)false);
            if (features[i].numActiveElements() == 0) {
                throw new IllegalArgumentException("No features found in Example " + e.toString());
            }
            ++i;
        }
        return features;
    }

    @Deprecated
    public static Pair<int[], SparseVector[]> convert(SequenceExample<Label> example, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) {
        int length = example.size();
        if (length == 0) {
            throw new IllegalArgumentException("SequenceExample is empty, " + example.toString());
        }
        int[] labels = new int[length];
        SparseVector[] features = new SparseVector[length];
        int i = 0;
        for (Example e : example) {
            labels[i] = labelIDMap.getID((Output)((Label)e.getOutput()));
            features[i] = SparseVector.createSparseVector((Example)e, (ImmutableFeatureMap)featureIDMap, (boolean)false);
            if (features[i].numActiveElements() == 0) {
                throw new IllegalArgumentException("No features found in Example " + e.toString());
            }
            ++i;
        }
        return new Pair((Object)labels, (Object)features);
    }

    public static <T extends Output<T>> SGDVector[] convertToVector(SequenceExample<T> example, ImmutableFeatureMap featureIDMap) {
        int length = example.size();
        if (length == 0) {
            throw new IllegalArgumentException("SequenceExample is empty, " + example.toString());
        }
        int featureSpaceSize = featureIDMap.size();
        SGDVector[] features = new SGDVector[length];
        int i = 0;
        for (Example e : example) {
            features[i] = e.size() == featureSpaceSize ? DenseVector.createDenseVector((Example)e, (ImmutableFeatureMap)featureIDMap, (boolean)false) : SparseVector.createSparseVector((Example)e, (ImmutableFeatureMap)featureIDMap, (boolean)false);
            if (features[i].numActiveElements() == 0) {
                throw new IllegalArgumentException("No features found in Example " + e.toString());
            }
            ++i;
        }
        return features;
    }

    public static Pair<int[], SGDVector[]> convertToVector(SequenceExample<Label> example, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) {
        int length = example.size();
        if (length == 0) {
            throw new IllegalArgumentException("SequenceExample is empty, " + example.toString());
        }
        int featureSpaceSize = featureIDMap.size();
        int[] labels = new int[length];
        SGDVector[] features = new SGDVector[length];
        int i = 0;
        for (Example e : example) {
            labels[i] = labelIDMap.getID((Output)((Label)e.getOutput()));
            features[i] = e.size() == featureSpaceSize ? DenseVector.createDenseVector((Example)e, (ImmutableFeatureMap)featureIDMap, (boolean)false) : SparseVector.createSparseVector((Example)e, (ImmutableFeatureMap)featureIDMap, (boolean)false);
            if (features[i].numActiveElements() == 0) {
                throw new IllegalArgumentException("No features found in Example " + e.toString());
            }
            ++i;
        }
        return new Pair((Object)labels, (Object)features);
    }

    public static enum ConfidenceType {
        NONE,
        MULTIPLY,
        CONSTRAINED_BP;

    }
}

