package bn.prior;

import bn.Distrib;
import bn.JPT;
import bn.node.CPT;
import bn.prob.EnumDistrib;
import dat.EnumTable;
import dat.EnumVariable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:bn/prior/CPTPrior.class */
public class CPTPrior extends CPT {
    private static final long serialVersionUID = 1;
    private final Double defaultValue;
    private EnumTable<DirichletDistribPrior> PriorTable;
    private DirichletDistribPrior rootPrior;

    public CPTPrior(EnumVariable enumVariable, List<EnumVariable> list) {
        super(enumVariable, list);
        this.defaultValue = Double.valueOf(1.0d);
        if (list == null || list.size() <= 0) {
            this.PriorTable = null;
            this.rootPrior = new DirichletDistribPrior(enumVariable.getDomain(), 1.0d);
        } else {
            this.PriorTable = new EnumTable<>(list);
            this.rootPrior = null;
        }
    }

    public CPTPrior(EnumVariable enumVariable, EnumVariable... enumVariableArr) {
        super(enumVariable, enumVariableArr);
        this.defaultValue = Double.valueOf(1.0d);
        if (enumVariableArr == null || enumVariableArr.length <= 0) {
            this.PriorTable = null;
            this.rootPrior = new DirichletDistribPrior(enumVariable.getDomain(), 1.0d);
        } else {
            this.PriorTable = new EnumTable<>(enumVariableArr);
            this.rootPrior = null;
        }
    }

    public CPTPrior(EnumVariable enumVariable) {
        super(enumVariable);
        this.defaultValue = Double.valueOf(1.0d);
        this.PriorTable = null;
        this.rootPrior = new DirichletDistribPrior(enumVariable.getDomain(), 1.0d);
    }

    public CPTPrior(JPT jpt, EnumVariable enumVariable) {
        super(jpt, enumVariable);
        this.defaultValue = Double.valueOf(1.0d);
        ArrayList arrayList = new ArrayList(jpt.getParents().size() - 1);
        for (int i = 0; i < jpt.getParents().size(); i++) {
            EnumVariable enumVariable2 = jpt.getParents().get(i);
            if (enumVariable2 != enumVariable) {
                arrayList.add(enumVariable2);
            }
        }
        if (jpt.getParents().size() == 0) {
            this.PriorTable = null;
            this.rootPrior = new DirichletDistribPrior(enumVariable.getDomain(), 1.0d);
        } else {
            this.PriorTable = new EnumTable<>(arrayList);
            this.rootPrior = null;
        }
    }

    public void setPrior(DirichletDistribPrior dirichletDistribPrior) {
        if (this.rootPrior == null || dirichletDistribPrior == null) {
            System.err.println("This CPT is conditioned on other nodes, need parents specified or prior is null");
        } else {
            this.rootPrior = dirichletDistribPrior;
        }
    }

    public void setPrior(Object[] objArr, DirichletDistribPrior dirichletDistribPrior) {
        if (this.PriorTable == null || dirichletDistribPrior == null) {
            System.err.println("This CPT is root node, no keys are needed or prior is null");
        } else {
            this.PriorTable.setValue(objArr, (Object[]) dirichletDistribPrior);
        }
    }

    public void setPrior(int i, DirichletDistribPrior dirichletDistribPrior) {
        if (this.PriorTable != null) {
            this.PriorTable.setValue(i, (int) dirichletDistribPrior);
        } else {
            System.err.println("This CPT is root node, no keys are needed");
        }
    }

    @Override // bn.node.CPT, bn.BNode
    public void maximizeInstance() {
        if (this.count.table.isEmpty()) {
            return;
        }
        HashMap hashMap = new HashMap();
        EnumVariable variable = getVariable();
        if (this.table != null) {
            Iterator<EnumDistrib> it = this.table.getValues().iterator();
            while (it.hasNext()) {
                it.next().setValid(false);
            }
            for (Map.Entry<Integer, Double> entry : this.count.table.getMapEntries()) {
                double doubleValue = entry.getValue().doubleValue();
                Object[] key = this.count.table.getKey(entry.getKey().intValue());
                Object[] objArr = new Object[key.length - 1];
                for (int i = 0; i < objArr.length; i++) {
                    objArr[i] = key[i + 1];
                }
                Integer num = new Integer(this.table.getIndex(objArr));
                if (hashMap.containsKey(num)) {
                    ((Double[]) hashMap.get(num))[variable.getIndex(key[0])] = Double.valueOf(doubleValue);
                } else {
                    Double[] dArr = new Double[variable.size()];
                    Arrays.fill(dArr, this.defaultValue);
                    dArr[variable.getIndex(key[0])] = Double.valueOf(doubleValue);
                    hashMap.put(num, dArr);
                }
            }
            Object[] values = variable.getDomain().getValues();
            for (Map.Entry entry2 : hashMap.entrySet()) {
                int intValue = ((Integer) entry2.getKey()).intValue();
                DirichletDistribPrior value = this.PriorTable.getValue(intValue);
                if (value == null) {
                    value = new DirichletDistribPrior(variable.getDomain(), this.defaultValue.doubleValue());
                    this.PriorTable.setValue(intValue, (int) value);
                }
                double[] dArr2 = new double[variable.size()];
                for (int i2 = 0; i2 < variable.size(); i2++) {
                    dArr2[i2] = ((Double[]) entry2.getValue())[i2].doubleValue();
                }
                value.setEstimatedDistrib(new EnumDistrib(variable.getDomain()));
                value.learn(values, dArr2);
                Distrib estimatedDistrib = value.getEstimatedDistrib();
                value.resetParameters();
                this.table.setValue(((Integer) entry2.getKey()).intValue(), (int) estimatedDistrib);
            }
            Iterator<Map.Entry<Integer, EnumDistrib>> it2 = this.table.getMapEntries().iterator();
            while (it2.hasNext()) {
                if (!it2.next().getValue().isValid()) {
                    it2.remove();
                }
            }
        } else {
            Object[] objArr2 = new Object[1];
            double[] dArr3 = new double[variable.size()];
            for (int i3 = 0; i3 < variable.size(); i3++) {
                objArr2[0] = variable.getDomain().get(i3);
                dArr3[i3] = this.count.get(objArr2);
            }
            this.rootPrior.setEstimatedDistrib(new EnumDistrib(variable.getDomain()));
            this.rootPrior.learn(variable.getDomain().getValues(), dArr3);
            put((EnumDistrib) this.rootPrior.getEstimatedDistrib());
            this.rootPrior.resetParameters();
        }
        this.count.table.setEmpty();
    }
}
