/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.MutableDouble;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import org.tribuo.MutableOutputInfo;
import org.tribuo.protos.core.OutputDomainProto;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.RegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.protos.MutableRegressionInfoProto;

public class MutableRegressionInfo
extends RegressionInfo
implements MutableOutputInfo<Regressor> {
    private static final long serialVersionUID = 2L;

    MutableRegressionInfo() {
    }

    public MutableRegressionInfo(RegressionInfo info) {
        super(info);
    }

    private MutableRegressionInfo(Map<String, MutableLong> countMap, Map<String, MutableDouble> maxMap, Map<String, MutableDouble> minMap, Map<String, MutableDouble> meanMap, Map<String, MutableDouble> sumSquaresMap, int unknownCount, long overallCount) {
        super(countMap, maxMap, minMap, meanMap, sumSquaresMap, unknownCount, overallCount);
    }

    public static MutableRegressionInfo deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        MutableRegressionInfoProto proto = (MutableRegressionInfoProto)message.unpack(MutableRegressionInfoProto.class);
        if (proto.getLabelCount() != proto.getMaxCount() || proto.getLabelCount() != proto.getMinCount() || proto.getLabelCount() != proto.getMeanCount() || proto.getLabelCount() != proto.getSumSquaresCount() || proto.getLabelCount() != proto.getCountCount()) {
            throw new IllegalArgumentException("Invalid protobuf, expected the same number of dimension names, maxes, mins, means, sumSquares and counts, found " + proto.getLabelCount() + " names, " + proto.getMaxCount() + " maxes, " + proto.getMinCount() + " mins, " + proto.getMeanCount() + " means, " + proto.getSumSquaresCount() + " sumSquares, and " + proto.getCountCount() + " counts.");
        }
        LinkedHashMap<String, MutableDouble> maxMap = new LinkedHashMap<String, MutableDouble>();
        LinkedHashMap<String, MutableDouble> minMap = new LinkedHashMap<String, MutableDouble>();
        LinkedHashMap<String, MutableDouble> meanMap = new LinkedHashMap<String, MutableDouble>();
        LinkedHashMap<String, MutableDouble> sumSquaresMap = new LinkedHashMap<String, MutableDouble>();
        TreeMap<String, MutableLong> countMap = new TreeMap<String, MutableLong>();
        for (int i = 0; i < proto.getLabelCount(); ++i) {
            long cnt;
            String lbl = proto.getLabel(i);
            MutableLong old = countMap.put(lbl, new MutableLong(cnt = proto.getCount(i)));
            if (old != null) {
                throw new IllegalArgumentException("Invalid protobuf, two mappings for " + lbl);
            }
            maxMap.put(lbl, new MutableDouble(proto.getMax(i)));
            minMap.put(lbl, new MutableDouble(proto.getMin(i)));
            meanMap.put(lbl, new MutableDouble(proto.getMean(i)));
            sumSquaresMap.put(lbl, new MutableDouble(proto.getSumSquares(i)));
        }
        return new MutableRegressionInfo(countMap, maxMap, minMap, meanMap, sumSquaresMap, proto.getUnknownCount(), proto.getOverallCount());
    }

    public OutputDomainProto serialize() {
        OutputDomainProto.Builder outputBuilder = OutputDomainProto.newBuilder();
        outputBuilder.setClassName(MutableRegressionInfo.class.getName());
        outputBuilder.setVersion(0);
        MutableRegressionInfoProto.Builder data = MutableRegressionInfoProto.newBuilder();
        for (Map.Entry e : this.countMap.entrySet()) {
            data.addLabel((String)e.getKey());
            data.addCount(((MutableLong)e.getValue()).longValue());
            data.addMax(((MutableDouble)this.maxMap.get(e.getKey())).doubleValue());
            data.addMin(((MutableDouble)this.minMap.get(e.getKey())).doubleValue());
            data.addMean(((MutableDouble)this.meanMap.get(e.getKey())).doubleValue());
            data.addSumSquares(((MutableDouble)this.sumSquaresMap.get(e.getKey())).doubleValue());
        }
        data.setUnknownCount(this.unknownCount);
        data.setOverallCount(this.overallCount);
        outputBuilder.setSerializedData(Any.pack((Message)data.build()));
        return outputBuilder.build();
    }

    public void observe(Regressor output) {
        if (output == RegressionFactory.UNKNOWN_REGRESSOR) {
            ++this.unknownCount;
        } else {
            if (this.overallCount != 0L) {
                String[] names = output.getNames();
                if (names.length != this.countMap.size()) {
                    throw new IllegalArgumentException("Expected this Regressor to contain " + this.countMap.size() + " dimensions, found " + names.length);
                }
                for (String name : names) {
                    if (this.countMap.containsKey(name)) continue;
                    throw new IllegalArgumentException("Regressor contains unexpected dimension named '" + name + "'");
                }
            }
            for (Regressor.DimensionTuple r : output) {
                String name = r.getName();
                double value = r.getValue();
                this.minMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() < b.doubleValue() ? a : b);
                this.maxMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() > b.doubleValue() ? a : b);
                MutableLong countValue = this.countMap.computeIfAbsent(name, k -> new MutableLong());
                countValue.increment();
                MutableDouble meanValue = this.meanMap.computeIfAbsent(name, k -> new MutableDouble());
                double delta = value - meanValue.doubleValue();
                meanValue.increment(delta / (double)countValue.longValue());
                double delta2 = value - meanValue.doubleValue();
                MutableDouble sumSquaresValue = this.sumSquaresMap.computeIfAbsent(name, k -> new MutableDouble());
                sumSquaresValue.increment(delta * delta2);
            }
            ++this.overallCount;
        }
    }

    public void clear() {
        this.maxMap.clear();
        this.minMap.clear();
        this.meanMap.clear();
        this.sumSquaresMap.clear();
        this.countMap.clear();
    }

    @Override
    public MutableRegressionInfo copy() {
        return new MutableRegressionInfo(this);
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("RegressionInfo(");
        for (Map.Entry e : this.countMap.entrySet()) {
            String name = (String)e.getKey();
            long count = ((MutableLong)e.getValue()).longValue();
            builder.append(String.format("{name=%s,count=%d,max=%f,min=%f,mean=%f,variance=%f},", name, count, ((MutableDouble)this.maxMap.get(name)).doubleValue(), ((MutableDouble)this.minMap.get(name)).doubleValue(), ((MutableDouble)this.meanMap.get(name)).doubleValue(), ((MutableDouble)this.sumSquaresMap.get(name)).doubleValue() / (double)(count - 1L)));
        }
        builder.deleteCharAt(builder.length() - 1);
        builder.append(")");
        return builder.toString();
    }

    public String toReadableString() {
        return this.toString();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        MutableRegressionInfo that = (MutableRegressionInfo)o;
        if (this.unknownCount == that.unknownCount && this.overallCount == that.overallCount) {
            for (Map.Entry e : this.countMap.entrySet()) {
                MutableLong other = (MutableLong)that.countMap.get(e.getKey());
                if (other == null || other.longValue() != ((MutableLong)e.getValue()).longValue()) {
                    return false;
                }
                if (!MutableRegressionInfo.checkMutableDouble((MutableDouble)this.maxMap.get(e.getKey()), (MutableDouble)that.maxMap.get(e.getKey()))) {
                    return false;
                }
                if (!MutableRegressionInfo.checkMutableDouble((MutableDouble)this.minMap.get(e.getKey()), (MutableDouble)that.minMap.get(e.getKey()))) {
                    return false;
                }
                if (!MutableRegressionInfo.checkMutableDouble((MutableDouble)this.meanMap.get(e.getKey()), (MutableDouble)that.meanMap.get(e.getKey()))) {
                    return false;
                }
                if (MutableRegressionInfo.checkMutableDouble((MutableDouble)this.sumSquaresMap.get(e.getKey()), (MutableDouble)that.sumSquaresMap.get(e.getKey()))) continue;
                return false;
            }
            return true;
        }
        return false;
    }

    public int hashCode() {
        return Objects.hash(this.countMap, this.maxMap, this.minMap, this.meanMap, this.sumSquaresMap, this.unknownCount, this.overallCount);
    }
}

