/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.experiment;

import com.googlecode.clearnlp.engine.EngineProcess;
import com.googlecode.clearnlp.feature.xml.POSFtrXml;
import com.googlecode.clearnlp.pos.POSLib;
import com.googlecode.clearnlp.pos.POSNode;
import com.googlecode.clearnlp.pos.POSTagger;
import com.googlecode.clearnlp.reader.POSReader;
import com.googlecode.clearnlp.run.POSTrain;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTXml;
import java.io.FileInputStream;
import java.util.Arrays;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

public class POSDevelop
extends POSTrain {
    @Option(name="-d", usage="the directory containing development files (input; required)", required=true, metaVar="<directory>")
    private String s_devDir;

    public POSDevelop() {
    }

    public POSDevelop(String[] args) {
        this.initArgs(args);
        try {
            this.run(this.s_configXml, this.s_featureXml, this.s_trainDir, this.s_devDir);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void run(String configXml, String featureXml, String trnDir, String devDir) throws Exception {
        int i;
        int[] lCounts;
        Element eConfig = UTXml.getDocumentElement(new FileInputStream(configXml));
        POSReader reader = (POSReader)this.getReader((Element)eConfig).o1;
        POSFtrXml xml = new POSFtrXml(new FileInputStream(featureXml));
        String[] trnFiles = UTFile.getSortedFileList(trnDir);
        String[] devFiles = UTFile.getSortedFileList(devDir);
        POSTagger[] taggers = this.getTrainedTaggers(eConfig, reader, xml, trnFiles, -1);
        int[] gCounts = new int[4];
        for (String devFile : devFiles) {
            lCounts = this.predict(devFile, reader, taggers);
            for (i = 0; i < gCounts.length; ++i) {
                int n = i;
                gCounts[n] = gCounts[n] + lCounts[i];
            }
        }
        System.out.println("Overall");
        this.printAccuracy(gCounts);
        gCounts = new int[2];
        for (double th = 0.01; th <= 0.03; th += 0.001) {
            System.out.println("Threshold: " + th);
            Arrays.fill(gCounts, 0);
            for (String devFile : devFiles) {
                lCounts = this.predict(devFile, reader, taggers, th);
                for (i = 0; i < gCounts.length; ++i) {
                    int n = i;
                    gCounts[n] = gCounts[n] + lCounts[i];
                }
            }
            System.out.println("Overall");
            this.printAccuracy(gCounts);
        }
    }

    protected int[] predict(String devFile, POSReader reader, POSTagger[] taggers, double threshold) {
        POSNode[] nodes;
        int[] counts = new int[]{0, 0};
        System.out.println("Predicting: " + devFile);
        reader.open(UTInput.createBufferedFileReader(devFile));
        while ((nodes = reader.next()) != null) {
            String[] gold = POSLib.getLabels(nodes);
            EngineProcess.normalizeForms(nodes);
            if (threshold < taggers[0].getCosineSimilarity(nodes)) {
                taggers[0].tag(nodes);
            } else {
                taggers[1].tag(nodes);
            }
            counts[0] = counts[0] + this.countCorrect(nodes, gold);
            counts[1] = counts[1] + gold.length;
        }
        reader.close();
        this.printAccuracy(counts);
        return counts;
    }

    protected int[] predict(String devFile, POSReader reader, POSTagger[] taggers) {
        POSNode[] nodes;
        int[] counts = new int[]{0, 0, 0, 0};
        int[] correct = new int[]{0, 0};
        System.out.println("Predicting: " + devFile);
        reader.open(UTInput.createBufferedFileReader(devFile));
        while ((nodes = reader.next()) != null) {
            String[] gold = POSLib.getLabels(nodes);
            EngineProcess.normalizeForms(nodes);
            Arrays.fill(correct, 0);
            for (int i = 0; i < 2; ++i) {
                taggers[i].tag(nodes);
                correct[i] = this.countCorrect(nodes, gold);
                int n = i;
                counts[n] = counts[n] + correct[i];
            }
            counts[2] = counts[2] + (correct[0] < correct[1] ? correct[1] : correct[0]);
            counts[3] = counts[3] + gold.length;
        }
        reader.close();
        this.printAccuracy(counts);
        return counts;
    }

    private void printAccuracy(int[] counts) {
        int last = counts.length - 1;
        for (int i = 0; i < last; ++i) {
            double accuracy = 100.0 * (double)counts[i] / (double)counts[last];
            System.out.printf("- accuracy %d: %7.5f (%d/%d)\n", i, accuracy, counts[i], counts[last]);
        }
    }

    private int countCorrect(POSNode[] nodes, String[] gold) {
        int correct = 0;
        int n = nodes.length;
        for (int i = 0; i < n; ++i) {
            if (!gold[i].equals(nodes[i].pos)) continue;
            ++correct;
        }
        return correct;
    }

    public static void main(String[] args) {
        new POSDevelop(args);
    }
}

