package org.encog.neural.networks.training.propagation.scg;

import org.encog.mathutil.BoundNumbers;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;

/* loaded from: classes.dex */
public class ScaledConjugateGradient extends Propagation {
    protected static final double FIRST_LAMBDA = 1.0E-6d;
    protected static final double FIRST_SIGMA = 1.0E-4d;
    private double delta;
    private int k;
    private double lambda;
    private double lambda2;
    private double magP;
    private boolean mustInit;
    private double oldError;
    private final double[] oldGradient;
    private final double[] oldWeights;
    private final double[] p;
    private final double[] r;
    private boolean restart;
    private boolean success;
    private final double[] weights;

    public ScaledConjugateGradient(ContainsFlat containsFlat, MLDataSet mLDataSet) {
        super(containsFlat, mLDataSet);
        this.success = true;
        this.success = true;
        this.delta = FlatNetwork.NO_BIAS_ACTIVATION;
        this.lambda2 = FlatNetwork.NO_BIAS_ACTIVATION;
        this.lambda = 1.0E-6d;
        this.oldError = FlatNetwork.NO_BIAS_ACTIVATION;
        this.magP = FlatNetwork.NO_BIAS_ACTIVATION;
        this.restart = false;
        this.weights = EngineArray.arrayCopy(containsFlat.getFlat().getWeights());
        int length = this.weights.length;
        this.oldWeights = new double[length];
        this.oldGradient = new double[length];
        this.p = new double[length];
        this.r = new double[length];
        this.mustInit = true;
    }

    private void init() {
        int length = this.weights.length;
        calculateGradients();
        this.k = 1;
        for (int i = 0; i < length; i++) {
            double[] dArr = this.p;
            double[] dArr2 = this.r;
            double d = -this.gradients[i];
            dArr2[i] = d;
            dArr[i] = d;
        }
        this.mustInit = false;
    }

    @Override // org.encog.neural.networks.training.propagation.Propagation
    public void calculateGradients() {
        int outputCount = this.network.getFlat().getOutputCount();
        super.calculateGradients();
        double length = ((-2.0d) / this.gradients.length) / outputCount;
        for (int i = 0; i < this.gradients.length; i++) {
            double[] dArr = this.gradients;
            dArr[i] = dArr[i] * length;
        }
    }

    @Override // org.encog.ml.train.MLTrain
    public final boolean canContinue() {
        return false;
    }

    @Override // org.encog.neural.networks.training.propagation.Propagation
    public void initOthers() {
    }

    @Override // org.encog.neural.networks.training.propagation.Propagation, org.encog.ml.train.MLTrain
    public void iteration() {
        if (this.mustInit) {
            init();
        }
        rollIteration();
        int length = this.weights.length;
        if (this.restart) {
            this.lambda = 1.0E-6d;
            this.lambda2 = FlatNetwork.NO_BIAS_ACTIVATION;
            this.k = 1;
            this.success = true;
            this.restart = false;
        }
        if (this.success) {
            this.magP = EngineArray.vectorProduct(this.p, this.p);
            double sqrt = 1.0E-4d / Math.sqrt(this.magP);
            EngineArray.arrayCopy(this.gradients, this.oldGradient);
            EngineArray.arrayCopy(this.weights, this.oldWeights);
            this.oldError = getError();
            for (int i = 0; i < length; i++) {
                double[] dArr = this.weights;
                dArr[i] = dArr[i] + (this.p[i] * sqrt);
            }
            EngineArray.arrayCopy(this.weights, this.network.getFlat().getWeights());
            calculateGradients();
            this.delta = FlatNetwork.NO_BIAS_ACTIVATION;
            for (int i2 = 0; i2 < length; i2++) {
                this.delta += this.p[i2] * ((this.gradients[i2] - this.oldGradient[i2]) / sqrt);
            }
        }
        this.delta += (this.lambda - this.lambda2) * this.magP;
        if (this.delta <= FlatNetwork.NO_BIAS_ACTIVATION) {
            this.lambda2 = 2.0d * (this.lambda - (this.delta / this.magP));
            this.delta = (this.lambda * this.magP) - this.delta;
            this.lambda = this.lambda2;
        }
        double vectorProduct = EngineArray.vectorProduct(this.p, this.r);
        double d = vectorProduct / this.delta;
        for (int i3 = 0; i3 < length; i3++) {
            this.weights[i3] = this.oldWeights[i3] + (this.p[i3] * d);
        }
        EngineArray.arrayCopy(this.weights, this.network.getFlat().getWeights());
        calculateGradients();
        double error = ((2.0d * this.delta) * (this.oldError - getError())) / (vectorProduct * vectorProduct);
        if (error >= FlatNetwork.NO_BIAS_ACTIVATION) {
            double d2 = FlatNetwork.NO_BIAS_ACTIVATION;
            for (int i4 = 0; i4 < length; i4++) {
                double d3 = -this.gradients[i4];
                d2 += this.r[i4] * d3;
                this.r[i4] = d3;
            }
            this.lambda2 = FlatNetwork.NO_BIAS_ACTIVATION;
            this.success = true;
            if (this.k >= length) {
                this.restart = true;
                EngineArray.arrayCopy(this.r, this.p);
            } else {
                double vectorProduct2 = (EngineArray.vectorProduct(this.r, this.r) - d2) / vectorProduct;
                for (int i5 = 0; i5 < length; i5++) {
                    this.p[i5] = this.r[i5] + (this.p[i5] * vectorProduct2);
                }
                this.restart = false;
            }
            if (error >= 0.75d) {
                this.lambda *= 0.25d;
            }
        } else {
            EngineArray.arrayCopy(this.oldWeights, this.weights);
            setError(this.oldError);
            this.lambda2 = this.lambda;
            this.success = false;
        }
        if (error < 0.25d) {
            this.lambda += (this.delta * (1.0d - error)) / this.magP;
        }
        this.lambda = BoundNumbers.bound(this.lambda);
        this.k++;
        EngineArray.arrayCopy(this.weights, this.network.getFlat().getWeights());
    }

    @Override // org.encog.ml.train.MLTrain
    public final TrainingContinuation pause() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public final void resume(TrainingContinuation trainingContinuation) {
    }

    @Override // org.encog.neural.networks.training.propagation.Propagation
    public double updateWeight(double[] dArr, double[] dArr2, int i) {
        return FlatNetwork.NO_BIAS_ACTIVATION;
    }
}
