package bn.factor;

import bn.Predef;
import bn.factor.AbstractFactor;
import bn.factor.Factorize;
import bn.prob.GaussianDistrib;
import dat.Continuous;
import dat.EnumVariable;
import dat.Variable;
import java.io.PrintStream;
import java.util.HashSet;
import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:bn/factor/FactorizeTest.class */
class FactorizeTest {
    protected static final int POOL_OPTION_LINEAR = 0;
    protected static final int POOL_OPTION_TREE = 1;

    FactorizeTest() {
    }

    protected static Variable[] getVariablePool(long j, int i) {
        Random random = new Random(j);
        Variable[] variableArr = new Variable[i];
        for (int i2 = 0; i2 < variableArr.length; i2++) {
            Variable<Continuous> variable = null;
            switch (random.nextInt(5)) {
                case 0:
                    variable = Predef.Boolean();
                    break;
                case 1:
                    variable = Predef.Nominal("a", "b", "c");
                    break;
                case 2:
                    variable = Predef.NucleicAcid();
                    break;
                case 3:
                    variable = Predef.Real();
                    break;
                case 4:
                    variable = Predef.Number(random.nextInt(8) + 2);
                    break;
            }
            variableArr[i2] = variable;
        }
        return variableArr;
    }

    protected static Variable[] getSubset(long j, Variable[] variableArr, int i) {
        Random random = new Random(j);
        HashSet hashSet = new HashSet();
        int min = Math.min(variableArr.length, i);
        Variable[] variableArr2 = new Variable[min];
        while (hashSet.size() < min) {
            hashSet.add(variableArr[random.nextInt(variableArr.length)]);
        }
        hashSet.toArray(variableArr2);
        return variableArr2;
    }

    protected static AbstractFactor[] getFactorPool(long j, Variable[] variableArr, int i) {
        Random random = new Random(j);
        int abs = Math.abs(((int) (random.nextGaussian() * i)) + 1);
        DenseFactor[] denseFactorArr = new DenseFactor[Math.abs(((int) (random.nextGaussian() * i)) + 1)];
        for (int i2 = 0; i2 < denseFactorArr.length; i2++) {
            int nextInt = random.nextInt(Math.max(1, abs));
            if (nextInt > 0) {
                denseFactorArr[i2] = new DenseFactor(getSubset(random.nextInt(), variableArr, nextInt));
            } else {
                denseFactorArr[i2] = new DenseFactor();
            }
            int size = denseFactorArr[i2].getSize();
            AbstractFactor.FactorFiller filler = denseFactorArr[i2].getFiller();
            for (int i3 = 0; i3 < size; i3++) {
                if (denseFactorArr[i2].getSize() == 1) {
                    denseFactorArr[i2].setValue(Math.abs(random.nextGaussian()) / size);
                    if (denseFactorArr[i2].isJDF()) {
                        for (Variable variable : denseFactorArr[i2].getNonEnumVars()) {
                            denseFactorArr[i2].setDistrib(variable, new GaussianDistrib(random.nextGaussian() * random.nextInt(100), Math.abs(random.nextGaussian() * (random.nextInt(10) + 1))));
                        }
                    }
                } else {
                    int nextInt2 = random.nextInt(size);
                    filler.setValue(nextInt2, Math.abs(random.nextGaussian()) / size);
                    if (denseFactorArr[i2].isJDF()) {
                        for (Variable variable2 : denseFactorArr[i2].getNonEnumVars()) {
                            denseFactorArr[i2].setDistrib(nextInt2, variable2, new GaussianDistrib(random.nextGaussian() * random.nextInt(100), Math.abs(random.nextGaussian() * (random.nextInt(10) + 1))));
                        }
                    }
                }
            }
            denseFactorArr[i2].setValuesByFiller(filler);
        }
        return denseFactorArr;
    }

    protected static AbstractFactor getProductBenchmarked(Factorize.FactorProductTree factorProductTree) {
        if (factorProductTree.getFactor() != null) {
            return factorProductTree.getFactor();
        }
        AbstractFactor productBenchmarked = getProductBenchmarked(factorProductTree.x);
        AbstractFactor productBenchmarked2 = getProductBenchmarked(factorProductTree.y);
        long nanoTime = System.nanoTime();
        AbstractFactor product = Factorize.getProduct(productBenchmarked, productBenchmarked2);
        long nanoTime2 = System.nanoTime();
        for (int i = 0; i < 20; i++) {
            Assertions.assertTrue(testProductIntegrity(i, productBenchmarked, productBenchmarked2, product));
        }
        int overlap = Factorize.getOverlap(productBenchmarked, productBenchmarked2);
        int min = Math.min(productBenchmarked.nEVars, productBenchmarked2.nEVars);
        int max = Math.max(productBenchmarked2.nEVars, productBenchmarked2.nEVars);
        PrintStream printStream = System.out;
        double d = (nanoTime2 - nanoTime) / 100000.0d;
        printStream.println(max + "\t" + min + "\t" + overlap + "\t" + (min == 0 ? 0.0d : overlap / min) + "\t" + printStream + "\t" + Factorize.getComplexity(productBenchmarked, productBenchmarked2, false) + "\t" + Factorize.getComplexity(productBenchmarked, productBenchmarked2, true));
        factorProductTree.setFactor(product);
        return product;
    }

