/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.evaluation;

import com.oracle.labs.mlrg.olcut.util.SortUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.tribuo.util.Util;

public final class LabelEvaluationUtil {
    private LabelEvaluationUtil() {
    }

    public static double averagedPrecision(boolean[] yPos, double[] yScore) {
        PRCurve prc = LabelEvaluationUtil.generatePRCurve(yPos, yScore);
        double score = 0.0;
        for (int i = 0; i < prc.precision.length - 1; ++i) {
            score += (prc.recall[i + 1] - prc.recall[i]) * prc.precision[i];
        }
        return -score;
    }

    public static PRCurve generatePRCurve(boolean[] yPos, double[] yScore) {
        TPFP tpfp = LabelEvaluationUtil.generateTPFPs(yPos, yScore);
        ArrayList<Double> precisions = new ArrayList<Double>(tpfp.falsePos.size());
        ArrayList<Double> recalls = new ArrayList<Double>(tpfp.falsePos.size());
        ArrayList<Double> thresholds = new ArrayList<Double>(tpfp.falsePos.size());
        for (int i = 0; i < tpfp.falsePos.size(); ++i) {
            double curFalsePos = tpfp.falsePos.get(i).intValue();
            double curTruePos = tpfp.truePos.get(i).intValue();
            double precision = 0.0;
            double recall = 0.0;
            if (curTruePos != 0.0) {
                precision = curTruePos / (curTruePos + curFalsePos);
                recall = curTruePos / (double)tpfp.totalPos;
            }
            precisions.add(precision);
            recalls.add(recall);
            thresholds.add(tpfp.thresholds.get(i));
            if (curTruePos == (double)tpfp.totalPos) break;
        }
        Collections.reverse(precisions);
        Collections.reverse(recalls);
        Collections.reverse(thresholds);
        precisions.add(1.0);
        recalls.add(0.0);
        return new PRCurve(Util.toPrimitiveDouble(precisions), Util.toPrimitiveDouble(recalls), Util.toPrimitiveDouble(thresholds));
    }

    public static double binaryAUCROC(boolean[] yPos, double[] yScore) {
        ROC roc = LabelEvaluationUtil.generateROCCurve(yPos, yScore);
        return Util.auc((double[])roc.fpr, (double[])roc.tpr);
    }

    public static ROC generateROCCurve(boolean[] yPos, double[] yScore) {
        TPFP tpfp = LabelEvaluationUtil.generateTPFPs(yPos, yScore);
        if (tpfp.truePos.get(0) != 0 || tpfp.falsePos.get(0) != 0) {
            tpfp.truePos.add(0, 0);
            tpfp.falsePos.add(0, 0);
            tpfp.thresholds.add(0, Double.POSITIVE_INFINITY);
        }
        double[] truePosArr = Util.toPrimitiveDoubleFromInteger(tpfp.truePos);
        double[] falsePosArr = Util.toPrimitiveDoubleFromInteger(tpfp.falsePos);
        double[] thresholdsArr = Util.toPrimitiveDouble(tpfp.thresholds);
        double maxTrue = truePosArr[truePosArr.length - 1];
        double maxFalse = falsePosArr[falsePosArr.length - 1];
        int i = 0;
        while (i < truePosArr.length) {
            int n = i;
            truePosArr[n] = truePosArr[n] / maxTrue;
            int n2 = i++;
            falsePosArr[n2] = falsePosArr[n2] / maxFalse;
        }
        return new ROC(falsePosArr, truePosArr, thresholdsArr);
    }

    private static TPFP generateTPFPs(boolean[] yPos, double[] yScore) {
        if (yPos.length != yScore.length) {
            throw new IllegalArgumentException("yPos and yScore must be the same length, yPos.length = " + yPos.length + ", yScore.length = " + yScore.length);
        }
        int[] sortedIndices = SortUtil.argsort((double[])yScore, (boolean)false);
        double[] sortedScore = new double[yScore.length];
        boolean[] sortedPos = new boolean[yPos.length];
        int totalPos = 0;
        for (int i = 0; i < yScore.length; ++i) {
            sortedScore[i] = yScore[sortedIndices[i]];
            sortedPos[i] = yPos[sortedIndices[i]];
            if (!sortedPos[i]) continue;
            ++totalPos;
        }
        int[] differentIndices = Util.differencesIndices((double[])sortedScore);
        int[] truePosSum = Util.cumulativeSum((boolean[])sortedPos);
        ArrayList<Integer> truePos = new ArrayList<Integer>();
        ArrayList<Integer> falsePos = new ArrayList<Integer>();
        ArrayList<Double> thresholds = new ArrayList<Double>();
        for (int i = 0; i < differentIndices.length; ++i) {
            thresholds.add(sortedScore[differentIndices[i]]);
            truePos.add(truePosSum[differentIndices[i]]);
            falsePos.add(1 + (differentIndices[i] - truePosSum[differentIndices[i]]));
        }
        return new TPFP(falsePos, truePos, thresholds, totalPos);
    }

    public static class PRCurve {
        public final double[] precision;
        public final double[] recall;
        public final double[] thresholds;

        public PRCurve(double[] precision, double[] recall, double[] thresholds) {
            this.precision = precision;
            this.recall = recall;
            this.thresholds = thresholds;
        }
    }

    private static class TPFP {
        public final List<Integer> falsePos;
        public final List<Integer> truePos;
        public final List<Double> thresholds;
        public final int totalPos;

        public TPFP(List<Integer> falsePos, List<Integer> truePos, List<Double> thresholds, int totalPos) {
            this.falsePos = falsePos;
            this.truePos = truePos;
            this.thresholds = thresholds;
            this.totalPos = totalPos;
        }
    }

    public static class ROC {
        public final double[] fpr;
        public final double[] tpr;
        public final double[] thresholds;

        public ROC(double[] fpr, double[] tpr, double[] thresholds) {
            this.fpr = fpr;
            this.tpr = tpr;
            this.thresholds = thresholds;
        }
    }
}

