From: <lor...@us...> - 2011-11-02 14:37:45
|
Revision: 3362 http://dl-learner.svn.sourceforge.net/dl-learner/?rev=3362&view=rev Author: lorenz_b Date: 2011-11-02 14:37:35 +0000 (Wed, 02 Nov 2011) Log Message: ----------- Changed handling of count results with value 0. Modified Paths: -------------- trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/learning/SPARQLTemplateBasedLearner.java trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/sparql/Query.java trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/sparql/WeightedQuery.java Modified: trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/learning/SPARQLTemplateBasedLearner.java =================================================================== --- trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/learning/SPARQLTemplateBasedLearner.java 2011-11-02 14:07:30 UTC (rev 3361) +++ trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/learning/SPARQLTemplateBasedLearner.java 2011-11-02 14:37:35 UTC (rev 3362) @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -16,6 +17,7 @@ import java.util.SortedSet; import java.util.TreeSet; +import org.apache.commons.collections.SetUtils; import org.apache.log4j.Logger; import org.dllearner.algorithm.qtl.util.ModelGenerator; import org.dllearner.algorithm.qtl.util.ModelGenerator.Strategy; @@ -31,6 +33,7 @@ import org.dllearner.algorithm.tbsl.sparql.RatedQuery; import org.dllearner.algorithm.tbsl.sparql.SPARQL_Prefix; import org.dllearner.algorithm.tbsl.sparql.SPARQL_QueryType; +import org.dllearner.algorithm.tbsl.sparql.SPARQL_Triple; import org.dllearner.algorithm.tbsl.sparql.Slot; import org.dllearner.algorithm.tbsl.sparql.SlotType; import org.dllearner.algorithm.tbsl.sparql.Template; @@ -43,9 +46,13 @@ import org.dllearner.core.ComponentInitException; import org.dllearner.core.Oracle; import org.dllearner.core.SparqlQueryLearningAlgorithm; +import org.dllearner.core.owl.Description; +import org.dllearner.core.owl.NamedClass; +import org.dllearner.kb.SparqlEndpointKS; import org.dllearner.kb.sparql.ExtractionDBCache; import org.dllearner.kb.sparql.SparqlEndpoint; import org.dllearner.kb.sparql.SparqlQuery; +import org.dllearner.reasoning.SPARQLReasoner; import org.ini4j.InvalidFileFormatException; import org.ini4j.Options; @@ -56,6 +63,7 @@ import com.hp.hpl.jena.rdf.model.Model; import com.hp.hpl.jena.rdf.model.ModelFactory; import com.hp.hpl.jena.sparql.engine.http.QueryEngineHTTP; +import com.hp.hpl.jena.vocabulary.RDFS; import com.jamonapi.Monitor; import com.jamonapi.MonitorFactory; @@ -109,6 +117,8 @@ private Lemmatizer lemmatizer = new LingPipeLemmatizer();// StanfordLemmatizer(); + private SPARQLReasoner reasoner; + public SPARQLTemplateBasedLearner() throws InvalidFileFormatException, FileNotFoundException, IOException{ this(OPTIONS_FILE); } @@ -188,6 +198,9 @@ predicateFilters.add("http://dbpedia.org/ontology/wikiPageWikiLink"); predicateFilters.add("http://dbpedia.org/property/wikiPageUsesTemplate"); modelGenenerator = new ModelGenerator(endpoint, predicateFilters); + + reasoner = new SPARQLReasoner(new SparqlEndpointKS(endpoint)); + reasoner.prepareSubsumptionHierarchy(); } public void setQuestion(String question){ @@ -251,7 +264,7 @@ Set<WeightedQuery> weightedQueries = getWeightedSPARQLQueries(templates); sparqlQueryCandidates = new ArrayList<Query>(); int i = 0; - for(WeightedQuery wQ : weightedQueries){ + for(WeightedQuery wQ : weightedQueries){System.out.println(wQ); sparqlQueryCandidates.add(wQ.getQuery()); if(i == maxTestedQueries){ break; @@ -346,7 +359,7 @@ } private Set<WeightedQuery> getWeightedSPARQLQueries(Set<Template> templates){ - double alpha = 0.7; + double alpha = 0.8; double beta = 1 - alpha; Map<Slot, Set<Allocation>> slot2Allocations = new HashMap<Slot, Set<Allocation>>(); @@ -355,6 +368,7 @@ Set<Allocation> allAllocations; for(Template t : templates){ allAllocations = new HashSet<Allocation>(); + for(Slot slot : t.getSlots()){ Set<Allocation> allocations = computeAllocation(slot); allAllocations.addAll(allocations); @@ -386,25 +400,71 @@ queries.add(new WeightedQuery(cleanQuery)); Set<WeightedQuery> tmp = new HashSet<WeightedQuery>(); + List<Slot> sortedSlots = new ArrayList<Slot>(); + Set<Slot> classSlots = new HashSet<Slot>(); for(Slot slot : t.getSlots()){ + if(slot.getSlotType() == SlotType.CLASS){ + sortedSlots.add(slot); + classSlots.add(slot); + } + } + for(Slot slot : t.getSlots()){ + if(!sortedSlots.contains(slot)){ + sortedSlots.add(slot); + } + } + for(Slot slot : sortedSlots){ if(!slot2Allocations.get(slot).isEmpty()){ for(Allocation a : slot2Allocations.get(slot)){ for(WeightedQuery query : queries){ - if(slot.getSlotType() == SlotType.SYMPROPERTY){ Query reversedQuery = new Query(query.getQuery()); reversedQuery.getTriplesWithVar(slot.getAnchor()).iterator().next().reverse(); - reversedQuery.replaceVarWithURI(slot.getAnchor(), a.getUri()); - WeightedQuery w = new WeightedQuery(reversedQuery); + //check if the query is possible + if(slot.getSlotType() == SlotType.SYMPROPERTY){ + boolean drop = false; + for(SPARQL_Triple triple : query.getQuery().getTriplesWithVar(slot.getAnchor())){ + System.out.println(triple); + for(SPARQL_Triple typeTriple : query.getQuery().getRDFTypeTriples(triple.getValue().getName())){ + System.out.println(typeTriple); + Set<String> ranges = getRanges(a.getUri()); + System.out.println(a); + if(!ranges.isEmpty()){ + Set<String> allRanges = new HashSet<String>(); + for(String range : ranges){ + allRanges.addAll(getSuperClasses(range)); + } + String typeURI = typeTriple.getValue().getName().substring(1,typeTriple.getValue().getName().length()-1); + Set<String> allTypes = getSuperClasses(typeURI); + allTypes.add(typeTriple.getValue().getName()); + System.out.println("RANGES: " + ranges); + System.out.println("TYPES: " + allTypes); + + if(!org.mindswap.pellet.utils.SetUtils.intersects(allRanges, allTypes)){ + drop = true; + } + } + } + } + + if(!drop){ + reversedQuery.replaceVarWithURI(slot.getAnchor(), a.getUri()); + WeightedQuery w = new WeightedQuery(reversedQuery); + double newScore = query.getScore() + a.getScore(); + w.setScore(newScore); + tmp.add(w); + } + + + + + } + Query q = new Query(query.getQuery()); + q.replaceVarWithURI(slot.getAnchor(), a.getUri()); + WeightedQuery w = new WeightedQuery(q); double newScore = query.getScore() + a.getScore(); w.setScore(newScore); tmp.add(w); - } - Query q = new Query(query.getQuery()); - q.replaceVarWithURI(slot.getAnchor(), a.getUri()); - WeightedQuery w = new WeightedQuery(q); - double newScore = query.getScore() + a.getScore(); - w.setScore(newScore); - tmp.add(w); + } } queries.clear(); @@ -425,6 +485,29 @@ } return allQueries; } + +/* + * for(SPARQL_Triple triple : t.getQuery().getTriplesWithVar(slot.getAnchor())){System.out.println(triple); + for(SPARQL_Triple typeTriple : t.getQuery().getRDFTypeTriples(triple.getVariable().getName())){ + System.out.println(typeTriple); + for(Allocation a : allocations){ + Set<String> domains = getDomains(a.getUri()); + System.out.println(a); + System.out.println(domains); + for(Slot s : classSlots){ + if(s.getAnchor().equals(triple.getVariable().getName())){ + for(Allocation all : slot2Allocations.get(s)){ + if(!domains.contains(all.getUri())){ + System.out.println("DROP " + a); + } + } + } + } + } + + + } + */ private Set<Allocation> computeAllocation(Slot slot){ Set<Allocation> allocations = new HashSet<Allocation>(); @@ -829,10 +912,21 @@ logger.info("Testing query:\n" + query); List<String> results = getResultFromRemoteEndpoint(query); if(!results.isEmpty()){ - learnedSPARQLQueries.put(query, results); - if(stopIfQueryResultNotEmpty){ - return; + try{ + int cnt = Integer.parseInt(results.get(0)); + if(cnt > 0){ + learnedSPARQLQueries.put(query, results); + if(stopIfQueryResultNotEmpty){ + return; + } + } + } catch (NumberFormatException e){ + learnedSPARQLQueries.put(query, results); + if(stopIfQueryResultNotEmpty){ + return; + } } + } logger.info("Result: " + results); } @@ -884,7 +978,7 @@ logger.info("Done in " + mon.getLastValue() + "ms."); } - private List<String> getResultFromRemoteEndpoint(String query){System.out.println(query); + private List<String> getResultFromRemoteEndpoint(String query){ List<String> resources = new ArrayList<String>(); try { String queryString = query; @@ -897,7 +991,12 @@ while(rs.hasNext()){ qs = rs.next(); projectionVar = qs.varNames().next(); - resources.add(qs.get(projectionVar).toString()); + if(qs.get(projectionVar).isLiteral()){ + resources.add(qs.get(projectionVar).asLiteral().getLexicalForm()); + } else if(qs.get(projectionVar).isURIResource()){ + resources.add(qs.get(projectionVar).asResource().getURI()); + } + } } catch (Exception e) { logger.error("Query execution failed.", e); @@ -917,8 +1016,42 @@ return resources; } + private Set<String> getDomains(String property){ + Set<String> domains = new HashSet<String>(); + String query = String.format("SELECT ?domain WHERE {<%s> <%s> ?domain}", property, RDFS.domain.getURI()); + ResultSet rs = SparqlQuery.convertJSONtoResultSet(cache.executeSelectQuery(endpoint, query)); + QuerySolution qs; + while(rs.hasNext()){ + qs = rs.next(); + domains.add(qs.getResource("domain").getURI()); + } + + return domains; + } + private Set<String> getRanges(String property){ + Set<String> domains = new HashSet<String>(); + String query = String.format("SELECT ?range WHERE {<%s> <%s> ?range}", property, RDFS.range.getURI()); + ResultSet rs = SparqlQuery.convertJSONtoResultSet(cache.executeSelectQuery(endpoint, query)); + QuerySolution qs; + while(rs.hasNext()){ + qs = rs.next(); + domains.add(qs.getResource("range").getURI()); + } + + return domains; + } + private Set<String> getSuperClasses(String cls){ + Set<String> superClasses = new HashSet<String>(); + for(Description d : reasoner.getClassHierarchy().getSuperClasses(new NamedClass(cls))){ + superClasses.add(((NamedClass)d).getName()); + } + return superClasses; + } + + + /** * @param args @@ -932,7 +1065,9 @@ // Logger.getLogger(HttpClient.class).setLevel(Level.OFF); // Logger.getLogger(HttpMethodBase.class).setLevel(Level.OFF); // String question = "In which programming language is GIMP written?"; - String question = "Who/WP are/VBP the/DT presidents/NNS of/IN the/DT United/NNP States/NNPS"; +// String question = "Who/WP was/VBD the/DT wife/NN of/IN president/NN Lincoln/NNP"; + String question = "Who/WP produced/VBD the/DT most/JJS films/NNS"; +// String question = "Give/VB me/PRP all/DT soccer/NN clubs/NNS in/IN the/DT Premier/NNP League/NNP"; // String question = "Give me all books written by authors influenced by Ernest Hemingway."; SPARQLTemplateBasedLearner learner = new SPARQLTemplateBasedLearner();learner.setUseIdealTagger(true); Modified: trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/sparql/Query.java =================================================================== --- trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/sparql/Query.java 2011-11-02 14:07:30 UTC (rev 3361) +++ trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/sparql/Query.java 2011-11-02 14:37:35 UTC (rev 3362) @@ -423,6 +423,28 @@ } return triples; } + + public List<SPARQL_Triple> getRDFTypeTriples(){ + List<SPARQL_Triple> triples = new ArrayList<SPARQL_Triple>(); + + for(SPARQL_Triple triple : conditions){ + if(triple.getProperty().equals("rdf:type")){ + triples.add(triple); + } + } + return triples; + } + + public List<SPARQL_Triple> getRDFTypeTriples(String var){ + List<SPARQL_Triple> triples = new ArrayList<SPARQL_Triple>(); + + for(SPARQL_Triple triple : conditions){ + if(triple.getProperty().toString().equals("rdf:type") && triple.getVariable().getName().equals(var)){ + triples.add(triple); + } + } + return triples; + } @Override public int hashCode() { Modified: trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/sparql/WeightedQuery.java =================================================================== --- trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/sparql/WeightedQuery.java 2011-11-02 14:07:30 UTC (rev 3361) +++ trunk/components-ext/src/main/java/org/dllearner/algorithm/tbsl/sparql/WeightedQuery.java 2011-11-02 14:37:35 UTC (rev 3362) @@ -33,7 +33,16 @@ return -1; } else if(o.getScore() > this.score){ return 1; - } else return query.toString().compareTo(o.getQuery().toString()); + } else { + int filter = Boolean.valueOf(query.getFilters().isEmpty()).compareTo(Boolean.valueOf(o.getQuery().getFilters().isEmpty())); + if(filter == 0){ + return query.toString().compareTo(o.getQuery().toString()); + } else { + return filter; + } + } + + } @Override This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |