From: <jen...@us...> - 2009-10-05 11:09:30
|
Revision: 1879 http://dl-learner.svn.sourceforge.net/dl-learner/?rev=1879&view=rev Author: jenslehmann Date: 2009-10-05 11:09:22 +0000 (Mon, 05 Oct 2009) Log Message: ----------- working nested cross validation script Modified Paths: -------------- trunk/src/dl-learner/org/dllearner/scripts/NestedCrossValidation.java trunk/src/dl-learner/org/dllearner/utilities/datastructures/Datastructures.java Modified: trunk/src/dl-learner/org/dllearner/scripts/NestedCrossValidation.java =================================================================== --- trunk/src/dl-learner/org/dllearner/scripts/NestedCrossValidation.java 2009-10-05 10:20:13 UTC (rev 1878) +++ trunk/src/dl-learner/org/dllearner/scripts/NestedCrossValidation.java 2009-10-05 11:09:22 UTC (rev 1879) @@ -24,11 +24,14 @@ import java.io.IOException; import java.text.DecimalFormat; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.Set; import java.util.TreeSet; +import java.util.Map.Entry; import org.apache.log4j.ConsoleAppender; import org.apache.log4j.Level; @@ -45,7 +48,9 @@ import org.dllearner.learningproblems.PosNegLP; import org.dllearner.parser.ParseException; import org.dllearner.utilities.Helper; +import org.dllearner.utilities.datastructures.Datastructures; import org.dllearner.utilities.datastructures.TrainTestList; +import org.dllearner.utilities.statistics.Stat; import joptsimple.OptionParser; import joptsimple.OptionSet; @@ -80,6 +85,10 @@ * * Currently, only learning from positive and negative examples is supported. * + * Currently, the script can only optimise towards classification accuracy. + * (Can be extended to handle optimising F measure or other combinations of + * precision, recall, accuracy.) + * * @author Jens Lehmann * */ @@ -129,6 +138,7 @@ String[] rangeSplit = range.split("-"); int rangeStart = new Integer(rangeSplit[0]); int rangeEnd = new Integer(rangeSplit[1]); + boolean verbose = options.has("v"); // create logger (a simple logger which outputs // its messages to the console) @@ -141,7 +151,8 @@ // disable OWL API info output java.util.logging.Logger.getLogger("").setLevel(java.util.logging.Level.WARNING); - new NestedCrossValidation(confFile, outerFolds, innerFolds, parameter, rangeStart, rangeEnd); + System.out.println("Warning: The script is not well tested yet. (No known bugs, but needs more testing.)"); + new NestedCrossValidation(confFile, outerFolds, innerFolds, parameter, rangeStart, rangeEnd, verbose); // an option is missing => print help screen and message } else { @@ -151,7 +162,7 @@ } - public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, int startValue, int endValue) throws FileNotFoundException, ComponentInitException, ParseException { + public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, int startValue, int endValue, boolean verbose) throws FileNotFoundException, ComponentInitException, ParseException { DecimalFormat df = new DecimalFormat(); ComponentManager cm = ComponentManager.getInstance(); @@ -170,18 +181,30 @@ LinkedList<Individual> negExamples = new LinkedList<Individual>(((PosNegLP)lp).getNegativeExamples()); Collections.shuffle(negExamples, new Random(2)); + ReasonerComponent rc = start.getReasonerComponent(); + String baseURI = rc.getBaseURI(); + List<TrainTestList> posLists = getFolds(posExamples, outerFolds); List<TrainTestList> negLists = getFolds(negExamples, outerFolds); + // overall statistics + Stat accOverall = new Stat(); + Stat fOverall = new Stat(); + Stat recallOverall = new Stat(); + Stat precisionOverall = new Stat(); + for(int currOuterFold=0; currOuterFold<outerFolds; currOuterFold++) { - System.out.println("Start processing outer fold " + currOuterFold); + System.out.println("Outer fold " + currOuterFold); TrainTestList posList = posLists.get(currOuterFold); TrainTestList negList = negLists.get(currOuterFold); + // measure relevant criterion (accuracy, F-measure) over different parameter values + Map<Integer,Stat> paraStats = new HashMap<Integer,Stat>(); + for(int currParaValue=startValue; currParaValue<=endValue; currParaValue++) { - System.out.println(" Start Processing parameter value " + currParaValue); + System.out.println(" Parameter value " + currParaValue + ":"); // split train folds again (computation of inner folds for each parameter // value is redundant, but not a big problem) List<Individual> trainPosList = posList.getTrainList(); @@ -189,101 +212,156 @@ List<Individual> trainNegList = negList.getTrainList(); List<TrainTestList> innerNegLists = getFolds(trainNegList, innerFolds); + // measure relevant criterion for parameter (by default accuracy, + // can also be F measure) + Stat paraCriterionStat = new Stat(); + for(int currInnerFold=0; currInnerFold<innerFolds; currInnerFold++) { - System.out.println(" Inner fold " + currInnerFold + " ... "); + System.out.println(" Inner fold " + currInnerFold + ":"); // get positive & negative examples for training run - List<Individual> posEx = innerPosLists.get(currInnerFold).getTrainList(); - List<Individual> negEx = innerNegLists.get(currInnerFold).getTrainList(); + Set<Individual> posEx = new TreeSet<Individual>(innerPosLists.get(currInnerFold).getTrainList()); + Set<Individual> negEx = new TreeSet<Individual>(innerNegLists.get(currInnerFold).getTrainList()); // read conf file and exchange options for pos/neg examples // and parameter to optimise start = new Start(confFile); - lp = start.getLearningProblem(); - cm.applyConfigEntry(lp, "positiveExamples", posEx); - cm.applyConfigEntry(lp, "negativeExamples", negEx); - LearningAlgorithm la = start.getLearningAlgorithm(); - cm.applyConfigEntry(la, parameter, (double)currParaValue); + LearningProblem lpIn = start.getLearningProblem(); + cm.applyConfigEntry(lpIn, "positiveExamples", Datastructures.individualSetToStringSet(posEx)); + cm.applyConfigEntry(lpIn, "negativeExamples", Datastructures.individualSetToStringSet(negEx)); + LearningAlgorithm laIn = start.getLearningAlgorithm(); + cm.applyConfigEntry(laIn, parameter, (double)currParaValue); - lp.init(); - la.init(); - la.start(); + lpIn.init(); + laIn.init(); + laIn.start(); // evaluate learned expression - Description concept = la.getCurrentlyBestDescription(); + Description concept = laIn.getCurrentlyBestDescription(); TreeSet<Individual> posTest = new TreeSet<Individual>(innerPosLists.get(currInnerFold).getTestList()); TreeSet<Individual> negTest = new TreeSet<Individual>(innerNegLists.get(currInnerFold).getTestList()); ReasonerComponent rs = start.getReasonerComponent(); + // true positive Set<Individual> posCorrect = rs.hasType(concept, posTest); + // false negative Set<Individual> posError = Helper.difference(posTest, posCorrect); + // false positive Set<Individual> negError = rs.hasType(concept, negTest); + // true negative Set<Individual> negCorrect = Helper.difference(negTest, negError); + +// double posErrorRate = 100*(posError.size()/posTest.size()); +// double negErrorRate = 100*(negError.size()/posTest.size()); -// System.out.println("test set errors pos: " + tmp2); -// System.out.println("test set errors neg: " + tmp3); - double accuracy = 100*((double)(posCorrect.size()+negCorrect.size())/(posTest.size()+negTest.size())); + double precision = 100 * (double) posCorrect.size() / (posCorrect.size() + negError.size()); + double recall = 100 * (double) posCorrect.size() / (posCorrect.size() + posError.size()); + double fmeasure = 2 * (precision * recall) / (precision + recall); - System.out.println(" accuracy: " + df.format(accuracy)); + paraCriterionStat.addNumber(accuracy); + System.out.println(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null)); + System.out.println(" accuracy: " + df.format(accuracy) + "%"); + System.out.println(" precision: " + df.format(precision) + "%"); + System.out.println(" recall: " + df.format(recall) + "%"); + System.out.println(" F measure: " + df.format(fmeasure) + "%"); + + if(verbose) { + System.out.println(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI)); + System.out.println(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI)); + } + // free memory rs.releaseKB(); cm.freeAllComponents(); } + paraStats.put(currParaValue, paraCriterionStat); + } - } - - /* - - // calculate splits using CV class - int[] splitsPos = CrossValidation.calculateSplits(posExamples.size(), outerFolds); - int[] splitsNeg = CrossValidation.calculateSplits(negExamples.size(), outerFolds); - - // the training and test sets used later on -// List<List<Individual>> trainingSetsPos = new LinkedList<List<Individual>>(); -// List<List<Individual>> trainingSetsNeg = new LinkedList<List<Individual>>(); -// List<List<Individual>> testSetsPos = new LinkedList<List<Individual>>(); -// List<List<Individual>> testSetsNeg = new LinkedList<List<Individual>>(); - - // calculating training and test sets for outer folds - for(int i=0; i<outerFolds; i++) { + // decide for the best parameter + System.out.println(" Summary over parameter values:"); + int bestPara = startValue; + double bestValue = Double.NEGATIVE_INFINITY; + for(Entry<Integer,Stat> entry : paraStats.entrySet()) { + int para = entry.getKey(); + Stat stat = entry.getValue(); + System.out.println(" value " + para + ": " + stat.prettyPrint("%")); + if(stat.getMean() > bestValue) { + bestPara = para; + bestValue = stat.getMean(); + } + } + System.out.println(" selected " + bestPara + " as best parameter value (criterion value " + df.format(bestValue) + "%)"); + System.out.println(" Learn on Outer fold:"); + // start a learning process with this parameter and evaluate it on the outer fold + start = new Start(confFile); + LearningProblem lpOut = start.getLearningProblem(); + cm.applyConfigEntry(lpOut, "positiveExamples", Datastructures.individualListToStringSet(posLists.get(currOuterFold).getTrainList())); + cm.applyConfigEntry(lpOut, "negativeExamples", Datastructures.individualListToStringSet(negLists.get(currOuterFold).getTrainList())); + LearningAlgorithm laOut = start.getLearningAlgorithm(); + cm.applyConfigEntry(laOut, parameter, (double)bestPara); + lpOut.init(); + laOut.init(); + laOut.start(); - // sets for positive examples - int posFromIndex = (i==0) ? 0 : splitsPos[i-1]; - int posToIndex = splitsPos[i]; - List<Individual> testPos = posExamples.subList(posFromIndex, posToIndex); - List<Individual> trainPos = new LinkedList<Individual>(posExamples); - trainPos.removeAll(testPos); - - // sets for negative examples - int negFromIndex = (i==0) ? 0 : splitsNeg[i-1]; - int negToIndex = splitsNeg[i]; - List<Individual> testNeg = posExamples.subList(negFromIndex, negToIndex); - List<Individual> trainNeg = new LinkedList<Individual>(negExamples); - trainNeg.removeAll(testNeg); - - // split train folds - int[] innerSplitPos = CrossValidation.calculateSplits(trainPos.size(), innerFolds); - int[] innerSplitNeg = CrossValidation.calculateSplits(trainNeg.size(), innerFolds); + // evaluate learned expression + Description concept = laOut.getCurrentlyBestDescription(); - for(int j=0; j<innerFolds; j++) { - - } + TreeSet<Individual> posTest = new TreeSet<Individual>(posLists.get(currOuterFold).getTestList()); + TreeSet<Individual> negTest = new TreeSet<Individual>(negLists.get(currOuterFold).getTestList()); - // add to list of folds -// trainingSetsPos.add(trainPos); -// trainingSetsNeg.add(trainNeg); -// testSetsPos.add(testPos); -// testSetsNeg.add(testNeg); - } + ReasonerComponent rs = start.getReasonerComponent(); + // true positive + Set<Individual> posCorrect = rs.hasType(concept, posTest); + // false negative + Set<Individual> posError = Helper.difference(posTest, posCorrect); + // false positive + Set<Individual> negError = rs.hasType(concept, negTest); + // true negative + Set<Individual> negCorrect = Helper.difference(negTest, negError); + + double accuracy = 100*((double)(posCorrect.size()+negCorrect.size())/(posTest.size()+negTest.size())); + double precision = 100 * (double) posCorrect.size() / (posCorrect.size() + negError.size()); + double recall = 100 * (double) posCorrect.size() / (posCorrect.size() + posError.size()); + double fmeasure = 2 * (precision * recall) / (precision + recall); + + System.out.println(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null)); + System.out.println(" accuracy: " + df.format(accuracy) + "%"); + System.out.println(" precision: " + df.format(precision) + "%"); + System.out.println(" recall: " + df.format(recall) + "%"); + System.out.println(" F measure: " + df.format(fmeasure) + "%"); + + if(verbose) { + System.out.println(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI)); + System.out.println(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI)); + } + + // update overall statistics + accOverall.addNumber(accuracy); + fOverall.addNumber(fmeasure); + recallOverall.addNumber(recall); + precisionOverall.addNumber(precision); + + // free memory + rs.releaseKB(); + cm.freeAllComponents(); + } - */ + // overall statistics + System.out.println(); + System.out.println("*******************"); + System.out.println("* Overall Results *"); + System.out.println("*******************"); + System.out.println("accuracy: " + accOverall.prettyPrint("%")); + System.out.println("F measure: " + fOverall.prettyPrint("%")); + System.out.println("precision: " + precisionOverall.prettyPrint("%")); + System.out.println("recall: " + recallOverall.prettyPrint("%")); } @@ -302,4 +380,17 @@ } return ret; } + + private static String formatIndividualSet(Set<Individual> inds, String baseURI) { + String ret = ""; + int i=0; + for(Individual ind : inds) { + ret += ind.toManchesterSyntaxString(baseURI, null) + " "; + i++; + if(i==20) { + break; + } + } + return ret; + } } Modified: trunk/src/dl-learner/org/dllearner/utilities/datastructures/Datastructures.java =================================================================== --- trunk/src/dl-learner/org/dllearner/utilities/datastructures/Datastructures.java 2009-10-05 10:20:13 UTC (rev 1878) +++ trunk/src/dl-learner/org/dllearner/utilities/datastructures/Datastructures.java 2009-10-05 11:09:22 UTC (rev 1879) @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.Iterator; +import java.util.List; import java.util.Set; import java.util.TreeSet; @@ -107,4 +108,12 @@ } return ret; } + + public static Set<String> individualListToStringSet(List<Individual> individuals) { + Set<String> ret = new TreeSet<String>(); + for(Individual ind : individuals) { + ret.add(ind.toString()); + } + return ret; + } } This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |