package bn.alg;

import bn.BNet;
import bn.BNode;
import bn.factor.AbstractFactor;
import bn.factor.Factor;
import bn.factor.Factorize;
import dat.EnumVariable;
import dat.Variable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* JADX WARN: Classes with same name are omitted:
  input_file:target/classes/bn/alg/VarElim.class
 */
/* loaded from: input_file:bn/alg/VarElim.class */
public class VarElim implements Inference {

    /* renamed from: bn, reason: collision with root package name */
    public BNet f4bn;
    static int STATUS_BEL = 0;
    static int STATUS_MPE = 1;

    /* JADX WARN: Classes with same name are omitted:
      input_file:target/classes/bn/alg/VarElim$Bucket.class
     */
    /* loaded from: input_file:bn/alg/VarElim$Bucket.class */
    public class Bucket {
        List<AbstractFactor> factors;
        List<Variable> vars;

        Bucket(Variable variable) {
            this.vars = new ArrayList();
            this.vars.add(variable);
            this.factors = new ArrayList();
        }

        Bucket(List<Variable> list) {
            this.vars = new ArrayList();
            this.vars = list;
            this.factors = new ArrayList();
        }

        boolean match(AbstractFactor abstractFactor) {
            for (EnumVariable enumVariable : abstractFactor.getEnumVars()) {
                if (this.vars.contains(enumVariable)) {
                    return true;
                }
            }
            return false;
        }

        boolean hasFactorWith(Variable variable) {
            Iterator<AbstractFactor> it = this.factors.iterator();
            while (it.hasNext()) {
                if (it.next().hasVariable(variable)) {
                    return true;
                }
            }
            return false;
        }

        void put(AbstractFactor abstractFactor) {
            this.factors.add(abstractFactor);
        }

        List<AbstractFactor> get() {
            return this.factors;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Classes with same name are omitted:
      input_file:target/classes/bn/alg/VarElim$CGQuery.class
     */
    /* loaded from: input_file:bn/alg/VarElim$CGQuery.class */
    public class CGQuery implements Query {
        final List<Variable> Q;
        final List<Variable> X;
        private int status = VarElim.STATUS_BEL;
        final Map<Variable, Object> E = new HashMap();

        CGQuery(List<Variable> list, List<Variable.Assignment> list2, List<Variable> list3) {
            this.Q = list;
            for (Variable.Assignment assignment : list2) {
                this.E.put(assignment.var, assignment.val);
            }
            Iterator<Variable> it = list.iterator();
            while (it.hasNext()) {
                this.E.put(it.next(), null);
            }
            Iterator<Variable> it2 = list3.iterator();
            while (it2.hasNext()) {
                this.E.put(it2.next(), null);
            }
            this.X = list3;
        }

        void setStatus(int i) {
            this.status = i;
        }

        int getStatus() {
            return this.status;
        }

        Map<Variable, Object> getRelevant() {
            return this.E;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("Q:");
            Iterator<Variable> it = this.Q.iterator();
            while (it.hasNext()) {
                sb.append(it.next().getName()).append(",");
            }
            sb.append("|E:");
            for (Map.Entry<Variable, Object> entry : this.E.entrySet()) {
                if (entry.getValue() != null) {
                    sb.append(entry.getKey().toString()).append("=").append(entry.getValue().toString()).append(",");
                }
            }
            sb.append("|X:");
            Iterator<Variable> it2 = this.X.iterator();
            while (it2.hasNext()) {
                sb.append(it2.next().getName()).append(",");
            }
            sb.append("|Status:").append(this.status);
            return sb.toString();
        }
    }

    /* JADX WARN: Classes with same name are omitted:
      input_file:target/classes/bn/alg/VarElim$FTCompare.class
     */
    /* loaded from: input_file:bn/alg/VarElim$FTCompare.class */
    public class FTCompare implements Comparator<Factor> {
        public FTCompare() {
        }

        @Override // java.util.Comparator
        public int compare(Factor factor, Factor factor2) {
            return factor.getEnumVariables().size() - factor2.getEnumVariables().size();
        }
    }

    /* JADX WARN: Classes with same name are omitted:
      input_file:target/classes/bn/alg/VarElim$VarElimRuntimeException.class
     */
    /* loaded from: input_file:bn/alg/VarElim$VarElimRuntimeException.class */
    public class VarElimRuntimeException extends RuntimeException {
        private static final long serialVersionUID = 1;

        public VarElimRuntimeException(String str) {
            super(str);
        }
    }

    @Override // bn.alg.Inference
    public void instantiate(BNet bNet) {
        this.f4bn = bNet;
        this.f4bn.compile();
    }

    @Override // bn.alg.Inference
    public Query makeQuery(Variable... variableArr) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        arrayList.addAll(Arrays.asList(variableArr));
        for (BNode bNode : this.f4bn.getDconnected(variableArr)) {
            Variable variable = bNode.getVariable();
            Object bNode2 = bNode.getInstance();
            if (bNode2 != null) {
                arrayList2.add(new Variable.Assignment(variable, bNode2));
            } else if (!arrayList.contains(variable)) {
                arrayList3.add(variable);
            }
        }
        return new CGQuery(arrayList, arrayList2, arrayList3);
    }

    public Query makeMPE(Variable... variableArr) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        arrayList.addAll(Arrays.asList(variableArr));
        for (BNode bNode : this.f4bn.getOrdered()) {
            Variable variable = bNode.getVariable();
            Object bNode2 = bNode.getInstance();
            if (bNode2 != null) {
                arrayList2.add(new Variable.Assignment(variable, bNode2));
            } else if (!arrayList.contains(variable)) {
                arrayList3.add(variable);
            }
        }
        CGQuery cGQuery = new CGQuery(arrayList, arrayList2, arrayList3);
        cGQuery.setStatus(STATUS_MPE);
        return cGQuery;
    }

    public Query makeNominatedMPEQuery(Variable... variableArr) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (BNode bNode : this.f4bn.getDconnected(variableArr)) {
            Variable variable = bNode.getVariable();
            Object bNode2 = bNode.getInstance();
            if (bNode2 != null) {
                arrayList2.add(new Variable.Assignment(variable, bNode2));
            } else {
                arrayList3.add(variable);
            }
        }
        CGQuery cGQuery = new CGQuery(arrayList, arrayList2, arrayList3);
        cGQuery.setStatus(STATUS_MPE);
        return cGQuery;
    }

    @Override // bn.alg.Inference
    public QueryResult infer(Query query) {
        AbstractFactor product;
        CGQuery cGQuery = (CGQuery) query;
        cGQuery.Q.toArray(new Variable[cGQuery.Q.size()]);
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Bucket(cGQuery.Q));
        Iterator<Variable> it = cGQuery.X.iterator();
        while (it.hasNext()) {
            try {
                arrayList.add(new Bucket((EnumVariable) it.next()));
            } catch (ClassCastException e) {
            }
        }
        int size = arrayList.size();
        Iterator<Variable> it2 = cGQuery.getRelevant().keySet().iterator();
        while (it2.hasNext()) {
            BNode node = this.f4bn.getNode(it2.next());
            AbstractFactor makeDenseFactor = node.makeDenseFactor(cGQuery.getRelevant());
            boolean z = false;
            if (makeDenseFactor.hasEnumVars()) {
                int i = size - 1;
                while (true) {
                    if (i < 0 || 0 != 0) {
                        break;
                    }
                    Bucket bucket = (Bucket) arrayList.get(i);
                    if (bucket.match(makeDenseFactor)) {
                        bucket.put(makeDenseFactor);
                        z = true;
                        break;
                    }
                    i--;
                }
                if (!z) {
                    throw new VarElimRuntimeException("Node can not be eliminated in inference: " + node.getName());
                }
            } else {
                ((Bucket) arrayList.get(0)).put(makeDenseFactor);
            }
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 1; i2 < size; i2++) {
            Bucket bucket2 = (Bucket) arrayList.get(i2);
            if (bucket2.factors.isEmpty()) {
                for (Variable variable : bucket2.vars) {
                    for (int i3 = i2 + 1; i3 < size; i3++) {
                        try {
                            Bucket bucket3 = (Bucket) arrayList.get(i3);
                            if (bucket3.hasFactorWith(variable) || i3 == size - 1) {
                                bucket3.vars.add(variable);
                            }
                        } catch (IndexOutOfBoundsException e2) {
                            throw new VarElimRuntimeException("Bucket elimination failed during purging and merging: Variable is " + variable.getName());
                        }
                    }
                }
                arrayList2.add(Integer.valueOf(i2));
            }
        }
        arrayList.removeAll(arrayList2);
        for (int size2 = arrayList.size() - 1; size2 >= 0; size2--) {
            Bucket bucket4 = (Bucket) arrayList.get(size2);
            int size3 = bucket4.factors.size();
            if (size3 > 0) {
                if (size3 == 1) {
                    product = bucket4.factors.get(0);
                } else if (size3 == 2) {
                    product = Factorize.getProduct(bucket4.factors.get(0), bucket4.factors.get(1));
                } else {
                    AbstractFactor[] abstractFactorArr = new AbstractFactor[bucket4.factors.size()];
                    bucket4.factors.toArray(abstractFactorArr);
                    product = Factorize.getProduct(abstractFactorArr);
                }
                if (size2 <= 0) {
                    return new CGTable(product, cGQuery.Q);
                }
                try {
                    Variable[] variableArr = new Variable[bucket4.vars.size()];
                    for (int i4 = 0; i4 < variableArr.length; i4++) {
                        variableArr[i4] = bucket4.vars.get(i4);
                    }
                    AbstractFactor maxMargin = cGQuery.getStatus() == STATUS_MPE ? Factorize.getMaxMargin(product, variableArr) : Factorize.getMargin(product, variableArr);
                    if (maxMargin.hasEnumVars()) {
                        int i5 = size2 - 1;
                        while (true) {
                            if (i5 < 0) {
                                break;
                            }
                            Bucket bucket5 = (Bucket) arrayList.get(i5);
                            if (bucket5.match(maxMargin)) {
                                bucket5.put(maxMargin);
                                break;
                            }
                            i5--;
                        }
                    } else {
                        ((Bucket) arrayList.get(0)).put(maxMargin);
                    }
                } catch (ClassCastException e3) {
                    throw new VarElimRuntimeException("Cannot marginalize or maximize-out continuous variables");
                }
            }
        }
        throw new VarElimRuntimeException("Variable elimination failed");
    }

    public QueryResult infer(Variable[] variableArr) {
        return infer(makeQuery(variableArr));
    }

    public QueryResult infer(Variable variable) {
        return infer(new Variable[]{variable});
    }

    public QueryResult infer(BNode[] bNodeArr) {
        Variable[] variableArr = new Variable[bNodeArr.length];
        for (int i = 0; i < variableArr.length; i++) {
            variableArr[i] = bNodeArr[i].getVariable();
        }
        return infer(makeQuery(variableArr));
    }

    public QueryResult infer(BNode bNode) {
        return infer(new BNode[]{bNode});
    }

    public double logLikelihood() {
        AbstractFactor product;
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        for (BNode bNode : this.f4bn.getOrdered()) {
            Variable variable = bNode.getVariable();
            Object bNode2 = bNode.getInstance();
            hashMap.put(variable, bNode2);
            if (bNode2 == null) {
                arrayList.add(variable);
            }
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            try {
                arrayList2.add(new Bucket((EnumVariable) ((Variable) it.next())));
            } catch (ClassCastException e) {
            }
        }
        int size = arrayList2.size();
        if (size == 0) {
            arrayList2.add(new Bucket((EnumVariable) null));
        }
        Iterator<BNode> it2 = this.f4bn.getNodes().iterator();
        while (it2.hasNext()) {
            AbstractFactor makeDenseFactor = it2.next().makeDenseFactor(hashMap);
            boolean z = false;
            if (makeDenseFactor.hasEnumVars()) {
                for (int i = size - 1; i >= 0 && !z; i--) {
                    Bucket bucket = (Bucket) arrayList2.get(i);
                    if (bucket.match(makeDenseFactor)) {
                        bucket.put(makeDenseFactor);
                        z = true;
                    }
                }
                if (!z) {
                    ((Bucket) arrayList2.get(0)).put(makeDenseFactor);
                }
            } else {
                ((Bucket) arrayList2.get(0)).put(makeDenseFactor);
            }
        }
        for (int i2 = 1; i2 < size; i2++) {
            Bucket bucket2 = (Bucket) arrayList2.get(i2);
            if (bucket2.factors.isEmpty()) {
                for (Variable variable2 : bucket2.vars) {
                    for (int i3 = i2 + 1; i3 < size; i3++) {
                        Bucket bucket3 = (Bucket) arrayList2.get(i3);
                        if (bucket3.hasFactorWith(variable2) || i3 == size - 1) {
                            bucket3.vars.add(variable2);
                        }
                    }
                }
                arrayList2.remove(i2);
            }
        }
        for (int size2 = arrayList2.size() - 1; size2 >= 0; size2--) {
            Bucket bucket4 = (Bucket) arrayList2.get(size2);
            int size3 = bucket4.factors.size();
            if (size3 > 0) {
                if (size3 == 1) {
                    product = bucket4.factors.get(0);
                } else if (size3 == 2) {
                    product = Factorize.getProduct(bucket4.factors.get(0), bucket4.factors.get(1));
                } else {
                    AbstractFactor[] abstractFactorArr = new AbstractFactor[bucket4.factors.size()];
                    bucket4.factors.toArray(abstractFactorArr);
                    product = Factorize.getProduct(abstractFactorArr);
                }
                if (size2 <= 0) {
                    return product.getLogSum();
                }
                try {
                    Variable[] variableArr = new Variable[bucket4.vars.size()];
                    for (int i4 = 0; i4 < variableArr.length; i4++) {
                        variableArr[i4] = bucket4.vars.get(i4);
                    }
                    AbstractFactor margin = Factorize.getMargin(product, variableArr);
                    int i5 = size2 - 1;
                    while (true) {
                        if (i5 < 0) {
                            break;
                        }
                        Bucket bucket5 = (Bucket) arrayList2.get(i5);
                        if (bucket5.match(margin)) {
                            bucket5.put(margin);
                            break;
                        }
                        i5--;
                    }
                } catch (ClassCastException e2) {
                    throw new VarElimRuntimeException("Cannot marginalize continuous variables");
                }
            }
        }
        throw new VarElimRuntimeException("Exited variable elimination prematurely");
    }
}
