package edu.cmu.tetrad.search;

import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.IBayesIm;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.data.Column;
import edu.cmu.tetrad.data.DiscreteDataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.util.LogUtils;
import java.io.Serializable;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Logger;

/* loaded from: input_file:edu/cmu/tetrad/search/BayesUpdaterClassifier.class */
public class BayesUpdaterClassifier implements DiscreteClassifier, Serializable {
    static final long serialVersionUID = 23;
    private static Logger LOGGER = LogUtils.getLogger(BayesUpdaterClassifier.class);
    private IBayesIm bayesIM;
    private DiscreteDataSet testData;
    private int[][] rawData;
    private DiscreteVariable targetVariable;
    private boolean missingValueCaseFound;
    private List bayesImVars;
    private int[] classifications;
    private double[][] marginals;
    private List availableTargets;
    private int totalUsableCases;
    private int targetCategory;
    private double binaryCutoff = 0.5d;
    private int numCases = -1;
    private double percentCorrect = Double.NaN;

    public BayesUpdaterClassifier(IBayesIm iBayesIm, DiscreteDataSet discreteDataSet) {
        this.bayesIM = iBayesIm;
        this.testData = discreteDataSet;
        this.bayesImVars = iBayesIm.getVariables();
        this.availableTargets = new LinkedList(this.bayesImVars);
    }

    public void setTarget(String str, int i) {
        DiscreteVariable discreteVariable = null;
        int i2 = 0;
        while (true) {
            if (i2 >= this.availableTargets.size()) {
                break;
            }
            DiscreteVariable discreteVariable2 = (DiscreteVariable) this.availableTargets.get(i2);
            if (discreteVariable2.getName().equals(str)) {
                discreteVariable = discreteVariable2;
                break;
            }
            i2++;
        }
        if (discreteVariable == null) {
            throw new IllegalArgumentException("Not an available target: " + str);
        }
        this.targetVariable = discreteVariable;
        this.targetCategory = i;
    }

    /* JADX WARN: Type inference failed for: r1v2, types: [int[], int[][]] */
    @Override // edu.cmu.tetrad.search.DiscreteClassifier
    public int[] classify() {
        if (this.targetVariable == null) {
            throw new NullPointerException("Target not set.");
        }
        RowSummingExactUpdater rowSummingExactUpdater = new RowSummingExactUpdater(getBayesIm());
        int size = this.bayesImVars.size();
        int maxRowCount = this.testData.getMaxRowCount();
        this.rawData = new int[size];
        for (int i = 0; i < size; i++) {
            Column column = this.testData.get(((DiscreteVariable) this.bayesImVars.get(i)).getName());
            if (column != null) {
                this.rawData[i] = (int[]) column.getRawData();
            }
        }
        this.numCases = maxRowCount;
        int[] iArr = new int[maxRowCount];
        int numCategories = this.targetVariable.getNumCategories();
        double[][] dArr = new double[numCategories][maxRowCount];
        Arrays.fill(iArr, -1);
        for (int i2 = 0; i2 < maxRowCount; i2++) {
            Evidence evidence = new Evidence(getBayesIm());
            evidence.getProposition().setVariable(evidence.getNodeIndex(this.targetVariable.getName()), true);
            this.missingValueCaseFound = false;
            for (int i3 = 0; i3 < this.bayesImVars.size(); i3++) {
                if (i3 != this.bayesImVars.indexOf(this.targetVariable) && this.rawData[i3] != null) {
                    int i4 = this.rawData[i3][i2];
                    if (i4 == -99) {
                        this.missingValueCaseFound = true;
                    } else {
                        evidence.getProposition().setCategory(evidence.getNodeIndex(((DiscreteVariable) this.bayesImVars.get(i3)).getName()), i4);
                    }
                }
            }
            rowSummingExactUpdater.setEvidence(evidence);
            int nodeIndex = getBayesIm().getNodeIndex(getBayesIm().getNode(this.targetVariable.getName()));
            int i5 = -1;
            if (numCategories == 2) {
                int i6 = 0;
                while (true) {
                    if (i6 >= numCategories) {
                        break;
                    }
                    double marginal = rowSummingExactUpdater.getMarginal(nodeIndex, i6);
                    dArr[i6][i2] = marginal;
                    dArr[1 - i6][i2] = 1.0d - marginal;
                    if (this.targetCategory == i6) {
                        i5 = marginal > this.binaryCutoff ? i6 : 1 - i6;
                    } else {
                        i6++;
                    }
                }
            } else {
                double d = -0.1d;
                for (int i7 = 0; i7 < numCategories; i7++) {
                    double marginal2 = rowSummingExactUpdater.getMarginal(nodeIndex, i7);
                    dArr[i7][i2] = marginal2;
                    if (marginal2 >= d) {
                        d = marginal2;
                        i5 = i7;
                    }
                }
            }
            if (i5 < 0) {
                LOGGER.fine("Case " + i2 + " does not return valid marginal.");
                for (int i8 = 0; i8 < size; i8++) {
                    System.out.print(((DiscreteVariable) this.bayesImVars.get(i8)).getName());
                    LOGGER.fine("  " + this.rawData[i8][i2]);
                }
                iArr[i2] = -99;
            } else {
                iArr[i2] = i5;
            }
        }
        this.classifications = iArr;
        this.marginals = dArr;
        return iArr;
    }

    @Override // edu.cmu.tetrad.search.DiscreteClassifier
    public int[][] crossTabulate() {
        int[] classify = classify();
        Column column = this.testData.getColumn(this.targetVariable.getName());
        if (column == null) {
            return null;
        }
        int[] iArr = (int[]) column.getRawData();
        int size = column.size();
        int numCategories = this.targetVariable.getNumCategories();
        int[][] iArr2 = new int[numCategories][numCategories];
        for (int i = 0; i < numCategories; i++) {
            for (int i2 = 0; i2 < numCategories; i2++) {
                iArr2[i][i2] = 0;
            }
        }
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < size; i5++) {
            int i6 = classify[i5];
            int i7 = iArr[i5];
            if (i6 >= 0 && i7 >= 0) {
                i4++;
                int[] iArr3 = iArr2[i7];
                iArr3[i6] = iArr3[i6] + 1;
                if (i7 == i6) {
                    i3++;
                }
            }
        }
        this.percentCorrect = (100.0d * i3) / size;
        this.totalUsableCases = i4;
        return iArr2;
    }

    @Override // edu.cmu.tetrad.search.DiscreteClassifier
    public double getPercentCorrect() {
        if (Double.isNaN(this.percentCorrect)) {
            crossTabulate();
        }
        return this.percentCorrect;
    }

    public DiscreteVariable getTargetVariable() {
        return this.targetVariable;
    }

    public IBayesIm getBayesIm() {
        return this.bayesIM;
    }

    public DiscreteDataSet getTestData() {
        return this.testData;
    }

    public int[] getClassifications() {
        return this.classifications;
    }

    public double[][] getMarginals() {
        return this.marginals;
    }

    public int getNumCases() {
        return this.numCases;
    }

    public List getAvailableTargets() {
        return this.availableTargets;
    }

    public int getTotalUsableCases() {
        return this.totalUsableCases;
    }

    public boolean isMissingValueCaseFound() {
        return this.missingValueCaseFound;
    }

    public double getBinaryCutoff() {
        return this.binaryCutoff;
    }

    public void setBinaryCutoff(double d) {
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException();
        }
        this.binaryCutoff = d;
    }
}
