package bn.prob;

import bn.Distrib;
import dat.Enumerable;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;

/* JADX WARN: Classes with same name are omitted:
  input_file:target/classes/bn/prob/MixDirichletDistrib.class
 */
/* loaded from: input_file:bn/prob/MixDirichletDistrib.class */
public class MixDirichletDistrib extends MixtureDistrib implements Serializable {
    private final int ROUND_LIMITATION = 200;
    private final int NO_UPDATE = 20;
    private Enumerable domain;
    private double components;
    private double letters;
    private double DL_best;
    private EnumDistrib m_best;

    public MixDirichletDistrib(DirichletDistrib dirichletDistrib, double d) {
        super(dirichletDistrib, d);
        this.ROUND_LIMITATION = 200;
        this.NO_UPDATE = 20;
        this.domain = (Enumerable) dirichletDistrib.getDomain();
    }

    public MixDirichletDistrib(Enumerable enumerable, int i, int[][] iArr) {
        this.ROUND_LIMITATION = 200;
        this.NO_UPDATE = 20;
        Random random = new Random(System.currentTimeMillis());
        for (int i2 = 0; i2 < i; i2++) {
            int[] iArr2 = iArr[random.nextInt(iArr.length)];
            for (boolean z = true; z; z = false) {
                for (int i3 : iArr2) {
                    if (i3 == 0) {
                        iArr2 = iArr[random.nextInt(iArr.length)];
                    }
                }
            }
            super.addDistrib(new DirichletDistrib(new EnumDistrib(enumerable, iArr2), random.nextInt(90) + 10), random.nextDouble());
        }
        this.domain = enumerable;
        this.components = i;
        this.letters = enumerable.size();
    }

    public MixDirichletDistrib(Enumerable enumerable, int i) {
        this.ROUND_LIMITATION = 200;
        this.NO_UPDATE = 20;
        super.setSeed(System.currentTimeMillis());
        for (int i2 = 0; i2 < i; i2++) {
            super.addDistrib(new DirichletDistrib(EnumDistrib.random(enumerable, nextInt(Integer.MAX_VALUE)), nextInt(90) + 10), nextDouble());
        }
        this.domain = enumerable;
        this.components = i;
        this.letters = enumerable.size();
    }

    @Override // bn.prob.MixtureDistrib
    public double addDistrib(Distrib distrib, double d) {
        if (distrib instanceof DirichletDistrib) {
            if (getDomain().equals(((DirichletDistrib) distrib).getDomain())) {
                return super.addDistrib(distrib, d);
            }
            throw new RuntimeException("Domain should be the same");
        }
        if (!(distrib instanceof MixDirichletDistrib)) {
            throw new RuntimeException("only accept DirichletDistrib or MixDirichletDistrib");
        }
        if (getDomain().equals(((MixDirichletDistrib) distrib).getDomain())) {
            return super.addDistrib(distrib, d);
        }
        throw new RuntimeException("Domain should be the same");
    }

    public Enumerable getDomain() {
        return this.domain;
    }

