From: <lor...@us...> - 2014-05-08 11:45:49
|
Revision: 4263 http://sourceforge.net/p/dl-learner/code/4263 Author: lorenz_b Date: 2014-05-08 11:45:45 +0000 (Thu, 08 May 2014) Log Message: ----------- Added more termination criteria for QTL algorithm. Added heuristics for tree score. Modified Paths: -------------- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java trunk/components-core/src/main/java/org/dllearner/learningproblems/Heuristics.java trunk/components-core/src/main/java/org/dllearner/learningproblems/QueryTreeScore.java trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java Modified: trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java =================================================================== --- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java 2014-05-07 13:43:54 UTC (rev 4262) +++ trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java 2014-05-08 11:45:45 UTC (rev 4263) @@ -2,6 +2,7 @@ import java.io.ByteArrayInputStream; import java.io.IOException; +import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -61,6 +62,7 @@ private static final Logger logger = Logger.getLogger(QTL2Disjunctive.class.getName()); + private final DecimalFormat df = new DecimalFormat("0.00"); private LGGGenerator<String> lggGenerator; @@ -92,21 +94,31 @@ private EvaluatedDescription currentBestSolution; + private QueryTreeHeuristic heuristic; //Parameters @ConfigOption(name = "noisePercentage", defaultValue="0.0", description="the (approximated) percentage of noise within the examples") private double noisePercentage = 0.0; @ConfigOption(defaultValue = "10", name = "maxExecutionTimeInSeconds", description = "maximum execution of the algorithm in seconds") - private int maxExecutionTimeInSeconds = 10; - private double minimumTreeScore = 0.2; + private int maxExecutionTimeInSeconds = -1; + private double coverageWeight = 0.8; private double specifityWeight = 0.1; - private double noise = 0.3; - private double coverageBeta = 0.7; + private double coverageBeta = 0.5; - private double posExampleWeight = 1; + private double minCoveredPosExamplesFraction = 0.2; + // maximum execution time to compute a part of the solution + private double maxTreeComputationTimeInSeconds = 60; + // how important not to cover negatives + private double posWeight = 2; + // minimum score a query tree must have to be part of the solution + private double minimumTreeScore = 0.2; + private long startTime; + + private long partialSolutionStartTime; + public QTL2Disjunctive() {} public QTL2Disjunctive(PosNegLP learningProblem, AbstractReasonerComponent reasoner) throws LearningProblemUnsupportedException{ @@ -142,6 +154,11 @@ lggGenerator = new LGGGeneratorImpl<String>(); + if(heuristic == null){ + heuristic = new QueryTreeHeuristic(); + heuristic.setPosExamplesWeight(2); + } + logger.info("Initializing..."); treeCache = new QueryTreeCache(model); tree2Individual = new HashMap<QueryTree<String>, Individual>(lp.getPositiveExamples().size()+lp.getNegativeExamples().size()); @@ -164,6 +181,7 @@ } private void generateTrees(){ + logger.info("Generating trees..."); QueryTree<String> queryTree; for (Individual ind : lp.getPositiveExamples()) { queryTree = treeCache.getQueryTree(ind.getName()); @@ -175,6 +193,7 @@ tree2Individual.put(queryTree, ind); currentNegExampleTrees.add(queryTree); } + logger.info("...done."); } /* (non-Javadoc) @@ -183,46 +202,59 @@ @Override public void start() { String setup = "Setup:"; - setup += "#Pos. examples:" + currentPosExamples.size(); - setup += "#Neg. examples:" + currentNegExamples.size(); - setup += "Coverage beta:" + coverageBeta; + setup += "\n#Pos. examples:" + currentPosExamples.size(); + setup += "\n#Neg. examples:" + currentNegExamples.size(); + setup += "\nCoverage beta:" + coverageBeta; logger.info(setup); logger.info("Running..."); - long startTime = System.currentTimeMillis(); + startTime = System.currentTimeMillis(); reset(); int i = 1; - do { + while(!terminationCriteriaSatisfied()){ logger.info(i++ + ". iteration..."); logger.info("#Remaining pos. examples:" + currentPosExampleTrees.size()); logger.info("#Remaining neg. examples:" + currentNegExampleTrees.size()); - //compute LGG - computeLGG(); + //compute a (partial) solution + computeNextPartialSolution(); //pick best (partial) solution computed so far EvaluatedQueryTree<String> bestPartialSolution = currentPartialSolutions.first(); - partialSolutions.add(bestPartialSolution); - //remove all covered examples - QueryTree<String> tree; - for (Iterator<QueryTree<String>> iterator = currentPosExampleTrees.iterator(); iterator.hasNext();) { - tree = iterator.next(); - if(tree.isSubsumedBy(bestPartialSolution.getTree())){ - iterator.remove(); - currentPosExamples.remove(tree2Individual.get(tree)); + //add if some criteria are satisfied + if(bestPartialSolution.getScore() >= minimumTreeScore){ + + partialSolutions.add(bestPartialSolution); + + //remove all covered examples + QueryTree<String> tree; + for (Iterator<QueryTree<String>> iterator = currentPosExampleTrees.iterator(); iterator.hasNext();) { + tree = iterator.next(); + if(tree.isSubsumedBy(bestPartialSolution.getTree())){ + iterator.remove(); + currentPosExamples.remove(tree2Individual.get(tree)); + } } - } - for (Iterator<QueryTree<String>> iterator = currentNegExampleTrees.iterator(); iterator.hasNext();) { - tree = iterator.next(); - if(tree.isSubsumedBy(bestPartialSolution.getTree())){ - iterator.remove(); - currentNegExamples.remove(tree2Individual.get(tree)); + for (Iterator<QueryTree<String>> iterator = currentNegExampleTrees.iterator(); iterator.hasNext();) { + tree = iterator.next(); + if(tree.isSubsumedBy(bestPartialSolution.getTree())){ + iterator.remove(); + currentNegExamples.remove(tree2Individual.get(tree)); + } } + //build the current combined solution + currentBestSolution = buildCombinedSolution(); + + logger.info("combined accuracy: " + df.format(currentBestSolution.getAccuracy())); + } else { + logger.info("no tree found, which satisfies the minimum criteria - the best was: " + + currentBestSolution.getDescription().toManchesterSyntaxString(baseURI, prefixes) + + " with score " + currentBestSolution.getScore()); } - currentBestSolution = buildCombinedSolution(); - } while (!(stop || currentPosExampleTrees.isEmpty())); + + }; isRunning = false; @@ -234,6 +266,119 @@ } + private void computeNextPartialSolution(){ + logger.info("Computing best partial solution..."); + currentlyBestScore = 0d; + partialSolutionStartTime = System.currentTimeMillis(); + initTodoList(currentPosExampleTrees, currentNegExampleTrees); + + EvaluatedQueryTree<String> currentElement; + while(!partialSolutionTerminationCriteriaSatisfied()){ + logger.trace("TODO list size: " + todoList.size()); + //pick best element from todo list + currentElement = todoList.poll(); + //generate the LGG between the chosen tree and each uncovered positive example + for (QueryTree<String> example : currentElement.getFalseNegatives()) { + QueryTree<String> tree = currentElement.getTree(); + + //compute the LGG + lggMon.start(); + QueryTree<String> lgg = lggGenerator.getLGG(tree, example); + lggMon.stop(); + + //evaluate the LGG + EvaluatedQueryTree<String> solution = evaluate(lgg, true); + double score = solution.getScore(); + double mas = heuristic.getMaximumAchievableScore(solution); + + if(score >= currentlyBestScore){ + //add to todo list, if not already contained in todo list or solution list + todo(solution); + if(solution.getScore() > currentlyBestScore){ + logger.info("Got better solution:" + solution.getTreeScore()); + } + currentlyBestScore = solution.getScore(); + } else if(mas < currentlyBestScore){ + todo(solution); + } else { + System.out.println("Too general"); + } + currentPartialSolutions.add(currentElement); + + } + currentPartialSolutions.add(currentElement); + } + long endTime = System.currentTimeMillis(); + logger.info("...finished in " + (endTime-partialSolutionStartTime) + "ms."); + EvaluatedDescription bestPartialSolution = currentPartialSolutions.first().asEvaluatedDescription(); + + logger.info("Best partial solution:\n" + OWLAPIConverter.getOWLAPIDescription(bestPartialSolution.getDescription()) + "\n(" + bestPartialSolution.getScore() + ")"); + + logger.trace("LGG time: " + lggMon.getTotal() + "ms"); + logger.trace("Avg. LGG time: " + lggMon.getAvg() + "ms"); + logger.info("#LGG computations: " + lggMon.getHits()); + logger.trace("Subsumption test time: " + subMon.getTotal() + "ms"); + logger.trace("Avg. subsumption test time: " + subMon.getAvg() + "ms"); + logger.trace("#Subsumption tests: " + subMon.getHits()); + } + + private EvaluatedQueryTree<String> evaluate(QueryTree<String> tree, boolean useSpecifity){ + //1. get a score for the coverage = recall oriented + //compute positive examples which are not covered by LGG + List<QueryTree<String>> uncoveredPositiveExampleTrees = getUncoveredTrees(tree, currentPosExampleTrees); + Set<Individual> uncoveredPosExamples = new TreeSet<Individual>(); + for (QueryTree<String> queryTree : uncoveredPositiveExampleTrees) { + uncoveredPosExamples.add(tree2Individual.get(queryTree)); + } + //compute negative examples which are covered by LGG + Collection<QueryTree<String>> coveredNegativeExampleTrees = getCoveredTrees(tree, currentNegExampleTrees); + Set<Individual> coveredNegExamples = new TreeSet<Individual>(); + for (QueryTree<String> queryTree : coveredNegativeExampleTrees) { + coveredNegExamples.add(tree2Individual.get(queryTree)); + } + //compute score + int coveredPositiveExamples = currentPosExampleTrees.size() - uncoveredPositiveExampleTrees.size(); + double recall = coveredPositiveExamples / (double)currentPosExampleTrees.size(); + double precision = (coveredNegativeExampleTrees.size() + coveredPositiveExamples == 0) + ? 0 + : coveredPositiveExamples / (double)(coveredPositiveExamples + coveredNegativeExampleTrees.size()); + + double coverageScore = Heuristics.getFScore(recall, precision, coverageBeta); + + //2. get a score for the specifity of the query, i.e. how many edges/nodes = precision oriented + int nrOfSpecificNodes = 0; + for (QueryTree<String> childNode : tree.getChildrenClosure()) { + if(!childNode.getUserObject().equals("?")){ + nrOfSpecificNodes++; + } + } + double specifityScore = 0d; + if(useSpecifity){ + specifityScore = Math.log(nrOfSpecificNodes); + } + + //3.compute the total score + double score = coverageWeight * coverageScore + specifityWeight * specifityScore; + + QueryTreeScore queryTreeScore = new QueryTreeScore(score, coverageScore, + new TreeSet<Individual>(Sets.difference(currentPosExamples, uncoveredPosExamples)), uncoveredPosExamples, + coveredNegExamples, new TreeSet<Individual>(Sets.difference(currentNegExamples, coveredNegExamples)), + specifityScore, nrOfSpecificNodes); + +// QueryTreeScore queryTreeScore = new QueryTreeScore(score, coverageScore, +// null,null,null,null, +// specifityScore, nrOfSpecificNodes); + + EvaluatedQueryTree<String> evaluatedTree = new EvaluatedQueryTree<String>(tree, uncoveredPositiveExampleTrees, coveredNegativeExampleTrees, queryTreeScore); + + //TODO use only the heuristic to compute the score + score = heuristic.getScore(evaluatedTree); + queryTreeScore.setScore(score); + queryTreeScore.setAccuracy(score); + + return evaluatedTree; + } + private EvaluatedDescription buildCombinedSolution(){ if(partialSolutions.size() == 1){ EvaluatedDescription combinedSolution = partialSolutions.get(0).asEvaluatedDescription(); @@ -288,58 +433,8 @@ lggMon.reset(); } - private void computeLGG(){ - logger.info("Computing best partial solution..."); - currentlyBestScore = 0d; - - initTodoList(currentPosExampleTrees, currentNegExampleTrees); - - long startTime = System.currentTimeMillis(); - EvaluatedQueryTree<String> currentElement; - do{ - logger.trace("TODO list size: " + todoList.size()); - //pick best element from todo list - currentElement = todoList.poll(); - //generate the LGG between the chosen tree and each uncovered positive example - for (QueryTree<String> example : currentElement.getFalseNegatives()) { - QueryTree<String> tree = currentElement.getTree(); - - //compute the LGG - lggMon.start(); - QueryTree<String> lgg = lggGenerator.getLGG(tree, example); - lggMon.stop(); - - //evaluate the LGG - EvaluatedQueryTree<String> solution = evaluate(lgg, true); - - if(solution.getScore() >= currentlyBestScore){ - //add to todo list, if not already contained in todo list or solution list - todo(solution); - if(solution.getScore() > currentlyBestScore){ - logger.info("Got better solution:" + solution.getTreeScore()); - } - currentlyBestScore = solution.getScore(); - } - currentPartialSolutions.add(currentElement); - - } - currentPartialSolutions.add(currentElement); -// todoList.remove(currentElement); - } while(!terminationCriteriaSatisfied()); - long endTime = System.currentTimeMillis(); - logger.info("...finished in " + (endTime-startTime) + "ms."); - EvaluatedDescription bestPartialSolution = currentPartialSolutions.first().asEvaluatedDescription(); - - logger.info("Best partial solution:\n" + OWLAPIConverter.getOWLAPIDescription(bestPartialSolution.getDescription()) + "\n(" + bestPartialSolution.getScore() + ")"); - - logger.trace("LGG time: " + lggMon.getTotal() + "ms"); - logger.trace("Avg. LGG time: " + lggMon.getAvg() + "ms"); - logger.trace("#LGG computations: " + lggMon.getHits()); - logger.trace("Subsumption test time: " + subMon.getTotal() + "ms"); - logger.trace("Avg. subsumption test time: " + subMon.getAvg() + "ms"); - logger.trace("#Subsumption tests: " + subMon.getHits()); - } + /* (non-Javadoc) * @see org.dllearner.core.StoppableLearningAlgorithm#stop() */ @@ -411,58 +506,7 @@ return treeCache; } - private EvaluatedQueryTree<String> evaluate(QueryTree<String> tree, boolean useSpecifity){ - //1. get a score for the coverage = recall oriented - //compute positive examples which are not covered by LGG - List<QueryTree<String>> uncoveredPositiveExampleTrees = getUncoveredTrees(tree, currentPosExampleTrees); - Set<Individual> uncoveredPosExamples = new TreeSet<Individual>(); - for (QueryTree<String> queryTree : uncoveredPositiveExampleTrees) { - uncoveredPosExamples.add(tree2Individual.get(queryTree)); - } - //compute negative examples which are covered by LGG - Collection<QueryTree<String>> coveredNegativeExampleTrees = getCoveredTrees(tree, currentNegExampleTrees); - Set<Individual> coveredNegExamples = new TreeSet<Individual>(); - for (QueryTree<String> queryTree : coveredNegativeExampleTrees) { - coveredNegExamples.add(tree2Individual.get(queryTree)); - } - //compute score - int coveredPositiveExamples = currentPosExampleTrees.size() - uncoveredPositiveExampleTrees.size(); - double recall = coveredPositiveExamples / (double)currentPosExampleTrees.size(); - double precision = (coveredNegativeExampleTrees.size() + coveredPositiveExamples == 0) - ? 0 - : coveredPositiveExamples / (double)(coveredPositiveExamples + coveredNegativeExampleTrees.size()); - - double beta = 0.5; - double coverageScore = Heuristics.getFScore(recall, precision, beta); - - //2. get a score for the specifity of the query, i.e. how many edges/nodes = precision oriented - int nrOfSpecificNodes = 0; - for (QueryTree<String> childNode : tree.getChildrenClosure()) { - if(!childNode.getUserObject().equals("?")){ - nrOfSpecificNodes++; - } - } - double specifityScore = 0d; - if(useSpecifity){ - specifityScore = Math.log(nrOfSpecificNodes); - } - - //3.compute the total score - double score = coverageWeight * coverageScore + specifityWeight * specifityScore; - - QueryTreeScore queryTreeScore = new QueryTreeScore(score, coverageScore, - new TreeSet<Individual>(Sets.difference(currentPosExamples, uncoveredPosExamples)), uncoveredPosExamples, - coveredNegExamples, new TreeSet<Individual>(Sets.difference(currentNegExamples, coveredNegExamples)), - specifityScore, nrOfSpecificNodes); - -// QueryTreeScore queryTreeScore = new QueryTreeScore(score, coverageScore, -// null,null,null,null, -// specifityScore, nrOfSpecificNodes); - - EvaluatedQueryTree<String> evaluatedTree = new EvaluatedQueryTree<String>(tree, uncoveredPositiveExampleTrees, coveredNegativeExampleTrees, queryTreeScore); - - return evaluatedTree; - } + /** * Return all trees from the given list {@code allTrees} which are not already subsumed by {@code tree}. @@ -535,9 +579,21 @@ } private boolean terminationCriteriaSatisfied(){ - return stop || todoList.isEmpty() || currentPosExampleTrees.isEmpty(); + return stop || isTimeExpired() || currentPosExampleTrees.isEmpty(); } + private boolean partialSolutionTerminationCriteriaSatisfied(){ + return stop || todoList.isEmpty() || currentPosExampleTrees.isEmpty() || isPartialSolutionTimeExpired() || isTimeExpired(); + } + + private boolean isTimeExpired(){ + return maxExecutionTimeInSeconds <= 0 ? false : (System.currentTimeMillis() - startTime)/1000d >= maxExecutionTimeInSeconds; + } + + private boolean isPartialSolutionTimeExpired(){ + return maxTreeComputationTimeInSeconds <= 0 ? false : (System.currentTimeMillis() - partialSolutionStartTime)/1000d >= maxTreeComputationTimeInSeconds; + } + /** * Add tree to todo list if not already contained in that list or the solutions. * @param solution Modified: trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java =================================================================== --- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java 2014-05-07 13:43:54 UTC (rev 4262) +++ trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java 2014-05-08 11:45:45 UTC (rev 4263) @@ -13,6 +13,7 @@ import org.dllearner.core.Heuristic; import org.dllearner.core.owl.Individual; import org.dllearner.learningproblems.Heuristics; +import org.dllearner.learningproblems.Heuristics.HeuristicType; import org.dllearner.learningproblems.QueryTreeScore; import org.dllearner.utilities.owl.ConceptComparator; @@ -23,6 +24,8 @@ @ComponentAnn(name = "QueryTreeHeuristic", shortName = "qtree_heuristic", version = 0.1) public class QueryTreeHeuristic extends AbstractComponent implements Heuristic, Comparator<EvaluatedQueryTree<String>>{ + private HeuristicType heuristicType = HeuristicType.PRED_ACC; + // F score beta value private double coverageBeta = 1; @@ -30,6 +33,8 @@ private double specifityWeight = 0.1; + private double posExamplesWeight = 1; + // syntactic comparison as final comparison criterion private ConceptComparator conceptComparator = new ConceptComparator(); @@ -43,23 +48,88 @@ public double getScore(EvaluatedQueryTree<String> tree){ QueryTreeScore treeScore = tree.getTreeScore(); - //TODO - double score = treeScore.getScore(); + Set<Individual> truePositives = treeScore.getCoveredPositives(); + Set<Individual> trueNegatives = treeScore.getNotCoveredNegatives(); + Set<Individual> falsePositives = treeScore.getNotCoveredPositives(); + Set<Individual> falseNegatives = treeScore.getCoveredNegatives(); + double tp = truePositives.size(); + double tn = trueNegatives.size(); + double fp = falsePositives.size(); + double fn = falseNegatives.size(); + + double score = 0; + switch(heuristicType){ + case FMEASURE : + score = Heuristics.getFScore(tp/(tp+fn), tp/(tp+fp), posExamplesWeight);break; + case PRED_ACC : + score = (posExamplesWeight * tp + tn) / (posExamplesWeight * (tp + fn) + tn + fp);break; + case ENTROPY :{ + double total = tp + fn; + double pp = tp / total; + double pn = fn / total; + score = pp * Math.log(pp) + pn * Math.log(pn); + break;} + case MATTHEWS_CORRELATION : + score = (tp * tn - fp * fn) / Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));break; + case YOUDEN_INDEX : score = tp / (tp + fn) + tn / (fp + tn) - 1;break; + default: + break; + + } + return score; } - private double getPredictedAccuracy(EvaluatedQueryTree<String> tree){ + /** + * Returns the maximum achievable score according to the used score function. + * @return + */ + public double getMaximumAchievableScore(EvaluatedQueryTree<String> tree) { QueryTreeScore treeScore = tree.getTreeScore(); Set<Individual> truePositives = treeScore.getCoveredPositives(); Set<Individual> trueNegatives = treeScore.getNotCoveredNegatives(); Set<Individual> falsePositives = treeScore.getNotCoveredPositives(); Set<Individual> falseNegatives = treeScore.getCoveredNegatives(); - return 0; + double tp = truePositives.size(); + double tn = trueNegatives.size(); + double fp = falsePositives.size(); + double fn = falseNegatives.size(); + + return getMaximumAchievableScore(tp, tn, fp, fn); } + + /** + * Returns the maximum achievable score according to the used score function. + * @param tp + * @param tn + * @param fp + * @param fn + * @return + */ + private double getMaximumAchievableScore(double tp, double tn, double fp, double fn) { + double mas = 0d; + switch (heuristicType) { + case FMEASURE: + break; + case PRED_ACC: + mas = (posExamplesWeight * tp + tn - fp) / (posExamplesWeight * (tp + fn) + tn + fp); + break; + case ENTROPY: + break; + case MATTHEWS_CORRELATION: + break; + case YOUDEN_INDEX: + break; + default: + break; + } + return mas; + } + /* (non-Javadoc) * @see java.util.Comparator#compare(java.lang.Object, java.lang.Object) */ @@ -75,5 +145,19 @@ return conceptComparator.compare(tree1.asEvaluatedDescription().getDescription(), tree2.asEvaluatedDescription().getDescription()); } } + + /** + * @param heuristicType the heuristicType to set + */ + public void setHeuristicType(HeuristicType heuristicType) { + this.heuristicType = heuristicType; + } + + /** + * @param posExamplesWeight the posExamplesWeight to set + */ + public void setPosExamplesWeight(double posExamplesWeight) { + this.posExamplesWeight = posExamplesWeight; + } } Modified: trunk/components-core/src/main/java/org/dllearner/learningproblems/Heuristics.java =================================================================== --- trunk/components-core/src/main/java/org/dllearner/learningproblems/Heuristics.java 2014-05-07 13:43:54 UTC (rev 4262) +++ trunk/components-core/src/main/java/org/dllearner/learningproblems/Heuristics.java 2014-05-08 11:45:45 UTC (rev 4263) @@ -29,7 +29,7 @@ */ public class Heuristics { - public static enum HeuristicType { PRED_ACC, AMEASURE, JACCARD, FMEASURE, GEN_FMEASURE }; + public static enum HeuristicType { PRED_ACC, AMEASURE, JACCARD, FMEASURE, GEN_FMEASURE, ENTROPY, MATTHEWS_CORRELATION, YOUDEN_INDEX }; /** * Computes F1-Score. Modified: trunk/components-core/src/main/java/org/dllearner/learningproblems/QueryTreeScore.java =================================================================== --- trunk/components-core/src/main/java/org/dllearner/learningproblems/QueryTreeScore.java 2014-05-07 13:43:54 UTC (rev 4262) +++ trunk/components-core/src/main/java/org/dllearner/learningproblems/QueryTreeScore.java 2014-05-08 11:45:45 UTC (rev 4263) @@ -16,7 +16,7 @@ private double score; - private double coverageScore; + private double accuracy; private double specifityScore; private int nrOfSpecificNodes; @@ -26,12 +26,12 @@ private Set<Individual> negAsPos; private Set<Individual> negAsNeg; - public QueryTreeScore(double score, double coverageScore, + public QueryTreeScore(double score, double accuracy, Set<Individual> posAsPos, Set<Individual> posAsNeg, Set<Individual> negAsPos, Set<Individual> negAsNeg, double specifityScore, int nrOfSpecificNodes) { super(); this.score = score; - this.coverageScore = coverageScore; + this.accuracy = accuracy; this.posAsPos = posAsPos; this.posAsNeg = posAsNeg; this.negAsPos = negAsPos; @@ -46,15 +46,29 @@ public double getScore() { return score; } + + /** + * @param score the score to set + */ + public void setScore(double score) { + this.score = score; + } /* (non-Javadoc) * @see org.dllearner.core.Score#getAccuracy() */ @Override public double getAccuracy() { - return score; + return accuracy; } + /** + * @param accuracy the accuracy to set + */ + public void setAccuracy(double accuracy) { + this.accuracy = accuracy; + } + public Set<Individual> getCoveredNegatives() { return negAsPos; } @@ -77,7 +91,7 @@ @Override public String toString() { return score - + "(coverage=" + coverageScore + + "(accuracy=" + accuracy + "(+" + posAsPos.size() + "/" + (posAsPos.size() + posAsNeg.size()) + "|-" + negAsPos.size() + "/" + (negAsPos.size() + negAsNeg.size()) + ")|" + "specifity=" + specifityScore + "(" + nrOfSpecificNodes + "))"; Modified: trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java =================================================================== --- trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java 2014-05-07 13:43:54 UTC (rev 4262) +++ trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java 2014-05-08 11:45:45 UTC (rev 4263) @@ -19,44 +19,49 @@ */ package org.dllearner.scripts; +import static java.util.Arrays.asList; + import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; import java.text.DecimalFormat; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Random; import java.util.Set; import java.util.TreeSet; -import java.util.Map.Entry; +import joptsimple.OptionParser; +import joptsimple.OptionSet; + +import org.apache.commons.beanutils.PropertyUtils; import org.apache.log4j.ConsoleAppender; +import org.apache.log4j.FileAppender; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.log4j.SimpleLayout; -import org.dllearner.cli.Start; -import org.dllearner.core.ComponentInitException; -import org.dllearner.core.ComponentManager; +import org.dllearner.cli.CLI; import org.dllearner.core.AbstractCELA; import org.dllearner.core.AbstractLearningProblem; import org.dllearner.core.AbstractReasonerComponent; +import org.dllearner.core.ComponentInitException; +import org.dllearner.core.ComponentManager; import org.dllearner.core.owl.Description; import org.dllearner.core.owl.Individual; import org.dllearner.learningproblems.PosNegLP; import org.dllearner.parser.ParseException; import org.dllearner.utilities.Helper; -import org.dllearner.utilities.datastructures.Datastructures; import org.dllearner.utilities.datastructures.TrainTestList; import org.dllearner.utilities.statistics.Stat; -import joptsimple.OptionParser; -import joptsimple.OptionSet; +import com.google.common.base.Charsets; +import com.google.common.io.Files; -import static java.util.Arrays.*; - /** * Performs nested cross validation for the given problem. A k fold outer and l * fold inner cross validation is used. Parameters: @@ -93,6 +98,8 @@ * */ public class NestedCrossValidation { + + private File outputFile = new File("log/nested-cv.log"); /** * Entry method, which uses JOptSimple to parse parameters. @@ -115,6 +122,7 @@ parser.acceptsAll(asList( "i", "innerfolds"), "Number of inner folds.").withRequiredArg().ofType(Integer.class).describedAs("#folds"); parser.acceptsAll(asList( "p", "parameter"), "Parameter to vary.").withRequiredArg(); parser.acceptsAll(asList( "r", "pvalues", "range"), "Values of parameter. $x-$y can be used for integer ranges.").withRequiredArg(); + parser.acceptsAll(asList( "s", "stepsize", "steps"), "Step size of range.").withOptionalArg().ofType(Double.class).defaultsTo(1d); // parse options and display a message for the user in case of problems OptionSet options = null; @@ -137,8 +145,9 @@ String parameter = (String) options.valueOf("p"); String range = (String) options.valueOf("r"); String[] rangeSplit = range.split("-"); - int rangeStart = new Integer(rangeSplit[0]); - int rangeEnd = new Integer(rangeSplit[1]); + double rangeStart = Double.valueOf(rangeSplit[0]); + double rangeEnd = Double.valueOf(rangeSplit[1]); + double stepsize = (Double) options.valueOf("s"); boolean verbose = options.has("v"); // create logger (a simple logger which outputs @@ -149,11 +158,13 @@ logger.removeAllAppenders(); logger.addAppender(consoleAppender); logger.setLevel(Level.WARN); + Logger.getLogger("org.dllearner.algorithms").setLevel(Level.INFO); +// logger.addAppender(new FileAppender(layout, "nested-cv.log", false)); // disable OWL API info output java.util.logging.Logger.getLogger("").setLevel(java.util.logging.Level.WARNING); System.out.println("Warning: The script is not well tested yet. (No known bugs, but needs more testing.)"); - new NestedCrossValidation(confFile, outerFolds, innerFolds, parameter, rangeStart, rangeEnd, verbose); + new NestedCrossValidation(confFile, outerFolds, innerFolds, parameter, rangeStart, rangeEnd, stepsize, verbose); // an option is missing => print help screen and message } else { @@ -163,14 +174,24 @@ } - public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, int startValue, int endValue, boolean verbose) throws FileNotFoundException, ComponentInitException, ParseException, org.dllearner.confparser.ParseException { + private void print(String s){ + try { + Files.append(s + "\n", outputFile , Charsets.UTF_8); + } catch (IOException e) { + e.printStackTrace(); + } + System.out.println(s); + } + + public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException { DecimalFormat df = new DecimalFormat(); ComponentManager cm = ComponentManager.getInstance(); - Start start = new Start(confFile); + CLI start = new CLI(confFile); + start.init(); AbstractLearningProblem lp = start.getLearningProblem(); - + System.out.println(lp); if(!(lp instanceof PosNegLP)) { System.out.println("Positive only learning not supported yet."); System.exit(0); @@ -196,16 +217,16 @@ for(int currOuterFold=0; currOuterFold<outerFolds; currOuterFold++) { - System.out.println("Outer fold " + currOuterFold); + print("Outer fold " + currOuterFold); TrainTestList posList = posLists.get(currOuterFold); TrainTestList negList = negLists.get(currOuterFold); // measure relevant criterion (accuracy, F-measure) over different parameter values - Map<Integer,Stat> paraStats = new HashMap<Integer,Stat>(); + Map<Double,Stat> paraStats = new HashMap<Double,Stat>(); - for(int currParaValue=startValue; currParaValue<=endValue; currParaValue++) { + for(double currParaValue=startValue; currParaValue<=endValue; currParaValue+=stepsize) { - System.out.println(" Parameter value " + currParaValue + ":"); + print(" Parameter value " + currParaValue + ":"); // split train folds again (computation of inner folds for each parameter // value is redundant, but not a big problem) List<Individual> trainPosList = posList.getTrainList(); @@ -219,19 +240,24 @@ for(int currInnerFold=0; currInnerFold<innerFolds; currInnerFold++) { - System.out.println(" Inner fold " + currInnerFold + ":"); + print(" Inner fold " + currInnerFold + ":"); // get positive & negative examples for training run Set<Individual> posEx = new TreeSet<Individual>(innerPosLists.get(currInnerFold).getTrainList()); Set<Individual> negEx = new TreeSet<Individual>(innerNegLists.get(currInnerFold).getTrainList()); // read conf file and exchange options for pos/neg examples // and parameter to optimise - start = new Start(confFile); + start = new CLI(confFile); + start.init(); AbstractLearningProblem lpIn = start.getLearningProblem(); - cm.applyConfigEntry(lpIn, "positiveExamples", Datastructures.individualSetToStringSet(posEx)); - cm.applyConfigEntry(lpIn, "negativeExamples", Datastructures.individualSetToStringSet(negEx)); + ((PosNegLP)lpIn).setPositiveExamples(posEx); + ((PosNegLP)lpIn).setNegativeExamples(negEx); AbstractCELA laIn = start.getLearningAlgorithm(); - cm.applyConfigEntry(laIn, parameter, (double)currParaValue); + try { + PropertyUtils.setSimpleProperty(laIn, parameter, currParaValue); + } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + e.printStackTrace(); + } lpIn.init(); laIn.init(); @@ -263,15 +289,15 @@ paraCriterionStat.addNumber(accuracy); - System.out.println(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null)); - System.out.println(" accuracy: " + df.format(accuracy) + "%"); - System.out.println(" precision: " + df.format(precision) + "%"); - System.out.println(" recall: " + df.format(recall) + "%"); - System.out.println(" F measure: " + df.format(fmeasure) + "%"); + print(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null)); + print(" accuracy: " + df.format(accuracy) + "%"); + print(" precision: " + df.format(precision) + "%"); + print(" recall: " + df.format(recall) + "%"); + print(" F measure: " + df.format(fmeasure) + "%"); if(verbose) { - System.out.println(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI)); - System.out.println(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI)); + print(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI)); + print(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI)); } // free memory @@ -284,28 +310,33 @@ } // decide for the best parameter - System.out.println(" Summary over parameter values:"); - int bestPara = startValue; + print(" Summary over parameter values:"); + double bestPara = startValue; double bestValue = Double.NEGATIVE_INFINITY; - for(Entry<Integer,Stat> entry : paraStats.entrySet()) { - int para = entry.getKey(); + for(Entry<Double,Stat> entry : paraStats.entrySet()) { + double para = entry.getKey(); Stat stat = entry.getValue(); - System.out.println(" value " + para + ": " + stat.prettyPrint("%")); + print(" value " + para + ": " + stat.prettyPrint("%")); if(stat.getMean() > bestValue) { bestPara = para; bestValue = stat.getMean(); } } - System.out.println(" selected " + bestPara + " as best parameter value (criterion value " + df.format(bestValue) + "%)"); - System.out.println(" Learn on Outer fold:"); + print(" selected " + bestPara + " as best parameter value (criterion value " + df.format(bestValue) + "%)"); + print(" Learn on Outer fold:"); // start a learning process with this parameter and evaluate it on the outer fold - start = new Start(confFile); + start = new CLI(confFile); + start.init(); AbstractLearningProblem lpOut = start.getLearningProblem(); - cm.applyConfigEntry(lpOut, "positiveExamples", Datastructures.individualListToStringSet(posLists.get(currOuterFold).getTrainList())); - cm.applyConfigEntry(lpOut, "negativeExamples", Datastructures.individualListToStringSet(negLists.get(currOuterFold).getTrainList())); + ((PosNegLP)lpOut).setPositiveExamples(new TreeSet<Individual>(posLists.get(currOuterFold).getTrainList())); + ((PosNegLP)lpOut).setNegativeExamples(new TreeSet<Individual>(negLists.get(currOuterFold).getTrainList())); AbstractCELA laOut = start.getLearningAlgorithm(); - cm.applyConfigEntry(laOut, parameter, (double)bestPara); + try { + PropertyUtils.setSimpleProperty(laOut, parameter, bestPara); + } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + e.printStackTrace(); + } lpOut.init(); laOut.init(); @@ -332,15 +363,15 @@ double recall = 100 * (double) posCorrect.size() / (posCorrect.size() + posError.size()); double fmeasure = 2 * (precision * recall) / (precision + recall); - System.out.println(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null)); - System.out.println(" accuracy: " + df.format(accuracy) + "%"); - System.out.println(" precision: " + df.format(precision) + "%"); - System.out.println(" recall: " + df.format(recall) + "%"); - System.out.println(" F measure: " + df.format(fmeasure) + "%"); + print(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null)); + print(" accuracy: " + df.format(accuracy) + "%"); + print(" precision: " + df.format(precision) + "%"); + print(" recall: " + df.format(recall) + "%"); + print(" F measure: " + df.format(fmeasure) + "%"); if(verbose) { - System.out.println(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI)); - System.out.println(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI)); + print(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI)); + print(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI)); } // update overall statistics @@ -355,14 +386,13 @@ } // overall statistics - System.out.println(); - System.out.println("*******************"); - System.out.println("* Overall Results *"); - System.out.println("*******************"); - System.out.println("accuracy: " + accOverall.prettyPrint("%")); - System.out.println("F measure: " + fOverall.prettyPrint("%")); - System.out.println("precision: " + precisionOverall.prettyPrint("%")); - System.out.println("recall: " + recallOverall.prettyPrint("%")); + print("*******************"); + print("* Overall Results *"); + print("*******************"); + print("accuracy: " + accOverall.prettyPrint("%")); + print("F measure: " + fOverall.prettyPrint("%")); + print("precision: " + precisionOverall.prettyPrint("%")); + print("recall: " + recallOverall.prettyPrint("%")); } Modified: trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java =================================================================== --- trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java 2014-05-07 13:43:54 UTC (rev 4262) +++ trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java 2014-05-08 11:45:45 UTC (rev 4263) @@ -53,8 +53,8 @@ public class QTLEvaluation { int nrOfFolds = 10; - private int nrOfPosExamples = 100; - private int nrOfNegExamples = 100; + private int nrOfPosExamples = 300; + private int nrOfNegExamples = 300; List<String> posExamples = Lists.newArrayList( "http://dl-learner.org/carcinogenesis#d1", @@ -496,13 +496,13 @@ lp.setReasoner(reasoner); lp.init(); QTL2Disjunctive la = new QTL2Disjunctive(lp, reasoner); -// la.init(); -// la.start(); + la.init(); + la.start(); CrossValidation.outputFile = new File("log/qtl-cv.log"); CrossValidation.writeToFile = true; CrossValidation.multiThreaded = multiThreaded; - CrossValidation cv = new CrossValidation(la, lp, reasoner, nrOfFolds, false); +// CrossValidation cv = new CrossValidation(la, lp, reasoner, nrOfFolds, false); long endTime = System.currentTimeMillis(); System.err.println((endTime - startTime) + "ms"); } This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |