package asr;

import asr.MaxLhoodJoint;
import bn.ctmc.SubstModel;
import bn.ctmc.matrix.JC;
import bn.prob.EnumDistrib;
import bn.prob.GaussianDistrib;
import dat.Enumerable;
import dat.file.Newick;
import dat.phylo.BranchPoint;
import dat.phylo.IdxTree;
import dat.phylo.PhyloBN;
import dat.phylo.Tree;
import dat.phylo.TreeInstance;
import java.util.Iterator;
import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import util.MilliTimer;

/* loaded from: input_file:asr/MaxLhoodJointTest.class */
class MaxLhoodJointTest {
    Tree tree = Newick.parse("((Leaf1:0.02,Leaf2:0.12)Anc_left:0.19,(Leaf3:0.18,Leaf4:0.15)Anc_right:0.17)Root;");
    Tree mini1 = Newick.parse("((Leaf1:0.09,Leaf2:0.11)N1:0.07,Leaf3:0.12)N0");
    Tree mini2 = Newick.parse("((Leaf1:0.03,Leaf2:0.05)N1:0.13,Leaf3:0.12)N0");
    Object A = 'A';
    Object C = 'C';
    Object G = 'G';
    Object T = 'T';

    MaxLhoodJointTest() {
    }

    IdxTree createTree() {
        BranchPoint branchPoint = new BranchPoint("Root");
        BranchPoint branchPoint2 = new BranchPoint("Anc_left", branchPoint, Double.valueOf(0.21d));
        branchPoint.addChild(branchPoint2);
        BranchPoint branchPoint3 = new BranchPoint("Anc_right", branchPoint, Double.valueOf(0.17d));
        branchPoint.addChild(branchPoint3);
        BranchPoint branchPoint4 = new BranchPoint("Leaf1", branchPoint2, Double.valueOf(0.22d));
        BranchPoint branchPoint5 = new BranchPoint("Leaf2", branchPoint2, Double.valueOf(0.12d));
        branchPoint2.addChild(branchPoint4);
        branchPoint2.addChild(branchPoint5);
        BranchPoint branchPoint6 = new BranchPoint("Leaf3", branchPoint3, Double.valueOf(0.09d));
        BranchPoint branchPoint7 = new BranchPoint("Leaf4", branchPoint3, Double.valueOf(0.15d));
        branchPoint3.addChild(branchPoint6);
        branchPoint3.addChild(branchPoint7);
        return new IdxTree(new BranchPoint[]{branchPoint, branchPoint2, branchPoint4, branchPoint5, branchPoint3, branchPoint6, branchPoint7});
    }

    @Test
    void infer1() {
        MaxLhoodJoint maxLhoodJoint = new MaxLhoodJoint(PhyloBN.create(this.tree, SubstModel.createModel("JC")));
        TreeInstance treeInstance = new TreeInstance(this.tree, new Object[]{null, null, this.A, null, null, this.C, null});
        MaxLhoodJoint.Inference infer = maxLhoodJoint.infer(treeInstance);
        System.out.println("Input: \t" + treeInstance);
        System.out.println("Output:\t" + infer);
        Assertions.assertEquals(this.A, infer.getTreeInstance().getInstance(this.tree.getIndex("Leaf1")));
        Assertions.assertEquals(this.A, infer.getTreeInstance().getInstance(this.tree.getIndex(1)));
        Assertions.assertEquals(this.C, infer.getTreeInstance().getInstance(this.tree.getIndex(2)));
        Assertions.assertEquals(this.A, infer.getTreeInstance().getInstance(this.tree.getIndex(0)));
    }

    @Test
    void infer1b() {
        JC jc = new JC(1.0d, new Object[]{this.A, this.C});
        PhyloBN create = PhyloBN.create(this.mini1, jc);
        MaxLhoodJoint maxLhoodJoint = new MaxLhoodJoint(create);
        MaxLhoodMarginal maxLhoodMarginal = new MaxLhoodMarginal(0, create);
        MaxLhoodMarginal maxLhoodMarginal2 = new MaxLhoodMarginal(1, create);
        TreeInstance treeInstance = new TreeInstance(this.mini1, new Object[]{null, null, this.A, null, this.C});
        MaxLhoodJoint.Inference infer = maxLhoodJoint.infer(treeInstance);
        maxLhoodMarginal.decorate(treeInstance);
        maxLhoodMarginal2.decorate(treeInstance);
        System.out.println(maxLhoodMarginal.getDecoration(0));
        System.out.println(maxLhoodMarginal2.getDecoration(1));
        System.out.println("Input: \t" + treeInstance);
        System.out.println("Output:\t" + infer);
        PhyloBN create2 = PhyloBN.create(this.mini2, jc);
        MaxLhoodJoint maxLhoodJoint2 = new MaxLhoodJoint(create2);
        MaxLhoodMarginal maxLhoodMarginal3 = new MaxLhoodMarginal(0, create2);
        TreeInstance treeInstance2 = new TreeInstance(this.mini2, new Object[]{null, null, this.A, null, this.C});
        MaxLhoodJoint.Inference infer2 = maxLhoodJoint2.infer(treeInstance2);
        maxLhoodMarginal3.decorate(treeInstance2);
        System.out.println(maxLhoodMarginal3.getDecoration(0));
        System.out.println("Input: \t" + treeInstance2);
        System.out.println("Output:\t" + infer2);
        Assertions.assertEquals(this.A, infer.getTreeInstance().getInstance(this.tree.getIndex(0)));
        Assertions.assertEquals(this.C, infer2.getTreeInstance().getInstance(this.tree.getIndex(0)));
        Assertions.assertEquals(this.A, infer.getTreeInstance().getInstance(this.tree.getIndex(1)));
        Assertions.assertEquals(this.A, infer2.getTreeInstance().getInstance(this.tree.getIndex(1)));
    }

    @Test
    void infer2b() {
        PhyloBN withGDTs = PhyloBN.withGDTs(this.mini1, new JC(1.0d, new Object[]{this.A, this.C}), 1.0d, true, 1L);
        withGDTs.setMasterGDT(new Object[]{this.A, this.C}, new GaussianDistrib[]{new GaussianDistrib(1.45d, 0.1d), new GaussianDistrib(2.55d, 0.1d)});
        MaxLhoodJoint maxLhoodJoint = new MaxLhoodJoint(withGDTs);
        MaxLhoodMarginal maxLhoodMarginal = new MaxLhoodMarginal(0, withGDTs);
        TreeInstance treeInstance = new TreeInstance(this.mini1, new Object[]{null, null, Double.valueOf(1.45d), null, Double.valueOf(2.55d)});
        MaxLhoodJoint.Inference infer = maxLhoodJoint.infer(treeInstance);
        maxLhoodMarginal.decorate(treeInstance);
        System.out.println(maxLhoodMarginal.getDecoration(0));
        PhyloBN withGDTs2 = PhyloBN.withGDTs(this.mini2, new JC(1.0d, new Object[]{this.A, this.C}), 1.0d, true, 1L);
        withGDTs2.setMasterGDT(new Object[]{this.A, this.C}, new GaussianDistrib[]{new GaussianDistrib(1.45d, 0.1d), new GaussianDistrib(2.55d, 0.1d)});
        MaxLhoodJoint maxLhoodJoint2 = new MaxLhoodJoint(withGDTs2);
        TreeInstance treeInstance2 = new TreeInstance(this.mini2, new Object[]{null, null, Double.valueOf(1.45d), null, Double.valueOf(2.55d)});
        MaxLhoodJoint.Inference infer2 = maxLhoodJoint2.infer(treeInstance2);
        System.out.println("Input: \t" + treeInstance);
        System.out.println("Output:\t" + infer);
        System.out.println("Input: \t" + treeInstance2);
        System.out.println("Output:\t" + infer2);
        Assertions.assertEquals(this.A, infer.getTreeInstance().getInstance(this.tree.getIndex(0)));
        Assertions.assertEquals(this.C, infer2.getTreeInstance().getInstance(this.tree.getIndex(0)));
        Assertions.assertEquals(this.A, infer.getTreeInstance().getInstance(this.tree.getIndex(1)));
        Assertions.assertEquals(this.A, infer2.getTreeInstance().getInstance(this.tree.getIndex(1)));
    }

    @Test
    void infer3a() {
        PhyloBN withGDTs = PhyloBN.withGDTs(this.tree, new JC(1.0d, new Object[]{this.A, this.C}), 1.0d, true, 1L);
        withGDTs.setMasterGDT(new Object[]{this.A, this.C}, new GaussianDistrib[]{new GaussianDistrib(1.45d, 0.1d), new GaussianDistrib(2.55d, 0.1d)});
        MaxLhoodJoint maxLhoodJoint = new MaxLhoodJoint(withGDTs);
        MaxLhoodMarginal maxLhoodMarginal = new MaxLhoodMarginal(0, withGDTs);
        TreeInstance treeInstance = new TreeInstance(this.tree, new Object[]{null, null, Double.valueOf(1.5d), Double.valueOf(1.4d), null, Double.valueOf(2.5d), Double.valueOf(2.6d)});
        MaxLhoodJoint.Inference infer = maxLhoodJoint.infer(treeInstance);
        maxLhoodMarginal.decorate(treeInstance);
        System.out.println(maxLhoodMarginal.getDecoration(0));
        System.out.println("Input: \t" + treeInstance);
        System.out.println("Output:\t" + infer);
        Assertions.assertEquals(infer.getTreeInstance().getInstance(this.tree.getIndex(1)), this.A);
        Assertions.assertEquals(infer.getTreeInstance().getInstance(this.tree.getIndex(2)), this.C);
        Assertions.assertEquals(infer.getTreeInstance().getInstance(this.tree.getIndex(0)), this.C);
    }

