From: <jen...@us...> - 2008-02-06 14:41:48
|
Revision: 498 http://dl-learner.svn.sourceforge.net/dl-learner/?rev=498&view=rev Author: jenslehmann Date: 2008-02-06 06:41:42 -0800 (Wed, 06 Feb 2008) Log Message: ----------- finished cross validator Modified Paths: -------------- trunk/src/dl-learner/org/dllearner/algorithms/refinement/ROLearner.java trunk/src/dl-learner/org/dllearner/cli/Start.java trunk/src/dl-learner/org/dllearner/learningproblems/PosOnlyDefinitionLP.java trunk/src/dl-learner/org/dllearner/learningproblems/PosOnlyLP.java trunk/src/dl-learner/org/dllearner/utilities/CrossValidation.java trunk/src/dl-learner/org/dllearner/utilities/Datastructures.java trunk/src/dl-learner/org/dllearner/utilities/Stat.java Modified: trunk/src/dl-learner/org/dllearner/algorithms/refinement/ROLearner.java =================================================================== --- trunk/src/dl-learner/org/dllearner/algorithms/refinement/ROLearner.java 2008-02-05 15:59:48 UTC (rev 497) +++ trunk/src/dl-learner/org/dllearner/algorithms/refinement/ROLearner.java 2008-02-06 14:41:42 UTC (rev 498) @@ -11,6 +11,7 @@ import java.util.SortedSet; import java.util.TreeSet; +import org.apache.log4j.Logger; import org.dllearner.core.LearningAlgorithm; import org.dllearner.core.LearningProblem; import org.dllearner.core.ReasoningService; @@ -38,6 +39,9 @@ public class ROLearner extends LearningAlgorithm { + private static Logger logger = Logger + .getLogger(LearningAlgorithm.class); + public enum Heuristic { LEXICOGRAPHIC, FLEXIBLE } // configuration options @@ -461,7 +465,7 @@ loop++; if(!quiet) - System.out.println("--- loop " + loop + " finished ---"); + logger.debug("--- loop " + loop + " finished ---"); } @@ -864,15 +868,15 @@ // Refinementoperator auf Konzept anwenden String bestNodeString = "currently best node: " + candidatesStable.last(); // searchTree += bestNodeString + "\n"; - System.out.println(bestNodeString); + logger.info(bestNodeString); String expandedNodeString = "next expanded node: " + candidates.last(); // searchTree += expandedNodeString + "\n"; - System.out.println(expandedNodeString); - System.out.println("algorithm runtime " + Helper.prettyPrintNanoSeconds(algorithmRuntime)); + logger.debug(expandedNodeString); + logger.debug("algorithm runtime " + Helper.prettyPrintNanoSeconds(algorithmRuntime)); String expansionString = "horizontal expansion: " + minimumHorizontalExpansion + " to " + maximumHorizontalExpansion; // searchTree += expansionString + "\n"; - System.out.println(expansionString); - System.out.println("size of candidate set: " + candidates.size()); + logger.debug(expansionString); + logger.debug("size of candidate set: " + candidates.size()); // System.out.println("properness max recursion depth: " + maxRecDepth); // System.out.println("max. number of one-step refinements: " + maxNrOfRefinements); // System.out.println("max. number of children of a node: " + maxNrOfChildren); @@ -916,9 +920,9 @@ System.out.println("onnf time percentage: " + df.format(onnfTimePercentage) + "%"); System.out.println("shortening time percentage: " + df.format(shorteningTimePercentage) + "%"); } - System.out.println("properness tests (reasoner/short concept/too weak list): " + propernessTestsReasoner + "/" + propernessTestsAvoidedByShortConceptConstruction + logger.debug("properness tests (reasoner/short concept/too weak list): " + propernessTestsReasoner + "/" + propernessTestsAvoidedByShortConceptConstruction + "/" + propernessTestsAvoidedByTooWeakList); - System.out.println("concept tests (reasoner/too weak list/overly general list/redundant concepts): " + conceptTestsReasoner + "/" + logger.debug("concept tests (reasoner/too weak list/overly general list/redundant concepts): " + conceptTestsReasoner + "/" + conceptTestsTooWeakList + "/" + conceptTestsOverlyGeneralList + "/" + redundantConcepts); } Modified: trunk/src/dl-learner/org/dllearner/cli/Start.java =================================================================== --- trunk/src/dl-learner/org/dllearner/cli/Start.java 2008-02-05 15:59:48 UTC (rev 497) +++ trunk/src/dl-learner/org/dllearner/cli/Start.java 2008-02-06 14:41:42 UTC (rev 498) @@ -710,4 +710,12 @@ return la; } + public LearningProblem getLearningProblem() { + return lp; + } + + public ReasoningService getReasoningService() { + return rs; + } + } Modified: trunk/src/dl-learner/org/dllearner/learningproblems/PosOnlyDefinitionLP.java =================================================================== --- trunk/src/dl-learner/org/dllearner/learningproblems/PosOnlyDefinitionLP.java 2008-02-05 15:59:48 UTC (rev 497) +++ trunk/src/dl-learner/org/dllearner/learningproblems/PosOnlyDefinitionLP.java 2008-02-06 14:41:42 UTC (rev 498) @@ -19,18 +19,10 @@ */ package org.dllearner.learningproblems; -import java.util.Collection; -import java.util.LinkedList; -import java.util.Set; import java.util.SortedSet; import org.dllearner.core.ReasoningService; import org.dllearner.core.Score; -import org.dllearner.core.config.CommonConfigMappings; -import org.dllearner.core.config.ConfigEntry; -import org.dllearner.core.config.ConfigOption; -import org.dllearner.core.config.InvalidConfigOptionValueException; -import org.dllearner.core.config.StringSetConfigOption; import org.dllearner.core.dl.Concept; import org.dllearner.core.dl.Individual; import org.dllearner.utilities.Helper; @@ -55,47 +47,12 @@ /* * (non-Javadoc) * - * @see org.dllearner.core.Component#applyConfigEntry(org.dllearner.core.ConfigEntry) - */ - @Override - @SuppressWarnings( { "unchecked" }) - public <T> void applyConfigEntry(ConfigEntry<T> entry) throws InvalidConfigOptionValueException { - String name = entry.getOptionName(); - if (name.equals("positiveExamples")) - positiveExamples = CommonConfigMappings - .getIndividualSet((Set<String>) entry.getValue()); - } - - public static Collection<ConfigOption<?>> createConfigOptions() { - Collection<ConfigOption<?>> options = new LinkedList<ConfigOption<?>>(); - options.add(new StringSetConfigOption("positiveExamples", - "positive examples")); - return options; - } - - /* - * (non-Javadoc) - * * @see org.dllearner.core.Component#getName() */ public static String getName() { return "positive only definition learning problem"; } - - /** - * @return the positiveExamples - */ - public SortedSet<Individual> getPositiveExamples() { - return positiveExamples; - } - /** - * @return the pseudoNegatives - */ - public SortedSet<Individual> getPseudoNegatives() { - return pseudoNegatives; - } - /* (non-Javadoc) * @see org.dllearner.core.Component#init() */ Modified: trunk/src/dl-learner/org/dllearner/learningproblems/PosOnlyLP.java =================================================================== --- trunk/src/dl-learner/org/dllearner/learningproblems/PosOnlyLP.java 2008-02-05 15:59:48 UTC (rev 497) +++ trunk/src/dl-learner/org/dllearner/learningproblems/PosOnlyLP.java 2008-02-06 14:41:42 UTC (rev 498) @@ -19,17 +19,64 @@ */ package org.dllearner.learningproblems; +import java.util.Collection; +import java.util.LinkedList; +import java.util.Set; +import java.util.SortedSet; + import org.dllearner.core.LearningProblem; import org.dllearner.core.ReasoningService; +import org.dllearner.core.config.CommonConfigMappings; +import org.dllearner.core.config.ConfigEntry; +import org.dllearner.core.config.ConfigOption; +import org.dllearner.core.config.InvalidConfigOptionValueException; +import org.dllearner.core.config.StringSetConfigOption; +import org.dllearner.core.dl.Individual; /** + * A learning problem, where we learn from positive examples only. + * * @author Jens Lehmann * */ public abstract class PosOnlyLP extends LearningProblem { + protected SortedSet<Individual> positiveExamples; + protected SortedSet<Individual> pseudoNegatives; + public PosOnlyLP(ReasoningService reasoningService) { super(reasoningService); } + /* + * (non-Javadoc) + * + * @see org.dllearner.core.Component#applyConfigEntry(org.dllearner.core.ConfigEntry) + */ + @Override + @SuppressWarnings( { "unchecked" }) + public <T> void applyConfigEntry(ConfigEntry<T> entry) throws InvalidConfigOptionValueException { + String name = entry.getOptionName(); + if (name.equals("positiveExamples")) + positiveExamples = CommonConfigMappings + .getIndividualSet((Set<String>) entry.getValue()); + } + + public static Collection<ConfigOption<?>> createConfigOptions() { + Collection<ConfigOption<?>> options = new LinkedList<ConfigOption<?>>(); + options.add(new StringSetConfigOption("positiveExamples", + "positive examples")); + return options; + } + + public SortedSet<Individual> getPositiveExamples() { + return positiveExamples; + } + + /** + * @return the pseudoNegatives + */ + public SortedSet<Individual> getPseudoNegatives() { + return pseudoNegatives; + } } Modified: trunk/src/dl-learner/org/dllearner/utilities/CrossValidation.java =================================================================== --- trunk/src/dl-learner/org/dllearner/utilities/CrossValidation.java 2008-02-05 15:59:48 UTC (rev 497) +++ trunk/src/dl-learner/org/dllearner/utilities/CrossValidation.java 2008-02-06 14:41:42 UTC (rev 498) @@ -20,16 +20,30 @@ package org.dllearner.utilities; import java.io.File; +import java.text.DecimalFormat; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; import org.apache.log4j.ConsoleAppender; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.log4j.SimpleLayout; import org.dllearner.cli.Start; +import org.dllearner.core.ComponentManager; import org.dllearner.core.LearningAlgorithm; +import org.dllearner.core.LearningProblem; +import org.dllearner.core.ReasoningService; +import org.dllearner.core.dl.Concept; +import org.dllearner.core.dl.Individual; +import org.dllearner.learningproblems.PosNegLP; +import org.dllearner.learningproblems.PosOnlyLP; /** - * Performs cross validation for the given problem. + * Performs cross validation for the given problem. Supports + * k-fold cross-validation and leave-one-out cross-validation. * * @author Jens Lehmann * @@ -41,9 +55,15 @@ public static void main(String[] args) { File file = new File(args[0]); + boolean leaveOneOut = false; int folds = 10; + + // use second argument as number of folds; if not specified + // leave one out cross validation is used if(args.length > 1) folds = Integer.parseInt(args[1]); + else + leaveOneOut = true; // create logger (a simple logger which outputs // its messages to the console) @@ -51,19 +71,181 @@ ConsoleAppender consoleAppender = new ConsoleAppender(layout); logger.removeAllAppenders(); logger.addAppender(consoleAppender); - logger.setLevel(Level.INFO); + logger.setLevel(Level.WARN); - new CrossValidation(file, folds); + new CrossValidation(file, folds, leaveOneOut); } - public CrossValidation(File file, int folds) { + public CrossValidation(File file, int folds, boolean leaveOneOut) { + + DecimalFormat df = new DecimalFormat(); + ComponentManager cm = ComponentManager.getInstance(); + + // the first read of the file is used to detect the examples + // and set up the splits correctly according to our validation + // method Start start = new Start(file); - LearningAlgorithm la = start.getLearningAlgorithm(); + LearningProblem lp = start.getLearningProblem(); + ReasoningService rs = start.getReasoningService(); + + // the training and test sets used later on + List<Set<Individual>> trainingSetsPos = new LinkedList<Set<Individual>>(); + List<Set<Individual>> trainingSetsNeg = new LinkedList<Set<Individual>>(); + List<Set<Individual>> testSetsPos = new LinkedList<Set<Individual>>(); + List<Set<Individual>> testSetsNeg = new LinkedList<Set<Individual>>(); + + if(lp instanceof PosNegLP) { + + Set<Individual> posExamples = ((PosNegLP)lp).getPositiveExamples(); + List<Individual> posExamplesList = new LinkedList<Individual>(posExamples); + Set<Individual> negExamples = ((PosNegLP)lp).getNegativeExamples(); + List<Individual> negExamplesList = new LinkedList<Individual>(negExamples); + + // sanity check whether nr. of folds makes sense for this benchmark + if(!leaveOneOut && (posExamples.size()<folds || negExamples.size()<folds)) { + System.out.println("The number of folds is higher than the number of " + + "positive/negative examples. This can result in empty test sets. Exiting."); + System.exit(0); + } + + if(leaveOneOut) { + // note that leave-one-out is not identical to k-fold with + // k = nr. of examples in the current implementation, because + // with n folds and n examples there is no guarantee that a fold + // is never empty (this is an implementation issue) + int nrOfExamples = posExamples.size() + negExamples.size(); + for(int i = 0; i < nrOfExamples; i++) { + // ... + } + System.out.println("Leave-one-out not supported yet."); + System.exit(1); + } else { + // calculating where to split the sets, ; note that we split + // positive and negative examples separately such that the + // distribution of positive and negative examples remains similar + // (note that there better but more complex ways to implement this, + // which guarantee that the sum of the elements of a fold for pos + // and neg differs by at most 1 - it can differ by 2 in our implementation, + // e.g. with 3 folds, 4 pos. examples, 4 neg. examples) + int[] splitsPos = calculateSplits(posExamples.size(),folds); + int[] splitsNeg = calculateSplits(negExamples.size(),folds); + + // calculating training and test sets + for(int i=0; i<folds; i++) { + Set<Individual> testPos = getTestingSet(posExamplesList, splitsPos, i); + Set<Individual> testNeg = getTestingSet(negExamplesList, splitsNeg, i); + testSetsPos.add(i, testPos); + testSetsNeg.add(i, testNeg); + trainingSetsPos.add(i, getTrainingSet(posExamples, testPos)); + trainingSetsNeg.add(i, getTrainingSet(negExamples, testNeg)); + } + + } + + } else if(lp instanceof PosOnlyLP) { + System.out.println("Cross validation for positive only learning not supported yet."); + System.exit(0); + // Set<Individual> posExamples = ((PosOnlyLP)lp).getPositiveExamples(); + // int[] splits = calculateSplits(posExamples.size(),folds); + } else { + System.out.println("Cross validation for learning problem " + lp + " not supported."); + System.exit(0); + } + + // statistical values + Stat runtime = new Stat(); + Stat accuracy = new Stat(); + Stat length = new Stat(); + + // run the algorithm for(int currFold=0; currFold<folds; currFold++) { + // we always perform a full initialisation to make sure that + // no objects are reused + start = new Start(file); + lp = start.getLearningProblem(); + Set<String> pos = Datastructures.individualSetToStringSet(trainingSetsPos.get(currFold)); + Set<String> neg = Datastructures.individualSetToStringSet(trainingSetsNeg.get(currFold)); + cm.applyConfigEntry(lp, "positiveExamples", pos); + cm.applyConfigEntry(lp, "negativeExamples", neg); + + LearningAlgorithm la = start.getLearningAlgorithm(); + long algorithmStartTime = System.nanoTime(); la.start(); - } + long algorithmDuration = System.nanoTime() - algorithmStartTime; + runtime.addNumber(algorithmDuration/(double)1000000000); + + Concept concept = la.getBestSolution(); + int correctExamples = getCorrectPosClassified(rs, concept, testSetsPos.get(currFold)) + + getCorrectNegClassified(rs, concept, testSetsNeg.get(currFold)); + double currAccuracy = 100*((double)correctExamples/(testSetsPos.get(currFold).size()+ + testSetsNeg.get(currFold).size())); + accuracy.addNumber(currAccuracy); + + length.addNumber(concept.getLength()); + + System.out.println("fold " + currFold + " (" + file + "):"); + System.out.println(" concept: " + concept); + System.out.println(" accuracy: " + df.format(currAccuracy) + "%"); + System.out.println(" length: " + df.format(concept.getLength())); + System.out.println(" runtime: " + df.format(algorithmDuration/(double)1000000000) + "s"); + } + + System.out.println(); + System.out.println("Finished " + folds + "-folds cross-validation on " + file + "."); + System.out.println("runtime: " + statOutput(df, runtime, "s")); + System.out.println("length: " + statOutput(df, length, "")); + System.out.println("accuracy: " + statOutput(df, accuracy, "%")); + } + private int getCorrectPosClassified(ReasoningService rs, Concept concept, Set<Individual> posClassified) { + return rs.instanceCheck(concept, posClassified).size(); + } + + private int getCorrectNegClassified(ReasoningService rs, Concept concept, Set<Individual> negClassified) { + return negClassified.size() - rs.instanceCheck(concept, negClassified).size(); + } + + private Set<Individual> getTestingSet(List<Individual> examples, int[] splits, int fold) { + int fromIndex; + // we either start from 0 or after the last fold ended + if(fold == 0) + fromIndex = 0; + else + fromIndex = splits[fold-1]; + // the split corresponds to the ends of the folds + int toIndex = splits[fold]; + + Set<Individual> testingSet = new HashSet<Individual>(); + // +1 because 2nd element is exclusive in subList method + testingSet.addAll(examples.subList(fromIndex, toIndex)); + return testingSet; + } + + private Set<Individual> getTrainingSet(Set<Individual> examples, Set<Individual> testingSet) { + return Helper.difference(examples, testingSet); + } + + // takes nr. of examples and the nr. of folds for this examples; + // returns an array which says where each fold ends, i.e. + // splits[i] is the index of the last element of fold i in the examples + private int[] calculateSplits(int nrOfExamples, int folds) { + int[] splits = new int[folds]; + for(int i=1; i<=folds; i++) { + // we always round up to the next integer + splits[i-1] = (int)Math.ceil(i*nrOfExamples/(double)folds); + } + return splits; + } + + private String statOutput(DecimalFormat df, Stat stat, String unit) { + String str = "av. " + df.format(stat.getMean()) + unit; + str += " (deviation " + df.format(stat.getStandardDeviation()) + unit + "; "; + str += "min " + df.format(stat.getMin()) + unit + "; "; + str += "max " + df.format(stat.getMax()) + unit + ")"; + return str; + } + } Modified: trunk/src/dl-learner/org/dllearner/utilities/Datastructures.java =================================================================== --- trunk/src/dl-learner/org/dllearner/utilities/Datastructures.java 2008-02-05 15:59:48 UTC (rev 497) +++ trunk/src/dl-learner/org/dllearner/utilities/Datastructures.java 2008-02-06 14:41:42 UTC (rev 498) @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.Set; +import java.util.TreeSet; import org.dllearner.core.dl.AtomicConcept; import org.dllearner.core.dl.AtomicRole; @@ -99,4 +100,12 @@ Arrays.sort(ret); return ret; } + + public static Set<String> individualSetToStringSet(Set<Individual> individuals) { + Set<String> ret = new TreeSet<String>(); + for(Individual ind : individuals) { + ret.add(ind.toString()); + } + return ret; + } } Modified: trunk/src/dl-learner/org/dllearner/utilities/Stat.java =================================================================== --- trunk/src/dl-learner/org/dllearner/utilities/Stat.java 2008-02-05 15:59:48 UTC (rev 497) +++ trunk/src/dl-learner/org/dllearner/utilities/Stat.java 2008-02-06 14:41:42 UTC (rev 498) @@ -32,6 +32,8 @@ private int count = 0; private double sum = 0; private double squareSum = 0; + private double min = Double.MAX_VALUE; + private double max = Double.MIN_NORMAL; /** * Add a number to this object. @@ -43,6 +45,10 @@ count++; sum += number; squareSum += number * number; + if(number<min) + min=number; + if(number>max) + max=number; } /** @@ -94,4 +100,18 @@ return root; } + /** + * @return the min + */ + public double getMin() { + return min; + } + + /** + * @return the max + */ + public double getMax() { + return max; + } + } This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |