package edu.cmu.tetrad.search;

import be.ac.vub.ir.statistics.InformationWEntropy;
import edu.cmu.tetrad.data.Column;
import edu.cmu.tetrad.data.ContinuousColumn;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.StatUtils;
import edu.cmu.tetrad.data.Variable;
import edu.cmu.tetrad.data.VariableRelations;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.info.EdgeInfo;
import edu.cmu.tetrad.ind.IndTestLog;
import edu.cmu.tetrad.ind.IndependenceTest;
import edu.cmu.tetrad.ind.Knowledge;
import edu.cmu.tetrad.ind.SearchLogUtils;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.LogUtils;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Vector;
import java.util.logging.Logger;

/* loaded from: input_file:edu/cmu/tetrad/search/FastAdjacencySearch.class */
public class FastAdjacencySearch implements Serializable {
    static final long serialVersionUID = 23;
    public static Logger LOGGER = LogUtils.getLogger(FastAdjacencySearch.class);
    protected Graph graph;
    protected IndependenceTest indTest;
    protected Knowledge knowledge;
    private boolean mRun;
    private DataSet data;
    protected SepsetMatrix sepset;
    private int mStartDepth = 0;
    private int depth = Integer.MAX_VALUE;
    protected IndTestLog mLog = new IndTestLog();
    private VariableRelations deterministicRelations = null;
    private List deterministicChoiceList = new Vector();

    public FastAdjacencySearch(Graph graph, IndependenceTest independenceTest) {
        this.data = null;
        this.graph = graph;
        this.indTest = independenceTest;
        if (independenceTest instanceof InformationWEntropy) {
            this.data = independenceTest.getData();
        }
        this.sepset = new SepsetMatrixImpl(graph.getNodes());
    }

    public SepsetMatrix search() {
        searchDeterministicRelations();
        List edges = this.graph.getEdges();
        for (int i = 0; i < edges.size(); i++) {
            Edge edge = (Edge) edges.get(i);
            String name = edge.getNode1().getName();
            String name2 = edge.getNode2().getName();
            if (this.knowledge.isEdgeForbidden(name, name2) && this.knowledge.isEdgeForbidden(name2, name)) {
                this.graph.removeEdge(edge);
            }
        }
        this.mRun = true;
        LOGGER.fine("Entering fast adjacency search method, Depth = " + (getDepth() == Integer.MAX_VALUE ? "Unlimited" : new Integer(getDepth()).toString()));
        for (int i2 = this.mStartDepth; i2 <= getDepth() && this.mRun; i2++) {
            System.out.println("adjStep " + i2);
            if (!adjStep(this.graph, this.indTest, getKnowledge(), this.sepset, i2)) {
                break;
            }
        }
        chooseNaturalEdges(this.sepset);
        return this.sepset;
    }

    public SepsetMatrix search(Variable variable) {
        this.mRun = true;
        LOGGER.fine("Entering fast adjacency search method, Depth = " + (getDepth() == Integer.MAX_VALUE ? "Unlimited" : new Integer(getDepth()).toString()));
        for (int i = this.mStartDepth; i <= getDepth() && this.mRun; i++) {
            System.out.println("adjStep " + i);
            if (!adjStep(this.graph, variable, this.indTest, getKnowledge(), this.sepset, i, null)) {
                break;
            }
        }
        return this.sepset;
    }

    public int getDepth() {
        return this.depth;
    }

    public void stopSearching() {
        this.mRun = false;
    }

    public void setDepth(int i) {
        if (i < -1) {
            throw new IllegalArgumentException("Depth must be -1 (unlimited) or >= 0.");
        }
        if (i == -1) {
            LOGGER.info("Setting depth of search to unlimited.");
            this.depth = Integer.MAX_VALUE;
        } else {
            LOGGER.info("Setting depth of search to " + i);
            this.depth = i;
        }
    }

    public int getStartDepth() {
        return this.mStartDepth;
    }

    public void setStartDepth(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("Start depth must be >= 0.");
        }
        if (i > this.depth) {
            throw new IllegalArgumentException("Start depth must be <= depth ");
        }
        this.mStartDepth = i;
    }

    public IndTestLog log() {
        return this.mLog;
    }

    private static void removeForbidden(List list, String str, String str2, Knowledge knowledge) {
        Iterator it = list.iterator();
        while (it.hasNext()) {
            String name = ((Node) it.next()).getName();
            if (knowledge.isEdgeForbidden(name, str) || knowledge.isEdgeRequired(str, name)) {
                if (knowledge.isEdgeForbidden(name, str2) || knowledge.isEdgeRequired(str2, name)) {
                    it.remove();
                }
            }
        }
    }

    private boolean adjStep(Graph graph, IndependenceTest independenceTest, Knowledge knowledge, SepsetMatrix sepsetMatrix, int i) {
        boolean z = false;
        LinkedList linkedList = new LinkedList(graph.getNodes());
        LinkedList linkedList2 = new LinkedList();
        Iterator it = linkedList.iterator();
        while (it.hasNext()) {
            if (adjStep(graph, (Variable) it.next(), independenceTest, knowledge, sepsetMatrix, i, linkedList2)) {
                z = true;
            }
        }
        return z;
    }

    private boolean adjStep(Graph graph, Variable variable, IndependenceTest independenceTest, Knowledge knowledge, SepsetMatrix sepsetMatrix, int i, List list) {
        LinkedList<Variable> linkedList = new LinkedList(graph.getAdjacentNodes(variable));
        if (list != null) {
            linkedList.removeAll(list);
            list.add(variable);
        }
        for (Variable variable2 : linkedList) {
            Edge edge = graph.getEdge(variable, variable2);
            if (!(edge.getObject() instanceof EdgeInfo) || ((EdgeInfo) edge.getObject()).mNbrConditionedOn < i) {
                LinkedList linkedList2 = new LinkedList(graph.getAdjacentNodes(variable));
                LinkedList linkedList3 = new LinkedList(graph.getAdjacentNodes(variable2));
                linkedList2.remove(variable2);
                linkedList3.remove(variable);
                removeForbidden(linkedList2, variable.getName(), variable2.getName(), knowledge);
                removeForbidden(linkedList3, variable.getName(), variable2.getName(), knowledge);
                boolean isNoEdgeRequired = knowledge.isNoEdgeRequired(variable.getName(), variable2.getName());
                if (linkedList2.size() >= i) {
                    ChoiceGenerator choiceGenerator = new ChoiceGenerator(linkedList2.size(), i);
                    while (true) {
                        int[] next = choiceGenerator.next();
                        if (next == null) {
                            break;
                        }
                        List asList = asList(next, linkedList2);
                        if (checkDeterministicRelations(variable, variable2, asList) && isNoEdgeRequired && independenceTest.isIndependent(variable, variable2, asList)) {
                            condIndepFound(graph, sepsetMatrix, variable, variable2, asList, this.indTest.getDependencyStrength(), this.indTest.getCutoff(), this.indTest.getCorr(variable.getName(), variable2.getName()));
                            break;
                        }
                    }
                }
                if (i > 0 && linkedList3.size() >= i) {
                    ChoiceGenerator choiceGenerator2 = new ChoiceGenerator(linkedList3.size(), i);
                    while (true) {
                        int[] next2 = choiceGenerator2.next();
                        if (next2 == null) {
                            break;
                        }
                        List asList2 = asList(next2, linkedList3);
                        if (checkDeterministicRelations(variable, variable2, asList2) && isNoEdgeRequired && independenceTest.isIndependent(variable, variable2, asList2)) {
                            condIndepFound(graph, sepsetMatrix, variable, variable2, asList2, this.indTest.getDependencyStrength(), this.indTest.getCutoff(), this.indTest.getCorr(variable.getName(), variable2.getName()));
                            break;
                        }
                    }
                }
                if (edge.getObject() instanceof EdgeInfo) {
                    ((EdgeInfo) edge.getObject()).mNbrConditionedOn = i;
                }
            }
        }
        return graph.getAdjacentNodes(variable).size() - 1 > i;
    }

    protected void condIndepFound(Graph graph, SepsetMatrix sepsetMatrix, Variable variable, Variable variable2, List list, double d, double d2, double d3) {
        this.mLog.logDSeparation(variable, variable2, list, d3, d, d2);
        SearchLogUtils.logDSeparation(variable, variable2, list, LOGGER);
        graph.removeEdge(variable, variable2);
        sepsetMatrix.setSepset(variable, variable2, new HashSet(list));
    }

    private static List asList(int[] iArr, List list) {
        LinkedList linkedList = new LinkedList();
        for (int i : iArr) {
            linkedList.add(list.get(i));
        }
        return linkedList;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public void setKnowledge(Knowledge knowledge) {
        if (knowledge == null) {
            throw new NullPointerException("Cannot set knowledge to null");
        }
        this.knowledge = knowledge;
    }

    protected void searchDeterministicRelations() {
        if (this.data == null) {
            return;
        }
        System.out.println("Starting search for deterministic relations...");
        this.mRun = true;
        InformationWEntropy informationWEntropy = new InformationWEntropy(this.data);
        this.deterministicRelations = new VariableRelations(this.data.getVariables());
        int numColumns = this.data.getNumColumns();
        for (int i = 0; i < numColumns; i++) {
            if (this.mRun) {
                Variable var = this.data.getVar(i);
                for (int i2 = i + 1; i2 < numColumns; i2++) {
                    Vector vector = new Vector();
                    Variable var2 = this.data.getVar(i2);
                    boolean isDeterministic = informationWEntropy.isDeterministic(var, var2);
                    boolean isDeterministic2 = informationWEntropy.isDeterministic(var2, var);
                    if (isDeterministic && isDeterministic2) {
                        vector.add(var);
                        vector.add(var2);
                        this.deterministicRelations.addRelations(vector);
                        Vector vector2 = new Vector();
                        vector2.add(var2);
                        vector2.add(var);
                        this.deterministicRelations.addRelations(vector2);
                    } else if (isDeterministic && !isDeterministic2) {
                        vector.add(var);
                        vector.add(var2);
                        this.deterministicRelations.addRelations(vector);
                    } else if (!isDeterministic && isDeterministic2) {
                        vector.add(var2);
                        vector.add(var);
                        this.deterministicRelations.addRelations(vector);
                    }
                    new Vector();
                }
            }
            if (this.mRun) {
                System.out.println("Search for deterministic relations done.");
            } else {
                System.out.println("Search for deterministic relations aborted.");
            }
        }
    }

    protected boolean checkDeterministicRelations(Variable variable, Variable variable2, List list) {
        boolean z = true;
        if (this.deterministicRelations == null) {
            return true;
        }
        if (this.deterministicRelations.getRelations(variable).contains(variable2) && this.deterministicRelations.getRelations(variable2).contains(variable)) {
            return false;
        }
        for (int i = 0; i < list.size(); i++) {
            if (this.deterministicRelations.getRelations(variable).contains(list.get(i))) {
                List relations = this.deterministicRelations.getRelations(variable);
                relations.add(variable);
                this.deterministicChoiceList.add(new DeterministicChoice(variable2, relations));
                z = false;
            }
            if (this.deterministicRelations.getRelations(variable2).contains(list.get(i))) {
                List relations2 = this.deterministicRelations.getRelations(variable2);
                relations2.add(variable2);
                this.deterministicChoiceList.add(new DeterministicChoice(variable, relations2));
                z = false;
            }
        }
        return z;
    }

    protected void chooseNaturalEdges(SepsetMatrix sepsetMatrix) {
        if (this.data == null) {
            return;
        }
        System.out.println("Choosing natural edges");
        boolean z = false;
        InformationWEntropy informationWEntropy = new InformationWEntropy(this.data);
        for (int i = 0; i < this.deterministicChoiceList.size(); i++) {
            DeterministicChoice deterministicChoice = (DeterministicChoice) this.deterministicChoiceList.get(i);
            if (deterministicChoice.deterministicVars.size() > 2) {
                System.out.println("Deterministic relation with more than 2 vars. Cannot handle this.");
            } else {
                Variable variable = deterministicChoice.X;
                Variable variable2 = (Variable) deterministicChoice.deterministicVars.get(0);
                Variable variable3 = (Variable) deterministicChoice.deterministicVars.get(1);
                if ((sepsetMatrix.getSepset(variable, variable2) == null && sepsetMatrix.getSepset(variable, variable3) != null) || (sepsetMatrix.getSepset(variable, variable2) != null && sepsetMatrix.getSepset(variable, variable3) == null)) {
                    int maxRowCount = this.data.getMaxRowCount();
                    double[] dArr = new double[maxRowCount];
                    double[] dArr2 = new double[maxRowCount];
                    double[] dArr3 = new double[maxRowCount];
                    Column column = this.data.getColumn(variable.getName());
                    Column column2 = this.data.getColumn(variable2.getName());
                    Column column3 = this.data.getColumn(variable3.getName());
                    if (column instanceof ContinuousColumn) {
                        double[] dArr4 = (double[]) column.getRawData();
                        for (int i2 = 0; i2 < maxRowCount; i2++) {
                            dArr[i2] = dArr4[i2];
                        }
                    } else {
                        double[] convertIntToDoubleArray = convertIntToDoubleArray((int[]) column.getRawData());
                        for (int i3 = 0; i3 < maxRowCount; i3++) {
                            dArr[i3] = convertIntToDoubleArray[i3];
                        }
                        z = true;
                    }
                    if (column2 instanceof ContinuousColumn) {
                        double[] dArr5 = (double[]) column2.getRawData();
                        for (int i4 = 0; i4 < maxRowCount; i4++) {
                            dArr2[i4] = dArr5[i4];
                        }
                    } else {
                        double[] convertIntToDoubleArray2 = convertIntToDoubleArray((int[]) column2.getRawData());
                        for (int i5 = 0; i5 < maxRowCount; i5++) {
                            dArr2[i5] = convertIntToDoubleArray2[i5];
                        }
                        z = true;
                    }
                    if (column3 instanceof ContinuousColumn) {
                        double[] dArr6 = (double[]) column3.getRawData();
                        for (int i6 = 0; i6 < maxRowCount; i6++) {
                            dArr3[i6] = dArr6[i6];
                        }
                    } else {
                        double[] convertIntToDoubleArray3 = convertIntToDoubleArray((int[]) column3.getRawData());
                        for (int i7 = 0; i7 < maxRowCount; i7++) {
                            dArr3[i7] = convertIntToDoubleArray3[i7];
                        }
                        z = true;
                    }
                    if (z) {
                        System.out.println("Performing entropytest");
                        new Vector().add(DataUtils.toColumnExt(column));
                        System.out.println("Calculating mutualInfoYX, vars: " + variable2.getName() + ", " + variable.getName());
                        float mutualInfo = informationWEntropy.mutualInfo(variable2, variable);
                        System.out.println("Calculating mutualInfoZX, vars: " + variable3.getName() + ", " + variable.getName());
                        float mutualInfo2 = informationWEntropy.mutualInfo(variable3, variable);
                        System.out.println("Done computing mutual info");
                        if (mutualInfo < mutualInfo2) {
                            if (sepsetMatrix.getSepset(variable, variable3) == null) {
                                sepsetMatrix.setSepset(variable, variable3, sepsetMatrix.getSepset(variable, variable2));
                                sepsetMatrix.setSepset(variable, variable2, null);
                            }
                        } else if (sepsetMatrix.getSepset(variable, variable2) == null) {
                            sepsetMatrix.setSepset(variable, variable2, sepsetMatrix.getSepset(variable, variable3));
                            sepsetMatrix.setSepset(variable, variable3, null);
                        }
                    } else {
                        System.out.println("Performing regression");
                        double d = 0.0d;
                        double d2 = 0.0d;
                        int i8 = 2;
                        while (d < 0.94d && d2 < 0.94d && i8 < 5) {
                            i8++;
                            d = StatUtils.rFitting(dArr, dArr2, i8);
                            d2 = StatUtils.rFitting(dArr, dArr3, i8);
                        }
                        if (d > d2) {
                            if (sepsetMatrix.getSepset(variable, variable2) == null) {
                                sepsetMatrix.setSepset(variable, variable2, sepsetMatrix.getSepset(variable, variable3));
                                sepsetMatrix.setSepset(variable, variable3, null);
                            }
                        } else if (sepsetMatrix.getSepset(variable, variable3) == null) {
                            sepsetMatrix.setSepset(variable, variable3, sepsetMatrix.getSepset(variable, variable2));
                            sepsetMatrix.setSepset(variable, variable2, null);
                        }
                    }
                }
            }
        }
        System.out.println("Choose natural edge done");
    }

    private double[] convertIntToDoubleArray(int[] iArr) {
        double[] dArr = new double[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            dArr[i] = iArr[i];
        }
        return dArr;
    }

    public VariableRelations getDeterministicRelations() {
        return this.deterministicRelations;
    }
}
