package be.ac.vub.ir.statistics;

import be.ac.vub.ir.util.ArrayList2;
import be.ac.vub.ir.util.OneObjectList;
import be.ac.vub.ir.util.StatUtils;
import edu.cmu.tetrad.data.Column;
import edu.cmu.tetrad.data.DataLoaders;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteDataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.IntColumn;
import edu.cmu.tetrad.data.Variable;
import edu.cmu.tetrad.ind.IndependenceTest;
import flanagan.math.Fmath;
import java.io.File;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:be/ac/vub/ir/statistics/IndependenceDiscrete.class */
public class IndependenceDiscrete implements IndependenceTest, MutualInformation {
    private static final long serialVersionUID = 1;
    protected DiscreteDataSet mDataSet;
    protected float mDependencyThreshold;
    protected float mDependencyValue;

    public IndependenceDiscrete(DiscreteDataSet discreteDataSet) {
        this(discreteDataSet, 0.2f);
    }

    public IndependenceDiscrete(DiscreteDataSet discreteDataSet, float f) {
        this.mDataSet = discreteDataSet;
        this.mDependencyThreshold = f;
    }

    @Override // be.ac.vub.ir.statistics.MutualInformation
    public float mutualInfo(Variable variable, Variable variable2) {
        return mutualInformation((IntColumn) this.mDataSet.getColumn(variable), (IntColumn) this.mDataSet.getColumn(variable2));
    }

