/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.crf;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.tribuo.classification.sgd.crf.ChainHelper;
import org.tribuo.classification.sgd.crf.Chunk;
import org.tribuo.classification.sgd.protos.CRFParametersProto;
import org.tribuo.math.Parameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.protos.ParametersProto;
import org.tribuo.math.util.HeapMerger;
import org.tribuo.math.util.Merger;
import org.tribuo.protos.ProtoUtil;

public class CRFParameters
implements Parameters,
Serializable {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    private final int numLabels;
    private final int numFeatures;
    private static final Merger merger = new HeapMerger();
    private Tensor[] weights;
    private DenseVector biases;
    private DenseMatrix featureLabelWeights;
    private DenseMatrix labelLabelWeights;

    CRFParameters(int numFeatures, int numLabels) {
        this.biases = new DenseVector(numLabels);
        this.featureLabelWeights = new DenseMatrix(numLabels, numFeatures);
        this.labelLabelWeights = new DenseMatrix(numLabels, numLabels);
        this.weights = new Tensor[3];
        this.weights[0] = this.biases;
        this.weights[1] = this.featureLabelWeights;
        this.weights[2] = this.labelLabelWeights;
        this.numLabels = numLabels;
        this.numFeatures = numFeatures;
    }

    private CRFParameters(DenseVector biases, DenseMatrix featureLabelWeights, DenseMatrix labelLabelWeights) {
        this.weights = new Tensor[3];
        this.weights[0] = biases;
        this.weights[1] = featureLabelWeights;
        this.weights[2] = labelLabelWeights;
        this.numLabels = biases.size();
        this.numFeatures = featureLabelWeights.getDimension2Size();
        this.biases = biases;
        this.featureLabelWeights = featureLabelWeights;
        this.labelLabelWeights = labelLabelWeights;
    }

    public static CRFParameters deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        CRFParametersProto proto = (CRFParametersProto)message.unpack(CRFParametersProto.class);
        int numLabels = proto.getNumLabels();
        int numFeatures = proto.getNumFeatures();
        Tensor biasTensor = (Tensor)ProtoUtil.deserialize((Message)proto.getBiases());
        Tensor featureLabelTensor = (Tensor)ProtoUtil.deserialize((Message)proto.getFeatureLabelWeights());
        Tensor labelLabelTensor = (Tensor)ProtoUtil.deserialize((Message)proto.getLabelLabelWeights());
        if (!(biasTensor instanceof DenseVector)) {
            throw new IllegalArgumentException("Invalid protobuf, expected bias vector, found " + biasTensor.getClass().getSimpleName());
        }
        if (((DenseVector)biasTensor).size() != numLabels) {
            throw new IllegalArgumentException("Invalid protobuf, expected bias vector with " + numLabels + " elements, but found " + ((DenseVector)biasTensor).size());
        }
        if (!(featureLabelTensor instanceof DenseMatrix)) {
            throw new IllegalArgumentException("Invalid protobuf, expected feature/label matrix, found " + featureLabelTensor.getClass().getSimpleName());
        }
        DenseMatrix featureLabelMatrix = (DenseMatrix)featureLabelTensor;
        if (featureLabelMatrix.getDimension1Size() != numLabels || featureLabelMatrix.getDimension2Size() != numFeatures) {
            throw new IllegalArgumentException("Invalid protobuf, expected feature/label matrix of size [" + numLabels + ", " + numFeatures + "], found " + Arrays.toString(featureLabelMatrix.getShape()));
        }
        if (!(labelLabelTensor instanceof DenseMatrix)) {
            throw new IllegalArgumentException("Invalid protobuf, expected label/label matrix, found " + labelLabelTensor.getClass().getSimpleName());
        }
        DenseMatrix labelLabelMatrix = (DenseMatrix)labelLabelTensor;
        if (labelLabelMatrix.getDimension1Size() != numLabels || labelLabelMatrix.getDimension2Size() != numLabels) {
            throw new IllegalArgumentException("Invalid protobuf, expected label/label matrix of size [" + numLabels + ", " + numLabels + "], found " + Arrays.toString(labelLabelMatrix.getShape()));
        }
        return new CRFParameters((DenseVector)biasTensor, featureLabelMatrix, labelLabelMatrix);
    }

    public ParametersProto serialize() {
        ParametersProto.Builder builder = ParametersProto.newBuilder();
        builder.setVersion(0);
        builder.setClassName(CRFParameters.class.getName());
        CRFParametersProto.Builder crfParamsBuilder = CRFParametersProto.newBuilder();
        crfParamsBuilder.setNumFeatures(this.numFeatures);
        crfParamsBuilder.setNumLabels(this.numLabels);
        crfParamsBuilder.setBiases(this.biases.serialize());
        crfParamsBuilder.setFeatureLabelWeights(this.featureLabelWeights.serialize());
        crfParamsBuilder.setLabelLabelWeights(this.labelLabelWeights.serialize());
        builder.setSerializedData(Any.pack((Message)crfParamsBuilder.build()));
        return builder.build();
    }

    public DenseVector getFeatureWeights(int id) {
        return this.featureLabelWeights.getColumn(id);
    }

    public double getBias(int id) {
        return this.biases.get(id);
    }

    public double getWeight(int labelID, int featureID) {
        return this.featureLabelWeights.get(labelID, featureID);
    }

    public DenseVector[] getLocalScores(SGDVector[] features) {
        DenseVector[] localScores = new DenseVector[features.length];
        for (int i = 0; i < features.length; ++i) {
            DenseVector scores = this.featureLabelWeights.leftMultiply(features[i]);
            scores.intersectAndAddInPlace((Tensor)this.biases);
            localScores[i] = scores;
        }
        return localScores;
    }

    public ChainHelper.ChainCliqueValues getCliqueValues(SGDVector[] features) {
        DenseVector[] localScores = this.getLocalScores(features);
        return new ChainHelper.ChainCliqueValues(localScores, this.labelLabelWeights);
    }

    public int[] predict(SGDVector[] features) {
        ChainHelper.ChainViterbiResults result = ChainHelper.viterbi(this.getCliqueValues(features));
        return result.mapValues;
    }

    public DenseVector[] predictMarginals(SGDVector[] features) {
        ChainHelper.ChainBPResults result = ChainHelper.beliefPropagation(this.getCliqueValues(features));
        DenseVector[] marginals = new DenseVector[features.length];
        for (int i = 0; i < features.length; ++i) {
            marginals[i] = result.alphas[i].add((SGDVector)result.betas[i]);
            marginals[i].expNormalize(result.logZ);
        }
        return marginals;
    }

    public List<Double> predictConfidenceUsingCBP(SGDVector[] features, List<Chunk> chunks) {
        ChainHelper.ChainCliqueValues cliqueValues = this.getCliqueValues(features);
        ChainHelper.ChainBPResults bpResult = ChainHelper.beliefPropagation(cliqueValues);
        double bpLogZ = bpResult.logZ;
        int[] constraints = new int[features.length];
        ArrayList<Double> output = new ArrayList<Double>();
        for (Chunk chunk : chunks) {
            Arrays.fill(constraints, -1);
            chunk.unpack(constraints);
            double chunkScore = ChainHelper.constrainedBeliefPropagation(cliqueValues, constraints);
            output.add(Math.exp(chunkScore - bpLogZ));
        }
        return output;
    }

    public Pair<Double, Tensor[]> valueAndGradient(SGDVector[] features, int[] labels) {
        ChainHelper.ChainCliqueValues scores = this.getCliqueValues(features);
        ChainHelper.ChainBPResults bpResults = ChainHelper.beliefPropagation(scores);
        double logZ = bpResults.logZ;
        DenseVector[] alphas = bpResults.alphas;
        DenseVector[] betas = bpResults.betas;
        Tensor[] gradient = new Tensor[3];
        DenseSparseMatrix[] featureGradients = new DenseSparseMatrix[features.length];
        DenseMatrix denseFeatureGradients = null;
        boolean sparseFeatures = false;
        gradient[0] = new DenseVector(this.biases.size());
        DenseMatrix transGradient = new DenseMatrix(this.numLabels, this.numLabels);
        gradient[2] = transGradient;
        double score = -logZ;
        for (int i = 0; i < features.length; ++i) {
            int curLabel = labels[i];
            DenseVector curLocalScores = scores.localValues[i];
            score += curLocalScores.get(curLabel);
            DenseVector curAlpha = alphas[i];
            DenseVector curBeta = betas[i];
            DenseVector localMarginal = curAlpha.add((SGDVector)curBeta);
            localMarginal.expNormalize(logZ);
            localMarginal.scaleInPlace(-1.0);
            localMarginal.add(curLabel, 1.0);
            gradient[0].intersectAndAddInPlace((Tensor)localMarginal);
            Matrix tmpFeatureGradient = localMarginal.outer(features[i]);
            if (tmpFeatureGradient instanceof DenseSparseMatrix) {
                featureGradients[i] = (DenseSparseMatrix)tmpFeatureGradient;
                sparseFeatures = true;
            } else if (denseFeatureGradients == null) {
                denseFeatureGradients = (DenseMatrix)tmpFeatureGradient;
            } else {
                denseFeatureGradients.intersectAndAddInPlace((Tensor)tmpFeatureGradient);
            }
            if (i < 1) continue;
            DenseVector prevAlpha = alphas[i - 1];
            for (int ii = 0; ii < this.numLabels; ++ii) {
                double prevAlphaVal = prevAlpha.get(ii);
                for (int jj = 0; jj < this.numLabels; ++jj) {
                    double update = -Math.exp(prevAlphaVal + this.labelLabelWeights.get(ii, jj) + curBeta.get(jj) + curLocalScores.get(jj) - logZ);
                    transGradient.add(ii, jj, update);
                }
            }
            int prevLabel = labels[i - 1];
            score += this.labelLabelWeights.get(prevLabel, curLabel);
            transGradient.add(prevLabel, curLabel, 1.0);
        }
        if (sparseFeatures) {
            gradient[1] = merger.merge(featureGradients);
            if (denseFeatureGradients != null) {
                throw new IllegalStateException("Mixture of dense and sparse features found.");
            }
        } else {
            gradient[1] = denseFeatureGradients;
        }
        return new Pair((Object)score, (Object)gradient);
    }

    public Tensor[] getEmptyCopy() {
        Tensor[] output = new Tensor[]{new DenseVector(this.biases.size()), new DenseMatrix(this.featureLabelWeights.getDimension1Size(), this.featureLabelWeights.getDimension2Size()), new DenseMatrix(this.labelLabelWeights.getDimension1Size(), this.labelLabelWeights.getDimension2Size())};
        return output;
    }

    public Tensor[] get() {
        return this.weights;
    }

    public void set(Tensor[] newWeights) {
        if (newWeights.length == this.weights.length) {
            this.weights = newWeights;
            this.biases = (DenseVector)this.weights[0];
            this.featureLabelWeights = (DenseMatrix)this.weights[1];
            this.labelLabelWeights = (DenseMatrix)this.weights[2];
        }
    }

    public void update(Tensor[] gradients) {
        for (int i = 0; i < gradients.length; ++i) {
            this.weights[i].intersectAndAddInPlace(gradients[i]);
        }
    }

    public Tensor[] merge(Tensor[][] gradients, int size) {
        DenseSparseMatrix featureLabelUpdate;
        DenseVector biasUpdate = new DenseVector(this.biases.size());
        ArrayList<DenseSparseMatrix> updates = new ArrayList<DenseSparseMatrix>(size);
        DenseSparseMatrix denseUpdates = null;
        DenseMatrix labelLabelUpdate = new DenseMatrix(this.labelLabelWeights.getDimension1Size(), this.labelLabelWeights.getDimension2Size());
        for (int j = 0; j < gradients.length; ++j) {
            biasUpdate.intersectAndAddInPlace(gradients[j][0]);
            Matrix tmpUpdate = (Matrix)gradients[j][1];
            if (tmpUpdate instanceof DenseSparseMatrix) {
                updates.add((DenseSparseMatrix)tmpUpdate);
            } else if (denseUpdates == null) {
                denseUpdates = (DenseMatrix)tmpUpdate;
            } else {
                denseUpdates.intersectAndAddInPlace((Tensor)tmpUpdate);
            }
            labelLabelUpdate.intersectAndAddInPlace(gradients[j][2]);
        }
        if (updates.size() > 0) {
            featureLabelUpdate = merger.merge(updates.toArray(new DenseSparseMatrix[0]));
            if (denseUpdates != null) {
                denseUpdates.intersectAndAddInPlace((Tensor)featureLabelUpdate);
                featureLabelUpdate = denseUpdates;
            }
        } else {
            featureLabelUpdate = denseUpdates;
        }
        return new Tensor[]{biasUpdate, featureLabelUpdate, labelLabelUpdate};
    }
}

