package rbm;

import bn.prob.EnumDistrib;
import dat.EnumVariable;
import dat.Enumerable;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/* JADX WARN: Classes with same name are omitted:
  input_file:target/classes/rbm/BooleanRBM.class
 */
/* loaded from: input_file:rbm/BooleanRBM.class */
public class BooleanRBM extends AbstractRBM {
    private double[][] w;
    private double[] a;
    private double[] b;
    public static double WEIGHT_STDEV = 0.01d;

    public BooleanRBM(int i, int i2) {
        this.rand = new Random(RANDOM_SEED);
        this.v = new EnumVariable[i];
        this.Pv = new EnumDistrib[i];
        this.h = new EnumVariable[i2];
        this.Ph = new EnumDistrib[i2];
        this.a = new double[this.v.length];
        this.b = new double[this.h.length];
        this.w = new double[this.h.length][this.v.length];
        this.lnk = new boolean[this.h.length][this.v.length];
        for (int i3 = 0; i3 < this.h.length; i3++) {
            this.h[i3] = new EnumVariable(Enumerable.bool, "hid_" + i3);
            this.h[i3].setPredef("Boolean");
            this.Ph[i3] = new EnumDistrib(Enumerable.bool);
            this.Ph[i3].setSeed(this.rand.nextLong());
            for (int i4 = 0; i4 < this.v.length; i4++) {
                this.w[i3][i4] = this.rand.nextGaussian() * WEIGHT_STDEV;
                this.lnk[i3][i4] = true;
                if (i3 == 0) {
                    this.a[i4] = this.rand.nextGaussian() * WEIGHT_STDEV;
                    this.v[i4] = new EnumVariable(Enumerable.bool, "vis_" + i4);
                    this.v[i4].setPredef("Boolean");
                    this.Pv[i4] = new EnumDistrib(Enumerable.bool);
                    this.Pv[i4].setSeed(this.rand.nextLong());
                }
                if (i4 == 0) {
                    this.b[i3] = this.rand.nextGaussian() * WEIGHT_STDEV;
                }
            }
        }
    }

    private void setWeights(List<Double[]> list) {
        this.a = new double[this.v.length];
        this.b = new double[this.h.length];
        this.w = new double[this.h.length][this.v.length];
        this.lnk = new boolean[this.h.length][this.v.length];
        for (int i = 0; i < list.size(); i++) {
            for (int i2 = 0; i2 < list.get(i).length; i2++) {
                if (list.get(i)[i2] == null) {
                    this.lnk[i][i2] = false;
                } else {
                    this.w[i][i2] = list.get(i)[i2].doubleValue();
                    this.lnk[i][i2] = true;
                }
            }
        }
    }

