package edu.cmu.tetrad.ind;

import be.ac.vub.ir.statistics.EntropyEstimatorCache;
import be.ac.vub.ir.statistics.InformationWEntropy;
import be.ac.vub.ir.statistics.StatUtilsExt;
import be.ac.vub.ir.statistics.estimators.KdeOptEntropyEstimator;
import edu.cmu.tetrad.data.ContinuousDataSet;
import edu.cmu.tetrad.data.CorrAndFittMatrix;
import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.DiscreteDataSet;
import edu.cmu.tetrad.data.MixedDataSet;
import edu.cmu.tetrad.data.TimeSeriesData;
import edu.cmu.tetrad.data.Variable;
import edu.cmu.tetrad.graph.Graph;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:edu/cmu/tetrad/ind/IndTestFactory.class */
public class IndTestFactory implements Serializable {
    static final long serialVersionUID = 23;
    public static IndependenceMap lastUsedIndependenceMap;

    public static IndependenceTest getTest(Object obj, SearchParams searchParams) {
        if (searchParams == null) {
            throw new NullPointerException("SearchParams must not be null.");
        }
        return getTest(obj, getBasicIndTestParams(searchParams));
    }

    public static IndependenceTest getTest(Object obj, IndTestParams indTestParams) {
        IndependenceTest indTestXSquare2;
        InformationWEntropy informationWEntropy;
        IndependenceTest informationWEntropy2;
        if (obj == null) {
            throw new NullPointerException("Data source must not be null.");
        }
        if (indTestParams == null) {
            indTestParams = new IndTestParams();
        }
        if (indTestParams.getTest() == 0 && (obj instanceof DataSet)) {
            obj = DataUtils.toContinuousDataSet((DataSet) obj);
        }
        if (obj instanceof ContinuousDataSet) {
            ContinuousDataSet continuousDataSet = (ContinuousDataSet) obj;
            if (indTestParams.getTest() == 0) {
                informationWEntropy2 = indTestParams.getCorrOrder() <= 1 ? new IndTestCorrMatrix(continuousDataSet, indTestParams.getAlpha()) : new IndTestCorrAndFitMatrix(continuousDataSet, indTestParams.getAlpha(), indTestParams.getCorrOrder());
            } else {
                if (indTestParams.getTest() != 1) {
                    throw new IllegalArgumentException("Invalid independence test parameter (" + IndTestParams.test2String(indTestParams.getTest()) + ") for continuous data");
                }
                if (indTestParams.getKdeParams() == null) {
                    informationWEntropy2 = new InformationWEntropy(continuousDataSet, (float) indTestParams.getAlpha());
                    indTestParams.setKdeParams(StatUtilsExt.getKDEParams((InformationWEntropy) informationWEntropy2));
                } else {
                    informationWEntropy2 = new InformationWEntropy(continuousDataSet, (float) indTestParams.getAlpha(), indTestParams.getKdeParams());
                }
            }
            List variables = continuousDataSet.getVariables();
            if (lastUsedIndependenceMap == null || !variablesListMatch(lastUsedIndependenceMap.getVariables(), variables)) {
                if (lastUsedIndependenceMap != null) {
                    List<Variable> variables2 = lastUsedIndependenceMap.getVariables();
                    System.err.println("Independence map is not conform dataset (" + variables2 + " vs " + variables + ").\nIt is not used.");
                    System.err.print("Variables missing in independence map: ");
                    for (int i = 0; i < variables.size(); i++) {
                        if (!variables2.contains(variables.get(i))) {
                            System.err.print(variables.get(i) + ", ");
                        }
                    }
                    System.err.println();
                }
                lastUsedIndependenceMap = new IndependenceMap((DataSet) continuousDataSet, informationWEntropy2);
            } else {
                lastUsedIndependenceMap.setIndependenceTest(informationWEntropy2);
            }
            return lastUsedIndependenceMap;
        }
        if (obj instanceof MixedDataSet) {
            MixedDataSet mixedDataSet = (MixedDataSet) obj;
            if (indTestParams.getKdeParams() == null) {
                informationWEntropy = new InformationWEntropy(mixedDataSet, (float) indTestParams.getAlpha());
                indTestParams.setKdeParams(StatUtilsExt.getKDEParams(informationWEntropy));
            } else {
                informationWEntropy = new InformationWEntropy(mixedDataSet, (float) indTestParams.getAlpha(), indTestParams.getKdeParams());
            }
            if (lastUsedIndependenceMap == null || !variablesListMatch(lastUsedIndependenceMap.getVariables(), mixedDataSet.getVariables())) {
                lastUsedIndependenceMap = new IndependenceMap((DataSet) mixedDataSet, (IndependenceTest) informationWEntropy);
            } else {
                lastUsedIndependenceMap.setIndependenceTest(informationWEntropy);
            }
            return lastUsedIndependenceMap;
        }
        if (obj instanceof DiscreteDataSet) {
            DiscreteDataSet discreteDataSet = (DiscreteDataSet) obj;
            switch (indTestParams.getTest()) {
                case 1:
                    indTestXSquare2 = new IndependenceMap((DataSet) discreteDataSet, (IndependenceTest) new InformationWEntropy(discreteDataSet, new EntropyEstimatorCache(new KdeOptEntropyEstimator(discreteDataSet)), (float) indTestParams.getAlpha()));
                    break;
                case 2:
                    indTestXSquare2 = new IndTestXSquare2(discreteDataSet, indTestParams.getAlpha());
                    break;
                case 3:
                    indTestXSquare2 = new IndTestGSquare(discreteDataSet, indTestParams.getAlpha());
                    break;
                case 4:
                    indTestXSquare2 = new IndTestGSquare2(discreteDataSet, indTestParams.getAlpha());
                    break;
                default:
                    throw new IllegalArgumentException("Invalid independence test parameter (" + IndTestParams.test2String(indTestParams.getTest()) + ") for discrete data");
            }
            return new IndTestWithEquivalence(indTestXSquare2);
        }
        if (obj instanceof Graph) {
            return new IndTestGraph((Graph) obj);
        }
        if (obj instanceof CovarianceMatrix) {
            return new IndTestCorrMatrix(new CorrelationMatrix((CovarianceMatrix) obj), indTestParams.getAlpha());
        }
        if (obj instanceof CorrelationMatrix) {
            return new IndTestCorrMatrix((CorrelationMatrix) obj, indTestParams.getAlpha());
        }
        if (obj instanceof CorrAndFittMatrix) {
            return new IndTestCorrAndFitMatrix((CorrAndFittMatrix) obj, indTestParams.getAlpha(), indTestParams.getCorrOrder());
        }
        if (!(obj instanceof TimeSeriesData)) {
            throw new IllegalStateException("Data source must be either a ContinuousDataSet, a DiscreteDataSet, or a Graph: " + obj.getClass());
        }
        TimeSeriesData timeSeriesData = (TimeSeriesData) obj;
        TimeSeriesIndTestParams timeSeriesIndTestParams = getTimeSeriesIndTestParams(timeSeriesData, indTestParams);
        IndTestTimeSeries indTestTimeSeries = new IndTestTimeSeries(timeSeriesData.getData(), Arrays.asList(timeSeriesData.getVariableNames()));
        indTestTimeSeries.setAlpha(timeSeriesIndTestParams.getAlpha());
        indTestTimeSeries.setNumLags(timeSeriesIndTestParams.getNumLags());
        return indTestTimeSeries;
    }

    protected static boolean variablesListMatch(List list, List list2) {
        if (list.size() < list2.size()) {
            return false;
        }
        for (int i = 0; i < list2.size(); i++) {
            if (!list.contains(list2.get(i))) {
                return false;
            }
        }
        return true;
    }

    public static void setIndTestParams(Object obj, SearchParams searchParams) {
        if (obj == null) {
            throw new NullPointerException("Data source must not be null.");
        }
        if (searchParams == null) {
            throw new NullPointerException("SearchParams must not be null.");
        }
        if (obj instanceof ContinuousDataSet) {
            getBasicIndTestParams(searchParams);
            return;
        }
        if (obj instanceof DiscreteDataSet) {
            getBasicIndTestParams(searchParams);
            return;
        }
        if (obj instanceof MixedDataSet) {
            getBasicIndTestParams(searchParams);
            return;
        }
        if (obj instanceof Graph) {
            return;
        }
        if (obj instanceof CovarianceMatrix) {
            getBasicIndTestParams(searchParams);
        } else if (obj instanceof CorrelationMatrix) {
            getBasicIndTestParams(searchParams);
        } else {
            if (!(obj instanceof TimeSeriesData)) {
                throw new IllegalStateException("Data source must be either a ContinuousDataSet, a DiscreteDataSet, a MixedDataSet, or a Graph, it is a: " + obj.getClass());
            }
            getTimeSeriesIndTestParams((TimeSeriesData) obj, searchParams);
        }
    }

    private static IndTestParams getBasicIndTestParams(SearchParams searchParams) {
        IndTestParams indTestParams = searchParams.getIndTestParams();
        if (indTestParams == null) {
            indTestParams = new IndTestParams();
            searchParams.setIndTestParams(indTestParams);
        }
        return indTestParams;
    }

    private static TimeSeriesIndTestParams getTimeSeriesIndTestParams(TimeSeriesData timeSeriesData, SearchParams searchParams) {
        IndTestParams indTestParams = searchParams.getIndTestParams();
        if (indTestParams == null || !(indTestParams instanceof TimeSeriesIndTestParams) || getOldNumTimePoints(indTestParams) != timeSeriesData.getNumTimePoints()) {
            indTestParams = new TimeSeriesIndTestParams(timeSeriesData);
            searchParams.setIndTestParams(indTestParams);
        }
        return (TimeSeriesIndTestParams) indTestParams;
    }

    private static TimeSeriesIndTestParams getTimeSeriesIndTestParams(TimeSeriesData timeSeriesData, IndTestParams indTestParams) {
        if (indTestParams == null || !(indTestParams instanceof TimeSeriesIndTestParams) || getOldNumTimePoints(indTestParams) != timeSeriesData.getNumTimePoints()) {
            indTestParams = new TimeSeriesIndTestParams(timeSeriesData);
        }
        return (TimeSeriesIndTestParams) indTestParams;
    }

    private static int getOldNumTimePoints(IndTestParams indTestParams) {
        return ((TimeSeriesIndTestParams) indTestParams).getNumTimePoints();
    }
}
