package edu.cmu.tetrad.util;

import java.io.Serializable;
import java.util.Arrays;
import java.util.LinkedList;

/* loaded from: input_file:edu/cmu/tetrad/util/RocCalculator.class */
public class RocCalculator implements Serializable {
    static final long serialVersionUID = 23;
    public static final int ASCENDING = 0;
    public static final int DESCENDING = 1;
    private double[] scores;
    private boolean[] inCategory;
    private int direction;
    private int[][] points;
    private ScoreCategoryPair[] scoreCatPairs;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/tetrad/util/RocCalculator$OrderedPairInt.class */
    public static class OrderedPairInt {
        static final long serialVersionUID = 23;
        private int first;
        private int second;

        public OrderedPairInt(int i, int i2) {
            this.first = i;
            this.second = i2;
        }

        public int getFirst() {
            return this.first;
        }

        public void setFirst(int i) {
            this.first = i;
        }

        public int getSecond() {
            return this.second;
        }

        public void setSecond(int i) {
            this.second = i;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/tetrad/util/RocCalculator$ScoreCategoryPair.class */
    public static class ScoreCategoryPair implements Comparable, Serializable {
        static final long serialVersionUID = 23;
        private double score;
        private boolean hasProperty;

        public ScoreCategoryPair(double d, boolean z) {
            this.score = d;
            this.hasProperty = z;
        }

        public double getScore() {
            return this.score;
        }

        public boolean getHasProperty() {
            return this.hasProperty;
        }

        @Override // java.lang.Comparable
        public int compareTo(Object obj) {
            if (getScore() < ((ScoreCategoryPair) obj).getScore()) {
                return -1;
            }
            return getScore() == ((ScoreCategoryPair) obj).getScore() ? 0 : 1;
        }
    }

    private RocCalculator(double[] dArr, boolean[] zArr, int i) {
        this.direction = 0;
        if (dArr == null) {
            throw new NullPointerException();
        }
        if (zArr == null) {
            throw new NullPointerException();
        }
        if (dArr.length != zArr.length) {
            throw new IllegalArgumentException("Scores array must have same number of items as inCategory array.");
        }
        if (i != 0 && i != 1) {
            throw new IllegalArgumentException("Direction must be ASCENDING or DESCENDING.");
        }
        this.scores = dArr;
        this.inCategory = zArr;
        this.direction = i;
        this.scoreCatPairs = new ScoreCategoryPair[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            this.scoreCatPairs[i2] = new ScoreCategoryPair(this.scores[i2], this.inCategory[i2]);
        }
    }

    public static RocCalculator newCalculator(double[] dArr, boolean[] zArr, int i) {
        return new RocCalculator(dArr, zArr, i);
    }

    public double getAuc() {
        if (this.points == null) {
            getUnscaledRocPlot();
        }
        int length = this.points.length - 1;
        int i = 0;
        int i2 = 0;
        for (int i3 = 1; i3 < this.points.length; i3++) {
            if (this.points[i3][1] > this.points[i3 - 1][1]) {
                i++;
            } else if (this.points[i3][0] > this.points[i3 - 1][0]) {
                i2 += i;
            }
        }
        return i2 / (this.points[length][0] * this.points[length][1]);
    }

    public int[][] getUnscaledRocPlot() {
        OrderedPairInt orderedPairInt;
        sortCases();
        OrderedPairInt orderedPairInt2 = new OrderedPairInt(0, 0);
        LinkedList linkedList = new LinkedList();
        linkedList.add(orderedPairInt2);
        for (int length = this.scoreCatPairs.length - 1; length >= 0; length--) {
            if (this.scoreCatPairs[length].getHasProperty()) {
                orderedPairInt = new OrderedPairInt(orderedPairInt2.getFirst(), orderedPairInt2.getSecond() + 1);
                linkedList.add(orderedPairInt);
            } else {
                orderedPairInt = new OrderedPairInt(orderedPairInt2.getFirst() + 1, orderedPairInt2.getSecond());
                linkedList.add(orderedPairInt);
            }
            orderedPairInt2 = orderedPairInt;
        }
        this.points = new int[linkedList.size()][2];
        for (int i = 0; i < linkedList.size(); i++) {
            OrderedPairInt orderedPairInt3 = (OrderedPairInt) linkedList.get(i);
            this.points[i][0] = orderedPairInt3.getFirst();
            this.points[i][1] = orderedPairInt3.getSecond();
        }
        return this.points;
    }

    public double[][] getScaledRocPlot() {
        if (this.points == null) {
            getUnscaledRocPlot();
        }
        int length = this.points.length;
        double[][] dArr = new double[length][2];
        for (int i = 0; i < length; i++) {
            dArr[i][0] = this.points[i][0] / this.points[length - 1][0];
            dArr[i][1] = this.points[i][1] / this.points[length - 1][1];
        }
        return dArr;
    }

    private void sortCases() {
        Arrays.sort(this.scoreCatPairs);
        if (this.direction == 1) {
            int length = this.scoreCatPairs.length;
            ScoreCategoryPair[] scoreCategoryPairArr = new ScoreCategoryPair[length];
            for (int i = 0; i < length; i++) {
                scoreCategoryPairArr[i] = this.scoreCatPairs[(length - i) - 1];
            }
            for (int i2 = 0; i2 < length; i2++) {
                this.scoreCatPairs[i2] = scoreCategoryPairArr[i2];
            }
        }
    }
}
