package opennlp.tools.ml.perceptron;

import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.util.TrainingParameters;

/* loaded from: classes2.dex */
public class PerceptronTrainer extends AbstractEventTrainer {
    public static final String PERCEPTRON_VALUE = "PERCEPTRON";
    public static final double TOLERANCE_DEFAULT = 1.0E-5d;
    private int[][] contexts;
    private int numEvents;
    private int numOutcomes;
    private int numPreds;
    private int[] numTimesEventsSeen;
    private int numUniqueEvents;
    private String[] outcomeLabels;
    private int[] outcomeList;
    private String[] predLabels;
    private Double stepSizeDecrease;
    private double tolerance;
    private boolean useSkippedlAveraging;
    private float[][] values;

    public PerceptronTrainer() {
        this.tolerance = 1.0E-5d;
    }

    public PerceptronTrainer(TrainingParameters trainingParameters) {
        super(trainingParameters);
        this.tolerance = 1.0E-5d;
    }

    private MutableContext[] findParameters(int i6, boolean z6) {
        double d6;
        MutableContext[] mutableContextArr;
        EvalParameters evalParameters;
        double d7;
        int i7;
        double d8;
        EvalParameters evalParameters2;
        MutableContext[] mutableContextArr2;
        double d9;
        int[] iArr = new int[this.numOutcomes];
        for (int i8 = 0; i8 < this.numOutcomes; i8++) {
            iArr[i8] = i8;
        }
        MutableContext[] mutableContextArr3 = new MutableContext[this.numPreds];
        int i9 = 0;
        while (true) {
            d6 = 0.0d;
            if (i9 >= this.numPreds) {
                break;
            }
            mutableContextArr3[i9] = new MutableContext(iArr, new double[this.numOutcomes]);
            for (int i10 = 0; i10 < this.numOutcomes; i10++) {
                mutableContextArr3[i9].setParameter(i10, 0.0d);
            }
            i9++;
        }
        EvalParameters evalParameters3 = new EvalParameters(mutableContextArr3, this.numOutcomes);
        MutableContext[] mutableContextArr4 = new MutableContext[this.numPreds];
        if (z6) {
            for (int i11 = 0; i11 < this.numPreds; i11++) {
                mutableContextArr4[i11] = new MutableContext(iArr, new double[this.numOutcomes]);
                for (int i12 = 0; i12 < this.numOutcomes; i12++) {
                    mutableContextArr4[i11].setParameter(i12, 0.0d);
                }
            }
        }
        double d10 = 1.0d;
        int i13 = i6;
        double d11 = 0.0d;
        double d12 = 0.0d;
        double d13 = 1.0d;
        int i14 = 1;
        int i15 = 0;
        while (true) {
            if (i14 > i13) {
                mutableContextArr = mutableContextArr4;
                evalParameters = evalParameters3;
                break;
            }
            Double d14 = this.stepSizeDecrease;
            if (d14 != null) {
                d13 *= d10 - d14.doubleValue();
            }
            double d15 = d13;
            int i16 = 0;
            int i17 = 0;
            while (i16 < this.numUniqueEvents) {
                int i18 = this.outcomeList[i16];
                double d16 = d12;
                int i19 = 0;
                while (i19 < this.numTimesEventsSeen[i16]) {
                    double[] dArr = new double[this.numOutcomes];
                    float[][] fArr = this.values;
                    if (fArr != null) {
                        d7 = d11;
                        int[] iArr2 = this.contexts[i16];
                        float[] fArr2 = fArr[i16];
                        i7 = 0;
                        PerceptronModel.eval(iArr2, fArr2, dArr, evalParameters3, false);
                    } else {
                        d7 = d11;
                        i7 = 0;
                        PerceptronModel.eval(this.contexts[i16], (float[]) null, dArr, evalParameters3, false);
                    }
                    int argmax = ArrayMath.argmax(dArr);
                    if (argmax != i18) {
                        int i20 = i7;
                        while (true) {
                            int[] iArr3 = this.contexts[i16];
                            if (i20 >= iArr3.length) {
                                break;
                            }
                            int i21 = iArr3[i20];
                            float[][] fArr3 = this.values;
                            if (fArr3 == null) {
                                mutableContextArr3[i21].updateParameter(i18, d15);
                                d8 = d6;
                                mutableContextArr3[i21].updateParameter(argmax, -d15);
                                evalParameters2 = evalParameters3;
                                mutableContextArr2 = mutableContextArr4;
                                d9 = d15;
                            } else {
                                d8 = d6;
                                MutableContext mutableContext = mutableContextArr3[i21];
                                float f6 = fArr3[i16][i20];
                                evalParameters2 = evalParameters3;
                                mutableContextArr2 = mutableContextArr4;
                                mutableContext.updateParameter(i18, f6 * d15);
                                d9 = d15;
                                mutableContextArr3[i21].updateParameter(argmax, (-d15) * this.values[i16][i20]);
                            }
                            i20++;
                            evalParameters3 = evalParameters2;
                            mutableContextArr4 = mutableContextArr2;
                            d6 = d8;
                            d15 = d9;
                        }
                    }
                    EvalParameters evalParameters4 = evalParameters3;
                    MutableContext[] mutableContextArr5 = mutableContextArr4;
                    double d17 = d6;
                    double d18 = d15;
                    if (argmax == i18) {
                        i17++;
                    }
                    i19++;
                    evalParameters3 = evalParameters4;
                    mutableContextArr4 = mutableContextArr5;
                    d11 = d7;
                    d6 = d17;
                    d15 = d18;
                }
                i16++;
                d12 = d16;
            }
            mutableContextArr = mutableContextArr4;
            double d19 = d6;
            double d20 = d15;
            double d21 = d11;
            double d22 = d12;
            EvalParameters evalParameters5 = evalParameters3;
            double d23 = i17 / this.numEvents;
            if ((z6 && this.useSkippedlAveraging && (i14 < 20 || isPerfectSquare(i14))) || z6) {
                i15++;
                for (int i22 = 0; i22 < this.numPreds; i22++) {
                    for (int i23 = 0; i23 < this.numOutcomes; i23++) {
                        mutableContextArr[i22].updateParameter(i23, mutableContextArr3[i22].getParameters()[i23]);
                    }
                }
            }
            if (StrictMath.abs(d19 - d23) < this.tolerance && StrictMath.abs(d21 - d23) < this.tolerance && StrictMath.abs(d22 - d23) < this.tolerance) {
                evalParameters = evalParameters5;
                break;
            }
            i14++;
            i13 = i6;
            evalParameters3 = evalParameters5;
            mutableContextArr4 = mutableContextArr;
            d11 = d22;
            d13 = d20;
            d10 = 1.0d;
            d12 = d23;
            d6 = d21;
        }
        int i24 = i15;
        trainingStats(evalParameters);
        if (!z6) {
            return mutableContextArr3;
        }
        for (int i25 = 0; i25 < this.numPreds; i25++) {
            for (int i26 = 0; i26 < this.numOutcomes; i26++) {
                MutableContext mutableContext2 = mutableContextArr[i25];
                mutableContext2.setParameter(i26, mutableContext2.getParameters()[i26] / i24);
            }
        }
        return mutableContextArr;
    }

