package edu.cmu.tetradapp.editor;

import edu.cmu.tetrad.data.ContinuousColumn;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteColumn;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.Variable;
import edu.cmu.tetrad.search.BayesUpdaterClassifier;
import edu.cmu.tetrad.util.RocCalculator;
import edu.cmu.tetradapp.model.BayesUpdaterClassifierWrapper;
import edu.cmu.tetradapp.util.DoubleTextField;
import edu.cmu.tetradapp.workbench.GraphWorkbench;
import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Component;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.ItemEvent;
import java.awt.event.ItemListener;
import java.text.DecimalFormat;
import javax.swing.Box;
import javax.swing.DefaultComboBoxModel;
import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTabbedPane;
import javax.swing.border.EmptyBorder;

/* loaded from: input_file:edu/cmu/tetradapp/editor/BayesUpdaterClassifierEditor.class */
public class BayesUpdaterClassifierEditor extends JComponent {
    private BayesUpdaterClassifier classifier;
    private JComboBox variableDropdown;
    private JTabbedPane tabbedPane;
    private JComboBox categoryDropdown;
    private double binaryCutoff;
    private DoubleTextField binaryCutoffField;

    public BayesUpdaterClassifierEditor(BayesUpdaterClassifier bayesUpdaterClassifier) {
        this.binaryCutoff = 0.5d;
        if (bayesUpdaterClassifier == null) {
            throw new NullPointerException();
        }
        this.classifier = bayesUpdaterClassifier;
        setLayout(new BorderLayout());
        add(getToolbar(), "North");
        add(getDisplayPanel(), "Center");
    }

    private Component getDisplayPanel() {
        JPanel jPanel = new JPanel();
        jPanel.setLayout(new BorderLayout());
        this.tabbedPane = new JTabbedPane();
        getTabbedPane().add("Graph", getGraphPanel());
        getTabbedPane().add("Test Data", getDataPanel());
        jPanel.add(getTabbedPane(), "Center");
        return jPanel;
    }

    private Component getDataPanel() {
        return new JScrollPane(new DataDisplay(getClassifier().getTestData()));
    }

    private Component getGraphPanel() {
        return new JScrollPane(new GraphWorkbench(getClassifier().getBayesIm().getGraph()));
    }

