package bn.prob;

import bn.Distrib;
import dat.Domain;
import dat.Enumerable;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;

/* JADX WARN: Classes with same name are omitted:
  input_file:target/classes/bn/prob/DirichletDistrib.class
 */
/* loaded from: input_file:bn/prob/DirichletDistrib.class */
public class DirichletDistrib implements Distrib, Serializable {
    private static final long serialVersionUID = 1;
    private double[] alpha;
    private GammaDistrib[] gammas;
    private final Enumerable domain;
    static int[] component = null;
    static String[] dna2 = {"AAAAAAAAAACCCCCCCCCCGGGGGGGGGGTTTTTTTTTTAAAAAAACCCCCCCGGGGGGGTTTTTTTAC", "TTCGGCACGAGTCTCGGGCGGGAGAAGAAGAAGAATTAGTAAAGTGTGATCATAATGTCTGCTAGCGGCG", "GCACCGGAGATGAAGATAAGAAGCCTAATGATCAGATGGTTCATATCAATCTCAAGGTTAAGGGTCAGGA", "TGGGAATGAAGTTTTTTTCAGGATCAAACGTAGCACACAGATGCGCAAGCTCATGAATGCTTATTGTGAC", "CGGCAGTCAGTGGACATGAACTCAATTGCATTCTTATTTGATGGGCGCAGGCTTAGGGCAGAGCAAACTC", "CTGATGAGCTGGAGATGGAGGAGGGTGATGAAATCGATGCAATGCTACATCAAACTGGAGGCAGTTGCTG", "CACTTGTTTCTCTAATTTTTAACTTGGTTTATGTTAGTAGATTGTTTAGGGTAATACTTTCAACTCCCTC", "ATCTGCTCTAAGATGGGTAAATTTATGAATGTTTAGTTTTCAGTATTAGATGATGACACTACTAAATGGT", "TCAATTTTCATGGCATTTGTAAAAGTTTACTCTTAATATGGTTAAAAAGATGATGACACTACTAAATGGT"};
    static String[] dna = {"ACGTACGTACGTACGTACGTACGTACGTA", "AAGTAAGTAAGTAAGTAAGTAAGTAAGTC", "AACTAACTAACTAACTAACTAACTAACTG", "AACGAACGAACGAACGAACGAACGAACGT", "CCGTCCGTCCGTCCGTCCGTCCGTCCGTA"};

    public DirichletDistrib(Enumerable enumerable, double d) {
        this.domain = enumerable;
        this.alpha = new double[enumerable.size()];
        Arrays.fill(this.alpha, d);
        this.gammas = new GammaDistrib[this.alpha.length];
        for (int i = 0; i < this.gammas.length; i++) {
            this.gammas[i] = new GammaDistrib(this.alpha[i], 1.0d);
        }
    }

    public DirichletDistrib(EnumDistrib enumDistrib, double d) {
        this.domain = enumDistrib.getDomain();
        this.alpha = new double[this.domain.size()];
        for (int i = 0; i < this.alpha.length; i++) {
            this.alpha[i] = enumDistrib.get(i) * d;
        }
        this.gammas = new GammaDistrib[this.alpha.length];
        for (int i2 = 0; i2 < this.gammas.length; i2++) {
            this.gammas[i2] = new GammaDistrib(this.alpha[i2], 1.0d);
        }
    }

    public DirichletDistrib(Enumerable enumerable, double... dArr) {
        if (enumerable.size() != dArr.length) {
            throw new RuntimeException("Invalid distribution");
        }
        this.domain = enumerable;
        this.alpha = dArr;
        this.gammas = new GammaDistrib[this.alpha.length];
        for (int i = 0; i < this.gammas.length; i++) {
            this.gammas[i] = new GammaDistrib(dArr[i], 1.0d);
        }
    }

