/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.common.sgd;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.SplittableRandom;
import org.tribuo.common.sgd.protos.FMParametersProto;
import org.tribuo.math.FeedForwardParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.ParametersProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.math.util.HeapMerger;
import org.tribuo.math.util.Merger;
import org.tribuo.protos.ProtoUtil;

public final class FMParameters
implements FeedForwardParameters {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    private static final Merger merger = new HeapMerger();
    private Tensor[] weights;
    private DenseVector biasVector;
    private DenseMatrix weightMatrix;
    private final int numFactors;

    public FMParameters(SplittableRandom rng, int numFeatures, int numLabels, int numFactors, double variance) {
        this.weights = new Tensor[numLabels + 2];
        this.biasVector = new DenseVector(numLabels);
        this.weightMatrix = new DenseMatrix(numLabels, numFeatures);
        this.weights[0] = this.biasVector;
        this.weights[1] = this.weightMatrix;
        for (int i = 0; i < numLabels; ++i) {
            DenseMatrix curMatrix = new DenseMatrix(numFactors, numFeatures);
            this.initializeMatrix(rng, variance, curMatrix);
            this.weights[i + 2] = curMatrix;
        }
        this.numFactors = numFactors;
    }

    private FMParameters(Tensor[] weights, int numFactors) {
        this.weights = weights;
        this.biasVector = (DenseVector)weights[0];
        this.weightMatrix = (DenseMatrix)weights[1];
        this.numFactors = numFactors;
    }

    public static FMParameters deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        FMParametersProto proto = (FMParametersProto)message.unpack(FMParametersProto.class);
        int numFactors = proto.getNumFactors();
        List<TensorProto> tensorProtoList = proto.getWeightsList();
        Tensor[] tensors = new Tensor[tensorProtoList.size()];
        for (int i = 0; i < tensors.length; ++i) {
            tensors[i] = (Tensor)ProtoUtil.deserialize((Message)tensorProtoList.get(i));
        }
        if (tensors[0] instanceof DenseVector) {
            int numOutputs = ((DenseVector)tensors[0]).size();
            if (numOutputs + 2 == tensors.length) {
                if (tensors[1] instanceof DenseMatrix) {
                    DenseMatrix weightMatrix = (DenseMatrix)tensors[1];
                    int numFeatures = weightMatrix.getDimension2Size();
                    if (weightMatrix.getDimension1Size() == numOutputs) {
                        for (int i = 2; i < tensors.length; ++i) {
                            DenseMatrix dm;
                            if (!(tensors[i] instanceof DenseMatrix) || (dm = (DenseMatrix)tensors[i]).getDimension1Size() == numFactors && dm.getDimension2Size() == numFeatures) continue;
                            throw new IllegalArgumentException("Invalid protobuf, expected factor matrix of shape [" + numFactors + ", " + numFeatures + "], found " + Arrays.toString(dm.getShape()));
                        }
                        return new FMParameters(tensors, numFactors);
                    }
                    throw new IllegalArgumentException("Invalid protobuf, expected weight matrix of shape [" + numOutputs + "," + numFeatures + "], found " + Arrays.toString(weightMatrix.getShape()));
                }
                throw new IllegalArgumentException("Invalid protobuf, expected DenseMatrix, found " + tensors[1].getClass());
            }
            throw new IllegalArgumentException("Invalid protobuf, expected " + (numOutputs + 2) + " weight tensors, found " + tensors.length);
        }
        throw new IllegalArgumentException("Invalid protobuf, expected bias vector found " + tensors[0].getClass());
    }

    public ParametersProto serialize() {
        ParametersProto.Builder builder = ParametersProto.newBuilder();
        builder.setVersion(0);
        builder.setClassName(FMParameters.class.getName());
        FMParametersProto.Builder fmParamsBuilder = FMParametersProto.newBuilder();
        fmParamsBuilder.setNumFactors(this.numFactors);
        for (int i = 0; i < this.weights.length; ++i) {
            fmParamsBuilder.addWeights((TensorProto)this.weights[i].serialize());
        }
        builder.setSerializedData(Any.pack((Message)fmParamsBuilder.build()));
        return builder.build();
    }

    private void initializeMatrix(SplittableRandom rng, double variance, DenseMatrix matrix) {
        Random innerRNG = new Random(rng.nextLong());
        int dim1 = matrix.getDimension1Size();
        int dim2 = matrix.getDimension2Size();
        for (int i = 0; i < dim1; ++i) {
            for (int j = 0; j < dim2; ++j) {
                matrix.set(i, j, innerRNG.nextGaussian() * variance);
            }
        }
    }

    public DenseVector predict(SGDVector example) {
        DenseVector pred = this.weightMatrix.leftMultiply(example);
        pred.intersectAndAddInPlace((Tensor)this.biasVector);
        DenseVector factorizedPred = new DenseVector(this.biasVector.size());
        for (int i = 2; i < this.weights.length; ++i) {
            DenseMatrix curMatrix = (DenseMatrix)this.weights[i];
            double curValue = 0.0;
            for (int k = 0; k < this.numFactors; ++k) {
                double sumOfSquares = 0.0;
                double sum = 0.0;
                for (VectorTuple v : example) {
                    double curWeight = curMatrix.get(k, v.index);
                    double value = curWeight * v.value;
                    sum += value;
                    sumOfSquares += value * value;
                }
                curValue += sum * sum - sumOfSquares;
            }
            factorizedPred.set(i - 2, curValue /= 2.0);
        }
        pred.intersectAndAddInPlace((Tensor)factorizedPred);
        return pred;
    }

    public Tensor[] gradients(Pair<Double, SGDVector> score, SGDVector features) {
        Tensor[] gradients = new Tensor[this.weights.length];
        SGDVector outputGradient = (SGDVector)score.getB();
        gradients[0] = outputGradient instanceof SparseVector ? ((SparseVector)outputGradient).densify() : outputGradient.copy();
        gradients[1] = outputGradient.outer(features);
        for (int i = 2; i < this.weights.length; ++i) {
            double curOutputGradient = outputGradient.get(i - 2);
            DenseMatrix curFactors = (DenseMatrix)this.weights[i];
            if (curOutputGradient != 0.0) {
                int j;
                DenseMatrix factorGradMatrix;
                DenseVector factorSum = curFactors.leftMultiply(features);
                if (features instanceof SparseVector) {
                    ArrayList<SparseVector> vectors = new ArrayList<SparseVector>(this.numFactors);
                    for (int j2 = 0; j2 < this.numFactors; ++j2) {
                        vectors.add(((SparseVector)features).copy());
                    }
                    factorGradMatrix = new DenseSparseMatrix(vectors);
                } else {
                    factorGradMatrix = new DenseMatrix(this.numFactors, features.size());
                    for (j = 0; j < this.numFactors; ++j) {
                        for (int k = 0; k < features.size(); ++k) {
                            factorGradMatrix.set(j, k, features.get(k));
                        }
                    }
                }
                j = 0;
                while (j < this.numFactors) {
                    SGDVector curFactorGrad = factorGradMatrix.getRow(j);
                    double curFactorSum = factorSum.get(j);
                    int jFinal = j++;
                    curFactorGrad.foreachIndexedInPlace((idx, a) -> a * curFactorSum - curFactors.get(jFinal, idx.intValue()) * a * a);
                    curFactorGrad.scaleInPlace(curOutputGradient);
                }
                gradients[i] = factorGradMatrix;
                continue;
            }
            gradients[i] = new DenseSparseMatrix(this.numFactors, features.size());
        }
        return gradients;
    }

    public Tensor[] getEmptyCopy() {
        Tensor[] output = new Tensor[this.weights.length];
        output[0] = new DenseVector(this.biasVector.size());
        output[1] = new DenseMatrix(this.weightMatrix.getDimension1Size(), this.weightMatrix.getDimension2Size());
        for (int i = 2; i < this.weights.length; ++i) {
            DenseMatrix curMatrix = (DenseMatrix)this.weights[i];
            output[i] = new DenseMatrix(curMatrix.getDimension1Size(), curMatrix.getDimension2Size());
        }
        return output;
    }

    public Tensor[] get() {
        return this.weights;
    }

    public void set(Tensor[] newWeights) {
        if (newWeights.length == this.weights.length) {
            this.weights = newWeights;
            this.biasVector = (DenseVector)this.weights[0];
            this.weightMatrix = (DenseMatrix)this.weights[1];
        }
    }

    public void update(Tensor[] gradients) {
        for (int i = 0; i < gradients.length; ++i) {
            this.weights[i].intersectAndAddInPlace(gradients[i]);
        }
    }

    public Tensor[] merge(Tensor[][] gradients, int size) {
        Tensor[] output = new Tensor[this.weights.length];
        for (int i = 0; i < this.weights.length; ++i) {
            if (gradients[0][i] instanceof DenseVector) {
                for (int j = 1; j < size; ++j) {
                    gradients[0][i].intersectAndAddInPlace(gradients[j][i]);
                }
                output[i] = gradients[0][i];
                continue;
            }
            if (gradients[0][i] instanceof DenseMatrix) {
                for (int j = 1; j < size; ++j) {
                    gradients[0][i].intersectAndAddInPlace(gradients[j][i]);
                }
                output[i] = gradients[0][i];
                continue;
            }
            if (gradients[0][i] instanceof DenseSparseMatrix) {
                DenseSparseMatrix[] updates = new DenseSparseMatrix[size];
                for (int j = 0; j < updates.length; ++j) {
                    updates[j] = (DenseSparseMatrix)gradients[j][0];
                }
                DenseSparseMatrix update = merger.merge(updates);
                output[i] = update;
                continue;
            }
            throw new IllegalStateException("Unexpected gradient type, expected DenseVector, DenseMatrix or DenseSparseMatrix, received " + gradients[0][i].getClass().getName());
        }
        return output;
    }

    public FMParameters copy() {
        Tensor[] weightCopy = new Tensor[this.weights.length];
        for (int i = 0; i < this.weights.length; ++i) {
            weightCopy[i] = this.weights[i].copy();
        }
        return new FMParameters(weightCopy, this.numFactors);
    }
}