    private Component getToolbar() {
        JButton jButton = new JButton("Classify");
        jButton.addActionListener(new ActionListener() { // from class: edu.cmu.tetradapp.editor.BayesUpdaterClassifierEditor.1
            public void actionPerformed(ActionEvent actionEvent) {
                BayesUpdaterClassifierEditor.this.doClassify();
                BayesUpdaterClassifierEditor.this.showClassification();
                BayesUpdaterClassifierEditor.this.showRocCurve();
                BayesUpdaterClassifierEditor.this.showConfusionMatrix();
            }
        });
        this.variableDropdown = new JComboBox((Variable[]) getClassifier().getAvailableTargets().toArray(new Variable[0]));
        getVariableDropdown().setBackground(Color.WHITE);
        getVariableDropdown().setMaximumSize(new Dimension(200, 50));
        this.categoryDropdown = new JComboBox(((DiscreteVariable) getVariableDropdown().getSelectedItem()).getCategories());
        getCategoryDropdown().setBackground(Color.WHITE);
        getCategoryDropdown().setMaximumSize(new Dimension(200, 50));
        this.variableDropdown.addItemListener(new ItemListener() { // from class: edu.cmu.tetradapp.editor.BayesUpdaterClassifierEditor.2
            public void itemStateChanged(ItemEvent itemEvent) {
                String[] categories = ((DiscreteVariable) ((JComboBox) itemEvent.getSource()).getSelectedItem()).getCategories();
                BayesUpdaterClassifierEditor.this.getCategoryDropdown().setModel(new DefaultComboBoxModel(categories));
                if (categories.length == 2) {
                    BayesUpdaterClassifierEditor.this.getBinaryCutoffField().setEnabled(true);
                    BayesUpdaterClassifierEditor.this.getBinaryCutoffField().setEditable(true);
                } else {
                    BayesUpdaterClassifierEditor.this.getBinaryCutoffField().setEnabled(false);
                    BayesUpdaterClassifierEditor.this.getBinaryCutoffField().setEditable(false);
                }
            }
        });
        this.categoryDropdown.addItemListener(new ItemListener() { // from class: edu.cmu.tetradapp.editor.BayesUpdaterClassifierEditor.3
            public void itemStateChanged(ItemEvent itemEvent) {
                BayesUpdaterClassifierEditor.this.showRocCurve();
            }
        });
        this.binaryCutoffField = new DoubleTextField(getBinaryCutoff(), 5, new DecimalFormat("0.0###")) { // from class: edu.cmu.tetradapp.editor.BayesUpdaterClassifierEditor.4
            @Override // edu.cmu.tetradapp.util.DoubleTextField
            public void setValue(double d) {
                if (d >= 0.0d && d <= 1.0d) {
                    BayesUpdaterClassifierEditor.this.setBinaryCutoff(d);
                }
                super.setValue(BayesUpdaterClassifierEditor.this.getBinaryCutoff());
            }
        };
        if (((DiscreteVariable) this.variableDropdown.getSelectedItem()).getCategories().length == 2) {
            getBinaryCutoffField().setEnabled(true);
            getBinaryCutoffField().setEditable(true);
        } else {
            getBinaryCutoffField().setEnabled(false);
            getBinaryCutoffField().setEditable(false);
        }
        Box createVerticalBox = Box.createVerticalBox();
        Box createHorizontalBox = Box.createHorizontalBox();
        createHorizontalBox.add(Box.createHorizontalStrut(5));
        createHorizontalBox.add(new JLabel("Target = "));
        createHorizontalBox.add(getVariableDropdown());
        createHorizontalBox.add(Box.createHorizontalStrut(5));
        createHorizontalBox.add(new JLabel("Category for ROC ="));
        createHorizontalBox.add(getCategoryDropdown());
        createHorizontalBox.add(Box.createHorizontalStrut(10));
        createHorizontalBox.add(jButton);
        createHorizontalBox.add(Box.createHorizontalGlue());
        createVerticalBox.add(createHorizontalBox);
        createVerticalBox.add(Box.createVerticalStrut(5));
        Box createHorizontalBox2 = Box.createHorizontalBox();
        createHorizontalBox2.add(Box.createHorizontalStrut(5));
        createHorizontalBox2.add(new JLabel("(Cutoff for binary target = "));
        createHorizontalBox2.add(getBinaryCutoffField());
        createHorizontalBox2.add(new JLabel(" )"));
        createHorizontalBox2.add(Box.createHorizontalGlue());
        createVerticalBox.add(createHorizontalBox2);
        createVerticalBox.setBorder(new EmptyBorder(2, 2, 2, 2));
        return createVerticalBox;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void doClassify() {
        DiscreteVariable discreteVariable = (DiscreteVariable) getVariableDropdown().getSelectedItem();
        getClassifier().setTarget(discreteVariable.getName(), discreteVariable.getIndex((String) getCategoryDropdown().getSelectedItem()));
        getClassifier().classify();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void showClassification() {
        int i = -1;
        for (int i2 = 0; i2 < getTabbedPane().getTabCount(); i2++) {
            if ("Classification".equals(getTabbedPane().getTitleAt(i2))) {
                getTabbedPane().remove(i2);
                i = i2;
            }
        }
        int[] classifications = getClassifier().getClassifications();
        double[][] marginals = getClassifier().getMarginals();
        DiscreteVariable targetVariable = this.classifier.getTargetVariable();
        DiscreteVariable discreteVariable = new DiscreteVariable(targetVariable);
        discreteVariable.setName("Result");
        DiscreteColumn discreteColumn = new DiscreteColumn(discreteVariable);
        ContinuousColumn[] continuousColumnArr = new ContinuousColumn[marginals.length];
        for (int i3 = 0; i3 < continuousColumnArr.length; i3++) {
            ContinuousVariable continuousVariable = new ContinuousVariable("P(" + targetVariable + "=" + i3 + ")");
            discreteVariable.setNewCategoriesAccomodated(true);
            continuousColumnArr[i3] = new ContinuousColumn(continuousVariable);
        }
        for (int i4 = 0; i4 < getClassifier().getNumCases(); i4++) {
            discreteColumn.add(new Integer(classifications[i4]));
            for (int i5 = 0; i5 < continuousColumnArr.length; i5++) {
                continuousColumnArr[i5].add(new Double(marginals[i5][i4]));
            }
        }
        DataSet dataSet = new DataSet();
        dataSet.addColumn(discreteColumn);
        for (ContinuousColumn continuousColumn : continuousColumnArr) {
            dataSet.addColumn(continuousColumn);
        }
        JScrollPane jScrollPane = new JScrollPane(new DataDisplay(dataSet));
        if (i == -1) {
            getTabbedPane().add("Classification", jScrollPane);
        } else {
            getTabbedPane().add(jScrollPane, i);
            getTabbedPane().setTitleAt(i, "Classification");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void showRocCurve() {
        int i = -1;
        for (int i2 = 0; i2 < getTabbedPane().getTabCount(); i2++) {
            if ("ROC Plot".equals(getTabbedPane().getTitleAt(i2))) {
                getTabbedPane().remove(i2);
                i = i2;
            }
        }
        double[][] marginals = getClassifier().getMarginals();
        boolean[] zArr = new boolean[getClassifier().getNumCases()];
        DiscreteColumn discreteColumn = (DiscreteColumn) getClassifier().getTestData().get(this.classifier.getTargetVariable().getName());
        if (discreteColumn == null) {
            return;
        }
        int[] iArr = (int[]) discreteColumn.getRawData();
        String str = (String) getCategoryDropdown().getSelectedItem();
        int index = ((DiscreteVariable) discreteColumn.getVariable()).getIndex(str);
        for (int i3 = 0; i3 < zArr.length; i3++) {
            zArr[i3] = iArr[i3] == index;
        }
        RocCalculator newCalculator = RocCalculator.newCalculator(marginals[index], zArr, 0);
        RocPlot rocPlot = new RocPlot(newCalculator.getScaledRocPlot(), "ROC Plot, " + this.classifier.getTargetVariable() + " = " + str, "AUC = " + new DecimalFormat("0.0000").format(newCalculator.getAuc()));
        if (i == -1) {
            getTabbedPane().add("ROC Plot", rocPlot);
        } else {
            getTabbedPane().add(rocPlot, i);
            getTabbedPane().setTitleAt(i, "ROC Plot");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void showConfusionMatrix() {
        int i = -1;
        for (int i2 = 0; i2 < getTabbedPane().getTabCount(); i2++) {
            if ("Confusion Matrix".equals(getTabbedPane().getTitleAt(i2))) {
                getTabbedPane().remove(i2);
                i = i2;
            }
        }
        StringBuffer stringBuffer = new StringBuffer();
        int[][] crossTabulate = getClassifier().crossTabulate();
        if (crossTabulate == null) {
            return;
        }
        DiscreteVariable targetVariable = getClassifier().getTargetVariable();
        int numCategories = targetVariable.getNumCategories();
        int numCases = getClassifier().getNumCases();
        int totalUsableCases = getClassifier().getTotalUsableCases();
        stringBuffer.append("<html><pre>");
        stringBuffer.append("Total number of usable cases = " + totalUsableCases + " out of " + numCases);
        stringBuffer.append("<br><br>Target Variable " + targetVariable);
        stringBuffer.append("<br>\t\t\tEstimated\t");
        stringBuffer.append("<br>Observed\t");
        for (int i3 = 0; i3 < numCategories - 1; i3++) {
            stringBuffer.append(String.valueOf(targetVariable.getCategory(i3)) + "\t");
        }
        stringBuffer.append(targetVariable.getCategory(numCategories - 1));
        for (int i4 = 0; i4 < numCategories; i4++) {
            stringBuffer.append("<br>" + targetVariable.getCategory(i4) + "\t\t");
            for (int i5 = 0; i5 < numCategories - 1; i5++) {
                stringBuffer.append(String.valueOf(crossTabulate[i4][i5]) + "\t");
            }
            stringBuffer.append(crossTabulate[i4][numCategories - 1]);
        }
        stringBuffer.append("<br><br>Percentage correctly classified:  ");
        stringBuffer.append(getClassifier().getPercentCorrect());
        stringBuffer.append("</pre></html>");
        JLabel jLabel = new JLabel(stringBuffer.toString());
        jLabel.setFocusable(false);
        jLabel.setFont(new Font("Serif", 0, 12));
        JPanel jPanel = new JPanel();
        jPanel.setLayout(new BorderLayout());
        jPanel.setBackground(Color.WHITE);
        Box createVerticalBox = Box.createVerticalBox();
        Box createHorizontalBox = Box.createHorizontalBox();
        createHorizontalBox.add(Box.createHorizontalStrut(5));
        createHorizontalBox.add(jLabel);
        createHorizontalBox.add(Box.createHorizontalGlue());
        createVerticalBox.add(createHorizontalBox);
        createVerticalBox.add(Box.createVerticalGlue());
        createVerticalBox.add(Box.createVerticalGlue());
        jPanel.add(createVerticalBox, "Center");
        JScrollPane jScrollPane = new JScrollPane(jPanel);
        if (i == -1) {
            getTabbedPane().add("Confusion Matrix", jScrollPane);
        } else {
            getTabbedPane().add(jScrollPane, i);
            getTabbedPane().setTitleAt(i, "Confusion Matrix");
        }
    }

    public BayesUpdaterClassifierEditor(BayesUpdaterClassifierWrapper bayesUpdaterClassifierWrapper) {
        this(bayesUpdaterClassifierWrapper.getClassifier());
    }

    private BayesUpdaterClassifier getClassifier() {
        return this.classifier;
    }

    private JComboBox getVariableDropdown() {
        return this.variableDropdown;
    }

    private JTabbedPane getTabbedPane() {
        return this.tabbedPane;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public JComboBox getCategoryDropdown() {
        return this.categoryDropdown;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double getBinaryCutoff() {
        return this.binaryCutoff;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void setBinaryCutoff(double d) {
        this.binaryCutoff = d;
        getClassifier().setBinaryCutoff(d);
    }

    public DoubleTextField getBinaryCutoffField() {
        return this.binaryCutoffField;
    }
}
