package quipu.maxent;

import java.io.BufferedWriter;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.zip.GZIPOutputStream;

/* loaded from: input_file:quipu/maxent/GIS.class */
public class GIS {
    private static int[][] contexts;
    private static int[][] cfvals;
    private static double[][] observedExpects;
    private static double[][] params;
    private static double[][] modifiers;
    private static String[] outcomeLabels;
    private static String[] predLabels;
    private static double constantInverse;
    private static double cfObservedExpect;
    private static int numOutcomes;
    private static int numPreds;
    private static int numTokens;
    private static double[][] pabiTable;
    private static int constant = 1;
    private static double correctionParam = 0.0d;

    public static void trainModel(String str, DataIndexer dataIndexer, int i) {
        trainModel("", str, dataIndexer, i);
    }

    public static void trainModel(String str, String str2, DataIndexer dataIndexer, int i) {
        contexts = dataIndexer.contexts;
        numTokens = contexts.length;
        for (int i2 = 0; i2 < contexts.length; i2++) {
            if (contexts[i2].length > constant) {
                constant = contexts[i2].length;
            }
        }
        constantInverse = 1.0d / constant;
        outcomeLabels = dataIndexer.outcomeLabels;
        numOutcomes = outcomeLabels.length;
        predLabels = dataIndexer.predLabels;
        numPreds = predLabels.length;
        int[][] iArr = new int[numPreds][numOutcomes];
        for (int i3 = 0; i3 < numTokens; i3++) {
            for (int i4 = 0; i4 < contexts[i3].length; i4++) {
                int[] iArr2 = iArr[contexts[i3][i4]];
                int i5 = dataIndexer.outcomeList[i3];
                iArr2[i5] = iArr2[i5] + 1;
            }
        }
        observedExpects = new double[numPreds][numOutcomes];
        for (int i6 = 0; i6 < numPreds; i6++) {
            for (int i7 = 0; i7 < numOutcomes; i7++) {
                observedExpects[i6][i7] = Math.log(iArr[i6][i7]);
            }
        }
        params = new double[numPreds][numOutcomes];
        for (int i8 = 0; i8 < numPreds; i8++) {
            for (int i9 = 0; i9 < numOutcomes; i9++) {
                if (iArr[i8][i9] > 0) {
                    params[i8][i9] = 0.0d;
                } else {
                    params[i8][i9] = Double.NaN;
                }
            }
        }
        modifiers = new double[numPreds][numOutcomes];
        pabiTable = new double[numTokens][numOutcomes];
        cfvals = new int[numTokens][numOutcomes];
        for (int i10 = 0; i10 < numTokens; i10++) {
            for (int i11 = 0; i11 < contexts[i10].length; i11++) {
                int i12 = contexts[i10][i11];
                for (int i13 = 0; i13 < numOutcomes; i13++) {
                    if (!Double.isNaN(params[i12][i13])) {
                        int[] iArr3 = cfvals[i10];
                        int i14 = i13;
                        iArr3[i14] = iArr3[i14] + 1;
                    }
                }
            }
        }
        for (int i15 = 0; i15 < numTokens; i15++) {
            for (int i16 = 0; i16 < numOutcomes; i16++) {
                cfvals[i15][i16] = constant - cfvals[i15][i16];
            }
        }
        int i17 = 0;
        for (int i18 = 0; i18 < numTokens; i18++) {
            i17 += constant - contexts[i18].length;
        }
        cfObservedExpect = Math.log(i17);
        findParameters(i);
        try {
            writeModel(new StringBuffer().append(str).append(str2).toString());
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static void findParameters(int i) {
        for (int i2 = 0; i2 <= i; i2++) {
            nextIteration();
            System.out.println(new StringBuffer("Iteration: ").append(i2).toString());
        }
    }

    private static void nextIteration() {
        for (int i = 0; i < numTokens; i++) {
            double[] dArr = pabiTable[i];
            for (int i2 = 0; i2 < numOutcomes; i2++) {
                dArr[i2] = Double.NaN;
            }
            for (int i3 = 0; i3 < contexts[i].length; i3++) {
                int i4 = contexts[i][i3];
                for (int i5 = 0; i5 < numOutcomes; i5++) {
                    if (!Double.isNaN(params[i4][i5])) {
                        if (Double.isNaN(dArr[i5])) {
                            dArr[i5] = params[i4][i5];
                        } else {
                            int i6 = i5;
                            dArr[i6] = dArr[i6] + params[i4][i5];
                        }
                    }
                }
            }
            for (int i7 = 0; i7 < numOutcomes; i7++) {
                int i8 = i7;
                dArr[i8] = dArr[i8] + (correctionParam * cfvals[i][i7]);
            }
            for (int i9 = 0; i9 < numOutcomes; i9++) {
                if (Double.isNaN(dArr[i9])) {
                    dArr[i9] = 0.0d;
                } else {
                    dArr[i9] = Math.exp(dArr[i9]);
                }
            }
            double sum = sum(dArr);
            if (sum > 0.0d) {
                for (int i10 = 0; i10 < numOutcomes; i10++) {
                    int i11 = i10;
                    dArr[i11] = dArr[i11] / sum;
                }
            }
        }
        for (int i12 = 0; i12 < numPreds; i12++) {
            for (int i13 = 0; i13 < numOutcomes; i13++) {
                modifiers[i12][i13] = 0.0d;
            }
        }
        for (int i14 = 0; i14 < numTokens; i14++) {
            for (int i15 = 0; i15 < contexts[i14].length; i15++) {
                int i16 = contexts[i14][i15];
                for (int i17 = 0; i17 < numOutcomes; i17++) {
                    if (!Double.isNaN(params[i16][i17])) {
                        double[] dArr2 = modifiers[i16];
                        int i18 = i17;
                        dArr2[i18] = dArr2[i18] + pabiTable[i14][i17];
                    }
                }
            }
        }
        for (int i19 = 0; i19 < numPreds; i19++) {
            for (int i20 = 0; i20 < numOutcomes; i20++) {
                if (!Double.isNaN(params[i19][i20])) {
                    modifiers[i19][i20] = constantInverse * (observedExpects[i19][i20] - Math.log(modifiers[i19][i20]));
                }
            }
        }
        for (int i21 = 0; i21 < numPreds; i21++) {
            for (int i22 = 0; i22 < numOutcomes; i22++) {
                if (!Double.isNaN(params[i21][i22])) {
                    double[] dArr3 = params[i21];
                    int i23 = i22;
                    dArr3[i23] = dArr3[i23] + modifiers[i21][i22];
                }
            }
        }
        double d = 0.0d;
        for (int i24 = 0; i24 < numTokens; i24++) {
            for (int i25 = 0; i25 < numOutcomes; i25++) {
                d += pabiTable[i24][i25] * cfvals[i24][i25];
            }
        }
        correctionParam += constantInverse * (cfObservedExpect - Math.log(d));
    }

    private static void writeModel(String str) throws IOException {
        ComparablePredicate[] comparablePredicateArr = new ComparablePredicate[numPreds];
        int i = 0;
        for (int i2 = 0; i2 < params.length; i2++) {
            int i3 = 0;
            for (int i4 = 0; i4 < numOutcomes; i4++) {
                if (!Double.isNaN(params[i2][i4])) {
                    i3++;
                }
            }
            int[] iArr = new int[i3];
            double[] dArr = new double[i3];
            i += i3;
            int i5 = 0;
            for (int i6 = 0; i6 < numOutcomes; i6++) {
                if (!Double.isNaN(params[i2][i6])) {
                    iArr[i5] = i6;
                    dArr[i5] = params[i2][i6];
                    i5++;
                }
            }
            comparablePredicateArr[i2] = new ComparablePredicate(predLabels[i2], iArr, dArr);
        }
        Arrays.sort(comparablePredicateArr);
        ComparablePredicate comparablePredicate = comparablePredicateArr[0];
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i7 = 0; i7 < comparablePredicateArr.length; i7++) {
            if (comparablePredicate.compareTo(comparablePredicateArr[i7]) == 0) {
                arrayList2.add(comparablePredicateArr[i7]);
            } else {
                comparablePredicate = comparablePredicateArr[i7];
                arrayList.add(arrayList2);
                arrayList2 = new ArrayList();
                arrayList2.add(comparablePredicateArr[i7]);
            }
        }
        arrayList.add(arrayList2);
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(new StringBuffer().append(str).append(".mei.gz").toString()))));
        bufferedWriter.write("GIS");
        bufferedWriter.newLine();
        bufferedWriter.write(Integer.toString(constant));
        bufferedWriter.newLine();
        bufferedWriter.write(Double.toString(correctionParam));
        bufferedWriter.newLine();
        bufferedWriter.write(Integer.toString(numOutcomes));
        bufferedWriter.newLine();
        for (int i8 = 0; i8 < numOutcomes; i8++) {
            bufferedWriter.write(outcomeLabels[i8]);
            bufferedWriter.newLine();
        }
        bufferedWriter.write(Integer.toString(arrayList.size()));
        bufferedWriter.newLine();
        for (int i9 = 0; i9 < arrayList.size(); i9++) {
            ArrayList arrayList3 = (ArrayList) arrayList.get(i9);
            bufferedWriter.write(new StringBuffer().append(arrayList3.size()).append(((ComparablePredicate) arrayList3.get(0)).toString()).toString());
            bufferedWriter.newLine();
        }
        bufferedWriter.write(Integer.toString(numPreds));
        bufferedWriter.newLine();
        for (int i10 = 0; i10 < numPreds; i10++) {
            bufferedWriter.write(comparablePredicateArr[i10].name);
            bufferedWriter.newLine();
        }
        bufferedWriter.flush();
        bufferedWriter.close();
        DataOutputStream dataOutputStream = new DataOutputStream(new GZIPOutputStream(new FileOutputStream(new StringBuffer().append(str).append(".mep.gz").toString())));
        int i11 = 0;
        for (int i12 = 0; i12 < numPreds; i12++) {
            for (int i13 = 0; i13 < comparablePredicateArr[i12].params.length; i13++) {
                i11++;
                dataOutputStream.writeDouble(comparablePredicateArr[i12].params[i13]);
            }
        }
        dataOutputStream.flush();
        dataOutputStream.close();
    }

    private static double sum(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d;
    }

    private static void printTable(int[][] iArr) {
        String str = "";
        for (int i = 0; i < 20; i++) {
            str = new StringBuffer().append(str).append("\n").append(i).append(".").toString();
            for (int i2 = 0; i2 < iArr[i].length; i2++) {
                str = new StringBuffer().append(str).append("\t").append(i2).append(": ").append(iArr[i][i2]).toString();
            }
        }
        System.out.println(str);
    }

    private static void printTable(double[][] dArr) {
        String str = "";
        for (int i = 0; i < 20; i++) {
            str = new StringBuffer().append(str).append("\n").append(i).append(".").toString();
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                str = new StringBuffer().append(str).append("\t").append(i2).append(": ").append(dArr[i][i2]).toString();
            }
        }
        System.out.println(str);
    }
}
