/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.HashMap;
import java.util.Map;
import org.tribuo.MutableOutputInfo;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.clustering.ClusteringInfo;
import org.tribuo.clustering.protos.ClusteringInfoProto;
import org.tribuo.protos.ProtoSerializableClass;

@ProtoSerializableClass(serializedDataClass=ClusteringInfoProto.class, version=0)
public class MutableClusteringInfo
extends ClusteringInfo
implements MutableOutputInfo<ClusterID> {
    private static final long serialVersionUID = 1L;

    MutableClusteringInfo() {
    }

    MutableClusteringInfo(ClusteringInfo info) {
        super(info);
    }

    private MutableClusteringInfo(Map<Integer, MutableLong> counts, int unknownCount) {
        super(counts, unknownCount);
    }

    public static MutableClusteringInfo deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        ClusteringInfoProto proto = (ClusteringInfoProto)message.unpack(ClusteringInfoProto.class);
        if (proto.getIdCount() != proto.getCountCount()) {
            throw new IllegalArgumentException("Invalid protobuf, different numbers of ids and counts, labels " + proto.getIdCount() + ", counts " + proto.getCountCount());
        }
        HashMap<Integer, MutableLong> labelCounts = new HashMap<Integer, MutableLong>();
        for (int i = 0; i < proto.getIdCount(); ++i) {
            long cnt;
            Integer lbl = proto.getId(i);
            MutableLong old = labelCounts.put(lbl, new MutableLong(cnt = proto.getCount(i)));
            if (old == null) continue;
            throw new IllegalArgumentException("Invalid protobuf, two mappings for " + lbl);
        }
        return new MutableClusteringInfo(labelCounts, proto.getUnknownCount());
    }

    public void observe(ClusterID output) {
        if (output == ClusteringFactory.UNASSIGNED_CLUSTER_ID) {
            ++this.unknownCount;
        } else {
            int id = output.getID();
            MutableLong value = this.clusterCounts.computeIfAbsent(id, k -> new MutableLong());
            value.increment();
        }
    }

    public void clear() {
        this.clusterCounts.clear();
    }

    @Override
    public MutableClusteringInfo copy() {
        return new MutableClusteringInfo(this);
    }
}

