/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.MutableNumber;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringInfo;
import org.tribuo.clustering.protos.ClusteringInfoProto;
import org.tribuo.protos.ProtoSerializableClass;

@ProtoSerializableClass(serializedDataClass=ClusteringInfoProto.class, version=0)
public class ImmutableClusteringInfo
extends ClusteringInfo
implements ImmutableOutputInfo<ClusterID> {
    private static final long serialVersionUID = 1L;
    private final Set<ClusterID> domain;

    public ImmutableClusteringInfo(Map<Integer, MutableLong> counts) {
        this.clusterCounts.putAll(MutableNumber.copyMap(counts));
        HashSet<ClusterID> outputs = new HashSet<ClusterID>();
        for (Map.Entry e : this.clusterCounts.entrySet()) {
            outputs.add(new ClusterID((Integer)e.getKey()));
        }
        this.domain = Collections.unmodifiableSet(outputs);
    }

    public ImmutableClusteringInfo(ClusteringInfo other) {
        super(other);
        HashSet<ClusterID> outputs = new HashSet<ClusterID>();
        for (Map.Entry e : this.clusterCounts.entrySet()) {
            outputs.add(new ClusterID((Integer)e.getKey()));
        }
        this.domain = Collections.unmodifiableSet(outputs);
    }

    private ImmutableClusteringInfo(Map<Integer, MutableLong> counts, int unknownCount) {
        super(counts, unknownCount);
        HashSet<ClusterID> outputs = new HashSet<ClusterID>();
        for (Map.Entry e : this.clusterCounts.entrySet()) {
            outputs.add(new ClusterID((Integer)e.getKey()));
        }
        this.domain = Collections.unmodifiableSet(outputs);
    }

    public static ImmutableClusteringInfo deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        ClusteringInfoProto proto = (ClusteringInfoProto)message.unpack(ClusteringInfoProto.class);
        if (proto.getIdCount() != proto.getCountCount()) {
            throw new IllegalArgumentException("Invalid protobuf, different numbers of ids and counts, labels " + proto.getIdCount() + ", counts " + proto.getCountCount());
        }
        HashMap<Integer, MutableLong> labelCounts = new HashMap<Integer, MutableLong>();
        for (int i = 0; i < proto.getIdCount(); ++i) {
            long cnt;
            Integer lbl = proto.getId(i);
            MutableLong old = labelCounts.put(lbl, new MutableLong(cnt = proto.getCount(i)));
            if (old == null) continue;
            throw new IllegalArgumentException("Invalid protobuf, two mappings for " + lbl);
        }
        return new ImmutableClusteringInfo(labelCounts, proto.getUnknownCount());
    }

    @Override
    public Set<ClusterID> getDomain() {
        return this.domain;
    }

    public int getID(ClusterID output) {
        return output.getID();
    }

    public ClusterID getOutput(int id) {
        return new ClusterID(id);
    }

    public long getTotalObservations() {
        long count = 0L;
        for (Map.Entry e : this.clusterCounts.entrySet()) {
            count += ((MutableLong)e.getValue()).longValue();
        }
        return count;
    }

    @Override
    public ClusteringInfo copy() {
        return new ImmutableClusteringInfo(this);
    }

    public Iterator<Pair<Integer, ClusterID>> iterator() {
        return new ImmutableInfoIterator(this.clusterCounts.keySet());
    }

    public boolean domainAndIDEquals(ImmutableOutputInfo<ClusterID> other) {
        return this.getDomain().equals(other.getDomain());
    }

    private static class ImmutableInfoIterator
    implements Iterator<Pair<Integer, ClusterID>> {
        private final Iterator<Integer> itr;

        public ImmutableInfoIterator(Set<Integer> idLabelMap) {
            this.itr = idLabelMap.iterator();
        }

        @Override
        public boolean hasNext() {
            return this.itr.hasNext();
        }

        @Override
        public Pair<Integer, ClusterID> next() {
            int id = this.itr.next();
            return new Pair((Object)id, (Object)new ClusterID(id));
        }
    }
}

