package bn.prob;

import bn.Distrib;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;

/* JADX WARN: Classes with same name are omitted:
  input_file:target/classes/bn/prob/MixtureDistrib.class
 */
/* loaded from: input_file:bn/prob/MixtureDistrib.class */
public class MixtureDistrib implements Distrib {
    final Map<Distrib, Double> mixture;
    protected ArrayList<Distrib> distribs;
    protected ArrayList<Double> weights;
    private double density;
    Random rand;

    public MixtureDistrib() {
        this.rand = new Random(1L);
        this.mixture = new HashMap();
        this.distribs = new ArrayList<>();
        this.weights = new ArrayList<>();
        this.density = 0.0d;
    }

    public MixtureDistrib(Distrib distrib, double d) {
        this.rand = new Random(1L);
        this.distribs = new ArrayList<>();
        this.weights = new ArrayList<>();
        this.density = 0.0d;
        this.mixture = new HashMap();
        if (!(distrib instanceof MixtureDistrib)) {
            addDistribForced(distrib, d);
            this.mixture.put(distrib, Double.valueOf(d));
            return;
        }
        MixtureDistrib mixtureDistrib = (MixtureDistrib) distrib;
        int mixtureSize = mixtureDistrib.getMixtureSize();
        for (int i = 0; i < mixtureSize; i++) {
            Distrib distrib2 = mixtureDistrib.getDistrib(i);
            double doubleValue = d * mixtureDistrib.getWeights(i).doubleValue();
            addDistribForced(distrib2, doubleValue);
            this.mixture.put(distrib2, Double.valueOf(doubleValue));
        }
    }

    public void setSeed(long j) {
        this.rand = new Random(j);
    }

    public int nextInt(int i) {
        return this.rand.nextInt(i);
    }

    public double nextDouble() {
        return this.rand.nextDouble();
    }

    private double addDistribForced(Distrib distrib, double d) {
        Double d2 = this.mixture.get(distrib);
        if (d2 == null) {
            this.mixture.put(distrib, Double.valueOf(d));
        } else {
            this.mixture.put(distrib, Double.valueOf(d2.doubleValue() + d));
        }
        if (hasDistrib(distrib)) {
            int indexOf = this.distribs.indexOf(distrib);
            this.weights.set(indexOf, Double.valueOf(this.weights.get(indexOf).doubleValue() + d));
        } else {
            this.distribs.add(distrib);
            this.weights.add(Double.valueOf(d));
        }
        this.density += d;
        return this.density;
    }

    public double addDistrib(Distrib distrib, double d) {
        if (distrib instanceof MixtureDistrib) {
            MixtureDistrib mixtureDistrib = (MixtureDistrib) distrib;
            int mixtureSize = mixtureDistrib.getMixtureSize();
            for (int i = 0; i < mixtureSize; i++) {
                addDistribForced(mixtureDistrib.getDistrib(i), d * mixtureDistrib.getWeights(i).doubleValue());
            }
        } else {
            addDistribForced(distrib, d);
        }
        return this.density;
    }

    public boolean hasDistrib(Distrib distrib) {
        return this.distribs.contains(distrib);
    }

    public int getMixtureSize() {
        return this.distribs.size();
    }

    public Distrib getDistrib(int i) {
        if (i < getMixtureSize()) {
            return this.distribs.get(i);
        }
        return null;
    }

    public Double getWeights(int i) {
        if (i < getMixtureSize()) {
            return this.weights.get(i);
        }
        return null;
    }

    public void setWeight(int i, double d) {
        this.weights.set(i, Double.valueOf(d));
        this.mixture.put(this.distribs.get(i), Double.valueOf(d));
    }

    public void setWeights(double[] dArr) {
        if (dArr.length != this.weights.size()) {
            throw new RuntimeException("number of weights invalid");
        }
        for (int i = 0; i < dArr.length; i++) {
            setWeight(i, dArr[i]);
        }
    }

    public double[] getAllWeights() {
        double[] dArr = new double[this.weights.size()];
        double d = 0.0d;
        for (int i = 0; i < this.weights.size(); i++) {
            dArr[i] = this.weights.get(i).doubleValue();
            d += dArr[i];
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
        }
        return dArr;
    }

    public Double getWeightsByDistrib(Distrib distrib) {
        if (this.mixture.containsKey(distrib)) {
            return this.mixture.get(distrib);
        }
        return null;
    }

    public MixtureDistrib getNormalizedClone() {
        double d = 0.0d;
        Iterator<Double> it = this.weights.iterator();
        while (it.hasNext()) {
            d += it.next().doubleValue();
        }
        MixtureDistrib mixtureDistrib = null;
        for (int i = 0; i < this.distribs.size(); i++) {
            if (mixtureDistrib == null) {
                mixtureDistrib = new MixtureDistrib(this.distribs.get(i), this.weights.get(i).doubleValue() / d);
            } else {
                mixtureDistrib.addDistrib(this.distribs.get(i), this.weights.get(i).doubleValue() / d);
            }
        }
        return mixtureDistrib;
    }

