package com.mockturtlesolutions.snifflib.util;

import com.mockturtlesolutions.snifflib.datatypes.DblMatrix;
import com.mockturtlesolutions.snifflib.datatypes.Subscript;
import com.mockturtlesolutions.snifflib.invprobs.StatisticalModel;

/* loaded from: input_file:com/mockturtlesolutions/snifflib/util/Gradient.class */
public class Gradient {
    public static final int GRADIENT_ASSUME_ASYMMETRIC = 0;
    public static final int GRADIENT_ASSUME_SYMMETRIC = 1;
    public static final int GRADIENT_RETURN_FULL = 2;
    public static final int GRADIENT_RETURN_VECTOR = 3;
    public static final int GRADIENT_WRT_PARAMETERS = 4;
    public static final int GRADIENT_WRT_X = 5;
    public static final int GRADIENT_FINITE_DIFFERENCE = 6;
    private DblMatrix[] X;
    private DblMatrix[] XX;
    private DblMatrix Y;
    private StatisticalModel FUN;
    private int gradient_method = 6;
    private int symmetry_type = 1;
    private int return_type = 2;
    private int calculate_wrt = 5;
    private int gradient_order = 1;

    public Gradient(DblMatrix dblMatrix, DblMatrix dblMatrix2) {
    }

    public Gradient(DblMatrix[] dblMatrixArr, DblMatrix dblMatrix) {
    }

    public Gradient(DblMatrix dblMatrix) {
        this.Y = dblMatrix;
        int[] iArr = this.Y.Size;
        DblMatrix[] dblMatrixArr = iArr.length > 2 ? new DblMatrix[this.Y.Size.length] : iArr[1] == 1 ? new DblMatrix[]{new DblMatrix(iArr[0])} : new DblMatrix[]{new DblMatrix(iArr), new DblMatrix(iArr)};
        DblMatrix[] dblMatrixArr2 = new DblMatrix[dblMatrixArr.length];
        for (int i = 0; i < dblMatrixArr.length; i++) {
            dblMatrixArr2[i] = DblMatrix.span(0, iArr[i] - 1, iArr[i]);
        }
        this.X = DblMatrix.grid(dblMatrixArr2);
    }

    public Gradient(StatisticalModel statisticalModel, DblMatrix[] dblMatrixArr) {
        this.FUN = statisticalModel;
        this.X = dblMatrixArr;
    }

    public void setSymmetry(int i) {
        this.symmetry_type = i;
    }

    public void calculateWRT(int i) {
        this.calculate_wrt = i;
    }

    public void setReturnType(int i) {
        this.calculate_wrt = i;
    }

    public void setGradientOrder(int i) {
        this.gradient_order = i;
    }

    public static DblMatrix full(DblMatrix[] dblMatrixArr, int i, int i2) {
        DblMatrix dblMatrix = null;
        int[] iArr = new int[3];
        iArr[2] = dblMatrixArr[0].getN();
        if (i == 1) {
            iArr[0] = 1;
            iArr[1] = dblMatrixArr.length;
            dblMatrix = new DblMatrix(iArr);
            Subscript[] spanningSet = Subscript.spanningSet(iArr);
            int[] iArr2 = {1, 1, dblMatrixArr[0].getN()};
            for (int i3 = 0; i3 < dblMatrixArr.length; i3++) {
                spanningSet[1].setStart(i3);
                spanningSet[1].setStop(i3);
                dblMatrix.setSubMatrix(dblMatrixArr[i3].reshape(iArr2), spanningSet);
            }
        } else if (i == 2) {
            int[] iArr3 = {1, 1, dblMatrixArr[0].getN()};
            double sqrt = ((-1.0d) + Math.sqrt(1.0d + (8.0d * dblMatrixArr.length))) / 2.0d;
            iArr[0] = (int) sqrt;
            iArr[1] = iArr[0];
            dblMatrix = new DblMatrix(iArr);
            Subscript[] spanningSet2 = Subscript.spanningSet(iArr);
            int i4 = 0;
            for (int i5 = 0; i5 < sqrt; i5++) {
                for (int i6 = i5; i6 < sqrt; i6++) {
                    spanningSet2[0].setStart(i5);
                    spanningSet2[0].setStop(i5);
                    spanningSet2[1].setStart(i6);
                    spanningSet2[1].setStop(i6);
                    DblMatrix reshape = dblMatrixArr[i4].reshape(iArr3);
                    dblMatrix.setSubMatrix(reshape, spanningSet2);
                    if (i5 != i6 && i2 == 1) {
                        spanningSet2[0].setStart(i6);
                        spanningSet2[0].setStop(i6);
                        spanningSet2[1].setStart(i5);
                        spanningSet2[1].setStop(i5);
                        dblMatrix.setSubMatrix(reshape, spanningSet2);
                    }
                    i4++;
                }
            }
        }
        return dblMatrix;
    }

