package be.ac.vub.ir.statistics.estimators;

import be.ac.vub.ir.statistics.InformationWEntropy;
import be.ac.vub.ir.statistics.bandwidthselectors.ConstantBandwidthSelector;
import be.ac.vub.ir.statistics.bandwidthselectors.FixingBWSelectorAdapter;
import be.ac.vub.ir.statistics.bandwidthselectors.JansNearestNeighboursBandwidthSelector;
import be.ac.vub.ir.statistics.estimators.KDEtest;
import edu.cmu.tetrad.data.ContinuousDataSet;
import edu.cmu.tetrad.data.DataLoaders;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.Variable;
import edu.cmu.tetrad.ind.IndTestCorrAndFitMatrix;
import edu.cmu.tetrad.ind.IndTestCorrMatrix;
import edu.cmu.tetrad.ind.IndTestParams;
import edu.cmu.tetrad.ind.IndependenceTest;
import edu.cmu.tetradapp.util.ExecutableProgressMonitor;
import java.text.NumberFormat;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;

/* loaded from: input_file:be/ac/vub/ir/statistics/estimators/KDEreliabilityTests.class */
public class KDEreliabilityTests {
    static final String DATABASE = "KDE";
    static final int MIN_NBR_SAMPLES = 2;
    static final int DELTA_NBR_SAMPLES = 15;
    static IndTestParams[] independenceTests = {new IndTestParams(new KdeParams(new ConstantBandwidthSelector())), new IndTestParams(new KdeParams(new FixingBWSelectorAdapter(new JansNearestNeighboursBandwidthSelector()))), new IndTestParams(0)};
    static NumberFormat sF = NumberFormat.getNumberInstance();
    public boolean printFalseTests = true;
    IndTestParams indTestParams = independenceTests[0];
    TreeMap<String, IndependenceTest> independentTests = new TreeMap<>();
    int start = 0;
    int end = KDEtest.kdeTests.length;
    float startbwf = 0.25f;

    static {
        sF.setMaximumFractionDigits(2);
    }

    void reset() {
        resetCache();
        resetTests();
    }

    void resetCache() {
        this.independentTests = new TreeMap<>();
    }

    IndependenceTest getTest(KDEtest kDEtest) {
        if (this.independentTests.containsKey(kDEtest.fileName)) {
            IndependenceTest independenceTest = this.independentTests.get(kDEtest.fileName);
            kDEtest.setData(independenceTest.getData());
            return independenceTest;
        }
        IndependenceTest loadDataAndTest = loadDataAndTest(kDEtest);
        if (loadDataAndTest != null) {
            this.independentTests.put(kDEtest.fileName, loadDataAndTest);
        }
        return loadDataAndTest;
    }

    IndependenceTest loadDataAndTest(KDEtest kDEtest) {
        if (kDEtest.dataSet == null) {
            kDEtest.setData(DataLoaders.loadDataFromGivenFile(kDEtest.fileName));
        }
        if (kDEtest.dataSet != null) {
            return createTest(kDEtest);
        }
        System.err.println("File of test not found: " + kDEtest.fileName + ", test is skipped.");
        return null;
    }

    IndependenceTest createTest(KDEtest kDEtest) {
        if (this.indTestParams.getTest() == 0 && !(kDEtest.dataSet instanceof ContinuousDataSet)) {
            if (!kDEtest.doContinuousTest) {
                return null;
            }
            kDEtest.setData(DataUtils.toContinuousDataSet(kDEtest.dataSet));
        }
        if (this.indTestParams.getTest() == 0) {
            return this.indTestParams.getCorrOrder() <= 1 ? new IndTestCorrMatrix((ContinuousDataSet) kDEtest.dataSet, this.indTestParams.getAlpha()) : new IndTestCorrAndFitMatrix((ContinuousDataSet) kDEtest.dataSet, this.indTestParams.getAlpha(), this.indTestParams.getCorrOrder());
        }
        if (this.indTestParams.getTest() == 1) {
            return new InformationWEntropy(kDEtest.dataSet, (float) this.indTestParams.getAlpha(), this.indTestParams.getKdeParams());
        }
        throw new IllegalArgumentException("Invalid independence test parameter (" + IndTestParams.test2String(this.indTestParams.getTest()) + ") for dataset ");
    }

    KDEtest.TestOutcome performTest(KDEtest kDEtest) {
        return performTest(kDEtest, getTest(kDEtest));
    }

    KDEtest.TestOutcome performTest(KDEtest kDEtest, boolean z) {
        return z ? performTest(kDEtest, createTest(kDEtest)) : performTest(kDEtest, getTest(kDEtest));
    }

    KDEtest.TestOutcome performTest(KDEtest kDEtest, IndependenceTest independenceTest) {
        if (independenceTest == null) {
            return KDEtest.TestOutcome.UNKNOWN;
        }
        Variable variable = kDEtest.dataSet.getVariable(kDEtest.xName);
        if (variable == null) {
            System.err.println("Variable " + kDEtest.xName + " is not present in dataset " + kDEtest.dataSet);
            System.err.println("Test is skipped");
            return KDEtest.TestOutcome.UNKNOWN;
        }
        Variable variable2 = kDEtest.dataSet.getVariable(kDEtest.yName);
        if (variable2 == null) {
            System.err.println("Variable " + kDEtest.yName + " is not present in dataset " + kDEtest.dataSet);
            System.err.println("Test is skipped");
            return KDEtest.TestOutcome.UNKNOWN;
        }
        List variables = kDEtest.condNames == null ? null : DataUtils.getVariables(kDEtest.dataSet.getVariables(), kDEtest.condNames);
        if (kDEtest.condNames != null && variables.size() != kDEtest.condNames.length) {
            System.err.println("Not all variables [" + arrayToString(kDEtest.condNames) + "] are present in dataset " + kDEtest.dataSet);
            System.err.println("Test is skipped");
            return KDEtest.TestOutcome.UNKNOWN;
        }
        boolean isIndependent = independenceTest.isIndependent(variable, variable2, variables);
        kDEtest.setResult(isIndependent, independenceTest.getDependencyStrength(), independenceTest.getCutoff(), kDEtest.dataSet.getMaxRowCount(), this.indTestParams);
        if (this.printFalseTests && !kDEtest.testOK) {
            if (isIndependent) {
                System.out.println("  Test " + kDEtest + " not dependent:   " + sF.format(independenceTest.getDependencyStrength()) + " < " + sF.format(independenceTest.getCutoff()));
            } else {
                System.out.println("  Test " + kDEtest + " not independent: " + sF.format(independenceTest.getDependencyStrength()) + " > " + sF.format(independenceTest.getCutoff()));
            }
        }
        return kDEtest.testOutcome;
    }

    float[] performTests() {
        float[] fArr = new float[5];
        String str = "";
        System.out.println(" for tijdelijk een beetje anders...");
        for (int i = this.start; i < this.end; i++) {
            KDEtest kDEtest = KDEtest.kdeTests[i];
            getTest(kDEtest);
            if (this.printFalseTests && !str.equals(kDEtest.dataSet.getName())) {
                System.out.println("Data from <" + kDEtest.dataSet.getName() + "> (#" + kDEtest.dataSet.getMaxRowCount() + ")");
                str = kDEtest.dataSet.getName();
            }
            int ordinal = performTest(kDEtest).ordinal();
            fArr[ordinal] = fArr[ordinal] + 1.0f;
        }
        for (int i2 = 0; i2 < 5; i2++) {
            int i3 = i2;
            fArr[i3] = fArr[i3] / KDEtest.kdeTests.length;
        }
        return fArr;
    }

    void performTestsPrint() {
        for (KDEtest kDEtest : KDEtest.kdeTests) {
            System.out.println(kDEtest + "    => " + performTest(kDEtest));
        }
    }

    void writeTestsToDB(String str) {
        for (KDEtest kDEtest : KDEtest.kdeTests) {
            kDEtest.writeExperiment(DATABASE, str);
        }
    }

    void resetTests() {
        for (KDEtest kDEtest : KDEtest.kdeTests) {
            kDEtest.reset();
        }
    }

    void checkTests() {
        int i = 0;
        for (KDEtest kDEtest : KDEtest.kdeTests) {
            if (!kDEtest.check()) {
                i++;
            }
        }
        System.out.println("All tests checked: " + i + " errors on " + KDEtest.kdeTests.length + " tests");
    }

