From: <jen...@us...> - 2011-11-02 19:33:00
|
Revision: 3363 http://dl-learner.svn.sourceforge.net/dl-learner/?rev=3363&view=rev Author: jenslehmann Date: 2011-11-02 19:32:53 +0000 (Wed, 02 Nov 2011) Log Message: ----------- cross validation for the new command line interface Modified Paths: -------------- trunk/interfaces/src/main/java/org/dllearner/cli/CLI.java Added Paths: ----------- trunk/examples/cross-validation/ trunk/examples/cross-validation/father.conf trunk/examples/cross-validation/father.owl trunk/examples/cross-validation/father.xml trunk/interfaces/src/main/java/org/dllearner/cli/CrossValidation.java Added: trunk/examples/cross-validation/father.conf =================================================================== --- trunk/examples/cross-validation/father.conf (rev 0) +++ trunk/examples/cross-validation/father.conf 2011-11-02 19:32:53 UTC (rev 3363) @@ -0,0 +1,33 @@ +/** + * Father Example + * + * possible solution: + * male AND EXISTS hasChild.TOP + * + * Copyright (C) 2007, Jens Lehmann + */ + +// perform cross validation +cli.type = "org.dllearner.cli.CLI" +cli.writeSpringConfiguration = true +cli.performCrossValidation = true +cli.nrOfFolds = 3 + +// declare some prefixes to use as abbreviations +prefixes = [ ("ex","http://example.com/father#") ] + +// knowledge source definition +ks.type = "OWL File" +ks.fileName = "father.owl" + +// reasoner +reasoner.type = "fast instance checker" +reasoner.sources = { ks } + +// learning problem +lp.type = "posNegStandard" +lp.positiveExamples = { "ex:stefan", "ex:markus", "ex:martin" } +lp.negativeExamples = { "ex:heinz", "ex:anna", "ex:michelle" } + +// create learning algorithm to run +alg.type = "ocel" Added: trunk/examples/cross-validation/father.owl =================================================================== --- trunk/examples/cross-validation/father.owl (rev 0) +++ trunk/examples/cross-validation/father.owl 2011-11-02 19:32:53 UTC (rev 3363) @@ -0,0 +1,35 @@ +<?xml version="1.0"?> +<rdf:RDF + xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" + xmlns:xsd="http://www.w3.org/2001/XMLSchema#" + xmlns="http://example.com/father#" + xmlns:rdfs="http://www.w3.org/2000/01/rdf-schema#" + xmlns:owl="http://www.w3.org/2002/07/owl#" + xml:base="http://example.com/father"> + <owl:Ontology rdf:about=""/> + <owl:Class rdf:ID="female"/> + <owl:Class rdf:ID="male"> + <owl:equivalentClass> + <owl:Class> + <owl:complementOf rdf:resource="#female"/> + </owl:Class> + </owl:equivalentClass> + </owl:Class> + <owl:ObjectProperty rdf:ID="hasChild"/> + <male rdf:ID="markus"> + <hasChild> + <female rdf:ID="anna"> + <hasChild> + <male rdf:ID="heinz"/> + </hasChild> + </female> + </hasChild> + </male> + <male rdf:ID="stefan"> + <hasChild rdf:resource="#markus"/> + </male> + <female rdf:ID="michelle"/> + <male rdf:ID="martin"> + <hasChild rdf:resource="#heinz"/> + </male> +</rdf:RDF> Added: trunk/examples/cross-validation/father.xml =================================================================== --- trunk/examples/cross-validation/father.xml (rev 0) +++ trunk/examples/cross-validation/father.xml 2011-11-02 19:32:53 UTC (rev 3363) @@ -0,0 +1,32 @@ +<beans xmlns="http://www.springframework.org/schema/beans"> + <bean class="org.dllearner.algorithms.ocel.OCEL" name="alg"/> + <bean class="org.dllearner.kb.OWLFile" name="ks"> + <property name="fileName" value="father.owl"/> + </bean> + <bean class="org.dllearner.learningproblems.PosNegLPStandard" name="lp"> + <property name="positiveExamples"> + <set> + <value>http://example.com/father#markus</value> + <value>http://example.com/father#stefan</value> + <value>http://example.com/father#martin</value> + </set> + </property> + <property name="negativeExamples"> + <set> + <value>http://example.com/father#heinz</value> + <value>http://example.com/father#anna</value> + <value>http://example.com/father#michelle</value> + </set> + </property> + </bean> + <bean class="org.dllearner.cli.CLI" name="cli"> + <property name="writeSpringConfiguration"/> + </bean> + <bean class="org.dllearner.reasoning.FastInstanceChecker" name="reasoner"> + <property name="sources"> + <set> + <ref bean="ks"/> + </set> + </property> + </bean> +</beans> \ No newline at end of file Modified: trunk/interfaces/src/main/java/org/dllearner/cli/CLI.java =================================================================== --- trunk/interfaces/src/main/java/org/dllearner/cli/CLI.java 2011-11-02 14:37:35 UTC (rev 3362) +++ trunk/interfaces/src/main/java/org/dllearner/cli/CLI.java 2011-11-02 19:32:53 UTC (rev 3363) @@ -38,8 +38,10 @@ import org.dllearner.confparser3.ConfParserConfiguration; import org.dllearner.confparser3.ParseException; import org.dllearner.core.AbstractCELA; +import org.dllearner.core.AbstractReasonerComponent; import org.dllearner.core.LearningAlgorithm; import org.dllearner.core.ReasoningMethodUnsupportedException; +import org.dllearner.learningproblems.PosNegLP; import org.dllearner.utilities.Files; import org.springframework.context.ApplicationContext; import org.springframework.core.io.FileSystemResource; @@ -57,59 +59,65 @@ private static Logger logger = Logger.getLogger(CLI.class); private static Logger rootLogger = Logger.getRootLogger(); - private boolean writeSpringConfiguration = false; private ApplicationContext context; + private File confFile; - public CLI(File file) throws IOException{ - Resource confFile = new FileSystemResource(file); + // some CLI options + private boolean writeSpringConfiguration = false; + private boolean performCrossValidation = false; + private int nrOfFolds = 10; + + public CLI() { - List<Resource> springConfigResources = new ArrayList<Resource>(); + } + + public CLI(File confFile) { + this(); + this.confFile = confFile; + } + + public void run() throws IOException { // ApplicationContext context, String algorithmBeanName){ + + IConfiguration configuration = null; + + if(context == null) { + Resource confFileR = new FileSystemResource(confFile); + List<Resource> springConfigResources = new ArrayList<Resource>(); + configuration = new ConfParserConfiguration(confFileR); - //DL-Learner Configuration Object - IConfiguration configuration = new ConfParserConfiguration(confFile); - - ApplicationContextBuilder builder = new DefaultApplicationContextBuilder(); - context = builder.buildApplicationContext(configuration,springConfigResources); - - // a lot of debugging stuff -// FastInstanceChecker fi = context.getBean("reasoner", FastInstanceChecker.class); -// System.out.println(fi.getClassHierarchy()); -// NamedClass male = new NamedClass("http://localhost/foo#male"); -// System.out.println(fi.getIndividuals(new NamedClass("http://localhost/foo#male"))); -// System.out.println(fi.getIndividuals().size()); -// System.out.println("has type: " + fi.hasTypeImpl(male, new Individual("http://localhost/foo#bernd"))); -// -// PosNegLPStandard lp = context.getBean("lp", PosNegLPStandard.class); -// System.out.println(lp.getPositiveExamples()); -// System.out.println(lp.getNegativeExamples()); -// System.out.println(lp.getAccuracy(new NamedClass("http://localhost/foo#male"))); - - // get a CLI bean if it exists - CLI cli = null; - if(context.getBeansOfType(CLI.class).size()>0) { - System.out.println(); - cli = context.getBean(CLI.class); + ApplicationContextBuilder builder = new DefaultApplicationContextBuilder(); + ApplicationContext context = builder.buildApplicationContext(configuration,springConfigResources); + } + + if(writeSpringConfiguration) { SpringConfigurationXMLBeanConverter converter = new SpringConfigurationXMLBeanConverter(); - XmlObject xml = converter.convert(configuration); - String springFilename = file.getCanonicalPath().replace(".conf", ".xml"); + XmlObject xml; + if(configuration == null) { + Resource confFileR = new FileSystemResource(confFile); + configuration = new ConfParserConfiguration(confFileR); + xml = converter.convert(configuration); + } else { + xml = converter.convert(configuration); + } + String springFilename = confFile.getCanonicalPath().replace(".conf", ".xml"); File springFile = new File(springFilename); if(springFile.exists()) { logger.warn("Cannot write Spring configuration, because " + springFilename + " already exists."); } else { Files.createFile(springFile, xml.toString()); - } -// SpringConfigurationXMLBeanConverter converter; - } - - // start algorithm in conf file -// LearningAlgorithm algorithm = context.getBean("alg",LearningAlgorithm.class); -// algorithm.start(); - } + } + } + + if(performCrossValidation) { + AbstractReasonerComponent rs = context.getBean(AbstractReasonerComponent.class); + PosNegLP lp = context.getBean(PosNegLP.class); + AbstractCELA la = context.getBean(AbstractCELA.class); + new CrossValidation(la,lp,rs,nrOfFolds,false); + } else { + LearningAlgorithm algorithm = context.getBean(LearningAlgorithm.class); + algorithm.start(); + } - public void run() { // ApplicationContext context, String algorithmBeanName){ - LearningAlgorithm algorithm = context.getBean(LearningAlgorithm.class); -// LearningAlgorithm algorithm = context.getBean(algorithmBeanName, LearningAlgorithm.class); - algorithm.start(); } public boolean isWriteSpringConfiguration() { @@ -150,12 +158,61 @@ System.exit(0); } - CLI cli = new CLI(file); + Resource confFile = new FileSystemResource(file); + + List<Resource> springConfigResources = new ArrayList<Resource>(); + + //DL-Learner Configuration Object + IConfiguration configuration = new ConfParserConfiguration(confFile); + + ApplicationContextBuilder builder = new DefaultApplicationContextBuilder(); + ApplicationContext context = builder.buildApplicationContext(configuration,springConfigResources); + + // TODO: later we could check which command line interface is specified in the conf file + // for now we just use the default one + + CLI cli; + if(context.containsBean("cli")) { + cli = (CLI) context.getBean("cli"); + } else { + cli = new CLI(); + } + cli.setContext(context); + cli.setConfFile(file); cli.run(); + } + public void setContext(ApplicationContext context) { + this.context = context; + } + public ApplicationContext getContext() { return context; } + public File getConfFile() { + return confFile; + } + + public void setConfFile(File confFile) { + this.confFile = confFile; + } + + public boolean isPerformCrossValidation() { + return performCrossValidation; + } + + public void setPerformCrossValidation(boolean performCrossValiation) { + this.performCrossValidation = performCrossValiation; + } + + public int getNrOfFolds() { + return nrOfFolds; + } + + public void setNrOfFolds(int nrOfFolds) { + this.nrOfFolds = nrOfFolds; + } + } Added: trunk/interfaces/src/main/java/org/dllearner/cli/CrossValidation.java =================================================================== --- trunk/interfaces/src/main/java/org/dllearner/cli/CrossValidation.java (rev 0) +++ trunk/interfaces/src/main/java/org/dllearner/cli/CrossValidation.java 2011-11-02 19:32:53 UTC (rev 3363) @@ -0,0 +1,286 @@ +/** + * Copyright (C) 2007-2008, Jens Lehmann + * + * This file is part of DL-Learner. + * + * DL-Learner is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * DL-Learner is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + */ +package org.dllearner.cli; + +import java.io.File; +import java.io.FileNotFoundException; +import java.text.DecimalFormat; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; + +import org.apache.log4j.ConsoleAppender; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.log4j.SimpleLayout; +import org.dllearner.core.ComponentInitException; +import org.dllearner.core.ComponentManager; +import org.dllearner.core.AbstractCELA; +import org.dllearner.core.AbstractLearningProblem; +import org.dllearner.core.AbstractReasonerComponent; +import org.dllearner.core.owl.Description; +import org.dllearner.core.owl.Individual; +import org.dllearner.learningproblems.Heuristics; +import org.dllearner.learningproblems.PosNegLP; +import org.dllearner.learningproblems.PosOnlyLP; +import org.dllearner.utilities.Helper; +import org.dllearner.utilities.datastructures.Datastructures; +import org.dllearner.utilities.statistics.Stat; +import org.dllearner.utilities.Files; + +/** + * Performs cross validation for the given problem. Supports + * k-fold cross-validation and leave-one-out cross-validation. + * + * @author Jens Lehmann + * + */ +public class CrossValidation { + + private static Logger logger = Logger.getRootLogger(); + + // statistical values + private Stat runtime = new Stat(); + private Stat accuracy = new Stat(); + private Stat length = new Stat(); + private Stat accuracyTraining = new Stat(); + private Stat fMeasure = new Stat(); + private Stat fMeasureTraining = new Stat(); + private static boolean writeToFile = false; + private static File outputFile; + + public CrossValidation(AbstractCELA la, PosNegLP lp, AbstractReasonerComponent rs, int folds, boolean leaveOneOut) { + + DecimalFormat df = new DecimalFormat(); + + // the training and test sets used later on + List<Set<Individual>> trainingSetsPos = new LinkedList<Set<Individual>>(); + List<Set<Individual>> trainingSetsNeg = new LinkedList<Set<Individual>>(); + List<Set<Individual>> testSetsPos = new LinkedList<Set<Individual>>(); + List<Set<Individual>> testSetsNeg = new LinkedList<Set<Individual>>(); + + // get examples and shuffle them too + Set<Individual> posExamples = ((PosNegLP)lp).getPositiveExamples(); + List<Individual> posExamplesList = new LinkedList<Individual>(posExamples); + Collections.shuffle(posExamplesList, new Random(1)); + Set<Individual> negExamples = ((PosNegLP)lp).getNegativeExamples(); + List<Individual> negExamplesList = new LinkedList<Individual>(negExamples); + Collections.shuffle(negExamplesList, new Random(2)); + + // sanity check whether nr. of folds makes sense for this benchmark + if(!leaveOneOut && (posExamples.size()<folds && negExamples.size()<folds)) { + System.out.println("The number of folds is higher than the number of " + + "positive/negative examples. This can result in empty test sets. Exiting."); + System.exit(0); + } + + if(leaveOneOut) { + // note that leave-one-out is not identical to k-fold with + // k = nr. of examples in the current implementation, because + // with n folds and n examples there is no guarantee that a fold + // is never empty (this is an implementation issue) + int nrOfExamples = posExamples.size() + negExamples.size(); + for(int i = 0; i < nrOfExamples; i++) { + // ... + } + System.out.println("Leave-one-out not supported yet."); + System.exit(1); + } else { + // calculating where to split the sets, ; note that we split + // positive and negative examples separately such that the + // distribution of positive and negative examples remains similar + // (note that there are better but more complex ways to implement this, + // which guarantee that the sum of the elements of a fold for pos + // and neg differs by at most 1 - it can differ by 2 in our implementation, + // e.g. with 3 folds, 4 pos. examples, 4 neg. examples) + int[] splitsPos = calculateSplits(posExamples.size(),folds); + int[] splitsNeg = calculateSplits(negExamples.size(),folds); + +// System.out.println(splitsPos[0]); +// System.out.println(splitsNeg[0]); + + // calculating training and test sets + for(int i=0; i<folds; i++) { + Set<Individual> testPos = getTestingSet(posExamplesList, splitsPos, i); + Set<Individual> testNeg = getTestingSet(negExamplesList, splitsNeg, i); + testSetsPos.add(i, testPos); + testSetsNeg.add(i, testNeg); + trainingSetsPos.add(i, getTrainingSet(posExamples, testPos)); + trainingSetsNeg.add(i, getTrainingSet(negExamples, testNeg)); + } + + } + + // run the algorithm + for(int currFold=0; currFold<folds; currFold++) { + + Set<String> pos = Datastructures.individualSetToStringSet(trainingSetsPos.get(currFold)); + Set<String> neg = Datastructures.individualSetToStringSet(trainingSetsNeg.get(currFold)); + lp.setPositiveExamples(trainingSetsPos.get(currFold)); + lp.setNegativeExamples(trainingSetsNeg.get(currFold)); + + try { + lp.init(); + la.init(); + } catch (ComponentInitException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + + long algorithmStartTime = System.nanoTime(); + la.start(); + long algorithmDuration = System.nanoTime() - algorithmStartTime; + runtime.addNumber(algorithmDuration/(double)1000000000); + + Description concept = la.getCurrentlyBestDescription(); + + Set<Individual> tmp = rs.hasType(concept, testSetsPos.get(currFold)); + Set<Individual> tmp2 = Helper.difference(testSetsPos.get(currFold), tmp); + Set<Individual> tmp3 = rs.hasType(concept, testSetsNeg.get(currFold)); + + outputWriter("test set errors pos: " + tmp2); + outputWriter("test set errors neg: " + tmp3); + + // calculate training accuracies + int trainingCorrectPosClassified = getCorrectPosClassified(rs, concept, trainingSetsPos.get(currFold)); + int trainingCorrectNegClassified = getCorrectNegClassified(rs, concept, trainingSetsNeg.get(currFold)); + int trainingCorrectExamples = trainingCorrectPosClassified + trainingCorrectNegClassified; + double trainingAccuracy = 100*((double)trainingCorrectExamples/(trainingSetsPos.get(currFold).size()+ + trainingSetsNeg.get(currFold).size())); + accuracyTraining.addNumber(trainingAccuracy); + // calculate test accuracies + int correctPosClassified = getCorrectPosClassified(rs, concept, testSetsPos.get(currFold)); + int correctNegClassified = getCorrectNegClassified(rs, concept, testSetsNeg.get(currFold)); + int correctExamples = correctPosClassified + correctNegClassified; + double currAccuracy = 100*((double)correctExamples/(testSetsPos.get(currFold).size()+ + testSetsNeg.get(currFold).size())); + accuracy.addNumber(currAccuracy); + // calculate training F-Score + int negAsPosTraining = rs.hasType(concept, trainingSetsNeg.get(currFold)).size(); + double precisionTraining = trainingCorrectPosClassified + negAsPosTraining == 0 ? 0 : trainingCorrectPosClassified / (double) (trainingCorrectPosClassified + negAsPosTraining); + double recallTraining = trainingCorrectPosClassified / (double) trainingSetsPos.get(currFold).size(); + fMeasureTraining.addNumber(100*Heuristics.getFScore(recallTraining, precisionTraining)); + // calculate test F-Score + int negAsPos = rs.hasType(concept, testSetsNeg.get(currFold)).size(); + double precision = correctPosClassified + negAsPos == 0 ? 0 : correctPosClassified / (double) (correctPosClassified + negAsPos); + double recall = correctPosClassified / (double) testSetsPos.get(currFold).size(); +// System.out.println(precision);System.out.println(recall); + fMeasure.addNumber(100*Heuristics.getFScore(recall, precision)); + + length.addNumber(concept.getLength()); + + outputWriter("fold " + currFold + ":"); + outputWriter(" training: " + pos.size() + " positive and " + neg.size() + " negative examples"); + outputWriter(" testing: " + correctPosClassified + "/" + testSetsPos.get(currFold).size() + " correct positives, " + + correctNegClassified + "/" + testSetsNeg.get(currFold).size() + " correct negatives"); + outputWriter(" concept: " + concept); + outputWriter(" accuracy: " + df.format(currAccuracy) + "% (" + df.format(trainingAccuracy) + "% on training set)"); + outputWriter(" length: " + df.format(concept.getLength())); + outputWriter(" runtime: " + df.format(algorithmDuration/(double)1000000000) + "s"); + + } + + outputWriter(""); + outputWriter("Finished " + folds + "-folds cross-validation."); + outputWriter("runtime: " + statOutput(df, runtime, "s")); + outputWriter("length: " + statOutput(df, length, "")); + outputWriter("F-Measure on training set: " + statOutput(df, fMeasureTraining, "%")); + outputWriter("F-Measure: " + statOutput(df, fMeasure, "%")); + outputWriter("predictive accuracy on training set: " + statOutput(df, accuracyTraining, "%")); + outputWriter("predictive accuracy: " + statOutput(df, accuracy, "%")); + + } + + private int getCorrectPosClassified(AbstractReasonerComponent rs, Description concept, Set<Individual> testSetPos) { + return rs.hasType(concept, testSetPos).size(); + } + + private int getCorrectNegClassified(AbstractReasonerComponent rs, Description concept, Set<Individual> testSetNeg) { + return testSetNeg.size() - rs.hasType(concept, testSetNeg).size(); + } + + public static Set<Individual> getTestingSet(List<Individual> examples, int[] splits, int fold) { + int fromIndex; + // we either start from 0 or after the last fold ended + if(fold == 0) + fromIndex = 0; + else + fromIndex = splits[fold-1]; + // the split corresponds to the ends of the folds + int toIndex = splits[fold]; + +// System.out.println("from " + fromIndex + " to " + toIndex); + + Set<Individual> testingSet = new HashSet<Individual>(); + // +1 because 2nd element is exclusive in subList method + testingSet.addAll(examples.subList(fromIndex, toIndex)); + return testingSet; + } + + public static Set<Individual> getTrainingSet(Set<Individual> examples, Set<Individual> testingSet) { + return Helper.difference(examples, testingSet); + } + + // takes nr. of examples and the nr. of folds for this examples; + // returns an array which says where each fold ends, i.e. + // splits[i] is the index of the last element of fold i in the examples + public static int[] calculateSplits(int nrOfExamples, int folds) { + int[] splits = new int[folds]; + for(int i=1; i<=folds; i++) { + // we always round up to the next integer + splits[i-1] = (int)Math.ceil(i*nrOfExamples/(double)folds); + } + return splits; + } + + public static String statOutput(DecimalFormat df, Stat stat, String unit) { + String str = "av. " + df.format(stat.getMean()) + unit; + str += " (deviation " + df.format(stat.getStandardDeviation()) + unit + "; "; + str += "min " + df.format(stat.getMin()) + unit + "; "; + str += "max " + df.format(stat.getMax()) + unit + ")"; + return str; + } + + public Stat getAccuracy() { + return accuracy; + } + + public Stat getLength() { + return length; + } + + public Stat getRuntime() { + return runtime; + } + + private void outputWriter(String output) { + if(writeToFile) { + Files.appendFile(outputFile, output +"\n"); + System.out.println(output); + } else { + System.out.println(output); + } + + } + +} This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |