package org.encog.ml.hmm.alog;

import java.lang.reflect.Array;
import java.util.EnumSet;
import java.util.Iterator;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.neural.flat.FlatNetwork;

/* loaded from: classes.dex */
public class ForwardBackwardCalculator {
    protected double[][] alpha;
    protected double[][] beta;
    protected double probability;

    /* loaded from: classes.dex */
    public enum Computation {
        ALPHA,
        BETA
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ForwardBackwardCalculator() {
        this.alpha = (double[][]) null;
        this.beta = (double[][]) null;
    }

    public ForwardBackwardCalculator(MLDataSet mLDataSet, HiddenMarkovModel hiddenMarkovModel) {
        this(mLDataSet, hiddenMarkovModel, EnumSet.of(Computation.ALPHA));
    }

    public ForwardBackwardCalculator(MLDataSet mLDataSet, HiddenMarkovModel hiddenMarkovModel, EnumSet<Computation> enumSet) {
        this.alpha = (double[][]) null;
        this.beta = (double[][]) null;
        if (mLDataSet.size() < 1) {
            throw new IllegalArgumentException("Empty sequence");
        }
        if (enumSet.contains(Computation.ALPHA)) {
            computeAlpha(hiddenMarkovModel, mLDataSet);
        }
        if (enumSet.contains(Computation.BETA)) {
            computeBeta(hiddenMarkovModel, mLDataSet);
        }
        computeProbability(mLDataSet, hiddenMarkovModel, enumSet);
    }

    private void computeProbability(MLDataSet mLDataSet, HiddenMarkovModel hiddenMarkovModel, EnumSet<Computation> enumSet) {
        this.probability = FlatNetwork.NO_BIAS_ACTIVATION;
        if (enumSet.contains(Computation.ALPHA)) {
            for (int i = 0; i < hiddenMarkovModel.getStateCount(); i++) {
                this.probability += this.alpha[mLDataSet.size() - 1][i];
            }
            return;
        }
        for (int i2 = 0; i2 < hiddenMarkovModel.getStateCount(); i2++) {
            this.probability += hiddenMarkovModel.getPi(i2) * hiddenMarkovModel.getStateDistribution(i2).probability(mLDataSet.get(0)) * this.beta[0][i2];
        }
    }

    public double alphaElement(int i, int i2) {
        if (this.alpha == null) {
            throw new UnsupportedOperationException("Alpha array has not been computed");
        }
        return this.alpha[i][i2];
    }

    public double betaElement(int i, int i2) {
        if (this.beta == null) {
            throw new UnsupportedOperationException("Beta array has not been computed");
        }
        return this.beta[i][i2];
    }

    protected void computeAlpha(HiddenMarkovModel hiddenMarkovModel, MLDataSet mLDataSet) {
        this.alpha = (double[][]) Array.newInstance((Class<?>) Double.TYPE, mLDataSet.size(), hiddenMarkovModel.getStateCount());
        for (int i = 0; i < hiddenMarkovModel.getStateCount(); i++) {
            computeAlphaInit(hiddenMarkovModel, mLDataSet.get(0), i);
        }
        Iterator<MLDataPair> it = mLDataSet.iterator();
        if (it.hasNext()) {
            it.next();
        }
        for (int i2 = 1; i2 < mLDataSet.size(); i2++) {
            MLDataPair next = it.next();
            for (int i3 = 0; i3 < hiddenMarkovModel.getStateCount(); i3++) {
                computeAlphaStep(hiddenMarkovModel, next, i2, i3);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeAlphaInit(HiddenMarkovModel hiddenMarkovModel, MLDataPair mLDataPair, int i) {
        this.alpha[0][i] = hiddenMarkovModel.getPi(i) * hiddenMarkovModel.getStateDistribution(i).probability(mLDataPair);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeAlphaStep(HiddenMarkovModel hiddenMarkovModel, MLDataPair mLDataPair, int i, int i2) {
        double d = FlatNetwork.NO_BIAS_ACTIVATION;
        for (int i3 = 0; i3 < hiddenMarkovModel.getStateCount(); i3++) {
            d += this.alpha[i - 1][i3] * hiddenMarkovModel.getTransitionProbability(i3, i2);
        }
        this.alpha[i][i2] = hiddenMarkovModel.getStateDistribution(i2).probability(mLDataPair) * d;
    }

    protected void computeBeta(HiddenMarkovModel hiddenMarkovModel, MLDataSet mLDataSet) {
        this.beta = (double[][]) Array.newInstance((Class<?>) Double.TYPE, mLDataSet.size(), hiddenMarkovModel.getStateCount());
        for (int i = 0; i < hiddenMarkovModel.getStateCount(); i++) {
            this.beta[mLDataSet.size() - 1][i] = 1.0d;
        }
        for (int size = mLDataSet.size() - 2; size >= 0; size--) {
            for (int i2 = 0; i2 < hiddenMarkovModel.getStateCount(); i2++) {
                computeBetaStep(hiddenMarkovModel, mLDataSet.get(size + 1), size, i2);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeBetaStep(HiddenMarkovModel hiddenMarkovModel, MLDataPair mLDataPair, int i, int i2) {
        double d = FlatNetwork.NO_BIAS_ACTIVATION;
        for (int i3 = 0; i3 < hiddenMarkovModel.getStateCount(); i3++) {
            d += this.beta[i + 1][i3] * hiddenMarkovModel.getTransitionProbability(i2, i3) * hiddenMarkovModel.getStateDistribution(i3).probability(mLDataPair);
        }
        this.beta[i][i2] = d;
    }

    public double probability() {
        return this.probability;
    }
}