    float[] performTestsNbrSamples(boolean z, boolean z2) {
        float[] fArr = new float[ExecutableProgressMonitor.ONE_SECOND];
        int[] iArr = new int[ExecutableProgressMonitor.ONE_SECOND];
        for (KDEtest kDEtest : KDEtest.kdeTests) {
            Map<Integer, Boolean> nbrSamplesNeeded = nbrSamplesNeeded(kDEtest, z2);
            if (z) {
                System.out.println("Test <" + kDEtest + ">: " + nbrSamplesNeeded);
            }
            Iterator<Integer> it = nbrSamplesNeeded.keySet().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (nbrSamplesNeeded.get(Integer.valueOf(intValue)).booleanValue()) {
                    fArr[intValue] = fArr[intValue] + 1.0f;
                }
                iArr[intValue] = iArr[intValue] + 1;
            }
        }
        for (int i = 0; i < fArr.length; i++) {
            if (iArr[i] > 0) {
                int i2 = i;
                fArr[i2] = fArr[i2] / iArr[i];
            } else {
                fArr[i] = -1.0f;
            }
        }
        return fArr;
    }

    public Map<Integer, Boolean> nbrSamplesNeeded(KDEtest kDEtest, boolean z) {
        loadDataAndTest(kDEtest);
        int maxRowCount = kDEtest.dataSet.getMaxRowCount();
        Random random = new Random();
        TreeMap treeMap = new TreeMap();
        int nextInt = 22 - random.nextInt(DELTA_NBR_SAMPLES);
        boolean z2 = false;
        int i = maxRowCount;
        while (true) {
            int i2 = i;
            if (i2 <= 2) {
                return treeMap;
            }
            performTest(kDEtest, true);
            treeMap.put(Integer.valueOf(kDEtest.dataSet.getMaxRowCount()), Boolean.valueOf(kDEtest.testOK));
            if (z) {
                kDEtest.writeExperiment(DATABASE, "sampleSize");
            }
            for (int i3 = 0; i3 < nextInt; i3++) {
                if (kDEtest.dataSet.getMaxRowCount() > 2) {
                    kDEtest.dataSet.removeRow(random.nextInt(kDEtest.dataSet.getMaxRowCount()));
                }
            }
            if (i2 < maxRowCount / 2 && !z2) {
                nextInt /= 2;
                z2 = true;
            }
            i = i2 - nextInt;
        }
    }

    public void testBWF(boolean z) {
        float f = 0.25f;
        float f2 = this.startbwf;
        while (true) {
            float f3 = f2;
            if (f3 >= 15.0f) {
                return;
            }
            this.indTestParams.getKdeParams().setBandwidthFactor(f3);
            this.indTestParams.getKdeParams().setSampleSize(200);
            reset();
            System.out.print("BWF " + f3 + " :  ");
            printStats(performTests());
            if (z) {
                writeTestsToDB("BWF");
            }
            if (f3 >= 1.5d) {
                f = 0.5f;
            }
            if (f3 > 3.0f) {
                f = 1.0f;
            }
            f2 = f3 + f;
        }
    }

    public void testBWF(KDEtest kDEtest) {
        loadDataAndTest(kDEtest);
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(3);
        System.out.print(kDEtest + ": ");
        float f = 0.25f;
        float f2 = 0.25f;
        while (true) {
            float f3 = f2;
            if (f3 >= 15.0f) {
                System.out.println();
                return;
            }
            System.out.print(" " + f3 + " - ");
            this.indTestParams.getKdeParams().setBandwidthFactor(f3);
            performTest(kDEtest, true);
            System.out.print(String.valueOf(numberInstance.format(kDEtest.dependencyStrength)) + ";");
            if (f3 >= 1.5d) {
                f = 0.5f;
            }
            if (f3 >= 4.0f) {
                f = 1.0f;
            }
            f2 = f3 + f;
        }
    }

    public void testIndependenceTests(boolean z) {
        for (IndTestParams indTestParams : independenceTests) {
            this.indTestParams = indTestParams;
            System.out.println("Configuration: " + indTestParams);
            System.out.println("=------------------------------------------=");
            reset();
            float[] performTests = performTests();
            System.out.print(indTestParams + " =>   ");
            printStats(performTests);
            if (z) {
                writeTestsToDB("Tests");
            }
            System.out.println();
        }
    }

    public static String toPct(float f) {
        return String.valueOf((int) (f * 100.0f)) + "%";
    }

    public static void printStats(float[] fArr) {
        System.out.println("POSITIVE " + toPct(fArr[1]) + " vs " + toPct(fArr[2]) + " False ; NEGATIVE " + toPct(fArr[3]) + " vs " + toPct(fArr[4]) + " False (" + toPct(fArr[1] + fArr[3]) + " OK)");
    }

    public static void printPercentages(float[] fArr) {
        for (int i = 0; i < fArr.length; i++) {
            if (fArr[i] >= 0.0f) {
                System.out.println(String.valueOf(i) + " = " + toPct(fArr[i]));
            }
        }
    }

    public static String arrayToString(String[] strArr) {
        String str = "";
        for (int i = 0; i < strArr.length; i++) {
            if (i > 0) {
                str = String.valueOf(str) + ", ";
            }
            str = String.valueOf(str) + strArr[i];
        }
        return str;
    }

    public static String arrayToString(boolean[] zArr) {
        String str = "[ ";
        for (boolean z : zArr) {
            str = String.valueOf(str) + (z ? "+ " : "- ");
        }
        return String.valueOf(str) + "]";
    }

    public static void main(String[] strArr) {
        KDEreliabilityTests kDEreliabilityTests = new KDEreliabilityTests();
        System.out.println("Default configuration: " + kDEreliabilityTests.indTestParams);
        System.out.println("=------------------------------------------=");
        kDEreliabilityTests.printFalseTests = false;
        printPercentages(kDEreliabilityTests.performTestsNbrSamples(true, true));
        System.out.println("");
        System.out.println("****  TESTS FINISHED   ****");
    }
}
