package bn.example;

import bn.BNet;
import bn.Distrib;
import bn.Predef;
import bn.alg.CGTable;
import bn.alg.EM;
import bn.alg.Query;
import bn.alg.VarElim;
import bn.file.BNBuf;
import bn.node.CPT;
import bn.node.DirDT;
import bn.prob.DirichletDistrib;
import bn.prob.EnumDistrib;
import dat.Continuous;
import dat.EnumVariable;
import dat.Enumerable;
import dat.IntegerSeq;
import dat.Variable;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.maven.surefire.shared.compress.archivers.tar.TarArchiveEntry;

/* loaded from: input_file:bn/example/DirDTExample.class */
public class DirDTExample {
    public static void main(String[] strArr) {
        EnumVariable Number;
        DirDT dirDT;
        int[][] loadData = loadData("/Users/uqbbalde/Desktop/Uni_Studies/myJava/bnkit/bnkit/data/mm10_Mixed_NfiX_segmented20_100.out");
        int length = loadData.length;
        Object[][] objArr = new Object[length][1];
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            if (i2 == 0) {
                i = loadData[i2].length;
            } else if (i != loadData[i2].length) {
                throw new RuntimeException("Error in data: invalid item at data point " + (i2 + 1));
            }
            objArr[i2][0] = new IntegerSeq(loadData[i2]);
        }
        BNet bNet = new BNet();
        if (strArr.length > 0) {
            bNet = BNBuf.load(strArr[0]);
            Number = ((CPT) bNet.getNode("Cluster")).getVariable();
            dirDT = (DirDT) bNet.getNode("Segments");
            dirDT.getVariable();
        } else {
            System.out.println("xxxxxxxxxxxxx");
            Number = Predef.Number(9, "Cluster");
            Variable<EnumDistrib> Distrib = Predef.Distrib(new Enumerable(i), "Segments");
            CPT cpt = new CPT(Number);
            dirDT = new DirDT(Distrib, Number);
            bNet.add(cpt, dirDT);
            new EM(bNet).train(objArr, new Variable[]{Distrib}, 1L);
            cpt.print();
            dirDT.print();
            BNBuf.save(bNet, "/Users/uqbbalde/Desktop/Uni_Studies/myJava/bnkit/bnkit/data/mm10_Mixed_NfiX_segmented20_100.out" + ".xml");
        }
        System.out.println("yyyyyyyyyyyy");
        ArrayList[] arrayListArr = new ArrayList[9];
        for (int i3 = 0; i3 < 9; i3++) {
            arrayListArr[i3] = new ArrayList();
        }
        VarElim varElim = new VarElim();
        varElim.instantiate(bNet);
        for (int i4 = 0; i4 < objArr.length; i4++) {
            dirDT.setInstance(objArr[i4][0]);
            arrayListArr[((EnumDistrib) ((CGTable) varElim.infer(varElim.makeQuery(Number))).query(Number)).getMaxIndex()].add(Integer.valueOf(i4));
        }
        for (int i5 = 0; i5 < 9; i5++) {
            try {
                BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter("/Users/uqbbalde/Desktop/Uni_Studies/myJava/bnkit/bnkit/data/mm10_Mixed_NfiX_segmented20_100.out" + "_bin_" + i5 + ".out"));
                for (int i6 = 0; i6 < arrayListArr[i5].size(); i6++) {
                    int intValue = ((Integer) arrayListArr[i5].get(i6)).intValue();
                    bufferedWriter.write(intValue + "\t");
                    IntegerSeq integerSeq = (IntegerSeq) objArr[intValue][0];
                    for (int i7 = 0; i7 < integerSeq.get().length; i7++) {
                        bufferedWriter.write(String.valueOf(integerSeq.get()[i7]) + "\t");
                    }
                    bufferedWriter.write(intValue + "\t");
                    EnumDistrib enumDistrib = new EnumDistrib(new Enumerable(i), IntegerSeq.intArray(integerSeq.get()));
                    for (int i8 = 0; i8 < enumDistrib.getDomain().size(); i8++) {
                        bufferedWriter.write(enumDistrib.get(i8) + "\t");
                    }
                    bufferedWriter.newLine();
                }
                bufferedWriter.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v13, types: [int[]] */
    public static int[][] loadData(String str) {
        int[][] iArr = null;
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            ArrayList arrayList = new ArrayList();
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                String[] split = readLine.split("\t");
                int[] iArr2 = new int[split.length];
                for (int i = 0; i < split.length; i++) {
                    try {
                        iArr2[i] = Integer.valueOf(split[i]).intValue();
                    } catch (NumberFormatException e) {
                        System.err.println("Ignored: " + readLine);
                    }
                }
                arrayList.add(iArr2);
            }
            iArr = new int[arrayList.size()];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = new int[((int[]) arrayList.get(i2)).length];
                for (int i3 = 0; i3 < iArr[i2].length; i3++) {
                    iArr[i2][i3] = ((int[]) arrayList.get(i2))[i3];
                }
            }
            bufferedReader.close();
        } catch (IOException e2) {
            Logger.getLogger(DirichletDistrib.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
        }
        return iArr;
    }

    public static void main0(String[] strArr) {
        Random random = new Random(1L);
        EnumVariable Nominal = Predef.Nominal(new String[]{"Male", "Female"}, "Gender");
        Enumerable enumerable = new Enumerable(new String[]{"Pink", "Green", "Blue"});
        Enumerable enumerable2 = new Enumerable(new String[]{"Netball", "Soccer", "Rugby"});
        Variable<EnumDistrib> Distrib = Predef.Distrib(enumerable, "Colours");
        Variable<EnumDistrib> Distrib2 = Predef.Distrib(enumerable2, "Sports");
        CPT cpt = new CPT(Nominal);
        DirDT dirDT = new DirDT(Distrib, Nominal);
        DirDT dirDT2 = new DirDT(Distrib2, Nominal);
        cpt.put(new EnumDistrib(new Enumerable(new String[]{"Male", "Female"}), 0.49d, 0.51d));
        dirDT.put(new DirichletDistrib(enumerable, 3.0d, 5.0d, 7.0d), "Male");
        dirDT.put(new DirichletDistrib(enumerable, 7.0d, 2.0d, 3.0d), "Female");
        dirDT2.put(new DirichletDistrib(enumerable2, 2.0d, 5.0d, 5.0d), "Male");
        dirDT2.put(new DirichletDistrib(enumerable2, 9.0d, 4.0d, 3.0d), "Female");
        BNet bNet = new BNet();
        bNet.add(cpt, dirDT, dirDT2);
        dirDT.setInstance(new EnumDistrib(enumerable, 0.3d, 0.3d, 0.4d));
        VarElim varElim = new VarElim();
        varElim.instantiate(bNet);
        CGTable cGTable = (CGTable) varElim.infer(varElim.makeQuery(Nominal, Distrib2));
        cGTable.display();
        Distrib query = cGTable.query(Distrib2);
        System.out.println("Prob of sports: " + String.valueOf(query));
        double[] dArr = new double[enumerable2.size()];
        for (int i = 0; i < 20; i++) {
            EnumDistrib enumDistrib = (EnumDistrib) query.sample();
            for (int i2 = 0; i2 < enumerable2.size(); i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (enumDistrib.get(i2) / 20);
            }
            System.out.println("\t" + (i + 1) + "\t" + String.valueOf(enumDistrib));
        }
        System.out.print("\tMean\t");
        for (int i4 = 0; i4 < enumerable2.size(); i4++) {
            System.out.print(String.format("%5.2f", Double.valueOf(dArr[i4])));
        }
        System.out.println();
        System.out.println("Prob of gender: " + String.valueOf(cGTable.query(Nominal)));
        DirichletDistrib dirichletDistrib = new DirichletDistrib(enumerable, 3.0d, 5.0d, 7.0d);
        DirichletDistrib dirichletDistrib2 = new DirichletDistrib(enumerable, 7.0d, 2.0d, 3.0d);
        DirichletDistrib dirichletDistrib3 = new DirichletDistrib(enumerable2, 2.0d, 5.0d, 5.0d);
        DirichletDistrib dirichletDistrib4 = new DirichletDistrib(enumerable2, 9.0d, 4.0d, 3.0d);
        Object[][] objArr = new Object[200][3];
        for (int i5 = 0; i5 < 200; i5++) {
            if (i5 % 2 == 0) {
                EnumDistrib enumDistrib2 = (EnumDistrib) dirichletDistrib.sample();
                IntegerSeq integerSeq = new IntegerSeq(new Continuous());
                int[] iArr = new int[3];
                for (int i6 = 0; i6 < random.nextInt(100); i6++) {
                    int index = enumDistrib2.getDomain().getIndex(enumDistrib2.sample());
                    iArr[index] = iArr[index] + 1;
                }
                integerSeq.set(IntegerSeq.objArray(iArr));
                objArr[i5][1] = integerSeq;
                EnumDistrib enumDistrib3 = (EnumDistrib) dirichletDistrib3.sample();
                IntegerSeq integerSeq2 = new IntegerSeq(new Continuous());
                int[] iArr2 = new int[3];
                for (int i7 = 0; i7 < random.nextInt(100); i7++) {
                    int index2 = enumDistrib3.getDomain().getIndex(enumDistrib3.sample());
                    iArr2[index2] = iArr2[index2] + 1;
                }
                integerSeq2.set(IntegerSeq.objArray(iArr2));
                objArr[i5][2] = integerSeq2;
            } else {
                EnumDistrib enumDistrib4 = (EnumDistrib) dirichletDistrib2.sample();
                IntegerSeq integerSeq3 = new IntegerSeq(new Continuous());
                int[] iArr3 = new int[3];
                for (int i8 = 0; i8 < random.nextInt(100); i8++) {
                    int index3 = enumDistrib4.getDomain().getIndex(enumDistrib4.sample());
                    iArr3[index3] = iArr3[index3] + 1;
                }
                integerSeq3.set(IntegerSeq.objArray(iArr3));
                objArr[i5][1] = integerSeq3;
                EnumDistrib enumDistrib5 = (EnumDistrib) dirichletDistrib4.sample();
                IntegerSeq integerSeq4 = new IntegerSeq(new Continuous());
                int[] iArr4 = new int[3];
                for (int i9 = 0; i9 < random.nextInt(100); i9++) {
                    int index4 = enumDistrib5.getDomain().getIndex(enumDistrib5.sample());
                    iArr4[index4] = iArr4[index4] + 1;
                }
                integerSeq4.set(IntegerSeq.objArray(iArr4));
                objArr[i5][2] = integerSeq4;
            }
        }
        EnumVariable Number = Predef.Number(10, "Cluster");
        new Enumerable(new String[]{"Pink", "Green", "Blue"});
        new Enumerable(new String[]{"Netball", "Soccer", "Rugby"});
        Variable<EnumDistrib> Distrib3 = Predef.Distrib(enumerable, "Colours");
        Variable<EnumDistrib> Distrib4 = Predef.Distrib(enumerable2, "Sports");
        CPT cpt2 = new CPT(Number);
        DirDT dirDT3 = new DirDT(Distrib3, Number);
        DirDT dirDT4 = new DirDT(Distrib4, Number);
        BNet bNet2 = new BNet();
        bNet2.add(cpt2, dirDT3, dirDT4);
        EM em = new EM(bNet2);
        System.out.println("Before EM");
        cpt2.randomize(1L);
        cpt2.print();
        dirDT3.randomize(1L);
        dirDT3.print();
        dirDT4.randomize(1L);
        dirDT4.print();
        em.EM_MAX_ROUNDS = 5;
        em.train(objArr, new Variable[]{Number, Distrib3, Distrib4}, 0L);
        System.out.println("After EM (1)");
        cpt2.print();
        dirDT3.print();
        dirDT4.print();
        dirDT3.setInstance(new EnumDistrib(enumerable, 0.1d, 0.1d, 0.8d));
        VarElim varElim2 = new VarElim();
        varElim2.instantiate(bNet2);
        Query makeQuery = varElim2.makeQuery(Number, Distrib4);
        System.out.println(makeQuery);
        ((CGTable) varElim2.infer(makeQuery)).display();
        dirDT3.setInstance(new EnumDistrib(enumerable, 0.8d, 0.1d, 0.1d));
        Query makeQuery2 = varElim2.makeQuery(Number, Distrib4);
        System.out.println(makeQuery2);
        ((CGTable) varElim2.infer(makeQuery2)).display();
        em.train(objArr, new Variable[]{Number, Distrib3, Distrib4}, 0L);
        System.out.println("After EM (2)");
        cpt2.print();
        dirDT3.print();
        dirDT4.print();
        dirDT3.setInstance(IntegerSeq.intSeq(new int[]{1, 1, 8}));
        VarElim varElim3 = new VarElim();
        varElim3.instantiate(bNet2);
        Query makeQuery3 = varElim3.makeQuery(Number, Distrib4);
        System.out.println(makeQuery3);
        ((CGTable) varElim3.infer(makeQuery3)).display();
        dirDT3.setInstance(IntegerSeq.intSeq(new int[]{8, 1, 1}));
        Query makeQuery4 = varElim3.makeQuery(Number, Distrib4);
        System.out.println(makeQuery4);
        ((CGTable) varElim3.infer(makeQuery4)).display();
        em.EM_MAX_ROUNDS = TarArchiveEntry.MILLIS_PER_SECOND;
        em.train(objArr, new Variable[]{Number, Distrib3, Distrib4}, 0L);
        System.out.println("After EM (3)");
        cpt2.print();
        dirDT3.print();
        dirDT4.print();
        dirDT3.setInstance(IntegerSeq.intSeq(new int[]{1, 1, 8}));
        VarElim varElim4 = new VarElim();
        varElim4.instantiate(bNet2);
        Query makeQuery5 = varElim4.makeQuery(Number, Distrib4);
        System.out.println(makeQuery5);
        ((CGTable) varElim4.infer(makeQuery5)).display();
        dirDT3.setInstance(IntegerSeq.intSeq(new int[]{8, 1, 1}));
        Query makeQuery6 = varElim4.makeQuery(Number, Distrib4);
        System.out.println(makeQuery6);
        ((CGTable) varElim4.infer(makeQuery6)).display();
    }
}
