package asr;

import bn.BNode;
import bn.Distrib;
import bn.alg.CGTable;
import bn.alg.VarElim;
import bn.ctmc.SubstModel;
import dat.phylo.IdxTree;
import dat.phylo.PhyloBN;
import dat.phylo.TreeDecor;
import dat.phylo.TreeInstance;

/* loaded from: input_file:asr/MaxLhoodMarginal.class */
public class MaxLhoodMarginal<E extends Distrib> implements TreeDecor<E> {
    private E value = null;
    private final IdxTree tree;
    private final int bpidx;
    private final PhyloBN pbn;

    public MaxLhoodMarginal(int i, IdxTree idxTree, SubstModel substModel, double d) {
        this.tree = idxTree;
        this.bpidx = i;
        this.pbn = PhyloBN.create(idxTree, substModel, d);
    }

    public MaxLhoodMarginal(int i, IdxTree idxTree, SubstModel substModel) {
        this.tree = idxTree;
        this.bpidx = i;
        this.pbn = PhyloBN.create(idxTree, substModel);
    }

    public MaxLhoodMarginal(int i, PhyloBN phyloBN) {
        this.tree = phyloBN.getTree();
        this.bpidx = i;
        this.pbn = phyloBN;
    }

    @Override // dat.phylo.TreeDecor
    public E getDecoration(int i) {
        if (i == this.bpidx) {
            return this.value;
        }
        throw new ASRRuntimeException("Invalid ancestor index, not defined for marginal inference: " + i);
    }

    @Override // dat.phylo.TreeDecor
    public void decorate(TreeInstance treeInstance) {
        for (int i = 0; i < treeInstance.getSize(); i++) {
            BNode extNode = this.pbn.isExt() ? this.pbn.getExtNode(i) : this.pbn.getBNode(i);
            if (extNode != null) {
                extNode.setInstance(treeInstance.getInstance(i));
            }
        }
        if (this.pbn.isValid()) {
            VarElim varElim = new VarElim();
            varElim.instantiate(this.pbn.getBN());
            BNode extNode2 = this.pbn.isExt() ? this.pbn.getExtNode(this.bpidx) != null ? this.pbn.getExtNode(this.bpidx) : this.pbn.getBNode(this.bpidx) : this.pbn.getBNode(this.bpidx);
            if (extNode2 == null) {
                throw new ASRRuntimeException("Marginal inference of invalid branchpoint: " + this.bpidx);
            }
            this.value = (E) ((CGTable) varElim.infer(varElim.makeQuery(extNode2.getVariable()))).query(extNode2.getVariable());
        }
    }
}
