package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.ContinuousDataSet;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.EndpointMatrixGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.ProtoSemGraph;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.ind.Knowledge;
import edu.cmu.tetrad.sem.MimBuildEstimator;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.ObjectPair;
import java.beans.PropertyVetoException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:edu/cmu/tetrad/search/MimBuildScoreSearch.class */
public class MimBuildScoreSearch extends PcStub {
    static final long serialVersionUID = 23;
    public final int OPTIMIZER_EM = 0;
    public final int OPTIMIZER_GRADIENT = 1;
    protected List latents;
    private CovarianceMatrix covMatrix;
    private int optimizer;
    private boolean adjustedScore;
    Node[] indicatorsArray;
    Node[] latentsArray;
    int numIndicators;
    int numLatents;
    int indicatorErrorsIndex;
    int latentErrorsIndex;
    double[] theta;
    double[][] bigLambda;
    double[][] beta;
    double[][] fi;
    double[][] iBeta;
    double[][] iMinusB;
    double[][] iMinusBT;
    double[][] J;
    List latentParents;
    int[] indicatorParents;
    int[] lambdaIndex;
    int[] betaIndex;
    double[][] latentImpliedCovar;
    double[][] deltaB0;
    double[][] deltaB1;
    double[][] deltaB2;
    double[][] Cyy;
    double[][] Cyz;
    double[][] Czz;
    int numObserved;
    int numLatent;
    Hashtable observableNames;
    Hashtable latentNames;
    Hashtable fixedMeasures;

    public MimBuildScoreSearch(IndTestMimBuild indTestMimBuild, Knowledge knowledge) {
        super(indTestMimBuild, knowledge);
        this.OPTIMIZER_EM = 0;
        this.OPTIMIZER_GRADIENT = 1;
        this.latents = new ArrayList();
        this.optimizer = 0;
        this.adjustedScore = true;
    }

    @Override // edu.cmu.tetrad.search.PcStub, edu.cmu.tetrad.search.SearchAlgorithm
    public Graph search() {
        if (getIndependenceTest() == null) {
            throw new NullPointerException();
        }
        List allVariablesStrings = ((IndTestMimBuild) getIndependenceTest()).getAllVariablesStrings();
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < allVariablesStrings.size(); i++) {
            linkedList.add(new GraphNode((String) allVariablesStrings.get(i)));
        }
        EndpointMatrixGraph endpointMatrixGraph = new EndpointMatrixGraph(linkedList);
        this.covMatrix = ((IndTestMimBuild) getIndependenceTest()).getCovMatrix();
        startMeasurementModel(endpointMatrixGraph);
        if (this.optimizer == 0) {
            structuralEmInitialization(endpointMatrixGraph);
        }
        try {
            return ges(endpointMatrixGraph);
        } catch (PropertyVetoException e) {
            System.out.println("Fatal error during MimBuildScoreSearch." + e.toString());
            return null;
        }
    }

    protected void startMeasurementModel(Graph graph) {
        Iterator requiredEdgesIterator = ((IndTestMimBuild) getIndependenceTest()).getMeasurements().requiredEdgesIterator();
        while (requiredEdgesIterator.hasNext()) {
            ObjectPair objectPair = (ObjectPair) requiredEdgesIterator.next();
            String str = (String) objectPair.getA();
            String str2 = (String) objectPair.getB();
            Node node = graph.getNode(str);
            Node node2 = graph.getNode(str2);
            graph.setEndpoint(node, node2, Endpoint.ARROW);
            graph.setEndpoint(node2, node, Endpoint.SEGMENT);
            if (this.latents.indexOf(node) == -1) {
                this.latents.add(node);
                node.setNodeType(NodeType.LATENT);
            }
        }
        System.out.println(graph.toString());
        System.exit(0);
    }

    public void setAdjustedScore(boolean z) {
        this.adjustedScore = z;
    }

    public boolean getAdjustedScore() {
        return this.adjustedScore;
    }

    protected Graph ges(Graph graph) throws PropertyVetoException {
        EdgeListGraph edgeListGraph;
        System.out.println("******************************************");
        System.out.println("* MIM BUILD SCORE (GES ALGORITHM)");
        System.out.println();
        EdgeListGraph edgeListGraph2 = new EdgeListGraph(graph);
        this.latents.clear();
        for (Node node : edgeListGraph2.getNodes()) {
            if (node.getName().startsWith(MimBuild.LATENT_PREFIX)) {
                node.setNodeType(NodeType.LATENT);
                this.latents.add(node);
            }
        }
        double d = -1.7976931348623157E308d;
        if (this.optimizer == 1) {
            bes(edgeListGraph2, fes(edgeListGraph2));
        } else if (this.optimizer == 0) {
            EdgeListGraph edgeListGraph3 = null;
            expectation(edgeListGraph2);
            do {
                edgeListGraph = new EdgeListGraph(edgeListGraph2);
                double bes = bes(edgeListGraph, fes(edgeListGraph));
                if (bes > d) {
                    pdagToDag(edgeListGraph);
                    edgeListGraph3 = edgeListGraph;
                    d = bes;
                    System.out.println("Starting next round... (Current score = " + d + ")");
                    expectation(edgeListGraph);
                }
            } while (edgeListGraph3 == edgeListGraph);
            edgeListGraph2 = edgeListGraph3;
            dagToPdag(edgeListGraph2);
        }
        List allVariablesStrings = ((IndTestMimBuild) getIndependenceTest()).getAllVariablesStrings();
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < allVariablesStrings.size(); i++) {
            String str = (String) allVariablesStrings.get(i);
            GraphNode graphNode = new GraphNode(str);
            if (str.startsWith(MimBuild.LATENT_PREFIX)) {
                graphNode.setNodeType(NodeType.LATENT);
            }
            linkedList.add(graphNode);
        }
        EndpointMatrixGraph endpointMatrixGraph = new EndpointMatrixGraph(linkedList);
        startMeasurementModel(endpointMatrixGraph);
        for (Edge edge : edgeListGraph2.getEdges()) {
            System.out.println(edge.toString());
            String name = edge.getNode1().getName();
            String name2 = edge.getNode2().getName();
            Node node2 = endpointMatrixGraph.getNode(name);
            Node node3 = endpointMatrixGraph.getNode(name2);
            if (edge.getEndpoint1() == Endpoint.ARROW) {
                endpointMatrixGraph.setEndpoint(node3, node2, Endpoint.ARROW);
            } else {
                endpointMatrixGraph.setEndpoint(node3, node2, Endpoint.SEGMENT);
            }
            if (edge.getEndpoint2() == Endpoint.ARROW) {
                endpointMatrixGraph.setEndpoint(node2, node3, Endpoint.ARROW);
            } else {
                endpointMatrixGraph.setEndpoint(node2, node3, Endpoint.SEGMENT);
            }
        }
        return endpointMatrixGraph;
    }

    private double scoreGraph(Graph graph) {
        EdgeListGraph edgeListGraph = new EdgeListGraph(graph);
        pdagToDag(edgeListGraph);
        if (this.optimizer != 1) {
            if (this.optimizer == 0) {
                return scoreSemIm(maximization(edgeListGraph));
            }
            throw new RuntimeException("No valid optimizer chosen for MimBuildScoreSearch!");
        }
        MimBuildEstimator newInstance = MimBuildEstimator.newInstance((ContinuousDataSet) ((IndTestMimBuild) getIndependenceTest()).getData(), new SemPm(new SemGraph(new ProtoSemGraph(edgeListGraph))));
        newInstance.estimate();
        return scoreSemIm(newInstance.getEstimatedSem());
    }

    private double scoreSemIm(SemIm semIm) {
        semIm.getImplCovar();
        return (((-0.5d) * semIm.getSampleSize()) * (logDetSigma(semIm) + traceSSigmaInv(semIm))) - ((0.5d * semIm.getNumFreeParams()) * Math.log(semIm.getSampleSize()));
    }

    private void fixParameters(SemPm semPm) {
        SemGraph graph = semPm.getGraph();
        for (Node node : semPm.getLatentNodes()) {
            ArrayList arrayList = new ArrayList(graph.getChildren(node));
            Collections.sort(arrayList, new Comparator() { // from class: edu.cmu.tetrad.search.MimBuildScoreSearch.1
                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    return ((Node) obj).getName().compareTo(((Node) obj2).getName());
                }
            });
            int i = 0;
            while (true) {
                if (i >= arrayList.size()) {
                    break;
                }
                Node node2 = (Node) arrayList.get(i);
                if (node2.getNodeType() == NodeType.MEASURED) {
                    semPm.getParameter(node, node2).setFixed(true);
                    break;
                }
                i++;
            }
        }
    }

    private double fes(Graph graph) {
        Node node;
        System.out.println("** FORWARD EQUIVALENCE SEARCH");
        System.out.println("###################################################");
        double scoreGraph = scoreGraph(graph);
        System.out.println("Initial Score = " + scoreGraph);
        System.out.println("###################################################");
        HashSet hashSet = new HashSet();
        do {
            System.out.println("###################################################");
            Node node2 = null;
            node = null;
            for (Node node3 : this.latents) {
                for (Node node4 : this.latents) {
                    if (node3 != node4 && !graph.isAdjacentTo(node3, node4)) {
                        for (Set set : powerSet(getTNeighbors(node3, node4, graph))) {
                            if (validInsert(node3, node4, set, graph)) {
                                System.out.println("Search operator: add " + node3.getName() + " -> " + node4.getName() + "(" + set.toString() + ")");
                                double insertEval = insertEval(node3, node4, set, graph);
                                System.out.println(" - Score = " + insertEval);
                                System.out.println();
                                if (insertEval > scoreGraph) {
                                    scoreGraph = insertEval;
                                    node = node3;
                                    node2 = node4;
                                    hashSet.clear();
                                    hashSet.addAll(set);
                                }
                            }
                        }
                    }
                }
            }
            if (node != null) {
                System.out.println();
                System.out.println("--- Adding edge " + node.getName() + " -> " + node2.getName());
                insert(node, node2, hashSet, graph);
                rebuildPattern(graph);
            }
            System.out.println("###################################################");
            System.out.println();
        } while (node != null);
        return scoreGraph;
    }

    private Set getTNeighbors(Node node, Node node2, Graph graph) {
        HashSet hashSet = new HashSet();
        for (Node node3 : graph.getNodes()) {
            if (node3 != node2 && node3 != node && graph.isAdjacentTo(node2, node3) && !graph.isParentOf(node2, node3) && !graph.isParentOf(node3, node2) && !graph.isAdjacentTo(node, node3)) {
                hashSet.add(node3);
            }
        }
        return hashSet;
    }

    private double insertEval(Node node, Node node2, Set set, Graph graph) {
        graph.addDirectedEdge(node, node2);
        Iterator it = set.iterator();
        while (it.hasNext()) {
            Node node3 = (Node) it.next();
            graph.removeEdges(node3, node2);
            graph.addDirectedEdge(node3, node2);
        }
        double scoreGraph = scoreGraph(graph);
        Iterator it2 = set.iterator();
        while (it2.hasNext()) {
            Node node4 = (Node) it2.next();
            graph.removeEdges(node4, node2);
            graph.addUndirectedEdge(node4, node2);
        }
        graph.removeEdges(node, node2);
        return scoreGraph;
    }

    private void insert(Node node, Node node2, Set set, Graph graph) {
        graph.addDirectedEdge(node, node2);
        Iterator it = set.iterator();
        while (it.hasNext()) {
            Node node3 = (Node) it.next();
            graph.removeEdges(node3, node2);
            graph.addDirectedEdge(node3, node2);
        }
    }

    private boolean validInsert(Node node, Node node2, Set set, Graph graph) {
        HashSet hashSet = new HashSet(set);
        hashSet.addAll(findNaYX(node, node2, graph));
        return isClique(hashSet, graph) && semiDirectedBlocked(node, node2, hashSet, graph, new HashSet());
    }

    private boolean semiDirectedBlocked(Node node, Node node2, Set set, Graph graph, Set set2) {
        if (set.contains(node2)) {
            return true;
        }
        if (node2 == node) {
            return false;
        }
        for (Node node3 : graph.getNodes()) {
            if (node3 != node2 && !set2.contains(node3) && graph.isAdjacentTo(node2, node3) && !graph.isParentOf(node3, node2)) {
                set2.add(node3);
                if (!semiDirectedBlocked(node, node3, set, graph, set2)) {
                    return false;
                }
                set2.remove(node3);
            }
        }
        return true;
    }

    private double bes(Graph graph, double d) {
        Node node;
        System.out.println("** BACKWARD ELIMINATION SEARCH");
        double d2 = d;
        HashSet hashSet = new HashSet();
        do {
            Node node2 = null;
            node = null;
            for (Node node3 : this.latents) {
                for (Node node4 : this.latents) {
                    if (node3 != node4 && graph.isAdjacentTo(node3, node4)) {
                        for (Set set : powerSet(getHNeighbors(node3, node4, graph))) {
                            if (validDelete(node3, node4, set, graph)) {
                                System.out.println("Search operator: delete " + node3.getName() + " - " + node4.getName() + "(" + set.toString() + ")");
                                double deleteEval = deleteEval(node3, node4, set, graph);
                                System.out.println(" - Score = " + deleteEval);
                                if (deleteEval > d2) {
                                    d2 = deleteEval;
                                    node = node3;
                                    node2 = node4;
                                    hashSet.clear();
                                    hashSet.addAll(set);
                                }
                            }
                        }
                    }
                }
            }
            if (node != null) {
                System.out.println("Deleting edge " + node.getName() + " - " + node2.getName());
                delete(node, node2, hashSet, graph);
                rebuildPattern(graph);
            }
        } while (node != null);
        return d2;
    }

    private Set getHNeighbors(Node node, Node node2, Graph graph) {
        HashSet hashSet = new HashSet();
        for (Node node3 : graph.getNodes()) {
            if (node3 != node2 && node3 != node && graph.isAdjacentTo(node2, node3) && !graph.isParentOf(node2, node3) && !graph.isParentOf(node3, node2) && graph.isAdjacentTo(node, node3)) {
                hashSet.add(node3);
            }
        }
        return hashSet;
    }

    private double deleteEval(Node node, Node node2, Set set, Graph graph) {
        boolean z;
        Node node3;
        Node node4;
        if (graph.isParentOf(node, node2)) {
            node3 = node;
            node4 = node2;
            z = true;
        } else {
            z = graph.isParentOf(node2, node);
            node3 = node2;
            node4 = node;
        }
        graph.removeEdges(node, node2);
        Iterator it = set.iterator();
        HashSet hashSet = new HashSet();
        while (it.hasNext()) {
            Node node5 = (Node) it.next();
            if (!graph.isParentOf(node5, node) && !graph.isParentOf(node, node5)) {
                graph.removeEdges(node, node5);
                graph.addDirectedEdge(node, node5);
                hashSet.add(node5);
            }
            graph.removeEdges(node2, node5);
            graph.addDirectedEdge(node2, node5);
        }
        double scoreGraph = scoreGraph(graph);
        Iterator it2 = set.iterator();
        while (it2.hasNext()) {
            Node node6 = (Node) it2.next();
            if (hashSet.contains(node6)) {
                graph.removeEdges(node, node6);
                graph.addUndirectedEdge(node, node6);
            }
            graph.removeEdges(node2, node6);
            graph.addUndirectedEdge(node2, node6);
        }
        graph.removeEdges(node, node2);
        if (z) {
            graph.addDirectedEdge(node3, node4);
        } else {
            graph.addUndirectedEdge(node, node2);
        }
        return scoreGraph;
    }

    private void delete(Node node, Node node2, Set set, Graph graph) {
        graph.removeEdges(node, node2);
        Iterator it = set.iterator();
        while (it.hasNext()) {
            Node node3 = (Node) it.next();
            if (!graph.isParentOf(node3, node) && !graph.isParentOf(node, node3)) {
                graph.removeEdges(node, node3);
                graph.addDirectedEdge(node, node3);
            }
            graph.removeEdges(node2, node3);
            graph.addDirectedEdge(node2, node3);
        }
    }

    private boolean validDelete(Node node, Node node2, Set set, Graph graph) {
        Set findNaYX = findNaYX(node, node2, graph);
        findNaYX.removeAll(set);
        return isClique(findNaYX, graph);
    }

    private Node getNode(Graph graph, String str) {
        for (Node node : graph.getNodes()) {
            if (node.getName().equals(str)) {
                return node;
            }
        }
        return null;
    }

    private List powerSet(Set set) {
        ArrayList arrayList = new ArrayList(set);
        ArrayList arrayList2 = new ArrayList();
        int pow = (int) Math.pow(2.0d, set.size());
        for (int i = 0; i < pow; i++) {
            HashSet hashSet = new HashSet();
            String binaryString = Integer.toBinaryString(i);
            for (int length = binaryString.length() - 1; length >= 0; length--) {
                if (binaryString.charAt(length) == '1') {
                    hashSet.add(arrayList.get((binaryString.length() - length) - 1));
                }
            }
            arrayList2.add(hashSet);
        }
        return arrayList2;
    }

    private boolean isClique(Set set, Graph graph) {
        LinkedList linkedList = new LinkedList(set);
        for (int i = 0; i < linkedList.size() - 1; i++) {
            for (int i2 = i + 1; i2 < linkedList.size(); i2++) {
                if (!graph.isAdjacentTo((Node) linkedList.get(i), (Node) linkedList.get(i2))) {
                    return false;
                }
            }
        }
        return true;
    }

    private Set findNaYX(Node node, Node node2, Graph graph) {
        HashSet hashSet = new HashSet();
        for (Node node3 : graph.getNodes()) {
            if (node3 != node2 && node3 != node && graph.isAdjacentTo(node2, node3) && !graph.isParentOf(node2, node3) && !graph.isParentOf(node3, node2) && graph.isAdjacentTo(node, node3)) {
                hashSet.add(node3);
            }
        }
        return hashSet;
    }

    private void rebuildPattern(Graph graph) {
        pdagToDag(graph);
        dagToPdag(graph);
    }

    private void pdagToDag(Graph graph) {
        EdgeListGraph edgeListGraph = new EdgeListGraph(graph);
        HashSet hashSet = new HashSet();
        for (Edge edge : graph.getEdges()) {
            if (edge.getEndpoint1() == Endpoint.SEGMENT && edge.getEndpoint2() == Endpoint.SEGMENT) {
                hashSet.add(edge);
            }
        }
        graph.removeEdges(new ArrayList(hashSet));
        ArrayList<Node> arrayList = new ArrayList();
        for (Node node : edgeListGraph.getNodes()) {
            if (node.getNodeType() == NodeType.LATENT) {
                arrayList.add(node);
            }
        }
        do {
            r13 = null;
            for (Node node2 : arrayList) {
                Iterator it = edgeListGraph.getChildren(node2).iterator();
                boolean z = false;
                while (it.hasNext() && !z) {
                    if (((Node) it.next()).getNodeType() == NodeType.LATENT) {
                        z = true;
                    }
                }
                if (!z) {
                    HashSet hashSet2 = new HashSet();
                    for (Edge edge2 : edgeListGraph.getEdges()) {
                        if (edge2.getNode1() == node2 || edge2.getNode2() == node2) {
                            if (edge2.getEndpoint1() == Endpoint.SEGMENT && edge2.getEndpoint2() == Endpoint.SEGMENT) {
                                if (edge2.getNode1() == node2) {
                                    hashSet2.add(edge2.getNode2());
                                } else {
                                    hashSet2.add(edge2.getNode1());
                                }
                            }
                        }
                    }
                    if (hashSet2.size() > 0) {
                        List parents = edgeListGraph.getParents(node2);
                        HashSet hashSet3 = new HashSet(hashSet2);
                        hashSet3.addAll(parents);
                        if (!isClique(hashSet3, edgeListGraph)) {
                        }
                    }
                    Iterator it2 = hashSet2.iterator();
                    while (it2.hasNext()) {
                        graph.addDirectedEdge(getNode(graph, ((Node) it2.next()).getName()), getNode(graph, node2.getName()));
                    }
                    edgeListGraph.removeNode(node2);
                    arrayList.remove(node2);
                }
            }
            arrayList.remove(node2);
        } while (arrayList.size() > 0);
    }

    private void dagToPdag(Graph graph) {
        System.out.println(graph.toString());
        EdgeListGraph edgeListGraph = new EdgeListGraph();
        Iterator it = graph.getNodes().iterator();
        Node[] nodeArr = new Node[graph.getNumNodes()];
        for (int i = 0; i < graph.getNumNodes(); i++) {
            nodeArr[i] = (Node) it.next();
            if (nodeArr[i].getNodeType() == NodeType.LATENT) {
                edgeListGraph.addNode(new GraphNode(nodeArr[i].getName()));
            }
        }
        for (int i2 = 0; i2 < graph.getNumNodes(); i2++) {
            Node node = nodeArr[i2];
            if (node.getNodeType() == NodeType.LATENT) {
                for (int i3 = i2 + 1; i3 < graph.getNumNodes(); i3++) {
                    Node node2 = nodeArr[i3];
                    if (node2.getNodeType() == NodeType.LATENT && graph.isAdjacentTo(node, node2)) {
                        if (graph.getEdge(node, node2).getNode1() == node) {
                            edgeListGraph.addDirectedEdge(getNode(edgeListGraph, node.getName()), getNode(edgeListGraph, node2.getName()));
                        } else {
                            edgeListGraph.addDirectedEdge(getNode(edgeListGraph, node2.getName()), getNode(edgeListGraph, node.getName()));
                        }
                    }
                }
            }
        }
        EdgeListGraph edgeListGraph2 = new EdgeListGraph(edgeListGraph);
        Node[] nodeArr2 = new Node[edgeListGraph2.getNodes().size()];
        int i4 = 0;
        while (edgeListGraph2.getNodes().size() > 0) {
            HashSet hashSet = new HashSet();
            for (Node node3 : edgeListGraph2.getNodes()) {
                if (edgeListGraph2.isExogenous(node3)) {
                    hashSet.add(node3);
                    int i5 = i4;
                    i4++;
                    nodeArr2[i5] = graph.getNode(node3.getName());
                }
            }
            edgeListGraph2.removeNodes(new ArrayList(hashSet));
        }
        int i6 = 0;
        Edge[] edgeArr = new Edge[edgeListGraph.getNumEdges()];
        boolean[] zArr = new boolean[edgeListGraph.getNumEdges()];
        Edge[] edgeArr2 = new Edge[edgeListGraph.getNumEdges()];
        for (Edge edge : edgeListGraph.getEdges()) {
            int i7 = i6;
            i6++;
            edgeArr[i7] = graph.getEdge(graph.getNode(edge.getNode1().getName()), graph.getNode(edge.getNode2().getName()));
        }
        for (int i8 = 0; i8 < edgeArr.length; i8++) {
            zArr[i8] = false;
        }
        while (i6 > 0) {
            for (int i9 = 0; i9 < nodeArr2.length; i9++) {
                for (int i10 = 0; i10 < edgeArr.length; i10++) {
                    if (!zArr[i10] && edgeArr[i10].getNode2() == nodeArr2[i9]) {
                        for (int length = nodeArr2.length - 1; length >= 0; length--) {
                            for (int i11 = 0; i11 < edgeArr.length; i11++) {
                                if (!zArr[i11] && edgeArr[i11].getNode1() == nodeArr2[length] && edgeArr[i11].getNode2() == nodeArr2[i9]) {
                                    zArr[i11] = true;
                                    edgeArr2[edgeArr2.length - i6] = edgeArr[i11];
                                    i6--;
                                }
                            }
                        }
                    }
                }
            }
        }
        boolean[] zArr2 = new boolean[edgeListGraph.getNumEdges()];
        boolean[] zArr3 = new boolean[edgeListGraph.getNumEdges()];
        for (int i12 = 0; i12 < edgeListGraph.getNumEdges(); i12++) {
            zArr2[i12] = false;
            zArr3[i12] = false;
        }
        for (int i13 = 0; i13 < edgeListGraph.getNumEdges(); i13++) {
            if (!zArr2[i13] && !zArr3[i13]) {
                Node node1 = edgeArr2[i13].getNode1();
                Node node22 = edgeArr2[i13].getNode2();
                for (int i14 = 0; i14 < edgeArr2.length; i14++) {
                    if (edgeArr2[i14].getNode2() == node1 && zArr2[i14]) {
                        Node node12 = edgeArr2[i14].getNode1();
                        if (edgeListGraph.isParentOf(node12, node22)) {
                            for (int i15 = 0; i15 < edgeArr2.length; i15++) {
                                if (edgeArr2[i15].getNode1() == node12 && edgeArr2[i15].getNode2() == node22) {
                                    zArr2[i15] = true;
                                }
                            }
                        } else {
                            for (int i16 = 0; i16 < edgeArr2.length; i16++) {
                                if (edgeArr2[i16].getNode2() == node22) {
                                    zArr2[i16] = true;
                                }
                            }
                        }
                    }
                    if (zArr2[i13]) {
                        break;
                    }
                }
                if (!zArr2[i13]) {
                    boolean z = false;
                    int i17 = 0;
                    while (true) {
                        if (i17 >= edgeArr2.length) {
                            break;
                        }
                        Node node13 = edgeArr2[i17].getNode1();
                        if (node13 == node1 || edgeArr2[i17].getNode2() != node22 || edgeListGraph.isParentOf(node13, node1)) {
                            i17++;
                        } else {
                            zArr2[i13] = true;
                            for (int i18 = i13 + 1; i18 < edgeListGraph.getNumEdges(); i18++) {
                                if (edgeArr2[i18].getNode2() == node22 && !zArr3[i18]) {
                                    zArr2[i18] = true;
                                }
                            }
                            z = true;
                        }
                    }
                    if (!z) {
                        zArr3[i13] = true;
                        for (int i19 = 0; i19 < edgeArr2.length; i19++) {
                            if (!zArr2[i19] && edgeArr2[i19].getNode2() == node22) {
                                zArr3[i19] = true;
                            }
                        }
                    }
                }
            }
        }
        for (int i20 = 0; i20 < zArr3.length; i20++) {
            if (zArr3[i20]) {
                graph.setEndpoint(edgeArr2[i20].getNode1(), edgeArr2[i20].getNode2(), Endpoint.SEGMENT);
                graph.setEndpoint(edgeArr2[i20].getNode2(), edgeArr2[i20].getNode1(), Endpoint.SEGMENT);
            }
        }
    }

    private void structuralEmInitialization(Graph graph) {
        this.observableNames = new Hashtable();
        this.latentNames = new Hashtable();
        this.numObserved = 0;
        this.numLatent = 0;
        for (Node node : graph.getNodes()) {
            System.out.println("Next node: " + node.toString());
            if (node.getNodeType() == NodeType.LATENT) {
                this.latentNames.put(node.getName(), new Integer(this.numLatent));
                this.numLatent++;
            } else {
                this.observableNames.put(node.getName(), new Integer(this.numObserved));
                this.numObserved++;
            }
        }
        this.Cyy = new double[this.numObserved][this.numObserved];
        System.out.println("#Observed = " + this.numObserved);
        String[] variableNames = this.covMatrix.getVariableNames();
        double[][] matrix = this.covMatrix.getMatrix();
        System.out.println("#Cov = " + matrix.length);
        for (int i = 0; i < this.numObserved; i++) {
            for (int i2 = 0; i2 < this.numObserved; i2++) {
                this.Cyy[((Integer) this.observableNames.get(variableNames[i])).intValue()][((Integer) this.observableNames.get(variableNames[i2])).intValue()] = matrix[i][i2];
            }
        }
        this.fixedMeasures = new Hashtable();
        EdgeListGraph edgeListGraph = new EdgeListGraph(graph);
        pdagToDag(edgeListGraph);
        System.out.println(edgeListGraph);
        for (Node node2 : edgeListGraph.getNodes()) {
            if (node2.getNodeType() == NodeType.LATENT) {
                List children = edgeListGraph.getChildren(node2);
                Collections.sort(children, new Comparator() { // from class: edu.cmu.tetrad.search.MimBuildScoreSearch.2
                    @Override // java.util.Comparator
                    public int compare(Object obj, Object obj2) {
                        return ((Node) obj).getName().compareTo(((Node) obj2).getName());
                    }
                });
                boolean z = false;
                for (int i3 = 0; i3 < children.size(); i3++) {
                    Node node3 = (Node) children.get(i3);
                    if (node3.getNodeType() != NodeType.LATENT) {
                        if (z) {
                            this.fixedMeasures.put(node3.getName(), new Boolean(false));
                        } else {
                            z = true;
                            this.fixedMeasures.put(node3.getName(), new Boolean(true));
                        }
                    }
                }
            }
        }
    }

    private void expectation(Graph graph) {
        SemPm semPm = new SemPm(new SemGraph(new ProtoSemGraph(graph)));
        System.out.println(graph.toString());
        MimBuildEstimator newInstance = MimBuildEstimator.newInstance(this.covMatrix, semPm);
        newInstance.estimate();
        expectation(newInstance.getEstimatedSem());
    }

    private void expectation(SemIm semIm) {
        double[][] dArr = new double[this.numLatent][this.numLatent];
        double[][] dArr2 = new double[this.numLatent][this.numObserved];
        double[][] dArr3 = new double[this.numObserved][this.numObserved];
        double[][] dArr4 = new double[this.numLatent][this.numLatent];
        for (int i = 0; i < this.numLatent; i++) {
            for (int i2 = 0; i2 < this.numLatent; i2++) {
                dArr[i][i2] = 0.0d;
                dArr4[i][i2] = 0.0d;
            }
            for (int i3 = 0; i3 < this.numObserved; i3++) {
                dArr2[i][i3] = 0.0d;
            }
        }
        for (int i4 = 0; i4 < this.numObserved; i4++) {
            for (int i5 = 0; i5 < this.numObserved; i5++) {
                dArr3[i4][i5] = 0.0d;
            }
        }
        List freeParameters = semIm.getFreeParameters();
        double[] paramValues = semIm.getParamValues();
        for (int i6 = 0; i6 < freeParameters.size(); i6++) {
            Parameter parameter = (Parameter) freeParameters.get(i6);
            if (parameter.getType() == ParamType.COEF) {
                Node nodeA = parameter.getNodeA();
                Node nodeB = parameter.getNodeB();
                if (nodeB.getNodeType() == NodeType.MEASURED) {
                    dArr2[((Integer) this.latentNames.get(nodeA.getName())).intValue()][((Integer) this.observableNames.get(nodeB.getName())).intValue()] = paramValues[i6];
                } else if (nodeB.getNodeType() == NodeType.LATENT) {
                    dArr[((Integer) this.latentNames.get(nodeB.getName())).intValue()][((Integer) this.latentNames.get(nodeA.getName())).intValue()] = paramValues[i6];
                }
            } else if (parameter.getType() == ParamType.VAR) {
                Node nodeA2 = parameter.getNodeA();
                if (nodeA2.getNodeType() == NodeType.ERROR) {
                    nodeA2 = (Node) semIm.getSemPm().getGraph().getChildren(nodeA2).iterator().next();
                }
                if (nodeA2.getNodeType() == NodeType.LATENT) {
                    dArr4[((Integer) this.latentNames.get(nodeA2.getName())).intValue()][((Integer) this.latentNames.get(nodeA2.getName())).intValue()] = paramValues[i6];
                } else {
                    dArr3[((Integer) this.observableNames.get(nodeA2.getName())).intValue()][((Integer) this.observableNames.get(nodeA2.getName())).intValue()] = paramValues[i6];
                }
            }
        }
        for (Node node : semIm.getSemPm().getGraph().getNodes()) {
            if (node.getNodeType() == NodeType.MEASURED && ((Boolean) this.fixedMeasures.get(node.getName())).booleanValue()) {
                Iterator it = semIm.getSemPm().getGraph().getParents(node).iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    Node node2 = (Node) it.next();
                    if (node2.getNodeType() == NodeType.LATENT) {
                        dArr2[((Integer) this.latentNames.get(node2.getName())).intValue()][((Integer) this.observableNames.get(node.getName())).intValue()] = 1.0d;
                        break;
                    }
                }
            }
        }
        double[][] dArr5 = new double[this.numLatent][this.numLatent];
        for (int i7 = 0; i7 < this.numLatent; i7++) {
            for (int i8 = 0; i8 < this.numLatent; i8++) {
                if (i7 == i8) {
                    dArr5[i7][i8] = 1.0d;
                } else {
                    dArr5[i7][i8] = 0.0d;
                }
            }
        }
        double[][] inverseGj = MatrixUtils.inverseGj(MatrixUtils.difference(dArr5, dArr), this.numLatent);
        double[][] product = MatrixUtils.product(inverseGj, MatrixUtils.product(dArr4, MatrixUtils.transpose(inverseGj)));
        double[][] inverseGj2 = MatrixUtils.inverseGj(MatrixUtils.sum(dArr3, MatrixUtils.product(MatrixUtils.product(MatrixUtils.transpose(dArr2), product), dArr2)), this.numObserved);
        double[][] product2 = MatrixUtils.product(MatrixUtils.transpose(dArr2), product);
        double[][] product3 = MatrixUtils.product(inverseGj2, product2);
        double[][] difference = MatrixUtils.difference(product, MatrixUtils.product(MatrixUtils.transpose(product2), product3));
        this.Cyz = MatrixUtils.product(this.Cyy, product3);
        this.Czz = MatrixUtils.sum(MatrixUtils.product(MatrixUtils.product(MatrixUtils.transpose(product3), this.Cyy), product3), difference);
    }

    private SemIm maximization(Graph graph) {
        SemPm semPm = new SemPm(new SemGraph(new ProtoSemGraph(graph)));
        fixParameters(semPm);
        SemIm newInstance = SemIm.newInstance(semPm);
        newInstance.setCovMatrix(this.covMatrix);
        Iterator it = newInstance.getFixedParameters().iterator();
        while (it.hasNext()) {
            newInstance.setFixedParamValue((Parameter) it.next(), 1.0d);
        }
        for (Node node : newInstance.getSemPm().getGraph().getNodes()) {
            if (node.getNodeType() == NodeType.MEASURED) {
                Node node2 = null;
                Node node3 = null;
                for (Node node4 : newInstance.getSemPm().getGraph().getParents(node)) {
                    if (node4.getNodeType() == NodeType.LATENT) {
                        node3 = node4;
                    } else {
                        node2 = node4;
                    }
                }
                int intValue = ((Integer) this.observableNames.get(node.getName())).intValue();
                int intValue2 = ((Integer) this.latentNames.get(node3.getName())).intValue();
                if (!((Boolean) this.fixedMeasures.get(node.getName())).booleanValue()) {
                    newInstance.setParamValue(node3, node, this.Cyz[intValue][intValue2] / this.Czz[intValue2][intValue2]);
                }
                newInstance.setParamValue(node2, node2, this.Cyy[intValue][intValue] - ((this.Cyz[intValue][intValue2] * this.Cyz[intValue][intValue2]) / this.Czz[intValue2][intValue2]));
            } else if (node.getNodeType() == NodeType.LATENT) {
                if (newInstance.getSemPm().getGraph().isExogenous(node)) {
                    int intValue3 = ((Integer) this.latentNames.get(node.getName())).intValue();
                    newInstance.setParamValue(node, node, this.Czz[intValue3][intValue3]);
                } else {
                    List<Node> parents = newInstance.getSemPm().getGraph().getParents(node);
                    Node[] nodeArr = new Node[parents.size() - 1];
                    Node node5 = null;
                    int intValue4 = ((Integer) this.latentNames.get(node.getName())).intValue();
                    int[] iArr = new int[parents.size() - 1];
                    int i = 0;
                    for (Node node6 : parents) {
                        if (node6.getNodeType() == NodeType.LATENT) {
                            iArr[i] = ((Integer) this.latentNames.get(node6.getName())).intValue();
                            int i2 = i;
                            i++;
                            nodeArr[i2] = node6;
                        } else {
                            node5 = node6;
                        }
                    }
                    double[] dArr = new double[iArr.length];
                    double[][] dArr2 = new double[iArr.length][iArr.length];
                    for (int i3 = 0; i3 < iArr.length; i3++) {
                        dArr[i3] = this.Czz[intValue4][iArr[i3]];
                        for (int i4 = 0; i4 < iArr.length; i4++) {
                            dArr2[i3][i4] = this.Czz[iArr[i3]][iArr[i4]];
                        }
                    }
                    double[] product = MatrixUtils.product(MatrixUtils.inverseGj(dArr2, iArr.length), dArr);
                    for (int i5 = 0; i5 < nodeArr.length; i5++) {
                        newInstance.setParamValue(nodeArr[i5], node, product[i5]);
                    }
                    newInstance.setParamValue(node5, node5, this.Czz[intValue4][intValue4] - MatrixUtils.innerProduct(dArr, product));
                }
            }
        }
        return newInstance;
    }

    private double logDetSigma(SemIm semIm) {
        return Math.log(MatrixUtils.determinant(semIm.getImplCovarMeas()));
    }

    private double traceSSigmaInv(SemIm semIm) {
        return MatrixUtils.trace(MatrixUtils.product(semIm.getSampleCovar(), MatrixUtils.inverseGj(semIm.getImplCovarMeas(), semIm.getImplCovarMeas().length)));
    }

    private double[][] getJacobian(SemIm semIm) {
        buildJacobianInformation(semIm);
        double[][] dArr = new double[(this.numIndicators * (this.numIndicators + 1)) / 2][semIm.getNumFreeParams()];
        int i = 0;
        int i2 = 0;
        while (i2 < this.numIndicators) {
            for (int i3 = i2; i3 < this.numIndicators; i3++) {
                int i4 = 0;
                int i5 = 0;
                while (i5 < this.numIndicators) {
                    if (this.lambdaIndex[i5] >= 0) {
                        if (i5 == i2 || i5 == i3) {
                            dArr[i][i4] = (i5 == i2 ? this.lambdaIndex[i3] < 0 ? 1.0d : this.theta[this.lambdaIndex[i3]] : this.lambdaIndex[i2] < 0 ? 1.0d : this.theta[this.lambdaIndex[i2]]) * this.latentImpliedCovar[this.indicatorParents[i2]][this.indicatorParents[i3]];
                            if (i2 == i3) {
                                double[] dArr2 = dArr[i];
                                int i6 = i4;
                                dArr2[i6] = dArr2[i6] * 2.0d;
                            }
                            i4++;
                        } else {
                            int i7 = i4;
                            i4++;
                            dArr[i][i7] = 0.0d;
                        }
                    }
                    i5++;
                }
                for (int i8 = 0; i8 < this.numIndicators; i8++) {
                    dArr[i][i4 + i8] = 0.0d;
                }
                if (i2 == i3) {
                    dArr[i][i4 + i2] = 1.0d;
                }
                int i9 = i4 + this.numIndicators;
                for (int i10 = 0; i10 < this.numLatents; i10++) {
                    int[] iArr = (int[]) this.latentParents.get(i10);
                    for (int i11 = 0; i11 < iArr.length; i11++) {
                        int i12 = i9;
                        i9++;
                        dArr[i][i12] = (this.deltaB0[i2][i10] * this.deltaB1[iArr[i11]][i3]) + (this.deltaB0[i3][i10] * this.deltaB1[iArr[i11]][i2]);
                    }
                }
                for (int i13 = 0; i13 < this.numLatents; i13++) {
                    int i14 = i9;
                    i9++;
                    dArr[i][i14] = this.deltaB0[i2][i13] * this.deltaB2[i13][i3];
                }
                i++;
            }
            i2++;
        }
        return dArr;
    }

    private void buildIndices(Graph graph) {
        this.numLatents = 0;
        this.numIndicators = 0;
        for (Node node : graph.getNodes()) {
            if (node.getNodeType() == NodeType.MEASURED) {
                this.numIndicators++;
            } else if (node.getNodeType() == NodeType.LATENT) {
                this.numLatents++;
            }
        }
        this.latentsArray = new Node[this.numLatents];
        this.numLatents = 0;
        this.indicatorsArray = new Node[this.numIndicators];
        this.numIndicators = 0;
        for (Node node2 : graph.getNodes()) {
            if (node2.getNodeType() == NodeType.MEASURED) {
                Node[] nodeArr = this.indicatorsArray;
                int i = this.numIndicators;
                this.numIndicators = i + 1;
                nodeArr[i] = node2;
            } else if (node2.getNodeType() == NodeType.LATENT) {
                Node[] nodeArr2 = this.latentsArray;
                int i2 = this.numLatents;
                this.numLatents = i2 + 1;
                nodeArr2[i2] = node2;
            }
        }
        boolean[] zArr = new boolean[this.numIndicators];
        this.indicatorParents = new int[this.numIndicators];
        for (int i3 = 0; i3 < this.numIndicators; i3++) {
            zArr[i3] = false;
            int i4 = 0;
            while (true) {
                if (i4 < this.numLatents) {
                    if (graph.isParentOf(this.latentsArray[i4], this.indicatorsArray[i3])) {
                        this.indicatorParents[i3] = i4;
                        break;
                    }
                    i4++;
                }
            }
        }
        for (int i5 = 0; i5 < this.numLatents; i5++) {
            int i6 = -1;
            for (int i7 = 0; i7 < this.numIndicators; i7++) {
                if (this.indicatorParents[i7] == i5 && (i6 == -1 || this.indicatorsArray[i7].getName().compareTo(this.indicatorsArray[i6].getName()) < 0)) {
                    i6 = i7;
                }
            }
            zArr[i6] = true;
        }
        int i8 = 0;
        this.lambdaIndex = new int[this.numIndicators];
        for (int i9 = 0; i9 < this.numIndicators; i9++) {
            if (zArr[i9]) {
                this.lambdaIndex[i9] = -1;
            } else {
                int i10 = i8;
                i8++;
                this.lambdaIndex[i9] = i10;
            }
        }
        this.latentParents = new LinkedList();
        this.beta = new double[this.numLatents][this.numLatents];
        this.fi = new double[this.numLatents][this.numLatents];
        this.iBeta = new double[this.numLatents][this.numLatents];
        for (int i11 = 0; i11 < this.numLatents; i11++) {
            LinkedList linkedList = new LinkedList();
            for (int i12 = 0; i12 < this.numLatents; i12++) {
                if (graph.isParentOf(this.latentsArray[i12], this.latentsArray[i11])) {
                    linkedList.add(new Integer(i12));
                }
                this.beta[i11][i12] = 0.0d;
                this.fi[i11][i12] = 0.0d;
                if (i11 == i12) {
                    this.iBeta[i11][i12] = 1.0d;
                } else {
                    this.iBeta[i11][i12] = 0.0d;
                }
            }
            int[] iArr = new int[linkedList.size()];
            for (int i13 = 0; i13 < linkedList.size(); i13++) {
                iArr[i13] = ((Integer) linkedList.get(i13)).intValue();
            }
            this.latentParents.add(iArr);
        }
        this.betaIndex = new int[this.numLatents];
        this.betaIndex[0] = this.numIndicators - this.numLatents;
        for (int i14 = 1; i14 < this.numLatents; i14++) {
            this.betaIndex[i14] = this.betaIndex[i14 - 1] + ((int[]) this.latentParents.get(i14 - 1)).length;
        }
        this.indicatorErrorsIndex = this.betaIndex[this.numLatents - 1] + ((int[]) this.latentParents.get(this.numLatents - 1)).length;
        this.latentErrorsIndex = this.indicatorErrorsIndex + this.numIndicators;
        this.bigLambda = new double[this.numIndicators][this.numLatents];
        for (int i15 = 0; i15 < this.numIndicators; i15++) {
            for (int i16 = 0; i16 < this.numLatents; i16++) {
                this.bigLambda[i15][i16] = 0.0d;
            }
        }
        this.J = new double[this.numLatents][this.numLatents];
        for (int i17 = 0; i17 < this.numLatents; i17++) {
            for (int i18 = 0; i18 < this.numLatents; i18++) {
                this.J[i17][i18] = 0.0d;
            }
        }
    }

    private void getTheta(SemIm semIm) {
        this.theta = new double[this.latentErrorsIndex + this.numLatents];
        for (Parameter parameter : semIm.getFreeParameters()) {
            if (parameter.getType() == ParamType.COEF) {
                Node nodeA = parameter.getNodeA();
                Node nodeB = parameter.getNodeB();
                int i = -1;
                int i2 = -1;
                boolean z = false;
                int i3 = 0;
                while (true) {
                    if (i3 >= this.numLatents) {
                        break;
                    }
                    if (this.latentsArray[i3].getName().equals(nodeA.getName())) {
                        i = i3;
                        break;
                    }
                    i3++;
                }
                int i4 = 0;
                while (true) {
                    if (i4 >= this.numIndicators) {
                        break;
                    }
                    if (this.indicatorsArray[i4].getName().equals(nodeB.getName())) {
                        z = true;
                        i2 = i4;
                        break;
                    }
                    i4++;
                }
                if (!z) {
                    int i5 = 0;
                    while (true) {
                        if (i5 >= this.numLatents) {
                            break;
                        }
                        if (this.latentsArray[i5].getName().equals(nodeB.getName())) {
                            i2 = i5;
                            break;
                        }
                        i5++;
                    }
                }
                if (!z) {
                    int[] iArr = (int[]) this.latentParents.get(i2);
                    int i6 = 0;
                    while (true) {
                        if (i6 < iArr.length) {
                            if (this.latentsArray[iArr[i6]].getName().equals(this.latentsArray[i].getName())) {
                                this.theta[this.betaIndex[i2] + i6] = semIm.getParamValue(parameter);
                                break;
                            }
                            i6++;
                        }
                    }
                } else if (this.lambdaIndex[i2] >= 0) {
                    this.theta[this.lambdaIndex[i2]] = semIm.getParamValue(parameter);
                }
            } else {
                Node nodeA2 = parameter.getNodeA();
                Node node = nodeA2.getNodeType() == NodeType.LATENT ? nodeA2 : (Node) semIm.getSemPm().getGraph().getChildren(nodeA2).iterator().next();
                int i7 = 0;
                while (true) {
                    if (i7 < this.numLatents) {
                        if (this.latentsArray[i7].getName().equals(node.getName())) {
                            this.theta[this.latentErrorsIndex + i7] = semIm.getParamValue(parameter);
                            break;
                        }
                        i7++;
                    }
                }
            }
        }
    }

    private void buildJacobianInformation(SemIm semIm) throws IllegalArgumentException {
        buildIndices(semIm.getSemPm().getGraph());
        getTheta(semIm);
        for (int i = 0; i < this.numLatents; i++) {
            int[] iArr = (int[]) this.latentParents.get(i);
            for (int i2 = 0; i2 < iArr.length; i2++) {
                this.beta[i][iArr[i2]] = this.theta[this.betaIndex[i] + i2];
            }
            this.fi[i][i] = this.theta[this.latentErrorsIndex + i];
        }
        this.iMinusB = MatrixUtils.inverseGj(MatrixUtils.difference(this.iBeta, this.beta), this.iBeta.length);
        this.iMinusBT = MatrixUtils.transpose(this.iMinusB);
        this.latentImpliedCovar = MatrixUtils.product(this.iMinusB, MatrixUtils.product(this.fi, this.iMinusBT));
        for (int i3 = 0; i3 < this.numIndicators; i3++) {
            if (this.lambdaIndex[i3] < 0) {
                this.bigLambda[i3][this.indicatorParents[i3]] = 1.0d;
            } else {
                this.bigLambda[i3][this.indicatorParents[i3]] = this.theta[this.lambdaIndex[i3]];
            }
        }
        this.deltaB0 = MatrixUtils.product(this.bigLambda, this.iMinusB);
        this.deltaB1 = MatrixUtils.transpose(MatrixUtils.product(MatrixUtils.product(this.deltaB0, this.fi), this.iMinusBT));
        this.deltaB2 = MatrixUtils.transpose(this.deltaB0);
    }
}
