/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.crf.CRFClassifier;
import edu.stanford.nlp.ie.crf.CRFLogConditionalObjectiveFunctionForLOP;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Evaluator;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import edu.stanford.nlp.util.ConvertByteArray;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.zip.GZIPInputStream;

public class CRFClassifierWithLOP<IN extends CoreMap>
extends CRFClassifier<IN> {
    private static final Redwood.RedwoodChannels log = Redwood.channels(CRFClassifierWithLOP.class);
    private List<Set<Integer>> featureIndicesSetArray;
    private List<List<Integer>> featureIndicesListArray;

    protected CRFClassifierWithLOP() {
        super(new SeqClassifierFlags());
    }

    public CRFClassifierWithLOP(Properties props) {
        super(props);
    }

    public CRFClassifierWithLOP(SeqClassifierFlags flags) {
        super(flags);
    }

    private int[][][][] createPartialDataForLOP(int lopIter, int[][][][] data) {
        ArrayList<Integer> newFeatureList = new ArrayList<Integer>(1000);
        Set<Integer> featureIndicesSet = this.featureIndicesSetArray.get(lopIter);
        int[][][][] newData = new int[data.length][][][];
        for (int i = 0; i < data.length; ++i) {
            newData[i] = new int[data[i].length][][];
            for (int j = 0; j < data[i].length; ++j) {
                newData[i][j] = new int[data[i][j].length][];
                for (int k = 0; k < data[i][j].length; ++k) {
                    int[] oldFeatures = data[i][j][k];
                    newFeatureList.clear();
                    for (int oldFeatureIndex : oldFeatures) {
                        if (!featureIndicesSet.contains(oldFeatureIndex)) continue;
                        newFeatureList.add(oldFeatureIndex);
                    }
                    newData[i][j][k] = new int[newFeatureList.size()];
                    for (int l = 0; l < newFeatureList.size(); ++l) {
                        newData[i][j][k][l] = (Integer)newFeatureList.get(l);
                    }
                }
            }
        }
        return newData;
    }

    private void getFeatureBoundaryIndices(int numFeatures, int numLopExpert) {
        int lopIter;
        int interval = numFeatures / numLopExpert;
        this.featureIndicesSetArray = new ArrayList<Set<Integer>>(numLopExpert);
        this.featureIndicesListArray = new ArrayList<List<Integer>>(numLopExpert);
        for (int i = 0; i < numLopExpert; ++i) {
            this.featureIndicesSetArray.add(Generics.newHashSet(interval));
            this.featureIndicesListArray.add(Generics.newArrayList(interval));
        }
        if (this.flags.randomLopFeatureSplit) {
            for (int fIndex = 0; fIndex < numFeatures; ++fIndex) {
                int lopIter2 = this.random.nextInt(numLopExpert);
                this.featureIndicesSetArray.get(lopIter2).add(fIndex);
                this.featureIndicesListArray.get(lopIter2).add(fIndex);
            }
        } else {
            for (lopIter = 0; lopIter < numLopExpert; ++lopIter) {
                int beginIndex = lopIter * interval;
                int endIndex = (lopIter + 1) * interval;
                if (lopIter == numLopExpert - 1) {
                    endIndex = numFeatures;
                }
                for (int fIndex = beginIndex; fIndex < endIndex; ++fIndex) {
                    this.featureIndicesSetArray.get(lopIter).add(fIndex);
                    this.featureIndicesListArray.get(lopIter).add(fIndex);
                }
            }
        }
        for (lopIter = 0; lopIter < numLopExpert; ++lopIter) {
            Collections.sort(this.featureIndicesListArray.get(lopIter));
        }
    }

    @Override
    protected double[] trainWeights(int[][][][] data, int[][] labels, Evaluator[] evaluators, int pruneFeatureItr, double[][][][] featureVals) {
        double[] initialScales;
        Object parts;
        double[][] lopExpertWeights;
        int numLopExpert;
        block40: {
            int numFeatures = this.featureIndex.size();
            numLopExpert = this.flags.numLopExpert;
            lopExpertWeights = new double[numLopExpert][];
            this.getFeatureBoundaryIndices(numFeatures, numLopExpert);
            if (this.flags.initialLopWeights != null) {
                try (BufferedReader br = IOUtils.readerFromString(this.flags.initialLopWeights);){
                    String line;
                    log.info("Reading initial LOP weights from file " + this.flags.initialLopWeights + " ...");
                    ArrayList<double[]> listOfWeights = new ArrayList<double[]>(numLopExpert);
                    while ((line = br.readLine()) != null) {
                        line = line.trim();
                        parts = line.split("\t");
                        double[] wArr = new double[((String[])parts).length];
                        for (int i = 0; i < ((String[])parts).length; ++i) {
                            wArr[i] = Double.parseDouble(parts[i]);
                        }
                        listOfWeights.add(wArr);
                    }
                    assert (listOfWeights.size() == numLopExpert);
                    log.info("Done!");
                    for (int i = 0; i < numLopExpert; ++i) {
                        lopExpertWeights[i] = (double[])listOfWeights.get(i);
                    }
                    break block40;
                }
                catch (IOException e) {
                    throw new RuntimeException("Could not read from double initial LOP weights file " + this.flags.initialLopWeights);
                }
            }
            for (int lopIter = 0; lopIter < numLopExpert; ++lopIter) {
                int[][][][] partialData = this.createPartialDataForLOP(lopIter, data);
                lopExpertWeights[lopIter] = this.flags.randomLopWeights ? super.getObjectiveFunction(partialData, labels).initial() : super.trainWeights(partialData, labels, evaluators, pruneFeatureItr, null);
            }
            if (this.flags.includeFullCRFInLOP) {
                double[][] newLopExpertWeights = new double[numLopExpert + 1][];
                System.arraycopy(lopExpertWeights, 0, newLopExpertWeights, 0, lopExpertWeights.length);
                newLopExpertWeights[numLopExpert] = this.flags.randomLopWeights ? super.getObjectiveFunction(data, labels).initial() : super.trainWeights(data, labels, evaluators, pruneFeatureItr, null);
                Set<Integer> newSet = Generics.newHashSet(numFeatures);
                ArrayList<Integer> newList = new ArrayList<Integer>(numFeatures);
                for (int fIndex = 0; fIndex < numFeatures; ++fIndex) {
                    newSet.add(fIndex);
                    newList.add(fIndex);
                }
                this.featureIndicesSetArray.add(newSet);
                this.featureIndicesListArray.add(newList);
                ++numLopExpert;
                lopExpertWeights = newLopExpertWeights;
            }
        }
        CRFLogConditionalObjectiveFunctionForLOP func = new CRFLogConditionalObjectiveFunctionForLOP(data, labels, lopExpertWeights, this.windowSize, this.classIndex, this.labelIndices, this.map, this.flags.backgroundSymbol, numLopExpert, this.featureIndicesSetArray, this.featureIndicesListArray, this.flags.backpropLopTraining);
        this.cliquePotentialFunctionHelper = func;
        Minimizer<DiffFunction> minimizer = this.getMinimizer(0, evaluators);
        if (this.flags.initialLopScales == null) {
            initialScales = func.initial();
        } else {
            log.info("Reading initial LOP scales from file " + this.flags.initialLopScales);
            try {
                DataInputStream dis = new DataInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(this.flags.initialLopScales))));
                parts = null;
                try {
                    initialScales = ConvertByteArray.readDoubleArr(dis);
                }
                catch (Throwable wArr) {
                    parts = wArr;
                    throw wArr;
                }
                finally {
                    if (dis != null) {
                        if (parts != null) {
                            try {
                                dis.close();
                            }
                            catch (Throwable wArr) {
                                ((Throwable)parts).addSuppressed(wArr);
                            }
                        } else {
                            dis.close();
                        }
                    }
                }
            }
            catch (IOException e) {
                throw new RuntimeException("Could not read from double initial LOP scales file " + this.flags.initialLopScales);
            }
        }
        double[] learnedParams = minimizer.minimize(func, this.flags.tolerance, initialScales);
        double[] rawScales = func.separateLopScales(learnedParams);
        double[] lopScales = ArrayMath.softmax(rawScales);
        log.info("After SoftMax Transformation, learned scales are:");
        for (int lopIter = 0; lopIter < numLopExpert; ++lopIter) {
            log.info("lopScales[" + lopIter + "] = " + lopScales[lopIter]);
        }
        Object learnedLopExpertWeights = lopExpertWeights;
        if (this.flags.backpropLopTraining) {
            learnedLopExpertWeights = func.separateLopExpertWeights(learnedParams);
        }
        return CRFLogConditionalObjectiveFunctionForLOP.combineAndScaleLopWeights(numLopExpert, learnedLopExpertWeights, lopScales);
    }
}

