/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.tribuo.MutableOutputInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.LabelInfo;
import org.tribuo.classification.protos.MutableLabelInfoProto;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.OutputDomainProto;

@ProtoSerializableClass(serializedDataClass=MutableLabelInfoProto.class, version=0)
public class MutableLabelInfo
extends LabelInfo
implements MutableOutputInfo<Label> {
    private static final long serialVersionUID = 1L;

    MutableLabelInfo() {
    }

    public MutableLabelInfo(LabelInfo info) {
        super(info);
    }

    private MutableLabelInfo(Map<String, MutableLong> counts, int unknownCount) {
        super(counts, unknownCount);
    }

    public static MutableLabelInfo deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        MutableLabelInfoProto proto = (MutableLabelInfoProto)message.unpack(MutableLabelInfoProto.class);
        if (proto.getLabelCount() != proto.getCountCount()) {
            throw new IllegalArgumentException("Invalid protobuf, different numbers of labels and counts, labels " + proto.getLabelCount() + ", counts " + proto.getCountCount());
        }
        HashMap<String, MutableLong> labelCounts = new HashMap<String, MutableLong>();
        for (int i = 0; i < proto.getLabelCount(); ++i) {
            long cnt;
            String lbl = proto.getLabel(i);
            MutableLong old = labelCounts.put(lbl, new MutableLong(cnt = proto.getCount(i)));
            if (old == null) continue;
            throw new IllegalArgumentException("Invalid protobuf, two mappings for " + lbl);
        }
        return new MutableLabelInfo(labelCounts, proto.getUnknownCount());
    }

    public OutputDomainProto serialize() {
        return (OutputDomainProto)ProtoUtil.serialize((ProtoSerializable)this);
    }

    public void observe(Label output) {
        if (output == LabelFactory.UNKNOWN_LABEL) {
            ++this.unknownCount;
        } else {
            String label = output.getLabel();
            MutableLong value = this.labelCounts.computeIfAbsent(label, k -> new MutableLong());
            this.labels.computeIfAbsent(label, Label::new);
            value.increment();
        }
    }

    public void clear() {
        this.labelCounts.clear();
    }

    @Override
    public MutableLabelInfo copy() {
        return new MutableLabelInfo(this);
    }

    public String toReadableString() {
        StringBuilder builder = new StringBuilder();
        for (Map.Entry e : this.labelCounts.entrySet()) {
            if (builder.length() > 0) {
                builder.append(", ");
            }
            builder.append('(');
            builder.append((String)e.getKey());
            builder.append(',');
            builder.append(((MutableLong)e.getValue()).longValue());
            builder.append(')');
        }
        return builder.toString();
    }

    public String toString() {
        return this.toReadableString();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        MutableLabelInfo labelInfo = (MutableLabelInfo)o;
        if (this.unknownCount == labelInfo.unknownCount && this.labelCounts.size() == labelInfo.labelCounts.size()) {
            for (Map.Entry e : this.labelCounts.entrySet()) {
                MutableLong other = (MutableLong)labelInfo.labelCounts.get(e.getKey());
                if (other != null && other.longValue() == ((MutableLong)e.getValue()).longValue()) continue;
                return false;
            }
            return true;
        }
        return false;
    }

    public int hashCode() {
        return Objects.hash(this.labelCounts, this.unknownCount);
    }
}

