package bn.alg;

import bn.BNet;
import bn.BNode;
import bn.SampleTrace;
import dat.Variable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;

/* loaded from: input_file:bn/alg/ApproxInference.class */
public class ApproxInference implements Inference {

    /* renamed from: bn, reason: collision with root package name */
    public BNet f0bn;
    private double logLikelihood = 1.0d;
    private Random randomGenerator = new Random();
    public static int iterations = 500;

    /* loaded from: input_file:bn/alg/ApproxInference$AQuery.class */
    public class AQuery implements Query {
        final List<BNode> X;
        final List<BNode> E;
        final List<BNode> Z;
        final List<BNode> rnl;

        AQuery(List<BNode> list, List<BNode> list2, List<BNode> list3, List<BNode> list4) {
            this.X = list;
            this.E = list2;
            this.Z = list3;
            this.rnl = list4;
        }
    }

    /* loaded from: input_file:bn/alg/ApproxInference$ApproxInferRuntimeException.class */
    public class ApproxInferRuntimeException extends RuntimeException {
        private static final long serialVersionUID = 1;

        public ApproxInferRuntimeException(String str) {
            super(str);
        }
    }

    @Override // bn.alg.Inference
    public void instantiate(BNet bNet) {
        this.f0bn = bNet;
        this.f0bn.compile();
    }

    @Override // bn.alg.Inference
    public Query makeQuery(Variable... variableArr) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        List<BNode> dconnected = this.f0bn.getDconnected(variableArr);
        try {
            for (Variable variable : variableArr) {
                arrayList.add(this.f0bn.getNode(variable));
            }
            for (BNode bNode : dconnected) {
                Variable variable2 = bNode.getVariable();
                if (bNode.getInstance() != null) {
                    arrayList2.add(this.f0bn.getNode(variable2));
                } else {
                    arrayList3.add(this.f0bn.getNode(variable2));
                }
            }
            return new AQuery(arrayList, arrayList2, arrayList3, dconnected);
        } catch (RuntimeException e) {
            throw new RuntimeException("makeQuery, ApproxInfer didn't work");
        }
    }

    @Override // bn.alg.Inference
    public CGTable infer(Query query) {
        AQuery aQuery = (AQuery) query;
        List<BNode> list = aQuery.rnl;
        this.f0bn.sampleInstance();
        SampleTrace sampleTrace = new SampleTrace(aQuery.X, iterations);
        sampleTrace.count();
        HashMap hashMap = new HashMap();
        int i = iterations;
        for (int i2 = 0; i2 < i - 1; i2++) {
            for (BNode bNode : aQuery.Z) {
                Object mBProb = this.f0bn.getMBProb(bNode);
                if (mBProb != null) {
                    hashMap.put(bNode, mBProb);
                }
            }
            for (Map.Entry entry : hashMap.entrySet()) {
                ((BNode) entry.getKey()).setInstance(entry.getValue());
            }
            sampleTrace.count();
        }
        Iterator<BNode> it = aQuery.Z.iterator();
        while (it.hasNext()) {
            it.next().resetInstance();
        }
        return new CGTable(sampleTrace.getFactor());
    }

    public static int getIterations() {
        return iterations;
    }

    public void setIterations(int i) {
        iterations = i;
    }

    public double getLogLikelihood() {
        return this.logLikelihood;
    }
}
