package org.encog.neural.networks.training.pso;

import java.lang.reflect.Array;
import org.encog.mathutil.VectorAlgebra;
import org.encog.mathutil.randomize.NguyenWidrowRandomizer;
import org.encog.mathutil.randomize.Randomizer;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.structure.NetworkCODEC;
import org.encog.neural.networks.training.CalculateScore;
import org.encog.neural.networks.training.TrainingSetScore;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.concurrency.EngineConcurrency;
import org.encog.util.concurrency.TaskGroup;

/* loaded from: classes.dex */
public class NeuralPSO extends BasicTraining {
    protected double[] m_bestErrors;
    BasicNetwork m_bestNetwork;
    private double[] m_bestVector;
    protected int m_bestVectorIndex;
    protected double[][] m_bestVectors;
    protected double m_c1;
    protected double m_c2;
    protected CalculateScore m_calculateScore;
    protected double m_inertiaWeight;
    protected double m_maxPosition;
    protected double m_maxVelocity;
    protected boolean m_multiThreaded;
    protected BasicNetwork[] m_networks;
    protected int m_populationSize;
    private boolean m_pseudoAsynchronousUpdate;
    protected Randomizer m_randomizer;
    protected VectorAlgebra m_va;
    protected double[][] m_velocities;

    public NeuralPSO(BasicNetwork basicNetwork, Randomizer randomizer, CalculateScore calculateScore, int i) {
        super(TrainingImplementationType.Iterative);
        this.m_multiThreaded = true;
        this.m_bestNetwork = null;
        this.m_populationSize = 30;
        this.m_maxPosition = -1.0d;
        this.m_maxVelocity = 2.0d;
        this.m_c1 = 2.0d;
        this.m_c2 = 2.0d;
        this.m_inertiaWeight = 0.4d;
        this.m_pseudoAsynchronousUpdate = false;
        this.m_populationSize = i;
        this.m_randomizer = randomizer;
        this.m_calculateScore = calculateScore;
        this.m_bestNetwork = basicNetwork;
        this.m_networks = new BasicNetwork[this.m_populationSize];
        this.m_velocities = (double[][]) null;
        this.m_bestVectors = new double[this.m_populationSize];
        this.m_bestErrors = new double[this.m_populationSize];
        this.m_bestVectorIndex = -1;
        this.m_bestVector = NetworkCODEC.networkToArray(this.m_bestNetwork);
        this.m_va = new VectorAlgebra();
    }

    public NeuralPSO(BasicNetwork basicNetwork, MLDataSet mLDataSet) {
        this(basicNetwork, new NguyenWidrowRandomizer(), new TrainingSetScore(mLDataSet), 20);
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    public double getC1() {
        return this.m_c1;
    }

    public double getC2() {
        return this.m_c2;
    }

    public String getDescription() {
        return String.format("pop = %d, w = %.2f, c1 = %.2f, c2 = %.2f, Xmax = %.2f, Vmax = %.2f", Integer.valueOf(this.m_populationSize), Double.valueOf(this.m_inertiaWeight), Double.valueOf(this.m_c1), Double.valueOf(this.m_c2), Double.valueOf(this.m_maxPosition), Double.valueOf(this.m_maxVelocity));
    }

    public double getInertiaWeight() {
        return this.m_inertiaWeight;
    }

    public double getMaxPosition() {
        return this.m_maxPosition;
    }

    public double getMaxVelocity() {
        return this.m_maxVelocity;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.m_bestNetwork;
    }

    protected double[] getNetworkState(int i) {
        return NetworkCODEC.networkToArray(this.m_networks[i]);
    }

    public int getPopulationSize() {
        return this.m_populationSize;
    }

    void initPopulation() {
        if (this.m_velocities == null) {
            this.m_velocities = (double[][]) Array.newInstance((Class<?>) Double.TYPE, this.m_populationSize, this.m_bestVector.length);
            iterationPSO(true);
        }
    }

    public boolean isMultiThreaded() {
        return this.m_multiThreaded;
    }

    boolean isScoreBetter(double d, double d2) {
        return (this.m_calculateScore.shouldMinimize() && d < d2) || (!this.m_calculateScore.shouldMinimize() && d > d2);
    }

    @Override // org.encog.ml.train.MLTrain
    public void iteration() {
        initPopulation();
        preIteration();
        iterationPSO(false);
        postIteration();
    }

    protected void iterationPSO(boolean z) {
        TaskGroup createTaskGroup = EngineConcurrency.getInstance().createTaskGroup();
        for (int i = 0; i < this.m_populationSize; i++) {
            NeuralPSOWorker neuralPSOWorker = new NeuralPSOWorker(this, i, z);
            if (z || !isMultiThreaded()) {
                neuralPSOWorker.run();
            } else {
                EngineConcurrency.getInstance().processTask(neuralPSOWorker, createTaskGroup);
            }
        }
        if (isMultiThreaded()) {
            createTaskGroup.waitForComplete();
        }
        updateGlobalBestPosition();
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    public void setC1(double d) {
        this.m_c1 = d;
    }

    public void setC2(double d) {
        this.m_c2 = d;
    }

    public void setInertiaWeight(double d) {
        this.m_inertiaWeight = d;
    }

    public void setInitialPopulation(BasicNetwork[] basicNetworkArr) {
        this.m_networks = basicNetworkArr;
    }

    public void setMaxPosition(double d) {
        this.m_maxPosition = d;
    }

    public void setMaxVelocity(double d) {
        this.m_maxVelocity = d;
    }

    protected void setNetworkState(int i, double[] dArr) {
        NetworkCODEC.arrayToNetwork(dArr, this.m_networks[i]);
    }

    public void setPopulationSize(int i) {
        this.m_populationSize = i;
    }

    protected void updateGlobalBestPosition() {
        boolean z = false;
        for (int i = 0; i < this.m_populationSize; i++) {
            if (this.m_bestVectorIndex == -1 || isScoreBetter(this.m_bestErrors[i], this.m_bestErrors[this.m_bestVectorIndex])) {
                this.m_bestVectorIndex = i;
                z = true;
            }
        }
        if (z) {
            this.m_va.copy(this.m_bestVector, this.m_bestVectors[this.m_bestVectorIndex]);
            this.m_bestNetwork.decodeFromArray(this.m_bestVector);
            setError(this.m_bestErrors[this.m_bestVectorIndex]);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateParticle(int i, boolean z) {
        double[] networkState;
        if (z) {
            if (this.m_networks[i] == null) {
                this.m_networks[i] = (BasicNetwork) this.m_bestNetwork.clone();
                if (i > 0) {
                    this.m_randomizer.randomize(this.m_networks[i]);
                }
            }
            networkState = getNetworkState(i);
            this.m_bestVectors[i] = networkState;
            this.m_va.randomise(this.m_velocities[i], this.m_maxVelocity);
        } else {
            networkState = getNetworkState(i);
            updateVelocity(i, networkState);
            this.m_va.clampComponents(this.m_velocities[i], this.m_maxVelocity);
            this.m_va.add(networkState, this.m_velocities[i]);
            this.m_va.clampComponents(networkState, this.m_maxPosition);
            setNetworkState(i, networkState);
        }
        updatePersonalBestPosition(i, networkState);
    }

    protected void updatePersonalBestPosition(int i, double[] dArr) {
        double calculateScore = this.m_calculateScore.calculateScore(this.m_networks[i]);
        if (this.m_bestErrors[i] == FlatNetwork.NO_BIAS_ACTIVATION || isScoreBetter(calculateScore, this.m_bestErrors[i])) {
            this.m_bestErrors[i] = calculateScore;
            this.m_va.copy(this.m_bestVectors[i], dArr);
        }
    }

    protected void updateVelocity(int i, double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        this.m_va.mul(this.m_velocities[i], this.m_inertiaWeight);
        this.m_va.copy(dArr2, this.m_bestVectors[i]);
        this.m_va.sub(dArr2, dArr);
        this.m_va.mulRand(dArr2, this.m_c1);
        this.m_va.add(this.m_velocities[i], dArr2);
        if (i != this.m_bestVectorIndex) {
            this.m_va.copy(dArr2, this.m_pseudoAsynchronousUpdate ? this.m_bestVectors[this.m_bestVectorIndex] : this.m_bestVector);
            this.m_va.sub(dArr2, dArr);
            this.m_va.mulRand(dArr2, this.m_c2);
            this.m_va.add(this.m_velocities[i], dArr2);
        }
    }
}
