|
From: <ste...@us...> - 2009-09-25 21:05:16
|
Revision: 4725
http://nhibernate.svn.sourceforge.net/nhibernate/?rev=4725&view=rev
Author: steverstrong
Date: 2009-09-25 21:04:56 +0000 (Fri, 25 Sep 2009)
Log Message:
-----------
Further updates to the Linq provider
Modified Paths:
--------------
trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs
trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupByRewriter.cs
trunk/nhibernate/src/NHibernate/Linq/HqlGeneratorExpressionTreeVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/NhExpressionTreeVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/NhThrowingExpressionTreeVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/NonAggregatingGroupByRewriter.cs
trunk/nhibernate/src/NHibernate/Linq/QueryModelVisitor.cs
trunk/nhibernate/src/NHibernate/NHibernate.csproj
trunk/nhibernate/src/NHibernate.Test/Linq/LinqQuerySamples.cs
trunk/nhibernate/src/NHibernate.Test/Linq/ReadonlyTestCase.cs
Added Paths:
-----------
trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupJoinRewriter.cs
trunk/nhibernate/src/NHibernate/Linq/GroupByAggregateDetectionVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/GroupByKeySelectorVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseRewriter.cs
trunk/nhibernate/src/NHibernate/Linq/MergeAggregatingResultsRewriter.cs
trunk/nhibernate/src/NHibernate/Linq/NhNewExpression.cs
Removed Paths:
-------------
trunk/nhibernate/src/NHibernate/Linq/AggregateDetectionVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/GroupBySelectorVisitor.cs
Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-09-24 11:22:59 UTC (rev 4724)
+++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -100,6 +100,11 @@
return new HqlEquality(_factory);
}
+ public HqlEquality Equality(HqlTreeNode lhs, HqlTreeNode rhs)
+ {
+ return new HqlEquality(_factory, lhs, rhs);
+ }
+
public HqlBooleanAnd BooleanAnd()
{
return new HqlBooleanAnd(_factory);
Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-09-24 11:22:59 UTC (rev 4724)
+++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -229,6 +229,11 @@
: base(HqlSqlWalker.EQ, "==", factory)
{
}
+
+ public HqlEquality(IASTFactory factory, HqlTreeNode lhs, HqlTreeNode rhs)
+ : base(HqlSqlWalker.EQ, "==", factory, lhs, rhs)
+ {
+ }
}
public class HqlParameter : HqlTreeNode
Deleted: trunk/nhibernate/src/NHibernate/Linq/AggregateDetectionVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/AggregateDetectionVisitor.cs 2009-09-24 11:22:59 UTC (rev 4724)
+++ trunk/nhibernate/src/NHibernate/Linq/AggregateDetectionVisitor.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -1,54 +0,0 @@
-using System.Linq;
-using System.Linq.Expressions;
-using Remotion.Data.Linq.Clauses.Expressions;
-using Remotion.Data.Linq.Clauses.ResultOperators;
-using Remotion.Data.Linq.Parsing;
-
-namespace NHibernate.Linq
-{
- // TODO: This needs strengthening. For example, it doesn't recurse into SubQueries at present
- internal class AggregateDetectionVisitor : ExpressionTreeVisitor
- {
- public bool ContainsAggregateMethods { get; private set; }
-
- public bool Visit(Expression expression)
- {
- ContainsAggregateMethods = false;
-
- VisitExpression(expression);
-
- return ContainsAggregateMethods;
- }
-
- protected override Expression VisitMethodCallExpression(MethodCallExpression m)
- {
- if (m.Method.DeclaringType == typeof (Queryable) ||
- m.Method.DeclaringType == typeof (Enumerable))
- {
- switch (m.Method.Name)
- {
- case "Count":
- case "Min":
- case "Max":
- case "Sum":
- case "Average":
- ContainsAggregateMethods = true;
- break;
- }
- }
-
- return base.VisitMethodCallExpression(m);
- }
-
- protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
- {
- if (expression.QueryModel.ResultOperators.Count == 1
- && typeof(ValueFromSequenceResultOperatorBase).IsAssignableFrom(expression.QueryModel.ResultOperators[0].GetType()))
- {
- ContainsAggregateMethods = true;
- }
-
- return base.VisitSubQueryExpression(expression);
- }
- }
-}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupByRewriter.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupByRewriter.cs 2009-09-24 11:22:59 UTC (rev 4724)
+++ trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupByRewriter.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -7,44 +7,55 @@
namespace NHibernate.Linq
{
- public class AggregatingGroupByRewriter : QueryModelVisitorBase
+ public class AggregatingGroupByRewriter
{
- public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel)
+ public void ReWrite(QueryModel queryModel)
{
- var subQueryExpression = fromClause.FromExpression as SubQueryExpression;
+ var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression;
if ((subQueryExpression != null) &&
(subQueryExpression.QueryModel.ResultOperators.Count() == 1) &&
(subQueryExpression.QueryModel.ResultOperators[0] is GroupResultOperator) &&
(IsAggregatingGroupBy(queryModel)))
{
- FlattenSubQuery(subQueryExpression, fromClause, queryModel);
+ FlattenSubQuery(subQueryExpression, queryModel.MainFromClause, queryModel);
}
-
- base.VisitMainFromClause(fromClause, queryModel);
}
private static bool IsAggregatingGroupBy(QueryModel queryModel)
{
- return new AggregateDetectionVisitor().Visit(queryModel.SelectClause.Selector);
+ return new GroupByAggregateDetectionVisitor().Visit(queryModel.SelectClause.Selector);
}
private void FlattenSubQuery(SubQueryExpression subQueryExpression, FromClauseBase fromClause,
QueryModel queryModel)
{
+ // Move the result operator up
+ if (queryModel.ResultOperators.Count != 0)
+ {
+ throw new NotImplementedException();
+ }
+
+ var groupBy = (GroupResultOperator) subQueryExpression.QueryModel.ResultOperators[0];
+
// Replace the outer select clause...
- queryModel.SelectClause.TransformExpressions(GroupBySelectClauseVisitor.Visit);
+ queryModel.SelectClause.TransformExpressions(s => GroupBySelectClauseRewriter.ReWrite(s, groupBy, subQueryExpression.QueryModel));
+ queryModel.SelectClause.TransformExpressions(
+ s =>
+ new SwapQuerySourceVisitor(queryModel.MainFromClause, subQueryExpression.QueryModel.MainFromClause).Swap
+ (s));
+
+
MainFromClause innerMainFromClause = subQueryExpression.QueryModel.MainFromClause;
CopyFromClauseData(innerMainFromClause, fromClause);
- // Move the result operator up
- if (queryModel.ResultOperators.Count != 0)
+ foreach (var bodyClause in subQueryExpression.QueryModel.BodyClauses)
{
- throw new NotImplementedException();
+ queryModel.BodyClauses.Add(bodyClause);
}
- queryModel.ResultOperators.Add(subQueryExpression.QueryModel.ResultOperators[0]);
+ queryModel.ResultOperators.Add(groupBy);
}
protected void CopyFromClauseData(FromClauseBase source, FromClauseBase destination)
Added: trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupJoinRewriter.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupJoinRewriter.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupJoinRewriter.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -0,0 +1,321 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Linq.Expressions;
+using Remotion.Data.Linq;
+using Remotion.Data.Linq.Clauses;
+using Remotion.Data.Linq.Clauses.Expressions;
+
+namespace NHibernate.Linq
+{
+ public class AggregatingGroupJoinRewriter
+ {
+ public void ReWrite(QueryModel model)
+ {
+ // We want to take queries like this:
+
+ //var q =
+ // from c in db.Customers
+ // join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords
+ // join e in db.Employees on c.Address.City equals e.Address.City into emps
+ // select new { c.ContactName, ords = ords.Count(), emps = emps.Count() };
+
+ // and turn them into this:
+
+ //var q =
+ // from c in db.Customers
+ // select new
+ // {
+ // c.ContactName,
+ // ords = (from o2 in db.Orders where o2.Customer.CustomerId == c.CustomerId select o2).Count(),
+ // emps = (from e2 in db.Employees where e2.Address.City == c.Address.City select e2).Count()
+ // };
+
+ // so spot a group join where every use of the grouping in the selector is an aggregate
+
+ // firstly, get the group join clauses
+ var groupJoin = model.BodyClauses.Where(bc => bc is GroupJoinClause).Cast<GroupJoinClause>();
+
+ if (groupJoin.Count() == 0)
+ {
+ // No group join here..
+ return;
+ }
+
+ // Now walk the tree to decide which groupings are fully aggregated (and can hence be done in hql)
+ var aggregateDetectorResults = IsAggregatingGroupJoin(model, groupJoin);
+
+ if (aggregateDetectorResults.AggregatingClauses.Count > 0)
+ {
+ // Re-write the select expression
+ model.SelectClause.TransformExpressions(s => GroupJoinSelectClauseRewriter.ReWrite(s, aggregateDetectorResults));
+
+ // Remove the aggregating group joins
+ foreach (GroupJoinClause aggregatingGroupJoin in aggregateDetectorResults.AggregatingClauses)
+ {
+ model.BodyClauses.Remove(aggregatingGroupJoin);
+ }
+ }
+ }
+
+ private static IsAggregatingResults IsAggregatingGroupJoin(QueryModel model, IEnumerable<GroupJoinClause> clause)
+ {
+ return new GroupJoinAggregateDetectionVisitor(clause).Visit(model.SelectClause.Selector);
+ }
+ }
+
+ public class GroupJoinSelectClauseRewriter : NhExpressionTreeVisitor
+ {
+ private readonly IsAggregatingResults _results;
+
+ public static Expression ReWrite(Expression expression, IsAggregatingResults results)
+ {
+ return new GroupJoinSelectClauseRewriter(results).VisitExpression(expression);
+ }
+
+ private GroupJoinSelectClauseRewriter(IsAggregatingResults results)
+ {
+ _results = results;
+ }
+
+ protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
+ {
+ // If the sub queries main (and only) from clause is one of our aggregating group bys, then swap it
+ GroupJoinClause groupJoin = LocateGroupJoinQuerySource(expression.QueryModel);
+
+ if (groupJoin != null)
+ {
+ Expression innerSelector = new SwapQuerySourceVisitor(groupJoin.JoinClause, expression.QueryModel.MainFromClause).
+ Swap(groupJoin.JoinClause.InnerKeySelector);
+
+ expression.QueryModel.MainFromClause.FromExpression = groupJoin.JoinClause.InnerSequence;
+
+
+ // TODO - this only works if the key selectors are not composite. Needs improvement...
+ expression.QueryModel.BodyClauses.Add(new WhereClause(Expression.Equal(innerSelector, groupJoin.JoinClause.OuterKeySelector)));
+ }
+
+ return expression;
+ }
+
+ private GroupJoinClause LocateGroupJoinQuerySource(QueryModel model)
+ {
+ if (model.BodyClauses.Count > 0)
+ {
+ return null;
+ }
+ return new LocateGroupJoinQuerySource(_results).Detect(model.MainFromClause.FromExpression);
+ }
+ }
+
+ public class SwapQuerySourceVisitor : NhExpressionTreeVisitor
+ {
+ private readonly IQuerySource _oldClause;
+ private readonly IQuerySource _newClause;
+
+ public SwapQuerySourceVisitor(IQuerySource oldClause, IQuerySource newClause)
+ {
+ _oldClause = oldClause;
+ _newClause = newClause;
+ }
+
+ public Expression Swap(Expression expression)
+ {
+ return VisitExpression(expression);
+ }
+
+ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
+ {
+ if (expression.ReferencedQuerySource == _oldClause)
+ {
+ return new QuerySourceReferenceExpression(_newClause);
+ }
+
+ // TODO - really don't like this drill down approach. Feels fragile
+ var mainFromClause = expression.ReferencedQuerySource as MainFromClause;
+
+ if (mainFromClause != null)
+ {
+ mainFromClause.FromExpression = VisitExpression(mainFromClause.FromExpression);
+ }
+
+ return expression;
+ }
+
+ protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
+ {
+ expression.QueryModel.TransformExpressions(VisitExpression);
+ return base.VisitSubQueryExpression(expression);
+ }
+ }
+
+ public class LocateGroupJoinQuerySource : NhExpressionTreeVisitor
+ {
+ private readonly IsAggregatingResults _results;
+ private GroupJoinClause _groupJoin;
+
+ public LocateGroupJoinQuerySource(IsAggregatingResults results)
+ {
+ _results = results;
+ }
+
+ public GroupJoinClause Detect(Expression expression)
+ {
+ VisitExpression(expression);
+ return _groupJoin;
+ }
+
+ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
+ {
+ if (_results.AggregatingClauses.Contains(expression.ReferencedQuerySource as GroupJoinClause))
+ {
+ _groupJoin = expression.ReferencedQuerySource as GroupJoinClause;
+ }
+
+ return base.VisitQuerySourceReferenceExpression(expression);
+ }
+ }
+
+ public class IsAggregatingResults
+ {
+ public List<GroupJoinClause> NonAggregatingClauses { get; set; }
+ public List<GroupJoinClause> AggregatingClauses { get; set; }
+ public List<Expression> NonAggregatingExpressions { get; set; }
+ }
+
+ internal class GroupJoinAggregateDetectionVisitor : NhExpressionTreeVisitor
+ {
+ private readonly HashSet<GroupJoinClause> _groupJoinClauses;
+ private readonly StackFlag _inAggregate = new StackFlag();
+ private readonly StackFlag _parentExpressionProcessed = new StackFlag();
+
+ private readonly List<Expression> _nonAggregatingExpressions = new List<Expression>();
+ private readonly List<GroupJoinClause> _nonAggregatingGroupJoins = new List<GroupJoinClause>();
+ private readonly List<GroupJoinClause> _aggregatingGroupJoins = new List<GroupJoinClause>();
+
+ public GroupJoinAggregateDetectionVisitor(IEnumerable<GroupJoinClause> groupJoinClause)
+ {
+ _groupJoinClauses = new HashSet<GroupJoinClause>(groupJoinClause);
+ }
+
+ public IsAggregatingResults Visit(Expression expression)
+ {
+ VisitExpression(expression);
+
+ return new IsAggregatingResults { NonAggregatingClauses = _nonAggregatingGroupJoins, AggregatingClauses = _aggregatingGroupJoins, NonAggregatingExpressions = _nonAggregatingExpressions };
+ }
+
+ protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
+ {
+ VisitExpression(expression.QueryModel.SelectClause.Selector);
+ return expression;
+ }
+
+ protected override Expression VisitNhAverage(NhAverageExpression expression)
+ {
+ using (_inAggregate.SetFlag())
+ {
+ return base.VisitNhAverage(expression);
+ }
+ }
+
+ protected override Expression VisitNhCount(NhCountExpression expression)
+ {
+ using (_inAggregate.SetFlag())
+ {
+ return base.VisitNhCount(expression);
+ }
+ }
+
+ protected override Expression VisitNhMax(NhMaxExpression expression)
+ {
+ using (_inAggregate.SetFlag())
+ {
+ return base.VisitNhMax(expression);
+ }
+ }
+
+ protected override Expression VisitNhMin(NhMinExpression expression)
+ {
+ using (_inAggregate.SetFlag())
+ {
+ return base.VisitNhMin(expression);
+ }
+ }
+
+ protected override Expression VisitNhSum(NhSumExpression expression)
+ {
+ using (_inAggregate.SetFlag())
+ {
+ return base.VisitNhSum(expression);
+ }
+ }
+
+ protected override Expression VisitMemberExpression(MemberExpression expression)
+ {
+ if (_inAggregate.FlagIsFalse && _parentExpressionProcessed.FlagIsFalse)
+ {
+ _nonAggregatingExpressions.Add(expression);
+ }
+
+ using (_parentExpressionProcessed.SetFlag())
+ {
+ return base.VisitMemberExpression(expression);
+ }
+ }
+
+ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
+ {
+ var fromClause = (FromClauseBase) expression.ReferencedQuerySource;
+
+ if (fromClause.FromExpression is QuerySourceReferenceExpression)
+ {
+ var querySourceReference = (QuerySourceReferenceExpression) fromClause.FromExpression;
+
+ if (_groupJoinClauses.Contains(querySourceReference.ReferencedQuerySource as GroupJoinClause))
+ {
+ if (_inAggregate.FlagIsFalse)
+ {
+ _nonAggregatingGroupJoins.Add((GroupJoinClause) querySourceReference.ReferencedQuerySource);
+ }
+ else
+ {
+ _aggregatingGroupJoins.Add((GroupJoinClause) querySourceReference.ReferencedQuerySource);
+ }
+ }
+ }
+
+ return base.VisitQuerySourceReferenceExpression(expression);
+ }
+
+ internal class StackFlag
+ {
+ public bool FlagIsTrue { get; private set; }
+
+ public bool FlagIsFalse { get { return !FlagIsTrue; } }
+
+ public IDisposable SetFlag()
+ {
+ return new StackFlagDisposable(this);
+ }
+
+ internal class StackFlagDisposable : IDisposable
+ {
+ private readonly StackFlag _parent;
+ private readonly bool _old;
+
+ public StackFlagDisposable(StackFlag parent)
+ {
+ _parent = parent;
+ _old = parent.FlagIsTrue;
+ parent.FlagIsTrue = true;
+ }
+
+ public void Dispose()
+ {
+ _parent.FlagIsTrue = _old;
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
Added: trunk/nhibernate/src/NHibernate/Linq/GroupByAggregateDetectionVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/GroupByAggregateDetectionVisitor.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/GroupByAggregateDetectionVisitor.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -0,0 +1,83 @@
+using System.Linq;
+using System.Linq.Expressions;
+using Remotion.Data.Linq.Clauses.Expressions;
+using Remotion.Data.Linq.Clauses.ResultOperators;
+using Remotion.Data.Linq.Parsing;
+
+namespace NHibernate.Linq
+{
+ // TODO: This needs strengthening. Possibly a lot in common with the GroupJoinAggregateDetectionVisitor class, which does many more checks
+ internal class GroupByAggregateDetectionVisitor : NhExpressionTreeVisitor
+ {
+ public bool ContainsAggregateMethods { get; private set; }
+
+ public bool Visit(Expression expression)
+ {
+ ContainsAggregateMethods = false;
+
+ VisitExpression(expression);
+
+ return ContainsAggregateMethods;
+ }
+
+ // TODO - this should not exist, since it should be handled either by re-linq or by the MergeAggregatingResultsRewriter
+ protected override Expression VisitMethodCallExpression(MethodCallExpression m)
+ {
+ if (m.Method.DeclaringType == typeof (Queryable) ||
+ m.Method.DeclaringType == typeof (Enumerable))
+ {
+ switch (m.Method.Name)
+ {
+ case "Count":
+ case "Min":
+ case "Max":
+ case "Sum":
+ case "Average":
+ ContainsAggregateMethods = true;
+ break;
+ }
+ }
+
+ return m;
+ }
+
+ // TODO - having a VisitNhAggregation method or something in the base class would remove this duplication...
+ protected override Expression VisitNhAverage(NhAverageExpression expression)
+ {
+ ContainsAggregateMethods = true;
+ return expression;
+ }
+
+ protected override Expression VisitNhCount(NhCountExpression expression)
+ {
+ ContainsAggregateMethods = true;
+ return expression;
+ }
+
+ protected override Expression VisitNhMax(NhMaxExpression expression)
+ {
+ ContainsAggregateMethods = true;
+ return expression;
+ }
+
+ protected override Expression VisitNhMin(NhMinExpression expression)
+ {
+ ContainsAggregateMethods = true;
+ return expression;
+ }
+
+ protected override Expression VisitNhSum(NhSumExpression expression)
+ {
+ ContainsAggregateMethods = true;
+ return expression;
+ }
+
+ protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
+ {
+ ContainsAggregateMethods =
+ new GroupByAggregateDetectionVisitor().Visit(expression.QueryModel.SelectClause.Selector);
+
+ return expression;
+ }
+ }
+}
\ No newline at end of file
Added: trunk/nhibernate/src/NHibernate/Linq/GroupByKeySelectorVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/GroupByKeySelectorVisitor.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/GroupByKeySelectorVisitor.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -0,0 +1,26 @@
+using System.Linq.Expressions;
+using Remotion.Data.Linq.Clauses.Expressions;
+using Remotion.Data.Linq.Parsing;
+
+namespace NHibernate.Linq
+{
+ internal class GroupByKeySelectorVisitor : ExpressionTreeVisitor
+ {
+ private readonly ParameterExpression _parameter;
+
+ public GroupByKeySelectorVisitor(ParameterExpression parameter)
+ {
+ _parameter = parameter;
+ }
+
+ public Expression Visit(Expression expression)
+ {
+ return VisitExpression(expression);
+ }
+
+ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
+ {
+ return _parameter;
+ }
+ }
+}
\ No newline at end of file
Added: trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseRewriter.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseRewriter.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseRewriter.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -0,0 +1,202 @@
+using System;
+using System.Linq.Expressions;
+using Remotion.Data.Linq;
+using Remotion.Data.Linq.Clauses;
+using Remotion.Data.Linq.Clauses.Expressions;
+using Remotion.Data.Linq.Clauses.ResultOperators;
+
+namespace NHibernate.Linq
+{
+ internal class GroupBySelectClauseRewriter : NhExpressionTreeVisitor
+ {
+ public static Expression ReWrite(Expression expression, GroupResultOperator groupBy, QueryModel model)
+ {
+ var visitor = new GroupBySelectClauseRewriter(groupBy, model);
+ return visitor.VisitExpression(expression);
+ }
+
+ private readonly GroupResultOperator _groupBy;
+ private readonly QueryModel _model;
+
+ public GroupBySelectClauseRewriter(GroupResultOperator groupBy, QueryModel model)
+ {
+ _groupBy = groupBy;
+ _model = model;
+ }
+
+ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
+ {
+ if (expression.ReferencedQuerySource == _groupBy)
+ {
+ return _groupBy.ElementSelector;
+ }
+
+ return base.VisitQuerySourceReferenceExpression(expression);
+ }
+
+ protected override Expression VisitMemberExpression(MemberExpression expression)
+ {
+ if (IsMemberOfModel(expression))
+ {
+ if (expression.Member.Name == "Key")
+ {
+ return _groupBy.KeySelector;
+ }
+ else
+ {
+ Expression elementSelector = _groupBy.ElementSelector;
+
+ if ((elementSelector is MemberExpression) || (elementSelector is QuerySourceReferenceExpression))
+ {
+ // If ElementSelector is MemberExpression, just return
+ return base.VisitMemberExpression(expression);
+ }
+ else if (elementSelector is NewExpression)
+ {
+ // If ElementSelector is NewExpression, then search for member of name "get_" + originalMemberExpression.Member.Name
+ // TODO - this wouldn't handle nested initialisers. Should do a tree walk to find the correct member
+ var nex = elementSelector as NewExpression;
+
+ int i = 0;
+ foreach (var member in nex.Members)
+ {
+ if (member.Name == "get_" + expression.Member.Name)
+ {
+ return nex.Arguments[i];
+ }
+ i++;
+ }
+
+ throw new NotImplementedException();
+ }
+ else
+ {
+ throw new NotImplementedException();
+ }
+ }
+ }
+ else
+ {
+ return base.VisitMemberExpression(expression);
+ }
+ }
+
+ // TODO - dislike this code intensly. Should probably be a tree-walk in its own right
+ private bool IsMemberOfModel(MemberExpression expression)
+ {
+ var querySourceRef = expression.Expression as QuerySourceReferenceExpression;
+
+ if (querySourceRef == null)
+ {
+ return false;
+ }
+
+ var fromClause = querySourceRef.ReferencedQuerySource as FromClauseBase;
+
+ if (fromClause == null)
+ {
+ return false;
+ }
+
+ var subQuery = fromClause.FromExpression as SubQueryExpression;
+
+ if (subQuery != null)
+ {
+ return subQuery.QueryModel == _model;
+ }
+
+ var referencedQuery = fromClause.FromExpression as QuerySourceReferenceExpression;
+
+ if (referencedQuery == null)
+ {
+ return false;
+ }
+
+ var querySource = referencedQuery.ReferencedQuerySource as FromClauseBase;
+
+ var subQuery2 = querySource.FromExpression as SubQueryExpression;
+
+ return (subQuery2.QueryModel == _model);
+ }
+
+ protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
+ {
+ // TODO - is this safe? All we are extracting is the select clause from the sub-query. Assumes that everything
+ // else in the subquery has been removed. If there were two subqueries, one aggregating & one not, this may not be a
+ // valid assumption. Should probably be passed a list of aggregating subqueries that we are flattening so that we can check...
+ return GroupBySelectClauseRewriter.ReWrite(expression.QueryModel.SelectClause.Selector, _groupBy, _model);
+ }
+ }
+
+ public enum NhExpressionType
+ {
+ Average = 10000,
+ Min,
+ Max,
+ Sum,
+ Count,
+ Distinct,
+ New
+ }
+
+ public class NhAggregatedExpression : Expression
+ {
+ public Expression Expression { get; set; }
+
+ public NhAggregatedExpression(Expression expression, NhExpressionType type)
+ : base((ExpressionType)type, expression.Type)
+ {
+ Expression = expression;
+ }
+ }
+
+ public class NhAverageExpression : NhAggregatedExpression
+ {
+ public NhAverageExpression(Expression expression) : base(expression, NhExpressionType.Average)
+ {
+ }
+ }
+
+ public class NhMinExpression : NhAggregatedExpression
+ {
+ public NhMinExpression(Expression expression)
+ : base(expression, NhExpressionType.Min)
+ {
+ }
+ }
+
+ public class NhMaxExpression : NhAggregatedExpression
+ {
+ public NhMaxExpression(Expression expression)
+ : base(expression, NhExpressionType.Max)
+ {
+ }
+ }
+
+ public class NhSumExpression : NhAggregatedExpression
+ {
+ public NhSumExpression(Expression expression)
+ : base(expression, NhExpressionType.Sum)
+ {
+ }
+ }
+
+ public class NhDistinctExpression : NhAggregatedExpression
+ {
+ public NhDistinctExpression(Expression expression)
+ : base(expression, NhExpressionType.Distinct)
+ {
+ }
+ }
+
+ public class NhCountExpression : Expression
+ {
+ public NhCountExpression(Expression expression)
+ : base((ExpressionType)NhExpressionType.Count, typeof(int))
+ {
+ Expression = expression;
+ }
+
+ public Expression Expression { get; private set; }
+ }
+}
Deleted: trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseVisitor.cs 2009-09-24 11:22:59 UTC (rev 4724)
+++ trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseVisitor.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -1,148 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Linq.Expressions;
-using Remotion.Data.Linq.Clauses;
-using Remotion.Data.Linq.Clauses.Expressions;
-using Remotion.Data.Linq.Clauses.ResultOperators;
-using Remotion.Data.Linq.Parsing;
-
-namespace NHibernate.Linq
-{
- internal class GroupBySelectClauseVisitor : ExpressionTreeVisitor
- {
- public static Expression Visit(Expression expression)
- {
- var visitor = new GroupBySelectClauseVisitor();
- return visitor.VisitExpression(expression);
- }
-
- protected override Expression VisitMemberExpression(MemberExpression expression)
- {
- if (expression.Member.Name == "Key" &&
- expression.Member.DeclaringType.GetGenericTypeDefinition() == typeof (IGrouping<,>))
- {
- var querySourceRef = expression.Expression as QuerySourceReferenceExpression;
-
- var fromClause = querySourceRef.ReferencedQuerySource as FromClauseBase;
-
- var subQuery = fromClause.FromExpression as SubQueryExpression;
-
- var groupBy =
- subQuery.QueryModel.ResultOperators.Where(r => r is GroupResultOperator).Single() as
- GroupResultOperator;
-
- return groupBy.KeySelector;
- }
- else
- {
- return base.VisitMemberExpression(expression);
- }
- }
-
- protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
- {
- if (expression.QueryModel.ResultOperators.Count == 1)
- {
- ResultOperatorBase resultOperator = expression.QueryModel.ResultOperators[0];
-
- if (resultOperator is AverageResultOperator)
- {
- return new AverageExpression(expression.QueryModel.SelectClause.Selector);
- }
- else if (resultOperator is MinResultOperator)
- {
- return new MinExpression(expression.QueryModel.SelectClause.Selector);
- }
- else if (resultOperator is MaxResultOperator)
- {
- return new MaxExpression(expression.QueryModel.SelectClause.Selector);
- }
- else if (resultOperator is CountResultOperator)
- {
- return new CountExpression();
- }
- else if (resultOperator is SumResultOperator)
- {
- return new SumExpression(expression.QueryModel.SelectClause.Selector);
- }
- else
- {
- throw new NotImplementedException();
- }
- }
- else
- {
- return base.VisitSubQueryExpression(expression);
- }
- }
- }
-
- public enum NhExpressionType
- {
- Average = 10000,
- Min,
- Max,
- Sum,
- Count,
- Distinct
- }
-
- public class NhAggregatedExpression : Expression
- {
- public Expression Expression { get; set; }
-
- public NhAggregatedExpression(Expression expression, NhExpressionType type)
- : base((ExpressionType)type, expression.Type)
- {
- Expression = expression;
- }
- }
-
- public class AverageExpression : NhAggregatedExpression
- {
- public AverageExpression(Expression expression) : base(expression, NhExpressionType.Average)
- {
- }
- }
-
- public class MinExpression : NhAggregatedExpression
- {
- public MinExpression(Expression expression)
- : base(expression, NhExpressionType.Min)
- {
- }
- }
-
- public class MaxExpression : NhAggregatedExpression
- {
- public MaxExpression(Expression expression)
- : base(expression, NhExpressionType.Max)
- {
- }
- }
-
- public class SumExpression : NhAggregatedExpression
- {
- public SumExpression(Expression expression)
- : base(expression, NhExpressionType.Sum)
- {
- }
- }
-
- public class DistinctExpression : NhAggregatedExpression
- {
- public DistinctExpression(Expression expression)
- : base(expression, NhExpressionType.Distinct)
- {
- }
- }
-
- public class CountExpression : Expression
- {
- public CountExpression()
- : base((ExpressionType)NhExpressionType.Count, typeof(int))
- {
- }
- }
-}
Deleted: trunk/nhibernate/src/NHibernate/Linq/GroupBySelectorVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/GroupBySelectorVisitor.cs 2009-09-24 11:22:59 UTC (rev 4724)
+++ trunk/nhibernate/src/NHibernate/Linq/GroupBySelectorVisitor.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -1,26 +0,0 @@
-using System.Linq.Expressions;
-using Remotion.Data.Linq.Clauses.Expressions;
-using Remotion.Data.Linq.Parsing;
-
-namespace NHibernate.Linq
-{
- internal class GroupBySelectorVisitor : ExpressionTreeVisitor
- {
- private readonly ParameterExpression _parameter;
-
- public GroupBySelectorVisitor(ParameterExpression parameter)
- {
- _parameter = parameter;
- }
-
- public Expression Visit(Expression expression)
- {
- return VisitExpression(expression);
- }
-
- protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
- {
- return _parameter;
- }
- }
-}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/HqlGeneratorExpressionTreeVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/HqlGeneratorExpressionTreeVisitor.cs 2009-09-24 11:22:59 UTC (rev 4724)
+++ trunk/nhibernate/src/NHibernate/Linq/HqlGeneratorExpressionTreeVisitor.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -31,7 +31,7 @@
VisitExpression(expression);
}
- protected override Expression VisitNhAverage(AverageExpression expression)
+ protected override Expression VisitNhAverage(NhAverageExpression expression)
{
var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
visitor.Visit(expression.Expression);
@@ -41,14 +41,17 @@
return expression;
}
- protected override Expression VisitNhCount(CountExpression expression)
+ protected override Expression VisitNhCount(NhCountExpression expression)
{
- _stack.PushLeaf(_hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(_hqlTreeBuilder.RowStar()), expression.Type));
+ var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
+ visitor.Visit(expression.Expression);
+ _stack.PushLeaf(_hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(visitor.GetHqlTreeNodes().Single()), expression.Type));
+
return expression;
}
- protected override Expression VisitNhMin(MinExpression expression)
+ protected override Expression VisitNhMin(NhMinExpression expression)
{
var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
visitor.Visit(expression.Expression);
@@ -58,7 +61,7 @@
return expression;
}
- protected override Expression VisitNhMax(MaxExpression expression)
+ protected override Expression VisitNhMax(NhMaxExpression expression)
{
var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
visitor.Visit(expression.Expression);
@@ -68,7 +71,7 @@
return expression;
}
- protected override Expression VisitNhSum(SumExpression expression)
+ protected override Expression VisitNhSum(NhSumExpression expression)
{
var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
visitor.Visit(expression.Expression);
@@ -78,7 +81,7 @@
return expression;
}
- protected override Expression VisitNhDistinct(DistinctExpression expression)
+ protected override Expression VisitNhDistinct(NhDistinctExpression expression)
{
var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
visitor.Visit(expression.Expression);
Added: trunk/nhibernate/src/NHibernate/Linq/MergeAggregatingResultsRewriter.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/MergeAggregatingResultsRewriter.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/MergeAggregatingResultsRewriter.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -0,0 +1,123 @@
+using System;
+using System.Linq;
+using System.Linq.Expressions;
+using Remotion.Data.Linq;
+using Remotion.Data.Linq.Clauses;
+using Remotion.Data.Linq.Clauses.Expressions;
+using Remotion.Data.Linq.Clauses.ResultOperators;
+using Remotion.Data.Linq.Parsing;
+using Remotion.Data.Linq.Parsing.ExpressionTreeVisitors;
+
+namespace NHibernate.Linq
+{
+ public class MergeAggregatingResultsRewriter : QueryModelVisitorBase
+ {
+ public void ReWrite(QueryModel model)
+ {
+ this.VisitQueryModel(model);
+ }
+
+ public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index)
+ {
+ if (resultOperator is SumResultOperator)
+ {
+ queryModel.SelectClause.Selector = new NhSumExpression(queryModel.SelectClause.Selector);
+ queryModel.ResultOperators.Remove(resultOperator);
+ }
+ else if (resultOperator is AverageResultOperator)
+ {
+ queryModel.SelectClause.Selector = new NhAverageExpression(queryModel.SelectClause.Selector);
+ queryModel.ResultOperators.Remove(resultOperator);
+ }
+ else if (resultOperator is MinResultOperator)
+ {
+ queryModel.SelectClause.Selector = new NhMinExpression(queryModel.SelectClause.Selector);
+ queryModel.ResultOperators.Remove(resultOperator);
+ }
+ else if (resultOperator is MaxResultOperator)
+ {
+ queryModel.SelectClause.Selector = new NhMaxExpression(queryModel.SelectClause.Selector);
+ queryModel.ResultOperators.Remove(resultOperator);
+ }
+ else if (resultOperator is DistinctResultOperator)
+ {
+ queryModel.SelectClause.Selector = new NhDistinctExpression(queryModel.SelectClause.Selector);
+ queryModel.ResultOperators.Remove(resultOperator);
+ }
+ else if (resultOperator is CountResultOperator)
+ {
+ queryModel.SelectClause.Selector = new NhCountExpression(queryModel.SelectClause.Selector);
+ queryModel.ResultOperators.Remove(resultOperator);
+ }
+
+ base.VisitResultOperator(resultOperator, queryModel, index);
+ }
+
+ public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel)
+ {
+ selectClause.TransformExpressions(s => new MergeAggregatingResultsInExpressionRewriter().Visit(s));
+ }
+ }
+
+ internal class MergeAggregatingResultsInExpressionRewriter : NhExpressionTreeVisitor
+ {
+ public Expression Visit(Expression expression)
+ {
+ return VisitExpression(expression);
+ }
+
+ protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
+ {
+ new MergeAggregatingResultsRewriter().ReWrite(expression.QueryModel);
+ return expression;
+ }
+
+ protected override Expression VisitMethodCallExpression(MethodCallExpression m)
+ {
+ if (m.Method.DeclaringType == typeof(Queryable) ||
+ m.Method.DeclaringType == typeof(Enumerable))
+ {
+ // TODO - dynamic name generation needed here
+ switch (m.Method.Name)
+ {
+ case "Count":
+ return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1],
+ e => new NhCountExpression(e));
+ case "Min":
+ return CreateAggregate(m.Arguments[0], (LambdaExpression) m.Arguments[1],
+ e => new NhMinExpression(e));
+ case "Max":
+ return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1],
+ e => new NhMaxExpression(e));
+ case "Sum":
+ return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1],
+ e => new NhSumExpression(e));
+ case "Average":
+ return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1],
+ e => new NhAverageExpression(e));
+ }
+ }
+
+ return base.VisitMethodCallExpression(m);
+ }
+
+ private Expression CreateAggregate(Expression fromClauseExpression, LambdaExpression body, Func<Expression,Expression> factory)
+ {
+ var fromClause = new MainFromClause("x2", body.Parameters[0].Type, fromClauseExpression);
+ var selectClause = body.Body;
+ selectClause = ReplacingExpressionTreeVisitor.Replace(body.Parameters[0],
+ new QuerySourceReferenceExpression(
+ fromClause), selectClause);
+ var queryModel = new QueryModel(fromClause,
+ new SelectClause(factory(selectClause)));
+
+ queryModel.ResultOperators.Add(new AverageResultOperator());
+
+ var subQuery = new SubQueryExpression(queryModel);
+
+ queryModel.ResultOperators.Clear();
+
+ return subQuery;
+ }
+ }
+}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/NhExpressionTreeVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/NhExpressionTreeVisitor.cs 2009-09-24 11:22:59 UTC (rev 4724)
+++ trunk/nhibernate/src/NHibernate/Linq/NhExpressionTreeVisitor.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -1,3 +1,4 @@
+using System;
using System.Linq.Expressions;
using Remotion.Data.Linq.Parsing;
@@ -7,63 +8,79 @@
{
protected override Expression VisitExpression(Expression expression)
{
+ if (expression == null)
+ {
+ return null;
+ }
+
switch ((NhExpressionType) expression.NodeType)
{
case NhExpressionType.Average:
- return VisitNhAverage((AverageExpression) expression);
+ return VisitNhAverage((NhAverageExpression) expression);
case NhExpressionType.Min:
- return VisitNhMin((MinExpression)expression);
+ return VisitNhMin((NhMinExpression)expression);
case NhExpressionType.Max:
- return VisitNhMax((MaxExpression)expression);
+ return VisitNhMax((NhMaxExpression)expression);
case NhExpressionType.Sum:
- return VisitNhSum((SumExpression)expression);
+ return VisitNhSum((NhSumExpression)expression);
case NhExpressionType.Count:
- return VisitNhCount((CountExpression)expression);
+ return VisitNhCount((NhCountExpression)expression);
case NhExpressionType.Distinct:
- return VisitNhDistinct((DistinctExpression) expression);
+ return VisitNhDistinct((NhDistinctExpression) expression);
+ case NhExpressionType.New:
+ return VisitNhNew((NhNewExpression) expression);
}
return base.VisitExpression(expression);
}
- protected virtual Expression VisitNhDistinct(DistinctExpression expression)
+ private Expression VisitNhNew(NhNewExpression expression)
{
+ var arguments = VisitExpressionList(expression.Arguments);
+
+ return arguments != expression.Arguments ? new NhNewExpression(expression.Members, arguments) : expression;
+ }
+
+ protected virtual Expression VisitNhDistinct(NhDistinctExpression expression)
+ {
Expression nx = base.VisitExpression(expression.Expression);
- return nx != expression.Expression ? new DistinctExpression(nx) : expression;
+ return nx != expression.Expression ? new NhDistinctExpression(nx) : expression;
}
- protected virtual Expression VisitNhCount(CountExpression expression)
+ protected virtual Expression VisitNhCount(NhCountExpression expression)
{
- return expression;
+ Expression nx = base.VisitExpression(expression.Expression);
+
+ return nx != expression.Expression ? new NhCountExpression(nx) : expression;
}
- protected virtual Expression VisitNhSum(SumExpression expression)
+ protected virtual Expression VisitNhSum(NhSumExpression expression)
{
Expression nx = base.VisitExpression(expression.Expression);
- return nx != expression.Expression ? new SumExpression(nx) : expression;
+ return nx != expression.Expression ? new NhSumExpression(nx) : expression;
}
- protected virtual Expression VisitNhMax(MaxExpression expression)
+ protected virtual Expression VisitNhMax(NhMaxExpression expression)
{
Expression nx = base.VisitExpression(expression.Expression);
- return nx != expression.Expression ? new MaxExpression(nx) : expression;
+ return nx != expression.Expression ? new NhMaxExpression(nx) : expression;
}
- protected virtual Expression VisitNhMin(MinExpression expression)
+ protected virtual Expression VisitNhMin(NhMinExpression expression)
{
Expression nx = base.VisitExpression(expression.Expression);
- return nx != expression.Expression ? new MinExpression(nx) : expression;
+ return nx != expression.Expression ? new NhMinExpression(nx) : expression;
}
- protected virtual Expression VisitNhAverage(AverageExpression expression)
+ protected virtual Expression VisitNhAverage(NhAverageExpression expression)
{
Expression nx = base.VisitExpression(expression.Expression);
- return nx != expression.Expression ? new AverageExpression(nx) : expression;
+ return nx != expression.Expression ? new NhAverageExpression(nx) : expression;
}
}
}
\ No newline at end of file
Added: trunk/nhibernate/src/NHibernate/Linq/NhNewExpression.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/NhNewExpression.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/NhNewExpression.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -0,0 +1,30 @@
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
+using System.Linq;
+using System.Linq.Expressions;
+
+namespace NHibernate.Linq
+{
+ public class NhNewExpression : Expression
+ {
+ private readonly ReadOnlyCollection<string> _members;
+ private readonly ReadOnlyCollection<Expression> _arguments;
+
+ public NhNewExpression(IList<string> members, IList<Expression> arguments)
+ : base((ExpressionType)NhExpressionType.New, typeof(object))
+ {
+ _members = new ReadOnlyCollection<string>(members);
+ _arguments = new ReadOnlyCollection<Expression>(arguments);
+ }
+
+ public ReadOnlyCollection<Expression> Arguments
+ {
+ get { return _arguments; }
+ }
+
+ public ReadOnlyCollection<string> Members
+ {
+ get { return _members; }
+ }
+ }
+}
Modified: trunk/nhibernate/src/NHibernate/Linq/NhThrowingExpressionTreeVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/NhThrowingExpressionTreeVisitor.cs 2009-09-24 11:22:59 UTC (rev 4724)
+++ trunk/nhibernate/src/NHibernate/Linq/NhThrowingExpressionTreeVisitor.cs 2009-09-25 21:04:56 UTC (rev 4725)
@@ -1,4 +1,3 @@
-using System;
using System.Linq.Expressions;
using Remotion.Data.Linq.Parsing;
@@ -11,91 +10,106 @@
switch ((NhExpressionType)expression.NodeType)
{
case NhExpressionType.Average:
- return VisitNhAverage((AverageExpression)expression);
+ return VisitNhAverage((NhAverageExpression)expression);
case NhExpressionType.Min:
- return VisitNhMin((MinExpression)expression);
+ return VisitNhMin((NhMinExpression)expression);
case NhExpressionType.Max:
- return VisitNhMax((MaxExpression)expression);
+ return VisitNhMax((NhMaxExpression)expression);
case NhExpressionType.Sum:
- return VisitNhSum((SumExpression)expression);
+ return VisitNhSum((NhSumExpression)expression);
case NhExpressionType.Count:
- return VisitNhCount((CountExpression)expression);
+ return VisitNhCount((NhCountExpression)expression);
case NhExpressionType.Distinct:
- return VisitNhDistinct((DistinctExpression) expression);
+ return VisitNhDistinct((NhDistinctExpression) expression);
+ case NhExpressionType.New:
+ return VisitNhNew((NhNewExpression) expression);
}
return base.VisitExpression(expression);
}
- protected virtual Expression VisitNhDistinct(DistinctExpression expression)
+ protected virtual Expression VisitNhNew(NhNewExpression expression)
{
- return VisitUnhandledItem<DistinctExpression, Expression>(expression, "VisitNhDistinct", BaseVisitNhDistinct);
+ return VisitUnhandledItem<NhNewExpression, Expression>(expression, "VisitNhNew", BaseVisitNhNew);
}
- protected virtual Expression VisitNhAverage(AverageExpression expression)
+ protected virtual Expression VisitNhDistinct(NhDistinctExpression expression)
{
- return VisitUnhandledItem<AverageExpression, Expression>(expression, "VisitNhAverage", BaseVisitNhAverage);
+ return VisitUnhandledItem<NhDistinctExpression, Expression>(expression, "VisitNhDistinct", BaseVisitNhDistinct);
}
- protected virtual Expression VisitNhMin(MinExpression expression)
+ protected virtual Expression VisitNhAverage(NhAverageExpression expression)
{
- return VisitUnhandledItem<MinExpression, Expression>(expression, "VisitNhMin", BaseVisitNhMin);
+ return VisitUnhandledItem<NhAverageExpression, Expression>(expression, "VisitNhAverage", BaseVisitNhAverage);
}
- protected virtual Expression VisitNhMax(MaxExpression expression)
+ protected virtual Expression VisitNhMin(NhMinExpression expression)
{
- return VisitUnhandledItem<MaxExpression, Expression>(expression, "VisitNhMax", BaseVisitNhMax);
+ return VisitUnhandledItem<NhMinExpression, Expression>(expression, "VisitNhMin", BaseVisitNhMin);
}
- protected virtual Expression VisitNhSum(SumExpression expression)
+ protected virtual Expression VisitNhMax(NhMaxExpression expression)
{
- return VisitUnhandledItem<SumExpression, Expression>(expression, "VisitNhSum", BaseVisitNhSum);
+ return VisitUnhandledItem<NhMaxExpression, Expression>(expression, "VisitNhMax", BaseVisitNhMax);
}
- protected virtual Expression VisitNhCount(CountExpression expression)
+ protected virtual Expression VisitNhSum(NhSumExpression expression)
{
- return VisitUnhandledItem<CountExpression, Expression>(expression, "VisitNhCount", BaseVisitNhCount);
+ return VisitUnhandledItem<NhSumExpression, Expression>(expression, "VisitNhSum", BaseVisitNhSum);
}
- protected virtual Expression BaseVisitNhCount(CountExpression expression)
+ protected virtual Expression VisitNhCount(NhCountExpression expression)
{
- return expression;
+ return VisitUnhandledItem<NhCountExpression, Expression>(expression, "VisitNhCount", BaseVisitNhCount);
}
- protected virtual Expression BaseVisitNhSum(SumExpression expression)
+ protected virtual Expression BaseVisitNhCount(NhCountExpression expression)
{
Expression nx = base.VisitExpression(expression.Expression);
- return nx != expression.Expression ? new SumExpression(nx) : expression;
+ return nx != expression.Expression ? new NhCountExpression(nx) : expression;
}
- protected virtual Expression BaseVisitNhMax(MaxExpression expression)
+ protected virtual Expression BaseVisitNhSum(NhSumExpression expression)
{
Expression nx = base.VisitExpression(expression.Expression);
- return nx != expression.Expression ? new MaxExpression(nx) : expression;
+ return nx != expression.Expression ? new NhSumExpression(nx) : expression;
}
- protected virtual Expression BaseVisitNhMin(MinExpression expression)
+ protected virtual Expression BaseVisitNhMax(NhMaxExpression expression)
{
Expression nx = base.VisitExpression(expression.Expression);
- return nx != expression.Expression ? new MinExpression(nx) : expression;
+ return nx != expression.Expression ? new NhMaxExpression(nx) : expression;
}
- protected virtual Expression BaseVisitNhAverage(AverageExpression expression)
+ protected virtual Expression BaseVisitNhMin(NhMinExpression expression)
{
Expression nx = base.VisitExpression(expression.Expression);
- return nx != expression.Expression ? new AverageExpression(nx) : expression;
+ return nx != expression.Expression ? new NhMinExpression(nx) : expression;
}
- private Expression BaseVisitNhDistinct(DistinctExpression expression)
+ protected virtual Expression BaseVisitNhAverage(NhAverageExpression expression)
{
Expression nx = base.VisitExpression(expression.Expression);
- return nx != expression.Expression ? new DistinctExpression(nx) : expression;
+ ...
[truncated message content] |