package com.mockturtlesolutions.snifflib.stats;

import com.mockturtlesolutions.snifflib.datatypes.DblMatrix;

/* loaded from: input_file:com/mockturtlesolutions/snifflib/stats/BernoulliDistribution.class */
public class BernoulliDistribution extends ProbabilityDensity {
    public BernoulliDistribution() {
        setParam("p", new DblMatrix(new Double(0.5d)));
    }

    public BernoulliDistribution(DblMatrix dblMatrix) {
        this();
        setParam("p", dblMatrix);
    }

    public void setP(DblMatrix dblMatrix) {
        if (dblMatrix == null) {
            throw new RuntimeException("P can not be null.");
        }
        if (DblMatrix.test(dblMatrix.leq(0.0d))) {
            throw new IllegalArgumentException("P must be non-negative.");
        }
        if (DblMatrix.test(dblMatrix.geq(1.0d))) {
            throw new IllegalArgumentException("P can not be greater than 1.");
        }
        setParam("p", dblMatrix);
    }

    public void setP(double d) {
        setP(new DblMatrix(d));
    }

    @Override // com.mockturtlesolutions.snifflib.stats.ProbabilityDensity
    public DblMatrix random(int i) {
        DblMatrix dblMatrix = new DblMatrix(i);
        DblMatrix param = getParam("p");
        for (int i2 = 0; i2 < i; i2++) {
            DblMatrix dblMatrix2 = DblMatrix.ZERO;
            if (DblMatrix.test(DblMatrix.random(1).lt(param))) {
                dblMatrix2 = DblMatrix.ONE;
            }
            dblMatrix.setDblAt(dblMatrix2, i2);
        }
        return dblMatrix;
    }

    public DblMatrix variance() {
        DblMatrix param = getParam("p");
        return param.times(DblMatrix.ONE.minus(param));
    }

    public DblMatrix mean() {
        return getParam("p");
    }

    public DblMatrix median() {
        DblMatrix param = getParam("p");
        DblMatrix dblMatrix = DblMatrix.ZERO;
        if (DblMatrix.test(param.gt(0.5d))) {
            dblMatrix = DblMatrix.ONE;
        } else if (DblMatrix.test(param.eq(0.5d))) {
            dblMatrix = DblMatrix.HALF;
        }
        return dblMatrix;
    }

    public DblMatrix mode() {
        DblMatrix param = getParam("p");
        DblMatrix dblMatrix = DblMatrix.ZERO;
        if (DblMatrix.test(param.gt(0.5d))) {
            dblMatrix = DblMatrix.ONE;
        } else if (DblMatrix.test(param.eq(0.5d))) {
            dblMatrix = new DblMatrix(2);
            dblMatrix.setDoubleAt(new Double(1.0d), 1);
        }
        return dblMatrix;
    }

    @Override // com.mockturtlesolutions.snifflib.stats.ProbabilityDensity
    public DblMatrix hasSupport(DblMatrix dblMatrix) {
        return dblMatrix.eq(1.0d).or(dblMatrix.eq(0.0d));
    }

    @Override // com.mockturtlesolutions.snifflib.stats.ProbabilityDensity, com.mockturtlesolutions.snifflib.stats.InvCDF
    public DblMatrix criticalValue(DblMatrix dblMatrix) {
        DblMatrix minus = DblMatrix.ONE.minus(getParam("p"));
        if (DblMatrix.test(DblMatrix.Any(dblMatrix.lt(0.0d))) || DblMatrix.test(DblMatrix.Any(dblMatrix.gt(1.0d)))) {
            throw new IllegalArgumentException("Invalid cdf values.");
        }
        return dblMatrix.gt(minus);
    }

    @Override // com.mockturtlesolutions.snifflib.stats.ProbabilityDensity, com.mockturtlesolutions.snifflib.stats.CDF
    public DblMatrix cdf(DblMatrix dblMatrix) {
        DblMatrix minus = DblMatrix.ONE.minus(getParam("p"));
        DblMatrix dblMatrix2 = new DblMatrix(dblMatrix.Size);
        for (int i = 0; i < dblMatrix.getN(); i++) {
            double doubleValue = dblMatrix.getDoubleAt(i).doubleValue();
            if (doubleValue >= 0.0d) {
                if (doubleValue < 1.0d) {
                    dblMatrix2.setDblAt(minus, i);
                } else {
                    dblMatrix2.setDblAt(DblMatrix.ONE, i);
                }
            }
        }
        return dblMatrix2;
    }

    @Override // com.mockturtlesolutions.snifflib.stats.ProbabilityDensity
    public DblMatrix pdf(DblMatrix dblMatrix) {
        DblMatrix dblMatrix2 = new DblMatrix(dblMatrix.Size);
        DblMatrix param = getParam("p");
        DblMatrix minus = DblMatrix.ONE.minus(param);
        for (int i = 0; i < dblMatrix.getN(); i++) {
            double doubleValue = dblMatrix.getDoubleAt(i).doubleValue();
            if (doubleValue == 0.0d) {
                dblMatrix2.setDblAt(minus, i);
            } else if (doubleValue == 1.0d) {
                dblMatrix2.setDblAt(param, i);
            }
        }
        return dblMatrix2;
    }
}