    public BooleanRBM(File file) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
        Character ch = '?';
        int i = 0;
        int i2 = 0;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
            String[] split = readLine.split("\t");
            if (split[0].startsWith("*w")) {
                ch = 'w';
                i = 0;
            } else if (split[0].startsWith("*a")) {
                if (ch.charValue() == 'w') {
                    int size = arrayList.size();
                    this.h = new EnumVariable[size];
                    this.Ph = new EnumDistrib[size];
                    for (int i3 = 0; i3 < size; i3++) {
                        this.h[i3] = new EnumVariable(Enumerable.bool, (String) arrayList2.get(i3));
                        this.h[i3].setPredef("Boolean");
                        this.Ph[i3] = new EnumDistrib(Enumerable.bool);
                    }
                    setWeights(arrayList);
                }
                ch = 'a';
                i = 0;
            } else if (split[0].startsWith("*b")) {
                if (ch.charValue() == 'w') {
                    int size2 = arrayList.size();
                    this.h = new EnumVariable[size2];
                    this.Ph = new EnumDistrib[size2];
                    for (int i4 = 0; i4 < size2; i4++) {
                        this.h[i4] = new EnumVariable(Enumerable.bool, (String) arrayList2.get(i4));
                        this.h[i4].setPredef("Boolean");
                        this.Ph[i4] = new EnumDistrib(Enumerable.bool);
                    }
                    setWeights(arrayList);
                }
                ch = 'b';
                i = 0;
            }
            if (ch.equals('w')) {
                if (i == 0) {
                    i2 = split.length - 1;
                    this.v = new EnumVariable[i2];
                    this.Pv = new EnumDistrib[i2];
                    for (int i5 = 0; i5 < i2; i5++) {
                        this.v[i5] = new EnumVariable(Enumerable.bool, split[1 + i5]);
                        this.v[i5].setPredef("Boolean");
                        this.Pv[i5] = new EnumDistrib(Enumerable.bool);
                    }
                    i = 1;
                } else {
                    Double[] dArr = new Double[i2];
                    arrayList2.add(split[0]);
                    for (int i6 = 0; i6 < i2; i6++) {
                        try {
                            dArr[i6] = Double.valueOf(Double.parseDouble(split[1 + i6]));
                        } catch (NumberFormatException e) {
                            dArr[i6] = null;
                        }
                    }
                    arrayList.add(dArr);
                    i++;
                }
            } else if (ch.equals('a')) {
                if (i == 0 && this.a != null) {
                    for (int i7 = 0; i7 < i2; i7++) {
                        this.a[i7] = Double.parseDouble(split[1 + i7]);
                    }
                    i = 1;
                }
            } else if (ch.equals('b')) {
                if (i == 0 && this.b != null) {
                    i = 1;
                } else if (split[0].equals(this.h[i - 1].getName())) {
                    this.b[i - 1] = Double.parseDouble(split[1]);
                    i++;
                }
            }
        }
        this.rand = new Random(RANDOM_SEED);
    }

    public BooleanRBM(String str) throws IOException {
        this(new File(str));
    }

    public void save(String str) throws IOException {
        PrintWriter printWriter = new PrintWriter(str, "UTF-8");
        printWriter.printf("*w\t", new Object[0]);
        for (int i = 0; i < getNVisible(); i++) {
            printWriter.printf("%s\t", this.v[i].getName());
        }
        printWriter.println();
        for (int i2 = 0; i2 < getNHidden(); i2++) {
            printWriter.printf("%s\t", this.h[i2].getName());
            for (int i3 = 0; i3 < getNVisible(); i3++) {
                printWriter.printf("%f\t", Double.valueOf(this.w[i2][i3]));
            }
            printWriter.println();
        }
        printWriter.printf("*a\t", new Object[0]);
        for (int i4 = 0; i4 < getNVisible(); i4++) {
            printWriter.printf("%f\t", Double.valueOf(this.a[i4]));
        }
        printWriter.println();
        printWriter.printf("*b\n", new Object[0]);
        for (int i5 = 0; i5 < getNHidden(); i5++) {
            printWriter.printf("%s\t%f\n", this.h[i5].getName(), Double.valueOf(this.b[i5]));
        }
        printWriter.close();
    }

    public void setVisible(Object[] objArr) {
        for (int i = 0; i < this.v.length; i++) {
            if (objArr[i] != null) {
                assignVisible(i, objArr[i]);
            }
        }
    }

    @Override // rbm.AbstractRBM
    public Object[] encode(Object[] objArr, Object[] objArr2) {
        for (int i = 0; i < this.h.length; i++) {
            double d = this.b[i];
            for (int i2 = 0; i2 < this.v.length; i2++) {
                if (objArr[i2] != null && this.lnk[i][i2] && ((Boolean) objArr[i2]).booleanValue()) {
                    d += this.w[i][i2];
                }
            }
            double logistic = logistic(d);
            this.Ph[i].set(new double[]{logistic, 1.0d - logistic});
            objArr2[i] = this.Ph[i].sample();
        }
        return objArr2;
    }

    @Override // rbm.AbstractRBM
    public Object[] decode(Object[] objArr, Object[] objArr2) {
        for (int i = 0; i < this.v.length; i++) {
            double d = this.a[i];
            for (int i2 = 0; i2 < this.h.length; i2++) {
                if (objArr[i2] != null && this.lnk[i2][i] && ((Boolean) objArr[i2]).booleanValue()) {
                    d += this.w[i2][i];
                }
            }
            double logistic = logistic(d);
            this.Pv[i].set(new double[]{logistic, 1.0d - logistic});
            objArr2[i] = this.Pv[i].sample();
        }
        return objArr2;
    }

    @Override // rbm.AbstractRBM
    public Double[][] getCDGradient(Object[][] objArr, int i) {
        Double[][] dArr = new Double[getNHidden() + 1][getNVisible() + 1];
        int[][] iArr = new int[getNHidden() + 1][getNVisible() + 1];
        this.err = 0.0d;
        for (int i2 = 0; i2 < objArr.length; i2++) {
            Object[] objArr2 = objArr[i2];
            Object[] encode = encode(objArr2);
            double[][] dArr2 = new double[getNHidden() + 1][getNVisible() + 1];
            for (int i3 = 0; i3 < getNVisible(); i3++) {
                if (objArr[i2][i3] != null) {
                    for (int i4 = 0; i4 < getNHidden(); i4++) {
                        if (i3 == 0) {
                            int[] iArr2 = iArr[i4];
                            int nVisible = getNVisible();
                            iArr2[nVisible] = iArr2[nVisible] + 1;
                            dArr[i4][getNVisible()] = Double.valueOf(0.0d);
                            dArr2[i4][getNVisible()] = this.Ph[i4].get(0);
                        }
                        if (this.lnk[i4][i3]) {
                            int[] iArr3 = iArr[i4];
                            int i5 = i3;
                            iArr3[i5] = iArr3[i5] + 1;
                            dArr2[i4][i3] = ((Boolean) objArr2[i3]).booleanValue() ? 1.0d * this.Ph[i4].get(0) : 0.0d;
                        }
                        dArr[i4][i3] = Double.valueOf(0.0d);
                    }
                    int[] iArr4 = iArr[getNHidden()];
                    int i6 = i3;
                    iArr4[i6] = iArr4[i6] + 1;
                    dArr[getNHidden()][i3] = Double.valueOf(0.0d);
                    dArr2[getNHidden()][i3] = ((Boolean) objArr2[i3]).booleanValue() ? 1.0d : 0.0d;
                }
            }
            for (int i7 = 0; i7 < i; i7++) {
                Object[] decode_restricted = decode_restricted(encode, objArr2);
                for (int i8 = 0; i8 < objArr2.length; i8++) {
                    double d = ((Boolean) objArr2[i8]).booleanValue() ? 1.0d : 0.0d;
                    double d2 = this.Pv[i8].get(0);
                    this.err += Math.sqrt((d - d2) * (d - d2));
                }
                encode(decode_restricted);
            }
            for (int i9 = 0; i9 < getNVisible(); i9++) {
                if (objArr[i2][i9] != null) {
                    for (int i10 = 0; i10 < getNHidden(); i10++) {
                        if (i9 == 0) {
                            double d3 = this.Ph[i10].get(0);
                            Double[] dArr3 = dArr[i10];
                            int nVisible2 = getNVisible();
                            dArr3[nVisible2] = Double.valueOf(dArr3[nVisible2].doubleValue() + (dArr2[i10][getNVisible()] - d3));
                        }
                        if (this.lnk[i10][i9]) {
                            double d4 = this.Pv[i9].get(0) * this.Ph[i10].get(0);
                            Double[] dArr4 = dArr[i10];
                            int i11 = i9;
                            dArr4[i11] = Double.valueOf(dArr4[i11].doubleValue() + (dArr2[i10][i9] - d4));
                        }
                    }
                    double d5 = this.Pv[i9].get(0);
                    Double[] dArr5 = dArr[getNHidden()];
                    int i12 = i9;
                    dArr5[i12] = Double.valueOf(dArr5[i12].doubleValue() + (dArr2[getNHidden()][i9] - d5));
                }
            }
        }
        for (int i13 = 0; i13 < getNVisible() + 1; i13++) {
            for (int i14 = 0; i14 < getNHidden() + 1; i14++) {
                if (iArr[i14][i13] > 0) {
                    Double[] dArr6 = dArr[i14];
                    int i15 = i13;
                    dArr6[i15] = Double.valueOf(dArr6[i15].doubleValue() / iArr[i14][i13]);
                } else {
                    dArr[i14][i13] = null;
                }
            }
        }
        return dArr;
    }

    @Override // rbm.AbstractRBM
    public void setCDGradient(Double[][] dArr) {
        for (int i = 0; i < getNHidden() + 1; i++) {
            for (int i2 = 0; i2 < getNVisible() + 1; i2++) {
                if (dArr[i][i2] != null) {
                    if (i == getNHidden()) {
                        double[] dArr2 = this.a;
                        int i3 = i2;
                        dArr2[i3] = dArr2[i3] + dArr[i][i2].doubleValue();
                    } else if (i2 == getNVisible()) {
                        double[] dArr3 = this.b;
                        int i4 = i;
                        dArr3[i4] = dArr3[i4] + dArr[i][i2].doubleValue();
                    } else {
                        double[] dArr4 = this.w[i];
                        int i5 = i2;
                        dArr4[i5] = dArr4[i5] + dArr[i][i2].doubleValue();
                    }
                }
            }
        }
    }
}