    private static boolean isPerfectSquare(int i6) {
        int sqrt = (int) StrictMath.sqrt(i6);
        return sqrt * sqrt == i6;
    }

    private double trainingStats(EvalParameters evalParameters) {
        int i6 = 0;
        for (int i7 = 0; i7 < this.numUniqueEvents; i7++) {
            for (int i8 = 0; i8 < this.numTimesEventsSeen[i7]; i8++) {
                double[] dArr = new double[this.numOutcomes];
                float[][] fArr = this.values;
                if (fArr != null) {
                    PerceptronModel.eval(this.contexts[i7], fArr[i7], dArr, evalParameters, false);
                } else {
                    PerceptronModel.eval(this.contexts[i7], (float[]) null, dArr, evalParameters, false);
                }
                if (ArrayMath.argmax(dArr) == this.outcomeList[i7]) {
                    i6++;
                }
            }
        }
        return i6 / this.numEvents;
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer
    public AbstractModel doTrain(DataIndexer dataIndexer) {
        int iterations = getIterations();
        int cutoff = getCutoff();
        boolean booleanParameter = this.trainingParameters.getBooleanParameter("UseAverage", true);
        boolean booleanParameter2 = this.trainingParameters.getBooleanParameter("UseSkippedAveraging", false);
        boolean z6 = booleanParameter2 ? true : booleanParameter;
        double doubleParameter = this.trainingParameters.getDoubleParameter("StepSizeDecrease", 0.0d);
        double doubleParameter2 = this.trainingParameters.getDoubleParameter("Tolerance", 1.0E-5d);
        setSkippedAveraging(booleanParameter2);
        if (doubleParameter > 0.0d) {
            setStepSizeDecrease(doubleParameter);
        }
        setTolerance(doubleParameter2);
        return trainModel(iterations, dataIndexer, cutoff, z6);
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer
    public boolean isSortAndMerge() {
        return false;
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer, opennlp.tools.ml.AbstractTrainer
    @Deprecated
    public boolean isValid() {
        if (!super.isValid()) {
            return false;
        }
        String algorithm = getAlgorithm();
        if (algorithm != null) {
            return PERCEPTRON_VALUE.equals(algorithm);
        }
        return true;
    }

    public void setSkippedAveraging(boolean z6) {
        this.useSkippedlAveraging = z6;
    }

    public void setStepSizeDecrease(double d6) {
        if (d6 >= 0.0d && d6 <= 100.0d) {
            this.stepSizeDecrease = Double.valueOf(d6);
            return;
        }
        throw new IllegalArgumentException("decrease must be between 0 and 100 but is " + d6 + "!");
    }

    public void setTolerance(double d6) {
        if (d6 >= 0.0d) {
            this.tolerance = d6;
            return;
        }
        throw new IllegalArgumentException("tolerance must be a positive number but is " + d6 + "!");
    }

    public AbstractModel trainModel(int i6, DataIndexer dataIndexer, int i7) {
        return trainModel(i6, dataIndexer, i7, true);
    }

    public AbstractModel trainModel(int i6, DataIndexer dataIndexer, int i7, boolean z6) {
        this.contexts = dataIndexer.getContexts();
        this.values = dataIndexer.getValues();
        this.numTimesEventsSeen = dataIndexer.getNumTimesEventsSeen();
        this.numEvents = dataIndexer.getNumEvents();
        this.numUniqueEvents = this.contexts.length;
        this.outcomeLabels = dataIndexer.getOutcomeLabels();
        this.outcomeList = dataIndexer.getOutcomeList();
        String[] predLabels = dataIndexer.getPredLabels();
        this.predLabels = predLabels;
        this.numPreds = predLabels.length;
        this.numOutcomes = this.outcomeLabels.length;
        return new PerceptronModel(findParameters(i6, z6), this.predLabels, this.outcomeLabels);
    }

    @Override // opennlp.tools.ml.AbstractEventTrainer, opennlp.tools.ml.AbstractTrainer
    public void validate() {
        super.validate();
        String algorithm = getAlgorithm();
        if (algorithm != null && !PERCEPTRON_VALUE.equals(algorithm)) {
            throw new IllegalArgumentException("algorithmName must be PERCEPTRON");
        }
    }
}
