package bn.prob;

import bn.Distrib;
import java.io.Serializable;
import java.util.Random;

/* JADX WARN: Classes with same name are omitted:
  input_file:target/classes/bn/prob/GammaDistrib.class
 */
/* loaded from: input_file:bn/prob/GammaDistrib.class */
public class GammaDistrib implements Distrib, Serializable {
    private static final long serialVersionUID = 1;
    public static final double GAMMA = 0.5772156649015329d;
    private static final double C_LIMIT = 49.0d;
    private static final double S_LIMIT = 1.0E-5d;
    private double lambda;
    private double k;
    private Random rand = new Random();

    public GammaDistrib(double d, double d2) {
        this.k = d;
        this.lambda = d2;
    }

    @Override // bn.Distrib
    public double get(Object obj) {
        double doubleValue = ((Double) obj).doubleValue();
        return ((this.lambda * Math.exp((-this.lambda) * doubleValue)) * Math.pow(this.lambda * doubleValue, this.k - 1.0d)) / gamma(this.k);
    }

    @Override // bn.Distrib
    public Double sample() {
        double pow;
        double exp;
        boolean z = false;
        if (this.k < 1.0d) {
            double d = 1.0d / this.k;
            double pow2 = (1.0d - this.k) * Math.pow(this.k, this.k / (1.0d - this.k));
            do {
                double nextDouble = this.rand.nextDouble();
                double nextDouble2 = this.rand.nextDouble();
                double d2 = -Math.log(nextDouble);
                double d3 = -Math.log(nextDouble2);
                pow = Math.pow(d2, d);
                if (d2 + d3 >= pow2 + pow) {
                    z = true;
                }
            } while (!z);
            return new Double(pow / this.lambda);
        }
        double log = this.k - Math.log(4.0d);
        double sqrt = this.k + Math.sqrt((2.0d * this.k) - 1.0d);
        double sqrt2 = Math.sqrt((2.0d * this.k) - 1.0d);
        double log2 = 1.0d + Math.log(4.5d);
        do {
            double nextDouble3 = this.rand.nextDouble();
            double nextDouble4 = this.rand.nextDouble();
            double log3 = (1.0d / sqrt2) * Math.log(nextDouble4 / (1.0d - nextDouble4));
            exp = this.k * Math.exp(log3);
            double d4 = nextDouble3 * nextDouble4 * nextDouble4;
            double d5 = (log + (sqrt * log3)) - exp;
            if (d5 >= (4.5d * d4) - log2 || d5 >= Math.log(d4)) {
                z = true;
            }
        } while (!z);
        return new Double(exp / this.lambda);
    }

    public static double gamma(double d) {
        return Math.exp(lgamma(d));
    }

    public static double lgamma(double d) {
        double[] dArr = {76.18009172947146d, -86.50532032941678d, 24.01409824083091d, -1.231739572450155d, 0.001208650973866179d, -5.395239384953E-6d};
        double d2 = d;
        double d3 = d + 5.5d;
        double log = d3 - ((d + 0.5d) * Math.log(d3));
        double d4 = 1.000000000190015d;
        for (int i = 0; i < 6; i++) {
            d2 += 1.0d;
            d4 += dArr[i] / d2;
        }
        return (-log) + Math.log((2.5066282746310007d * d4) / d);
    }

    public double getK() {
        return this.k;
    }

    public void setK(double d) {
        this.k = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public double getAlpha() {
        return this.k;
    }

    public double getBeta() {
        return 1.0d / this.lambda;
    }

    public void setBeta(double d) {
        this.lambda = 1.0d / d;
    }

    public void setAlpha(double d) {
        this.k = d;
    }

    public double logLikelihood(double[] dArr) {
        double alpha = getAlpha();
        double beta = getBeta();
        int length = dArr.length;
        double d = 0.0d;
        double d2 = 0.0d;
        for (double d3 : dArr) {
            d += d3;
            d2 += Math.log(d3);
        }
        return ((((length * (alpha - 1.0d)) * (d2 / length)) - (length * lgamma(alpha))) - ((length * alpha) * Math.log(beta))) - ((length * (d / length)) / beta);
    }

    public static double getAlpha(double[] dArr) {
        int length = dArr.length;
        double d = 0.0d;
        double d2 = 0.0d;
        for (double d3 : dArr) {
            d += d3;
            d2 += Math.log(d3);
        }
        double d4 = d / length;
        double d5 = d2 / length;
        double log = 0.5d / (Math.log(d4) - d5);
        double d6 = 1.0d / log;
        for (int i = 0; i < 10; i++) {
            double log2 = 1.0d / (d6 + ((((d5 - Math.log(d4)) + Math.log(log)) - digamma(log)) / ((log * log) * ((1.0d / log) - trigamma(log)))));
            double abs = Math.abs(log2 - log);
            log = log2;
            if (abs < 0.01d) {
                break;
            }
        }
        return log;
    }

    public static double getBeta(double[] dArr, double d) {
        int length = dArr.length;
        double d2 = 0.0d;
        for (double d3 : dArr) {
            d2 += d3;
        }
        return (d2 / length) / d;
    }

    public static double digamma(double d) {
        if (d > 0.0d && d <= S_LIMIT) {
            return (-0.5772156649015329d) - (1.0d / d);
        }
        if (d < C_LIMIT) {
            return digamma(d + 1.0d) - (1.0d / d);
        }
        double d2 = 1.0d / (d * d);
        return (Math.log(d) - (0.5d / d)) - (d2 * (0.08333333333333333d + (d2 * (0.008333333333333333d - (d2 / 252.0d)))));
    }

    public static double trigamma(double d) {
        if (d > 0.0d && d <= S_LIMIT) {
            return 1.0d / (d * d);
        }
        if (d < C_LIMIT) {
            return trigamma(d + 1.0d) + (1.0d / (d * d));
        }
        double d2 = 1.0d / (d * d);
        return (1.0d / d) + (d2 / 2.0d) + ((d2 / d) * (0.16666666666666666d - (d2 * (0.03333333333333333d + (d2 / 42.0d)))));
    }

    public static void main(String[] strArr) {
        double[] dArr = {1.0E-6d, 11.2d, 8.3d, 13.1d, 15.9d, 11.5d, 11.4d, 12.3d, 11.9d, 5.5d};
        double alpha = getAlpha(dArr);
        double beta = getBeta(dArr, alpha);
        System.out.println("Setting Gamma distrib with alpha = " + alpha + " beta = " + beta);
        GammaDistrib gammaDistrib = new GammaDistrib(alpha, 1.0d / beta);
        double d = 0.0d;
        System.out.println("Sample");
        for (int i = 0; i < 2000; i++) {
            d += gammaDistrib.sample().doubleValue();
        }
        System.out.println("Mean\t" + (d / 2000));
    }
}
