package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.EndpointMatrixGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.ind.Knowledge;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizerEm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.ObjectPair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/cmu/tetrad/search/MimBuild.class */
public class MimBuild extends PcStub {
    static final long serialVersionUID = 23;
    public static final String LATENT_PREFIX = "_L";
    protected List latents;

    public MimBuild(IndTestMimBuild indTestMimBuild, Knowledge knowledge) {
        super(indTestMimBuild, knowledge);
        this.latents = new ArrayList();
    }

    @Override // edu.cmu.tetrad.search.PcStub, edu.cmu.tetrad.search.SearchAlgorithm
    public Graph search() {
        if (getIndependenceTest() == null) {
            throw new NullPointerException();
        }
        return (((IndTestMimBuild) getIndependenceTest()).getAlgorithmType() == 0 || ((IndTestMimBuild) getIndependenceTest()).getAlgorithmType() == -1) ? mimBuildGesSearch() : mimBuildPcSearch();
    }

    private Graph mimBuildPcSearch() {
        EndpointMatrixGraph endpointMatrixGraph = new EndpointMatrixGraph(((IndTestMimBuild) getIndependenceTest()).getVariableList());
        startMeasurementModel(endpointMatrixGraph);
        this.nodes = endpointMatrixGraph.getNodes();
        SepsetMatrix adjSearch = new MimAdjacencySearch(endpointMatrixGraph, getIndependenceTest(), getKnowledge(), this.latents).adjSearch();
        EndpointMatrixGraph endpointMatrixGraph2 = new EndpointMatrixGraph(this.latents);
        for (Node node : this.latents) {
            for (Node node2 : endpointMatrixGraph.getAdjacentNodes(node)) {
                if (this.latents.contains(node2)) {
                    endpointMatrixGraph2.setEndpoint(node2, node, Endpoint.SEGMENT);
                }
            }
        }
        this.stop = 3;
        pcOrient(adjSearch, getKnowledge(), endpointMatrixGraph2);
        for (Node node3 : this.latents) {
            Iterator it = endpointMatrixGraph2.nodesInTo(node3, Endpoint.ARROW).iterator();
            while (it.hasNext()) {
                endpointMatrixGraph.setEndpoint((Node) it.next(), node3, Endpoint.ARROW);
            }
        }
        return endpointMatrixGraph;
    }

    private Graph mimBuildGesSearch() {
        double d;
        Graph endpointMatrixGraph = new EndpointMatrixGraph(((IndTestMimBuild) getIndependenceTest()).getVariableList());
        startMeasurementModel(endpointMatrixGraph);
        String[] strArr = new String[endpointMatrixGraph.getNumNodes()];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = endpointMatrixGraph.getNodes().get(i).toString();
        }
        String[] strArr2 = new String[this.latents.size()];
        for (int i2 = 0; i2 < this.latents.size(); i2++) {
            strArr2[i2] = this.latents.get(i2).toString();
        }
        CovarianceMatrix covMatrix = ((IndTestMimBuild) getIndependenceTest()).getCovMatrix();
        GraphUtils.pdagToDag(endpointMatrixGraph);
        SemOptimizerEm semOptimizerEm = new SemOptimizerEm();
        SemEstimator newInstance = SemEstimator.newInstance(covMatrix, new SemPm(endpointMatrixGraph), semOptimizerEm);
        newInstance.estimate();
        double scoreModel = scoreModel(newInstance.getEstimatedSem());
        do {
            CovarianceMatrix submatrix = new CovarianceMatrix(strArr, semOptimizerEm.getExpectedCovarianceMatrix(), covMatrix.getSampleSize()).getSubmatrix(strArr2);
            d = scoreModel;
            Graph search = new GesSearch(submatrix, getKnowledge()).search();
            EdgeListGraph edgeListGraph = new EdgeListGraph(search);
            GraphUtils.pdagToDag(edgeListGraph);
            SemEstimator newInstance2 = SemEstimator.newInstance(covMatrix, new SemPm(getUpdatedGraph(endpointMatrixGraph, edgeListGraph)), semOptimizerEm);
            newInstance2.estimate();
            scoreModel = scoreModel(newInstance2.getEstimatedSem());
            if (scoreModel > d) {
                endpointMatrixGraph = getUpdatedGraph(endpointMatrixGraph, search);
            }
        } while (scoreModel > d);
        return endpointMatrixGraph;
    }

    private Graph getUpdatedGraph(Graph graph, Graph graph2) {
        EdgeListGraph edgeListGraph = new EdgeListGraph(graph);
        ArrayList arrayList = new ArrayList();
        for (Edge edge : edgeListGraph.getEdges()) {
            if (edge.getNode1().getNodeType() == NodeType.LATENT && edge.getNode2().getNodeType() == NodeType.LATENT) {
                arrayList.add(edge);
            }
        }
        edgeListGraph.removeEdges(arrayList);
        for (Edge edge2 : graph2.getEdges()) {
            Node node = edgeListGraph.getNode(edge2.getNode1().toString());
            Node node2 = edgeListGraph.getNode(edge2.getNode2().toString());
            edgeListGraph.setEndpoint(node2, node, edge2.getEndpoint1());
            edgeListGraph.setEndpoint(node, node2, edge2.getEndpoint2());
        }
        return edgeListGraph;
    }

    private double scoreModel(SemIm semIm) {
        return (-semIm.getFml()) - (semIm.getNumFreeParams() * Math.log(semIm.getSampleSize()));
    }

    protected void startMeasurementModel(Graph graph) {
        Iterator requiredEdgesIterator = ((IndTestMimBuild) getIndependenceTest()).getMeasurements().requiredEdgesIterator();
        while (requiredEdgesIterator.hasNext()) {
            ObjectPair objectPair = (ObjectPair) requiredEdgesIterator.next();
            String str = (String) objectPair.getA();
            String str2 = (String) objectPair.getB();
            Node node = graph.getNode(str);
            Node node2 = graph.getNode(str2);
            graph.setEndpoint(node, node2, Endpoint.ARROW);
            graph.setEndpoint(node2, node, Endpoint.SEGMENT);
            if (this.latents.indexOf(node) == -1) {
                this.latents.add(node);
                node.setNodeType(NodeType.LATENT);
            }
        }
        int size = this.latents.size();
        Iterator it = this.latents.iterator();
        Node[] nodeArr = new Node[size];
        int i = 0;
        while (it.hasNext()) {
            int i2 = i;
            i++;
            nodeArr[i2] = (Node) it.next();
        }
        for (int i3 = 0; i3 < size; i3++) {
            for (int i4 = i3 + 1; i4 < size; i4++) {
                graph.setEndpoint(nodeArr[i3], nodeArr[i4], Endpoint.SEGMENT);
                graph.setEndpoint(nodeArr[i4], nodeArr[i3], Endpoint.SEGMENT);
            }
        }
    }

    public static String[] getTestDescriptions() {
        return new String[]{new String("Gaussian maximum likelihood"), new String("Two-stage least squares")};
    }

    public static String[] getAlgorithmDescriptions() {
        return new String[]{new String("GES"), new String("PC search")};
    }

    public static List generateLatentNames(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new String(LATENT_PREFIX + i2));
        }
        return arrayList;
    }
}
