|
From: <lor...@us...> - 2014-05-08 11:45:49
|
Revision: 4263
http://sourceforge.net/p/dl-learner/code/4263
Author: lorenz_b
Date: 2014-05-08 11:45:45 +0000 (Thu, 08 May 2014)
Log Message:
-----------
Added more termination criteria for QTL algorithm. Added heuristics for tree score.
Modified Paths:
--------------
trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java
trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java
trunk/components-core/src/main/java/org/dllearner/learningproblems/Heuristics.java
trunk/components-core/src/main/java/org/dllearner/learningproblems/QueryTreeScore.java
trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java
trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java
Modified: trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java
===================================================================
--- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java 2014-05-07 13:43:54 UTC (rev 4262)
+++ trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QTL2Disjunctive.java 2014-05-08 11:45:45 UTC (rev 4263)
@@ -2,6 +2,7 @@
import java.io.ByteArrayInputStream;
import java.io.IOException;
+import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
@@ -61,6 +62,7 @@
private static final Logger logger = Logger.getLogger(QTL2Disjunctive.class.getName());
+ private final DecimalFormat df = new DecimalFormat("0.00");
private LGGGenerator<String> lggGenerator;
@@ -92,21 +94,31 @@
private EvaluatedDescription currentBestSolution;
+ private QueryTreeHeuristic heuristic;
//Parameters
@ConfigOption(name = "noisePercentage", defaultValue="0.0", description="the (approximated) percentage of noise within the examples")
private double noisePercentage = 0.0;
@ConfigOption(defaultValue = "10", name = "maxExecutionTimeInSeconds", description = "maximum execution of the algorithm in seconds")
- private int maxExecutionTimeInSeconds = 10;
- private double minimumTreeScore = 0.2;
+ private int maxExecutionTimeInSeconds = -1;
+
private double coverageWeight = 0.8;
private double specifityWeight = 0.1;
- private double noise = 0.3;
- private double coverageBeta = 0.7;
+ private double coverageBeta = 0.5;
- private double posExampleWeight = 1;
+ private double minCoveredPosExamplesFraction = 0.2;
+ // maximum execution time to compute a part of the solution
+ private double maxTreeComputationTimeInSeconds = 60;
+ // how important not to cover negatives
+ private double posWeight = 2;
+ // minimum score a query tree must have to be part of the solution
+ private double minimumTreeScore = 0.2;
+ private long startTime;
+
+ private long partialSolutionStartTime;
+
public QTL2Disjunctive() {}
public QTL2Disjunctive(PosNegLP learningProblem, AbstractReasonerComponent reasoner) throws LearningProblemUnsupportedException{
@@ -142,6 +154,11 @@
lggGenerator = new LGGGeneratorImpl<String>();
+ if(heuristic == null){
+ heuristic = new QueryTreeHeuristic();
+ heuristic.setPosExamplesWeight(2);
+ }
+
logger.info("Initializing...");
treeCache = new QueryTreeCache(model);
tree2Individual = new HashMap<QueryTree<String>, Individual>(lp.getPositiveExamples().size()+lp.getNegativeExamples().size());
@@ -164,6 +181,7 @@
}
private void generateTrees(){
+ logger.info("Generating trees...");
QueryTree<String> queryTree;
for (Individual ind : lp.getPositiveExamples()) {
queryTree = treeCache.getQueryTree(ind.getName());
@@ -175,6 +193,7 @@
tree2Individual.put(queryTree, ind);
currentNegExampleTrees.add(queryTree);
}
+ logger.info("...done.");
}
/* (non-Javadoc)
@@ -183,46 +202,59 @@
@Override
public void start() {
String setup = "Setup:";
- setup += "#Pos. examples:" + currentPosExamples.size();
- setup += "#Neg. examples:" + currentNegExamples.size();
- setup += "Coverage beta:" + coverageBeta;
+ setup += "\n#Pos. examples:" + currentPosExamples.size();
+ setup += "\n#Neg. examples:" + currentNegExamples.size();
+ setup += "\nCoverage beta:" + coverageBeta;
logger.info(setup);
logger.info("Running...");
- long startTime = System.currentTimeMillis();
+ startTime = System.currentTimeMillis();
reset();
int i = 1;
- do {
+ while(!terminationCriteriaSatisfied()){
logger.info(i++ + ". iteration...");
logger.info("#Remaining pos. examples:" + currentPosExampleTrees.size());
logger.info("#Remaining neg. examples:" + currentNegExampleTrees.size());
- //compute LGG
- computeLGG();
+ //compute a (partial) solution
+ computeNextPartialSolution();
//pick best (partial) solution computed so far
EvaluatedQueryTree<String> bestPartialSolution = currentPartialSolutions.first();
- partialSolutions.add(bestPartialSolution);
- //remove all covered examples
- QueryTree<String> tree;
- for (Iterator<QueryTree<String>> iterator = currentPosExampleTrees.iterator(); iterator.hasNext();) {
- tree = iterator.next();
- if(tree.isSubsumedBy(bestPartialSolution.getTree())){
- iterator.remove();
- currentPosExamples.remove(tree2Individual.get(tree));
+ //add if some criteria are satisfied
+ if(bestPartialSolution.getScore() >= minimumTreeScore){
+
+ partialSolutions.add(bestPartialSolution);
+
+ //remove all covered examples
+ QueryTree<String> tree;
+ for (Iterator<QueryTree<String>> iterator = currentPosExampleTrees.iterator(); iterator.hasNext();) {
+ tree = iterator.next();
+ if(tree.isSubsumedBy(bestPartialSolution.getTree())){
+ iterator.remove();
+ currentPosExamples.remove(tree2Individual.get(tree));
+ }
}
- }
- for (Iterator<QueryTree<String>> iterator = currentNegExampleTrees.iterator(); iterator.hasNext();) {
- tree = iterator.next();
- if(tree.isSubsumedBy(bestPartialSolution.getTree())){
- iterator.remove();
- currentNegExamples.remove(tree2Individual.get(tree));
+ for (Iterator<QueryTree<String>> iterator = currentNegExampleTrees.iterator(); iterator.hasNext();) {
+ tree = iterator.next();
+ if(tree.isSubsumedBy(bestPartialSolution.getTree())){
+ iterator.remove();
+ currentNegExamples.remove(tree2Individual.get(tree));
+ }
}
+ //build the current combined solution
+ currentBestSolution = buildCombinedSolution();
+
+ logger.info("combined accuracy: " + df.format(currentBestSolution.getAccuracy()));
+ } else {
+ logger.info("no tree found, which satisfies the minimum criteria - the best was: "
+ + currentBestSolution.getDescription().toManchesterSyntaxString(baseURI, prefixes)
+ + " with score " + currentBestSolution.getScore());
}
- currentBestSolution = buildCombinedSolution();
- } while (!(stop || currentPosExampleTrees.isEmpty()));
+
+ };
isRunning = false;
@@ -234,6 +266,119 @@
}
+ private void computeNextPartialSolution(){
+ logger.info("Computing best partial solution...");
+ currentlyBestScore = 0d;
+ partialSolutionStartTime = System.currentTimeMillis();
+ initTodoList(currentPosExampleTrees, currentNegExampleTrees);
+
+ EvaluatedQueryTree<String> currentElement;
+ while(!partialSolutionTerminationCriteriaSatisfied()){
+ logger.trace("TODO list size: " + todoList.size());
+ //pick best element from todo list
+ currentElement = todoList.poll();
+ //generate the LGG between the chosen tree and each uncovered positive example
+ for (QueryTree<String> example : currentElement.getFalseNegatives()) {
+ QueryTree<String> tree = currentElement.getTree();
+
+ //compute the LGG
+ lggMon.start();
+ QueryTree<String> lgg = lggGenerator.getLGG(tree, example);
+ lggMon.stop();
+
+ //evaluate the LGG
+ EvaluatedQueryTree<String> solution = evaluate(lgg, true);
+ double score = solution.getScore();
+ double mas = heuristic.getMaximumAchievableScore(solution);
+
+ if(score >= currentlyBestScore){
+ //add to todo list, if not already contained in todo list or solution list
+ todo(solution);
+ if(solution.getScore() > currentlyBestScore){
+ logger.info("Got better solution:" + solution.getTreeScore());
+ }
+ currentlyBestScore = solution.getScore();
+ } else if(mas < currentlyBestScore){
+ todo(solution);
+ } else {
+ System.out.println("Too general");
+ }
+ currentPartialSolutions.add(currentElement);
+
+ }
+ currentPartialSolutions.add(currentElement);
+ }
+ long endTime = System.currentTimeMillis();
+ logger.info("...finished in " + (endTime-partialSolutionStartTime) + "ms.");
+ EvaluatedDescription bestPartialSolution = currentPartialSolutions.first().asEvaluatedDescription();
+
+ logger.info("Best partial solution:\n" + OWLAPIConverter.getOWLAPIDescription(bestPartialSolution.getDescription()) + "\n(" + bestPartialSolution.getScore() + ")");
+
+ logger.trace("LGG time: " + lggMon.getTotal() + "ms");
+ logger.trace("Avg. LGG time: " + lggMon.getAvg() + "ms");
+ logger.info("#LGG computations: " + lggMon.getHits());
+ logger.trace("Subsumption test time: " + subMon.getTotal() + "ms");
+ logger.trace("Avg. subsumption test time: " + subMon.getAvg() + "ms");
+ logger.trace("#Subsumption tests: " + subMon.getHits());
+ }
+
+ private EvaluatedQueryTree<String> evaluate(QueryTree<String> tree, boolean useSpecifity){
+ //1. get a score for the coverage = recall oriented
+ //compute positive examples which are not covered by LGG
+ List<QueryTree<String>> uncoveredPositiveExampleTrees = getUncoveredTrees(tree, currentPosExampleTrees);
+ Set<Individual> uncoveredPosExamples = new TreeSet<Individual>();
+ for (QueryTree<String> queryTree : uncoveredPositiveExampleTrees) {
+ uncoveredPosExamples.add(tree2Individual.get(queryTree));
+ }
+ //compute negative examples which are covered by LGG
+ Collection<QueryTree<String>> coveredNegativeExampleTrees = getCoveredTrees(tree, currentNegExampleTrees);
+ Set<Individual> coveredNegExamples = new TreeSet<Individual>();
+ for (QueryTree<String> queryTree : coveredNegativeExampleTrees) {
+ coveredNegExamples.add(tree2Individual.get(queryTree));
+ }
+ //compute score
+ int coveredPositiveExamples = currentPosExampleTrees.size() - uncoveredPositiveExampleTrees.size();
+ double recall = coveredPositiveExamples / (double)currentPosExampleTrees.size();
+ double precision = (coveredNegativeExampleTrees.size() + coveredPositiveExamples == 0)
+ ? 0
+ : coveredPositiveExamples / (double)(coveredPositiveExamples + coveredNegativeExampleTrees.size());
+
+ double coverageScore = Heuristics.getFScore(recall, precision, coverageBeta);
+
+ //2. get a score for the specifity of the query, i.e. how many edges/nodes = precision oriented
+ int nrOfSpecificNodes = 0;
+ for (QueryTree<String> childNode : tree.getChildrenClosure()) {
+ if(!childNode.getUserObject().equals("?")){
+ nrOfSpecificNodes++;
+ }
+ }
+ double specifityScore = 0d;
+ if(useSpecifity){
+ specifityScore = Math.log(nrOfSpecificNodes);
+ }
+
+ //3.compute the total score
+ double score = coverageWeight * coverageScore + specifityWeight * specifityScore;
+
+ QueryTreeScore queryTreeScore = new QueryTreeScore(score, coverageScore,
+ new TreeSet<Individual>(Sets.difference(currentPosExamples, uncoveredPosExamples)), uncoveredPosExamples,
+ coveredNegExamples, new TreeSet<Individual>(Sets.difference(currentNegExamples, coveredNegExamples)),
+ specifityScore, nrOfSpecificNodes);
+
+// QueryTreeScore queryTreeScore = new QueryTreeScore(score, coverageScore,
+// null,null,null,null,
+// specifityScore, nrOfSpecificNodes);
+
+ EvaluatedQueryTree<String> evaluatedTree = new EvaluatedQueryTree<String>(tree, uncoveredPositiveExampleTrees, coveredNegativeExampleTrees, queryTreeScore);
+
+ //TODO use only the heuristic to compute the score
+ score = heuristic.getScore(evaluatedTree);
+ queryTreeScore.setScore(score);
+ queryTreeScore.setAccuracy(score);
+
+ return evaluatedTree;
+ }
+
private EvaluatedDescription buildCombinedSolution(){
if(partialSolutions.size() == 1){
EvaluatedDescription combinedSolution = partialSolutions.get(0).asEvaluatedDescription();
@@ -288,58 +433,8 @@
lggMon.reset();
}
- private void computeLGG(){
- logger.info("Computing best partial solution...");
- currentlyBestScore = 0d;
-
- initTodoList(currentPosExampleTrees, currentNegExampleTrees);
-
- long startTime = System.currentTimeMillis();
- EvaluatedQueryTree<String> currentElement;
- do{
- logger.trace("TODO list size: " + todoList.size());
- //pick best element from todo list
- currentElement = todoList.poll();
- //generate the LGG between the chosen tree and each uncovered positive example
- for (QueryTree<String> example : currentElement.getFalseNegatives()) {
- QueryTree<String> tree = currentElement.getTree();
-
- //compute the LGG
- lggMon.start();
- QueryTree<String> lgg = lggGenerator.getLGG(tree, example);
- lggMon.stop();
-
- //evaluate the LGG
- EvaluatedQueryTree<String> solution = evaluate(lgg, true);
-
- if(solution.getScore() >= currentlyBestScore){
- //add to todo list, if not already contained in todo list or solution list
- todo(solution);
- if(solution.getScore() > currentlyBestScore){
- logger.info("Got better solution:" + solution.getTreeScore());
- }
- currentlyBestScore = solution.getScore();
- }
- currentPartialSolutions.add(currentElement);
-
- }
- currentPartialSolutions.add(currentElement);
-// todoList.remove(currentElement);
- } while(!terminationCriteriaSatisfied());
- long endTime = System.currentTimeMillis();
- logger.info("...finished in " + (endTime-startTime) + "ms.");
- EvaluatedDescription bestPartialSolution = currentPartialSolutions.first().asEvaluatedDescription();
-
- logger.info("Best partial solution:\n" + OWLAPIConverter.getOWLAPIDescription(bestPartialSolution.getDescription()) + "\n(" + bestPartialSolution.getScore() + ")");
-
- logger.trace("LGG time: " + lggMon.getTotal() + "ms");
- logger.trace("Avg. LGG time: " + lggMon.getAvg() + "ms");
- logger.trace("#LGG computations: " + lggMon.getHits());
- logger.trace("Subsumption test time: " + subMon.getTotal() + "ms");
- logger.trace("Avg. subsumption test time: " + subMon.getAvg() + "ms");
- logger.trace("#Subsumption tests: " + subMon.getHits());
- }
+
/* (non-Javadoc)
* @see org.dllearner.core.StoppableLearningAlgorithm#stop()
*/
@@ -411,58 +506,7 @@
return treeCache;
}
- private EvaluatedQueryTree<String> evaluate(QueryTree<String> tree, boolean useSpecifity){
- //1. get a score for the coverage = recall oriented
- //compute positive examples which are not covered by LGG
- List<QueryTree<String>> uncoveredPositiveExampleTrees = getUncoveredTrees(tree, currentPosExampleTrees);
- Set<Individual> uncoveredPosExamples = new TreeSet<Individual>();
- for (QueryTree<String> queryTree : uncoveredPositiveExampleTrees) {
- uncoveredPosExamples.add(tree2Individual.get(queryTree));
- }
- //compute negative examples which are covered by LGG
- Collection<QueryTree<String>> coveredNegativeExampleTrees = getCoveredTrees(tree, currentNegExampleTrees);
- Set<Individual> coveredNegExamples = new TreeSet<Individual>();
- for (QueryTree<String> queryTree : coveredNegativeExampleTrees) {
- coveredNegExamples.add(tree2Individual.get(queryTree));
- }
- //compute score
- int coveredPositiveExamples = currentPosExampleTrees.size() - uncoveredPositiveExampleTrees.size();
- double recall = coveredPositiveExamples / (double)currentPosExampleTrees.size();
- double precision = (coveredNegativeExampleTrees.size() + coveredPositiveExamples == 0)
- ? 0
- : coveredPositiveExamples / (double)(coveredPositiveExamples + coveredNegativeExampleTrees.size());
-
- double beta = 0.5;
- double coverageScore = Heuristics.getFScore(recall, precision, beta);
-
- //2. get a score for the specifity of the query, i.e. how many edges/nodes = precision oriented
- int nrOfSpecificNodes = 0;
- for (QueryTree<String> childNode : tree.getChildrenClosure()) {
- if(!childNode.getUserObject().equals("?")){
- nrOfSpecificNodes++;
- }
- }
- double specifityScore = 0d;
- if(useSpecifity){
- specifityScore = Math.log(nrOfSpecificNodes);
- }
-
- //3.compute the total score
- double score = coverageWeight * coverageScore + specifityWeight * specifityScore;
-
- QueryTreeScore queryTreeScore = new QueryTreeScore(score, coverageScore,
- new TreeSet<Individual>(Sets.difference(currentPosExamples, uncoveredPosExamples)), uncoveredPosExamples,
- coveredNegExamples, new TreeSet<Individual>(Sets.difference(currentNegExamples, coveredNegExamples)),
- specifityScore, nrOfSpecificNodes);
-
-// QueryTreeScore queryTreeScore = new QueryTreeScore(score, coverageScore,
-// null,null,null,null,
-// specifityScore, nrOfSpecificNodes);
-
- EvaluatedQueryTree<String> evaluatedTree = new EvaluatedQueryTree<String>(tree, uncoveredPositiveExampleTrees, coveredNegativeExampleTrees, queryTreeScore);
-
- return evaluatedTree;
- }
+
/**
* Return all trees from the given list {@code allTrees} which are not already subsumed by {@code tree}.
@@ -535,9 +579,21 @@
}
private boolean terminationCriteriaSatisfied(){
- return stop || todoList.isEmpty() || currentPosExampleTrees.isEmpty();
+ return stop || isTimeExpired() || currentPosExampleTrees.isEmpty();
}
+ private boolean partialSolutionTerminationCriteriaSatisfied(){
+ return stop || todoList.isEmpty() || currentPosExampleTrees.isEmpty() || isPartialSolutionTimeExpired() || isTimeExpired();
+ }
+
+ private boolean isTimeExpired(){
+ return maxExecutionTimeInSeconds <= 0 ? false : (System.currentTimeMillis() - startTime)/1000d >= maxExecutionTimeInSeconds;
+ }
+
+ private boolean isPartialSolutionTimeExpired(){
+ return maxTreeComputationTimeInSeconds <= 0 ? false : (System.currentTimeMillis() - partialSolutionStartTime)/1000d >= maxTreeComputationTimeInSeconds;
+ }
+
/**
* Add tree to todo list if not already contained in that list or the solutions.
* @param solution
Modified: trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java
===================================================================
--- trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java 2014-05-07 13:43:54 UTC (rev 4262)
+++ trunk/components-core/src/main/java/org/dllearner/algorithms/qtl/QueryTreeHeuristic.java 2014-05-08 11:45:45 UTC (rev 4263)
@@ -13,6 +13,7 @@
import org.dllearner.core.Heuristic;
import org.dllearner.core.owl.Individual;
import org.dllearner.learningproblems.Heuristics;
+import org.dllearner.learningproblems.Heuristics.HeuristicType;
import org.dllearner.learningproblems.QueryTreeScore;
import org.dllearner.utilities.owl.ConceptComparator;
@@ -23,6 +24,8 @@
@ComponentAnn(name = "QueryTreeHeuristic", shortName = "qtree_heuristic", version = 0.1)
public class QueryTreeHeuristic extends AbstractComponent implements Heuristic, Comparator<EvaluatedQueryTree<String>>{
+ private HeuristicType heuristicType = HeuristicType.PRED_ACC;
+
// F score beta value
private double coverageBeta = 1;
@@ -30,6 +33,8 @@
private double specifityWeight = 0.1;
+ private double posExamplesWeight = 1;
+
// syntactic comparison as final comparison criterion
private ConceptComparator conceptComparator = new ConceptComparator();
@@ -43,23 +48,88 @@
public double getScore(EvaluatedQueryTree<String> tree){
QueryTreeScore treeScore = tree.getTreeScore();
- //TODO
- double score = treeScore.getScore();
+ Set<Individual> truePositives = treeScore.getCoveredPositives();
+ Set<Individual> trueNegatives = treeScore.getNotCoveredNegatives();
+ Set<Individual> falsePositives = treeScore.getNotCoveredPositives();
+ Set<Individual> falseNegatives = treeScore.getCoveredNegatives();
+ double tp = truePositives.size();
+ double tn = trueNegatives.size();
+ double fp = falsePositives.size();
+ double fn = falseNegatives.size();
+
+ double score = 0;
+ switch(heuristicType){
+ case FMEASURE :
+ score = Heuristics.getFScore(tp/(tp+fn), tp/(tp+fp), posExamplesWeight);break;
+ case PRED_ACC :
+ score = (posExamplesWeight * tp + tn) / (posExamplesWeight * (tp + fn) + tn + fp);break;
+ case ENTROPY :{
+ double total = tp + fn;
+ double pp = tp / total;
+ double pn = fn / total;
+ score = pp * Math.log(pp) + pn * Math.log(pn);
+ break;}
+ case MATTHEWS_CORRELATION :
+ score = (tp * tn - fp * fn) / Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));break;
+ case YOUDEN_INDEX : score = tp / (tp + fn) + tn / (fp + tn) - 1;break;
+ default:
+ break;
+
+ }
+
return score;
}
- private double getPredictedAccuracy(EvaluatedQueryTree<String> tree){
+ /**
+ * Returns the maximum achievable score according to the used score function.
+ * @return
+ */
+ public double getMaximumAchievableScore(EvaluatedQueryTree<String> tree) {
QueryTreeScore treeScore = tree.getTreeScore();
Set<Individual> truePositives = treeScore.getCoveredPositives();
Set<Individual> trueNegatives = treeScore.getNotCoveredNegatives();
Set<Individual> falsePositives = treeScore.getNotCoveredPositives();
Set<Individual> falseNegatives = treeScore.getCoveredNegatives();
- return 0;
+ double tp = truePositives.size();
+ double tn = trueNegatives.size();
+ double fp = falsePositives.size();
+ double fn = falseNegatives.size();
+
+ return getMaximumAchievableScore(tp, tn, fp, fn);
}
+
+ /**
+ * Returns the maximum achievable score according to the used score function.
+ * @param tp
+ * @param tn
+ * @param fp
+ * @param fn
+ * @return
+ */
+ private double getMaximumAchievableScore(double tp, double tn, double fp, double fn) {
+ double mas = 0d;
+ switch (heuristicType) {
+ case FMEASURE:
+ break;
+ case PRED_ACC:
+ mas = (posExamplesWeight * tp + tn - fp) / (posExamplesWeight * (tp + fn) + tn + fp);
+ break;
+ case ENTROPY:
+ break;
+ case MATTHEWS_CORRELATION:
+ break;
+ case YOUDEN_INDEX:
+ break;
+ default:
+ break;
+ }
+ return mas;
+ }
+
/* (non-Javadoc)
* @see java.util.Comparator#compare(java.lang.Object, java.lang.Object)
*/
@@ -75,5 +145,19 @@
return conceptComparator.compare(tree1.asEvaluatedDescription().getDescription(), tree2.asEvaluatedDescription().getDescription());
}
}
+
+ /**
+ * @param heuristicType the heuristicType to set
+ */
+ public void setHeuristicType(HeuristicType heuristicType) {
+ this.heuristicType = heuristicType;
+ }
+
+ /**
+ * @param posExamplesWeight the posExamplesWeight to set
+ */
+ public void setPosExamplesWeight(double posExamplesWeight) {
+ this.posExamplesWeight = posExamplesWeight;
+ }
}
Modified: trunk/components-core/src/main/java/org/dllearner/learningproblems/Heuristics.java
===================================================================
--- trunk/components-core/src/main/java/org/dllearner/learningproblems/Heuristics.java 2014-05-07 13:43:54 UTC (rev 4262)
+++ trunk/components-core/src/main/java/org/dllearner/learningproblems/Heuristics.java 2014-05-08 11:45:45 UTC (rev 4263)
@@ -29,7 +29,7 @@
*/
public class Heuristics {
- public static enum HeuristicType { PRED_ACC, AMEASURE, JACCARD, FMEASURE, GEN_FMEASURE };
+ public static enum HeuristicType { PRED_ACC, AMEASURE, JACCARD, FMEASURE, GEN_FMEASURE, ENTROPY, MATTHEWS_CORRELATION, YOUDEN_INDEX };
/**
* Computes F1-Score.
Modified: trunk/components-core/src/main/java/org/dllearner/learningproblems/QueryTreeScore.java
===================================================================
--- trunk/components-core/src/main/java/org/dllearner/learningproblems/QueryTreeScore.java 2014-05-07 13:43:54 UTC (rev 4262)
+++ trunk/components-core/src/main/java/org/dllearner/learningproblems/QueryTreeScore.java 2014-05-08 11:45:45 UTC (rev 4263)
@@ -16,7 +16,7 @@
private double score;
- private double coverageScore;
+ private double accuracy;
private double specifityScore;
private int nrOfSpecificNodes;
@@ -26,12 +26,12 @@
private Set<Individual> negAsPos;
private Set<Individual> negAsNeg;
- public QueryTreeScore(double score, double coverageScore,
+ public QueryTreeScore(double score, double accuracy,
Set<Individual> posAsPos, Set<Individual> posAsNeg, Set<Individual> negAsPos, Set<Individual> negAsNeg,
double specifityScore, int nrOfSpecificNodes) {
super();
this.score = score;
- this.coverageScore = coverageScore;
+ this.accuracy = accuracy;
this.posAsPos = posAsPos;
this.posAsNeg = posAsNeg;
this.negAsPos = negAsPos;
@@ -46,15 +46,29 @@
public double getScore() {
return score;
}
+
+ /**
+ * @param score the score to set
+ */
+ public void setScore(double score) {
+ this.score = score;
+ }
/* (non-Javadoc)
* @see org.dllearner.core.Score#getAccuracy()
*/
@Override
public double getAccuracy() {
- return score;
+ return accuracy;
}
+ /**
+ * @param accuracy the accuracy to set
+ */
+ public void setAccuracy(double accuracy) {
+ this.accuracy = accuracy;
+ }
+
public Set<Individual> getCoveredNegatives() {
return negAsPos;
}
@@ -77,7 +91,7 @@
@Override
public String toString() {
return score
- + "(coverage=" + coverageScore
+ + "(accuracy=" + accuracy
+ "(+" + posAsPos.size() + "/" + (posAsPos.size() + posAsNeg.size())
+ "|-" + negAsPos.size() + "/" + (negAsPos.size() + negAsNeg.size()) + ")|"
+ "specifity=" + specifityScore + "(" + nrOfSpecificNodes + "))";
Modified: trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java
===================================================================
--- trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java 2014-05-07 13:43:54 UTC (rev 4262)
+++ trunk/scripts/src/main/java/org/dllearner/scripts/NestedCrossValidation.java 2014-05-08 11:45:45 UTC (rev 4263)
@@ -19,44 +19,49 @@
*/
package org.dllearner.scripts;
+import static java.util.Arrays.asList;
+
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
import java.text.DecimalFormat;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
+import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
-import java.util.Map.Entry;
+import joptsimple.OptionParser;
+import joptsimple.OptionSet;
+
+import org.apache.commons.beanutils.PropertyUtils;
import org.apache.log4j.ConsoleAppender;
+import org.apache.log4j.FileAppender;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.SimpleLayout;
-import org.dllearner.cli.Start;
-import org.dllearner.core.ComponentInitException;
-import org.dllearner.core.ComponentManager;
+import org.dllearner.cli.CLI;
import org.dllearner.core.AbstractCELA;
import org.dllearner.core.AbstractLearningProblem;
import org.dllearner.core.AbstractReasonerComponent;
+import org.dllearner.core.ComponentInitException;
+import org.dllearner.core.ComponentManager;
import org.dllearner.core.owl.Description;
import org.dllearner.core.owl.Individual;
import org.dllearner.learningproblems.PosNegLP;
import org.dllearner.parser.ParseException;
import org.dllearner.utilities.Helper;
-import org.dllearner.utilities.datastructures.Datastructures;
import org.dllearner.utilities.datastructures.TrainTestList;
import org.dllearner.utilities.statistics.Stat;
-import joptsimple.OptionParser;
-import joptsimple.OptionSet;
+import com.google.common.base.Charsets;
+import com.google.common.io.Files;
-import static java.util.Arrays.*;
-
/**
* Performs nested cross validation for the given problem. A k fold outer and l
* fold inner cross validation is used. Parameters:
@@ -93,6 +98,8 @@
*
*/
public class NestedCrossValidation {
+
+ private File outputFile = new File("log/nested-cv.log");
/**
* Entry method, which uses JOptSimple to parse parameters.
@@ -115,6 +122,7 @@
parser.acceptsAll(asList( "i", "innerfolds"), "Number of inner folds.").withRequiredArg().ofType(Integer.class).describedAs("#folds");
parser.acceptsAll(asList( "p", "parameter"), "Parameter to vary.").withRequiredArg();
parser.acceptsAll(asList( "r", "pvalues", "range"), "Values of parameter. $x-$y can be used for integer ranges.").withRequiredArg();
+ parser.acceptsAll(asList( "s", "stepsize", "steps"), "Step size of range.").withOptionalArg().ofType(Double.class).defaultsTo(1d);
// parse options and display a message for the user in case of problems
OptionSet options = null;
@@ -137,8 +145,9 @@
String parameter = (String) options.valueOf("p");
String range = (String) options.valueOf("r");
String[] rangeSplit = range.split("-");
- int rangeStart = new Integer(rangeSplit[0]);
- int rangeEnd = new Integer(rangeSplit[1]);
+ double rangeStart = Double.valueOf(rangeSplit[0]);
+ double rangeEnd = Double.valueOf(rangeSplit[1]);
+ double stepsize = (Double) options.valueOf("s");
boolean verbose = options.has("v");
// create logger (a simple logger which outputs
@@ -149,11 +158,13 @@
logger.removeAllAppenders();
logger.addAppender(consoleAppender);
logger.setLevel(Level.WARN);
+ Logger.getLogger("org.dllearner.algorithms").setLevel(Level.INFO);
+// logger.addAppender(new FileAppender(layout, "nested-cv.log", false));
// disable OWL API info output
java.util.logging.Logger.getLogger("").setLevel(java.util.logging.Level.WARNING);
System.out.println("Warning: The script is not well tested yet. (No known bugs, but needs more testing.)");
- new NestedCrossValidation(confFile, outerFolds, innerFolds, parameter, rangeStart, rangeEnd, verbose);
+ new NestedCrossValidation(confFile, outerFolds, innerFolds, parameter, rangeStart, rangeEnd, stepsize, verbose);
// an option is missing => print help screen and message
} else {
@@ -163,14 +174,24 @@
}
- public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, int startValue, int endValue, boolean verbose) throws FileNotFoundException, ComponentInitException, ParseException, org.dllearner.confparser.ParseException {
+ private void print(String s){
+ try {
+ Files.append(s + "\n", outputFile , Charsets.UTF_8);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ System.out.println(s);
+ }
+
+ public NestedCrossValidation(File confFile, int outerFolds, int innerFolds, String parameter, double startValue, double endValue, double stepsize, boolean verbose) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException {
DecimalFormat df = new DecimalFormat();
ComponentManager cm = ComponentManager.getInstance();
- Start start = new Start(confFile);
+ CLI start = new CLI(confFile);
+ start.init();
AbstractLearningProblem lp = start.getLearningProblem();
-
+ System.out.println(lp);
if(!(lp instanceof PosNegLP)) {
System.out.println("Positive only learning not supported yet.");
System.exit(0);
@@ -196,16 +217,16 @@
for(int currOuterFold=0; currOuterFold<outerFolds; currOuterFold++) {
- System.out.println("Outer fold " + currOuterFold);
+ print("Outer fold " + currOuterFold);
TrainTestList posList = posLists.get(currOuterFold);
TrainTestList negList = negLists.get(currOuterFold);
// measure relevant criterion (accuracy, F-measure) over different parameter values
- Map<Integer,Stat> paraStats = new HashMap<Integer,Stat>();
+ Map<Double,Stat> paraStats = new HashMap<Double,Stat>();
- for(int currParaValue=startValue; currParaValue<=endValue; currParaValue++) {
+ for(double currParaValue=startValue; currParaValue<=endValue; currParaValue+=stepsize) {
- System.out.println(" Parameter value " + currParaValue + ":");
+ print(" Parameter value " + currParaValue + ":");
// split train folds again (computation of inner folds for each parameter
// value is redundant, but not a big problem)
List<Individual> trainPosList = posList.getTrainList();
@@ -219,19 +240,24 @@
for(int currInnerFold=0; currInnerFold<innerFolds; currInnerFold++) {
- System.out.println(" Inner fold " + currInnerFold + ":");
+ print(" Inner fold " + currInnerFold + ":");
// get positive & negative examples for training run
Set<Individual> posEx = new TreeSet<Individual>(innerPosLists.get(currInnerFold).getTrainList());
Set<Individual> negEx = new TreeSet<Individual>(innerNegLists.get(currInnerFold).getTrainList());
// read conf file and exchange options for pos/neg examples
// and parameter to optimise
- start = new Start(confFile);
+ start = new CLI(confFile);
+ start.init();
AbstractLearningProblem lpIn = start.getLearningProblem();
- cm.applyConfigEntry(lpIn, "positiveExamples", Datastructures.individualSetToStringSet(posEx));
- cm.applyConfigEntry(lpIn, "negativeExamples", Datastructures.individualSetToStringSet(negEx));
+ ((PosNegLP)lpIn).setPositiveExamples(posEx);
+ ((PosNegLP)lpIn).setNegativeExamples(negEx);
AbstractCELA laIn = start.getLearningAlgorithm();
- cm.applyConfigEntry(laIn, parameter, (double)currParaValue);
+ try {
+ PropertyUtils.setSimpleProperty(laIn, parameter, currParaValue);
+ } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
+ e.printStackTrace();
+ }
lpIn.init();
laIn.init();
@@ -263,15 +289,15 @@
paraCriterionStat.addNumber(accuracy);
- System.out.println(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null));
- System.out.println(" accuracy: " + df.format(accuracy) + "%");
- System.out.println(" precision: " + df.format(precision) + "%");
- System.out.println(" recall: " + df.format(recall) + "%");
- System.out.println(" F measure: " + df.format(fmeasure) + "%");
+ print(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null));
+ print(" accuracy: " + df.format(accuracy) + "%");
+ print(" precision: " + df.format(precision) + "%");
+ print(" recall: " + df.format(recall) + "%");
+ print(" F measure: " + df.format(fmeasure) + "%");
if(verbose) {
- System.out.println(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI));
- System.out.println(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI));
+ print(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI));
+ print(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI));
}
// free memory
@@ -284,28 +310,33 @@
}
// decide for the best parameter
- System.out.println(" Summary over parameter values:");
- int bestPara = startValue;
+ print(" Summary over parameter values:");
+ double bestPara = startValue;
double bestValue = Double.NEGATIVE_INFINITY;
- for(Entry<Integer,Stat> entry : paraStats.entrySet()) {
- int para = entry.getKey();
+ for(Entry<Double,Stat> entry : paraStats.entrySet()) {
+ double para = entry.getKey();
Stat stat = entry.getValue();
- System.out.println(" value " + para + ": " + stat.prettyPrint("%"));
+ print(" value " + para + ": " + stat.prettyPrint("%"));
if(stat.getMean() > bestValue) {
bestPara = para;
bestValue = stat.getMean();
}
}
- System.out.println(" selected " + bestPara + " as best parameter value (criterion value " + df.format(bestValue) + "%)");
- System.out.println(" Learn on Outer fold:");
+ print(" selected " + bestPara + " as best parameter value (criterion value " + df.format(bestValue) + "%)");
+ print(" Learn on Outer fold:");
// start a learning process with this parameter and evaluate it on the outer fold
- start = new Start(confFile);
+ start = new CLI(confFile);
+ start.init();
AbstractLearningProblem lpOut = start.getLearningProblem();
- cm.applyConfigEntry(lpOut, "positiveExamples", Datastructures.individualListToStringSet(posLists.get(currOuterFold).getTrainList()));
- cm.applyConfigEntry(lpOut, "negativeExamples", Datastructures.individualListToStringSet(negLists.get(currOuterFold).getTrainList()));
+ ((PosNegLP)lpOut).setPositiveExamples(new TreeSet<Individual>(posLists.get(currOuterFold).getTrainList()));
+ ((PosNegLP)lpOut).setNegativeExamples(new TreeSet<Individual>(negLists.get(currOuterFold).getTrainList()));
AbstractCELA laOut = start.getLearningAlgorithm();
- cm.applyConfigEntry(laOut, parameter, (double)bestPara);
+ try {
+ PropertyUtils.setSimpleProperty(laOut, parameter, bestPara);
+ } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
+ e.printStackTrace();
+ }
lpOut.init();
laOut.init();
@@ -332,15 +363,15 @@
double recall = 100 * (double) posCorrect.size() / (posCorrect.size() + posError.size());
double fmeasure = 2 * (precision * recall) / (precision + recall);
- System.out.println(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null));
- System.out.println(" accuracy: " + df.format(accuracy) + "%");
- System.out.println(" precision: " + df.format(precision) + "%");
- System.out.println(" recall: " + df.format(recall) + "%");
- System.out.println(" F measure: " + df.format(fmeasure) + "%");
+ print(" hypothesis: " + concept.toManchesterSyntaxString(baseURI, null));
+ print(" accuracy: " + df.format(accuracy) + "%");
+ print(" precision: " + df.format(precision) + "%");
+ print(" recall: " + df.format(recall) + "%");
+ print(" F measure: " + df.format(fmeasure) + "%");
if(verbose) {
- System.out.println(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI));
- System.out.println(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI));
+ print(" false positives (neg. examples classified as pos.): " + formatIndividualSet(posError, baseURI));
+ print(" false negatives (pos. examples classified as neg.): " + formatIndividualSet(negError, baseURI));
}
// update overall statistics
@@ -355,14 +386,13 @@
}
// overall statistics
- System.out.println();
- System.out.println("*******************");
- System.out.println("* Overall Results *");
- System.out.println("*******************");
- System.out.println("accuracy: " + accOverall.prettyPrint("%"));
- System.out.println("F measure: " + fOverall.prettyPrint("%"));
- System.out.println("precision: " + precisionOverall.prettyPrint("%"));
- System.out.println("recall: " + recallOverall.prettyPrint("%"));
+ print("*******************");
+ print("* Overall Results *");
+ print("*******************");
+ print("accuracy: " + accOverall.prettyPrint("%"));
+ print("F measure: " + fOverall.prettyPrint("%"));
+ print("precision: " + precisionOverall.prettyPrint("%"));
+ print("recall: " + recallOverall.prettyPrint("%"));
}
Modified: trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java
===================================================================
--- trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java 2014-05-07 13:43:54 UTC (rev 4262)
+++ trunk/scripts/src/main/java/org/dllearner/scripts/evaluation/QTLEvaluation.java 2014-05-08 11:45:45 UTC (rev 4263)
@@ -53,8 +53,8 @@
public class QTLEvaluation {
int nrOfFolds = 10;
- private int nrOfPosExamples = 100;
- private int nrOfNegExamples = 100;
+ private int nrOfPosExamples = 300;
+ private int nrOfNegExamples = 300;
List<String> posExamples = Lists.newArrayList(
"http://dl-learner.org/carcinogenesis#d1",
@@ -496,13 +496,13 @@
lp.setReasoner(reasoner);
lp.init();
QTL2Disjunctive la = new QTL2Disjunctive(lp, reasoner);
-// la.init();
-// la.start();
+ la.init();
+ la.start();
CrossValidation.outputFile = new File("log/qtl-cv.log");
CrossValidation.writeToFile = true;
CrossValidation.multiThreaded = multiThreaded;
- CrossValidation cv = new CrossValidation(la, lp, reasoner, nrOfFolds, false);
+// CrossValidation cv = new CrossValidation(la, lp, reasoner, nrOfFolds, false);
long endTime = System.currentTimeMillis();
System.err.println((endTime - startTime) + "ms");
}
This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site.
|