package edu.cmu.tetrad.search;

import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.MLBayesEstimator;
import edu.cmu.tetrad.bayes.MLBayesIm;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.data.DiscreteDataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.DirectedAcyclicGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.ind.IndTestXSquare2;
import edu.cmu.tetrad.util.LogUtils;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;

/* loaded from: input_file:edu/cmu/tetrad/search/PcxROCScorer.class */
public class PcxROCScorer implements DiscreteClassifier, Serializable {
    static final long serialVersionUID = 23;
    private static Logger LOGGER = LogUtils.getLogger(PcxClassifier.class);
    private DiscreteDataSet ddsTrain;
    private DiscreteDataSet ddsClassify;
    private int[][] rawData;
    private String target;
    private int targetValue;
    private double percentCorrect = Double.NaN;
    private double alpha;
    private int depth;
    private DiscreteVariable targetVariable;
    private List trainVars;
    private List classifyVars;
    private List markovBlanketNodes;
    private double[] scores;
    private boolean[] inCategory;
    private int numValidCases;

    public PcxROCScorer(DiscreteDataSet discreteDataSet, DiscreteDataSet discreteDataSet2, String str, int i, double d, int i2) {
        this.ddsTrain = discreteDataSet;
        this.ddsClassify = discreteDataSet2;
        this.target = str;
        this.targetValue = i;
        this.alpha = d;
        this.depth = i2;
        this.trainVars = discreteDataSet.getVariables();
        this.classifyVars = discreteDataSet2.getVariables();
        for (int i3 = 0; i3 < this.trainVars.size(); i3++) {
            if (!this.trainVars.get(i3).equals(this.classifyVars.get(i3))) {
                throw new IllegalArgumentException("Datasets must contain same vars.");
            }
        }
        this.targetVariable = null;
        int i4 = 0;
        while (true) {
            if (i4 >= this.trainVars.size()) {
                break;
            }
            DiscreteVariable discreteVariable = (DiscreteVariable) this.trainVars.get(i4);
            if (discreteVariable.getName().equals(str)) {
                this.targetVariable = discreteVariable;
                break;
            }
            i4++;
        }
        if (this.targetVariable == null) {
            throw new IllegalArgumentException("Target variable not in data.");
        }
    }

    @Override // edu.cmu.tetrad.search.DiscreteClassifier
    public int[] classify() {
        Graph search = new PcxSearch(new IndTestXSquare2(this.ddsTrain, this.alpha), this.depth).search(this.target);
        this.markovBlanketNodes = search.getNodes();
        DiscreteDataSet discreteDataSet = new DiscreteDataSet(this.ddsTrain.subsetDataSet(this.markovBlanketNodes));
        List variables = discreteDataSet.getVariables();
        BayesPm bayesPm = new BayesPm(new DirectedAcyclicGraph(search));
        for (int i = 0; i < variables.size(); i++) {
            bayesPm.setNumCategories((Node) this.markovBlanketNodes.get(i), ((DiscreteVariable) variables.get(i)).getNumCategories());
        }
        MLBayesIm estimate = new MLBayesEstimator().estimate(bayesPm, discreteDataSet);
        RowSummingExactUpdater rowSummingExactUpdater = new RowSummingExactUpdater(estimate);
        DiscreteDataSet discreteDataSet2 = new DiscreteDataSet(this.ddsClassify.subsetDataSet(this.markovBlanketNodes));
        this.rawData = discreteDataSet2.getDataMatrixTrimmed();
        int length = this.rawData.length;
        int length2 = this.rawData[0].length;
        this.scores = new double[length2];
        this.inCategory = new boolean[length2];
        int[] iArr = new int[length2];
        Arrays.fill(iArr, -1);
        List variables2 = discreteDataSet2.getVariables();
        int i2 = 0;
        for (int i3 = 0; i3 < length2; i3++) {
            Evidence evidence = new Evidence(estimate);
            evidence.getProposition().setVariable(evidence.getNodeIndex(this.target), true);
            for (int i4 = 0; i4 < variables2.size(); i4++) {
                if (i4 != variables2.indexOf(this.targetVariable)) {
                    evidence.getProposition().setCategory(evidence.getNodeIndex(((DiscreteVariable) variables2.get(i4)).getName()), this.rawData[i4][i3]);
                }
            }
            rowSummingExactUpdater.setEvidence(evidence);
            int nodeIndex = rowSummingExactUpdater.getBayesIm().getNodeIndex(this.targetVariable);
            double d = -0.1d;
            int i5 = -1;
            for (int i6 = 0; i6 < this.targetVariable.getNumCategories(); i6++) {
                if (rowSummingExactUpdater.getMarginal(nodeIndex, i6) >= d) {
                    d = rowSummingExactUpdater.getMarginal(nodeIndex, i6);
                    i5 = i6;
                }
            }
            int indexOf = variables2.indexOf(this.targetVariable);
            double marginal = rowSummingExactUpdater.getMarginal(nodeIndex, this.targetValue);
            boolean z = this.rawData[indexOf][i3] == this.targetValue;
            if (i5 < 0) {
                LOGGER.fine("Case " + i3 + " does not return valid marginal.");
                for (int i7 = 0; i7 < length; i7++) {
                    System.out.print(((DiscreteVariable) variables2.get(i7)).getName());
                    LOGGER.fine("  " + this.rawData[i7][i3]);
                }
            } else {
                this.scores[i2] = marginal;
                this.inCategory[i2] = z;
                i2++;
                iArr[i3] = i5;
            }
        }
        this.numValidCases = i2;
        return iArr;
    }

    @Override // edu.cmu.tetrad.search.DiscreteClassifier
    public int[][] crossTabulate() {
        int[] classify = classify();
        int indexOf = new DiscreteDataSet(this.ddsClassify.subsetDataSet(this.markovBlanketNodes)).getVariables().indexOf(this.targetVariable);
        int numCategories = this.targetVariable.getNumCategories();
        int[][] iArr = new int[numCategories][numCategories];
        for (int i = 0; i < numCategories; i++) {
            for (int i2 = 0; i2 < numCategories; i2++) {
                iArr[i][i2] = 0;
            }
        }
        int i3 = 0;
        int length = this.rawData[0].length;
        int i4 = 0;
        for (int i5 = 0; i5 < length; i5++) {
            int i6 = classify[i5];
            int i7 = this.rawData[indexOf][i5];
            if (i6 >= 0) {
                i3++;
                int[] iArr2 = iArr[i7];
                iArr2[i6] = iArr2[i6] + 1;
                if (i7 == i6) {
                    i4++;
                }
            }
        }
        this.percentCorrect = (100.0d * i4) / length;
        LOGGER.fine("Total no usable cases= " + i3 + " out of " + length);
        return iArr;
    }

    @Override // edu.cmu.tetrad.search.DiscreteClassifier
    public double getPercentCorrect() {
        if (Double.isNaN(this.percentCorrect)) {
            crossTabulate();
        }
        return this.percentCorrect;
    }

    public DiscreteVariable getTargetVariable() {
        return this.targetVariable;
    }

    public double[] getScores() {
        return this.scores;
    }

    public boolean[] getCategoryFacts() {
        return this.inCategory;
    }

    public int getNumValidCases() {
        return this.numValidCases;
    }
}
