/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.evaluation;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.evaluation.DescriptiveStats;
import org.tribuo.evaluation.Evaluation;
import org.tribuo.evaluation.Evaluator;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricContext;
import org.tribuo.evaluation.metrics.MetricID;
import org.tribuo.util.Util;

public final class EvaluationAggregator {
    private EvaluationAggregator() {
    }

    public static <T extends Output<T>, C extends MetricContext<T>> DescriptiveStats summarize(EvaluationMetric<T, C> metric, List<? extends Model<T>> models, Dataset<T> dataset) {
        DescriptiveStats summary = new DescriptiveStats();
        for (Model<T> model : models) {
            C ctx = metric.createContext(model, dataset);
            double value = metric.compute(ctx);
            summary.addValue(value);
        }
        return summary;
    }

    public static <T extends Output<T>, R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(Evaluator<T, R> evaluator, List<? extends Model<T>> models, Dataset<T> dataset) {
        List evals = models.stream().map(model -> evaluator.evaluate((Model)model, dataset)).collect(Collectors.toList());
        return EvaluationAggregator.summarize(evals);
    }

    public static <T extends Output<T>, C extends MetricContext<T>> DescriptiveStats summarize(EvaluationMetric<T, C> metric, Model<T> model, List<? extends Dataset<T>> datasets) {
        DescriptiveStats summary = new DescriptiveStats();
        for (Dataset<T> dataset : datasets) {
            C ctx = metric.createContext(model, dataset);
            double value = metric.compute(ctx);
            summary.addValue(value);
        }
        return summary;
    }

    public static <T extends Output<T>, C extends MetricContext<T>> DescriptiveStats summarize(List<? extends EvaluationMetric<T, C>> metrics, Model<T> model, Dataset<T> dataset) {
        List<Prediction<T>> predictions = model.predict(dataset);
        DescriptiveStats summary = new DescriptiveStats();
        for (EvaluationMetric<T, C> metric : metrics) {
            C ctx = metric.createContext(model, predictions);
            double value = metric.compute(ctx);
            summary.addValue(value);
        }
        return summary;
    }

    public static <T extends Output<T>, C extends MetricContext<T>> DescriptiveStats summarize(List<? extends EvaluationMetric<T, C>> metrics, Model<T> model, List<Prediction<T>> predictions) {
        DescriptiveStats summary = new DescriptiveStats();
        for (EvaluationMetric<T, C> metric : metrics) {
            C ctx = metric.createContext(model, predictions);
            double value = metric.compute(ctx);
            summary.addValue(value);
        }
        return summary;
    }

    public static <T extends Output<T>, R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(Evaluator<T, R> evaluator, Model<T> model, List<? extends Dataset<T>> datasets) {
        List evals = datasets.stream().map(data -> evaluator.evaluate(model, (Dataset)data)).collect(Collectors.toList());
        return EvaluationAggregator.summarize(evals);
    }

    public static <T extends Output<T>, R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(List<R> evaluations) {
        HashMap<MetricID<T>, DescriptiveStats> results = new HashMap<MetricID<T>, DescriptiveStats>();
        for (Evaluation evaluation : evaluations) {
            for (Map.Entry kv : evaluation.asMap().entrySet()) {
                MetricID key = kv.getKey();
                DescriptiveStats summary = results.getOrDefault((Object)key, new DescriptiveStats());
                summary.addValue(kv.getValue());
                results.put(key, summary);
            }
        }
        return results;
    }

    public static <T extends Output<T>, R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarizeCrossValidation(List<Pair<R, Model<T>>> evaluations) {
        HashMap<MetricID<T>, DescriptiveStats> results = new HashMap<MetricID<T>, DescriptiveStats>();
        for (Pair<R, Model<T>> pair : evaluations) {
            Evaluation evaluation = (Evaluation)pair.getA();
            for (Map.Entry kv : evaluation.asMap().entrySet()) {
                MetricID key = kv.getKey();
                DescriptiveStats summary = results.getOrDefault((Object)key, new DescriptiveStats());
                summary.addValue(kv.getValue());
                results.put(key, summary);
            }
        }
        return results;
    }

    public static <T extends Output<T>, R extends Evaluation<T>> DescriptiveStats summarize(List<R> evaluations, ToDoubleFunction<R> fieldGetter) {
        DescriptiveStats summary = new DescriptiveStats();
        for (Evaluation evaluation : evaluations) {
            double value = fieldGetter.applyAsDouble(evaluation);
            summary.addValue(value);
        }
        return summary;
    }

    public static <T extends Output<T>, C extends MetricContext<T>> Pair<Integer, Double> argmax(EvaluationMetric<T, C> metric, List<? extends Model<T>> models, Dataset<T> dataset) {
        List values = models.stream().map(model -> metric.compute(metric.createContext((Model)model, dataset))).collect(Collectors.toList());
        return Util.argmax(values);
    }

    public static <T extends Output<T>, C extends MetricContext<T>> Pair<Integer, Double> argmax(EvaluationMetric<T, C> metric, Model<T> model, List<? extends Dataset<T>> datasets) {
        List values = datasets.stream().map(dataset -> metric.compute(metric.createContext(model, (Dataset)dataset))).collect(Collectors.toList());
        return Util.argmax(values);
    }

    public static <T extends Output<T>, R extends Evaluation<T>> Pair<Integer, Double> argmax(List<R> evaluations, Function<R, Double> getter) {
        List values = evaluations.stream().map(getter).collect(Collectors.toList());
        return Util.argmax(values);
    }
}

