From: <lor...@us...> - 2014-05-08 14:40:17
|
Revision: 4265 http://sourceforge.net/p/dl-learner/code/4265 Author: lorenz_b Date: 2014-05-08 14:40:15 +0000 (Thu, 08 May 2014) Log Message: ----------- Extended CV script. Modified Paths: -------------- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/datastructures/impl/QueryTreeImpl.java trunk/components-core/src/main/java/org/dllearner/utilities/statistics/Stat.java trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java Modified: trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java =================================================================== --- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java 2014-05-08 14:39:43 UTC (rev 4264) +++ trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java 2014-05-08 14:40:15 UTC (rev 4265) @@ -17,6 +17,7 @@ import java.util.SortedSet; import java.util.TreeSet; +import org.apache.commons.collections.ListUtils; import org.apache.log4j.Logger; import org.dllearner.algorithms.qtl.cache.QueryTreeCache; import org.dllearner.algorithms.qtl.datastructures.QueryTree; @@ -69,7 +70,7 @@ private Queue<EvaluatedQueryTree<String>> todoList; private SortedSet<EvaluatedQueryTree<String>> currentPartialSolutions; - private double currentlyBestScore = 0d; + private double bestCurrentScore = 0d; private List<QueryTree<String>> currentPosExampleTrees; private List<QueryTree<String>> currentNegExampleTrees; @@ -114,11 +115,17 @@ private double posWeight = 2; // minimum score a query tree must have to be part of the solution private double minimumTreeScore = 0.2; + //If yes, then the algorithm tries to cover all positive examples. Note that while this improves accuracy on the testing set, + //it may lead to overfitting + private boolean tryFullCoverage; + //algorithm will terminate immediately when a correct definition is found + private boolean stopOnFirstDefinition; private long startTime; - private long partialSolutionStartTime; + private double startPosExamplesSize; + public QTL2Disjunctive() {} public QTL2Disjunctive(PosNegLP learningProblem, AbstractReasonerComponent reasoner) throws LearningProblemUnsupportedException{ @@ -156,7 +163,7 @@ if(heuristic == null){ heuristic = new QueryTreeHeuristic(); - heuristic.setPosExamplesWeight(2); + heuristic.setPosExamplesWeight(posWeight); } logger.info("Initializing..."); @@ -168,6 +175,8 @@ currentPosExamples = new TreeSet<Individual>(lp.getPositiveExamples()); currentNegExamples = new TreeSet<Individual>(lp.getNegativeExamples()); + startPosExamplesSize = currentPosExamples.size(); + //get the query trees generateTrees(); @@ -178,6 +187,15 @@ //console rendering of class expressions ToStringRenderer.getInstance().setRenderer(new ManchesterOWLSyntaxOWLObjectRendererImpl()); ToStringRenderer.getInstance().setShortFormProvider(new SimpleShortFormProvider()); + + //compute the LGG for all examples + //this allows us to prune all other trees because we can omit paths in trees which are contained in all positive + //as well as negative examples +// List<QueryTree<String>> allExamplesTrees = new ArrayList<QueryTree<String>>(); +// allExamplesTrees.addAll(currentPosExampleTrees); +// allExamplesTrees.addAll(currentNegExampleTrees); +// QueryTree<String> lgg = lggGenerator.getLGG(allExamplesTrees); +// lgg.dump(); } private void generateTrees(){ @@ -204,7 +222,7 @@ String setup = "Setup:"; setup += "\n#Pos. examples:" + currentPosExamples.size(); setup += "\n#Neg. examples:" + currentNegExamples.size(); - setup += "\nCoverage beta:" + coverageBeta; + setup += "\nPos. weight(beta):" + posWeight; logger.info(setup); logger.info("Running..."); startTime = System.currentTimeMillis(); @@ -268,7 +286,7 @@ private void computeNextPartialSolution(){ logger.info("Computing best partial solution..."); - currentlyBestScore = 0d; + bestCurrentScore = Double.NEGATIVE_INFINITY; partialSolutionStartTime = System.currentTimeMillis(); initTodoList(currentPosExampleTrees, currentNegExampleTrees); @@ -291,14 +309,14 @@ double score = solution.getScore(); double mas = heuristic.getMaximumAchievableScore(solution); - if(score >= currentlyBestScore){ + if(score >= bestCurrentScore){ //add to todo list, if not already contained in todo list or solution list todo(solution); - if(solution.getScore() > currentlyBestScore){ + if(solution.getScore() > bestCurrentScore){ logger.info("Got better solution:" + solution.getTreeScore()); } - currentlyBestScore = solution.getScore(); - } else if(mas < currentlyBestScore){ + bestCurrentScore = solution.getScore(); + } else if(mas < bestCurrentScore){ todo(solution); } else { System.out.println("Too general"); @@ -431,6 +449,8 @@ subMon.reset(); lggMon.reset(); + + bestCurrentScore = minimumTreeScore; } @@ -578,8 +598,31 @@ return tree1.isSubsumedBy(tree2) && tree2.isSubsumedBy(tree1); } - private boolean terminationCriteriaSatisfied(){ - return stop || isTimeExpired() || currentPosExampleTrees.isEmpty(); + private boolean terminationCriteriaSatisfied() { + //stop was called or time expired + if(stop || isTimeExpired()){ + return true; + } + + // stop if there are no more positive examples to cover + if (stopOnFirstDefinition && currentPosExamples.isEmpty()) { + return true; + } + + // we stop when the score of the last tree added is too low + // (indicating that the algorithm could not find anything appropriate + // in the timeframe set) + if (bestCurrentScore < minimumTreeScore) { + return true; + } + + // stop when almost all positive examples have been covered + if (tryFullCoverage) { + return false; + } else { + int maxPosRemaining = (int) Math.ceil(startPosExamplesSize * 0.05d); + return (currentPosExamples.size() <= maxPosRemaining); + } } private boolean partialSolutionTerminationCriteriaSatisfied(){ @@ -635,6 +678,13 @@ this.coverageBeta = coverageBeta; } + /** + * @param posWeight the posWeight to set + */ + public void setPosWeight(double posWeight) { + this.posWeight = posWeight; + } + /* (non-Javadoc) * @see java.lang.Object#clone() */ Modified: trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java =================================================================== --- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java 2014-05-08 14:39:43 UTC (rev 4264) +++ trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java 2014-05-08 14:40:15 UTC (rev 4265) @@ -63,7 +63,7 @@ case FMEASURE : score = Heuristics.getFScore(tp/(tp+fn), tp/(tp+fp), posExamplesWeight);break; case PRED_ACC : - score = (posExamplesWeight * tp + tn) / (posExamplesWeight * (tp + fn) + tn + fp);break; + score = (tp + posExamplesWeight * tn) / ((tp + fn) + posExamplesWeight * (tn + fp));break; case ENTROPY :{ double total = tp + fn; double pp = tp / total; Modified: trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/datastructures/impl/QueryTreeImpl.java =================================================================== --- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/datastructures/impl/QueryTreeImpl.java 2014-05-08 14:39:43 UTC (rev 4264) +++ trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/datastructures/impl/QueryTreeImpl.java 2014-05-08 14:40:15 UTC (rev 4265) @@ -161,7 +161,7 @@ label += "Values: " + object.getLiterals(); } } - label += object.isResourceNode() + "," + object.isLiteralNode(); +// label += object.isResourceNode() + "," + object.isLiteralNode(); return label; } }; @@ -801,13 +801,16 @@ writer.println(ren); for (QueryTree<N> child : getChildren()) { Object edge = getEdge(child); - if (edge != null) { + boolean meaningful = !edge.equals(RDF.type.getURI()) || meaningful(child); + if (edge != null && meaningful) { writer.print(sb.toString()); writer.print("--- "); writer.print(edge); writer.print(" ---\n"); } - child.dump(writer, indent); + if(meaningful){ + child.dump(writer, indent); + } } writer.flush(); // int depth = getPathToRoot().size(); @@ -832,6 +835,23 @@ // writer.flush(); } + private boolean meaningful(QueryTree<N> tree){ + if(tree.isResourceNode() || tree.isLiteralNode()){ + return true; + } else { + for (QueryTree<N> child : tree.getChildren()) { + Object edge = tree.getEdge(child); + if(!edge.equals(RDFS.subClassOf.getURI())){ + return true; + } else if(child.isResourceNode()){ + return true; + } else if(meaningful(child)){ + return true; + } + } + } + return false; + } public List<N> fillDepthFirst() { List<N> results = new ArrayList<N>(); Modified: trunk/components-core/src/main/java/org/dllearner/utilities/statistics/Stat.java =================================================================== --- trunk/components-core/src/main/java/org/dllearner/utilities/statistics/Stat.java 2014-05-08 14:39:43 UTC (rev 4264) +++ trunk/components-core/src/main/java/org/dllearner/utilities/statistics/Stat.java 2014-05-08 14:40:15 UTC (rev 4265) @@ -75,6 +75,14 @@ } } + public void add(Stat stat){ + count += stat.count; + sum += stat.sum; + squareSum += stat.squareSum; + min = Math.min(min, stat.min); + max = Math.max(max, stat.max); + } + /** * Add a number to this object. * Modified: trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java =================================================================== --- trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java 2014-05-08 14:39:43 UTC (rev 4264) +++ trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java 2014-05-08 14:40:15 UTC (rev 4265) @@ -26,6 +26,7 @@ import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.text.DecimalFormat; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; @@ -60,6 +61,7 @@ import org.dllearner.utilities.statistics.Stat; import com.google.common.base.Charsets; +import com.google.common.collect.Lists; import com.google.common.io.Files; /** @@ -100,6 +102,15 @@ public class NestedCrossValidation { private File outputFile = new File("log/nested-cv.log"); + DecimalFormat df = new DecimalFormat(); + + // overall statistics + Stat globalAcc = new Stat(); + Stat globalF = new Stat(); + Stat globalRecall = new Stat(); + Stat globalPrecision = new Stat(); + + Map<Double,Stat> globalParaStats = new HashMap<Double,Stat>(); /** * Entry method, which uses JOptSimple to parse parameters. @@ -115,8 +126,7 @@ OptionParser parser = new OptionParser(); parser.acceptsAll(asList("h", "?", "help"), "Show help."); - parser.acceptsAll(asList("c", "conf"), "Conf file to use.").withRequiredArg().ofType( - File.class); + parser.acceptsAll(asList("c", "conf"), "The comma separated list of conffiles to be used.").withRequiredArg().describedAs("file1, file2, ..."); parser.acceptsAll(asList( "v", "verbose"), "Be more verbose."); parser.acceptsAll(asList( "o", "outerfolds"), "Number of outer folds.").withRequiredArg().ofType(Integer.class).describedAs("#folds"); parser.acceptsAll(asList( "i", "innerfolds"), "Number of inner folds.").withRequiredArg().ofType(Integer.class).describedAs("#folds"); @@ -139,7 +149,12 @@ // all options present => start nested cross validation } else if(options.has("c") && options.has("o") && options.has("i") && options.has("p") && options.has("r")) { // read all options in variables and parse option values - File confFile = (File) options.valueOf("c"); + String confFilesString = (String) options.valueOf("c"); + List<File> confFiles = new ArrayList<File>(); + for (String fileString : confFilesString.split(",")) { + confFiles.add(new File(fileString.trim())); + } + int outerFolds = (Integer) options.valueOf("o"); int innerFolds = (Integer) options.valueOf("i"); String parameter = (String) options.valueOf("p"); @@ -164,7 +179,7 @@ java.util.logging.Logger.getLogger("").setLevel(java.util.logging.Level.WARNING); System.out.println("Warning: The script is not well tested yet. (No known bugs, but needs more testing.)"); - new NestedCrossValidation(confFile, outerFolds, innerFolds, parameter, rangeStart, rangeEnd, stepsize, verbose); + new NestedCrossValidation(confFiles, outerFolds, innerFolds, parameter, rangeStart, rangeEnd, stepsize, verbose); // an option is missing => print help screen and message } else { @@ -182,16 +197,52 @@ } System.out.println(s); } + + public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException { + this(Lists.newArrayList(confFile), outerFolds, innerFolds, parameter, startValue, endValue, stepsize, verbose); + } - public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException { + public NestedCrossValidation(List<File> confFiles, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException { - DecimalFormat df = new DecimalFormat(); - ComponentManager cm = ComponentManager.getInstance(); + for (File confFile : confFiles) { + print(confFile.getPath()); + validate(confFile, outerFolds, innerFolds, parameter, startValue, endValue, stepsize, verbose); + } + print("********************************************"); + print("********************************************"); + print("********************************************"); + + // decide for the best parameter + print(" Summary over parameter values:"); + double bestPara = startValue; + double bestValue = Double.NEGATIVE_INFINITY; + for (Entry<Double, Stat> entry : globalParaStats.entrySet()) { + double para = entry.getKey(); + Stat stat = entry.getValue(); + print(" value " + para + ": " + stat.prettyPrint("%")); + if (stat.getMean() > bestValue) { + bestPara = para; + bestValue = stat.getMean(); + } + } + print(" selected " + bestPara + " as best parameter value (criterion value " + df.format(bestValue) + "%)"); + + // overall statistics + print("*******************"); + print("* Overall Results *"); + print("*******************"); + print("accuracy: " + globalAcc.prettyPrint("%")); + print("F measure: " + globalF.prettyPrint("%")); + print("precision: " + globalPrecision.prettyPrint("%")); + print("recall: " + globalRecall.prettyPrint("%")); + + } + + private void validate(File confFile, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws IOException, ComponentInitException{ CLI start = new CLI(confFile); start.init(); AbstractLearningProblem lp = start.getLearningProblem(); - System.out.println(lp); if(!(lp instanceof PosNegLP)) { System.out.println("Positive only learning not supported yet."); System.exit(0); @@ -213,7 +264,7 @@ Stat accOverall = new Stat(); Stat fOverall = new Stat(); Stat recallOverall = new Stat(); - Stat precisionOverall = new Stat(); + Stat precisionOverall = new Stat(); for(int currOuterFold=0; currOuterFold<outerFolds; currOuterFold++) { @@ -302,11 +353,15 @@ // free memory rs.releaseKB(); - cm.freeAllComponents(); } paraStats.put(currParaValue, paraCriterionStat); - + Stat globalParaStat = globalParaStats.get(currParaValue); + if(globalParaStat == null){ + globalParaStat = new Stat(); + globalParaStats.put(currParaValue, globalParaStat); + } + globalParaStat.add(paraCriterionStat); } // decide for the best parameter @@ -382,9 +437,13 @@ // free memory rs.releaseKB(); - cm.freeAllComponents(); } + globalAcc.add(accOverall); + globalF.add(fOverall); + globalPrecision.add(precisionOverall); + globalRecall.add(recallOverall); + // overall statistics print("*******************"); print("* Overall Results *"); @@ -393,7 +452,6 @@ print("F measure: " + fOverall.prettyPrint("%")); print("precision: " + precisionOverall.prettyPrint("%")); print("recall: " + recallOverall.prettyPrint("%")); - } // convenience methods, which takes a list of examples and divides them in Modified: trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java =================================================================== --- trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java 2014-05-08 14:39:43 UTC (rev 4264) +++ trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java 2014-05-08 14:40:15 UTC (rev 4265) @@ -496,13 +496,13 @@ lp.setReasoner(reasoner); lp.init(); QTL2Disjunctive la = new QTL2Disjunctive(lp, reasoner); - la.init(); - la.start(); +// la.init(); +// la.start(); CrossValidation.outputFile = new File("log/qtl-cv.log"); CrossValidation.writeToFile = true; CrossValidation.multiThreaded = multiThreaded; -// CrossValidation cv = new CrossValidation(la, lp, reasoner, nrOfFolds, false); + CrossValidation cv = new CrossValidation(la, lp, reasoner, nrOfFolds, false); long endTime = System.currentTimeMillis(); System.err.println((endTime - startTime) + "ms"); } This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |