|
From: <fab...@us...> - 2010-12-18 15:59:08
|
Revision: 5325
http://nhibernate.svn.sourceforge.net/nhibernate/?rev=5325&view=rev
Author: fabiomaulo
Date: 2010-12-18 15:59:02 +0000 (Sat, 18 Dec 2010)
Log Message:
-----------
Fix of NH-2375 and NH-2381 (thanks to Dean Ward)
Modified Paths:
--------------
trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs
trunk/nhibernate/src/NHibernate/NHibernate.csproj
trunk/nhibernate/src/NHibernate.Test/Linq/EagerLoadTests.cs
trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs
Added Paths:
-----------
trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs
trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs
Added: trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs 2010-12-18 15:59:02 UTC (rev 5325)
@@ -0,0 +1,112 @@
+namespace NHibernate.Linq.ReWriters
+{
+ using System.Collections.Generic;
+ using System.Linq;
+ using System.Linq.Expressions;
+
+ using NHibernate.Linq.Visitors;
+
+ using Remotion.Data.Linq;
+ using Remotion.Data.Linq.Clauses;
+ using Remotion.Data.Linq.Clauses.Expressions;
+ using Remotion.Data.Linq.Clauses.ResultOperators;
+ using Remotion.Data.Linq.Clauses.StreamedData;
+ using Remotion.Data.Linq.EagerFetching;
+
+ /// <summary>
+ /// Removes various result operators from a query so that they can be processed at the same
+ /// tree level as the query itself.
+ /// </summary>
+ public class ResultOperatorRewriter : QueryModelVisitorBase
+ {
+ private readonly List<ResultOperatorBase> resultOperators = new List<ResultOperatorBase>();
+ private IStreamedDataInfo evaluationType;
+
+ private ResultOperatorRewriter()
+ {
+ }
+
+ public static ResultOperatorRewriterResult Rewrite(QueryModel queryModel)
+ {
+ ResultOperatorRewriter rewriter = new ResultOperatorRewriter();
+
+ rewriter.VisitQueryModel(queryModel);
+
+ return new ResultOperatorRewriterResult(rewriter.resultOperators, rewriter.evaluationType);
+ }
+
+ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel)
+ {
+ base.VisitMainFromClause(fromClause, queryModel);
+
+ ResultOperatorExpressionRewriter rewriter = new ResultOperatorExpressionRewriter();
+ fromClause.TransformExpressions(rewriter.Rewrite);
+ if (fromClause.FromExpression.NodeType == ExpressionType.Constant)
+ {
+ System.Type expressionType = queryModel.MainFromClause.FromExpression.Type;
+ if (expressionType.IsGenericType && expressionType.GetGenericTypeDefinition() == typeof(NhQueryable<>))
+ {
+ queryModel.MainFromClause.ItemType = expressionType.GetGenericArguments()[0];
+ }
+ }
+
+ this.resultOperators.AddRange(rewriter.ResultOperators);
+ this.evaluationType = rewriter.EvaluationType;
+ }
+
+ /// <summary>
+ /// Rewrites expressions so that they sit in the outermost portion of the query.
+ /// </summary>
+ private class ResultOperatorExpressionRewriter : NhExpressionTreeVisitor
+ {
+ private static readonly System.Type[] rewrittenTypes = new[]
+ {
+ typeof(FetchRequestBase),
+ typeof(OfTypeResultOperator),
+ };
+
+ private readonly List<ResultOperatorBase> resultOperators = new List<ResultOperatorBase>();
+ private IStreamedDataInfo evaluationType;
+
+ /// <summary>
+ /// Gets an <see cref="IEnumerable{T}" /> of <see cref="ResultOperatorBase" /> that were rewritten.
+ /// </summary>
+ public IEnumerable<ResultOperatorBase> ResultOperators
+ {
+ get { return this.resultOperators; }
+ }
+
+ /// <summary>
+ /// Gets the <see cref="IStreamedDataInfo" /> representing the type of data that the operator works upon.
+ /// </summary>
+ public IStreamedDataInfo EvaluationType
+ {
+ get { return this.evaluationType; }
+ }
+
+ public Expression Rewrite(Expression expression)
+ {
+ Expression rewrittenExpression = this.VisitExpression(expression);
+
+ return rewrittenExpression;
+ }
+
+ protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
+ {
+ this.resultOperators.AddRange(
+ expression.QueryModel.ResultOperators
+ .Where(r => rewrittenTypes.Any(t => t.IsAssignableFrom(r.GetType()))));
+
+ this.resultOperators.ForEach(f => expression.QueryModel.ResultOperators.Remove(f));
+ this.evaluationType = expression.QueryModel.SelectClause.GetOutputDataInfo();
+
+ if (expression.QueryModel.ResultOperators.Count == 0)
+ {
+ return expression.QueryModel.MainFromClause.FromExpression;
+ }
+
+ return base.VisitSubQueryExpression(expression);
+ }
+ }
+ }
+}
Added: trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/ResultOperatorRewriterResult.cs 2010-12-18 15:59:02 UTC (rev 5325)
@@ -0,0 +1,30 @@
+namespace NHibernate.Linq.ReWriters
+{
+ using System.Collections.Generic;
+
+ using Remotion.Data.Linq.Clauses;
+ using Remotion.Data.Linq.Clauses.StreamedData;
+
+ /// <summary>
+ /// Result of <see cref="ResultOperatorRewriter.Rewrite" />.
+ /// </summary>
+ public class ResultOperatorRewriterResult
+ {
+ public ResultOperatorRewriterResult(IEnumerable<ResultOperatorBase> rewrittenOperators, IStreamedDataInfo evaluationType)
+ {
+ this.RewrittenOperators = rewrittenOperators;
+ this.EvaluationType = evaluationType;
+ }
+
+ /// <summary>
+ /// Gets an <see cref="IEnumerable{T}" /> of <see cref="ResultOperatorBase" /> implementations that were
+ /// rewritten.
+ /// </summary>
+ public IEnumerable<ResultOperatorBase> RewrittenOperators { get; private set; }
+
+ /// <summary>
+ /// Gets the <see cref="IStreamedDataInfo" /> representing the type of data that the operator works upon.
+ /// </summary>
+ public IStreamedDataInfo EvaluationType { get; private set; }
+ }
+}
Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2010-12-18 15:40:50 UTC (rev 5324)
+++ trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2010-12-18 15:59:02 UTC (rev 5325)
@@ -45,10 +45,18 @@
// Add left joins for references
AddLeftJoinsReWriter.ReWrite(queryModel, parameters.SessionFactory);
- var visitor = new QueryModelVisitor(parameters, root, queryModel);
- visitor.Visit();
+ // rewrite any operators that should be applied on the outer query
+ // by flattening out the sub-queries that they are located in
+ ResultOperatorRewriterResult result = ResultOperatorRewriter.Rewrite(queryModel);
- return visitor._hqlTree.GetTranslation();
+ QueryModelVisitor visitor = new QueryModelVisitor(parameters, root, queryModel)
+ {
+ RewrittenOperatorResult = result
+ };
+
+ visitor.Visit();
+
+ return visitor._hqlTree.GetTranslation();
}
private readonly IntermediateHqlTree _hqlTree;
@@ -59,6 +67,7 @@
public IStreamedDataInfo CurrentEvaluationType { get; private set; }
public IStreamedDataInfo PreviousEvaluationType { get; private set; }
public QueryModel Model { get; private set; }
+ public ResultOperatorRewriterResult RewrittenOperatorResult { get; private set; }
static QueryModelVisitor()
{
@@ -102,6 +111,16 @@
HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters),
_hqlTree.TreeBuilder.Alias(fromClause.ItemName)));
+ // apply any result operators that were rewritten
+ if (RewrittenOperatorResult != null)
+ {
+ CurrentEvaluationType = RewrittenOperatorResult.EvaluationType;
+ foreach (ResultOperatorBase rewrittenOperator in RewrittenOperatorResult.RewrittenOperators)
+ {
+ this.VisitResultOperator(rewrittenOperator, queryModel, -1);
+ }
+ }
+
base.VisitMainFromClause(fromClause, queryModel);
}
Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs 2010-12-18 15:40:50 UTC (rev 5324)
+++ trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessOfType.cs 2010-12-18 15:59:02 UTC (rev 5325)
@@ -1,7 +1,6 @@
using System.Linq.Expressions;
using NHibernate.Hql.Ast;
using Remotion.Data.Linq.Clauses.ResultOperators;
-using Remotion.Data.Linq.Clauses.StreamedData;
namespace NHibernate.Linq.Visitors.ResultOperatorProcessors
{
@@ -11,8 +10,7 @@
public void Process(OfTypeResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree)
{
- Expression source =
- queryModelVisitor.CurrentEvaluationType.As<StreamedSequenceInfo>().ItemExpression;
+ Expression source = queryModelVisitor.Model.SelectClause.GetOutputDataInfo().ItemExpression;
tree.AddWhereClause(tree.TreeBuilder.Equality(
tree.TreeBuilder.Dot(
Modified: trunk/nhibernate/src/NHibernate/NHibernate.csproj
===================================================================
--- trunk/nhibernate/src/NHibernate/NHibernate.csproj 2010-12-18 15:40:50 UTC (rev 5324)
+++ trunk/nhibernate/src/NHibernate/NHibernate.csproj 2010-12-18 15:59:02 UTC (rev 5325)
@@ -257,6 +257,8 @@
<Compile Include="ITransaction.cs" />
<Compile Include="LazyInitializationException.cs" />
<Compile Include="Linq\Functions\DictionaryGenerator.cs" />
+ <Compile Include="Linq\ReWriters\ResultOperatorRewriter.cs" />
+ <Compile Include="Linq\ReWriters\ResultOperatorRewriterResult.cs" />
<Compile Include="Loader\Loader.cs" />
<Compile Include="Loader\OuterJoinLoader.cs" />
<Compile Include="LockMode.cs" />
Modified: trunk/nhibernate/src/NHibernate.Test/Linq/EagerLoadTests.cs
===================================================================
--- trunk/nhibernate/src/NHibernate.Test/Linq/EagerLoadTests.cs 2010-12-18 15:40:50 UTC (rev 5324)
+++ trunk/nhibernate/src/NHibernate.Test/Linq/EagerLoadTests.cs 2010-12-18 15:59:02 UTC (rev 5325)
@@ -56,12 +56,32 @@
Assert.IsTrue(NHibernateUtil.IsInitialized(x[0].Orders.First().OrderLines));
}
- [Test]
- public void WhenFetchSuperclassCollectionThenNotThrows()
- {
- // NH-2277
- session.Executing(s => s.Query<Lizard>().Fetch(x => x.Children).ToList()).NotThrows();
- session.Close();
- }
+ [Test]
+ public void WhenFetchSuperclassCollectionThenNotThrows()
+ {
+ // NH-2277
+ session.Executing(s => s.Query<Lizard>().Fetch(x => x.Children).ToList()).NotThrows();
+ session.Close();
+ }
+
+ [Test]
+ public void FetchWithWhere()
+ {
+ // NH-2381
+ (from p
+ in session.Query<Product>().Fetch(a => a.Supplier)
+ where p.ProductId == 1
+ select p).ToList();
+ }
+
+ [Test]
+ public void FetchManyWithWhere()
+ {
+ // NH-2381
+ (from s
+ in session.Query<Supplier>().FetchMany(a => a.Products)
+ where s.SupplierId == 1
+ select s).ToList();
+ }
}
}
Modified: trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs
===================================================================
--- trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs 2010-12-18 15:40:50 UTC (rev 5324)
+++ trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs 2010-12-18 15:59:02 UTC (rev 5325)
@@ -491,5 +491,25 @@
Assert.AreEqual(3, query.Count);
}
+ [Test]
+ public void OfTypeWithWhereAndProjection()
+ {
+ // NH-2375
+ (from a
+ in session.Query<Animal>().OfType<Cat>()
+ where a.Pregnant
+ select a.Id).FirstOrDefault();
+ }
+
+ [Test]
+ public void OfTypeWithWhere()
+ {
+ // NH-2375
+ (from a
+ in session.Query<Animal>().OfType<Cat>()
+ where a.Pregnant
+ select a).FirstOrDefault();
+ }
+
}
}
This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site.
|