package bn.alg;

import bn.BNet;
import bn.BNode;
import bn.JDF;
import dat.EnumTable;
import dat.EnumVariable;
import dat.Variable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/* JADX WARN: Classes with same name are omitted:
  input_file:target/classes/bn/alg/EM.class
 */
/* loaded from: input_file:bn/alg/EM.class */
public class EM extends LearningAlg {
    private Inference inf;
    public double EM_CONVERGENCE_CRITERION;
    public int EM_MAX_ROUNDS;
    public boolean EM_PRINT_STATUS;
    public int EM_THREAD_COUNT;
    public int EM_OPTION;

    /* JADX WARN: Classes with same name are omitted:
      input_file:target/classes/bn/alg/EM$EMRuntimeException.class
     */
    /* loaded from: input_file:bn/alg/EM$EMRuntimeException.class */
    public class EMRuntimeException extends RuntimeException {
        private static final long serialVersionUID = 1;

        public EMRuntimeException(String str) {
            super(str);
        }
    }

    /* JADX WARN: Classes with same name are omitted:
      input_file:target/classes/bn/alg/EM$EMc1.class
     */
    /* loaded from: input_file:bn/alg/EM$EMc1.class */
    public class EMc1 implements Runnable {
        private boolean init;
        private BNode node;
        private Variable.Assignment[] evidence;
        private int i;

        public EMc1(BNode bNode, Variable.Assignment[] assignmentArr, int i) {
            this.init = false;
            this.init = true;
            this.node = bNode;
            this.evidence = assignmentArr;
            this.i = i;
        }

        @Override // java.lang.Runnable
        public void run() {
            case3(this.node, this.evidence, this.i);
        }

        public void case3(BNode bNode, Variable.Assignment[] assignmentArr, int i) {
            if (bNode.isTrainable()) {
                ArrayList arrayList = new ArrayList();
                Object[] objArr = null;
                if (!bNode.isRoot()) {
                    objArr = EnumTable.getKey(bNode.getParents(), assignmentArr);
                    for (int i2 = 0; i2 < objArr.length; i2++) {
                        if (objArr[i2] == null) {
                            arrayList.add(bNode.getParents().get(i2));
                        }
                    }
                }
                Variable variable = bNode.getVariable();
                Object bNode2 = bNode.getInstance();
                if (bNode2 == null) {
                    arrayList.add(variable);
                }
                if (arrayList.size() <= 0) {
                    bNode.countInstance(objArr, bNode2);
                    return;
                }
                try {
                    Variable[] variableArr = new Variable[arrayList.size()];
                    arrayList.toArray(variableArr);
                    CGTable cGTable = (CGTable) EM.this.inf.infer(EM.this.inf.makeQuery(variableArr));
                    for (int i3 : cGTable.getIndices()) {
                        Object[] key = cGTable.getKey(i3);
                        double doubleValue = cGTable.getFactor(i3).doubleValue();
                        JDF jdf = cGTable.hasNonEnumVariables() ? cGTable.getJDF(i3) : null;
                        if (!bNode.isRoot()) {
                            Variable.Assignment[] array = Variable.Assignment.array(cGTable.getEnumVariables(), key);
                            EnumTable.overlay(objArr, EnumTable.getKey(bNode.getParents(), array));
                            if (bNode2 != null) {
                                bNode.countInstance(objArr, bNode2, Double.valueOf(doubleValue));
                            } else {
                                try {
                                    EnumVariable enumVariable = (EnumVariable) variable;
                                    int length = array.length;
                                    int i4 = 0;
                                    while (true) {
                                        if (i4 >= length) {
                                            break;
                                        }
                                        Variable.Assignment assignment = array[i4];
                                        if (assignment.var.equals(enumVariable)) {
                                            bNode.countInstance(objArr, assignment.val, Double.valueOf(doubleValue));
                                            break;
                                        }
                                        i4++;
                                    }
                                } catch (ClassCastException e) {
                                    bNode.countInstance(objArr, jdf.getDistrib(variable), Double.valueOf(doubleValue));
                                }
                            }
                        } else if (bNode2 != null) {
                            bNode.countInstance(null, bNode2, Double.valueOf(doubleValue));
                        } else {
                            try {
                                EnumVariable enumVariable2 = (EnumVariable) variable;
                                Variable.Assignment[] array2 = Variable.Assignment.array(cGTable.getEnumVariables(), key);
                                int length2 = array2.length;
                                int i5 = 0;
                                while (true) {
                                    if (i5 >= length2) {
                                        break;
                                    }
                                    Variable.Assignment assignment2 = array2[i5];
                                    if (assignment2.var.equals(enumVariable2)) {
                                        bNode.countInstance(null, assignment2.val, Double.valueOf(doubleValue));
                                        break;
                                    }
                                    i5++;
                                }
                            } catch (ClassCastException e2) {
                                throw new EMRuntimeException("Failed query for sample #" + (i + 1) + ": " + variable.getName() + " is a non-enumerable root node");
                            }
                        }
                    }
                } catch (RuntimeException e3) {
                    throw new EMRuntimeException("Failed query for sample #" + (i + 1) + " and node " + bNode.getName() + ": " + e3.getMessage());
                }
            }
        }
    }

    public EM(BNet bNet, Inference inference) {
        super(bNet);
        this.EM_CONVERGENCE_CRITERION = 5.0E-5d;
        this.EM_MAX_ROUNDS = 1000;
        this.EM_PRINT_STATUS = true;
        this.EM_THREAD_COUNT = 1;
        this.EM_OPTION = 1;
        this.inf = inference;
    }

    public EM(BNet bNet) {
        this(bNet, new VarElim());
    }

    public void setConvergeCrit(double d) {
        this.EM_CONVERGENCE_CRITERION = d > 0.0d ? d : 5.0E-5d;
    }

    public void setPrintStatus(Boolean bool) {
        this.EM_PRINT_STATUS = bool.booleanValue();
    }

    public void setThreadCount(int i) {
        this.EM_THREAD_COUNT = i;
    }

    public void setEMOption(int i) {
        this.EM_OPTION = i == 1 ? i : i == 2 ? 2 : 3;
    }

    public void setMaxRounds(int i) {
        this.EM_MAX_ROUNDS = i >= 1 ? i : 1;
    }

    @Override // bn.alg.LearningAlg
    public void train(Object[][] objArr, Variable[] variableArr, long j) {
        int length = objArr.length;
        for (BNode bNode : this.f3bn.getNodes()) {
            if (bNode.isTrainable()) {
                bNode.randomize(j);
            }
        }
        double[] dArr = new double[5];
        dArr[0] = -999999.0d;
        dArr[1] = -999999.0d;
        dArr[2] = -999999.0d;
        dArr[3] = -999999.0d;
        dArr[4] = -999999.0d;
        int i = 0;
        boolean z = false;
        while (i < this.EM_MAX_ROUNDS && !z) {
            i++;
            double d = 0.0d;
            Map synchronizedMap = Collections.synchronizedMap(new HashMap());
            int i2 = 0;
            while (i2 < objArr.length) {
                for (int i3 = 0; i3 < variableArr.length; i3++) {
                    BNode node = this.f3bn.getNode(variableArr[i3]);
                    if (node == null) {
                        throw new EMRuntimeException("Variable \"" + variableArr[i2].getName() + "\" is not part of Bayesian network");
                    }
                    if (objArr[i2][i3] != null) {
                        node.setInstance(objArr[i2][i3]);
                        synchronizedMap.put(node, null);
                        List<EnumVariable> parents = node.getParents();
                        if (parents != null) {
                            Iterator<EnumVariable> it = parents.iterator();
                            while (it.hasNext()) {
                                synchronizedMap.put(this.f3bn.getNode(it.next()), null);
                            }
                        }
                        Set<String> childrenNames = this.f3bn.getChildrenNames(node);
                        if (childrenNames != null) {
                            Iterator<String> it2 = childrenNames.iterator();
                            while (it2.hasNext()) {
                                synchronizedMap.put(this.f3bn.getNode(it2.next()), null);
                            }
                        }
                    } else {
                        node.resetInstance();
                    }
                }
                this.inf.instantiate(this.f3bn);
                Variable.Assignment[] array = Variable.Assignment.array(variableArr, objArr[i2]);
                switch (this.EM_OPTION) {
                    case 1:
                        if (i2 == 11) {
                            i2 = 11;
                        }
                        for (BNode bNode2 : synchronizedMap.keySet()) {
                            if (bNode2.isTrainable()) {
                                ArrayList arrayList = new ArrayList();
                                Object[] objArr2 = null;
                                if (!bNode2.isRoot()) {
                                    objArr2 = EnumTable.getKey(bNode2.getParents(), array);
                                    for (int i4 = 0; i4 < objArr2.length; i4++) {
                                        if (objArr2[i4] == null) {
                                            arrayList.add(bNode2.getParents().get(i4));
                                        }
                                    }
                                }
                                Variable variable = bNode2.getVariable();
                                Object bNode3 = bNode2.getInstance();
                                if (bNode3 == null) {
                                    arrayList.add(variable);
                                }
                                if (arrayList.size() > 0) {
                                    try {
                                        Variable[] variableArr2 = new Variable[arrayList.size()];
                                        arrayList.toArray(variableArr2);
                                        CGTable cGTable = (CGTable) this.inf.infer(this.inf.makeQuery(variableArr2));
                                        for (int i5 : cGTable.getIndices()) {
                                            Object[] key = cGTable.getKey(i5);
                                            double doubleValue = cGTable.getFactor(i5).doubleValue();
                                            if (doubleValue != 0.0d && !Double.isNaN(doubleValue)) {
                                                JDF jdf = cGTable.hasNonEnumVariables() ? cGTable.getJDF(i5) : null;
                                                if (!bNode2.isRoot()) {
                                                    Variable.Assignment[] array2 = Variable.Assignment.array(cGTable.getEnumVariables(), key);
                                                    EnumTable.overlay(objArr2, EnumTable.getKey(bNode2.getParents(), array2));
                                                    if (bNode3 != null) {
                                                        bNode2.countInstance(objArr2, bNode3, Double.valueOf(doubleValue));
                                                    } else {
                                                        try {
                                                            EnumVariable enumVariable = (EnumVariable) variable;
                                                            int length2 = array2.length;
                                                            int i6 = 0;
                                                            while (true) {
                                                                if (i6 < length2) {
                                                                    Variable.Assignment assignment = array2[i6];
                                                                    if (assignment.var.equals(enumVariable)) {
                                                                        bNode2.countInstance(objArr2, assignment.val, Double.valueOf(doubleValue));
                                                                    } else {
                                                                        i6++;
                                                                    }
                                                                }
                                                            }
                                                        } catch (ClassCastException e) {
                                                            bNode2.countInstance(objArr2, jdf.getDistrib(variable), Double.valueOf(doubleValue));
                                                        }
                                                    }
                                                } else if (bNode3 != null) {
                                                    bNode2.countInstance(null, bNode3, Double.valueOf(doubleValue));
                                                } else {
                                                    try {
                                                        EnumVariable enumVariable2 = (EnumVariable) variable;
                                                        Variable.Assignment[] array3 = Variable.Assignment.array(cGTable.getEnumVariables(), key);
                                                        int length3 = array3.length;
                                                        int i7 = 0;
                                                        while (true) {
                                                            if (i7 < length3) {
                                                                Variable.Assignment assignment2 = array3[i7];
                                                                if (assignment2.var.equals(enumVariable2)) {
                                                                    bNode2.countInstance(null, assignment2.val, Double.valueOf(doubleValue));
                                                                } else {
                                                                    i7++;
                                                                }
                                                            }
                                                        }
                                                    } catch (ClassCastException e2) {
                                                        throw new EMRuntimeException("Failed query for sample #" + (i2 + 1) + ": " + variable.getName() + " is a non-enumerable root node");
                                                    }
                                                }
                                            }
                                        }
                                    } catch (RuntimeException e3) {
                                        throw new EMRuntimeException("Failed query for sample #" + (i2 + 1) + " and node " + bNode2.getName() + ": " + e3.getLocalizedMessage());
                                    }
                                } else {
                                    bNode2.countInstance(objArr2, bNode3);
                                }
                            }
                        }
                        break;
                    case 2:
                        HashSet hashSet = new HashSet();
                        for (BNode bNode4 : synchronizedMap.keySet()) {
                            if (bNode4.isTrainable()) {
                                if (!bNode4.isRoot()) {
                                    Object[] key2 = EnumTable.getKey(bNode4.getParents(), array);
                                    synchronizedMap.put(bNode4, key2);
                                    for (int i8 = 0; i8 < key2.length; i8++) {
                                        if (key2[i8] == null) {
                                            hashSet.add(bNode4.getParents().get(i8));
                                        }
                                    }
                                }
                                if (bNode4.getInstance() == null) {
                                    hashSet.add(bNode4.getVariable());
                                }
                            }
                        }
                        if (hashSet.size() <= 0) {
                            for (BNode bNode5 : synchronizedMap.keySet()) {
                                if (bNode5.isTrainable()) {
                                    bNode5.countInstance(bNode5.isRoot() ? null : (Object[]) synchronizedMap.get(bNode5), bNode5.getInstance());
                                }
                            }
                            break;
                        } else {
                            try {
                                Variable[] variableArr3 = new Variable[hashSet.size()];
                                hashSet.toArray(variableArr3);
                                CGTable cGTable2 = (CGTable) this.inf.infer(this.inf.makeQuery(variableArr3));
                                for (int i9 : cGTable2.getIndices()) {
                                    Object[] key3 = cGTable2.getKey(i9);
                                    double doubleValue2 = cGTable2.getFactor(i9).doubleValue();
                                    JDF jdf2 = cGTable2.hasNonEnumVariables() ? cGTable2.getJDF(i9) : null;
                                    Variable.Assignment[] array4 = Variable.Assignment.array(cGTable2.getEnumVariables(), key3);
                                    for (BNode bNode6 : synchronizedMap.keySet()) {
                                        if (bNode6.isTrainable()) {
                                            if (bNode6.isRoot()) {
                                                Object bNode7 = bNode6.getInstance();
                                                if (bNode7 != null) {
                                                    bNode6.countInstance(null, bNode7, Double.valueOf(doubleValue2));
                                                } else {
                                                    Variable variable2 = bNode6.getVariable();
                                                    try {
                                                        EnumVariable enumVariable3 = (EnumVariable) variable2;
                                                        int length4 = array4.length;
                                                        int i10 = 0;
                                                        while (true) {
                                                            if (i10 < length4) {
                                                                Variable.Assignment assignment3 = array4[i10];
                                                                if (assignment3.var.equals(enumVariable3)) {
                                                                    bNode6.countInstance(null, assignment3.val, Double.valueOf(doubleValue2));
                                                                } else {
                                                                    i10++;
                                                                }
                                                            }
                                                        }
                                                    } catch (ClassCastException e4) {
                                                        throw new EMRuntimeException("Failed query for sample #" + (i2 + 1) + ": " + variable2.getName() + " is a non-enumerable root node");
                                                    }
                                                }
                                            } else {
                                                Object[] objArr3 = (Object[]) synchronizedMap.get(bNode6);
                                                EnumTable.overlay(objArr3, EnumTable.getKey(bNode6.getParents(), array4));
                                                Object bNode8 = bNode6.getInstance();
                                                if (bNode8 != null) {
                                                    bNode6.countInstance(objArr3, bNode8, Double.valueOf(doubleValue2));
                                                } else {
                                                    Variable variable3 = bNode6.getVariable();
                                                    try {
                                                        EnumVariable enumVariable4 = (EnumVariable) variable3;
                                                        int length5 = array4.length;
                                                        int i11 = 0;
                                                        while (true) {
                                                            if (i11 < length5) {
                                                                Variable.Assignment assignment4 = array4[i11];
                                                                if (assignment4.var.equals(enumVariable4)) {
                                                                    bNode6.countInstance(objArr3, assignment4.val, Double.valueOf(doubleValue2));
                                                                } else {
                                                                    i11++;
                                                                }
                                                            }
                                                        }
                                                    } catch (ClassCastException e5) {
                                                        bNode6.countInstance(objArr3, jdf2.getDistrib(variable3), Double.valueOf(doubleValue2));
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                                break;
                            } catch (RuntimeException e6) {
                                throw new EMRuntimeException("Failed query for sample #" + (i2 + 1) + ": " + e6.getMessage());
                            }
                        }
                    case 3:
                        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.EM_THREAD_COUNT);
                        Iterator it3 = synchronizedMap.keySet().iterator();
                        while (it3.hasNext()) {
                            newFixedThreadPool.execute(new EMc1((BNode) it3.next(), array, i2));
                        }
                        newFixedThreadPool.shutdown();
                        do {
                        } while (!newFixedThreadPool.isTerminated());
                }
                i2++;
            }
            for (BNode bNode9 : synchronizedMap.keySet()) {
                if (bNode9.isTrainable()) {
                    bNode9.maximizeInstance();
                }
            }
            if (i % 10 == 0) {
                for (int i12 = 0; i12 < objArr.length; i12++) {
                    for (int i13 = 0; i13 < variableArr.length; i13++) {
                        BNode node2 = this.f3bn.getNode(variableArr[i13]);
                        if (node2 == null) {
                            throw new EMRuntimeException("Variable \"" + variableArr[i12].getName() + "\" is not part of Bayesian network");
                        }
                        if (objArr[i12][i13] != null) {
                            node2.setInstance(objArr[i12][i13]);
                        } else {
                            node2.resetInstance();
                        }
                    }
                    double logLikelihood = ((VarElim) this.inf).logLikelihood();
                    if (Double.isNaN(logLikelihood)) {
                        System.err.println("Sample " + i12 + "/" + objArr.length + " log-likelihood is " + logLikelihood);
                    }
                    d += logLikelihood;
                    if (Double.isInfinite(d)) {
                        System.err.println("Log-likelihood is infinite: " + d);
                    }
                }
            }
            if (i % 10 == 0) {
                double d2 = 0.0d;
                if (i <= 50) {
                    for (int i14 = 0; i14 < dArr.length - 1; i14++) {
                        dArr[i14] = dArr[i14 + 1];
                    }
                    dArr[dArr.length - 1] = d;
                } else {
                    double length6 = dArr[0] / dArr.length;
                    for (int i15 = 0; i15 < dArr.length - 1; i15++) {
                        dArr[i15] = dArr[i15 + 1];
                        length6 += dArr[i15] / dArr.length;
                    }
                    double[] dArr2 = new double[dArr.length];
                    for (int i16 = 0; i16 < dArr.length; i16++) {
                        dArr2[i16] = (dArr[i16] - length6) * (dArr[i16] - length6);
                    }
                    d2 = dArr2[0] / dArr2.length;
                    for (int i17 = 0; i17 < dArr.length - 1; i17++) {
                        dArr2[i17] = dArr2[i17 + 1];
                        d2 += dArr2[i17] / dArr2.length;
                    }
                    dArr[dArr.length - 1] = d;
                    if (length6 > 0.0d) {
                        if (d2 < this.EM_CONVERGENCE_CRITERION * 0.01d * length6) {
                            z = true;
                        }
                    } else if (d2 < this.EM_CONVERGENCE_CRITERION * 0.01d * (-length6)) {
                        z = true;
                    }
                }
                if (this.EM_PRINT_STATUS || z) {
                    System.err.println("Completed " + i + " round(s), L = " + d);
                    if (z) {
                        System.err.println("SD(L) = " + d2);
                    }
                }
            }
        }
        Iterator<BNode> it4 = this.f3bn.getNodes().iterator();
        while (it4.hasNext()) {
            it4.next().resetInstance();
        }
        if (this.EM_PRINT_STATUS) {
            System.err.println("Done.");
        }
    }
}
