package opennlp.tools.ml.perceptron;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import opennlp.tools.ml.AbstractEventModelSequenceTrainer;
import opennlp.tools.ml.model.AbstractDataIndexer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.model.OnePassDataIndexer;
import opennlp.tools.ml.model.Sequence;
import opennlp.tools.ml.model.SequenceStream;
import opennlp.tools.ml.model.SequenceStreamEventStream;

/* loaded from: classes2.dex */
public class SimplePerceptronSequenceTrainer extends AbstractEventModelSequenceTrainer {
    private static final int EVENT = 2;
    private static final int ITER = 1;
    public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
    private static final int VALUE = 0;
    private MutableContext[] averageParams;
    private int iterations;
    private int numEvents;
    private int numOutcomes;
    private int numPreds;
    private int numSequences;
    private Map<String, Integer> omap;
    private String[] outcomeLabels;
    private int[] outcomeList;
    private MutableContext[] params;
    private Map<String, Integer> pmap;
    private String[] predLabels;
    private SequenceStream<Event> sequenceStream;
    private int[][][] updates;
    private boolean useAverage;

    private void findParameters(int i9) {
        for (int i10 = 1; i10 <= i9; i10++) {
            nextIteration(i10);
        }
        if (this.useAverage) {
            trainingStats(this.averageParams);
        } else {
            trainingStats(this.params);
        }
    }

    private void trainingStats(MutableContext[] mutableContextArr) {
        this.sequenceStream.reset();
        int i9 = 0;
        while (true) {
            Sequence read = this.sequenceStream.read();
            if (read == null) {
                return;
            }
            Event[] updateContext = this.sequenceStream.updateContext(read, new PerceptronModel(mutableContextArr, this.predLabels, this.outcomeLabels));
            int i10 = 0;
            while (i10 < updateContext.length) {
                this.omap.get(updateContext[i10].getOutcome()).intValue();
                int i11 = this.outcomeList[i9];
                i10++;
                i9++;
            }
        }
    }

    @Override // opennlp.tools.ml.AbstractEventModelSequenceTrainer
    public AbstractModel doTrain(SequenceStream<Event> sequenceStream) {
        return trainModel(getIterations(), sequenceStream, getCutoff(), this.trainingParameters.getBooleanParameter("UseAverage", true));
    }

    @Override // opennlp.tools.ml.AbstractEventModelSequenceTrainer
    public /* bridge */ /* synthetic */ MaxentModel doTrain(SequenceStream sequenceStream) {
        return doTrain((SequenceStream<Event>) sequenceStream);
    }

    @Override // opennlp.tools.ml.AbstractTrainer
    @Deprecated
    public boolean isValid() {
        try {
            validate();
            return true;
        } catch (IllegalArgumentException unused) {
            return false;
        }
    }

    public void nextIteration(int i9) {
        int i10 = i9 - 1;
        ArrayList arrayList = new ArrayList(this.numOutcomes);
        for (int i11 = 0; i11 < this.numOutcomes; i11++) {
            arrayList.add(new HashMap());
        }
        PerceptronModel perceptronModel = new PerceptronModel(this.params, this.predLabels, this.outcomeLabels);
        this.sequenceStream.reset();
        int i12 = 0;
        while (true) {
            Sequence read = this.sequenceStream.read();
            if (read == null) {
                break;
            }
            Event[] updateContext = this.sequenceStream.updateContext(read, perceptronModel);
            Event[] events = read.getEvents();
            boolean z8 = false;
            for (int i13 = 0; i13 < events.length; i13++) {
                if (!updateContext[i13].getOutcome().equals(events[i13].getOutcome())) {
                    z8 = true;
                }
            }
            if (z8) {
                for (int i14 = 0; i14 < this.numOutcomes; i14++) {
                    ((Map) arrayList.get(i14)).clear();
                }
                for (int i15 = 0; i15 < events.length; i15++) {
                    String[] context = events[i15].getContext();
                    float[] values = events[i15].getValues();
                    int intValue = this.omap.get(events[i15].getOutcome()).intValue();
                    for (int i16 = 0; i16 < context.length; i16++) {
                        float f8 = values != null ? values[i16] : 1.0f;
                        Float f9 = (Float) ((Map) arrayList.get(intValue)).get(context[i16]);
                        ((Map) arrayList.get(intValue)).put(context[i16], f9 == null ? Float.valueOf(f8) : Float.valueOf(f9.floatValue() + f8));
                    }
                }
                for (Event event : updateContext) {
                    String[] context2 = event.getContext();
                    float[] values2 = event.getValues();
                    int intValue2 = this.omap.get(event.getOutcome()).intValue();
                    for (int i17 = 0; i17 < context2.length; i17++) {
                        float f10 = values2 != null ? values2[i17] : 1.0f;
                        Float f11 = (Float) ((Map) arrayList.get(intValue2)).get(context2[i17]);
                        Float valueOf = f11 == null ? Float.valueOf(f10 * (-1.0f)) : Float.valueOf(f11.floatValue() - f10);
                        if (valueOf.floatValue() == 0.0f) {
                            ((Map) arrayList.get(intValue2)).remove(context2[i17]);
                        } else {
                            ((Map) arrayList.get(intValue2)).put(context2[i17], valueOf);
                        }
                    }
                }
                for (int i18 = 0; i18 < this.numOutcomes; i18++) {
                    Iterator it = ((Map) arrayList.get(i18)).keySet().iterator();
                    while (it.hasNext()) {
                        int intValue3 = this.pmap.getOrDefault((String) it.next(), -1).intValue();
                        if (intValue3 != -1) {
                            this.params[intValue3].updateParameter(i18, ((Float) ((Map) arrayList.get(i18)).get(r8)).floatValue());
                            if (this.useAverage) {
                                if (this.updates[intValue3][i18][0] != 0) {
                                    this.averageParams[intValue3].updateParameter(i18, ((i12 - r8[2]) + ((i10 - r8[1]) * this.numSequences)) * r10);
                                }
                                this.updates[intValue3][i18][0] = (int) this.params[intValue3].getParameters()[i18];
                                int[] iArr = this.updates[intValue3][i18];
                                iArr[1] = i10;
                                iArr[2] = i12;
                            }
                        }
                    }
                }
                perceptronModel = new PerceptronModel(this.params, this.predLabels, this.outcomeLabels);
            }
            i12++;
        }
        int i19 = this.iterations;
        double d9 = i19 * i12;
        if (this.useAverage && i10 == i19 - 1) {
            for (int i20 = 0; i20 < this.numPreds; i20++) {
                double[] parameters = this.averageParams[i20].getParameters();
                for (int i21 = 0; i21 < this.numOutcomes; i21++) {
                    if (this.updates[i20][i21][0] != 0) {
                        parameters[i21] = parameters[i21] + ((((this.iterations - r5[1]) * this.numSequences) - r5[2]) * r6);
                    }
                    double d10 = parameters[i21];
                    if (d10 != 0.0d) {
                        double d11 = d10 / d9;
                        parameters[i21] = d11;
                        this.averageParams[i20].setParameter(i21, d11);
                    }
                }
            }
        }
    }

    public AbstractModel trainModel(int i9, SequenceStream<Event> sequenceStream, int i10, boolean z8) {
        this.iterations = i9;
        this.sequenceStream = sequenceStream;
        this.trainingParameters.put("Cutoff", i10);
        this.trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false);
        OnePassDataIndexer onePassDataIndexer = new OnePassDataIndexer();
        onePassDataIndexer.init(this.trainingParameters, this.reportMap);
        onePassDataIndexer.index(new SequenceStreamEventStream(sequenceStream));
        this.numSequences = 0;
        sequenceStream.reset();
        while (sequenceStream.read() != null) {
            this.numSequences++;
        }
        this.outcomeList = onePassDataIndexer.getOutcomeList();
        this.predLabels = onePassDataIndexer.getPredLabels();
        this.pmap = new HashMap();
        int i11 = 0;
        while (true) {
            String[] strArr = this.predLabels;
            if (i11 >= strArr.length) {
                break;
            }
            this.pmap.put(strArr[i11], Integer.valueOf(i11));
            i11++;
        }
        this.useAverage = z8;
        this.numEvents = onePassDataIndexer.getNumEvents();
        this.iterations = i9;
        this.outcomeLabels = onePassDataIndexer.getOutcomeLabels();
        this.omap = new HashMap();
        int i12 = 0;
        while (true) {
            String[] strArr2 = this.outcomeLabels;
            if (i12 >= strArr2.length) {
                break;
            }
            this.omap.put(strArr2[i12], Integer.valueOf(i12));
            i12++;
        }
        this.outcomeList = onePassDataIndexer.getOutcomeList();
        int length = this.predLabels.length;
        this.numPreds = length;
        int length2 = this.outcomeLabels.length;
        this.numOutcomes = length2;
        if (z8) {
            this.updates = (int[][][]) Array.newInstance((Class<?>) Integer.TYPE, length, length2, 3);
        }
        int i13 = this.numPreds;
        this.params = new MutableContext[i13];
        if (z8) {
            this.averageParams = new MutableContext[i13];
        }
        int[] iArr = new int[this.numOutcomes];
        for (int i14 = 0; i14 < this.numOutcomes; i14++) {
            iArr[i14] = i14;
        }
        for (int i15 = 0; i15 < this.numPreds; i15++) {
            this.params[i15] = new MutableContext(iArr, new double[this.numOutcomes]);
            if (z8) {
                this.averageParams[i15] = new MutableContext(iArr, new double[this.numOutcomes]);
            }
            for (int i16 = 0; i16 < this.numOutcomes; i16++) {
                this.params[i15].setParameter(i16, 0.0d);
                if (z8) {
                    this.averageParams[i15].setParameter(i16, 0.0d);
                }
            }
        }
        findParameters(i9);
        String[] strArr3 = this.predLabels;
        return z8 ? new PerceptronModel(this.averageParams, strArr3, this.outcomeLabels) : new PerceptronModel(this.params, strArr3, this.outcomeLabels);
    }

    @Override // opennlp.tools.ml.AbstractTrainer
    public void validate() {
        super.validate();
        String algorithm = getAlgorithm();
        if (algorithm != null && !PERCEPTRON_SEQUENCE_VALUE.equals(algorithm)) {
            throw new IllegalArgumentException("algorithmName must be PERCEPTRON_SEQUENCE");
        }
    }
}
