/*
 * Decompiled with CFR 0.152.
 */
package net.loomchild.maligna.model.translation;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import net.loomchild.maligna.model.ModelParseException;
import net.loomchild.maligna.model.translation.InitialTranslationModel;
import net.loomchild.maligna.model.translation.MutableSourceData;
import net.loomchild.maligna.model.translation.MutableTranslationModel;
import net.loomchild.maligna.model.translation.TranslationModel;
import net.loomchild.maligna.model.vocabulary.Vocabulary;

public class TranslationModelUtil {
    public static final int DEFAULT_TRAIN_ITERATION_COUNT = 4;

    public static TranslationModel train(int iterationCount, List<List<Integer>> sourceSegmentList, List<List<Integer>> targetSegmentList) {
        assert (sourceSegmentList.size() == targetSegmentList.size());
        assert (iterationCount >= 1);
        TranslationModel model = new InitialTranslationModel();
        MutableTranslationModel newModel = null;
        for (int iteration = 0; iteration < iterationCount; ++iteration) {
            newModel = TranslationModelUtil.performTrainingIteration(model, sourceSegmentList, targetSegmentList);
            model = newModel;
        }
        newModel.sort();
        return newModel;
    }

    public static TranslationModel train(List<List<Integer>> sourceSegmentList, List<List<Integer>> targetSegmentList) {
        return TranslationModelUtil.train(4, sourceSegmentList, targetSegmentList);
    }

    private static MutableTranslationModel performTrainingIteration(TranslationModel model, List<List<Integer>> sourceSegmentList, List<List<Integer>> targetSegmentList) {
        MutableTranslationModel newModel = new MutableTranslationModel();
        Iterator<List<Integer>> sourceSegmentIterator = sourceSegmentList.iterator();
        Iterator<List<Integer>> targetSegmentIterator = targetSegmentList.iterator();
        while (sourceSegmentIterator.hasNext() && targetSegmentIterator.hasNext()) {
            List<Integer> sourceSegment = sourceSegmentIterator.next();
            ArrayList<Integer> sourceSegmentAndNull = new ArrayList<Integer>(sourceSegment.size() + 1);
            sourceSegmentAndNull.addAll(sourceSegment);
            sourceSegmentAndNull.add(0);
            List<Integer> targetSegment = targetSegmentIterator.next();
            for (int targetWid : targetSegment) {
                double probabilitySum = 0.0;
                Iterator i$ = sourceSegmentAndNull.iterator();
                while (i$.hasNext()) {
                    int sourceWid = (Integer)i$.next();
                    probabilitySum += model.get(sourceWid).getTranslationProbability(targetWid);
                }
                assert (probabilitySum > 0.0);
                double minProbabilityChange = 1.0 / (double)sourceSegmentAndNull.size();
                Iterator i$2 = sourceSegmentAndNull.iterator();
                while (i$2.hasNext()) {
                    int sourceWid = (Integer)i$2.next();
                    double oldModelProbability = model.get(sourceWid).getTranslationProbability(targetWid);
                    double probabilityChange = oldModelProbability / probabilitySum;
                    MutableSourceData newModelData = probabilityChange >= minProbabilityChange ? newModel.getMutable(sourceWid) : newModel.getMutable(0);
                    double newModelProbability = newModelData.getTranslationProbability(targetWid);
                    newModelData.setTranslationProbability(targetWid, newModelProbability + probabilityChange);
                }
            }
        }
        newModel.normalize();
        return newModel;
    }

    public static TranslationModel parse(Reader reader, Vocabulary sourceVocabulary, Vocabulary targetVocabulary) {
        try {
            String line;
            BufferedReader bufferedReader = new BufferedReader(reader);
            MutableTranslationModel translationModel = new MutableTranslationModel();
            while ((line = bufferedReader.readLine()) != null) {
                String[] parts = line.split("\\s");
                if (parts.length == 3) {
                    String sourceWord = parts[0];
                    String targetWord = parts[1];
                    double probability = Double.parseDouble(parts[2]);
                    int sourceWid = sourceVocabulary.putWord(sourceWord);
                    int targetWid = targetVocabulary.putWord(targetWord);
                    MutableSourceData sourceData = translationModel.getMutable(sourceWid);
                    sourceData.setTranslationProbability(targetWid, probability);
                    continue;
                }
                if (parts.length == 0) continue;
                throw new ModelParseException("Bad number of line parts.");
            }
            translationModel.normalize();
            translationModel.sort();
            return translationModel;
        }
        catch (NumberFormatException e) {
            throw new ModelParseException("Part format error", e);
        }
        catch (IOException e) {
            throw new ModelParseException("IO error", e);
        }
    }
}

