/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers.util;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.util.function.DoubleUnaryOperator;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.MatrixIterator;
import org.tribuo.math.la.MatrixTuple;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.optimisers.util.ShrinkingTensor;
import org.tribuo.math.protos.DenseTensorProto;
import org.tribuo.math.protos.ShrinkingDenseTensorProto;
import org.tribuo.math.protos.TensorProto;

public class ShrinkingMatrix
extends DenseMatrix
implements ShrinkingTensor {
    private final double baseRate;
    private final double lambdaSqrt;
    private final boolean scaleShrinking;
    private final boolean reproject;
    private double squaredTwoNorm;
    private int iteration;
    private double multiplier;

    public ShrinkingMatrix(DenseMatrix v, double baseRate, boolean scaleShrinking) {
        super(v);
        this.baseRate = baseRate;
        this.scaleShrinking = scaleShrinking;
        this.lambdaSqrt = 0.0;
        this.reproject = false;
        this.squaredTwoNorm = 0.0;
        this.iteration = 1;
        this.multiplier = 1.0;
    }

    public ShrinkingMatrix(DenseMatrix v, double baseRate, double lambda) {
        super(v);
        this.baseRate = baseRate;
        this.scaleShrinking = true;
        this.lambdaSqrt = Math.sqrt(lambda);
        this.reproject = true;
        this.squaredTwoNorm = 0.0;
        this.iteration = 1;
        this.multiplier = 1.0;
    }

    private ShrinkingMatrix(DenseMatrix v, double baseRate, boolean scaleShrinking, double lambdaSqrt, boolean reproject, double squaredTwoNorm, int iteration, double multiplier) {
        super(v);
        this.baseRate = baseRate;
        this.scaleShrinking = scaleShrinking;
        this.lambdaSqrt = lambdaSqrt;
        if (!reproject && lambdaSqrt != 0.0) {
            throw new IllegalStateException("Invalid ShrinkingMatrix, when reproject is true lambda must be zero");
        }
        this.reproject = reproject;
        this.squaredTwoNorm = squaredTwoNorm;
        if (iteration < 0) {
            throw new IllegalArgumentException("Invalid ShrinkingMatrix, iteration must be non-negative");
        }
        this.iteration = iteration;
        this.multiplier = multiplier;
    }

    public static ShrinkingMatrix deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        ShrinkingDenseTensorProto proto = (ShrinkingDenseTensorProto)message.unpack(ShrinkingDenseTensorProto.class);
        DenseMatrix data = DenseMatrix.unpackProto(proto.getData());
        return new ShrinkingMatrix(data, proto.getBaseRate(), proto.getScaleShrinking(), proto.getLambdaSqrt(), proto.getReproject(), proto.getSquaredTwoNorm(), proto.getIteration(), proto.getMultiplier());
    }

    @Override
    public TensorProto serialize() {
        TensorProto.Builder builder = TensorProto.newBuilder();
        builder.setVersion(0);
        builder.setClassName(ShrinkingMatrix.class.getName());
        ShrinkingDenseTensorProto.Builder shrinkingBuilder = ShrinkingDenseTensorProto.newBuilder();
        DenseTensorProto.Builder dataBuilder = DenseTensorProto.newBuilder();
        dataBuilder.addDimensions(this.dim1);
        dataBuilder.addDimensions(this.dim2);
        ByteBuffer buffer = ByteBuffer.allocate(this.dim1 * this.dim2 * 8).order(ByteOrder.LITTLE_ENDIAN);
        DoubleBuffer doubleBuffer = buffer.asDoubleBuffer();
        for (int i = 0; i < this.values.length; ++i) {
            doubleBuffer.put(this.values[i]);
        }
        doubleBuffer.rewind();
        dataBuilder.setValues(ByteString.copyFrom((ByteBuffer)buffer));
        shrinkingBuilder.setData(dataBuilder.build());
        shrinkingBuilder.setBaseRate(this.baseRate);
        shrinkingBuilder.setLambdaSqrt(this.lambdaSqrt);
        shrinkingBuilder.setScaleShrinking(this.scaleShrinking);
        shrinkingBuilder.setReproject(this.reproject);
        shrinkingBuilder.setSquaredTwoNorm(this.squaredTwoNorm);
        shrinkingBuilder.setIteration(this.iteration);
        shrinkingBuilder.setMultiplier(this.multiplier);
        builder.setSerializedData(Any.pack((Message)shrinkingBuilder.build()));
        return builder.build();
    }

    @Override
    public DenseMatrix convertToDense() {
        return new DenseMatrix(this);
    }

    @Override
    public DenseVector leftMultiply(SGDVector input) {
        if (input.size() == this.dim2) {
            double[] output = new double[this.dim1];
            for (VectorTuple tuple : input) {
                for (int i = 0; i < output.length; ++i) {
                    int n = i;
                    output[n] = output[n] + this.get(i, tuple.index) * tuple.value;
                }
            }
            return DenseVector.createDenseVector(output);
        }
        throw new IllegalArgumentException("input.size() != dim2");
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
        double projectionNormaliser;
        if (!(other instanceof Matrix)) throw new IllegalStateException("Adding a non-Matrix to a Matrix");
        Matrix otherMat = (Matrix)other;
        if (this.dim1 != otherMat.getDimension1Size() || this.dim2 != otherMat.getDimension2Size()) throw new IllegalStateException("Matrices are not the same size, this(" + this.dim1 + "," + this.dim2 + "), other(" + otherMat.getDimension1Size() + "," + otherMat.getDimension2Size() + ")");
        double shrinkage = this.scaleShrinking ? 1.0 - this.baseRate / (double)this.iteration : 1.0 - this.baseRate;
        this.scaleInPlace(shrinkage);
        for (MatrixTuple tuple : otherMat) {
            double update = f.applyAsDouble(tuple.value);
            double oldValue = this.values[tuple.i][tuple.j] * this.multiplier;
            double newValue = oldValue + update;
            this.squaredTwoNorm -= oldValue * oldValue;
            this.squaredTwoNorm += newValue * newValue;
            this.values[tuple.i][tuple.j] = newValue / this.multiplier;
        }
        if (this.reproject && (projectionNormaliser = 1.0 / this.lambdaSqrt / this.twoNorm()) < 1.0) {
            this.scaleInPlace(projectionNormaliser);
        }
        ++this.iteration;
    }

    @Override
    public double get(int i, int j) {
        return this.values[i][j] * this.multiplier;
    }

    @Override
    public void scaleInPlace(double value) {
        this.multiplier *= value;
        if (Math.abs(this.multiplier) < 1.0E-6) {
            this.reifyMultiplier();
        }
    }

    private void reifyMultiplier() {
        for (int i = 0; i < this.dim1; ++i) {
            int j = 0;
            while (j < this.dim2) {
                double[] dArray = this.values[i];
                int n = j++;
                dArray[n] = dArray[n] * this.multiplier;
            }
        }
        this.multiplier = 1.0;
    }

    @Override
    public double twoNorm() {
        return Math.sqrt(this.squaredTwoNorm);
    }

    @Override
    public MatrixIterator iterator() {
        return new ShrinkingMatrixIterator(this);
    }

    private class ShrinkingMatrixIterator
    implements MatrixIterator {
        private final ShrinkingMatrix matrix;
        private final MatrixTuple tuple;
        private int i;
        private int j;

        public ShrinkingMatrixIterator(ShrinkingMatrix matrix) {
            this.matrix = matrix;
            this.tuple = new MatrixTuple();
            this.i = 0;
            this.j = 0;
        }

        @Override
        public MatrixTuple getReference() {
            return this.tuple;
        }

        @Override
        public boolean hasNext() {
            return this.i < this.matrix.dim1 && this.j < this.matrix.dim2;
        }

        @Override
        public MatrixTuple next() {
            this.tuple.i = this.i;
            this.tuple.j = this.j;
            this.tuple.value = this.matrix.get(this.i, this.j);
            if (this.j < ShrinkingMatrix.this.dim2 - 1) {
                ++this.j;
            } else {
                ++this.i;
                this.j = 0;
            }
            return this.tuple;
        }
    }
}

