package rbm;

import bn.prob.EnumDistrib;
import dat.EnumVariable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Random;

/* JADX WARN: Classes with same name are omitted:
  input_file:target/classes/rbm/AbstractRBM.class
 */
/* loaded from: input_file:rbm/AbstractRBM.class */
public abstract class AbstractRBM {
    public static int RANDOM_SEED = 1;
    protected Random rand;
    protected EnumVariable[] h;
    protected EnumVariable[] v;
    protected EnumDistrib[] Ph;
    protected EnumDistrib[] Pv;
    protected boolean[][] lnk;
    private HashMap<EnumVariable, Integer> visvarmap = null;
    private HashMap<EnumVariable, Integer> hidvarmap = null;
    public double err = -1.7976931348623157E308d;

    public boolean isLinked(int i, int i2) {
        return this.lnk[i2][i];
    }

    public int getNVisible() {
        return this.v.length;
    }

    public EnumVariable[] getVisibleVars() {
        return this.v;
    }

    public EnumDistrib[] getVisibleDistribs() {
        return this.Pv;
    }

    public int getNHidden() {
        return this.h.length;
    }

    public EnumVariable[] getHiddenVars() {
        return this.h;
    }

    public EnumDistrib[] getHiddenDistribs() {
        return this.Ph;
    }

    private static HashMap<EnumVariable, Integer> mapVarsToIndices(EnumVariable[] enumVariableArr) {
        HashMap<EnumVariable, Integer> hashMap = new HashMap<>();
        for (int i = 0; i < enumVariableArr.length; i++) {
            hashMap.put(enumVariableArr[i], Integer.valueOf(i));
        }
        return hashMap;
    }

    public int getVisibleIndex(EnumVariable enumVariable) {
        if (this.v == null) {
            return -1;
        }
        if (this.visvarmap == null) {
            this.visvarmap = mapVarsToIndices(this.v);
        }
        Integer num = this.visvarmap.get(enumVariable);
        if (num != null) {
            return num.intValue();
        }
        return -1;
    }

    public int[] getVisibleIndex(EnumVariable[] enumVariableArr) {
        if (this.visvarmap == null) {
            this.visvarmap = mapVarsToIndices(this.v);
        }
        int[] iArr = new int[enumVariableArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = getVisibleIndex(enumVariableArr[i]);
        }
        return iArr;
    }

    public int getHiddenIndex(EnumVariable enumVariable) {
        if (this.h == null) {
            return -1;
        }
        if (this.hidvarmap == null) {
            this.hidvarmap = mapVarsToIndices(this.h);
        }
        Integer num = this.hidvarmap.get(enumVariable);
        if (num != null) {
            return num.intValue();
        }
        return -1;
    }

    public int[] getHiddenIndex(EnumVariable[] enumVariableArr) {
        if (this.hidvarmap == null) {
            this.hidvarmap = mapVarsToIndices(this.h);
        }
        int[] iArr = new int[enumVariableArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = getHiddenIndex(enumVariableArr[i]);
        }
        return iArr;
    }

    public void setLinked(boolean z) {
        for (int i = 0; i < this.h.length; i++) {
            for (int i2 = 0; i2 < this.v.length; i2++) {
                this.lnk[i][i2] = z;
            }
        }
    }

    public void setLinked(int[] iArr, int[] iArr2, boolean z) {
        for (int i : iArr2) {
            for (int i2 : iArr) {
                this.lnk[i][i2] = z;
            }
        }
    }

    public void setLinked(int i, int i2, int i3, boolean z) {
        for (int i4 = i; i4 < i2; i4++) {
            this.lnk[i3][i4] = z;
        }
    }

    public void setLinkedWindow(int i, boolean z) {
        int nVisible = (getNVisible() - i) / (getNHidden() - 1);
        for (int i2 = 0; i2 < getNHidden(); i2++) {
            int i3 = i2 * nVisible;
            setLinked(i3, i3 + i, i2, z);
        }
    }

    public void setLinked(EnumVariable[] enumVariableArr) {
        int[] visibleIndex = getVisibleIndex(enumVariableArr);
        for (int i = 0; i < this.h.length; i++) {
            Arrays.fill(this.lnk[i], false);
            for (int i2 = 0; i2 < visibleIndex.length; i2++) {
                if (visibleIndex[i2] != -1) {
                    this.lnk[i][visibleIndex[i2]] = true;
                }
            }
        }
    }

