package bn.ctmc;

import bn.BNode;
import bn.Distrib;
import bn.alg.CGTable;
import bn.alg.Query;
import bn.alg.VarElim;
import bn.ctmc.matrix.JC;
import bn.factor.FactorCache;
import bn.prob.EnumDistrib;
import dat.EnumSeq;
import dat.EnumVariable;
import dat.Enumerable;
import dat.Variable;
import dat.file.FastaWriter;
import dat.file.Newick;
import dat.phylo.PhyloBN;
import dat.phylo.Tree;
import dat.phylo.TreeInstance;
import java.io.IOException;
import java.io.PrintStream;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import util.MilliTimer;

/* loaded from: input_file:bn/ctmc/SubstNodeTest.class */
class SubstNodeTest {
    SubstModel model = SubstModel.createModel("JTT");
    FactorCache cache = new FactorCache();
    boolean CACHE_FACTORS = true;

    SubstNodeTest() {
    }

    PhyloBN getBarebone(int i, int i2, double d) {
        SubstNode[] substNodeArr = new SubstNode[i];
        double[] dArr = new double[i];
        EnumVariable[] enumVariableArr = new EnumVariable[i];
        long currentTimeMillis = System.currentTimeMillis();
        Random random = new Random(i);
        double d2 = d / i;
        double d3 = 0.0d;
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < substNodeArr.length; i3++) {
            double nextInt = d2 + (((random.nextInt(r0 - 1) + 1) / ((i2 / 2) + 1)) * d2 * (random.nextBoolean() ? -1 : 1));
            dArr[i3] = nextInt;
            if (!hashMap.containsKey(Double.valueOf(nextInt))) {
                hashMap.put(Double.valueOf(nextInt), 0);
            }
            hashMap.put(Double.valueOf(nextInt), Integer.valueOf(((Integer) hashMap.get(Double.valueOf(nextInt))).intValue() + 1));
            enumVariableArr[i3] = new EnumVariable(Enumerable.aacid, "N" + i3);
            if (i3 == 0) {
                substNodeArr[i3] = new SubstNode(enumVariableArr[i3], this.model);
            } else {
                substNodeArr[i3] = new SubstNode(enumVariableArr[i3], enumVariableArr[i3 - 1], this.model, nextInt);
            }
            d3 += nextInt;
        }
        PhyloBN createBarebone = PhyloBN.createBarebone(substNodeArr);
        if (this.CACHE_FACTORS) {
            PrintStream printStream = System.out;
            createBarebone.setCache(this.cache);
            printStream.println("Creating BN stretching " + d3 + " time units, over " + printStream + " nodes, of which " + i + " are cache-enabled");
        }
        System.out.println("Number of different time distances: " + hashMap.size() + " \t");
        for (Map.Entry entry : hashMap.entrySet()) {
            System.out.print(entry.getKey() + ": " + entry.getValue() + ", ");
        }
        System.out.println();
        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
        System.out.println(String.format("Creating %d nodes in %d min, %d s or %d ms", Integer.valueOf(i), Long.valueOf(TimeUnit.MILLISECONDS.toMinutes(currentTimeMillis2)), Long.valueOf(TimeUnit.MILLISECONDS.toSeconds(currentTimeMillis2) - TimeUnit.MINUTES.toSeconds(TimeUnit.MILLISECONDS.toMinutes(currentTimeMillis2))), Long.valueOf(TimeUnit.MILLISECONDS.toMillis(currentTimeMillis2))));
        Assertions.assertTrue(createBarebone.isValid());
        createBarebone.getBNode(0);
        HashMap hashMap2 = new HashMap();
        for (Object obj : this.model.getDomain().getValues()) {
            hashMap2.put(obj, Double.valueOf(1.0d));
        }
        for (int i4 = 0; i4 < createBarebone.getBN().getNodes().size(); i4++) {
            BNode bNode = createBarebone.getBNode(i4);
            for (Object obj2 : this.model.getDomain().getValues()) {
                double doubleValue = bNode.isRoot() ? bNode.get(obj2).doubleValue() : bNode.get(obj2, obj2).doubleValue();
                if (i4 >= 2) {
                    if (dArr[i4] <= dArr[i4 - 1]) {
                        Assertions.assertTrue(((Double) hashMap2.get(obj2)).doubleValue() <= doubleValue);
                    } else {
                        Assertions.assertTrue(((Double) hashMap2.get(obj2)).doubleValue() > doubleValue);
                    }
                }
                double d4 = 0.0d;
                for (Object obj3 : this.model.getDomain().getValues()) {
                    if (obj3 != obj2) {
                        double doubleValue2 = bNode.isRoot() ? bNode.get(obj3).doubleValue() : bNode.get(obj3, obj2).doubleValue();
                        Assertions.assertTrue(doubleValue2 <= 1.0d && doubleValue2 >= 0.0d);
                        d4 += doubleValue2;
                    }
                }
                Assertions.assertTrue(d4 + doubleValue > 0.99d && d4 + doubleValue < 1.01d);
                hashMap2.put(obj2, Double.valueOf(doubleValue));
            }
        }
        return createBarebone;
    }

    Tree getTree(int i, long j) {
        String[] strArr = new String[i];
        for (int i2 = 0; i2 < i; i2++) {
            strArr[i2] = "S" + String.format("%03d", Integer.valueOf(i2));
        }
        return Tree.Random(strArr, j, 1.1d, 5.0d, 2, 2);
    }

    @Test
    void getProb() {
        double[] dArr = new double[200];
        getBarebone(200, 10, 5.0d).getBNode(200 - 1).print();
    }

    @Test
    void getProb_Marginal1() {
        MilliTimer milliTimer = new MilliTimer();
        PhyloBN barebone = getBarebone(200, 10, 5.0d);
        BNode bNode = barebone.getBNode(0);
        BNode bNode2 = barebone.getBNode(200 - 1);
        int i = 0;
        for (Object obj : new Object[]{'A', 'C', 'D', 'R', 'P'}) {
            milliTimer.start("Reset");
            barebone.getBN().resetNodes();
            milliTimer.stopStart("Reset", "setInst");
            bNode2.setInstance(obj);
            milliTimer.stopStart("setInst", "VE inst");
            VarElim varElim = new VarElim();
            varElim.instantiate(barebone.getBN());
            milliTimer.stopStart("VE inst", "makeQuery");
            Query makeQuery = varElim.makeQuery(bNode.getVariable());
            milliTimer.stopStart("makeQuery", "Infer");
            CGTable cGTable = (CGTable) varElim.infer(makeQuery);
            milliTimer.stopStart("Infer", "Distrib");
            Distrib query = cGTable.query(bNode.getVariable());
            milliTimer.stop("Distrib");
            System.out.println(bNode + " is " + query);
            System.out.println("Cache size: " + this.cache.size());
            i++;
        }
        milliTimer.report();
        this.cache.reportCache();
    }

    @Test
    void getProb_Joint1() {
        MilliTimer milliTimer = new MilliTimer();
        PhyloBN barebone = getBarebone(200, 10, 5.0d);
        barebone.getBNode(0);
        BNode bNode = barebone.getBNode(200 - 1);
        int i = 0;
        for (Object obj : new Object[]{'W', 'C', 'D', 'R', 'P'}) {
            milliTimer.start("Reset");
            barebone.getBN().resetNodes();
            milliTimer.stopStart("Reset", "setInst");
            bNode.setInstance(obj);
            milliTimer.stopStart("setInst", "VE inst");
            VarElim varElim = new VarElim();
            varElim.instantiate(barebone.getBN());
            milliTimer.stopStart("VE inst", "makeMPE");
            Query makeMPE = varElim.makeMPE(new Variable[0]);
            milliTimer.stopStart("makeMPE", "Infer");
            CGTable cGTable = (CGTable) varElim.infer(makeMPE);
            varElim.timer.report();
            milliTimer.stopStart("Infer", "Assign");
            Variable.Assignment[] mpe = cGTable.getMPE();
            milliTimer.stop("Assign");
            for (Variable.Assignment assignment : mpe) {
                if (assignment.var.getName().equals("N0")) {
                    System.out.println("\t" + assignment);
                }
            }
            System.out.println("Cache size: " + this.cache.size());
            i++;
        }
        milliTimer.report();
        this.cache.reportCache();
    }

    @Test
    void getProb_Marginal2() {
        MilliTimer milliTimer = new MilliTimer();
        Random random = new Random(1L);
        Object[] values = this.model.getDomain().getValues();
        Tree tree = getTree(20, 1L);
        try {
            Newick.save(tree, "/Users/mikael/simhome/ASR/infer2022/marg20m.nwk", Newick.MODE_ANCESTOR);
        } catch (IOException e) {
            System.err.println(e);
            System.exit(1);
        }
        PhyloBN create = PhyloBN.create(tree, this.model);
        System.out.println("Cached nodes: " + create.setCache(this.cache) + " of " + create.getBN().getNodes().size());
        for (int i = 0; i < 100; i++) {
            milliTimer.start("Reset");
            create.getBN().resetNodes();
            milliTimer.stopStart("Reset", "setInst");
            for (int i2 : tree.getLeaves()) {
                create.getBN().getNode(tree.getLabel(i2).toString()).setInstance(values[random.nextInt(values.length)]);
            }
            milliTimer.stopStart("setInst", "VE inst");
            VarElim varElim = new VarElim();
            varElim.instantiate(create.getBN());
            milliTimer.stopStart("VE inst", "makeQuery");
            BNode bNode = create.getBNode(0);
            Query makeQuery = varElim.makeQuery(bNode.getVariable());
            milliTimer.stopStart("makeQuery", "Infer");
            CGTable cGTable = (CGTable) varElim.infer(makeQuery);
            milliTimer.stopStart("Infer", "Distrib");
            Distrib query = cGTable.query(bNode.getVariable());
            milliTimer.stop("Distrib");
            System.out.println(bNode + " is " + query);
            System.out.println("Cache size: " + this.cache.size());
        }
        milliTimer.report();
        this.cache.reportCache();
    }

    @Test
    void compare_Joint_Marg() {
        Random random = new Random(1L);
        Object[] values = this.model.getDomain().getValues();
        double d = 0.0d;
        for (int i = 2; i < 100; i++) {
            int i2 = 0;
            int i3 = 0;
            int i4 = 0;
            Tree tree = getTree(i, 1 + i);
            PhyloBN create = PhyloBN.create(tree, this.model);
            create.setCache(this.cache);
            PhyloBN create2 = PhyloBN.create(tree, this.model);
            create2.setCache(this.cache);
            Object[] objArr = new Object[tree.getSize()];
            Object[][] objArr2 = new Object[tree.getNLeaves()][1];
            String[] strArr = new String[tree.getNLeaves()];
            int i5 = 0;
            for (int i6 : tree.getLeaves()) {
                Object obj = values[random.nextInt(values.length)];
                objArr[i6] = obj;
                new EnumSeq.Gappy(this.model.getDomain());
                BNode bNode = create.getBNode(i6);
                BNode bNode2 = create2.getBNode(i6);
                bNode.setInstance(obj);
                bNode2.setInstance(obj);
                objArr2[i5][0] = obj;
                int i7 = i5;
                i5++;
                strArr[i7] = tree.getLabel(i6).toString();
            }
            VarElim varElim = new VarElim();
            varElim.instantiate(create.getBN());
            VarElim varElim2 = new VarElim();
            varElim2.instantiate(create2.getBN());
            Map<Variable, Object> map = Variable.Assignment.toMap(((CGTable) varElim2.infer(varElim2.makeMPE(new Variable[0]))).getMPE());
            for (int i8 : tree.getAncestors()) {
                BNode bNode3 = create2.getBNode(i8);
                tree.getBranchPoint(i8).setLabel(bNode3.getName());
                Object obj2 = map.get(bNode3.getVariable());
                objArr[i8] = obj2;
                BNode bNode4 = create.getBNode(i8);
                EnumDistrib enumDistrib = (EnumDistrib) ((CGTable) varElim.infer(varElim.makeQuery(bNode4.getVariable()))).query(bNode4.getVariable());
                double d2 = enumDistrib.get(obj2);
                Object obj3 = values[enumDistrib.getMaxIndex()];
                i2++;
                for (int i9 = 0; i9 < enumDistrib.getDomain().size(); i9++) {
                    if (enumDistrib.get(i9) >= d2) {
                        i4++;
                    }
                }
                if (obj2.equals(obj3)) {
                    i3++;
                } else {
                    objArr[i8] = obj2.toString() + "_" + obj3.toString();
                }
            }
            double d3 = i4 / i2;
            d += d3;
            PrintStream printStream = System.out;
            int i10 = i;
            int i11 = i3;
            int i12 = i2;
            if (d3 > 1.5d) {
            }
            printStream.println("With " + i10 + ":\t" + i11 + " are same from " + i12 + " with average rank at " + d3 + printStream);
            if (d3 > 1.5d) {
                HashMap hashMap = new HashMap();
                for (Map.Entry<Variable, Object> entry : map.entrySet()) {
                    hashMap.put(entry.getKey().getName(), entry.getValue());
                }
                TreeInstance treeInstance = new TreeInstance(tree, objArr);
                try {
                    Newick.save(tree, "/Users/mikael/simhome/ASR/infer2022/tst_" + i + ".nwk", Newick.MODE_ANCESTOR);
                    FastaWriter fastaWriter = new FastaWriter("/Users/mikael/simhome/ASR/infer2022/tst_" + i + ".fa");
                    fastaWriter.save(strArr, objArr2);
                    fastaWriter.close();
                    Newick.save(treeInstance, "/Users/mikael/simhome/ASR/infer2022/tstjoint_" + i + ".nwk");
                } catch (IOException e) {
                    System.err.println("Failed to save tree instance");
                }
            }
        }
        Assertions.assertTrue(d / ((double) (100 - 2)) < 2.0d);
    }

    @Test
    void getProb_Joint2() {
        MilliTimer milliTimer = new MilliTimer();
        Random random = new Random(1L);
        Object[] values = this.model.getDomain().getValues();
        Tree tree = getTree(20, 1L);
        try {
            Newick.save(tree, "/Users/mikael/simhome/ASR/infer2022/marg20j.nwk", Newick.MODE_ANCESTOR);
        } catch (IOException e) {
            System.err.println(e);
            System.exit(1);
        }
        PhyloBN create = PhyloBN.create(tree, this.model);
        System.out.println("Cached nodes: " + create.setCache(this.cache) + " of " + create.getBN().getNodes().size());
        Object[][] objArr = new Object[tree.getAncestors().length + tree.getNLeaves()][100];
        for (int i = 0; i < 100; i++) {
            milliTimer.start("Reset");
            create.getBN().resetNodes();
            milliTimer.stopStart("Reset", "setInst");
            for (int i2 : tree.getLeaves()) {
                BNode node = create.getBN().getNode(tree.getLabel(i2).toString());
                Object obj = values[random.nextInt(values.length)];
                node.setInstance(obj);
                objArr[i2][i] = obj;
            }
            milliTimer.stopStart("setInst", "VE inst");
            VarElim varElim = new VarElim();
            varElim.instantiate(create.getBN());
            milliTimer.stopStart("VE inst", "makeMPE");
            Query makeMPE = varElim.makeMPE(new Variable[0]);
            milliTimer.stopStart("makeMPE", "Infer");
            CGTable cGTable = (CGTable) varElim.infer(makeMPE);
            milliTimer.stopStart("Infer", "Assign");
            Variable.Assignment[] mpe = cGTable.getMPE();
            milliTimer.stop("Assign");
            for (Variable.Assignment assignment : mpe) {
                objArr[tree.getIndex(assignment.var.getName())][i] = assignment.val;
            }
        }
        milliTimer.report();
        this.cache.reportCache();
        for (int i3 = 0; i3 < objArr.length; i3++) {
            System.out.print(((tree.isLeaf(i3) ? "" : "N") + tree.getLabel(i3).toString()) + "  \t[" + i3 + "] \t");
            for (int i4 = 0; i4 < 100; i4++) {
                System.out.print(objArr[i3][i4]);
            }
            System.out.println();
        }
    }

    @Test
    void fromJSON() {
        Enumerable enumerable = new Enumerable(new Object[]{1, 2, 3});
        EnumVariable enumVariable = new EnumVariable(enumerable, "TestChild");
        EnumVariable enumVariable2 = new EnumVariable(enumerable, "TestParent");
        EnumVariable enumVariable3 = new EnumVariable(Enumerable.aacid, "TestChild2");
        EnumVariable enumVariable4 = new EnumVariable(Enumerable.aacid, "TestParent2");
        SubstNode substNode = new SubstNode(enumVariable, enumVariable2, new JC(1.0d, enumerable.getValues()), 1.0d);
        SubstNode substNode2 = new SubstNode(enumVariable3, enumVariable4, SubstModel.createModel("JTT"), 1.0d);
        System.out.println(substNode.getStateAsText());
        System.out.println(substNode.toJSON());
        System.out.println(substNode2.toJSON());
        Assertions.assertTrue(SubstNode.fromJSON(substNode.toJSON(), enumVariable, enumVariable2).toJSON().toString().equals(substNode.toJSON().toString()));
        Assertions.assertTrue(SubstNode.fromJSON(substNode2.toJSON(), enumVariable3, enumVariable4).toJSON().toString().equals(substNode2.toJSON().toString()));
    }
}
