/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.transform.transformations;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Logger;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.MeanStdDevTransformerProto;
import org.tribuo.protos.core.TransformerProto;
import org.tribuo.transform.TransformStatistics;
import org.tribuo.transform.Transformation;
import org.tribuo.transform.TransformationProvenance;
import org.tribuo.transform.Transformer;

public final class MeanStdDevTransformation
implements Transformation {
    private static final String TARGET_MEAN = "targetMean";
    private static final String TARGET_STDDEV = "targetStdDev";
    @Config(mandatory=true, description="Mean value after transformation.")
    private double targetMean = 0.0;
    @Config(mandatory=true, description="Standard deviation after transformation.")
    private double targetStdDev = 1.0;
    private MeanStdDevTransformationProvenance provenance;

    public MeanStdDevTransformation() {
    }

    public MeanStdDevTransformation(double targetMean, double targetStdDev) {
        this.targetMean = targetMean;
        this.targetStdDev = targetStdDev;
        this.postConfig();
    }

    public void postConfig() {
        if (this.targetStdDev < 1.0E-12) {
            throw new IllegalArgumentException("Target standard deviation must be positive, found " + this.targetStdDev);
        }
    }

    @Override
    public TransformStatistics createStats() {
        return new MeanStdDevStatistics(this.targetMean, this.targetStdDev);
    }

    public TransformationProvenance getProvenance() {
        if (this.provenance == null) {
            this.provenance = new MeanStdDevTransformationProvenance(this);
        }
        return this.provenance;
    }

    public String toString() {
        return "MeanStdDevTransformation(targetMean=" + this.targetMean + ",targetStdDev=" + this.targetStdDev + ")";
    }

    private static class MeanStdDevStatistics
    implements TransformStatistics {
        private static final Logger logger = Logger.getLogger(MeanStdDevStatistics.class.getName());
        private final double targetMean;
        private final double targetStdDev;
        private double mean = 0.0;
        private double sumSquares = 0.0;
        private long count = 0L;

        public MeanStdDevStatistics(double targetMean, double targetStdDev) {
            this.targetMean = targetMean;
            this.targetStdDev = targetStdDev;
        }

        @Override
        public void observeValue(double value) {
            ++this.count;
            double delta = value - this.mean;
            this.mean += delta / (double)this.count;
            double delta2 = value - this.mean;
            this.sumSquares += delta * delta2;
        }

        @Override
        @Deprecated
        public void observeSparse() {
            this.observeValue(0.0);
        }

        @Override
        public void observeSparse(int sparseCount) {
            this.count += (long)sparseCount;
            double delta = -this.mean;
            this.mean += delta;
            double delta2 = -this.mean;
            this.sumSquares += (double)sparseCount * (delta * delta2);
        }

        @Override
        public Transformer generateTransformer() {
            if (this.sumSquares == 0.0) {
                logger.info("Only observed a single value (" + this.mean + ") when building a MeanStdDevTransformation");
            }
            return new MeanStdDevTransformer(this.mean, Math.sqrt(this.sumSquares / (double)(this.count - 1L)), this.targetMean, this.targetStdDev);
        }

        public String toString() {
            return "MeanStdDevStatistics(mean=" + this.mean + ",sumSquares=" + this.sumSquares + ",count=" + this.count + "targetMean=" + this.targetMean + ",targetStdDev=" + this.targetStdDev + ")";
        }
    }

    public static final class MeanStdDevTransformationProvenance
    implements TransformationProvenance {
        private static final long serialVersionUID = 1L;
        private final DoubleProvenance targetMean;
        private final DoubleProvenance targetStdDev;

        MeanStdDevTransformationProvenance(MeanStdDevTransformation host) {
            this.targetMean = new DoubleProvenance(MeanStdDevTransformation.TARGET_MEAN, host.targetMean);
            this.targetStdDev = new DoubleProvenance(MeanStdDevTransformation.TARGET_STDDEV, host.targetStdDev);
        }

        public MeanStdDevTransformationProvenance(Map<String, Provenance> map) {
            this.targetMean = (DoubleProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)MeanStdDevTransformation.TARGET_MEAN, DoubleProvenance.class, (String)MeanStdDevTransformationProvenance.class.getSimpleName());
            this.targetStdDev = (DoubleProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)MeanStdDevTransformation.TARGET_STDDEV, DoubleProvenance.class, (String)MeanStdDevTransformationProvenance.class.getSimpleName());
        }

        public String getClassName() {
            return MeanStdDevTransformation.class.getName();
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof MeanStdDevTransformationProvenance)) {
                return false;
            }
            MeanStdDevTransformationProvenance pairs = (MeanStdDevTransformationProvenance)o;
            return this.targetMean.equals((Object)pairs.targetMean) && this.targetStdDev.equals((Object)pairs.targetStdDev);
        }

        public int hashCode() {
            return Objects.hash(this.targetMean, this.targetStdDev);
        }

        public Map<String, Provenance> getConfiguredParameters() {
            HashMap<String, DoubleProvenance> map = new HashMap<String, DoubleProvenance>();
            map.put(MeanStdDevTransformation.TARGET_MEAN, this.targetMean);
            map.put(MeanStdDevTransformation.TARGET_STDDEV, this.targetStdDev);
            return Collections.unmodifiableMap(map);
        }
    }

    @ProtoSerializableClass(version=0, serializedDataClass=MeanStdDevTransformerProto.class)
    static final class MeanStdDevTransformer
    implements Transformer {
        private static final long serialVersionUID = 1L;
        public static final int CURRENT_VERSION = 0;
        @ProtoSerializableField
        private final double observedMean;
        @ProtoSerializableField
        private final double observedStdDev;
        @ProtoSerializableField
        private final double targetMean;
        @ProtoSerializableField
        private final double targetStdDev;

        MeanStdDevTransformer(double observedMean, double observedStdDev, double targetMean, double targetStdDev) {
            if (observedStdDev < 0.0 || targetStdDev < 0.0) {
                throw new IllegalArgumentException("Standard deviations must be non-negative.");
            }
            this.observedMean = observedMean;
            this.observedStdDev = observedStdDev;
            this.targetMean = targetMean;
            this.targetStdDev = targetStdDev;
        }

        static MeanStdDevTransformer deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
            MeanStdDevTransformerProto proto = (MeanStdDevTransformerProto)message.unpack(MeanStdDevTransformerProto.class);
            if (version == 0) {
                return new MeanStdDevTransformer(proto.getObservedMean(), proto.getObservedStdDev(), proto.getTargetMean(), proto.getTargetStdDev());
            }
            throw new IllegalArgumentException("Unknown version " + version + " expected {0}");
        }

        @Override
        public TransformerProto serialize() {
            return (TransformerProto)ProtoUtil.serialize(this);
        }

        @Override
        public double transform(double input) {
            return (input - this.observedMean) / this.observedStdDev * this.targetStdDev + this.targetMean;
        }

        public String toString() {
            return "MeanStdDevTransformer(observedMean=" + this.observedMean + ",observedStdDev=" + this.observedStdDev + ",targetMean=" + this.targetMean + ",targetStdDev=" + this.targetStdDev + ")";
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            MeanStdDevTransformer that = (MeanStdDevTransformer)o;
            return Double.compare(that.observedMean, this.observedMean) == 0 && Double.compare(that.observedStdDev, this.observedStdDev) == 0 && Double.compare(that.targetMean, this.targetMean) == 0 && Double.compare(that.targetStdDev, this.targetStdDev) == 0;
        }

        public int hashCode() {
            return Objects.hash(this.observedMean, this.observedStdDev, this.targetMean, this.targetStdDev);
        }
    }
}