    public void setLinked(EnumVariable[] enumVariableArr, EnumVariable[] enumVariableArr2) {
        int[] visibleIndex = getVisibleIndex(enumVariableArr);
        int[] visibleIndex2 = getVisibleIndex(enumVariableArr2);
        for (int i = 0; i < visibleIndex2.length; i++) {
            if (visibleIndex2[i] != -1) {
                for (int i2 = 0; i2 < visibleIndex.length; i2++) {
                    if (visibleIndex[i2] != -1) {
                        this.lnk[visibleIndex2[i]][visibleIndex[i2]] = true;
                    }
                }
            }
        }
    }

    public EnumDistrib assignVisible(int i, Object obj) {
        Object[] values = this.Pv[i].getDomain().getValues();
        double[] dArr = new double[values.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = values[i2].equals(obj) ? 1.0d : 0.0d;
        }
        this.Pv[i].set(dArr);
        return this.Pv[i];
    }

    public static double logistic(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    public static double softmax(double[] dArr, int i) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += Math.exp(d2);
        }
        return Math.exp(dArr[i]) / d;
    }

    public abstract Object[] encode(Object[] objArr, Object[] objArr2);

    public abstract Object[] decode(Object[] objArr, Object[] objArr2);

    public Object[] encode(Object[] objArr) {
        return encode(objArr, new Object[this.h.length]);
    }

    public Object[] decode(Object[] objArr) {
        return decode(objArr, new Object[this.v.length]);
    }

    public abstract Double[][] getCDGradient(Object[][] objArr, int i);

    public abstract void setCDGradient(Double[][] dArr);

    public Object[] encode_decode_clamped(Object[] objArr) {
        return encode_decode_clamped(objArr, 0);
    }

    public Object[] encode_decode_clamped(Object[] objArr, int i) {
        Object[] encode = encode(objArr, new Object[getNHidden()]);
        Object[] decode = decode(encode, new Object[getNVisible()]);
        Object[] objArr2 = new Object[this.v.length];
        for (int i2 = 0; i2 < decode.length; i2++) {
            objArr2[i2] = objArr[i2] == null ? decode[i2] : objArr[i2];
        }
        for (int i3 = 0; i3 < i; i3++) {
            encode = encode(objArr2, encode);
            decode = decode(encode, decode);
            for (int i4 = 0; i4 < decode.length; i4++) {
                objArr2[i4] = objArr[i4] == null ? decode[i4] : objArr[i4];
            }
        }
        return objArr2;
    }

    public Object[] encode_decode_restricted(Object[] objArr) {
        return encode_decode_restricted(objArr, 0);
    }

    public Object[] decode_restricted(Object[] objArr, Object[] objArr2) {
        Object[] decode = decode(objArr, new Object[getNVisible()]);
        Object[] objArr3 = new Object[this.v.length];
        for (int i = 0; i < decode.length; i++) {
            objArr3[i] = objArr2[i] == null ? null : decode[i];
        }
        return objArr3;
    }

    public Object[] encode_decode_restricted(Object[] objArr, int i) {
        Object[] encode = encode(objArr, new Object[getNHidden()]);
        Object[] decode = decode(encode, new Object[getNVisible()]);
        Object[] objArr2 = new Object[this.v.length];
        for (int i2 = 0; i2 < decode.length; i2++) {
            objArr2[i2] = objArr[i2] == null ? null : decode[i2];
        }
        for (int i3 = 0; i3 < i; i3++) {
            encode = encode(objArr2, encode);
            decode = decode(encode, decode);
            for (int i4 = 0; i4 < decode.length; i4++) {
                objArr2[i4] = objArr[i4] == null ? null : decode[i4];
            }
        }
        return objArr2;
    }

    public Object[] encode_decode_full(Object[] objArr) {
        return encode_decode_full(objArr, 0);
    }

    public Object[] encode_decode_full(Object[] objArr, int i) {
        Object[] encode = encode(objArr, new Object[getNHidden()]);
        Object[] decode = decode(encode, new Object[getNVisible()]);
        for (int i2 = 0; i2 < i; i2++) {
            encode = encode(decode, encode);
            decode = decode(encode, decode);
        }
        return decode;
    }
}