    protected static AbstractFactor getProductBenchmarked(AbstractFactor[] abstractFactorArr) {
        if (abstractFactorArr.length == 0) {
            return null;
        }
        AbstractFactor abstractFactor = abstractFactorArr[0];
        for (int i = 1; i < abstractFactorArr.length; i++) {
            AbstractFactor abstractFactor2 = abstractFactor;
            AbstractFactor abstractFactor3 = abstractFactorArr[i];
            long nanoTime = System.nanoTime();
            abstractFactor = Factorize.getProduct(abstractFactor2, abstractFactor3);
            long nanoTime2 = System.nanoTime();
            for (int i2 = 0; i2 < 20; i2++) {
                Assertions.assertTrue(testProductIntegrity(i2, abstractFactor2, abstractFactor3, abstractFactor));
            }
            int overlap = Factorize.getOverlap(abstractFactor2, abstractFactor3);
            int min = Math.min(abstractFactor2.nEVars, abstractFactor3.nEVars);
            int max = Math.max(abstractFactor3.nEVars, abstractFactor3.nEVars);
            PrintStream printStream = System.out;
            double d = (nanoTime2 - nanoTime) / 100000.0d;
            printStream.println(max + "\t" + min + "\t" + overlap + "\t" + (min == 0 ? 0.0d : overlap / min) + "\t" + printStream + "\t" + Factorize.getComplexity(abstractFactor2, abstractFactor3, false) + "\t" + Factorize.getComplexity(abstractFactor2, abstractFactor3, true));
        }
        return abstractFactor;
    }

    protected static AbstractFactor productPool(AbstractFactor[] abstractFactorArr, int i) {
        if (abstractFactorArr.length == 0) {
            return null;
        }
        if (abstractFactorArr.length == 1) {
            return abstractFactorArr[0];
        }
        AbstractFactor abstractFactor = null;
        long nanoTime = System.nanoTime();
        switch (i) {
            case 0:
                abstractFactor = getProductBenchmarked(abstractFactorArr);
                break;
            case 1:
                abstractFactor = getProductBenchmarked(Factorize.getProductTree(abstractFactorArr));
                break;
        }
        long nanoTime2 = System.nanoTime();
        System.out.println("\t\t\t\t\t\t\t" + ((nanoTime2 - nanoTime) / 100000.0d));
        if (abstractFactor.nEVars <= 0) {
            return abstractFactor;
        }
        Random random = new Random(nanoTime2);
        int nextInt = random.nextInt(abstractFactor.nEVars);
        Variable[] variableArr = new Variable[nextInt];
        for (int i2 = 0; i2 < nextInt; i2++) {
            variableArr[i2] = abstractFactor.evars[random.nextInt(nextInt)];
        }
        if (random.nextBoolean()) {
            Factorize.getMargin(abstractFactor, variableArr);
        } else {
            Factorize.getMaxMargin(abstractFactor, variableArr);
        }
        return abstractFactor;
    }

    protected static boolean testProductIntegrity(long j, AbstractFactor abstractFactor, AbstractFactor abstractFactor2, AbstractFactor abstractFactor3) {
        int nextInt = new Random(j).nextInt(abstractFactor3.getSize());
        Object[] key = abstractFactor3.getSize() == 1 ? new Object[0] : abstractFactor3.getKey(nextInt);
        EnumVariable[] enumVars = abstractFactor3.getEnumVars();
        EnumVariable[] enumVars2 = abstractFactor.getEnumVars();
        EnumVariable[] enumVars3 = abstractFactor2.getEnumVars();
        Object[] objArr = new Object[abstractFactor.nEVars];
        Object[] objArr2 = new Object[abstractFactor2.nEVars];
        for (int i = 0; i < key.length; i++) {
            if (!enumVars[i].getDomain().isValid(key[i])) {
                return false;
            }
            int i2 = -1;
            for (int i3 = 0; i3 < abstractFactor.nEVars; i3++) {
                if (enumVars2[i3].equals(enumVars[i])) {
                    i2 = i3;
                }
            }
            int i4 = -1;
            for (int i5 = 0; i5 < abstractFactor2.nEVars; i5++) {
                if (enumVars3[i5].equals(enumVars[i])) {
                    i4 = i5;
                }
            }
            if (i4 != -1) {
                objArr2[i4] = key[i];
            }
            if (i2 != -1) {
                objArr[i2] = key[i];
            }
        }
        double value = objArr.length != 0 ? abstractFactor.getValue(abstractFactor.getIndex(objArr)) : abstractFactor.getValue();
        double value2 = objArr2.length != 0 ? abstractFactor2.getValue(abstractFactor2.getIndex(objArr2)) : abstractFactor2.getValue();
        double value3 = key.length == 0 ? abstractFactor3.getValue() : abstractFactor3.getValue(nextInt);
        return value * value2 <= value3 * 1.01d && value * value2 >= value3 * 0.99d;
    }

    protected static void testCrossRef(long j, Variable[] variableArr) {
        Random random = new Random(j);
        for (int i = 0; i < 20; i++) {
            Variable[] variableArr2 = new Variable[random.nextInt(variableArr.length - 1) + 1];
            Variable[] variableArr3 = new Variable[random.nextInt(variableArr.length - 1) + 1];
            for (int i2 = 0; i2 < variableArr2.length; i2++) {
                variableArr2[i2] = variableArr[random.nextInt(variableArr.length)];
            }
            for (int i3 = 0; i3 < variableArr3.length; i3++) {
                variableArr3[i3] = variableArr[random.nextInt(variableArr.length)];
            }
            Variable[] nonredundant = Factorize.getNonredundant(variableArr2);
            Variable[] nonredundant2 = Factorize.getNonredundant(variableArr3);
            int[] iArr = new int[nonredundant.length];
            int[] iArr2 = new int[nonredundant2.length];
            Factorize.getCrossref(nonredundant, iArr, nonredundant2, iArr2);
            Variable[] variableArr4 = new Variable[nonredundant.length];
            Variable[] variableArr5 = new Variable[nonredundant2.length];
            for (int i4 = 0; i4 < nonredundant2.length; i4++) {
                if (iArr2[i4] != -1) {
                    variableArr4[iArr2[i4]] = nonredundant2[i4];
                }
            }
            for (int i5 = 0; i5 < nonredundant.length; i5++) {
                if (iArr[i5] != -1) {
                    variableArr5[iArr[i5]] = nonredundant[i5];
                }
            }
            for (int i6 = 0; i6 < nonredundant.length; i6++) {
                System.out.print(nonredundant[i6].toString() + ":" + (variableArr4[i6] == null ? "-\t" : variableArr4[i6].toString() + "\t"));
            }
            System.out.println();
            for (int i7 = 0; i7 < nonredundant2.length; i7++) {
                System.out.print(nonredundant2[i7].toString() + ":" + (variableArr5[i7] == null ? "-\t" : variableArr5[i7].toString() + "\t"));
            }
            System.out.println();
            System.out.println();
        }
    }

    @Test
    void getProduct() {
        System.out.println("maxEV\tminEV\tOverlap\tContain\tProduct\tPJoin\tTime (ms)");
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= 50) {
                return;
            }
            AbstractFactor[] factorPool = getFactorPool(j2, getVariablePool(j2, 10), 8);
            AbstractFactor productPool = productPool(factorPool, 0);
            AbstractFactor productPool2 = productPool(factorPool, 1);
            if (productPool != null || productPool2 != null) {
                Assertions.assertFalse(productPool.getSize() != productPool2.getSize());
                if (productPool.getSize() != productPool2.getSize()) {
                    System.err.println("Invalid product size");
                }
                if (productPool.getSize() == 1) {
                    Assertions.assertFalse(productPool.getValue() < productPool2.getValue() * 0.999d || productPool.getValue() > productPool2.getValue() * 1.001d);
                    if (productPool.getValue() < productPool2.getValue() * 0.999d || productPool.getValue() > productPool2.getValue() * 1.001d) {
                        PrintStream printStream = System.err;
                        double value = productPool.getValue();
                        productPool2.getValue();
                        printStream.println("Invalid atomic product: " + value + " v " + printStream);
                        System.exit(1);
                    }
                } else {
                    for (int i = 0; i < productPool.getSize(); i++) {
                        Assertions.assertFalse(productPool.getValue(i) < productPool2.getValue(i) * 0.999d || productPool.getValue(i) > productPool2.getValue(i) * 1.001d);
                        if (productPool.getValue(i) < productPool2.getValue(i) * 0.999d || productPool.getValue(i) > productPool2.getValue(i) * 1.001d) {
                            PrintStream printStream2 = System.err;
                            double value2 = productPool.getValue(i);
                            productPool2.getValue(i);
                            printStream2.println("Invalid product: " + value2 + " v " + printStream2);
                            System.exit(1);
                        }
                    }
                }
            }
            j = j2 + 1;
        }
    }
}
