/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.distributions;

import java.util.Arrays;
import java.util.Optional;
import java.util.Random;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;

public final class MultivariateNormalDistribution {
    private final long seed;
    private final Random rng;
    private final DenseVector means;
    private final DenseMatrix covariance;
    private final DenseMatrix samplingCovariance;
    private final boolean eigenDecomposition;

    public MultivariateNormalDistribution(double[] means, double[][] covariance, long seed) {
        this(DenseVector.createDenseVector(means), DenseMatrix.createDenseMatrix(covariance), seed);
    }

    public MultivariateNormalDistribution(double[] means, double[][] covariance, long seed, boolean eigenDecomposition) {
        this(DenseVector.createDenseVector(means), DenseMatrix.createDenseMatrix(covariance), seed, eigenDecomposition);
    }

    public MultivariateNormalDistribution(DenseVector means, DenseMatrix covariance, long seed) {
        this(means, covariance, seed, false);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public MultivariateNormalDistribution(DenseVector means, DenseMatrix covariance, long seed, boolean eigenDecomposition) {
        this.seed = seed;
        this.rng = new Random(seed);
        this.means = means.copy();
        this.covariance = covariance.copy();
        if (this.covariance.getDimension1Size() != this.means.size() || this.covariance.getDimension2Size() != this.means.size()) {
            throw new IllegalArgumentException("Covariance matrix must be square and the same dimension as the mean vector. Mean vector size = " + means.size() + ", covariance size = " + Arrays.toString(this.covariance.getShape()));
        }
        this.eigenDecomposition = eigenDecomposition;
        if (eigenDecomposition) {
            Optional<DenseMatrix.EigenDecomposition> factorization = this.covariance.eigenDecomposition();
            if (!factorization.isPresent() || !factorization.get().positiveEigenvalues()) throw new IllegalArgumentException("Covariance matrix is not positive definite.");
            DenseVector eigenvalues = factorization.get().eigenvalues();
            DenseMatrix eigenvectors = new DenseMatrix(factorization.get().eigenvectors());
            eigenvalues.foreachInPlace(Math::sqrt);
            DenseSparseMatrix diagonal = DenseSparseMatrix.createDiagonal(eigenvalues);
            this.samplingCovariance = eigenvectors.matrixMultiply(diagonal).matrixMultiply(eigenvectors, false, true);
            return;
        } else {
            Optional<DenseMatrix.CholeskyFactorization> factorization = this.covariance.choleskyFactorization();
            if (!factorization.isPresent()) throw new IllegalArgumentException("Covariance matrix is not positive definite.");
            this.samplingCovariance = factorization.get().lMatrix();
        }
    }

    public DenseVector sampleVector() {
        DenseVector sampled = new DenseVector(this.means.size());
        for (int i = 0; i < this.means.size(); ++i) {
            sampled.set(i, this.rng.nextGaussian());
        }
        sampled = this.samplingCovariance.leftMultiply(sampled);
        return this.means.add(sampled);
    }

    public double[] sampleArray() {
        return this.sampleVector().toArray();
    }

    public String toString() {
        return "MultivariateNormal(mean=" + this.means + ",covariance=" + this.covariance + ",seed=" + this.seed + ",useEigenDecomposition=" + this.eigenDecomposition + ")";
    }
}

