/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.anomaly.evaluation;

import java.util.List;
import java.util.function.ToDoubleBiFunction;
import org.tribuo.Example;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.anomaly.Event;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricContext;
import org.tribuo.evaluation.metrics.MetricTarget;

public class AnomalyMetric
implements EvaluationMetric<Event, Context> {
    private final MetricTarget<Event> target;
    private final String name;
    private final ToDoubleBiFunction<MetricTarget<Event>, Context> impl;

    public AnomalyMetric(MetricTarget<Event> target, String name, ToDoubleBiFunction<MetricTarget<Event>, Context> impl) {
        this.target = target;
        this.name = name;
        this.impl = impl;
    }

    public double compute(Context context) {
        return this.impl.applyAsDouble(this.target, context);
    }

    public MetricTarget<Event> getTarget() {
        return this.target;
    }

    public String getName() {
        return this.name;
    }

    public Context createContext(Model<Event> model, List<Prediction<Event>> predictions) {
        return AnomalyMetric.buildContext(model, predictions);
    }

    static Context buildContext(Model<Event> model, List<Prediction<Event>> predictions) {
        return new Context(model, predictions);
    }

    static final class Context
    extends MetricContext<Event> {
        private final long truePositive;
        private final long falsePositive;
        private final long trueNegative;
        private final long falseNegative;

        Context(Model<Event> model, List<Prediction<Event>> predictions) {
            super(model, predictions);
            PredictionStatistics tab = Context.tabulate(predictions);
            this.truePositive = tab.truePositive;
            this.falsePositive = tab.falsePositive;
            this.trueNegative = tab.trueNegative;
            this.falseNegative = tab.falseNegative;
        }

        long getTruePositive() {
            return this.truePositive;
        }

        long getFalsePositive() {
            return this.falsePositive;
        }

        long getTrueNegative() {
            return this.trueNegative;
        }

        long getFalseNegative() {
            return this.falseNegative;
        }

        private static PredictionStatistics tabulate(List<Prediction<Event>> predictions) {
            long truePositive = 0L;
            long falsePositive = 0L;
            long trueNegative = 0L;
            long falseNegative = 0L;
            for (Prediction<Event> prediction : predictions) {
                Example example = prediction.getExample();
                Event.EventType truth = ((Event)example.getOutput()).getType();
                Event.EventType predicted = ((Event)prediction.getOutput()).getType();
                if (truth == Event.EventType.ANOMALOUS) {
                    if (predicted == Event.EventType.ANOMALOUS) {
                        ++truePositive;
                        continue;
                    }
                    if (predicted != Event.EventType.EXPECTED) continue;
                    ++falseNegative;
                    continue;
                }
                if (truth == Event.EventType.EXPECTED) {
                    if (predicted == Event.EventType.ANOMALOUS) {
                        ++falsePositive;
                        continue;
                    }
                    if (predicted != Event.EventType.EXPECTED) continue;
                    ++trueNegative;
                    continue;
                }
                throw new IllegalArgumentException("Evaluation data contained EventType.UNKNOWN as the ground truth output.");
            }
            return new PredictionStatistics(truePositive, falsePositive, trueNegative, falseNegative);
        }
    }

    private static final class PredictionStatistics {
        private final long truePositive;
        private final long falsePositive;
        private final long trueNegative;
        private final long falseNegative;

        PredictionStatistics(long truePositive, long falsePositive, long trueNegative, long falseNegative) {
            this.truePositive = truePositive;
            this.falsePositive = falsePositive;
            this.trueNegative = trueNegative;
            this.falseNegative = falseNegative;
        }
    }
}

