package be.ac.vub.ir.statistics;

import be.ac.vub.ir.data.distribution.DiscretizedDistribution;
import be.ac.vub.ir.data.distribution.MultiVariateDistribution;
import be.ac.vub.ir.data.distribution.UniVariateDistribution;
import be.ac.vub.ir.statistics.estimators.KDE1D;
import be.ac.vub.ir.statistics.estimators.KdeEntropyEstimator;
import be.ac.vub.ir.statistics.estimators.KdeLimitedMemoryEntropyEstimator;
import be.ac.vub.ir.statistics.estimators.KdeOptEntropyEstimator;
import be.ac.vub.ir.statistics.estimators.KdeParams;
import be.ac.vub.ir.util.StatUtils;
import edu.cmu.tetrad.data.Column;
import edu.cmu.tetrad.data.ColumnExt;
import edu.cmu.tetrad.data.ContinuousColumn;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.DiscreteColumn;
import edu.cmu.tetrad.data.DoubleColumn;
import edu.cmu.tetrad.data.FloatColumn;
import edu.cmu.tetrad.data.MixedDataSet;
import java.util.Random;

/* loaded from: input_file:be/ac/vub/ir/statistics/StatUtilsExt.class */
public class StatUtilsExt extends StatUtils {
    static Random mRandom = null;

    public static float kullbackLeiblerDistance(DiscretizedDistribution discretizedDistribution, MultiVariateDistribution multiVariateDistribution) {
        if (discretizedDistribution.dimCount() != multiVariateDistribution.dimCount()) {
            throw new IllegalArgumentException("Dimensions of both distributions should be the same: " + discretizedDistribution.dimCount() + " <> " + multiVariateDistribution.dimCount());
        }
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = 0.0f;
        for (int i = 0; i < discretizedDistribution.nbrProbabilities(); i++) {
            float probability = discretizedDistribution.probability(i) * discretizedDistribution.norm();
            float probability2 = ((multiVariateDistribution instanceof DiscretizedDistribution) && discretizedDistribution.discretizationProps().equals(((DiscretizedDistribution) multiVariateDistribution).discretizationProps())) ? ((DiscretizedDistribution) multiVariateDistribution).probability(i) : multiVariateDistribution.probability(discretizedDistribution.indexToPoint(discretizedDistribution.convertToCoordinates(i)));
            f2 += probability;
            f3 += probability2;
            if (probability != 0.0f) {
                f = probability2 == 0.0f ? (float) (f + (probability * Math.log(probability / 1.0E-9d))) : (float) (f + (probability * Math.log(probability / probability2)));
            } else if (probability2 != 0.0f) {
                f = (float) (f + (probability2 * Math.log(probability2 / 1.0E-9d)));
            }
            if (Float.isNaN(f)) {
                System.err.println("dist got NaN!!");
            }
        }
        System.out.println("KL distance, sum p1 = " + f2 + ", sum p2 = " + f3);
        return f;
    }

    public static float entropy1D(ColumnExt columnExt) {
        return new KDE1D(columnExt).differentialEntropy();
    }

    public static float randomPoint(UniVariateDistribution uniVariateDistribution) {
        if (mRandom == null) {
            mRandom = new Random();
        }
        return uniVariateDistribution.accumulatedProbability1D(mRandom.nextFloat());
    }

    public static ColumnExt sample(UniVariateDistribution uniVariateDistribution, int i) {
        float[] fArr = new float[i];
        for (int i2 = 0; i2 < i; i2++) {
            fArr[i2] = randomPoint(uniVariateDistribution);
        }
        return new FloatColumn("", "", fArr);
    }

    public static float[] randomPoint(MultiVariateDistribution multiVariateDistribution) {
        if (mRandom == null) {
            mRandom = new Random();
        }
        return multiVariateDistribution.accProbabilityPoint(mRandom.nextFloat());
    }

    public static DataSet sample(MultiVariateDistribution multiVariateDistribution, int i) {
        DataSet dataSet = new DataSet();
        int dimCount = multiVariateDistribution.dimCount();
        float[][] fArr = new float[dimCount][i];
        for (int i2 = 0; i2 < i; i2++) {
            float[] randomPoint = randomPoint(multiVariateDistribution);
            for (int i3 = 0; i3 < dimCount; i3++) {
                fArr[i3][i2] = randomPoint[i3];
            }
        }
        for (int i4 = 0; i4 < dimCount; i4++) {
            dataSet.addColumn(new FloatColumn("", "", fArr[i4]));
        }
        return dataSet;
    }

    public static DataSet createSampleData(MultiVariateDistribution multiVariateDistribution, int i) {
        DataSet sample;
        if (multiVariateDistribution instanceof UniVariateDistribution) {
            sample = new DataSet();
            sample.add(sample((UniVariateDistribution) multiVariateDistribution, i));
        } else {
            sample = sample(multiVariateDistribution, i);
        }
        return sample;
    }

