package bn.node;

import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:bn/node/GMM.class */
public class GMM {
    private int numComponents;
    private double[] weights;
    private double[][] means;
    private double[][] variances;

    /* JADX WARN: Type inference failed for: r1v4, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v6, types: [double[], double[][]] */
    public GMM(int i) {
        this.numComponents = i;
        this.weights = new double[i];
        this.means = new double[i];
        this.variances = new double[i];
    }

    public void train(double[][] dArr, int i, double d) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        Random random = new Random();
        for (int i2 = 0; i2 < this.numComponents; i2++) {
            this.weights[i2] = 1.0d / this.numComponents;
            this.means[i2] = new double[length2];
            this.variances[i2] = new double[length2];
            for (int i3 = 0; i3 < length2; i3++) {
                this.means[i2][i3] = random.nextDouble();
                this.variances[i2][i3] = random.nextDouble();
            }
        }
        double d2 = Double.NEGATIVE_INFINITY;
        for (int i4 = 0; i4 < i; i4++) {
            double[][] dArr2 = new double[length][this.numComponents];
            for (int i5 = 0; i5 < length; i5++) {
                double d3 = 0.0d;
                for (int i6 = 0; i6 < this.numComponents; i6++) {
                    dArr2[i5][i6] = this.weights[i6] * calculateGaussian(dArr[i5], this.means[i6], this.variances[i6]);
                    d3 += dArr2[i5][i6];
                }
                for (int i7 = 0; i7 < this.numComponents; i7++) {
                    double[] dArr3 = dArr2[i5];
                    int i8 = i7;
                    dArr3[i8] = dArr3[i8] / d3;
                }
            }
            double[] dArr4 = new double[this.numComponents];
            for (int i9 = 0; i9 < this.numComponents; i9++) {
                for (int i10 = 0; i10 < length; i10++) {
                    int i11 = i9;
                    dArr4[i11] = dArr4[i11] + dArr2[i10][i9];
                }
                for (int i12 = 0; i12 < length2; i12++) {
                    double d4 = 0.0d;
                    double d5 = 0.0d;
                    for (int i13 = 0; i13 < length; i13++) {
                        d4 += dArr2[i13][i9] * dArr[i13][i12];
                        d5 += dArr2[i13][i9] * Math.pow(dArr[i13][i12] - this.means[i9][i12], 2.0d);
                    }
                    this.means[i9][i12] = d4 / dArr4[i9];
                    this.variances[i9][i12] = d5 / dArr4[i9];
                }
                this.weights[i9] = dArr4[i9] / length;
            }
            double d6 = 0.0d;
            for (double[] dArr5 : dArr) {
                double d7 = 0.0d;
                for (int i14 = 0; i14 < this.numComponents; i14++) {
                    d7 += this.weights[i14] * calculateGaussian(dArr5, this.means[i14], this.variances[i14]);
                }
                d6 += Math.log(d7);
            }
            if (Math.abs(d6 - d2) < d) {
                return;
            }
            d2 = d6;
        }
    }

    private double calculateGaussian(double[] dArr, double[] dArr2, double[] dArr3) {
        double d = 1.0d;
        for (int i = 0; i < dArr.length; i++) {
            d *= Math.exp(((-0.5d) * Math.pow(dArr[i] - dArr2[i], 2.0d)) / dArr3[i]) / Math.sqrt(6.283185307179586d * dArr3[i]);
        }
        return d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        GMM gmm = new GMM(3);
        double[] dArr = {new double[]{2.0d}, new double[]{3.5d}, new double[]{1.5d}, new double[]{7.0d}, new double[]{8.0d}, new double[]{2.2d}, new double[]{7.7d}};
        gmm.train(new double[]{new double[]{8.547969513147832d}, new double[]{11.055365021610807d}, new double[]{9.515919623038354d}, new double[]{7.309429864631631d}, new double[]{5.155113714709337d}, new double[]{5.9934818128526d}, new double[]{9.737441853198932d}, new double[]{5.866697379742977d}, new double[]{8.730212309448154d}, new double[]{5.107871539173301d}, new double[]{7.091657649202554d}, new double[]{6.005496193011316d}, new double[]{5.784554136706749d}, new double[]{7.24665515021407d}, new double[]{7.588302408954381d}, new double[]{4.654811488291314d}}, 100, 1.0E-4d);
        System.out.println("Weights: " + Arrays.toString(gmm.weights));
        System.out.println("Means: " + Arrays.deepToString(gmm.means));
        System.out.println("Variances: " + Arrays.deepToString(gmm.variances));
    }
}