    public void getNormalized() {
        double d = 0.0d;
        Iterator<Double> it = this.weights.iterator();
        while (it.hasNext()) {
            d += it.next().doubleValue();
        }
        for (int i = 0; i < this.weights.size(); i++) {
            this.weights.set(i, Double.valueOf(this.weights.get(i).doubleValue() / d));
            this.mixture.put(this.distribs.get(i), this.weights.get(i));
        }
    }

    @Override // bn.Distrib
    public double get(Object obj) {
        double d = 0.0d;
        for (int i = 0; i < this.distribs.size(); i++) {
            d += this.distribs.get(i).get(obj) * this.weights.get(i).doubleValue();
        }
        return d;
    }

    public Distrib componentSample() {
        double nextDouble = this.rand.nextDouble() * this.density;
        Distrib distrib = null;
        double d = 0.0d;
        for (int i = 0; i < this.distribs.size(); i++) {
            distrib = this.distribs.get(i);
            d += this.weights.get(i).doubleValue();
            if (d >= nextDouble) {
                break;
            }
        }
        return distrib;
    }

    public boolean equals(MixtureDistrib mixtureDistrib) {
        for (int i = 0; i < this.distribs.size(); i++) {
            DirichletDistrib dirichletDistrib = (DirichletDistrib) getDistrib(i);
            DirichletDistrib dirichletDistrib2 = (DirichletDistrib) mixtureDistrib.getDistrib(i);
            if (!dirichletDistrib.equals(dirichletDistrib2)) {
                return false;
            }
            getDistrib(i);
            if (Math.abs(mixtureDistrib.getWeightsByDistrib(dirichletDistrib2).doubleValue() - getWeightsByDistrib(dirichletDistrib).doubleValue()) > 1.0E-15d) {
                return false;
            }
        }
        return true;
    }

    @Override // bn.Distrib
    public Object sample() {
        Distrib componentSample = componentSample();
        if (componentSample == null) {
            throw new RuntimeException("Invalid MixtureDistrib");
        }
        return componentSample.sample();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("^" + this.distribs.size());
        for (int i = 0; i < this.distribs.size(); i++) {
            sb.append("{" + this.distribs.get(i) + "*" + String.format("%4.2f", this.weights.get(i)) + "}");
        }
        return sb.toString();
    }

    public String toXMLString() {
        StringBuilder sb = new StringBuilder("<MixtureModels>\n");
        for (int i = 0; i < this.distribs.size(); i++) {
            sb.append("<model>\n<weight>" + this.weights.get(i) + "</weight>\n<distrib>" + this.distribs.get(i) + "</distrib>\n</model>\n");
        }
        sb.append("</MixtureModels>");
        return sb.toString();
    }

    public static void main(String[] strArr) {
        GaussianDistrib gaussianDistrib = new GaussianDistrib(0.0d, 1.0d);
        System.out.println("gd1 = " + gaussianDistrib);
        GaussianDistrib gaussianDistrib2 = new GaussianDistrib(1.0d, 0.5d);
        System.out.println("gd2 = " + gaussianDistrib2);
        GaussianDistrib gaussianDistrib3 = new GaussianDistrib(-2.0d, 2.5d);
        System.out.println("gd3 = " + gaussianDistrib3);
        MixtureDistrib mixtureDistrib = new MixtureDistrib(gaussianDistrib, 1.0d);
        mixtureDistrib.addDistrib(gaussianDistrib2, 2.0d);
        mixtureDistrib.addDistrib(gaussianDistrib2, 0.5d);
        System.out.println("md1 is gd1*1.0 + gd2*2.5 : \n" + mixtureDistrib);
        MixtureDistrib mixtureDistrib2 = new MixtureDistrib(mixtureDistrib, 1.0d);
        System.out.println("mds2 is md1*1.0 : \n" + mixtureDistrib2);
        mixtureDistrib2.addDistrib(gaussianDistrib, 0.5d);
        System.out.println("md2 += gd1*0.5 : \n" + mixtureDistrib2);
        mixtureDistrib2.addDistrib(gaussianDistrib3, 2.0d);
        System.out.println("md2 += gd3*2.0 : \n" + mixtureDistrib2);
        mixtureDistrib2.addDistrib(mixtureDistrib, 2.0d);
        System.out.println("md2 += md1*2.0 : \n" + mixtureDistrib2);
        mixtureDistrib2.addDistrib(gaussianDistrib, 1.5d);
        System.out.println("md2 += gd1*1.5 : \n" + mixtureDistrib2);
        System.out.println("density = " + mixtureDistrib2.density);
    }
}