    public int getLabel(int[] iArr) {
        double[] dArr = new double[this.distribs.size()];
        double[] dArr2 = new double[this.distribs.size()];
        double[] allWeights = getAllWeights();
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = ((DirichletDistrib) this.distribs.get(i2)).logLikelihood(iArr);
            double exp = Math.exp(dArr[i2]);
            dArr[i2] = Math.log(allWeights[i2]) + dArr[i2];
            dArr2[i2] = allWeights[i2] * exp;
            if (i2 > 0 && dArr[i2] > dArr[i]) {
                i = i2;
            }
        }
        return i;
    }

    private double DL(int[][] iArr) {
        double d = 0.0d;
        double[] allWeights = getAllWeights();
        for (int[] iArr2 : iArr) {
            double d2 = Double.MIN_VALUE;
            for (int i = 0; i < this.distribs.size(); i++) {
                d2 += allWeights[i] * Math.exp(((DirichletDistrib) this.distribs.get(i)).logLikelihood(iArr2));
            }
            d += Math.log(d2);
        }
        return -d;
    }

    /* JADX WARN: Type inference failed for: r0v71, types: [int[], int[][]] */
    public double learnParameters(int[][] iArr) {
        int length = iArr[0].length;
        int mixtureSize = getMixtureSize();
        int length2 = iArr.length;
        double DL = DL(iArr);
        double[][] dArr = new double[mixtureSize][length];
        EnumDistrib enumDistrib = new EnumDistrib(new Enumerable(mixtureSize), getAllWeights());
        ArrayList[] arrayListArr = new ArrayList[mixtureSize];
        getNormalized();
        for (int i = 0; i < mixtureSize; i++) {
            System.arraycopy(((DirichletDistrib) getDistrib(i)).getAlpha(), 0, dArr[i], 0, length);
        }
        EnumDistrib random = EnumDistrib.random(new Enumerable(mixtureSize), this.rand.nextInt());
        random.setSeed(this.rand.nextInt());
        int i2 = 0;
        boolean z = false;
        boolean z2 = false;
        int i3 = 0;
        while (i3 < 200 && i2 < 20) {
            for (int i4 = 0; i4 < mixtureSize; i4++) {
                arrayListArr[i4] = new ArrayList();
            }
            double[] allWeights = getAllWeights();
            HashMap hashMap = new HashMap();
            for (int i5 = 0; i5 < length2; i5++) {
                try {
                    double[] dArr2 = new double[mixtureSize];
                    double d = -1000000.0d;
                    for (int i6 = 0; i6 < mixtureSize; i6++) {
                        double log = Math.log(allWeights[i6]) + ((DirichletDistrib) getDistrib(i6)).logLikelihood(iArr[i5]);
                        dArr2[i6] = log;
                        if (log > d) {
                            d = log;
                        }
                    }
                    random.set(EnumDistrib.log2Prob(dArr2));
                    arrayListArr[((Integer) random.sample()).intValue()].add(iArr[i5]);
                    hashMap.put(iArr[i5], Double.valueOf(d));
                } catch (RuntimeException e) {
                    System.err.println("Problem with data point k = " + i5);
                    throw new RuntimeException("Problem with data point k = " + i5);
                }
            }
            for (int i7 = 0; i7 < mixtureSize; i7++) {
                if (arrayListArr[i7].size() == 0) {
                    System.err.println("Empty Bin : " + i7);
                    double doubleValue = ((Double) Collections.max(hashMap.values())).doubleValue();
                    Iterator it = hashMap.entrySet().iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            break;
                        }
                        Map.Entry entry = (Map.Entry) it.next();
                        if (((Double) entry.getValue()).doubleValue() == doubleValue) {
                            for (int i8 = 0; i8 < mixtureSize; i8++) {
                                for (int i9 = 0; i9 < arrayListArr[i8].size(); i9++) {
                                    if (arrayListArr[i8].get(i9) == entry.getKey()) {
                                        arrayListArr[i8].remove(i9);
                                    }
                                }
                            }
                            arrayListArr[i7].add(entry.getKey());
                            hashMap.remove(entry.getKey());
                        }
                    }
                    z = true;
                }
            }
            if (z) {
                System.out.println("Empty bins were re-initialised");
                z = false;
                z2 = true;
                i3 = 0;
            } else {
                z2 = true;
            }
            for (int i10 = 0; i10 < mixtureSize; i10++) {
                setWeight(i10, arrayListArr[i10].size());
                ?? r0 = new int[arrayListArr[i10].size()];
                for (int i11 = 0; i11 < arrayListArr[i10].size(); i11++) {
                    r0[i11] = (int[]) arrayListArr[i10].get(i11);
                }
                try {
                    ((DirichletDistrib) getDistrib(i10)).setPrior(DirichletDistrib.getAlpha(r0));
                } catch (NullPointerException e2) {
                    System.out.println(i10);
                }
            }
            getNormalized();
            double DL2 = DL(iArr);
            System.out.println("DL_cur = " + DL2 + "\tDL_best = " + DL);
            if (DL2 < DL) {
                setDLBest(DL2);
                DL = DL2;
                enumDistrib.set(getAllWeights());
                for (int i12 = 0; i12 < mixtureSize; i12++) {
                    System.arraycopy(((DirichletDistrib) getDistrib(i12)).getAlpha(), 0, dArr[i12], 0, length);
                }
                i2 = 0;
            } else {
                i2++;
            }
            i3++;
        }
        setWeights(enumDistrib.get());
        for (int i13 = 0; i13 < mixtureSize; i13++) {
            ((DirichletDistrib) getDistrib(i13)).setPrior(dArr[i13]);
        }
        if (!z2) {
            System.err.print("Failed to remove empty bins\n");
        }
        return DL;
    }

    /* JADX WARN: Type inference failed for: r0v70, types: [int[], int[][]] */
    public void learnParametersFlip(int[][] iArr) {
        int length = iArr[0].length;
        int mixtureSize = getMixtureSize();
        int length2 = iArr.length;
        double DL = DL(iArr);
        double[][] dArr = new double[mixtureSize][length];
        EnumDistrib enumDistrib = new EnumDistrib(new Enumerable(mixtureSize), getAllWeights());
        ArrayList[] arrayListArr = new ArrayList[mixtureSize];
        getNormalized();
        for (int i = 0; i < mixtureSize; i++) {
            System.arraycopy(((DirichletDistrib) getDistrib(i)).getAlpha(), 0, dArr[i], 0, length);
        }
        EnumDistrib random = EnumDistrib.random(new Enumerable(mixtureSize), this.rand.nextInt());
        random.setSeed(this.rand.nextInt());
        int i2 = 0;
        boolean z = false;
        boolean z2 = false;
        int i3 = 0;
        while (i3 < 200 && i2 < 20) {
            for (int i4 = 0; i4 < mixtureSize; i4++) {
                arrayListArr[i4] = new ArrayList();
            }
            double[] allWeights = getAllWeights();
            HashMap hashMap = new HashMap();
            for (int i5 = 0; i5 < length2; i5++) {
                int[] iArr2 = new int[iArr[i5].length];
                int[] iArr3 = iArr[i5];
                int length3 = iArr[i5].length / 2;
                if (iArr[i5].length % 2 == 1) {
                    length3 = (iArr[i5].length / 2) + 1;
                }
                for (int i6 = 0; i6 < length3; i6++) {
                    int i7 = iArr[i5][i6];
                    iArr2[i6] = iArr[i5][(iArr[i5].length - i6) - 1];
                    iArr2[(iArr[i5].length - i6) - 1] = i7;
                }
                boolean z3 = false;
                try {
                    double[] dArr2 = new double[mixtureSize];
                    double d = -1000000.0d;
                    for (int i8 = 0; i8 < mixtureSize; i8++) {
                        DirichletDistrib dirichletDistrib = (DirichletDistrib) getDistrib(i8);
                        double log = Math.log(allWeights[i8]) + dirichletDistrib.logLikelihood(iArr[i5]);
                        double log2 = Math.log(allWeights[i8]) + dirichletDistrib.logLikelihood(iArr2);
                        double d2 = log;
                        if (log < log2) {
                            z3 = true;
                            d2 = log2;
                        }
                        dArr2[i8] = d2;
                        if (d2 > d) {
                            d = d2;
                        }
                    }
                    random.set(EnumDistrib.log2Prob(dArr2));
                    Integer num = (Integer) random.sample();
                    if (z3) {
                        arrayListArr[num.intValue()].add(iArr2);
                        iArr[i5] = iArr2;
                    } else {
                        arrayListArr[num.intValue()].add(iArr[i5]);
                    }
                    if (z3) {
                        hashMap.put(iArr2, Double.valueOf(d));
                    } else {
                        hashMap.put(iArr[i5], Double.valueOf(d));
                    }
                } catch (RuntimeException e) {
                    System.err.println("Problem with data point k = " + i5);
                    throw new RuntimeException("Problem with data point k = " + i5);
                }
            }
            for (int i9 = 0; i9 < mixtureSize; i9++) {
                if (arrayListArr[i9].size() == 0) {
                    System.err.println("Empty Bin : " + i9);
                    double doubleValue = ((Double) Collections.max(hashMap.values())).doubleValue();
                    Iterator it = hashMap.entrySet().iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            break;
                        }
                        Map.Entry entry = (Map.Entry) it.next();
                        if (((Double) entry.getValue()).doubleValue() == doubleValue) {
                            for (int i10 = 0; i10 < mixtureSize; i10++) {
                                for (int i11 = 0; i11 < arrayListArr[i10].size(); i11++) {
                                    if (arrayListArr[i10].get(i11) == entry.getKey()) {
                                        arrayListArr[i10].remove(i11);
                                    }
                                }
                            }
                            arrayListArr[i9].add(entry.getKey());
                            hashMap.remove(entry.getKey());
                        }
                    }
                    z = true;
                }
            }
            if (z) {
                System.out.println("Empty bins were re-initialised");
                z = false;
                z2 = true;
                i3 = 0;
            } else {
                z2 = true;
            }
            for (int i12 = 0; i12 < mixtureSize; i12++) {
                setWeight(i12, arrayListArr[i12].size());
                ?? r0 = new int[arrayListArr[i12].size()];
                for (int i13 = 0; i13 < arrayListArr[i12].size(); i13++) {
                    r0[i13] = (int[]) arrayListArr[i12].get(i13);
                }
                try {
                    ((DirichletDistrib) getDistrib(i12)).setPrior(DirichletDistrib.getAlpha(r0));
                } catch (NullPointerException e2) {
                    System.out.println(i12);
                }
            }
            getNormalized();
            double DL2 = DL(iArr);
            System.out.println("DL_cur = " + DL2 + "\tDL_best = " + DL);
            if (DL2 < DL) {
                setDLBest(DL2);
                DL = DL2;
                enumDistrib.set(getAllWeights());
                for (int i14 = 0; i14 < mixtureSize; i14++) {
                    System.arraycopy(((DirichletDistrib) getDistrib(i14)).getAlpha(), 0, dArr[i14], 0, length);
                }
                i2 = 0;
            } else {
                i2++;
            }
            i3++;
        }
        setWeights(enumDistrib.get());
        for (int i15 = 0; i15 < mixtureSize; i15++) {
            ((DirichletDistrib) getDistrib(i15)).setPrior(dArr[i15]);
        }
        if (z2) {
            return;
        }
        System.err.print("Failed to remove empty bins\n");
    }

    public void setDLBest(double d) {
        this.DL_best = d;
    }

    public double getDLBest() {
        return this.DL_best;
    }

    public void setMBest(EnumDistrib enumDistrib) {
        this.m_best = enumDistrib;
    }

    public EnumDistrib getMBest() {
        return this.m_best;
    }

    public static double getEntropy(double[] dArr) {
        double log = Math.log(dArr.length);
        double d = 0.0d;
        for (double d2 : dArr) {
            double d3 = d2 + 1.0E-4d;
            d -= d3 * (Math.log(d3) / log);
        }
        return d;
    }

    public double getComplexity(int[][] iArr) {
        Double valueOf = Double.valueOf(iArr.length);
        Double valueOf2 = Double.valueOf(0.0d);
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr[i].length; i2++) {
                valueOf2 = Double.valueOf(valueOf2.doubleValue() + iArr[i][i2]);
            }
        }
        Double valueOf3 = Double.valueOf(((((this.letters / 2.0d) * Math.log(valueOf.doubleValue() / this.components)) + (((this.letters - 1.0d) / 2.0d) * Math.log(Double.valueOf(valueOf2.doubleValue() / (iArr.length * this.letters)).doubleValue() / 2.0d))) - GammaDistrib.lgamma(this.letters / 2.0d)) - (0.0d * Math.log(this.letters - 1.0d)));
        Double valueOf4 = Double.valueOf(((((this.components - 1.0d) / 2.0d) * Math.log(valueOf.doubleValue() / 2.0d)) + (0.0d * Math.log(3.141592653589793d))) - GammaDistrib.lgamma(this.components / 2.0d));
        double d = 1.0d;
        double d2 = 1.0d;
        while (true) {
            double d3 = d2;
            if (d3 > this.components) {
                return Double.valueOf((valueOf4.doubleValue() + (this.components * valueOf3.doubleValue())) - Math.log(d)).doubleValue();
            }
            d *= d3;
            d2 = d3 + 1.0d;
        }
    }

    public void saveAlphas(String str) {
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str + "_alpha_" + ((int) this.components) + ".out"));
            for (int i = 0; i < this.components; i++) {
                bufferedWriter.write("D" + i + "," + ((DirichletDistrib) getDistrib(i)));
                bufferedWriter.newLine();
            }
            bufferedWriter.close();
        } catch (IOException e) {
            Logger.getLogger(DirichletDistrib.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    public void saveClusters(int[][] iArr, String str) {
        int length = iArr.length;
        int i = (int) this.components;
        ArrayList[] arrayListArr = new ArrayList[i];
        for (int i2 = 0; i2 < i; i2++) {
            arrayListArr[i2] = new ArrayList();
        }
        for (int i3 = 0; i3 < length; i3++) {
            arrayListArr[getLabel(iArr[i3])].add(Integer.valueOf(i3));
        }
        for (int i4 = 0; i4 < i; i4++) {
            try {
                BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str + "_bin_" + i4 + "_" + i + ".out"));
                for (int i5 = 0; i5 < arrayListArr[i4].size(); i5++) {
                    bufferedWriter.write("" + ((Integer) arrayListArr[i4].get(i5)).intValue());
                    bufferedWriter.newLine();
                }
                bufferedWriter.close();
            } catch (IOException e) {
                Logger.getLogger(DirichletDistrib.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
        }
    }

    public static void main(String[] strArr) {
        int parseInt = Integer.parseInt(strArr[1]);
        int parseInt2 = Integer.parseInt(strArr[2]);
        int[][] loadData = loadData("wgEncodeH1hescSrf_seg20_500_srf_hg19.out");
        int length = loadData.length;
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            if (i == 0) {
                i = loadData[i2].length;
            } else if (i != loadData[i2].length) {
                throw new RuntimeException("Error in data: invalid item at data point " + (i2 + 1));
            }
        }
        Enumerable enumerable = new Enumerable(loadData[0].length);
        double[] dArr = new double[parseInt2 - parseInt];
        double[] dArr2 = new double[parseInt2 - parseInt];
        for (int i3 = parseInt; i3 < parseInt2; i3++) {
            System.out.println("nbins = " + i3);
            MixDirichletDistrib mixDirichletDistrib = new MixDirichletDistrib(enumerable, i3, loadData);
            mixDirichletDistrib.learnParametersFlip(loadData);
            mixDirichletDistrib.saveAlphas("wgEncodeH1hescSrf_seg20_500_srf_hg19.out");
            dArr[i3 - parseInt] = mixDirichletDistrib.getDLBest();
            dArr2[i3 - parseInt] = mixDirichletDistrib.getComplexity(loadData);
            mixDirichletDistrib.saveClusters(loadData, "wgEncodeH1hescSrf_seg20_500_srf_hg19.out");
        }
        double[] dArr3 = new double[parseInt2 - parseInt];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr3[i4] = dArr[i4] + dArr2[i4];
        }
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter("wgEncodeH1hescSrf_seg20_500_srf_hg19.out_comp.out"));
            for (double d : dArr3) {
                bufferedWriter.write(parseInt + "\t" + d);
                bufferedWriter.newLine();
                parseInt++;
            }
            bufferedWriter.close();
        } catch (IOException e) {
            Logger.getLogger(DirichletDistrib.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        System.out.println("COMPLETE");
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v13, types: [int[]] */
    public static int[][] loadData(String str) {
        int[][] iArr = null;
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            ArrayList arrayList = new ArrayList();
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                String[] split = readLine.split("\t");
                int[] iArr2 = new int[split.length];
                for (int i = 0; i < split.length; i++) {
                    try {
                        iArr2[i] = Integer.valueOf(split[i]).intValue();
                    } catch (NumberFormatException e) {
                        System.err.println("Ignored: " + readLine);
                    }
                }
                arrayList.add(iArr2);
            }
            iArr = new int[arrayList.size()];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = new int[((int[]) arrayList.get(i2)).length];
                for (int i3 = 0; i3 < iArr[i2].length; i3++) {
                    iArr[i2][i3] = ((int[]) arrayList.get(i2))[i3];
                }
            }
            bufferedReader.close();
        } catch (IOException e2) {
            Logger.getLogger(DirichletDistrib.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
        }
        return iArr;
    }
}