    public static DataSet logConvert(DataSet dataSet) {
        MixedDataSet mixedDataSet = new MixedDataSet();
        for (int i = 0; i < dataSet.getNumColumns(); i++) {
            ColumnExt columnExt = DataUtils.toColumnExt(dataSet.getColumn(i));
            if (columnExt instanceof DiscreteColumn) {
                mixedDataSet.addColumn(columnExt);
            } else {
                DoubleColumn doubleColumn = new DoubleColumn((ContinuousVariable) columnExt.getVariable());
                int i2 = 0;
                for (int i3 = 0; i3 < columnExt.size(); i3++) {
                    if (columnExt.at(i3) == 0.0d) {
                        i2++;
                        doubleColumn.add(0.0d);
                    } else {
                        doubleColumn.add(Math.log(columnExt.at(i3)));
                    }
                }
                if (i2 > 0) {
                    System.err.println("While logconverting column '" + columnExt.getVariable() + "' of dataset '" + dataSet + "', got " + i2 + " zero values! Guess -Infinity is not ok, so took 0");
                }
                mixedDataSet.addColumn(doubleColumn);
            }
        }
        return mixedDataSet;
    }

    public static KdeParams getKDEParams(InformationWEntropy informationWEntropy) {
        EntropyEstimator entropyEstimator = informationWEntropy.entropyEstimator();
        if (entropyEstimator instanceof EntropyEstimatorCache) {
            entropyEstimator = ((EntropyEstimatorCache) entropyEstimator).entropyEstimator();
        }
        if (entropyEstimator instanceof KdeOptEntropyEstimator) {
            return ((KdeOptEntropyEstimator) entropyEstimator).kdeParams();
        }
        if (entropyEstimator instanceof KdeEntropyEstimator) {
            return ((KdeEntropyEstimator) entropyEstimator).kdeParams();
        }
        if (entropyEstimator instanceof KdeLimitedMemoryEntropyEstimator) {
            return ((KdeLimitedMemoryEntropyEstimator) entropyEstimator).kdeParams();
        }
        System.err.println("Unknown EntropyEstimator (" + entropyEstimator + "), getting KDEParams failed");
        return null;
    }

    public static void setKDEParams(InformationWEntropy informationWEntropy, KdeParams kdeParams) {
        EntropyEstimator entropyEstimator = informationWEntropy.entropyEstimator();
        if (entropyEstimator instanceof EntropyEstimatorCache) {
            entropyEstimator = ((EntropyEstimatorCache) entropyEstimator).entropyEstimator();
        }
        if (entropyEstimator instanceof KdeOptEntropyEstimator) {
            ((KdeOptEntropyEstimator) entropyEstimator).setKdeParams(kdeParams);
            return;
        }
        if (entropyEstimator instanceof KdeEntropyEstimator) {
            ((KdeEntropyEstimator) entropyEstimator).setKdeParams(kdeParams);
        } else if (entropyEstimator instanceof KdeLimitedMemoryEntropyEstimator) {
            ((KdeLimitedMemoryEntropyEstimator) entropyEstimator).setKdeParams(kdeParams);
        } else {
            System.err.println("Unknown EntropyEstimator (" + entropyEstimator + "), setting KDEParams failed");
        }
    }

    public static double mean(Column column) {
        if (column instanceof DiscreteColumn) {
            return mean((int[]) column.getRawData(), column.size());
        }
        if (column instanceof ContinuousColumn) {
            return mean((double[]) column.getRawData(), column.size());
        }
        if (column instanceof FloatColumn) {
            return mean((float[]) column.getRawData(), column.size());
        }
        throw new IllegalArgumentException("Column type " + column.getClass() + " not supported for calculating mean.");
    }

    public static double standardDeviation(Column column) {
        if (column instanceof DiscreteColumn) {
            return standardDeviation((int[]) column.getRawData(), column.size());
        }
        if (column instanceof ContinuousColumn) {
            return standardDeviation((double[]) column.getRawData(), column.size());
        }
        if (column instanceof FloatColumn) {
            return standardDeviation((float[]) column.getRawData(), column.size());
        }
        if (column instanceof ColumnExt) {
            return standardDeviation((ColumnExt) column);
        }
        throw new IllegalArgumentException("Column type " + column.getClass() + " not supported for calculating mean.");
    }

    private static float standardDeviation(ColumnExt columnExt) {
        float f = 0.0f;
        float f2 = 0.0f;
        for (int i = 0; i < columnExt.size(); i++) {
            f += (float) columnExt.at(i);
        }
        float size = f / columnExt.size();
        for (int i2 = 0; i2 < columnExt.size(); i2++) {
            f2 += (float) Math.pow(size - columnExt.at(i2), 2.0d);
        }
        return (float) Math.sqrt(f2 / columnExt.size());
    }

    public static double rXY(ColumnExt columnExt, ColumnExt columnExt2) {
        return rXY(DataUtils.toDoubleArray(columnExt), DataUtils.toDoubleArray(columnExt2));
    }
}