    public DirichletDistrib(Enumerable enumerable, double[] dArr, double d) {
        if (enumerable.size() != dArr.length) {
            throw new RuntimeException("Invalid distribution");
        }
        this.domain = enumerable;
        this.alpha = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            this.alpha[i] = dArr[i] * d;
        }
        this.gammas = new GammaDistrib[this.alpha.length];
        for (int i2 = 0; i2 < this.gammas.length; i2++) {
            this.gammas[i2] = new GammaDistrib(this.alpha[i2], 1.0d);
        }
    }

    public double getSum() {
        double d = 0.0d;
        for (double d2 : this.alpha) {
            d += d2;
        }
        return d;
    }

    public void setPrior(double[] dArr) {
        for (int i = 0; i < this.alpha.length; i++) {
            this.alpha[i] = dArr[i];
            this.gammas[i] = new GammaDistrib(this.alpha[i], 1.0d);
        }
    }

    public Domain getDomain() {
        return this.domain;
    }

    @Override // bn.Distrib
    public double get(Object obj) {
        double[] dArr = ((EnumDistrib) obj).get();
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += Math.log(Math.pow(dArr[i], this.alpha[i] - 1.0d));
        }
        return Math.exp(d - lnormalize(this.alpha));
    }

    @Override // bn.Distrib
    public Object sample() {
        double d = 0.0d;
        double[] dArr = new double[this.alpha.length];
        for (int i = 0; i < this.alpha.length; i++) {
            double doubleValue = this.gammas[i].sample().doubleValue();
            d += doubleValue;
            dArr[i] = doubleValue;
        }
        for (int i2 = 0; i2 < this.alpha.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
        }
        return new EnumDistrib(this.domain, dArr);
    }

    public static final double normalize(double[] dArr) {
        double d = 0.0d;
        double d2 = 1.0d;
        for (double d3 : dArr) {
            d2 *= GammaDistrib.gamma(d3);
            d += d3;
        }
        return d2 / GammaDistrib.gamma(d);
    }

    public static final double lnormalize(double[] dArr) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (double d3 : dArr) {
            d2 += GammaDistrib.lgamma(d3);
            d += d3;
        }
        return d2 - GammaDistrib.lgamma(d);
    }

    public double[] getAlpha() {
        return this.alpha;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < this.alpha.length; i++) {
            if (i < this.alpha.length - 1) {
                sb.append(String.format("%7.5f,", Double.valueOf(this.alpha[i])));
            } else {
                sb.append(String.format("%7.5f", Double.valueOf(this.alpha[i])));
            }
        }
        return sb.toString();
    }

    public double logLikelihood(int[] iArr) {
        double sum = getSum();
        int i = 0;
        for (int i2 : iArr) {
            i += i2;
        }
        double lgamma = GammaDistrib.lgamma(sum) - GammaDistrib.lgamma(sum + i);
        for (int i3 = 0; i3 < this.alpha.length; i3++) {
            lgamma += GammaDistrib.lgamma(this.alpha[i3] + iArr[i3]) - GammaDistrib.lgamma(this.alpha[i3]);
        }
        return lgamma;
    }

    public static double logLikelihood(int[] iArr, double[] dArr, double d) {
        int i = 0;
        for (int i2 : iArr) {
            i += i2;
        }
        double lgamma = GammaDistrib.lgamma(d) - GammaDistrib.lgamma(d + i);
        for (int i3 = 0; i3 < dArr.length; i3++) {
            lgamma += GammaDistrib.lgamma((d * dArr[i3]) + iArr[i3]) - GammaDistrib.lgamma(d * dArr[i3]);
        }
        return lgamma;
    }

    public double logLikelihood(int[][] iArr) {
        double d = 0.0d;
        for (int[] iArr2 : iArr) {
            d += logLikelihood(iArr2);
        }
        return d;
    }

    public boolean equals(DirichletDistrib dirichletDistrib) {
        if (!getDomain().equals(dirichletDistrib.getDomain())) {
            return false;
        }
        double[] alpha = dirichletDistrib.getAlpha();
        for (int i = 0; i < this.alpha.length; i++) {
            if (Math.abs(this.alpha[i] - alpha[i]) > 1.0E-15d) {
                return false;
            }
        }
        return true;
    }

    public static double logLikelihood(int[][] iArr, double[] dArr, double d) {
        double d2 = 0.0d;
        for (int[] iArr2 : iArr) {
            d2 += logLikelihood(iArr2, dArr, d);
        }
        return d2;
    }

    public static double logLikelihood_1stDerivative(int[][] iArr, double[] dArr, double[] dArr2, double d) {
        double d2 = 0.0d;
        for (int i = 0; i < iArr.length; i++) {
            int i2 = 0;
            for (int i3 : iArr[i]) {
                i2 += i3;
            }
            double digamma = GammaDistrib.digamma(d) - GammaDistrib.digamma(d + i2);
            for (int i4 = 0; i4 < iArr[i].length; i4++) {
                digamma += dArr2[i4] * (GammaDistrib.digamma((dArr2[i4] * d) + iArr[i][i4]) - GammaDistrib.digamma(dArr2[i4] * d));
            }
            d2 += digamma * dArr[i];
        }
        return d2;
    }

    public static double logLikelihood_2ndDerivative(int[][] iArr, double[] dArr, double[] dArr2, double d) {
        double d2 = 0.0d;
        for (int i = 0; i < iArr.length; i++) {
            int i2 = 0;
            for (int i3 : iArr[i]) {
                i2 += i3;
            }
            double trigamma = GammaDistrib.trigamma(d) - GammaDistrib.trigamma(d + i2);
            for (int i4 = 0; i4 < iArr[i].length; i4++) {
                trigamma += dArr2[i4] * dArr2[i4] * (GammaDistrib.trigamma((dArr2[i4] * d) + iArr[i][i4]) - GammaDistrib.trigamma(dArr2[i4] * d));
            }
            d2 += trigamma * dArr[i];
        }
        return d2;
    }

    public static double getAlphaSum_byNewton(int[][] iArr, double[] dArr, double[] dArr2) {
        double d = 1.0d;
        double logLikelihood_1stDerivative = logLikelihood_1stDerivative(iArr, dArr, dArr2, 1.0d);
        for (int i = 0; logLikelihood_1stDerivative > 1.0E-6d && i < 1000; i++) {
            double logLikelihood_2ndDerivative = d - (logLikelihood_1stDerivative / logLikelihood_2ndDerivative(iArr, dArr, dArr2, d));
            d = logLikelihood_2ndDerivative;
            logLikelihood_1stDerivative = logLikelihood_1stDerivative(iArr, dArr, dArr2, logLikelihood_2ndDerivative);
        }
        if (Double.isNaN(d) || d == 0.0d) {
            System.err.println("Alpha* search failed.");
        }
        return d;
    }

    public static double[] getAlpha(int[][] iArr, double[] dArr) {
        if (iArr == null || iArr.length <= 0 || iArr.length != dArr.length) {
            return null;
        }
        double[] dArr2 = new double[iArr[0].length];
        Arrays.fill(dArr2, 0.001d);
        double d = 0.0d;
        int[][] iArr2 = new int[iArr.length][dArr2.length];
        for (int i = 0; i < iArr.length; i++) {
            int[] iArr3 = iArr[i];
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                iArr2[i][i2] = iArr3[i2];
                double d2 = iArr3[i2] * dArr[i];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + d2;
                d += d2;
            }
            if (Double.isNaN(d)) {
                System.err.println("DirichletDistrib.getAlpha fails");
            }
        }
        for (int i4 = 0; i4 < dArr2.length; i4++) {
            int i5 = i4;
            dArr2[i5] = dArr2[i5] / d;
        }
        double alphaSum_byNewton = getAlphaSum_byNewton(iArr2, dArr, dArr2);
        for (int i6 = 0; i6 < dArr2.length; i6++) {
            int i7 = i6;
            dArr2[i7] = dArr2[i7] * alphaSum_byNewton;
        }
        return dArr2;
    }

    public static double[] getAlpha(int[][] iArr) {
        double[] dArr = new double[iArr.length];
        Arrays.fill(dArr, 1.0d);
        return getAlpha(iArr, dArr);
    }

    public static double DL(int[][] iArr, EnumDistrib enumDistrib, DirichletDistrib[] dirichletDistribArr) {
        double d = 0.0d;
        for (int[] iArr2 : iArr) {
            double d2 = Double.MIN_VALUE;
            for (int i = 0; i < dirichletDistribArr.length; i++) {
                d2 += enumDistrib.get(i) * Math.exp(dirichletDistribArr[i].logLikelihood(iArr2));
            }
            d += Math.log(d2);
        }
        return -d;
    }

    public static double mixEntropy(int[][] iArr, EnumDistrib enumDistrib, DirichletDistrib[] dirichletDistribArr) {
        double d = 0.0d;
        double log = Math.log(dirichletDistribArr.length);
        for (int[] iArr2 : iArr) {
            double[] dArr = new double[dirichletDistribArr.length];
            double d2 = 0.0d;
            for (int i = 0; i < dirichletDistribArr.length; i++) {
                dArr[i] = (enumDistrib.get(i) * Math.exp(dirichletDistribArr[i].logLikelihood(iArr2))) + Double.MIN_VALUE;
                d2 += dArr[i];
            }
            double d3 = 0.0d;
            for (int i2 = 0; i2 < dirichletDistribArr.length; i2++) {
                double d4 = dArr[i2] / d2;
                d3 -= d4 * (Math.log(d4) / log);
            }
            d += d3 / iArr.length;
        }
        return d;
    }

    /* JADX WARN: Type inference failed for: r0v115, types: [int[], int[][]] */
    public static void main(String[] strArr) {
        Random random = new Random(serialVersionUID);
        int[][] loadData = loadData("/Users/mikael/Desktop/blocks_counts.txt");
        int length = loadData.length;
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            if (i == 0) {
                i = loadData[i2].length;
            } else if (i != loadData[i2].length) {
                throw new RuntimeException("Error in data: invalid item at data point " + (i2 + 1));
            }
        }
        ArrayList[] arrayListArr = new ArrayList[9];
        DirichletDistrib[] dirichletDistribArr = new DirichletDistrib[9];
        for (int i3 = 0; i3 < dirichletDistribArr.length; i3++) {
            dirichletDistribArr[i3] = new DirichletDistrib(EnumDistrib.random(new Enumerable(i), random.nextInt()), 100.0d);
            System.out.println("D" + i3 + " = " + dirichletDistribArr[i3]);
        }
        EnumDistrib random2 = EnumDistrib.random(new Enumerable(9), random.nextInt());
        random2.setSeed(random.nextInt());
        EnumDistrib random3 = EnumDistrib.random(new Enumerable(9), random.nextInt());
        random3.setSeed(random.nextInt());
        double DL = DL(loadData, random2, dirichletDistribArr);
        double[][] dArr = new double[9][i];
        for (int i4 = 0; i4 < 9; i4++) {
            System.arraycopy(dirichletDistribArr[i4].getAlpha(), 0, dArr[i4], 0, i);
        }
        EnumDistrib enumDistrib = new EnumDistrib(new Enumerable(9), random2.get());
        System.out.println("M = " + random2);
        System.out.println("P = " + random3);
        System.out.println("DL_best = " + DL);
        int i5 = 0;
        for (int i6 = 0; i6 < 100 && i5 < 10; i6++) {
            for (int i7 = 0; i7 < 9; i7++) {
                arrayListArr[i7] = new ArrayList();
            }
            for (int i8 = 0; i8 < length; i8++) {
                try {
                    double[] dArr2 = new double[9];
                    for (int i9 = 0; i9 < 9; i9++) {
                        dArr2[i9] = Math.log(random2.get(i9)) + dirichletDistribArr[i9].logLikelihood(loadData[i8]);
                    }
                    random3.set(EnumDistrib.log2Prob(dArr2));
                    arrayListArr[((Integer) random3.sample()).intValue()].add(loadData[i8]);
                } catch (RuntimeException e) {
                    System.err.println("Problem with data point k = " + i8);
                }
            }
            for (int i10 = 0; i10 < 9; i10++) {
                ?? r0 = new int[arrayListArr[i10].size()];
                for (int i11 = 0; i11 < arrayListArr[i10].size(); i11++) {
                    r0[i11] = (int[]) arrayListArr[i10].get(i11);
                }
                dirichletDistribArr[i10].setPrior(getAlpha(r0));
                random2.set(Integer.valueOf(i10), arrayListArr[i10].size() / length);
            }
            double DL2 = DL(loadData, random2, dirichletDistribArr);
            double mixEntropy = mixEntropy(loadData, random2, dirichletDistribArr);
            if (DL2 < DL) {
                DL = DL2;
                for (int i12 = 0; i12 < 9; i12++) {
                    System.arraycopy(dirichletDistribArr[i12].getAlpha(), 0, dArr[i12], 0, i);
                }
                enumDistrib = new EnumDistrib(new Enumerable(9), random2.get());
                i5 = 0;
            } else {
                i5++;
            }
            System.out.println("DL_cur = " + DL2 + "\tDL_best = " + DL + "\tEntropy = " + mixEntropy);
        }
        EnumDistrib enumDistrib2 = enumDistrib;
        System.out.println("M = " + enumDistrib2);
        for (int i13 = 0; i13 < dirichletDistribArr.length; i13++) {
            dirichletDistribArr[i13].setPrior(dArr[i13]);
            System.out.println("D" + i13 + " = " + dirichletDistribArr[i13]);
        }
        for (int i14 = 0; i14 < 9; i14++) {
            arrayListArr[i14] = new ArrayList();
        }
        for (int i15 = 0; i15 < length; i15++) {
            try {
                double[] dArr3 = new double[9];
                int i16 = 0;
                for (int i17 = 0; i17 < 9; i17++) {
                    dArr3[i17] = Math.log(enumDistrib2.get(i17)) + dirichletDistribArr[i17].logLikelihood(loadData[i15]);
                    if (i17 > 0 && dArr3[i17] > dArr3[i16]) {
                        i16 = i17;
                    }
                }
                random3.set(EnumDistrib.log2Prob(dArr3));
                arrayListArr[i16].add(Integer.valueOf(i15));
            } catch (RuntimeException e2) {
                System.err.println("Problem with data point k = " + i15);
            }
        }
        for (int i18 = 0; i18 < 9; i18++) {
            try {
                BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter("/Users/mikael/Desktop/blocks_counts.txt_bin_" + i18 + ".out"));
                for (int i19 = 0; i19 < arrayListArr[i18].size(); i19++) {
                    int intValue = ((Integer) arrayListArr[i18].get(i19)).intValue();
                    bufferedWriter.write(intValue + "\t");
                    for (int i20 = 0; i20 < loadData[intValue].length; i20++) {
                        bufferedWriter.write(loadData[intValue][i20] + "\t");
                    }
                    bufferedWriter.write(intValue + "\t");
                    EnumDistrib enumDistrib3 = new EnumDistrib(new Enumerable(i), loadData[intValue]);
                    for (int i21 = 0; i21 < enumDistrib3.getDomain().size(); i21++) {
                        bufferedWriter.write(enumDistrib3.get(i21) + "\t");
                    }
                    bufferedWriter.newLine();
                }
                bufferedWriter.close();
            } catch (IOException e3) {
                Logger.getLogger(DirichletDistrib.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e3);
            }
        }
    }

    /* JADX WARN: Type inference failed for: r0v71, types: [int[], int[][]] */
    public static void main0(String[] strArr) {
        Random random = new Random(serialVersionUID);
        EnumDistrib[] enumDistribArr = new EnumDistrib[9];
        for (int i = 0; i < enumDistribArr.length; i++) {
            enumDistribArr[i] = EnumDistrib.random(new Enumerable(10), random.nextInt());
        }
        int[][] generateData = generateData(enumDistribArr, 50, serialVersionUID);
        ArrayList[] arrayListArr = new ArrayList[9];
        DirichletDistrib[] dirichletDistribArr = new DirichletDistrib[9];
        for (int i2 = 0; i2 < dirichletDistribArr.length; i2++) {
            dirichletDistribArr[i2] = new DirichletDistrib(enumDistribArr[i2], 10.0d);
            System.out.println("D" + i2 + " = " + dirichletDistribArr[i2]);
        }
        EnumDistrib random2 = EnumDistrib.random(new Enumerable(9), random.nextInt());
        random2.setSeed(random.nextInt());
        EnumDistrib random3 = EnumDistrib.random(new Enumerable(9), random.nextInt());
        random3.setSeed(random.nextInt());
        double DL = DL(generateData, random2, dirichletDistribArr);
        double[][] dArr = new double[9][10];
        for (int i3 = 0; i3 < 9; i3++) {
            System.arraycopy(dirichletDistribArr[i3].getAlpha(), 0, dArr[i3], 0, 10);
        }
        EnumDistrib enumDistrib = new EnumDistrib(new Enumerable(9), random2.get());
        System.out.println("M = " + random2);
        System.out.println("P = " + random3);
        System.out.println("DL_best = " + DL);
        int i4 = 0;
        for (int i5 = 0; i5 < 100 && i4 < 10; i5++) {
            for (int i6 = 0; i6 < 9; i6++) {
                arrayListArr[i6] = new ArrayList();
            }
            int i7 = 0;
            for (int i8 = 0; i8 < 50; i8++) {
                double[] dArr2 = new double[9];
                for (int i9 = 0; i9 < 9; i9++) {
                    try {
                        dArr2[i9] = Math.log(random2.get(i9)) + dirichletDistribArr[i9].logLikelihood(generateData[i8]);
                    } catch (RuntimeException e) {
                        System.err.println("Problem with data point k = " + i8);
                    }
                }
                random3.set(EnumDistrib.log2Prob(dArr2));
                int intValue = ((Integer) random3.sample()).intValue();
                if (intValue == component[i8]) {
                    i7++;
                }
                arrayListArr[intValue].add(generateData[i8]);
            }
            System.out.println("Correct = " + i7);
            for (int i10 = 0; i10 < 9; i10++) {
                ?? r0 = new int[arrayListArr[i10].size()];
                for (int i11 = 0; i11 < arrayListArr[i10].size(); i11++) {
                    r0[i11] = (int[]) arrayListArr[i10].get(i11);
                }
                dirichletDistribArr[i10].setPrior(getAlpha(r0));
                random2.set(Integer.valueOf(i10), arrayListArr[i10].size() / 50);
            }
            System.out.println("M = " + random2);
            double DL2 = DL(generateData, random2, dirichletDistribArr);
            double mixEntropy = mixEntropy(generateData, random2, dirichletDistribArr);
            if (DL2 < DL) {
                DL = DL2;
                for (int i12 = 0; i12 < 9; i12++) {
                    System.arraycopy(dirichletDistribArr[i12].getAlpha(), 0, dArr[i12], 0, 10);
                }
                enumDistrib = new EnumDistrib(new Enumerable(9), random2.get());
                i4 = 0;
            } else {
                i4++;
            }
            System.out.println("DL_cur = " + DL2 + "\tDL_best = " + DL + "\tEntropy = " + mixEntropy);
        }
        System.out.println("M = " + enumDistrib);
        for (int i13 = 0; i13 < dirichletDistribArr.length; i13++) {
            dirichletDistribArr[i13].setPrior(dArr[i13]);
            System.out.println("D" + i13 + " = " + dirichletDistribArr[i13]);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [int[], int[][]] */
    public static int[][] generateData(EnumDistrib[] enumDistribArr, int i, long j) {
        Random random = new Random(j);
        ?? r0 = new int[i];
        component = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            int nextInt = random.nextInt(enumDistribArr.length);
            component[i2] = nextInt;
            int nextInt2 = 10 + random.nextInt(50);
            r0[i2] = new int[enumDistribArr[nextInt].getDomain().size()];
            for (int i3 = 0; i3 < nextInt2; i3++) {
                int[] iArr = r0[i2];
                int index = enumDistribArr[nextInt].getDomain().getIndex(enumDistribArr[nextInt].sample());
                iArr[index] = iArr[index] + 1;
            }
        }
        return r0;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v13, types: [int[]] */
    public static int[][] loadData(String str) {
        int[][] iArr = null;
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            ArrayList arrayList = new ArrayList();
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                String[] split = readLine.split("\t");
                int[] iArr2 = new int[split.length];
                for (int i = 0; i < split.length; i++) {
                    try {
                        iArr2[i] = Integer.valueOf(split[i]).intValue();
                    } catch (NumberFormatException e) {
                        System.err.println("Ignored: " + readLine);
                    }
                }
                arrayList.add(iArr2);
            }
            iArr = new int[arrayList.size()];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = new int[((int[]) arrayList.get(i2)).length];
                for (int i3 = 0; i3 < iArr[i2].length; i3++) {
                    iArr[i2][i3] = ((int[]) arrayList.get(i2))[i3];
                }
            }
            bufferedReader.close();
        } catch (IOException e2) {
            Logger.getLogger(DirichletDistrib.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
        }
        return iArr;
    }

    public void setPrior(EnumDistrib[] enumDistribArr) {
        if (enumDistribArr.length <= 0) {
            throw new RuntimeException("Invalid data for estimation of Dirichlet");
        }
        setPrior(findPrior(this.alpha, getSufficientStatistic(enumDistribArr, this.alpha.length)));
    }

    public static double[] getSufficientStatistic(EnumDistrib[] enumDistribArr, int i) {
        int length = enumDistribArr.length;
        double[] dArr = new double[i];
        for (EnumDistrib enumDistrib : enumDistribArr) {
            for (int i2 = 0; i2 < i; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + Math.log(enumDistrib.get(i2));
            }
        }
        for (int i4 = 0; i4 < i; i4++) {
            int i5 = i4;
            dArr[i5] = dArr[i5] / length;
        }
        return dArr;
    }

    private static double logProbForMultinomials(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i];
            d2 += GammaDistrib.lgamma(dArr[i]);
            d3 += dArr[i] * dArr2[i];
        }
        return (GammaDistrib.lgamma(d) - d2) + d3;
    }

    private static double[] getGradientForMultinomials(double[] dArr, double[] dArr2) {
        int length = dArr.length;
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        double digamma = GammaDistrib.digamma(d);
        double[] dArr3 = new double[length];
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr3[i2] = dArr3[i2] + ((digamma + dArr2[i]) - GammaDistrib.digamma(dArr[i]));
        }
        return dArr3;
    }

    private static double priorHessianConst(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return -GammaDistrib.trigamma(d);
    }

    private static double[] priorHessianDiag(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = GammaDistrib.trigamma(dArr[i]);
        }
        return dArr2;
    }

    private static double[] getPredictedStep(double[] dArr, double d, double[] dArr2, double[] dArr3) {
        int length = dArr.length;
        double[] dArr4 = new double[length];
        double d2 = 0.0d;
        for (int i = 0; i < length; i++) {
            d2 += dArr3[i] / dArr2[i];
        }
        double d3 = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            d3 += 1.0d / dArr2[i2];
        }
        double d4 = d2 / ((1.0d / d) + d3);
        for (int i3 = 0; i3 < length; i3++) {
            dArr4[i3] = (d4 - dArr3[i3]) / dArr2[i3];
        }
        return dArr4;
    }

    private static double[] getPredictedStepAlt(double[] dArr, double d, double[] dArr2, double[] dArr3) {
        int length = dArr.length;
        double[] dArr4 = new double[length];
        double d2 = 0.0d;
        for (int i = 0; i < length; i++) {
            d2 += dArr[i] / (dArr3[i] - (dArr[i] * dArr2[i]));
        }
        double d3 = d2 * d;
        double[] dArr5 = new double[length];
        double d4 = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            dArr5[i2] = (1.0d / (dArr3[i2] - (dArr[i2] * dArr2[i2]))) / (1.0d + d3);
            d4 += dArr5[i2];
        }
        for (int i3 = 0; i3 < length; i3++) {
            dArr4[i3] = (dArr3[i3] / (dArr3[i3] - (dArr[i3] * dArr2[i3]))) * (1.0d - ((d * dArr[i3]) * d4));
        }
        return dArr4;
    }

    private static double getTotalLoss(double[] dArr, double[] dArr2) {
        return (-1.0d) * logProbForMultinomials(dArr, dArr2);
    }

    private static double[] predictStepUsingHessian(double[] dArr, double[] dArr2) {
        return getPredictedStep(dArr, priorHessianConst(dArr), priorHessianDiag(dArr), dArr2);
    }

    private static double[] predictStepLogSpace(double[] dArr, double[] dArr2) {
        return getPredictedStepAlt(dArr, priorHessianConst(dArr), priorHessianDiag(dArr), dArr2);
    }

    private static double testTrialPriors(double[] dArr, double[] dArr2) {
        for (double d : dArr) {
            if (d <= 0.0d) {
                return Double.POSITIVE_INFINITY;
            }
        }
        return getTotalLoss(dArr, dArr2);
    }

    private static double sqVectorSize(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += Math.pow(d2, 2.0d);
        }
        return d;
    }

    public double[] findPrior(EnumDistrib[] enumDistribArr) {
        if (enumDistribArr.length <= 0) {
            throw new RuntimeException("Invalid data for estimation of Dirichlet");
        }
        return findPrior(this.alpha, getSufficientStatistic(enumDistribArr, this.alpha.length));
    }

    public static double[] findPrior(EnumDistrib[] enumDistribArr, double[] dArr) {
        return findPrior(dArr, getSufficientStatistic(enumDistribArr, dArr.length));
    }

    public static double[] findPrior(double[] dArr, double[] dArr2) {
        double d;
        double[] dArr3 = new double[dArr.length];
        System.arraycopy(dArr, 0, dArr3, 0, dArr3.length);
        double[] dArr4 = new double[dArr3.length];
        double totalLoss = getTotalLoss(dArr3, dArr2);
        double pow = Math.pow(2.0d, -20.0d);
        double pow2 = Math.pow(2.0d, -10.0d);
        int i = 0;
        while (i < 1000) {
            i++;
            double[] gradientForMultinomials = getGradientForMultinomials(dArr3, dArr2);
            if (sqVectorSize(gradientForMultinomials) < pow) {
                return dArr3;
            }
            double[] predictStepUsingHessian = predictStepUsingHessian(dArr3, gradientForMultinomials);
            for (int i2 = 0; i2 < dArr3.length; i2++) {
                dArr4[i2] = dArr3[i2] + predictStepUsingHessian[i2];
            }
            double testTrialPriors = testTrialPriors(dArr4, dArr2);
            if (testTrialPriors < totalLoss) {
                totalLoss = testTrialPriors;
                System.arraycopy(dArr4, 0, dArr3, 0, dArr3.length);
            } else {
                double[] predictStepLogSpace = predictStepLogSpace(dArr3, gradientForMultinomials);
                for (int i3 = 0; i3 < dArr3.length; i3++) {
                    dArr4[i3] = dArr3[i3] * Math.exp(predictStepLogSpace[i3]);
                }
                double testTrialPriors2 = testTrialPriors(dArr4, dArr2);
                if (testTrialPriors2 < totalLoss) {
                    totalLoss = testTrialPriors2;
                    System.arraycopy(dArr4, 0, dArr3, 0, dArr3.length);
                } else {
                    double d2 = Double.POSITIVE_INFINITY;
                    double d3 = 1.0d;
                    while (true) {
                        d = d3;
                        if (d2 <= totalLoss) {
                            break;
                        }
                        for (int i4 = 0; i4 < dArr3.length; i4++) {
                            dArr4[i4] = dArr3[i4] + (gradientForMultinomials[i4] * d);
                        }
                        d2 = testTrialPriors(dArr4, dArr2);
                        d3 = d * 0.9d;
                    }
                    if (d < pow2) {
                        return dArr4;
                    }
                    totalLoss = d2;
                    System.arraycopy(dArr4, 0, dArr3, 0, dArr3.length);
                }
            }
        }
        return dArr3;
    }

    public static void main1(String[] strArr) {
        new Random(serialVersionUID);
        Enumerable enumerable = new Enumerable(2);
        double[] sufficientStatistic = getSufficientStatistic(new EnumDistrib[]{new EnumDistrib(enumerable, 0.3d, 0.7d), new EnumDistrib(enumerable, 0.2d, 0.8d), new EnumDistrib(enumerable, 0.1d, 0.9d), new EnumDistrib(enumerable, 0.2d, 0.8d), new EnumDistrib(enumerable, 0.1d, 0.9d), new EnumDistrib(enumerable, 0.4d, 0.6d)}, 2);
        double[] findPrior = findPrior(new double[]{0.5d, 0.5d}, sufficientStatistic);
        System.out.println(new DirichletDistrib(enumerable, findPrior));
        System.out.println("Final loss = " + getTotalLoss(findPrior, sufficientStatistic));
        System.out.println("Best loss = " + getTotalLoss(new double[]{1.0d, 2.0d}, sufficientStatistic));
    }
}
