/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.MutableNumber;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.MutableOutputInfo;
import org.tribuo.OutputInfo;
import org.tribuo.classification.ImmutableLabelInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.MutableLabelInfo;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.ProtoSerializableKeysValuesField;

public abstract class LabelInfo
implements OutputInfo<Label> {
    private static final long serialVersionUID = 1L;
    @ProtoSerializableKeysValuesField(keysName="label", valuesName="count")
    protected final Map<String, MutableLong> labelCounts;
    @ProtoSerializableField
    protected int unknownCount = 0;
    protected transient Map<String, Label> labels;

    LabelInfo() {
        this.labelCounts = new HashMap<String, MutableLong>();
        this.labels = new HashMap<String, Label>();
    }

    LabelInfo(LabelInfo other) {
        this.labelCounts = MutableNumber.copyMap(other.labelCounts);
        this.labels = new HashMap<String, Label>();
        this.labels.putAll(other.labels);
    }

    LabelInfo(Map<String, MutableLong> counts, int unknownCount) {
        if (unknownCount < 0) {
            throw new IllegalArgumentException("Unknown count must be non-negative, found " + unknownCount);
        }
        this.unknownCount = unknownCount;
        this.labelCounts = new HashMap<String, MutableLong>();
        this.labels = new HashMap<String, Label>();
        for (Map.Entry<String, MutableLong> e : counts.entrySet()) {
            if (e.getValue().longValue() < 1L) {
                throw new IllegalArgumentException("Count for " + e.getKey() + " must be positive but found " + e.getValue().longValue());
            }
            this.labelCounts.put(e.getKey(), e.getValue().copy());
            this.labels.put(e.getKey(), new Label(e.getKey()));
        }
    }

    public int getUnknownCount() {
        return this.unknownCount;
    }

    public Set<Label> getDomain() {
        return new HashSet<Label>(this.labels.values());
    }

    public long getLabelCount(Label label) {
        MutableLong l = this.labelCounts.get(label.getLabel());
        if (l != null) {
            return l.longValue();
        }
        return 0L;
    }

    public long getLabelCount(String label) {
        MutableLong l = this.labelCounts.get(label);
        if (l != null) {
            return l.longValue();
        }
        return 0L;
    }

    public Iterable<Pair<String, Long>> outputCountsIterable() {
        return () -> new Iterator<Pair<String, Long>>(){
            Iterator<Map.Entry<String, MutableLong>> itr;
            {
                this.itr = LabelInfo.this.labelCounts.entrySet().iterator();
            }

            @Override
            public boolean hasNext() {
                return this.itr.hasNext();
            }

            @Override
            public Pair<String, Long> next() {
                Map.Entry<String, MutableLong> e = this.itr.next();
                return new Pair((Object)e.getKey(), (Object)e.getValue().longValue());
            }
        };
    }

    public int size() {
        return this.labelCounts.size();
    }

    public ImmutableOutputInfo<Label> generateImmutableOutputInfo() {
        return new ImmutableLabelInfo(this);
    }

    public MutableOutputInfo<Label> generateMutableOutputInfo() {
        return new MutableLabelInfo(this);
    }

    public abstract LabelInfo copy();

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.labels = new HashMap<String, Label>();
        for (Map.Entry<String, MutableLong> e : this.labelCounts.entrySet()) {
            this.labels.put(e.getKey(), new Label(e.getKey()));
        }
    }
}