    @Test
    void infer3b() {
        PhyloBN withGDTs = PhyloBN.withGDTs(this.tree, new JC(1.0d, new Object[]{this.A, this.C}), 1.0d, true, 1L);
        withGDTs.setMasterGDT(new Object[]{this.A, this.C}, new GaussianDistrib[]{new GaussianDistrib(1.45d, 0.1d), new GaussianDistrib(2.55d, 0.1d)});
        MaxLhoodJoint maxLhoodJoint = new MaxLhoodJoint(withGDTs);
        MaxLhoodMarginal maxLhoodMarginal = new MaxLhoodMarginal(0, withGDTs);
        TreeInstance treeInstance = new TreeInstance(this.tree, new Object[]{null, null, Double.valueOf(1.5d), null, null, null, Double.valueOf(2.6d)});
        MaxLhoodJoint.Inference infer = maxLhoodJoint.infer(treeInstance);
        maxLhoodMarginal.decorate(treeInstance);
        System.out.println(maxLhoodMarginal.getDecoration(0));
        System.out.println("Input: \t" + treeInstance);
        System.out.println("Output:\t" + infer);
        Assertions.assertEquals(infer.getTreeInstance().getInstance(this.tree.getIndex(1)), this.A);
        Assertions.assertEquals(infer.getTreeInstance().getInstance(this.tree.getIndex(2)), this.C);
        Assertions.assertEquals(infer.getTreeInstance().getInstance(this.tree.getIndex(0)), this.C);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    void joint1() {
        double prob;
        SubstModel createModel = SubstModel.createModel("JC");
        Enumerable enumerable = new Enumerable(createModel.getDomain().getValues());
        MilliTimer milliTimer = new MilliTimer();
        for (int i = 3; i < 8; i++) {
            milliTimer.start("Leaves=" + i);
            for (int i2 = 0; i2 < 100; i2++) {
                Tree Random = Tree.Random(i, i2, 1.1d, 1.0d / 0.2d, 2, 2);
                Random.adjustDistances(1.0d);
                int[] ancestors = Random.getAncestors();
                int[] leaves = Random.getLeaves();
                Random random = new Random(i2);
                Object[] objArr = new Object[Random.getSize()];
                Object[] objArr2 = new Object[leaves.length];
                for (int i3 = 0; i3 < objArr2.length; i3++) {
                    objArr2[i3] = enumerable.get(random.nextInt(enumerable.size()));
                    objArr[leaves[i3]] = objArr2[i3];
                }
                milliTimer.start("BNKit_ML_Leaves=" + i);
                MaxLhoodJoint.Inference infer = new MaxLhoodJoint(Random, createModel).infer(new TreeInstance(Random, objArr));
                milliTimer.stop("BNKit_ML_Leaves=" + i);
                TreeInstance treeInstance = infer.getTreeInstance();
                milliTimer.start("Naive_ML_Leaves=" + i);
                Object[] objArr3 = new Object[(int) Math.pow(enumerable.size(), ancestors.length)];
                for (int i4 = 0; i4 < objArr3.length; i4++) {
                    objArr3[i4] = enumerable.getWord4Key(i4, ancestors.length);
                }
                double[] dArr = new double[(int) Math.pow(enumerable.size(), ancestors.length)];
                int i5 = 0;
                String[] strArr = new String[objArr3.length];
                for (int i6 = 0; i6 < objArr3.length; i6++) {
                    for (int i7 = 0; i7 < objArr3[i6].length; i7++) {
                        objArr[ancestors[i7]] = objArr3[i6][i7];
                    }
                    double d = 1.0d;
                    StringBuilder sb = new StringBuilder();
                    Iterator<Integer> it = Random.iterator();
                    while (it.hasNext()) {
                        int intValue = it.next().intValue();
                        int parent = Random.getParent(intValue);
                        if (parent < 0) {
                            prob = createModel.getProb(objArr[intValue]);
                            sb.append(String.format("P(N%d=%s)=%5.3f", Integer.valueOf(intValue), objArr[intValue], Double.valueOf(prob)));
                        } else {
                            prob = createModel.getProb(objArr[intValue], objArr[parent], createModel.getProbs(Random.getDistance(intValue)));
                            sb.append(String.format("P(N%d=%s|N%d=%s)=%5.3f", Integer.valueOf(intValue), objArr[intValue], Integer.valueOf(parent), objArr[parent], Double.valueOf(prob)));
                        }
                        d *= prob;
                        if (intValue < Random.getSize() - 1) {
                            sb.append(" x ");
                        }
                    }
                    strArr[i6] = sb.toString();
                    dArr[i6] = d;
                    if (d > dArr[i5]) {
                        i5 = i6;
                    }
                }
                milliTimer.stop("Naive_ML_Leaves=" + i);
                for (int i8 = 0; i8 < objArr3[i5].length; i8++) {
                    objArr[ancestors[i8]] = objArr3[i5][i8];
                }
                TreeInstance treeInstance2 = new TreeInstance(Random, objArr);
                Object[] objArr4 = new Object[ancestors.length];
                for (int i9 = 0; i9 < ancestors.length; i9++) {
                    objArr4[i9] = treeInstance.getInstance(ancestors[i9]);
                }
                int key4Word = enumerable.getKey4Word(objArr4);
                if (key4Word != i5) {
                    System.out.println(treeInstance);
                    System.out.println(" [" + key4Word + "] -LogL = " + (-Math.log(dArr[key4Word])));
                    System.out.println(treeInstance2);
                    System.out.println(" [" + i5 + "] -LogL = " + (-Math.log(dArr[i5])));
                }
                for (int i10 : ancestors) {
                    Assertions.assertEquals(treeInstance2.getInstance(i10), treeInstance.getInstance(i10));
                }
                String[] strArr2 = {"a", "c", "g", "t"};
                Enumerable enumerable2 = new Enumerable(strArr2);
                PhyloBN withCPTs = PhyloBN.withCPTs(Random, createModel, strArr2, 1.0d, true, i2 + 1);
                Enumerable domain = createModel.getDomain();
                withCPTs.setMasterCPT(domain.getValues(), new EnumDistrib[]{new EnumDistrib(enumerable2, 0.997d, 0.001d, 0.001d, 0.001d), new EnumDistrib(enumerable2, 0.001d, 0.997d, 0.001d, 0.001d), new EnumDistrib(enumerable2, 0.001d, 0.001d, 0.997d, 0.001d), new EnumDistrib(enumerable2, 0.001d, 0.001d, 0.001d, 0.997d)});
                Object[] objArr5 = new Object[Random.getSize()];
                for (int i11 = 0; i11 < objArr2.length; i11++) {
                    objArr5[leaves[i11]] = objArr2[i11].toString().toLowerCase();
                }
                TreeInstance treeInstance3 = new TreeInstance(Random, objArr5);
                milliTimer.start("BNKit_ML_CPT_Leaves=" + i);
                MaxLhoodJoint.Inference infer2 = new MaxLhoodJoint(withCPTs).infer(treeInstance3);
                milliTimer.stop("BNKit_ML_CPT_Leaves=" + i);
                TreeInstance treeInstance4 = infer2.getTreeInstance();
                Object[] objArr6 = new Object[ancestors.length];
                for (int i12 = 0; i12 < ancestors.length; i12++) {
                    objArr6[i12] = treeInstance4.getInstance(ancestors[i12]);
                }
                int key4Word2 = enumerable.getKey4Word(objArr6);
                if (key4Word2 != i5) {
                    System.out.println(treeInstance4);
                    System.out.println(" [" + key4Word2 + "] -LogL = " + (-Math.log(dArr[key4Word2])));
                    System.out.println(treeInstance2);
                    System.out.println(" [" + i5 + "] -LogL = " + (-Math.log(dArr[i5])));
                }
                for (int i13 : ancestors) {
                    Assertions.assertEquals(treeInstance2.getInstance(i13), treeInstance4.getInstance(i13));
                }
                PhyloBN withGDTs = PhyloBN.withGDTs(Random, createModel, 1.0d, true, i2 + 1);
                GaussianDistrib[] gaussianDistribArr = {new GaussianDistrib(0.0d, 0.01d), new GaussianDistrib(1.0d, 0.01d), new GaussianDistrib(2.0d, 0.01d), new GaussianDistrib(3.0d, 0.01d)};
                withGDTs.setMasterGDT(domain.getValues(), gaussianDistribArr);
                Object[] objArr7 = new Object[Random.getSize()];
                for (int i14 = 0; i14 < objArr2.length; i14++) {
                    objArr7[leaves[i14]] = gaussianDistribArr[domain.getIndex(objArr2[i14])].sample();
                }
                TreeInstance treeInstance5 = new TreeInstance(Random, objArr7);
                milliTimer.start("BNKit_ML_GDT_Leaves=" + i);
                MaxLhoodJoint.Inference infer3 = new MaxLhoodJoint(withGDTs).infer(treeInstance5);
                milliTimer.stop("BNKit_ML_GDT_Leaves=" + i);
                TreeInstance treeInstance6 = infer3.getTreeInstance();
                Object[] objArr8 = new Object[ancestors.length];
                for (int i15 = 0; i15 < ancestors.length; i15++) {
                    objArr8[i15] = treeInstance6.getInstance(ancestors[i15]);
                }
                int key4Word3 = enumerable.getKey4Word(objArr8);
                if (key4Word3 != i5) {
                    System.out.println(treeInstance6);
                    System.out.println(" [" + key4Word3 + "] -LogL = " + (-Math.log(dArr[key4Word3])));
                    System.out.println(treeInstance2);
                    System.out.println(" [" + i5 + "] -LogL = " + (-Math.log(dArr[i5])));
                }
                for (int i16 : ancestors) {
                    Assertions.assertEquals(treeInstance2.getInstance(i16), treeInstance6.getInstance(i16));
                }
            }
            milliTimer.stop("Leaves=" + i);
        }
        milliTimer.report(true);
    }
}
