|
From: <lor...@us...> - 2014-05-08 14:40:17
|
Revision: 4265
http://sourceforge.net/p/dl-learner/code/4265
Author: lorenz_b
Date: 2014-05-08 14:40:15 +0000 (Thu, 08 May 2014)
Log Message:
-----------
Extended CV script.
Modified Paths:
--------------
trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java
trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java
trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/datastructures/impl/QueryTreeImpl.java
trunk/components-core/src/main/java/org/dllearner/utilities/statistics/Stat.java
trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java
trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java
Modified: trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java
===================================================================
--- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java 2014-05-08 14:39:43 UTC (rev 4264)
+++ trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java 2014-05-08 14:40:15 UTC (rev 4265)
@@ -17,6 +17,7 @@
import java.util.SortedSet;
import java.util.TreeSet;
+import org.apache.commons.collections.ListUtils;
import org.apache.log4j.Logger;
import org.dllearner.algorithms.qtl.cache.QueryTreeCache;
import org.dllearner.algorithms.qtl.datastructures.QueryTree;
@@ -69,7 +70,7 @@
private Queue<EvaluatedQueryTree<String>> todoList;
private SortedSet<EvaluatedQueryTree<String>> currentPartialSolutions;
- private double currentlyBestScore = 0d;
+ private double bestCurrentScore = 0d;
private List<QueryTree<String>> currentPosExampleTrees;
private List<QueryTree<String>> currentNegExampleTrees;
@@ -114,11 +115,17 @@
private double posWeight = 2;
// minimum score a query tree must have to be part of the solution
private double minimumTreeScore = 0.2;
+ //If yes, then the algorithm tries to cover all positive examples. Note that while this improves accuracy on the testing set,
+ //it may lead to overfitting
+ private boolean tryFullCoverage;
+ //algorithm will terminate immediately when a correct definition is found
+ private boolean stopOnFirstDefinition;
private long startTime;
-
private long partialSolutionStartTime;
+ private double startPosExamplesSize;
+
public QTL2Disjunctive() {}
public QTL2Disjunctive(PosNegLP learningProblem, AbstractReasonerComponent reasoner) throws LearningProblemUnsupportedException{
@@ -156,7 +163,7 @@
if(heuristic == null){
heuristic = new QueryTreeHeuristic();
- heuristic.setPosExamplesWeight(2);
+ heuristic.setPosExamplesWeight(posWeight);
}
logger.info("Initializing...");
@@ -168,6 +175,8 @@
currentPosExamples = new TreeSet<Individual>(lp.getPositiveExamples());
currentNegExamples = new TreeSet<Individual>(lp.getNegativeExamples());
+ startPosExamplesSize = currentPosExamples.size();
+
//get the query trees
generateTrees();
@@ -178,6 +187,15 @@
//console rendering of class expressions
ToStringRenderer.getInstance().setRenderer(new ManchesterOWLSyntaxOWLObjectRendererImpl());
ToStringRenderer.getInstance().setShortFormProvider(new SimpleShortFormProvider());
+
+ //compute the LGG for all examples
+ //this allows us to prune all other trees because we can omit paths in trees which are contained in all positive
+ //as well as negative examples
+// List<QueryTree<String>> allExamplesTrees = new ArrayList<QueryTree<String>>();
+// allExamplesTrees.addAll(currentPosExampleTrees);
+// allExamplesTrees.addAll(currentNegExampleTrees);
+// QueryTree<String> lgg = lggGenerator.getLGG(allExamplesTrees);
+// lgg.dump();
}
private void generateTrees(){
@@ -204,7 +222,7 @@
String setup = "Setup:";
setup += "\n#Pos. examples:" + currentPosExamples.size();
setup += "\n#Neg. examples:" + currentNegExamples.size();
- setup += "\nCoverage beta:" + coverageBeta;
+ setup += "\nPos. weight(beta):" + posWeight;
logger.info(setup);
logger.info("Running...");
startTime = System.currentTimeMillis();
@@ -268,7 +286,7 @@
private void computeNextPartialSolution(){
logger.info("Computing best partial solution...");
- currentlyBestScore = 0d;
+ bestCurrentScore = Double.NEGATIVE_INFINITY;
partialSolutionStartTime = System.currentTimeMillis();
initTodoList(currentPosExampleTrees, currentNegExampleTrees);
@@ -291,14 +309,14 @@
double score = solution.getScore();
double mas = heuristic.getMaximumAchievableScore(solution);
- if(score >= currentlyBestScore){
+ if(score >= bestCurrentScore){
//add to todo list, if not already contained in todo list or solution list
todo(solution);
- if(solution.getScore() > currentlyBestScore){
+ if(solution.getScore() > bestCurrentScore){
logger.info("Got better solution:" + solution.getTreeScore());
}
- currentlyBestScore = solution.getScore();
- } else if(mas < currentlyBestScore){
+ bestCurrentScore = solution.getScore();
+ } else if(mas < bestCurrentScore){
todo(solution);
} else {
System.out.println("Too general");
@@ -431,6 +449,8 @@
subMon.reset();
lggMon.reset();
+
+ bestCurrentScore = minimumTreeScore;
}
@@ -578,8 +598,31 @@
return tree1.isSubsumedBy(tree2) && tree2.isSubsumedBy(tree1);
}
- private boolean terminationCriteriaSatisfied(){
- return stop || isTimeExpired() || currentPosExampleTrees.isEmpty();
+ private boolean terminationCriteriaSatisfied() {
+ //stop was called or time expired
+ if(stop || isTimeExpired()){
+ return true;
+ }
+
+ // stop if there are no more positive examples to cover
+ if (stopOnFirstDefinition && currentPosExamples.isEmpty()) {
+ return true;
+ }
+
+ // we stop when the score of the last tree added is too low
+ // (indicating that the algorithm could not find anything appropriate
+ // in the timeframe set)
+ if (bestCurrentScore < minimumTreeScore) {
+ return true;
+ }
+
+ // stop when almost all positive examples have been covered
+ if (tryFullCoverage) {
+ return false;
+ } else {
+ int maxPosRemaining = (int) Math.ceil(startPosExamplesSize * 0.05d);
+ return (currentPosExamples.size() <= maxPosRemaining);
+ }
}
private boolean partialSolutionTerminationCriteriaSatisfied(){
@@ -635,6 +678,13 @@
this.coverageBeta = coverageBeta;
}
+ /**
+ * @param posWeight the posWeight to set
+ */
+ public void setPosWeight(double posWeight) {
+ this.posWeight = posWeight;
+ }
+
/* (non-Javadoc)
* @see java.lang.Object#clone()
*/
Modified: trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java
===================================================================
--- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java 2014-05-08 14:39:43 UTC (rev 4264)
+++ trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java 2014-05-08 14:40:15 UTC (rev 4265)
@@ -63,7 +63,7 @@
case FMEASURE :
score = Heuristics.getFScore(tp/(tp+fn), tp/(tp+fp), posExamplesWeight);break;
case PRED_ACC :
- score = (posExamplesWeight * tp + tn) / (posExamplesWeight * (tp + fn) + tn + fp);break;
+ score = (tp + posExamplesWeight * tn) / ((tp + fn) + posExamplesWeight * (tn + fp));break;
case ENTROPY :{
double total = tp + fn;
double pp = tp / total;
Modified: trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/datastructures/impl/QueryTreeImpl.java
===================================================================
--- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/datastructures/impl/QueryTreeImpl.java 2014-05-08 14:39:43 UTC (rev 4264)
+++ trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/datastructures/impl/QueryTreeImpl.java 2014-05-08 14:40:15 UTC (rev 4265)
@@ -161,7 +161,7 @@
label += "Values: " + object.getLiterals();
}
}
- label += object.isResourceNode() + "," + object.isLiteralNode();
+// label += object.isResourceNode() + "," + object.isLiteralNode();
return label;
}
};
@@ -801,13 +801,16 @@
writer.println(ren);
for (QueryTree<N> child : getChildren()) {
Object edge = getEdge(child);
- if (edge != null) {
+ boolean meaningful = !edge.equals(RDF.type.getURI()) || meaningful(child);
+ if (edge != null && meaningful) {
writer.print(sb.toString());
writer.print("--- ");
writer.print(edge);
writer.print(" ---\n");
}
- child.dump(writer, indent);
+ if(meaningful){
+ child.dump(writer, indent);
+ }
}
writer.flush();
// int depth = getPathToRoot().size();
@@ -832,6 +835,23 @@
// writer.flush();
}
+ private boolean meaningful(QueryTree<N> tree){
+ if(tree.isResourceNode() || tree.isLiteralNode()){
+ return true;
+ } else {
+ for (QueryTree<N> child : tree.getChildren()) {
+ Object edge = tree.getEdge(child);
+ if(!edge.equals(RDFS.subClassOf.getURI())){
+ return true;
+ } else if(child.isResourceNode()){
+ return true;
+ } else if(meaningful(child)){
+ return true;
+ }
+ }
+ }
+ return false;
+ }
public List<N> fillDepthFirst() {
List<N> results = new ArrayList<N>();
Modified: trunk/components-core/src/main/java/org/dllearner/utilities/statistics/Stat.java
===================================================================
--- trunk/components-core/src/main/java/org/dllearner/utilities/statistics/Stat.java 2014-05-08 14:39:43 UTC (rev 4264)
+++ trunk/components-core/src/main/java/org/dllearner/utilities/statistics/Stat.java 2014-05-08 14:40:15 UTC (rev 4265)
@@ -75,6 +75,14 @@
}
}
+ public void add(Stat stat){
+ count += stat.count;
+ sum += stat.sum;
+ squareSum += stat.squareSum;
+ min = Math.min(min, stat.min);
+ max = Math.max(max, stat.max);
+ }
+
/**
* Add a number to this object.
*
Modified: trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java
===================================================================
--- trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java 2014-05-08 14:39:43 UTC (rev 4264)
+++ trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java 2014-05-08 14:40:15 UTC (rev 4265)
@@ -26,6 +26,7 @@
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.text.DecimalFormat;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
@@ -60,6 +61,7 @@
import org.dllearner.utilities.statistics.Stat;
import com.google.common.base.Charsets;
+import com.google.common.collect.Lists;
import com.google.common.io.Files;
/**
@@ -100,6 +102,15 @@
public class NestedCrossValidation {
private File outputFile = new File("log/nested-cv.log");
+ DecimalFormat df = new DecimalFormat();
+
+ // overall statistics
+ Stat globalAcc = new Stat();
+ Stat globalF = new Stat();
+ Stat globalRecall = new Stat();
+ Stat globalPrecision = new Stat();
+
+ Map<Double,Stat> globalParaStats = new HashMap<Double,Stat>();
/**
* Entry method, which uses JOptSimple to parse parameters.
@@ -115,8 +126,7 @@
OptionParser parser = new OptionParser();
parser.acceptsAll(asList("h", "?", "help"), "Show help.");
- parser.acceptsAll(asList("c", "conf"), "Conf file to use.").withRequiredArg().ofType(
- File.class);
+ parser.acceptsAll(asList("c", "conf"), "The comma separated list of conffiles to be used.").withRequiredArg().describedAs("file1, file2, ...");
parser.acceptsAll(asList( "v", "verbose"), "Be more verbose.");
parser.acceptsAll(asList( "o", "outerfolds"), "Number of outer folds.").withRequiredArg().ofType(Integer.class).describedAs("#folds");
parser.acceptsAll(asList( "i", "innerfolds"), "Number of inner folds.").withRequiredArg().ofType(Integer.class).describedAs("#folds");
@@ -139,7 +149,12 @@
// all options present => start nested cross validation
} else if(options.has("c") && options.has("o") && options.has("i") && options.has("p") && options.has("r")) {
// read all options in variables and parse option values
- File confFile = (File) options.valueOf("c");
+ String confFilesString = (String) options.valueOf("c");
+ List<File> confFiles = new ArrayList<File>();
+ for (String fileString : confFilesString.split(",")) {
+ confFiles.add(new File(fileString.trim()));
+ }
+
int outerFolds = (Integer) options.valueOf("o");
int innerFolds = (Integer) options.valueOf("i");
String parameter = (String) options.valueOf("p");
@@ -164,7 +179,7 @@
java.util.logging.Logger.getLogger("").setLevel(java.util.logging.Level.WARNING);
System.out.println("Warning: The script is not well tested yet. (No known bugs, but needs more testing.)");
- new NestedCrossValidation(confFile, outerFolds, innerFolds, parameter, rangeStart, rangeEnd, stepsize, verbose);
+ new NestedCrossValidation(confFiles, outerFolds, innerFolds, parameter, rangeStart, rangeEnd, stepsize, verbose);
// an option is missing => print help screen and message
} else {
@@ -182,16 +197,52 @@
}
System.out.println(s);
}
+
+ public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException {
+ this(Lists.newArrayList(confFile), outerFolds, innerFolds, parameter, startValue, endValue, stepsize, verbose);
+ }
- public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException {
+ public NestedCrossValidation(List<File> confFiles, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException {
- DecimalFormat df = new DecimalFormat();
- ComponentManager cm = ComponentManager.getInstance();
+ for (File confFile : confFiles) {
+ print(confFile.getPath());
+ validate(confFile, outerFolds, innerFolds, parameter, startValue, endValue, stepsize, verbose);
+ }
+ print("********************************************");
+ print("********************************************");
+ print("********************************************");
+
+ // decide for the best parameter
+ print(" Summary over parameter values:");
+ double bestPara = startValue;
+ double bestValue = Double.NEGATIVE_INFINITY;
+ for (Entry<Double, Stat> entry : globalParaStats.entrySet()) {
+ double para = entry.getKey();
+ Stat stat = entry.getValue();
+ print(" value " + para + ": " + stat.prettyPrint("%"));
+ if (stat.getMean() > bestValue) {
+ bestPara = para;
+ bestValue = stat.getMean();
+ }
+ }
+ print(" selected " + bestPara + " as best parameter value (criterion value " + df.format(bestValue) + "%)");
+
+ // overall statistics
+ print("*******************");
+ print("* Overall Results *");
+ print("*******************");
+ print("accuracy: " + globalAcc.prettyPrint("%"));
+ print("F measure: " + globalF.prettyPrint("%"));
+ print("precision: " + globalPrecision.prettyPrint("%"));
+ print("recall: " + globalRecall.prettyPrint("%"));
+
+ }
+
+ private void validate(File confFile, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws IOException, ComponentInitException{
CLI start = new CLI(confFile);
start.init();
AbstractLearningProblem lp = start.getLearningProblem();
- System.out.println(lp);
if(!(lp instanceof PosNegLP)) {
System.out.println("Positive only learning not supported yet.");
System.exit(0);
@@ -213,7 +264,7 @@
Stat accOverall = new Stat();
Stat fOverall = new Stat();
Stat recallOverall = new Stat();
- Stat precisionOverall = new Stat();
+ Stat precisionOverall = new Stat();
for(int currOuterFold=0; currOuterFold<outerFolds; currOuterFold++) {
@@ -302,11 +353,15 @@
// free memory
rs.releaseKB();
- cm.freeAllComponents();
}
paraStats.put(currParaValue, paraCriterionStat);
-
+ Stat globalParaStat = globalParaStats.get(currParaValue);
+ if(globalParaStat == null){
+ globalParaStat = new Stat();
+ globalParaStats.put(currParaValue, globalParaStat);
+ }
+ globalParaStat.add(paraCriterionStat);
}
// decide for the best parameter
@@ -382,9 +437,13 @@
// free memory
rs.releaseKB();
- cm.freeAllComponents();
}
+ globalAcc.add(accOverall);
+ globalF.add(fOverall);
+ globalPrecision.add(precisionOverall);
+ globalRecall.add(recallOverall);
+
// overall statistics
print("*******************");
print("* Overall Results *");
@@ -393,7 +452,6 @@
print("F measure: " + fOverall.prettyPrint("%"));
print("precision: " + precisionOverall.prettyPrint("%"));
print("recall: " + recallOverall.prettyPrint("%"));
-
}
// convenience methods, which takes a list of examples and divides them in
Modified: trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java
===================================================================
--- trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java 2014-05-08 14:39:43 UTC (rev 4264)
+++ trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java 2014-05-08 14:40:15 UTC (rev 4265)
@@ -496,13 +496,13 @@
lp.setReasoner(reasoner);
lp.init();
QTL2Disjunctive la = new QTL2Disjunctive(lp, reasoner);
- la.init();
- la.start();
+// la.init();
+// la.start();
CrossValidation.outputFile = new File("log/qtl-cv.log");
CrossValidation.writeToFile = true;
CrossValidation.multiThreaded = multiThreaded;
-// CrossValidation cv = new CrossValidation(la, lp, reasoner, nrOfFolds, false);
+ CrossValidation cv = new CrossValidation(la, lp, reasoner, nrOfFolds, false);
long endTime = System.currentTimeMillis();
System.err.println((endTime - startTime) + "ms");
}
This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site.
|