From: <fab...@us...> - 2010-12-18 15:59:08
|
Revision: 5325 http://nhibernate.svn.sourceforge.net/nhibernate/?rev=5325&view=rev Author: fabiomaulo Date: 2010-12-18 15:59:02 +0000 (Sat, 18 Dec 2010) Log Message: ----------- Fix of NH-2375 and NH-2381 (thanks to Dean Ward) Modified Paths: -------------- trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs trunk/nhibernate/src/NHibernate/NHibernate.csproj trunk/nhibernate/src/NHibernate.Test/Linq/EagerLoadTests.cs trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs Added Paths: ----------- trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs Added: trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs 2010-12-18 15:59:02 UTC (rev 5325) @@ -0,0 +1,112 @@ +namespace NHibernate.Linq.ReWriters +{ + using System.Collections.Generic; + using System.Linq; + using System.Linq.Expressions; + + using NHibernate.Linq.Visitors; + + using Remotion.Data.Linq; + using Remotion.Data.Linq.Clauses; + using Remotion.Data.Linq.Clauses.Expressions; + using Remotion.Data.Linq.Clauses.ResultOperators; + using Remotion.Data.Linq.Clauses.StreamedData; + using Remotion.Data.Linq.EagerFetching; + + /// <summary> + /// Removes various result operators from a query so that they can be processed at the same + /// tree level as the query itself. + /// </summary> + public class ResultOperatorRewriter : QueryModelVisitorBase + { + private readonly List<ResultOperatorBase> resultOperators = new List<ResultOperatorBase>(); + private IStreamedDataInfo evaluationType; + + private ResultOperatorRewriter() + { + } + + public static ResultOperatorRewriterResult Rewrite(QueryModel queryModel) + { + ResultOperatorRewriter rewriter = new ResultOperatorRewriter(); + + rewriter.VisitQueryModel(queryModel); + + return new ResultOperatorRewriterResult(rewriter.resultOperators, rewriter.evaluationType); + } + + public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) + { + base.VisitMainFromClause(fromClause, queryModel); + + ResultOperatorExpressionRewriter rewriter = new ResultOperatorExpressionRewriter(); + fromClause.TransformExpressions(rewriter.Rewrite); + if (fromClause.FromExpression.NodeType == ExpressionType.Constant) + { + System.Type expressionType = queryModel.MainFromClause.FromExpression.Type; + if (expressionType.IsGenericType && expressionType.GetGenericTypeDefinition() == typeof(NhQueryable<>)) + { + queryModel.MainFromClause.ItemType = expressionType.GetGenericArguments()[0]; + } + } + + this.resultOperators.AddRange(rewriter.ResultOperators); + this.evaluationType = rewriter.EvaluationType; + } + + /// <summary> + /// Rewrites expressions so that they sit in the outermost portion of the query. + /// </summary> + private class ResultOperatorExpressionRewriter : NhExpressionTreeVisitor + { + private static readonly System.Type[] rewrittenTypes = new[] + { + typeof(FetchRequestBase), + typeof(OfTypeResultOperator), + }; + + private readonly List<ResultOperatorBase> resultOperators = new List<ResultOperatorBase>(); + private IStreamedDataInfo evaluationType; + + /// <summary> + /// Gets an <see cref="IEnumerable{T}" /> of <see cref="ResultOperatorBase" /> that were rewritten. + /// </summary> + public IEnumerable<ResultOperatorBase> ResultOperators + { + get { return this.resultOperators; } + } + + /// <summary> + /// Gets the <see cref="IStreamedDataInfo" /> representing the type of data that the operator works upon. + /// </summary> + public IStreamedDataInfo EvaluationType + { + get { return this.evaluationType; } + } + + public Expression Rewrite(Expression expression) + { + Expression rewrittenExpression = this.VisitExpression(expression); + + return rewrittenExpression; + } + + protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + { + this.resultOperators.AddRange( + expression.QueryModel.ResultOperators + .Where(r => rewrittenTypes.Any(t => t.IsAssignableFrom(r.GetType())))); + + this.resultOperators.ForEach(f => expression.QueryModel.ResultOperators.Remove(f)); + this.evaluationType = expression.QueryModel.SelectClause.GetOutputDataInfo(); + + if (expression.QueryModel.ResultOperators.Count == 0) + { + return expression.QueryModel.MainFromClause.FromExpression; + } + + return base.VisitSubQueryExpression(expression); + } + } + } +} Added: trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs 2010-12-18 15:59:02 UTC (rev 5325) @@ -0,0 +1,30 @@ +namespace NHibernate.Linq.ReWriters +{ + using System.Collections.Generic; + + using Remotion.Data.Linq.Clauses; + using Remotion.Data.Linq.Clauses.StreamedData; + + /// <summary> + /// Result of <see cref="ResultOperatorRewriter.Rewrite" />. + /// </summary> + public class ResultOperatorRewriterResult + { + public ResultOperatorRewriterResult(IEnumerable<ResultOperatorBase> rewrittenOperators, IStreamedDataInfo evaluationType) + { + this.RewrittenOperators = rewrittenOperators; + this.EvaluationType = evaluationType; + } + + /// <summary> + /// Gets an <see cref="IEnumerable{T}" /> of <see cref="ResultOperatorBase" /> implementations that were + /// rewritten. + /// </summary> + public IEnumerable<ResultOperatorBase> RewrittenOperators { get; private set; } + + /// <summary> + /// Gets the <see cref="IStreamedDataInfo" /> representing the type of data that the operator works upon. + /// </summary> + public IStreamedDataInfo EvaluationType { get; private set; } + } +} Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2010-12-18 15:40:50 UTC (rev 5324) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2010-12-18 15:59:02 UTC (rev 5325) @@ -45,10 +45,18 @@ // Add left joins for references AddLeftJoinsReWriter.ReWrite(queryModel, parameters.SessionFactory); - var visitor = new QueryModelVisitor(parameters, root, queryModel); - visitor.Visit(); + // rewrite any operators that should be applied on the outer query + // by flattening out the sub-queries that they are located in + ResultOperatorRewriterResult result = ResultOperatorRewriter.Rewrite(queryModel); - return visitor._hqlTree.GetTranslation(); + QueryModelVisitor visitor = new QueryModelVisitor(parameters, root, queryModel) + { + RewrittenOperatorResult = result + }; + + visitor.Visit(); + + return visitor._hqlTree.GetTranslation(); } private readonly IntermediateHqlTree _hqlTree; @@ -59,6 +67,7 @@ public IStreamedDataInfo CurrentEvaluationType { get; private set; } public IStreamedDataInfo PreviousEvaluationType { get; private set; } public QueryModel Model { get; private set; } + public ResultOperatorRewriterResult RewrittenOperatorResult { get; private set; } static QueryModelVisitor() { @@ -102,6 +111,16 @@ HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters), _hqlTree.TreeBuilder.Alias(fromClause.ItemName))); + // apply any result operators that were rewritten + if (RewrittenOperatorResult != null) + { + CurrentEvaluationType = RewrittenOperatorResult.EvaluationType; + foreach (ResultOperatorBase rewrittenOperator in RewrittenOperatorResult.RewrittenOperators) + { + this.VisitResultOperator(rewrittenOperator, queryModel, -1); + } + } + base.VisitMainFromClause(fromClause, queryModel); } Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs 2010-12-18 15:40:50 UTC (rev 5324) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs 2010-12-18 15:59:02 UTC (rev 5325) @@ -1,7 +1,6 @@ using System.Linq.Expressions; using NHibernate.Hql.Ast; using Remotion.Data.Linq.Clauses.ResultOperators; -using Remotion.Data.Linq.Clauses.StreamedData; namespace NHibernate.Linq.Visitors.ResultOperatorProcessors { @@ -11,8 +10,7 @@ public void Process(OfTypeResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) { - Expression source = - queryModelVisitor.CurrentEvaluationType.As<StreamedSequenceInfo>().ItemExpression; + Expression source = queryModelVisitor.Model.SelectClause.GetOutputDataInfo().ItemExpression; tree.AddWhereClause(tree.TreeBuilder.Equality( tree.TreeBuilder.Dot( Modified: trunk/nhibernate/src/NHibernate/NHibernate.csproj =================================================================== --- trunk/nhibernate/src/NHibernate/NHibernate.csproj 2010-12-18 15:40:50 UTC (rev 5324) +++ trunk/nhibernate/src/NHibernate/NHibernate.csproj 2010-12-18 15:59:02 UTC (rev 5325) @@ -257,6 +257,8 @@ <Compile Include="ITransaction.cs" /> <Compile Include="LazyInitializationException.cs" /> <Compile Include="Linq\Functions\DictionaryGenerator.cs" /> + <Compile Include="Linq\ReWriters\ResultOperatorRewriter.cs" /> + <Compile Include="Linq\ReWriters\ResultOperatorRewriterResult.cs" /> <Compile Include="Loader\Loader.cs" /> <Compile Include="Loader\OuterJoinLoader.cs" /> <Compile Include="LockMode.cs" /> Modified: trunk/nhibernate/src/NHibernate.Test/Linq/EagerLoadTests.cs =================================================================== --- trunk/nhibernate/src/NHibernate.Test/Linq/EagerLoadTests.cs 2010-12-18 15:40:50 UTC (rev 5324) +++ trunk/nhibernate/src/NHibernate.Test/Linq/EagerLoadTests.cs 2010-12-18 15:59:02 UTC (rev 5325) @@ -56,12 +56,32 @@ Assert.IsTrue(NHibernateUtil.IsInitialized(x[0].Orders.First().OrderLines)); } - [Test] - public void WhenFetchSuperclassCollectionThenNotThrows() - { - // NH-2277 - session.Executing(s => s.Query<Lizard>().Fetch(x => x.Children).ToList()).NotThrows(); - session.Close(); - } + [Test] + public void WhenFetchSuperclassCollectionThenNotThrows() + { + // NH-2277 + session.Executing(s => s.Query<Lizard>().Fetch(x => x.Children).ToList()).NotThrows(); + session.Close(); + } + + [Test] + public void FetchWithWhere() + { + // NH-2381 + (from p + in session.Query<Product>().Fetch(a => a.Supplier) + where p.ProductId == 1 + select p).ToList(); + } + + [Test] + public void FetchManyWithWhere() + { + // NH-2381 + (from s + in session.Query<Supplier>().FetchMany(a => a.Products) + where s.SupplierId == 1 + select s).ToList(); + } } } Modified: trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs =================================================================== --- trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs 2010-12-18 15:40:50 UTC (rev 5324) +++ trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs 2010-12-18 15:59:02 UTC (rev 5325) @@ -491,5 +491,25 @@ Assert.AreEqual(3, query.Count); } + [Test] + public void OfTypeWithWhereAndProjection() + { + // NH-2375 + (from a + in session.Query<Animal>().OfType<Cat>() + where a.Pregnant + select a.Id).FirstOrDefault(); + } + + [Test] + public void OfTypeWithWhere() + { + // NH-2375 + (from a + in session.Query<Animal>().OfType<Cat>() + where a.Pregnant + select a).FirstOrDefault(); + } + } } This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |