package reconstruction;

import bn.BNet;
import bn.BNode;
import bn.alg.CGTable;
import bn.alg.VarElim;
import bn.ctmc.PhyloBNet;
import bn.ctmc.SubstNode;
import bn.ctmc.matrix.Gap;
import bn.ctmc.matrix.JTT;
import bn.prob.EnumDistrib;
import bn.prob.GammaDistrib;
import dat.EnumSeq;
import dat.EnumVariable;
import dat.Enumerable;
import dat.PhyloTree;
import dat.Variable;
import dat.file.FastaWriter;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import json.JSONObject;

/* loaded from: input_file:reconstruction/ASR.class */
public class ASR {
    private PhyloBNet pbn;
    private List<EnumSeq.Gappy<Enumerable>> seqs;
    private EnumSeq.Alignment<Enumerable> aln;
    private double[] R;
    private EnumDistrib[] margin_distribs;
    private List<String> indexForNodes;
    private Object[][] asr_matrix;
    private PhyloTree tree = new PhyloTree();
    private boolean use_sampled_rate = false;
    private double sampled_rate = 0.15599004226404184d;

    public ASR(String str, String str2, String str3) {
        loadData(str, str2);
        PhyloBNet[] createNetworks = createNetworks();
        if (str3.equals("Joint")) {
            queryNetsJoint(createNetworks);
            getSequences();
        } else {
            if (!str3.equals("Marginal")) {
                throw new RuntimeException("Error: Inference must be either 'Joint' or 'Marginal'");
            }
            System.out.println("*Information*\nNo node specification: returning marginal distribution of root node");
            queryNetsMarg(createNetworks);
        }
    }

    public ASR(String str, String str2, String str3, double[][] dArr, double[] dArr2) {
        loadData(str, str2);
        PhyloBNet[] createNetworks = createNetworks();
        if (str3.equals("Joint")) {
            queryNetsJoint(createNetworks);
            getSequences();
        } else {
            if (!str3.equals("Marginal")) {
                throw new RuntimeException("Error: Inference must be either 'Joint' or 'Marginal'");
            }
            System.out.println("*Information*\nNo node specification: returning marginal distribution of root node");
            queryNetsMarg(createNetworks);
        }
        List<EnumSeq.Gappy> list = null;
        try {
            list = EnumSeq.Gappy.loadFasta("bnkit/data/simprot/cyp2f_r3.aln", Enumerable.aacid, Character.valueOf("-".charAt(0)));
        } catch (IOException e) {
            e.printStackTrace();
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (EnumSeq.Gappy gappy : list) {
            String name = gappy.getName();
            EnumSeq sequence = this.tree.find(name).getSequence();
            if (name.equals(sequence.getName())) {
                for (int i = 0; i < sequence.length(); i++) {
                    d2 += 1.0d;
                    if (gappy.get()[i] != sequence.get()[i]) {
                        d += 1.0d;
                    }
                }
            }
        }
        System.out.println(1.0d - (d / d2));
    }

    public ASR(String str, String str2, String str3, String str4) {
        loadData(str, str2);
        PhyloBNet[] createNetworks = createNetworks();
        if (str3.equals("Joint")) {
            System.out.println("*Information*\nUsing joint probability so node specification will be ignored");
            queryNetsJoint(createNetworks);
            getSequences();
        } else {
            if (!str3.equals("Marginal")) {
                throw new RuntimeException("Error: Inference must be either 'Joint' or 'Marginal' - exiting");
            }
            if (this.tree.find(str4) == null) {
                throw new RuntimeException("Error: Invalid node label" + str4 + " - exiting");
            }
            queryNetsMarg(createNetworks, this.tree.find(str4));
        }
    }

    public void loadData(String str, String str2) {
        try {
            this.tree = this.tree.loadNewick(str);
            PhyloTree.Node[] nodesBreadthFirst = this.tree.toNodesBreadthFirst();
            this.indexForNodes = new ArrayList();
            for (PhyloTree.Node node : nodesBreadthFirst) {
                this.indexForNodes.add(node.getLabel().toString());
            }
            String readLine = new BufferedReader(new FileReader(str2)).readLine();
            if (readLine.startsWith("CLUSTAL")) {
                this.seqs = EnumSeq.Gappy.loadClustal(str2, Enumerable.aacid);
            } else {
                if (!readLine.startsWith(">")) {
                    throw new RuntimeException("Incorrect sequence or alignment format (requires FASTA or Clustal format .aln, .fa or .fasta)");
                }
                this.seqs = EnumSeq.Gappy.loadFasta(str2, Enumerable.aacid, Character.valueOf("-".charAt(0)));
            }
            PhyloTree.Node[] extantNodes = getExtantNodes();
            HashSet hashSet = new HashSet();
            for (PhyloTree.Node node2 : extantNodes) {
                if (!hashSet.add(node2.getLabel().toString())) {
                    throw new RuntimeException("Extant node names must be unique - " + node2.getLabel().toString() + " is duplicated");
                }
            }
            HashSet hashSet2 = new HashSet();
            for (EnumSeq.Gappy<Enumerable> gappy : this.seqs) {
                if (!hashSet2.add(gappy.getName())) {
                    throw new RuntimeException("Sequence names must be unique - " + gappy.getName() + " is duplicated");
                }
            }
            if (!hashSet.equals(hashSet2)) {
                throw new RuntimeException("The sequence names in the provided alignment must all have a match in the provided tree");
            }
            this.aln = new EnumSeq.Alignment<>(this.seqs);
            this.pbn = PhyloBNet.create(this.tree, new JTT());
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public PhyloBNet[] createNetworks() {
        PhyloBNet create;
        PhyloBNet createGap;
        PhyloBNet[] phyloBNetArr = new PhyloBNet[this.aln.getWidth()];
        String[] names = this.aln.getNames();
        List<String> labels = getLabels(names);
        for (int i = 0; i < this.aln.getWidth(); i++) {
            Object[] gapColumn = this.aln.getGapColumn(i);
            Object[] column = this.aln.getColumn(i);
            this.tree.setContentByParsimony(names, gapColumn);
            if (this.use_sampled_rate) {
                create = PhyloBNet.create(this.tree, new JTT(), this.sampled_rate);
                createGap = PhyloBNet.createGap(this.tree, new Gap(), this.sampled_rate);
            } else {
                create = PhyloBNet.create(this.tree, new JTT());
                createGap = PhyloBNet.createGap(this.tree, new Gap());
            }
            phyloBNetArr[i] = create;
            for (int i2 = 0; i2 < labels.size(); i2++) {
                BNode node = create.getBN().getNode(labels.get(i2));
                BNode node2 = createGap.getBN().getNode(labels.get(i2));
                if (column[i2] == null) {
                    node.setInstance(column[i2]);
                    ((SubstNode) node).setGap(true);
                    node2.setInstance('G');
                } else {
                    node.setInstance(column[i2]);
                    node2.setInstance('C');
                }
            }
            for (Variable.Assignment assignment : queryGapNetJoint(createGap)) {
                EnumVariable enumVariable = (EnumVariable) assignment.var;
                Object obj = assignment.val;
                SubstNode substNode = (SubstNode) create.getBN().getNode(enumVariable.getName());
                if (substNode == null) {
                    System.out.println();
                }
                if (obj.equals('G')) {
                    substNode.setGap(true);
                } else {
                    substNode.setGap(false);
                }
            }
        }
        return phyloBNetArr;
    }

    public void queryNetsMarg(PhyloBNet[] phyloBNetArr) {
        this.margin_distribs = new EnumDistrib[this.aln.getWidth()];
        for (int i = 0; i < this.aln.getWidth(); i++) {
            PhyloBNet phyloBNet = phyloBNetArr[i];
            BNet bn2 = phyloBNet.getBN();
            BNode root = phyloBNet.getRoot();
            VarElim varElim = new VarElim();
            varElim.instantiate(bn2);
            phyloBNet.purgeGaps();
            phyloBNet.collapseSingles();
            this.margin_distribs[i] = getMarginalDistrib(varElim, root.getVariable());
        }
    }

    public void queryNetsMarg(PhyloBNet[] phyloBNetArr, PhyloTree.Node node) {
        this.margin_distribs = new EnumDistrib[this.aln.getWidth()];
        String obj = node.getLabel().toString();
        for (int i = 0; i < this.aln.getWidth(); i++) {
            PhyloBNet phyloBNet = phyloBNetArr[i];
            BNet bn2 = phyloBNet.getBN();
            BNode node2 = phyloBNet.getBN().getNode(obj);
            VarElim varElim = new VarElim();
            varElim.instantiate(bn2);
            phyloBNet.purgeGaps();
            phyloBNet.collapseSingles();
            this.margin_distribs[i] = getMarginalDistrib(varElim, node2.getVariable());
        }
    }

    public void queryNetsJoint(PhyloBNet[] phyloBNetArr) {
        this.asr_matrix = new Object[this.indexForNodes.size()][this.aln.getWidth()];
        this.R = new double[this.aln.getWidth()];
        for (int i = 0; i < this.aln.getWidth(); i++) {
            PhyloBNet phyloBNet = phyloBNetArr[i];
            BNet bn2 = phyloBNet.getBN();
            VarElim varElim = new VarElim();
            varElim.instantiate(bn2);
            phyloBNet.purgeGaps();
            phyloBNet.collapseSingles();
            for (Variable.Assignment assignment : getJointAssignment(varElim)) {
                EnumVariable enumVariable = (EnumVariable) assignment.var;
                Object obj = assignment.val;
                int indexOf = this.indexForNodes.indexOf(enumVariable.getName());
                if (indexOf >= 0) {
                    this.asr_matrix[indexOf][i] = obj;
                }
                bn2.getNode(enumVariable).setInstance(obj);
            }
            this.R[i] = phyloBNet.getRate();
        }
    }

    public EnumDistrib queryNetMarg(PhyloBNet[] phyloBNetArr, int i) {
        PhyloBNet phyloBNet = phyloBNetArr[i];
        BNet bn2 = phyloBNet.getBN();
        BNode root = phyloBNet.getRoot();
        VarElim varElim = new VarElim();
        varElim.instantiate(bn2);
        phyloBNet.purgeGaps();
        phyloBNet.collapseSingles();
        return getMarginalDistrib(varElim, root.getVariable());
    }

    public Variable.Assignment[] queryGapNetJoint(PhyloBNet phyloBNet) {
        BNet bn2 = phyloBNet.getBN();
        VarElim varElim = new VarElim();
        varElim.instantiate(bn2);
        return getJointAssignment(varElim);
    }

    private EnumDistrib getMarginalDistrib(VarElim varElim, Variable variable) {
        EnumDistrib enumDistrib = null;
        try {
            enumDistrib = (EnumDistrib) ((CGTable) varElim.infer(varElim.makeQuery(variable))).query(variable);
        } catch (NullPointerException e) {
            if (e.toString().contains("Invalid query")) {
                double[] dArr = new double[Enumerable.aacid.size()];
                for (int i = 0; i < Enumerable.aacid.size(); i++) {
                    dArr[i] = 0.0d;
                }
                enumDistrib = new EnumDistrib(Enumerable.aacid, dArr);
            } else {
                e.printStackTrace();
            }
        }
        return enumDistrib;
    }

    private Variable.Assignment[] getJointAssignment(VarElim varElim) {
        return ((CGTable) varElim.infer(varElim.makeMPE(new Variable[0]))).getMPE();
    }

    private Variable.Assignment[] getJointAssignment(VarElim varElim, Variable[] variableArr) {
        return ((CGTable) varElim.infer(varElim.makeMPE(variableArr))).getMPE();
    }

    public void getSequences() {
        for (int i = 0; i < this.seqs.size(); i++) {
            EnumSeq.Gappy<Enumerable> gappy = this.seqs.get(i);
            this.tree.find(gappy.getName()).setSequence(gappy);
        }
        ArrayList arrayList = new ArrayList();
        List asList = Arrays.asList(getExtantNodes());
        for (int i2 = 0; i2 < this.asr_matrix.length; i2++) {
            Object[] objArr = this.asr_matrix[i2];
            if (asList.contains(this.tree.find(this.indexForNodes.get(i2)))) {
                arrayList.add((EnumSeq.Gappy) this.tree.find(this.indexForNodes.get(i2)).getSequence());
            } else {
                EnumSeq.Gappy gappy2 = new EnumSeq.Gappy(Enumerable.aacid_alt);
                gappy2.set(objArr);
                gappy2.setName(this.indexForNodes.get(i2));
                arrayList.add(gappy2);
            }
        }
        String node = this.tree.getRoot().toString();
        PhyloTree.Node[] nodesBreadthFirst = this.tree.toNodesBreadthFirst();
        EnumSeq.Alignment<Enumerable> alignment2 = new EnumSeq.Alignment<>(arrayList);
        for (int i3 = 0; i3 < alignment2.getHeight(); i3++) {
            EnumSeq.Gappy<Enumerable> enumSeq = alignment2.getEnumSeq(i3);
            if (node.equals(enumSeq.getName())) {
                nodesBreadthFirst[i3].setSequence(enumSeq);
            }
            if (nodesBreadthFirst[i3].getChildren().toArray().length > 0) {
                nodesBreadthFirst[i3].setSequence(enumSeq);
            }
        }
        this.aln = alignment2;
    }

    public GammaDistrib calcGammaDistrib() {
        double alpha = GammaDistrib.getAlpha(this.R);
        return new GammaDistrib(alpha, 1.0d / (1.0d / alpha));
    }

    public boolean save(String str, boolean z) {
        if (!z) {
            saveTree(str + "_new_tree.nwk");
            saveDistrib(str + "_distribution.txt");
            return true;
        }
        saveALN(str + "_aln_full.fa");
        saveTree(str + "_new_tree.nwk");
        saveRate(str + "_rates.txt");
        return true;
    }

    public boolean saveJSON(String str) {
        try {
            Writer printWriter = new PrintWriter(str, "UTF-8");
            JSONObject jSONObject = new JSONObject();
            JSONObject jSONObject2 = new JSONObject();
            jSONObject2.put("Sequence", this.tree.getRoot().getSequence().toString());
            jSONObject2.put("SeqName", this.tree.getRoot().getLabel());
            jSONObject2.put("NewickRep", this.tree.getRoot().toString());
            JSONObject jSONObject3 = new JSONObject();
            for (int i = 0; i < this.R.length; i++) {
                jSONObject3.put(Integer.toString(i), this.R[i]);
            }
            jSONObject2.put("Rates", jSONObject3);
            jSONObject.put("Root", jSONObject2);
            for (PhyloTree.Node node : getInternalNodes()) {
                JSONObject jSONObject4 = new JSONObject();
                jSONObject4.put("SeqName", node.getLabel());
                jSONObject4.put("Sequence", node.getSequence().toString());
                jSONObject.append("ReconstructedNodes", jSONObject4);
            }
            for (PhyloTree.Node node2 : getExtantNodes()) {
                JSONObject jSONObject5 = new JSONObject();
                jSONObject5.put("SeqName", node2.getLabel());
                jSONObject5.put("Sequence", node2.getSequence().toString());
                jSONObject.append("ExtantNodes", jSONObject5);
            }
            JSONObject jSONObject6 = new JSONObject();
            jSONObject6.put("Reconstruction", jSONObject);
            jSONObject6.write(printWriter);
            printWriter.close();
            return true;
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            return false;
        } catch (UnsupportedEncodingException e2) {
            e2.printStackTrace();
            return false;
        } catch (IOException e3) {
            e3.printStackTrace();
            return false;
        }
    }

    public void saveALN(String str) {
        PhyloTree.Node[] nodesBreadthFirst = this.tree.toNodesBreadthFirst();
        EnumSeq.Gappy[] gappyArr = new EnumSeq.Gappy[nodesBreadthFirst.length];
        for (int i = 0; i < nodesBreadthFirst.length; i++) {
            EnumSeq.Gappy gappy = (EnumSeq.Gappy) nodesBreadthFirst[i].getSequence();
            String node = nodesBreadthFirst[i].toString();
            String name = gappy.getName();
            if (name != null) {
                gappy.setName(name + " " + node + ";");
            }
            gappyArr[i] = gappy;
        }
        try {
            FastaWriter fastaWriter = new FastaWriter(str);
            fastaWriter.save(gappyArr);
            fastaWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void saveTree(String str) {
        try {
            PrintWriter printWriter = new PrintWriter(str, "UTF-8");
            printWriter.write(this.tree.getRoot().toString());
            printWriter.write(";\n");
            printWriter.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (UnsupportedEncodingException e2) {
            e2.printStackTrace();
        } catch (IOException e3) {
            e3.printStackTrace();
        }
    }

    public void saveDistrib(String str) {
        try {
            PrintWriter printWriter = new PrintWriter(str, "UTF-8");
            Object[] values = Enumerable.aacid_ext.getValues();
            Object[][] objArr = new Object[values.length][this.margin_distribs.length + 1];
            printWriter.write("columns\t");
            for (int i = 1; i < this.margin_distribs.length; i++) {
                if (i == this.margin_distribs.length - 1) {
                    printWriter.write(i + "\n");
                } else {
                    printWriter.write(i + "\t");
                }
            }
            for (int i2 = 0; i2 < values.length; i2++) {
                objArr[i2][0] = values[i2];
            }
            for (int i3 = 1; i3 < this.margin_distribs.length + 1; i3++) {
                EnumDistrib enumDistrib = this.margin_distribs[i3 - 1];
                for (int i4 = 0; i4 < values.length; i4++) {
                    if (values[i4].equals('-')) {
                        objArr[i4][i3] = "NA";
                    } else {
                        objArr[i4][i3] = Double.valueOf(enumDistrib.get(values[i4]));
                    }
                }
            }
            for (int i5 = 0; i5 < values.length; i5++) {
                for (int i6 = 0; i6 < objArr[i5].length; i6++) {
                    if (i6 == objArr[i5].length - 1) {
                        printWriter.write(String.valueOf(objArr[i5][i6]) + "\n");
                    } else {
                        printWriter.write(String.valueOf(objArr[i5][i6]) + "\t");
                    }
                }
            }
            printWriter.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (UnsupportedEncodingException e2) {
            e2.printStackTrace();
        } catch (IOException e3) {
            e3.printStackTrace();
        }
    }

    public void saveRate(String str) {
        try {
            PrintWriter printWriter = new PrintWriter(str, "UTF-8");
            printWriter.write(this.R.toString());
            printWriter.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (UnsupportedEncodingException e2) {
            e2.printStackTrace();
        } catch (IOException e3) {
            e3.printStackTrace();
        }
    }

    private PhyloTree.Node[] getInternalNodes() {
        String node = this.tree.getRoot().toString();
        PhyloTree.Node[] nodesBreadthFirst = this.tree.toNodesBreadthFirst();
        ArrayList arrayList = new ArrayList();
        for (PhyloTree.Node node2 : nodesBreadthFirst) {
            if (!node.equals(node2.toString()) && node2.getChildren().toArray().length > 0) {
                arrayList.add(node2);
            }
        }
        return (PhyloTree.Node[]) arrayList.toArray(new PhyloTree.Node[(nodesBreadthFirst.length - this.seqs.size()) - 1]);
    }

    private PhyloTree.Node[] getExtantNodes() {
        PhyloTree.Node[] nodesBreadthFirst = this.tree.toNodesBreadthFirst();
        ArrayList arrayList = new ArrayList();
        for (PhyloTree.Node node : nodesBreadthFirst) {
            if (node.getChildren().toArray().length == 0) {
                arrayList.add(node);
            }
        }
        return (PhyloTree.Node[]) arrayList.toArray(new PhyloTree.Node[(nodesBreadthFirst.length - this.seqs.size()) - 1]);
    }

    public PhyloTree getTree() {
        return this.tree;
    }

    public PhyloBNet getPbn() {
        return this.pbn;
    }

    public List<EnumSeq.Gappy<Enumerable>> getSeqs() {
        return this.seqs;
    }

    public EnumSeq.Alignment<Enumerable> getAln() {
        return this.aln;
    }

    public double[] getRates() {
        return this.R;
    }

    public EnumDistrib[] getMarginDistribs() {
        return this.margin_distribs;
    }

    private static List<String> getLabels(String[] strArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < strArr.length; i++) {
            int indexOf = strArr[i].indexOf("/");
            if (indexOf > 0) {
                arrayList.add(strArr[i].substring(0, indexOf));
            } else {
                arrayList.add(strArr[i]);
            }
        }
        return arrayList;
    }
}