    public DblMatrix[] getGradient() {
        DblMatrix[] dblMatrixArr = null;
        if (this.gradient_method != 6) {
            throw new IllegalArgumentException("Unknown gradient method," + this.gradient_method + ".");
        }
        if (this.Y != null) {
            Subscript[] spanningSet = Subscript.spanningSet(this.X[0].Size);
            Subscript[] spanningSet2 = Subscript.spanningSet(this.X[0].Size);
            dblMatrixArr = new DblMatrix[this.X.length];
            if (this.gradient_order == 1) {
                for (int i = 0; i < this.X.length; i++) {
                    spanningSet[i].setStart(1);
                    spanningSet2[i].setStop(spanningSet[i].Value.getN() - 2);
                    DblMatrix minus = this.X[i].getSubMatrix(spanningSet).minus(this.X[i].getSubMatrix(spanningSet2));
                    dblMatrixArr[i] = this.Y.getSubMatrix(spanningSet).minus(this.Y.getSubMatrix(spanningSet2));
                    dblMatrixArr[i] = dblMatrixArr[i].divideBy(minus);
                    spanningSet[i].setStart(0);
                    spanningSet2[i].setStop(spanningSet[i].Value.getN() - 1);
                }
            }
            if (this.gradient_order == 2) {
                for (int i2 = 0; i2 < this.X.length; i2++) {
                    spanningSet[i2].setStart(1);
                    spanningSet2[i2].setStop(spanningSet[i2].Value.getN() - 2);
                    DblMatrix minus2 = this.X[i2].getSubMatrix(spanningSet).minus(this.X[i2].getSubMatrix(spanningSet2));
                    DblMatrix times = minus2.times(minus2);
                    dblMatrixArr[i2] = this.Y.times(-2.0d);
                    dblMatrixArr[i2] = this.Y.getSubMatrix(spanningSet).plus(dblMatrixArr[i2]).plus(this.Y.getSubMatrix(spanningSet2));
                    dblMatrixArr[i2] = dblMatrixArr[i2].divideBy(times);
                    spanningSet[i2].setStart(0);
                    spanningSet2[i2].setStop(spanningSet[i2].Value.getN() - 1);
                }
            }
        } else {
            if (this.FUN == null) {
                throw new IllegalArgumentException("Function is null.");
            }
            DblMatrix dblMatrix = new DblMatrix(new Double(0.01d));
            DblMatrix plus = dblMatrix.plus(dblMatrix);
            DblMatrix times2 = dblMatrix.times(dblMatrix);
            DblMatrix[] dblMatrixArr2 = new DblMatrix[this.X.length];
            if (this.gradient_order == 1) {
                if (this.calculate_wrt == 4) {
                    String[] parameterSet = this.FUN.parameterSet();
                    dblMatrixArr = new DblMatrix[parameterSet.length];
                    for (int i3 = 0; i3 < parameterSet.length; i3++) {
                        String str = parameterSet[i3];
                        DblMatrix param = this.FUN.getParam(str);
                        dblMatrix = DblMatrix.test(param.eq(0.0d)) ? new DblMatrix(1.0E-6d) : DblMatrix.abs(param.times(0.05d));
                        this.FUN.setParam(str, param.plus(dblMatrix));
                        DblMatrix valueAt = this.FUN.getValueAt(this.FUN.getX());
                        this.FUN.setParam(str, param.minus(dblMatrix));
                        DblMatrix valueAt2 = this.FUN.getValueAt(this.FUN.getX());
                        this.FUN.setParam(str, param);
                        dblMatrixArr[i3] = valueAt.minus(valueAt2).divideBy(dblMatrix.times(2));
                    }
                } else {
                    if (this.X == null) {
                        throw new IllegalArgumentException("No X data has been specified to calculate gradient.");
                    }
                    for (int i4 = 0; i4 < this.X.length; i4++) {
                        DblMatrix[] dblMatrixArr3 = this.X;
                        dblMatrixArr3[i4] = this.X[i4].plus(dblMatrix);
                        DblMatrix valueAt3 = this.FUN.getValueAt(dblMatrixArr3);
                        dblMatrixArr3[i4] = this.X[i4].minus(dblMatrix);
                        dblMatrixArr[i4] = valueAt3.minus(this.FUN.getValueAt(dblMatrixArr3)).divideBy(plus);
                    }
                }
            }
            if (this.gradient_order == 2) {
                if (this.calculate_wrt == 4) {
                    String[] parameterSet2 = this.FUN.parameterSet();
                    dblMatrixArr = new DblMatrix[(parameterSet2.length * (parameterSet2.length + 1)) / 2];
                    int i5 = 0;
                    for (int i6 = 0; i6 < parameterSet2.length; i6++) {
                        for (int i7 = i6; i7 < parameterSet2.length; i7++) {
                            if (i6 == i7) {
                                String str2 = parameterSet2[i6];
                                DblMatrix param2 = this.FUN.getParam(str2);
                                DblMatrix dblMatrix2 = DblMatrix.test(param2.eq(0.0d)) ? new DblMatrix(1.0E-6d) : DblMatrix.abs(param2.times(0.05d));
                                DblMatrix valueAt4 = this.FUN.getValueAt(this.FUN.getX());
                                this.FUN.setParam(str2, param2.plus(dblMatrix2));
                                DblMatrix valueAt5 = this.FUN.getValueAt(this.FUN.getX());
                                this.FUN.setParam(str2, param2.minus(dblMatrix2));
                                DblMatrix valueAt6 = this.FUN.getValueAt(this.FUN.getX());
                                this.FUN.setParam(str2, param2);
                                dblMatrixArr[i5] = valueAt5.minus(valueAt4.times(2)).plus(valueAt6);
                                dblMatrixArr[i5] = dblMatrixArr[i5].divideBy(dblMatrix2.times(dblMatrix2));
                            } else {
                                String str3 = parameterSet2[i6];
                                DblMatrix param3 = this.FUN.getParam(str3);
                                String str4 = parameterSet2[i7];
                                DblMatrix param4 = this.FUN.getParam(str4);
                                DblMatrix dblMatrix3 = DblMatrix.test(param3.eq(0.0d)) ? new DblMatrix(1.0E-6d) : DblMatrix.abs(param3.times(0.05d));
                                DblMatrix dblMatrix4 = DblMatrix.test(param4.eq(0.0d)) ? new DblMatrix(1.0E-6d) : DblMatrix.abs(param4.times(0.05d));
                                this.FUN.getValueAt(this.FUN.getX());
                                this.FUN.setParam(str3, param3.plus(dblMatrix3));
                                this.FUN.setParam(str4, param4.plus(dblMatrix4));
                                DblMatrix valueAt7 = this.FUN.getValueAt(this.FUN.getX());
                                this.FUN.setParam(str4, param4.minus(dblMatrix4));
                                DblMatrix valueAt8 = this.FUN.getValueAt(this.FUN.getX());
                                this.FUN.setParam(str3, param3.minus(dblMatrix3));
                                DblMatrix valueAt9 = this.FUN.getValueAt(this.FUN.getX());
                                this.FUN.setParam(str4, param4.plus(dblMatrix4));
                                DblMatrix valueAt10 = this.FUN.getValueAt(this.FUN.getX());
                                this.FUN.setParam(str3, param3);
                                this.FUN.setParam(str4, param4);
                                dblMatrixArr[i5] = valueAt7.minus(valueAt8).minus(valueAt10).plus(valueAt9);
                                dblMatrixArr[i5] = dblMatrixArr[i5].divideBy(dblMatrix3.times(dblMatrix4).times(4));
                            }
                            i5++;
                        }
                    }
                } else {
                    int i8 = 0;
                    for (int i9 = 0; i9 < this.X.length; i9++) {
                        for (int i10 = i9; i10 < this.X.length; i10++) {
                            if (i9 == i10) {
                                DblMatrix times3 = this.FUN.getValueAt(this.X[i9]).times(2);
                                DblMatrix[] dblMatrixArr4 = this.X;
                                dblMatrixArr4[i9] = this.X[i9].plus(dblMatrix);
                                DblMatrix valueAt11 = this.FUN.getValueAt(dblMatrixArr4);
                                dblMatrixArr4[i9] = this.X[i9].minus(dblMatrix);
                                dblMatrixArr[i8] = valueAt11.minus(times3).plus(this.FUN.getValueAt(dblMatrixArr4)).divideBy(times2.times(2));
                            } else {
                                DblMatrix[] dblMatrixArr5 = this.X;
                                dblMatrixArr5[i9] = this.X[i9].plus(dblMatrix);
                                dblMatrixArr5[i10] = this.X[i10].plus(dblMatrix);
                                DblMatrix valueAt12 = this.FUN.getValueAt(dblMatrixArr5);
                                dblMatrixArr5[i10] = this.X[i10].minus(dblMatrix);
                                DblMatrix valueAt13 = this.FUN.getValueAt(dblMatrixArr5);
                                dblMatrixArr5[i9] = this.X[i9].minus(dblMatrix);
                                DblMatrix valueAt14 = this.FUN.getValueAt(dblMatrixArr5);
                                dblMatrixArr5[i10] = this.X[i10].plus(dblMatrix);
                                dblMatrixArr[i8] = valueAt12.minus(valueAt13).minus(this.FUN.getValueAt(dblMatrixArr5)).plus(valueAt14);
                                dblMatrixArr[i8] = dblMatrixArr[i8].divideBy(dblMatrix.times(dblMatrix).times(4));
                            }
                            i8++;
                        }
                    }
                }
            }
        }
        return dblMatrixArr;
    }
}
