package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.ContinuousColumn;
import edu.cmu.tetrad.data.ContinuousDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Variable;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.ProtoSemGraph;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.ind.IndependenceTest;
import edu.cmu.tetrad.ind.Knowledge;
import edu.cmu.tetrad.sem.MimBuildEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.ObjectPair;
import edu.cmu.tetrad.util.ProbUtils;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:edu/cmu/tetrad/search/IndTestMimBuild.class */
public class IndTestMimBuild implements IndependenceTest {
    static final long serialVersionUID = 23;
    public static final int MIMBUILD_MLE = 0;
    public static final int MIMBUILD_2SLS = 1;
    public static final int MIMBUILD_BOOTSTRAP = 2;
    public static final int MIMBUILD_GES_ABIC = 0;
    public static final int MIMBUILD_GES_SBIC = -1;
    public static final int MIMBUILD_PC = 1;
    private ContinuousDataSet data;
    private CovarianceMatrix covMatrix;
    private List vars;
    private Knowledge measurements;
    private List latents;
    private ProtoSemGraph graph;
    private double sig = Double.NaN;
    private Hashtable measure_table;
    private int testType;
    private int algorithmType;
    private int numBootstrapSamples;
    private double[][][] bootstrapSamples;
    private double pValue;

    public IndTestMimBuild(ContinuousDataSet continuousDataSet, double d, Knowledge knowledge) {
        setData(continuousDataSet);
        this.vars = continuousDataSet.getVariables();
        this.latents = new ArrayList();
        this.measure_table = new Hashtable();
        setMeasurementsSource(knowledge);
        setSignificance(d);
        this.testType = 0;
        this.algorithmType = 0;
        this.numBootstrapSamples = 100;
    }

    public IndTestMimBuild(CovarianceMatrix covarianceMatrix, double d, Knowledge knowledge) {
        setCovMatrix(covarianceMatrix);
        this.vars = covarianceMatrix.getVariables();
        this.latents = new ArrayList();
        this.measure_table = new Hashtable();
        setMeasurementsSource(knowledge);
        setSignificance(d);
        this.testType = 0;
        this.algorithmType = 0;
        this.numBootstrapSamples = 100;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public IndependenceTest indTestSubset(List list) {
        throw new UnsupportedOperationException();
    }

    public List getAllVariablesStrings() {
        LinkedList linkedList = new LinkedList();
        Iterator requiredEdgesIterator = this.measurements.requiredEdgesIterator();
        while (requiredEdgesIterator.hasNext()) {
            ObjectPair objectPair = (ObjectPair) requiredEdgesIterator.next();
            String str = (String) objectPair.getA();
            String str2 = (String) objectPair.getB();
            if (linkedList.indexOf(str) == -1) {
                linkedList.add(str);
            }
            linkedList.add(str2);
        }
        return linkedList;
    }

    public List getVariableList() {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        Iterator requiredEdgesIterator = this.measurements.requiredEdgesIterator();
        while (requiredEdgesIterator.hasNext()) {
            ObjectPair objectPair = (ObjectPair) requiredEdgesIterator.next();
            String str = (String) objectPair.getA();
            String str2 = (String) objectPair.getB();
            if (linkedList.indexOf(str) == -1) {
                linkedList.add(str);
                linkedList2.add(new ContinuousVariable(str));
            }
            if (linkedList.indexOf(str2) == -1) {
                linkedList.add(str2);
                linkedList2.add(new ContinuousVariable(str2));
            }
        }
        return linkedList2;
    }

    public void setData(ContinuousDataSet continuousDataSet) {
        this.data = continuousDataSet;
        this.covMatrix = new CovarianceMatrix(continuousDataSet);
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public DataSet getData() {
        return this.data;
    }

    public void setCovMatrix(CovarianceMatrix covarianceMatrix) {
        this.covMatrix = covarianceMatrix;
    }

    public CovarianceMatrix getCovMatrix() {
        return this.covMatrix;
    }

    public void setNumBootstrapSamples(int i) {
        this.numBootstrapSamples = i;
    }

    public int getNumBootstrapSamples() {
        return this.numBootstrapSamples;
    }

    public Knowledge getMeasurements() {
        return this.measurements;
    }

    public void initMeasurements() {
        this.latents.clear();
        this.measure_table.clear();
        Iterator requiredEdgesIterator = this.measurements.requiredEdgesIterator();
        while (requiredEdgesIterator.hasNext()) {
            ObjectPair objectPair = (ObjectPair) requiredEdgesIterator.next();
            Object a = objectPair.getA();
            Object b = objectPair.getB();
            if (this.measure_table.containsKey(a)) {
                ((List) this.measure_table.get(a)).add(b);
            } else {
                this.latents.add(a);
                ArrayList arrayList = new ArrayList();
                arrayList.add(b);
                this.measure_table.put(a, arrayList);
            }
        }
    }

    public void setMeasurementsSource(Knowledge knowledge) {
        List list;
        this.measurements = new Knowledge();
        Node[] nodeArr = new Node[knowledge.getNumClusters()];
        for (String str : knowledge.getClusters().keySet()) {
            Object obj = knowledge.getClusters().get(str);
            if (obj instanceof Integer) {
                list = new ArrayList();
                list.add(obj);
            } else {
                list = (List) obj;
            }
            Iterator it = list.iterator();
            while (it.hasNext()) {
                int intValue = ((Integer) it.next()).intValue();
                String str2 = new String(MimBuild.LATENT_PREFIX + intValue);
                if (nodeArr[intValue] == null) {
                    nodeArr[intValue] = new GraphNode(str2);
                    nodeArr[intValue].setNodeType(NodeType.LATENT);
                }
                this.measurements.setEdgeRequired(str2, str, true);
            }
        }
        initMeasurements();
    }

    public void setSignificance(double d) {
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Significance out of range.");
        }
        this.sig = d;
    }

    public double getSignificance() {
        return this.sig;
    }

    public void setAlgorithmType(int i) {
        if (i != -1 && i != 0 && i != 1) {
            throw new IllegalArgumentException("Invalid algorithm test.");
        }
        this.algorithmType = i;
    }

    public int getAlgorithmType() {
        return this.algorithmType;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public List getVariables() {
        return Collections.unmodifiableList(this.vars);
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public boolean isIndependent(Variable variable, Variable variable2, List list) {
        System.out.println("\n\n************************************************");
        System.out.println(" Testing " + variable + " against " + variable2);
        System.out.print(" Conditional on ");
        if (list.size() > 0) {
            Iterator it = list.iterator();
            while (it.hasNext()) {
                System.out.print(it.next().toString());
            }
        } else {
            System.out.print("empty set");
        }
        System.out.println();
        System.out.println("************************************************");
        if (this.testType == 2) {
            return isIndependentBootstrap(variable, variable2, list);
        }
        ArrayList arrayList = new ArrayList();
        String[] strArr = new String[list.size()];
        Node[] nodeArr = new Node[list.size()];
        this.graph = new ProtoSemGraph();
        GraphNode graphNode = new GraphNode(variable.getName());
        graphNode.setNodeType(NodeType.LATENT);
        this.graph.addNode(graphNode);
        GraphNode graphNode2 = new GraphNode(variable2.getName());
        graphNode2.setNodeType(NodeType.LATENT);
        this.graph.addNode(graphNode2);
        Iterator it2 = list.iterator();
        int i = 0;
        while (it2.hasNext()) {
            String obj = it2.next().toString();
            nodeArr[i] = new GraphNode(obj);
            nodeArr[i].setNodeType(NodeType.LATENT);
            strArr[i] = obj;
            this.graph.addNode(nodeArr[i]);
            this.graph.addDirectedEdge(nodeArr[i], graphNode);
            this.graph.addDirectedEdge(nodeArr[i], graphNode2);
            i++;
        }
        for (int i2 = 0; i2 < list.size() - 1; i2++) {
            for (int i3 = i2 + 1; i3 < list.size(); i3++) {
                this.graph.addDirectedEdge(nodeArr[i2], nodeArr[i3]);
            }
        }
        for (String str : (List) this.measure_table.get(variable.toString())) {
            GraphNode graphNode3 = new GraphNode(str);
            graphNode3.setNodeType(NodeType.MEASURED);
            this.graph.addNode(graphNode3);
            this.graph.addDirectedEdge(graphNode, graphNode3);
            arrayList.add(str);
        }
        for (String str2 : (List) this.measure_table.get(variable2.toString())) {
            GraphNode graphNode4 = new GraphNode(str2);
            graphNode4.setNodeType(NodeType.MEASURED);
            this.graph.addNode(graphNode4);
            this.graph.addDirectedEdge(graphNode2, graphNode4);
            arrayList.add(str2);
        }
        for (int i4 = 0; i4 < list.size(); i4++) {
            for (String str3 : (List) this.measure_table.get(strArr[i4].toString())) {
                GraphNode graphNode5 = new GraphNode(str3);
                graphNode5.setNodeType(NodeType.MEASURED);
                this.graph.addNode(graphNode5);
                this.graph.addDirectedEdge(nodeArr[i4], graphNode5);
                arrayList.add(str3);
            }
        }
        String[] strArr2 = new String[this.graph.getNodes().size()];
        int i5 = 0;
        Iterator it3 = this.graph.getNodes().iterator();
        while (it3.hasNext()) {
            int i6 = i5;
            i5++;
            strArr2[i6] = ((Node) it3.next()).getName();
        }
        String[] strArr3 = new String[arrayList.size()];
        for (int i7 = 0; i7 < arrayList.size(); i7++) {
            strArr3[i7] = arrayList.get(i7).toString();
        }
        CovarianceMatrix submatrix = this.covMatrix.getSubmatrix(strArr3);
        if (this.testType != 0) {
            if (this.testType == 1) {
                throw new RuntimeException("Not currently supported!");
            }
            return true;
        }
        MimBuildEstimator newInstance = MimBuildEstimator.newInstance(submatrix, new SemPm(new SemGraph(this.graph)));
        System.out.println("\nEvaluating model without edge, MLE...");
        newInstance.estimate();
        SemIm estimatedSem = newInstance.getEstimatedSem();
        System.out.println("Prob significance = " + estimatedSem.getModelPValue());
        this.graph.addDirectedEdge(graphNode, graphNode2);
        MimBuildEstimator newInstance2 = MimBuildEstimator.newInstance(submatrix, new SemPm(new SemGraph(this.graph)));
        System.out.println("Evaluating model with edge, MLE...");
        newInstance2.estimate();
        SemIm estimatedSem2 = newInstance2.getEstimatedSem();
        System.out.println("Prob significance = " + estimatedSem2.getModelPValue());
        this.pValue = 1.0d - ProbUtils.chisqCdf(estimatedSem.getModelChiSquare() - estimatedSem2.getModelChiSquare(), 1.0d);
        if (this.pValue > this.sig) {
            System.out.println("Independent!");
        } else {
            System.out.println("NOT independent!");
        }
        return this.pValue > this.sig;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public double getDependencyStrength() {
        return this.pValue;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public double getCutoff() {
        throw new UnsupportedOperationException();
    }

    public boolean isIndependentBootstrap(Object obj, Object obj2, List list) {
        int size = list.size() + 2;
        int[] iArr = new int[size];
        iArr[0] = this.latents.indexOf(obj);
        iArr[1] = this.latents.indexOf(obj2);
        for (int i = 0; i < list.size(); i++) {
            iArr[i + 2] = this.latents.indexOf(list.get(i));
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < this.numBootstrapSamples; i2++) {
            double[][] dArr = new double[size][size];
            for (int i3 = 0; i3 < size; i3++) {
                for (int i4 = 0; i4 < size; i4++) {
                    dArr[i3][i4] = this.bootstrapSamples[i2][iArr[i3]][iArr[i4]];
                }
            }
            try {
                MatrixUtils.inverseGjNr(dArr, true);
                double pow = ((-1.0d) * dArr[0][1]) / Math.pow(dArr[0][0] * dArr[1][1], 0.5d);
                d += pow;
                d2 += pow * pow;
            } catch (Exception e) {
                throw new RuntimeException("Matrix singularity detected while using correlations \nto check for independence; probably due to collinearity \nin the data. The independence fact being checked was \n" + obj + " _||_ " + obj2 + " | " + list + ".", e);
            }
        }
        double d3 = d / this.numBootstrapSamples;
        double d4 = (d2 / this.numBootstrapSamples) - (d3 * d3);
        System.out.println("Statistic: " + (d3 / d4));
        return isZeroBootstrap(d3, d4, this.sig);
    }

    boolean isZeroBootstrap(double d, double d2, double d3) {
        return Math.abs(d / d2) < 1.96d;
    }

    private double getRelativeStrength() {
        return 0.0d;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public double getRelativeStrength(Variable variable, Variable variable2, List list) {
        isIndependent(variable, variable2, list);
        return getRelativeStrength();
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public double getPValue() {
        throw new UnsupportedOperationException("Method not implemented.");
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public List getVariableNames() {
        List variables = getVariables();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < variables.size(); i++) {
            arrayList.add(((Variable) variables.get(i)).getName());
        }
        return arrayList;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public Variable getVariable(String str) {
        for (int i = 0; i < getVariables().size(); i++) {
            Variable variable = (Variable) getVariables().get(i);
            if (variable.getName().equals(str)) {
                return variable;
            }
        }
        return null;
    }

    private void fixParameters(SemPm semPm) {
        SemGraph graph = semPm.getGraph();
        for (Node node : semPm.getLatentNodes()) {
            ArrayList arrayList = new ArrayList(graph.getChildren(node));
            Collections.sort(arrayList, new Comparator() { // from class: edu.cmu.tetrad.search.IndTestMimBuild.1
                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    return ((Node) obj).getName().compareTo(((Node) obj2).getName());
                }
            });
            int i = 0;
            while (true) {
                if (i >= arrayList.size()) {
                    break;
                }
                Node node2 = (Node) arrayList.get(i);
                if (node2.getNodeType() == NodeType.MEASURED) {
                    semPm.getParameter(node, node2).setFixed(true);
                    break;
                }
                i++;
            }
        }
    }

    public void bootstrap() {
        this.bootstrapSamples = getBootstrapSamples(this.numBootstrapSamples);
    }

    private double[][][] getBootstrapSamples(int i) {
        ProtoSemGraph protoSemGraph = new ProtoSemGraph();
        ContinuousDataSet continuousDataSet = (ContinuousDataSet) getData();
        int size = this.latents.size();
        Node[] nodeArr = new Node[size];
        double[][][] dArr = new double[i][size][size];
        int i2 = 0;
        Iterator it = this.latents.iterator();
        while (it.hasNext()) {
            nodeArr[i2] = new GraphNode((String) it.next());
            nodeArr[i2].setNodeType(NodeType.LATENT);
            protoSemGraph.addNode(nodeArr[i2]);
            i2++;
        }
        for (int i3 = 0; i3 < this.latents.size() - 1; i3++) {
            for (int i4 = i3 + 1; i4 < this.latents.size(); i4++) {
                protoSemGraph.addDirectedEdge(nodeArr[i3], nodeArr[i4]);
            }
        }
        for (int i5 = 0; i5 < nodeArr.length; i5++) {
            Iterator it2 = ((List) this.measure_table.get(nodeArr[i5].toString())).iterator();
            while (it2.hasNext()) {
                GraphNode graphNode = new GraphNode((String) it2.next());
                graphNode.setNodeType(NodeType.MEASURED);
                protoSemGraph.addNode(graphNode);
                protoSemGraph.addDirectedEdge(nodeArr[i5], graphNode);
            }
        }
        SemPm semPm = new SemPm(new SemGraph(protoSemGraph));
        fixLatentOrder(semPm);
        int maxRowCount = continuousDataSet.getMaxRowCount();
        int numColumns = continuousDataSet.getNumColumns();
        double[][] dArr2 = new double[numColumns][maxRowCount];
        DataSet dataSet = new DataSet();
        int i6 = 0;
        Iterator it3 = getVariables().iterator();
        while (it3.hasNext()) {
            int i7 = i6;
            i6++;
            dataSet.addColumn(new ContinuousColumn((ContinuousVariable) it3.next(), dArr2[i7], maxRowCount));
        }
        ContinuousDataSet continuousDataSet2 = new ContinuousDataSet(dataSet);
        for (int i8 = 0; i8 < i; i8++) {
            for (int i9 = 0; i9 < maxRowCount; i9++) {
                int nextInt = RandomUtil.nextInt(maxRowCount);
                for (int i10 = 0; i10 < numColumns; i10++) {
                    dArr2[i10][i9] = continuousDataSet.getDoubleData(i10)[nextInt];
                }
            }
            System.out.println("********\n Estimating latent covariance matrix #" + i8 + "...");
            MimBuildEstimator newInstance = MimBuildEstimator.newInstance(continuousDataSet2, semPm);
            newInstance.estimate();
            int i11 = 0;
            int i12 = 0;
            double[][] implCovar = newInstance.getEstimatedSem().getImplCovar();
            Iterator it4 = semPm.getVariableNodes().iterator();
            while (it4.hasNext()) {
                if (((Node) it4.next()).getNodeType() == NodeType.LATENT) {
                    int i13 = 0;
                    int i14 = 0;
                    Iterator it5 = semPm.getVariableNodes().iterator();
                    while (it5.hasNext()) {
                        if (((Node) it5.next()).getNodeType() == NodeType.LATENT) {
                            int i15 = i14;
                            i14++;
                            dArr[i8][i12][i15] = implCovar[i11][i13];
                        }
                        i13++;
                    }
                    i12++;
                }
                i11++;
            }
        }
        return dArr;
    }

    private void fixLatentOrder(SemPm semPm) {
        ArrayList arrayList = new ArrayList(this.latents.size());
        for (Node node : semPm.getVariableNodes()) {
            if (node.getNodeType() == NodeType.LATENT) {
                Iterator it = this.latents.iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    String str = (String) it.next();
                    if (str.equals(node.getName())) {
                        arrayList.add(str);
                        break;
                    }
                }
            }
        }
        this.latents = arrayList;
        System.out.println(this.latents.toString());
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public double getCorr(String str, String str2) {
        return 0.0d;
    }
}
