package bn;

import bn.prob.EnumDistrib;
import dat.EnumVariable;
import dat.Enumerable;
import dat.Variable;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

/* JADX WARN: Classes with same name are omitted:
  input_file:target/classes/bn/BNet.class
 */
/* loaded from: input_file:bn/BNet.class */
public class BNet implements Serializable {
    private static final long serialVersionUID = 1;
    private boolean compiled = true;
    private final Map<String, BNode> nodesByName = new HashMap();
    private final Map<Variable, BNode> nodesByVar = new HashMap();
    private Map<BNode, Set<BNode>> ch2par = new HashMap();
    private Map<BNode, Set<BNode>> par2ch = new HashMap();
    private List<BNode> ordered = new ArrayList();
    private List<Variable> orderedVars = new ArrayList();
    private HashMap<String, Set<BNode>> tagged = new HashMap<>();

    /* JADX WARN: Classes with same name are omitted:
      input_file:target/classes/bn/BNet$NodeDirection.class
     */
    /* loaded from: input_file:bn/BNet$NodeDirection.class */
    public class NodeDirection {
        private BNode node;
        private String direction;

        public NodeDirection(BNode bNode, String str) {
            this.node = bNode;
            this.direction = str;
        }

        public BNode getNode() {
            return this.node;
        }

        public String getDirection() {
            return this.direction;
        }

        public boolean within(Set<NodeDirection> set) {
            for (NodeDirection nodeDirection : set) {
                if (this.node.equals(nodeDirection.getNode()) && this.direction.equals(nodeDirection.getDirection())) {
                    return true;
                }
            }
            return false;
        }
    }

    public void add(BNode bNode) {
        if (this.nodesByName.containsKey(bNode.getName()) || this.nodesByVar.containsKey(bNode.getVariable())) {
            throw new BNetRuntimeException("Duplicate node names in BNet: " + bNode.getName());
        }
        this.compiled = false;
        this.nodesByName.put(bNode.getName(), bNode);
        this.nodesByVar.put(bNode.getVariable(), bNode);
    }

    public void add(BNode... bNodeArr) {
        for (BNode bNode : bNodeArr) {
            add(bNode);
        }
    }

    public void remove(BNode bNode) {
        this.compiled = false;
        this.nodesByName.remove(bNode.getName());
        this.nodesByVar.remove(bNode.getVariable());
    }

    public void compile() {
        if (this.compiled) {
            return;
        }
        this.ch2par.clear();
        this.par2ch.clear();
        for (BNode bNode : this.nodesByVar.values()) {
            HashSet hashSet = new HashSet();
            List<EnumVariable> parents = bNode.getParents();
            if (parents != null) {
                for (EnumVariable enumVariable : parents) {
                    BNode bNode2 = this.nodesByVar.get(enumVariable);
                    if (bNode2 == null) {
                        System.err.println("Invalid Bayesian network: node " + enumVariable.getName() + " is not a member but referenced by " + bNode.getName());
                        throw new BNetRuntimeException("Invalid Bayesian network: node " + enumVariable.getName() + " is not a member but referenced by " + bNode.getName());
                    }
                    hashSet.add(bNode2);
                    Set<BNode> set = this.par2ch.get(bNode2);
                    if (set == null) {
                        set = new HashSet();
                        this.par2ch.put(bNode2, set);
                    }
                    set.add(bNode);
                }
            }
            this.ch2par.put(bNode, hashSet);
        }
        this.compiled = true;
        this.ordered.clear();
        for (BNode bNode3 : getRoots()) {
            ArrayList arrayList = new ArrayList();
            if (!this.ordered.contains(bNode3) || this.ordered.size() != this.nodesByVar.size()) {
                arrayList.add(bNode3);
                for (BNode bNode4 : getDescendants(bNode3)) {
                    if (!this.ordered.contains(bNode4)) {
                        arrayList.add(bNode4);
                    }
                }
                this.ordered.addAll(0, arrayList);
            }
        }
    }

    public void setCompiled(Boolean bool) {
        this.compiled = bool.booleanValue();
    }

    public Set<BNode> getRoots() {
        HashSet hashSet = new HashSet();
        for (BNode bNode : this.nodesByName.values()) {
            if (bNode.isRoot()) {
                hashSet.add(bNode);
            }
        }
        return hashSet;
    }

    public Collection<BNode> getNodes() {
        return this.nodesByName.values();
    }

    public Set<String> getParentsNames(String str) {
        BNode node = getNode(str);
        if (node == null) {
            throw new BNetRuntimeException("Node " + str + " does not exist in network");
        }
        return getParentsNames(node);
    }

    public Set<String> getParentsNames(BNode bNode) {
        List<EnumVariable> parents = bNode.getParents();
        if (parents == null) {
            return null;
        }
        HashSet hashSet = new HashSet();
        Iterator<EnumVariable> it = parents.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().toString());
        }
        return hashSet;
    }

    public Set<BNode> getParents(BNode bNode) {
        if (!this.compiled) {
            compile();
        }
        return this.ch2par.get(bNode);
    }

    public Set<BNode> getSiblings(BNode bNode) {
        if (!this.compiled) {
            compile();
        }
        Set<BNode> set = this.ch2par.get(bNode);
        if (set == null) {
            return null;
        }
        HashSet hashSet = new HashSet();
        Iterator<BNode> it = set.iterator();
        while (it.hasNext()) {
            hashSet.addAll(this.par2ch.get(it.next()));
        }
        hashSet.remove(bNode);
        return hashSet;
    }

    public List<BNode> getAncestors(String str) {
        BNode node = getNode(str);
        if (node == null) {
            throw new BNetRuntimeException("Node " + str + " does not exist in network");
        }
        return getAncestors(node);
    }

    public List<BNode> getAncestors(BNode bNode) {
        if (!this.compiled) {
            compile();
        }
        Set<BNode> set = this.ch2par.get(bNode);
        if (set == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(set);
        Iterator<BNode> it = set.iterator();
        while (it.hasNext()) {
            List<BNode> ancestors = getAncestors(it.next());
            if (ancestors != null) {
                for (BNode bNode2 : ancestors) {
                    if (!arrayList.contains(bNode2)) {
                        arrayList.add(bNode2);
                    }
                }
            }
        }
        return arrayList;
    }

    public Set<String> getChildrenNames(BNode bNode) {
        if (!this.compiled) {
            compile();
        }
        Set<BNode> set = this.par2ch.get(bNode);
        if (set == null) {
            return null;
        }
        HashSet hashSet = new HashSet();
        Iterator<BNode> it = set.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().getName());
        }
        return hashSet;
    }

    public Set<BNode> getChildren(BNode bNode) {
        if (!this.compiled) {
            compile();
        }
        Set<BNode> set = this.par2ch.get(bNode);
        if (set == null) {
            return null;
        }
        HashSet hashSet = new HashSet();
        Iterator<BNode> it = set.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next());
        }
        return hashSet;
    }

    public boolean hasChildren(BNode bNode) {
        if (!this.compiled) {
            compile();
        }
        return this.par2ch.get(bNode) != null;
    }

    public List<BNode> getDescendants(String str) {
        BNode node = getNode(str);
        if (node == null) {
            throw new BNetRuntimeException("Node " + str + " does not exist in network");
        }
        return getDescendants(node);
    }

    public List<BNode> getDescendants(BNode bNode) {
        if (!this.compiled) {
            compile();
        }
        Set<BNode> set = this.par2ch.get(bNode);
        if (set == null) {
            return Collections.EMPTY_LIST;
        }
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(set);
        Iterator<BNode> it = set.iterator();
        while (it.hasNext()) {
            List<BNode> descendants = getDescendants(it.next());
            if (descendants != null) {
                for (BNode bNode2 : descendants) {
                    if (!arrayList.contains(bNode2)) {
                        arrayList.add(bNode2);
                    }
                }
            }
        }
        return arrayList;
    }

    public List<BNode> getOrdered() {
        if (!this.compiled) {
            compile();
        }
        return this.ordered;
    }

    public List<Variable> getOrderedVariables() {
        Iterator<BNode> it = getOrdered().iterator();
        while (it.hasNext()) {
            this.orderedVars.add(it.next().getVariable());
        }
        return this.orderedVars;
    }

    public List<BNode> getAlphabetical() {
        ArrayList arrayList = new ArrayList();
        Iterator it = new TreeSet(getNames()).iterator();
        while (it.hasNext()) {
            arrayList.add(getNode((String) it.next()));
        }
        return arrayList;
    }

    public Set<String> getNames() {
        return this.nodesByName.keySet();
    }

    public BNode getNode(String str) {
        BNode bNode = this.nodesByName.get(str);
        if (bNode != null) {
            return bNode;
        }
        for (String str2 : this.nodesByName.keySet()) {
            int lastIndexOf = str2.lastIndexOf(".");
            if ((lastIndexOf < 0 ? str2 : str2.substring(0, lastIndexOf)).equals(str)) {
                return this.nodesByName.get(str2);
            }
        }
        return null;
    }

    public BNode getNode(Variable variable) {
        BNode bNode = this.nodesByVar.get(variable);
        if (bNode != null) {
            return bNode;
        }
        return null;
    }

    public Object[] getEvidenceKey(BNode bNode) {
        if (bNode.isRoot()) {
            return null;
        }
        List<EnumVariable> parents = bNode.getParents();
        Object[] objArr = new Object[parents.size()];
        for (int i = 0; i < objArr.length; i++) {
            objArr[i] = getNode(parents.get(i)).getInstance();
        }
        return objArr;
    }

    public void sampleInstance() {
        for (BNode bNode : getOrdered()) {
            if (bNode.getInstance() == null) {
                bNode.setInstance(bNode.getDistrib(getEvidenceKey(bNode)).sample());
            }
        }
    }

    public BNet getRelevant(Variable... variableArr) {
        BNet bNet = new BNet();
        HashSet hashSet = new HashSet();
        for (Variable variable : variableArr) {
            hashSet.add(variable.toString());
        }
        for (BNode bNode : this.nodesByName.values()) {
            if (bNode.getInstance() != null || hashSet.contains(bNode.getVariable().toString())) {
                bNet.add(bNode);
            }
        }
        HashSet<BNode> hashSet2 = new HashSet();
        Iterator<String> it = bNet.getNames().iterator();
        while (it.hasNext()) {
            List<BNode> ancestors = getAncestors(it.next());
            if (ancestors != null) {
                hashSet2.addAll(ancestors);
            }
        }
        for (BNode bNode2 : hashSet2) {
            if (!bNet.getNames().contains(bNode2.getName())) {
                bNet.add(bNode2);
            }
        }
        bNet.compile();
        return bNet;
    }

    public List<BNode> getDconnected(Variable... variableArr) {
        HashSet<BNode> hashSet = new HashSet();
        for (BNode bNode : getNodes()) {
            if (bNode.getInstance() != null) {
                hashSet.add(bNode);
            }
        }
        HashSet hashSet2 = new HashSet();
        for (BNode bNode2 : hashSet) {
            hashSet2.addAll(getAncestors(bNode2));
            hashSet2.add(bNode2);
        }
        ArrayList arrayList = new ArrayList();
        for (Variable variable : variableArr) {
            BNode node = getNode(variable);
            if (node == null) {
                throw new NullPointerException("Invalid query: node " + variable.toString() + " does not exist in this network");
            }
            arrayList.add(new NodeDirection(node, "up"));
        }
        HashSet hashSet3 = new HashSet();
        HashSet hashSet4 = new HashSet();
        while (!arrayList.isEmpty()) {
            NodeDirection nodeDirection = (NodeDirection) arrayList.remove(0);
            if (!nodeDirection.within(hashSet3)) {
                hashSet4.add(nodeDirection.getNode());
                hashSet3.add(nodeDirection);
                if (nodeDirection.getDirection() == "up" && !hashSet.contains(nodeDirection.getNode())) {
                    if (nodeDirection.getNode().getParents() != null) {
                        Iterator<EnumVariable> it = nodeDirection.getNode().getParents().iterator();
                        while (it.hasNext()) {
                            arrayList.add(new NodeDirection(getNode(it.next()), "up"));
                        }
                    }
                    if (getChildrenNames(nodeDirection.getNode()) != null) {
                        Iterator<String> it2 = getChildrenNames(nodeDirection.getNode()).iterator();
                        while (it2.hasNext()) {
                            arrayList.add(new NodeDirection(getNode(it2.next()), "down"));
                        }
                    }
                } else if (nodeDirection.getDirection() == "down") {
                    if (!hashSet.contains(nodeDirection.getNode()) && getChildrenNames(nodeDirection.getNode()) != null) {
                        Iterator<String> it3 = getChildrenNames(nodeDirection.getNode()).iterator();
                        while (it3.hasNext()) {
                            arrayList.add(new NodeDirection(getNode(it3.next()), "down"));
                        }
                    }
                    if (hashSet2.contains(nodeDirection.getNode()) && nodeDirection.getNode().getParents() != null) {
                        Iterator<EnumVariable> it4 = nodeDirection.getNode().getParents().iterator();
                        while (it4.hasNext()) {
                            arrayList.add(new NodeDirection(getNode(it4.next()), "up"));
                        }
                    }
                }
            }
        }
        ArrayList arrayList2 = new ArrayList();
        for (BNode bNode3 : getOrdered()) {
            if (hashSet4.contains(bNode3)) {
                arrayList2.add(bNode3);
            }
        }
        return arrayList2;
    }

    public void resetNodes() {
        Iterator<BNode> it = getNodes().iterator();
        while (it.hasNext()) {
            it.next().resetInstance();
        }
    }

    public Set<BNode> getMB(BNode bNode) {
        if (!this.compiled) {
            compile();
        }
        HashSet hashSet = new HashSet();
        Set<BNode> set = this.par2ch.get(bNode);
        if (set != null) {
            hashSet.addAll(set);
            Iterator<BNode> it = set.iterator();
            while (it.hasNext()) {
                Set<BNode> set2 = this.ch2par.get(it.next());
                if (set2 != null) {
                    hashSet.addAll(set2);
                }
            }
            hashSet.remove(bNode);
        }
        Set<BNode> set3 = this.ch2par.get(bNode);
        if (set3 != null) {
            hashSet.addAll(set3);
        }
        return hashSet;
    }

    public Object getMBProb(BNode bNode) {
        Object sample;
        if (!this.compiled) {
            compile();
        }
        Object bNode2 = bNode.getInstance();
        bNode.resetInstance();
        try {
            Enumerable domain = ((EnumVariable) bNode.getVariable()).getDomain();
            double[] dArr = new double[domain.size()];
            Object[] values = domain.getValues();
            for (int i = 0; i < values.length; i++) {
                dArr[i] = bNode.get(getEvidenceKey(bNode), values[i]).doubleValue();
                bNode.setInstance(values[i]);
                Set<BNode> set = this.par2ch.get(bNode);
                if (set != null) {
                    for (BNode bNode3 : set) {
                        int i2 = i;
                        dArr[i2] = dArr[i2] * bNode3.get(getEvidenceKey(bNode3), bNode3.getInstance()).doubleValue();
                    }
                }
            }
            sample = new EnumDistrib(domain, dArr).sample();
        } catch (ClassCastException e) {
            sample = bNode.getDistrib(getEvidenceKey(bNode)).sample();
        }
        bNode.setInstance(bNode2);
        return sample;
    }

    public Object getMBProb(Set<BNode> set, BNode bNode) {
        bNode.getInstance();
        bNode.resetInstance();
        try {
            Enumerable domain = ((EnumVariable) bNode.getVariable()).getDomain();
            double[] dArr = new double[domain.size()];
            Object[] values = domain.getValues();
            for (int i = 0; i < values.length; i++) {
                dArr[i] = bNode.get(getEvidenceKey(bNode), values[i]).doubleValue();
                bNode.setInstance(values[i]);
                for (BNode bNode2 : set) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] * bNode2.get(getEvidenceKey(bNode2), bNode2.getInstance()).doubleValue();
                }
            }
            return new EnumDistrib(domain, dArr).sample();
        } catch (ClassCastException e) {
            return bNode.getDistrib(getEvidenceKey(bNode)).sample();
        }
    }

    private String[] makeArray(Set<String> set) {
        String[] strArr = new String[set.size()];
        Iterator<String> it = set.iterator();
        int i = 0;
        while (it.hasNext()) {
            strArr[i] = it.next();
            i++;
        }
        return strArr;
    }

    private Set<String> makeSet(String[] strArr) {
        HashSet hashSet = new HashSet();
        hashSet.addAll(Arrays.asList(strArr));
        return hashSet;
    }

    public Set<String> getTagNames() {
        return this.tagged.keySet();
    }

    public List<BNode> getTagged() {
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        Iterator<Map.Entry<String, Set<BNode>>> it = this.tagged.entrySet().iterator();
        while (it.hasNext()) {
            Iterator<BNode> it2 = it.next().getValue().iterator();
            while (it2.hasNext()) {
                hashSet.add(it2.next());
            }
        }
        arrayList.addAll(hashSet);
        return arrayList;
    }

    public void removeAllTags() {
        this.tagged.clear();
    }

    public List<BNode> getTagged(String... strArr) {
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        for (BNode bNode : this.nodesByName.values()) {
            boolean z = false;
            int length = strArr.length;
            int i = 0;
            while (true) {
                if (i >= length) {
                    break;
                }
                String str = strArr[i];
                if (!this.tagged.keySet().contains(str)) {
                    throw new IllegalArgumentException("Tag " + str + " does not exist");
                }
                if (!this.tagged.get(str).contains(bNode)) {
                    z = false;
                    break;
                }
                z = true;
                i++;
            }
            if (z) {
                hashSet.add(bNode);
            }
        }
        arrayList.addAll(hashSet);
        return arrayList;
    }

    public void removeTag(String str, BNode bNode) {
        this.tagged.get(str).remove(bNode);
    }

    public void setTags(String str, BNode... bNodeArr) {
        if (!this.tagged.containsKey(str)) {
            this.tagged.put(str, new HashSet());
        }
        for (BNode bNode : bNodeArr) {
            this.tagged.get(str).add(bNode);
        }
    }

    public void setTags(String[] strArr, BNode... bNodeArr) {
        for (String str : strArr) {
            setTags(str, bNodeArr);
        }
    }

    public List<String> getTags(BNode bNode) {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, Set<BNode>> entry : this.tagged.entrySet()) {
            if (entry.getValue().contains(bNode)) {
                arrayList.add(entry.getKey());
            }
        }
        return arrayList;
    }
}
