/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.util;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import java.util.Arrays;
import java.util.PriorityQueue;
import java.util.logging.Logger;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.MatrixIterator;
import org.tribuo.math.la.MatrixTuple;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.protos.MergerProto;
import org.tribuo.math.util.HeapMerger;
import org.tribuo.math.util.Merger;

public class MatrixHeapMerger
implements Merger {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = Logger.getLogger(MatrixHeapMerger.class.getName());
    public static final int CURRENT_VERSION = 0;

    public static MatrixHeapMerger deserializeFromProto(int version, String className, Any message) {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        if (message.getValue() != ByteString.EMPTY) {
            throw new IllegalArgumentException("Invalid proto");
        }
        return new MatrixHeapMerger();
    }

    public MergerProto serialize() {
        MergerProto.Builder mergerProto = MergerProto.newBuilder();
        mergerProto.setClassName(this.getClass().getName());
        mergerProto.setVersion(0);
        return mergerProto.build();
    }

    @Override
    public DenseSparseMatrix merge(DenseSparseMatrix[] inputs) {
        int sparseLength = inputs[0].getDimension2Size();
        PriorityQueue<MatrixIterator> queue = new PriorityQueue<MatrixIterator>();
        int[] totalLengths = new int[inputs[0].getDimension1Size()];
        for (int i = 0; i < inputs.length; ++i) {
            for (int j = 0; j < totalLengths.length; ++j) {
                int n = j;
                totalLengths[n] = totalLengths[n] + inputs[i].numActiveElements(j);
            }
            MatrixIterator cur = inputs[i].iterator();
            cur.next();
            queue.add(cur);
        }
        int maxLength = 0;
        for (int i = 0; i < totalLengths.length; ++i) {
            if (totalLengths[i] <= maxLength) continue;
            maxLength = totalLengths[i];
        }
        SparseVector[] output = new SparseVector[totalLengths.length];
        int denseCounter = 0;
        int sparseCounter = 0;
        int sparseIndex = -1;
        int[] curIndices = new int[maxLength];
        double[] curValues = new double[maxLength];
        while (!queue.isEmpty()) {
            MatrixIterator cur = (MatrixIterator)queue.peek();
            MatrixTuple ref = cur.getReference();
            if (ref.i > denseCounter) {
                int[] indices = Arrays.copyOf(curIndices, sparseCounter + 1);
                double[] values = Arrays.copyOf(curValues, sparseCounter + 1);
                output[denseCounter] = SparseVector.createSparseVector(sparseLength, indices, values);
                Arrays.fill(curIndices, 0);
                Arrays.fill(curValues, 0.0);
                sparseIndex = -1;
                sparseCounter = 0;
                ++denseCounter;
            }
            if (sparseIndex == -1) {
                curIndices[sparseCounter] = sparseIndex = ref.j;
                curValues[sparseCounter] = ref.value;
            } else if (ref.j == sparseIndex) {
                int n = sparseCounter;
                curValues[n] = curValues[n] + ref.value;
            } else {
                sparseIndex = ref.j;
                curIndices[++sparseCounter] = sparseIndex;
                curValues[sparseCounter] = ref.value;
            }
            if (!cur.hasNext()) {
                queue.poll();
                continue;
            }
            cur.next();
            MatrixIterator tmp = (MatrixIterator)queue.poll();
            queue.offer(tmp);
        }
        int[] indices = Arrays.copyOf(curIndices, sparseCounter + 1);
        double[] values = Arrays.copyOf(curValues, sparseCounter + 1);
        output[denseCounter] = SparseVector.createSparseVector(sparseLength, indices, values);
        return DenseSparseMatrix.createFromSparseVectors(output);
    }

    @Override
    public SparseVector merge(SparseVector[] inputs) {
        int maxLength = 0;
        for (int i = 0; i < inputs.length; ++i) {
            maxLength += inputs[i].numActiveElements();
        }
        return HeapMerger.merge(Arrays.asList(inputs), inputs[0].size(), new int[maxLength], new double[maxLength]);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        return o != null && this.getClass() == o.getClass();
    }

    public int hashCode() {
        return 31;
    }
}