    @Override // be.ac.vub.ir.statistics.MutualInformation
    public float mutualInfo(Variable variable, List<Variable> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<Variable> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add((IntColumn) this.mDataSet.getColumn((DiscreteVariable) it.next()));
        }
        return mutualInformation((IntColumn) this.mDataSet.getColumn(variable), arrayList);
    }

    @Override // be.ac.vub.ir.statistics.MutualInformation
    public float condMutualInfo(Variable variable, Variable variable2, List<Variable> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<Variable> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add((IntColumn) this.mDataSet.getColumn((DiscreteVariable) it.next()));
        }
        return condMutualInformation((IntColumn) this.mDataSet.getColumn(variable), (IntColumn) this.mDataSet.getColumn(variable2), (List<IntColumn>) arrayList);
    }

    @Override // be.ac.vub.ir.statistics.MutualInformation
    public float condMutualInfo(Variable variable, List<Variable> list, List<Variable> list2) {
        ArrayList arrayList = new ArrayList();
        Iterator<Variable> it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add((IntColumn) this.mDataSet.getColumn((DiscreteVariable) it.next()));
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator<Variable> it2 = list.iterator();
        while (it2.hasNext()) {
            arrayList2.add((IntColumn) this.mDataSet.getColumn((DiscreteVariable) it2.next()));
        }
        return condMutualInformation((IntColumn) this.mDataSet.getColumn(variable), arrayList2, arrayList);
    }

    @Override // be.ac.vub.ir.statistics.MutualInformation
    public float condMutualInfo(Variable variable, Variable variable2, Variable variable3) {
        return condMutualInfo(variable, variable2, new OneObjectList(variable3));
    }

    @Override // be.ac.vub.ir.statistics.MutualInformation
    public float condMutualInfo(Variable variable, List<Variable> list, Variable variable2) {
        return condMutualInfo(variable, list, new OneObjectList(variable2));
    }

    public static float mutualInformation(IntColumn intColumn, IntColumn intColumn2) {
        DiscreteVariable discreteVariable = (DiscreteVariable) intColumn.getVariable();
        DiscreteVariable discreteVariable2 = (DiscreteVariable) intColumn2.getVariable();
        int[][] iArr = new int[discreteVariable.getNumCategories()][discreteVariable2.getNumCategories()];
        float[][] fArr = new float[discreteVariable.getNumCategories()][discreteVariable2.getNumCategories()];
        int numCategories = discreteVariable.getNumCategories();
        int numCategories2 = discreteVariable2.getNumCategories();
        int i = 0;
        for (int i2 = 0; i2 < intColumn.size(); i2++) {
            int[] iArr2 = iArr[intColumn.atI(i2)];
            int atI = intColumn2.atI(i2);
            iArr2[atI] = iArr2[atI] + 1;
            i++;
        }
        for (int i3 = 0; i3 < numCategories; i3++) {
            for (int i4 = 0; i4 < numCategories2; i4++) {
                fArr[i3][i4] = iArr[i3][i4] / intColumn.size();
            }
        }
        float[] fArr2 = new float[numCategories];
        float[] fArr3 = new float[numCategories2];
        for (int i5 = 0; i5 < numCategories; i5++) {
            for (int i6 = 0; i6 < numCategories2; i6++) {
                int i7 = i5;
                fArr2[i7] = fArr2[i7] + fArr[i5][i6];
                int i8 = i6;
                fArr3[i8] = fArr3[i8] + fArr[i5][i6];
            }
        }
        double d = 0.0d;
        for (int i9 = 0; i9 < numCategories; i9++) {
            for (int i10 = 0; i10 < numCategories2; i10++) {
                if (fArr2[i9] * fArr3[i10] > 0.0f && fArr[i9][i10] > 0.0f) {
                    d += fArr[i9][i10] * StatUtils.log2(fArr[i9][i10] / (fArr2[i9] * fArr3[i10]));
                }
            }
        }
        return (float) d;
    }

    public static float mutualInformation(IntColumn intColumn, List<IntColumn> list) {
        if (list.size() == 1) {
            return mutualInformation(intColumn, list.get(0));
        }
        Iterator<IntColumn> it = list.iterator();
        while (it.hasNext()) {
            if (((DiscreteVariable) it.next().getVariable()).getNumCategories() > 2) {
                throw new UnsupportedOperationException("Test implmeneted for binary variables only. Sorry.");
            }
        }
        int numCategories = ((DiscreteVariable) intColumn.getVariable()).getNumCategories();
        int pow = (int) Math.pow(2.0d, list.size());
        int[][] iArr = new int[numCategories][pow];
        float[][] fArr = new float[numCategories][pow];
        int i = 0;
        for (int i2 = 0; i2 < intColumn.size(); i2++) {
            int i3 = 0;
            for (int i4 = 0; i4 < list.size(); i4++) {
                if (list.get(i4).atI(i2) > 0) {
                    i3 += (int) Math.pow(2.0d, i4);
                }
            }
            int[] iArr2 = iArr[intColumn.atI(i2)];
            int i5 = i3;
            iArr2[i5] = iArr2[i5] + 1;
            i++;
        }
        for (int i6 = 0; i6 < numCategories; i6++) {
            for (int i7 = 0; i7 < pow; i7++) {
                fArr[i6][i7] = iArr[i6][i7] / intColumn.size();
            }
        }
        float[] fArr2 = new float[numCategories];
        float[] fArr3 = new float[pow];
        for (int i8 = 0; i8 < numCategories; i8++) {
            for (int i9 = 0; i9 < pow; i9++) {
                int i10 = i8;
                fArr2[i10] = fArr2[i10] + fArr[i8][i9];
                int i11 = i9;
                fArr3[i11] = fArr3[i11] + fArr[i8][i9];
            }
        }
        double d = 0.0d;
        for (int i12 = 0; i12 < numCategories; i12++) {
            for (int i13 = 0; i13 < pow; i13++) {
                if (fArr2[i12] * fArr3[i13] > 0.0f && fArr[i12][i13] > 0.0f) {
                    d += fArr[i12][i13] * StatUtils.log2(fArr[i12][i13] / (fArr2[i12] * fArr3[i13]));
                }
            }
        }
        return (float) d;
    }

    public static float condMutualInformation(IntColumn intColumn, IntColumn intColumn2, List<IntColumn> list) {
        if (list == null || list.size() == 0) {
            return mutualInformation(intColumn, intColumn2);
        }
        Iterator<IntColumn> it = list.iterator();
        while (it.hasNext()) {
            if (((DiscreteVariable) it.next().getVariable()).getNumCategories() > 2) {
                throw new UnsupportedOperationException("Test implmeneted for binary variables only. Sorry.");
            }
        }
        DiscreteVariable discreteVariable = (DiscreteVariable) intColumn.getVariable();
        DiscreteVariable discreteVariable2 = (DiscreteVariable) intColumn2.getVariable();
        int numCategories = discreteVariable.getNumCategories();
        int numCategories2 = discreteVariable2.getNumCategories();
        int pow = (int) Math.pow(2.0d, list.size());
        int[][][] iArr = new int[numCategories][numCategories2][pow];
        int[] iArr2 = new int[pow];
        float[][][] fArr = new float[numCategories][numCategories2][pow];
        int i = 0;
        for (int i2 = 0; i2 < intColumn.size(); i2++) {
            int i3 = 0;
            for (int i4 = 0; i4 < list.size(); i4++) {
                if (list.get(i4).atI(i2) > 0) {
                    i3 += (int) Math.pow(2.0d, i4);
                }
            }
            int[] iArr3 = iArr[intColumn.atI(i2)][intColumn2.atI(i2)];
            int i5 = i3;
            iArr3[i5] = iArr3[i5] + 1;
            int i6 = i3;
            iArr2[i6] = iArr2[i6] + 1;
            i++;
        }
        for (int i7 = 0; i7 < numCategories; i7++) {
            for (int i8 = 0; i8 < numCategories2; i8++) {
                for (int i9 = 0; i9 < pow; i9++) {
                    fArr[i7][i8][i9] = iArr[i7][i8][i9] / iArr2[i9];
                }
            }
        }
        float[][] fArr2 = new float[numCategories][pow];
        float[][] fArr3 = new float[numCategories2][pow];
        for (int i10 = 0; i10 < numCategories; i10++) {
            for (int i11 = 0; i11 < numCategories2; i11++) {
                for (int i12 = 0; i12 < pow; i12++) {
                    float[] fArr4 = fArr2[i10];
                    int i13 = i12;
                    fArr4[i13] = fArr4[i13] + fArr[i10][i11][i12];
                    float[] fArr5 = fArr3[i11];
                    int i14 = i12;
                    fArr5[i14] = fArr5[i14] + fArr[i10][i11][i12];
                }
            }
        }
        float[] fArr6 = new float[pow];
        for (int i15 = 0; i15 < pow; i15++) {
            fArr6[i15] = iArr2[i15] / i;
        }
        double d = 0.0d;
        for (int i16 = 0; i16 < numCategories; i16++) {
            for (int i17 = 0; i17 < numCategories2; i17++) {
                for (int i18 = 0; i18 < pow; i18++) {
                    float f = fArr2[i16][i18] * fArr3[i17][i18];
                    if (f > 0.0f && fArr[i16][i17][i18] > 0.0f) {
                        d += fArr6[i18] * fArr[i16][i17][i18] * StatUtils.log2(fArr[i16][i17][i18] / f);
                    }
                }
            }
        }
        return (float) d;
    }

    public static float condMutualInformation(IntColumn intColumn, List<IntColumn> list, List<IntColumn> list2) {
        if (list.size() == 1) {
            return condMutualInformation(intColumn, list.get(0), list2);
        }
        Iterator<IntColumn> it = list.iterator();
        while (it.hasNext()) {
            if (((DiscreteVariable) it.next().getVariable()).getNumCategories() > 2) {
                throw new UnsupportedOperationException("Test implmeneted for binary variables only. Sorry.");
            }
        }
        int numCategories = ((DiscreteVariable) intColumn.getVariable()).getNumCategories();
        int pow = (int) Math.pow(2.0d, list.size());
        int pow2 = (int) Math.pow(2.0d, list2.size());
        int[][][] iArr = new int[numCategories][pow][pow2];
        int[] iArr2 = new int[pow2];
        float[][][] fArr = new float[numCategories][pow][pow2];
        int i = 0;
        for (int i2 = 0; i2 < intColumn.size(); i2++) {
            int i3 = 0;
            for (int i4 = 0; i4 < list.size(); i4++) {
                if (list.get(i4).atI(i2) > 0) {
                    i3 += (int) Math.pow(2.0d, i4);
                }
            }
            int i5 = 0;
            for (int i6 = 0; i6 < list2.size(); i6++) {
                if (list2.get(i6).atI(i2) > 0) {
                    i5 += (int) Math.pow(2.0d, i6);
                }
            }
            int[] iArr3 = iArr[intColumn.atI(i2)][i3];
            int i7 = i5;
            iArr3[i7] = iArr3[i7] + 1;
            int i8 = i5;
            iArr2[i8] = iArr2[i8] + 1;
            i++;
        }
        for (int i9 = 0; i9 < numCategories; i9++) {
            for (int i10 = 0; i10 < pow; i10++) {
                for (int i11 = 0; i11 < pow2; i11++) {
                    fArr[i9][i10][i11] = iArr[i9][i10][i11] / iArr2[i11];
                }
            }
        }
        float[][] fArr2 = new float[numCategories][pow2];
        float[][] fArr3 = new float[pow][pow2];
        for (int i12 = 0; i12 < numCategories; i12++) {
            for (int i13 = 0; i13 < pow; i13++) {
                for (int i14 = 0; i14 < pow2; i14++) {
                    float[] fArr4 = fArr2[i12];
                    int i15 = i14;
                    fArr4[i15] = fArr4[i15] + fArr[i12][i13][i14];
                    float[] fArr5 = fArr3[i13];
                    int i16 = i14;
                    fArr5[i16] = fArr5[i16] + fArr[i12][i13][i14];
                }
            }
        }
        float[] fArr6 = new float[pow2];
        for (int i17 = 0; i17 < pow2; i17++) {
            fArr6[i17] = iArr2[i17] / i;
        }
        double d = 0.0d;
        for (int i18 = 0; i18 < numCategories; i18++) {
            for (int i19 = 0; i19 < pow; i19++) {
                for (int i20 = 0; i20 < pow2; i20++) {
                    float f = fArr2[i18][i20] * fArr3[i19][i20];
                    if (f > 0.0f && fArr[i18][i19][i20] > 0.0f) {
                        d += fArr6[i20] * fArr[i18][i19][i20] * StatUtils.log2(fArr[i18][i19][i20] / f);
                    }
                }
            }
        }
        return (float) d;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public boolean isIndependent(Variable variable, Variable variable2, List list) {
        Column column = this.mDataSet.getColumn(variable);
        if (column == null) {
            throw new IllegalArgumentException("Variable " + variable + " is not in dataset " + this.mDataSet.getName());
        }
        Column column2 = this.mDataSet.getColumn(variable);
        if (column2 == null) {
            throw new IllegalArgumentException("Variable " + variable2 + " is not in dataset " + this.mDataSet.getName());
        }
        if (list == null) {
            this.mDependencyValue = mutualInformation((IntColumn) column, (IntColumn) column2);
        }
        return this.mDependencyValue < this.mDependencyThreshold;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public double getCorr(String str, String str2) {
        if (this.mDataSet.getColumn(str) == null) {
            throw new IllegalArgumentException("Variable " + str + " is not in dataset " + this.mDataSet.getName());
        }
        if (this.mDataSet.getColumn(str2) == null) {
            throw new IllegalArgumentException("Variable " + str2 + " is not in dataset " + this.mDataSet.getName());
        }
        return mutualInformation((IntColumn) r0, (IntColumn) r0);
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public double getCutoff() {
        return this.mDependencyThreshold;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public DataSet getData() {
        return this.mDataSet;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public double getDependencyStrength() {
        return this.mDependencyValue;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public double getPValue() {
        return this.mDependencyValue;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public double getRelativeStrength(Variable variable, Variable variable2, List list) {
        isIndependent(variable, variable2, list);
        return this.mDependencyValue;
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public Variable getVariable(String str) {
        return this.mDataSet.getVariable(str);
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public List getVariableNames() {
        return this.mDataSet.getVariableNamesList();
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public List getVariables() {
        return this.mDataSet.getVariables();
    }

    @Override // edu.cmu.tetrad.ind.IndependenceTest
    public IndependenceTest indTestSubset(List list) {
        List variables = this.mDataSet.getVariables();
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Subset may not be empty.");
        }
        for (int i = 0; i < list.size(); i++) {
            if (!variables.contains(list.get(i))) {
                throw new IllegalArgumentException("All vars must be original vars");
            }
        }
        DataSet dataSet = new DataSet();
        for (int i2 = 0; i2 < variables.size(); i2++) {
            if (list.contains(variables.get(i2))) {
                dataSet.addColumn(this.mDataSet.getColumn(i2));
            }
        }
        return new IndependenceDiscrete(new DiscreteDataSet(dataSet), this.mDependencyThreshold);
    }

    public static void main(String[] strArr) {
        DiscreteDataSet discreteDataSet = (DiscreteDataSet) DataLoaders.loadDataFromFile(new File("D:\\research\\causalityChallenge\\data\\lucas0\\lucas0_trainingset.data"));
        IndependenceDiscrete independenceDiscrete = new IndependenceDiscrete(discreteDataSet);
        DiscreteVariable discreteVariable = (DiscreteVariable) discreteDataSet.getVariable("LungCancer");
        DiscreteVariable discreteVariable2 = (DiscreteVariable) discreteDataSet.getVariable("Anxiety");
        DiscreteVariable discreteVariable3 = (DiscreteVariable) discreteDataSet.getVariable("CarAccident");
        DiscreteVariable discreteVariable4 = (DiscreteVariable) discreteDataSet.getVariable("Smoking");
        DiscreteVariable discreteVariable5 = (DiscreteVariable) discreteDataSet.getVariable("Coughing");
        System.out.println("MI(" + discreteVariable + "; " + discreteVariable2 + ")=" + Fmath.truncate(independenceDiscrete.mutualInfo(discreteVariable, discreteVariable2), 3));
        System.out.println("MI(" + discreteVariable + "; " + discreteVariable3 + ")=" + Fmath.truncate(independenceDiscrete.mutualInfo(discreteVariable, discreteVariable3), 3));
        ArrayList2 arrayList2 = new ArrayList2(discreteVariable2, discreteVariable3);
        System.out.println("MI(" + discreteVariable + "; " + arrayList2 + ")=" + Fmath.truncate(independenceDiscrete.mutualInfo(discreteVariable, arrayList2), 3));
        ArrayList2 arrayList22 = new ArrayList2(discreteVariable4, discreteVariable5);
        System.out.println("MI(" + discreteVariable + "; " + arrayList2 + "|" + discreteVariable4 + ")=" + Fmath.truncate(independenceDiscrete.condMutualInfo(discreteVariable, arrayList2, new OneObjectList(discreteVariable4)), 3));
        System.out.println("MI(" + discreteVariable + "; " + arrayList2 + "|" + discreteVariable5 + ")=" + Fmath.truncate(independenceDiscrete.condMutualInfo(discreteVariable, arrayList2, new OneObjectList(discreteVariable5)), 3));
        System.out.println("MI(" + discreteVariable + "; " + arrayList2 + "|" + arrayList22 + ")=" + Fmath.truncate(independenceDiscrete.condMutualInfo(discreteVariable, arrayList2, arrayList22), 3));
    }
}
