package asr;

import bn.ctmc.SubstModel;
import bn.ctmc.matrix.JC;
import bn.prob.EnumDistrib;
import bn.prob.MixtureDistrib;
import dat.Enumerable;
import dat.phylo.IdxTree;
import dat.phylo.PhyloBN;
import dat.phylo.Tree;
import dat.phylo.TreeInstance;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import json.JSONObject;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import util.MilliTimer;

/* loaded from: input_file:asr/MaxLhoodMarginalTest.class */
class MaxLhoodMarginalTest {
    PhyloBN pbn_gdt;
    PhyloBN pbn_cpt;
    PhyloBN pbn_gdt_ext;
    PhyloBN pbn_cpt_ext;
    IdxTree tree = IdxTree.fromJSON(new JSONObject("{\"Parents\":[-1,0,1,2,2,4,4,1,0,8,9,9,11,11,8,14,14,16,16],\"Labels\":[\"0\",\"1\",\"2\",\"S001\",\"3\",\"S002\",\"S003\",\"S004\",\"4\",\"5\",\"S005\",\"6\",\"S006\",\"S007\",\"7\",\"S008\",\"8\",\"S009\",\"S010\"],\"Distances\":[0,0.14,0.03,0.14,0.08,0.16,0.10,0.12,0.06,0.06,0.28,0.13,0.12,0.14,0.11,0.20,0.07,0.12,0.19],\"Branchpoints\":19}\n"));
    String[] headers = {"S009", "S005", "S002", "S006", "S003", "S001", "S008", "S010", "S004", "S007"};
    Double[][] rows1 = {new Double[]{Double.valueOf(3.63d), Double.valueOf(3.81d), Double.valueOf(2.89d), Double.valueOf(3.81d), Double.valueOf(2.54d), Double.valueOf(2.76d), Double.valueOf(3.79d), Double.valueOf(3.7d), Double.valueOf(1.94d), Double.valueOf(3.97d)}};
    Double[][] rows1tst = {new Double[]{Double.valueOf(3.6d), Double.valueOf(3.8d), Double.valueOf(2.9d), Double.valueOf(3.8d), Double.valueOf(2.5d), Double.valueOf(2.8d), Double.valueOf(3.8d), Double.valueOf(3.7d), Double.valueOf(1.9d), Double.valueOf(4.0d)}};
    String[][] rows2 = {new String[]{"b", "c", "a", "c", "a", "a", "b", "b", "a", "c"}};
    String[][] rows2tst = {new String[]{"b", "c", "a", "c", "a", "a", "b", "b", "a", "c"}};
    String[] states = {"A", "B"};
    String[] states3 = {"A", "B", "C"};
    double GAMMA = 1.0d;
    SubstModel model2 = new JC(this.GAMMA, this.states);
    SubstModel model3 = new JC(this.GAMMA, this.states3);

    /* JADX WARN: Type inference failed for: r1v11, types: [java.lang.String[], java.lang.String[][]] */
    /* JADX WARN: Type inference failed for: r1v5, types: [java.lang.Double[], java.lang.Double[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [java.lang.Double[], java.lang.Double[][]] */
    /* JADX WARN: Type inference failed for: r1v9, types: [java.lang.String[], java.lang.String[][]] */
    MaxLhoodMarginalTest() {
    }

    void setup_gdt() {
        this.pbn_gdt = PhyloBN.withGDTs(this.tree, this.model2, 1.0d);
        this.pbn_gdt.trainEM(this.headers, this.rows1, 1L);
    }

    void setup_gdt_ext() {
        this.pbn_gdt_ext = PhyloBN.withGDTs(this.tree, this.model2, 1.0d, false, 1L);
        this.pbn_gdt_ext.trainEM(this.headers, this.rows1, 1L);
    }

    void setup_cpt() {
        HashSet hashSet = new HashSet();
        for (int i = 0; i < this.rows2[0].length; i++) {
            hashSet.add(this.rows2[0][i]);
        }
        String[] strArr = new String[hashSet.size()];
        hashSet.toArray(strArr);
        this.pbn_cpt = PhyloBN.withCPTs(this.tree, this.model2, strArr, 1.0d);
        this.pbn_cpt.trainEM(this.headers, this.rows2, 1L);
    }

    void setup_cpt_ext() {
        HashSet hashSet = new HashSet();
        for (int i = 0; i < this.rows2[0].length; i++) {
            hashSet.add(this.rows2[0][i]);
        }
        String[] strArr = new String[hashSet.size()];
        hashSet.toArray(strArr);
        this.pbn_cpt_ext = PhyloBN.withCPTs(this.tree, this.model3, strArr, 1.0d, false, 1L);
        this.pbn_cpt_ext.trainEM(this.headers, this.rows2, 1L);
    }

    @Test
    void getDecoration1a() {
        setup_gdt();
        int i = 0;
        Iterator<Integer> it = this.tree.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (this.tree.isParent(intValue)) {
                Object label = this.tree.getLabel(intValue);
                MaxLhoodMarginal maxLhoodMarginal = new MaxLhoodMarginal(intValue, this.pbn_gdt);
                maxLhoodMarginal.decorate(this.tree.getInstance(this.headers, this.rows1tst[0]));
                EnumDistrib enumDistrib = (EnumDistrib) maxLhoodMarginal.getDecoration(intValue);
                System.out.println(intValue + "\t" + label + "\t" + enumDistrib);
                if (intValue == 0) {
                    i = enumDistrib.getMaxIndex();
                } else {
                    Assertions.assertTrue(intValue < 8 ? i != enumDistrib.getMaxIndex() : i == enumDistrib.getMaxIndex());
                }
            }
        }
    }

    @Test
    void getDecoration1b() {
        setup_gdt_ext();
        double d = 0.0d;
        double d2 = 0.0d;
        int i = 0;
        double d3 = 0.0d;
        int i2 = 0;
        Iterator<Integer> it = this.tree.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (this.tree.isParent(intValue)) {
                Object label = this.tree.getLabel(intValue);
                MaxLhoodMarginal maxLhoodMarginal = new MaxLhoodMarginal(intValue, this.pbn_gdt_ext);
                maxLhoodMarginal.decorate(this.tree.getInstance(this.headers, this.rows1tst[0]));
                MixtureDistrib mixtureDistrib = (MixtureDistrib) maxLhoodMarginal.getDecoration(intValue);
                System.out.println(intValue + "\t" + label + "\t" + mixtureDistrib);
                if (intValue == 0) {
                    for (int i3 = 0; i3 < 30; i3++) {
                        d += ((Double) mixtureDistrib.sample()).doubleValue() / 30.0d;
                    }
                } else if (intValue < 8) {
                    i++;
                    for (int i4 = 0; i4 < 30; i4++) {
                        d2 += ((Double) mixtureDistrib.sample()).doubleValue() / 30.0d;
                    }
                } else {
                    i2++;
                    for (int i5 = 0; i5 < 30; i5++) {
                        d3 += ((Double) mixtureDistrib.sample()).doubleValue() / 30.0d;
                    }
                }
            }
        }
        Assertions.assertTrue(d2 / ((double) i) < d && d < d3 / ((double) i2));
    }

    @Test
    void getDecoration2a() {
        setup_cpt();
        int i = 0;
        Iterator<Integer> it = this.tree.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (this.tree.isParent(intValue)) {
                Object label = this.tree.getLabel(intValue);
                MaxLhoodMarginal maxLhoodMarginal = new MaxLhoodMarginal(intValue, this.pbn_cpt);
                maxLhoodMarginal.decorate(this.tree.getInstance(this.headers, this.rows2tst[0]));
                EnumDistrib enumDistrib = (EnumDistrib) maxLhoodMarginal.getDecoration(intValue);
                System.out.println(intValue + "\t" + label + "\t" + enumDistrib);
                if (intValue == 0) {
                    i = enumDistrib.getMaxIndex();
                } else {
                    Assertions.assertTrue(intValue < 8 ? i != enumDistrib.getMaxIndex() : i == enumDistrib.getMaxIndex());
                }
            }
        }
    }

    @Test
    void getDecoration2b() {
        setup_cpt_ext();
        int[] iArr = new int[3];
        Iterator<Integer> it = this.tree.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (this.tree.isParent(intValue)) {
                Object label = this.tree.getLabel(intValue);
                MaxLhoodMarginal maxLhoodMarginal = new MaxLhoodMarginal(intValue, this.pbn_cpt_ext);
                maxLhoodMarginal.decorate(this.tree.getInstance(this.headers, this.rows2tst[0]));
                EnumDistrib enumDistrib = (EnumDistrib) maxLhoodMarginal.getDecoration(intValue);
                System.out.println(intValue + "\t" + label + "\t" + enumDistrib);
                if (intValue != 0) {
                    int maxIndex = enumDistrib.getMaxIndex();
                    iArr[maxIndex] = iArr[maxIndex] + 1;
                }
            }
        }
        Assertions.assertTrue(Math.abs(iArr[0] - iArr[1]) <= 1 && Math.abs(iArr[1] - iArr[2]) <= 1);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    void marginal1() {
        SubstModel createModel = SubstModel.createModel("JC");
        Enumerable enumerable = new Enumerable(createModel.getDomain().getValues());
        MilliTimer milliTimer = new MilliTimer();
        for (int i = 3; i < 9; i++) {
            milliTimer.start("Leaves=" + i);
            for (int i2 = 0; i2 < 100; i2++) {
                Tree Random = Tree.Random(i, i2, 1.1d, 1.0d / 0.2d, 2, 2);
                Random.adjustDistances(1.0d);
                int[] ancestors = Random.getAncestors();
                int[] leaves = Random.getLeaves();
                Random random = new Random(i2);
                Object[] objArr = new Object[Random.getSize()];
                Object[] objArr2 = new Object[leaves.length];
                for (int i3 = 0; i3 < objArr2.length; i3++) {
                    objArr2[i3] = enumerable.get(random.nextInt(enumerable.size()));
                    objArr[leaves[i3]] = objArr2[i3];
                }
                milliTimer.start("BNKit_ML_Leaves=" + i);
                MaxLhoodMarginal maxLhoodMarginal = new MaxLhoodMarginal(0, Random, createModel);
                maxLhoodMarginal.decorate(new TreeInstance(Random, objArr));
                EnumDistrib enumDistrib = (EnumDistrib) maxLhoodMarginal.getDecoration(0);
                milliTimer.stop("BNKit_ML_Leaves=" + i);
                milliTimer.start("Naive_ML_Leaves=" + i);
                Object[] objArr3 = new Object[(int) Math.pow(enumerable.size(), ancestors.length)];
                for (int i4 = 0; i4 < objArr3.length; i4++) {
                    objArr3[i4] = enumerable.getWord4Key(i4, ancestors.length);
                }
                double[] dArr = new double[(int) Math.pow(enumerable.size(), ancestors.length)];
                for (int i5 = 0; i5 < objArr3.length; i5++) {
                    for (int i6 = 0; i6 < objArr3[i5].length; i6++) {
                        objArr[ancestors[i6]] = objArr3[i5][i6];
                    }
                    double d = 1.0d;
                    Iterator<Integer> it = Random.iterator();
                    while (it.hasNext()) {
                        int intValue = it.next().intValue();
                        int parent = Random.getParent(intValue);
                        d *= parent < 0 ? createModel.getProb(objArr[intValue]) : createModel.getProb(objArr[intValue], objArr[parent], createModel.getProbs(Random.getDistance(intValue)));
                    }
                    dArr[i5] = d;
                }
                double[] dArr2 = new double[enumerable.size()];
                for (int i7 = 0; i7 < objArr3.length; i7++) {
                    int index = enumerable.getIndex(objArr3[i7][0]);
                    dArr2[index] = dArr2[index] + dArr[i7];
                }
                EnumDistrib enumDistrib2 = new EnumDistrib(enumerable, dArr2);
                milliTimer.stop("Naive_ML_Leaves=" + i);
                Assertions.assertArrayEquals(enumDistrib2.get(), enumDistrib.get(), 0.01d);
            }
            milliTimer.stop("Leaves=" + i);
        }
        milliTimer.report(true);
    }
}
