From: <ste...@us...> - 2009-11-05 16:43:54
|
Revision: 4819 http://nhibernate.svn.sourceforge.net/nhibernate/?rev=4819&view=rev Author: steverstrong Date: 2009-11-05 16:43:43 +0000 (Thu, 05 Nov 2009) Log Message: ----------- Further Linq updates, plus a change to the Cast dialect function Modified Paths: -------------- trunk/nhibernate/src/NHibernate/Dialect/Function/CastFunction.cs trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/GroupByKeySelectorVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/NhThrowingExpressionTreeVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs trunk/nhibernate/src/NHibernate/NHibernate.csproj trunk/nhibernate/src/NHibernate.Test/Linq/Entities/Shipper.cs trunk/nhibernate/src/NHibernate.Test/Linq/LinqQuerySamples.cs trunk/nhibernate/src/NHibernate.Test/Linq/LinqTestCase.cs trunk/nhibernate/src/NHibernate.Test/Linq/Mappings/Shipper.hbm.xml trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj Added Paths: ----------- trunk/nhibernate/src/NHibernate/Linq/Expressions/ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs trunk/nhibernate/src/NHibernate/Linq/Expressions/NhAverageExpression.cs trunk/nhibernate/src/NHibernate/Linq/Expressions/NhCountExpression.cs trunk/nhibernate/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs trunk/nhibernate/src/NHibernate/Linq/Expressions/NhMaxExpression.cs trunk/nhibernate/src/NHibernate/Linq/Expressions/NhMinExpression.cs trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs trunk/nhibernate/src/NHibernate/Linq/Expressions/NhSumExpression.cs trunk/nhibernate/src/NHibernate/Linq/GroupBy/ trunk/nhibernate/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs trunk/nhibernate/src/NHibernate/Linq/GroupBy/GroupByAggregateDetectionVisitor.cs trunk/nhibernate/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs trunk/nhibernate/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs trunk/nhibernate/src/NHibernate/Linq/GroupJoin/ trunk/nhibernate/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs trunk/nhibernate/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs trunk/nhibernate/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs trunk/nhibernate/src/NHibernate/Linq/GroupJoin/IsAggregatingResults.cs trunk/nhibernate/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs trunk/nhibernate/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/SwapQuerySourceVisitor.cs trunk/nhibernate/src/NHibernate.Test/Linq/MiscellaneousTextFixture.cs trunk/nhibernate/src/NHibernate.Test/Linq/ObjectDumper.cs Removed Paths: ------------- trunk/nhibernate/src/NHibernate/Linq/NhNewExpression.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/AggregatingGroupByRewriter.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/AggregatingGroupJoinRewriter.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/GroupBySelectClauseRewriter.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/NonAggregatingGroupByRewriter.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/GroupByAggregateDetectionVisitor.cs Modified: trunk/nhibernate/src/NHibernate/Dialect/Function/CastFunction.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Dialect/Function/CastFunction.cs 2009-11-04 10:55:10 UTC (rev 4818) +++ trunk/nhibernate/src/NHibernate/Dialect/Function/CastFunction.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -15,16 +15,17 @@ [Serializable] public class CastFunction : ISQLFunction, IFunctionGrammar { - private LazyType returnType; + //private LazyType returnType; #region ISQLFunction Members public IType ReturnType(IType columnType, IMapping mapping) { //note there is a weird implementation in the client side //TODO: cast that use only costant are not supported in SELECT. Ex: cast(5 as string) - return SetLazyType(columnType); + //return SetLazyType(columnType); + return columnType; } - + /* private LazyType SetLazyType(IType columnType) { if(returnType == null) @@ -34,6 +35,7 @@ returnType.RealType = columnType; return returnType; } + */ public bool HasArguments { get { return true; } @@ -53,7 +55,7 @@ string typeName = args[1].ToString(); string sqlType; IType hqlType = TypeFactory.HeuristicType(typeName); - SetLazyType(hqlType); + //SetLazyType(hqlType); if (hqlType != null) { SqlType[] sqlTypeCodes = hqlType.SqlTypes(factory); Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-11-04 10:55:10 UTC (rev 4818) +++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -133,6 +133,11 @@ _node.Text = "string"; break; default: + if (type == typeof(Guid)) + { + _node.Text = "guid"; + break; + } throw new NotSupportedException(string.Format("Don't currently support idents of type {0}", type.Name)); } } Added: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,21 @@ +using System.Linq.Expressions; + +namespace NHibernate.Linq.Expressions +{ + public class NhAggregatedExpression : Expression + { + public Expression Expression { get; set; } + + public NhAggregatedExpression(Expression expression, NhExpressionType type) + : base((ExpressionType)type, expression.Type) + { + Expression = expression; + } + + public NhAggregatedExpression(Expression expression, System.Type expressionType, NhExpressionType type) + : base((ExpressionType)type, expressionType) + { + Expression = expression; + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhAverageExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhAverageExpression.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhAverageExpression.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,11 @@ +using System.Linq.Expressions; + +namespace NHibernate.Linq.Expressions +{ + public class NhAverageExpression : NhAggregatedExpression + { + public NhAverageExpression(Expression expression) : base(expression, NhExpressionType.Average) + { + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhCountExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhCountExpression.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhCountExpression.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,12 @@ +using System.Linq.Expressions; + +namespace NHibernate.Linq.Expressions +{ + public class NhCountExpression : NhAggregatedExpression + { + public NhCountExpression(Expression expression) + : base(expression, typeof(int), NhExpressionType.Count) + { + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhDistinctExpression.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,12 @@ +using System.Linq.Expressions; + +namespace NHibernate.Linq.Expressions +{ + public class NhDistinctExpression : NhAggregatedExpression + { + public NhDistinctExpression(Expression expression) + : base(expression, NhExpressionType.Distinct) + { + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,13 @@ +namespace NHibernate.Linq.Expressions +{ + public enum NhExpressionType + { + Average = 10000, + Min, + Max, + Sum, + Count, + Distinct, + New + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhMaxExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhMaxExpression.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhMaxExpression.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,12 @@ +using System.Linq.Expressions; + +namespace NHibernate.Linq.Expressions +{ + public class NhMaxExpression : NhAggregatedExpression + { + public NhMaxExpression(Expression expression) + : base(expression, NhExpressionType.Max) + { + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhMinExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhMinExpression.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhMinExpression.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,12 @@ +using System.Linq.Expressions; + +namespace NHibernate.Linq.Expressions +{ + public class NhMinExpression : NhAggregatedExpression + { + public NhMinExpression(Expression expression) + : base(expression, NhExpressionType.Min) + { + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,29 @@ +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq.Expressions; + +namespace NHibernate.Linq.Expressions +{ + public class NhNewExpression : Expression + { + private readonly ReadOnlyCollection<string> _members; + private readonly ReadOnlyCollection<Expression> _arguments; + + public NhNewExpression(IList<string> members, IList<Expression> arguments) + : base((ExpressionType)NhExpressionType.New, typeof(object)) + { + _members = new ReadOnlyCollection<string>(members); + _arguments = new ReadOnlyCollection<Expression>(arguments); + } + + public ReadOnlyCollection<Expression> Arguments + { + get { return _arguments; } + } + + public ReadOnlyCollection<string> Members + { + get { return _members; } + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhSumExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhSumExpression.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhSumExpression.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,12 @@ +using System.Linq.Expressions; + +namespace NHibernate.Linq.Expressions +{ + public class NhSumExpression : NhAggregatedExpression + { + public NhSumExpression(Expression expression) + : base(expression, NhExpressionType.Sum) + { + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,88 @@ +using System; +using System.Linq; +using NHibernate.Linq.Visitors; +using Remotion.Data.Linq; +using Remotion.Data.Linq.Clauses; +using Remotion.Data.Linq.Clauses.Expressions; +using Remotion.Data.Linq.Clauses.ResultOperators; + +namespace NHibernate.Linq.GroupBy +{ + /// <summary> + /// An AggregatingGroupBy is a query such as: + /// + /// from p in db.Products + /// group p by p.Category.CategoryId + /// into g + /// select new + /// { + /// g.Key, + /// MaxPrice = g.Max(p => p.UnitPrice) + /// }; + /// + /// Where the grouping operation is being fully aggregated and hence does not create any form of heirarchy. + /// This class takes such queries, flattens out the re-linq sub-query and re-writes the outer select + /// </summary> + public class AggregatingGroupByRewriter + { + private AggregatingGroupByRewriter() + { + } + + public static void ReWrite(QueryModel queryModel) + { + var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression; + + if ((subQueryExpression != null) && + (subQueryExpression.QueryModel.ResultOperators.Count() == 1) && + (subQueryExpression.QueryModel.ResultOperators[0] is GroupResultOperator) && + (IsAggregatingGroupBy(queryModel))) + { + var rewriter = new AggregatingGroupByRewriter(); + rewriter.FlattenSubQuery(subQueryExpression, queryModel.MainFromClause, queryModel); + } + } + + private static bool IsAggregatingGroupBy(QueryModel queryModel) + { + return new GroupByAggregateDetectionVisitor().Visit(queryModel.SelectClause.Selector); + } + + private void FlattenSubQuery(SubQueryExpression subQueryExpression, FromClauseBase fromClause, + QueryModel queryModel) + { + // Move the result operator up + if (queryModel.ResultOperators.Count != 0) + { + throw new NotImplementedException(); + } + + var groupBy = (GroupResultOperator) subQueryExpression.QueryModel.ResultOperators[0]; + + // Replace the outer select clause... + queryModel.SelectClause.TransformExpressions(s => GroupBySelectClauseRewriter.ReWrite(s, groupBy, subQueryExpression.QueryModel)); + + queryModel.SelectClause.TransformExpressions( + s => + new SwapQuerySourceVisitor(queryModel.MainFromClause, subQueryExpression.QueryModel.MainFromClause).Swap + (s)); + + MainFromClause innerMainFromClause = subQueryExpression.QueryModel.MainFromClause; + CopyFromClauseData(innerMainFromClause, fromClause); + + foreach (var bodyClause in subQueryExpression.QueryModel.BodyClauses) + { + queryModel.BodyClauses.Add(bodyClause); + } + + queryModel.ResultOperators.Add(groupBy); + } + + protected void CopyFromClauseData(FromClauseBase source, FromClauseBase destination) + { + destination.FromExpression = source.FromExpression; + destination.ItemName = source.ItemName; + destination.ItemType = source.ItemType; + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/GroupBy/GroupByAggregateDetectionVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupBy/GroupByAggregateDetectionVisitor.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/GroupBy/GroupByAggregateDetectionVisitor.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,61 @@ +using System.Linq; +using System.Linq.Expressions; +using NHibernate.Linq.Expressions; +using NHibernate.Linq.Visitors; +using Remotion.Data.Linq.Clauses.Expressions; + +namespace NHibernate.Linq.GroupBy +{ + // TODO: This needs strengthening. Possibly a lot in common with the GroupJoinAggregateDetectionVisitor class, which does many more checks + /// <summary> + /// Detects if an expression tree contains aggregate functions + /// </summary> + internal class GroupByAggregateDetectionVisitor : NhExpressionTreeVisitor + { + public bool ContainsAggregateMethods { get; private set; } + + public bool Visit(Expression expression) + { + ContainsAggregateMethods = false; + + VisitExpression(expression); + + return ContainsAggregateMethods; + } + + // TODO - this should not exist, since it should be handled either by re-linq or by the MergeAggregatingResultsRewriter + protected override Expression VisitMethodCallExpression(MethodCallExpression m) + { + if (m.Method.DeclaringType == typeof (Queryable) || + m.Method.DeclaringType == typeof (Enumerable)) + { + switch (m.Method.Name) + { + case "Count": + case "Min": + case "Max": + case "Sum": + case "Average": + ContainsAggregateMethods = true; + break; + } + } + + return m; + } + + protected override Expression VisitNhAggregate(NhAggregatedExpression expression) + { + ContainsAggregateMethods = true; + return expression; + } + + protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + { + ContainsAggregateMethods = + new GroupByAggregateDetectionVisitor().Visit(expression.QueryModel.SelectClause.Selector); + + return expression; + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,131 @@ +using System; +using System.Linq.Expressions; +using NHibernate.Linq.Visitors; +using Remotion.Data.Linq; +using Remotion.Data.Linq.Clauses; +using Remotion.Data.Linq.Clauses.Expressions; +using Remotion.Data.Linq.Clauses.ResultOperators; + +namespace NHibernate.Linq.GroupBy +{ + internal class GroupBySelectClauseRewriter : NhExpressionTreeVisitor + { + public static Expression ReWrite(Expression expression, GroupResultOperator groupBy, QueryModel model) + { + var visitor = new GroupBySelectClauseRewriter(groupBy, model); + return visitor.VisitExpression(expression); + } + + private readonly GroupResultOperator _groupBy; + private readonly QueryModel _model; + + public GroupBySelectClauseRewriter(GroupResultOperator groupBy, QueryModel model) + { + _groupBy = groupBy; + _model = model; + } + + protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + { + if (expression.ReferencedQuerySource == _groupBy) + { + return _groupBy.ElementSelector; + } + + return base.VisitQuerySourceReferenceExpression(expression); + } + + protected override Expression VisitMemberExpression(MemberExpression expression) + { + if (IsMemberOfModel(expression)) + { + if (expression.Member.Name == "Key") + { + return _groupBy.KeySelector; + } + else + { + Expression elementSelector = _groupBy.ElementSelector; + + if ((elementSelector is MemberExpression) || (elementSelector is QuerySourceReferenceExpression)) + { + // If ElementSelector is MemberExpression, just return + return base.VisitMemberExpression(expression); + } + else if (elementSelector is NewExpression) + { + // If ElementSelector is NewExpression, then search for member of name "get_" + originalMemberExpression.Member.Name + // TODO - this wouldn't handle nested initialisers. Should do a tree walk to find the correct member + var nex = elementSelector as NewExpression; + + int i = 0; + foreach (var member in nex.Members) + { + if (member.Name == "get_" + expression.Member.Name) + { + return nex.Arguments[i]; + } + i++; + } + + throw new NotImplementedException(); + } + else + { + throw new NotImplementedException(); + } + } + } + else + { + return base.VisitMemberExpression(expression); + } + } + + // TODO - dislike this code intensly. Should probably be a tree-walk in its own right + private bool IsMemberOfModel(MemberExpression expression) + { + var querySourceRef = expression.Expression as QuerySourceReferenceExpression; + + if (querySourceRef == null) + { + return false; + } + + var fromClause = querySourceRef.ReferencedQuerySource as FromClauseBase; + + if (fromClause == null) + { + return false; + } + + var subQuery = fromClause.FromExpression as SubQueryExpression; + + if (subQuery != null) + { + return subQuery.QueryModel == _model; + } + + var referencedQuery = fromClause.FromExpression as QuerySourceReferenceExpression; + + if (referencedQuery == null) + { + return false; + } + + var querySource = referencedQuery.ReferencedQuerySource as FromClauseBase; + + var subQuery2 = querySource.FromExpression as SubQueryExpression; + + return (subQuery2.QueryModel == _model); + } + + protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + { + // TODO - is this safe? All we are extracting is the select clause from the sub-query. Assumes that everything + // else in the subquery has been removed. If there were two subqueries, one aggregating & one not, this may not be a + // valid assumption. Should probably be passed a list of aggregating subqueries that we are flattening so that we can check... + return GroupBySelectClauseRewriter.ReWrite(expression.QueryModel.SelectClause.Selector, _groupBy, _model); + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,108 @@ +using System; +using System.Linq; +using System.Linq.Expressions; +using NHibernate.Linq.ResultOperators; +using NHibernate.Linq.Visitors; +using Remotion.Data.Linq; +using Remotion.Data.Linq.Clauses; +using Remotion.Data.Linq.Clauses.Expressions; +using Remotion.Data.Linq.Clauses.ResultOperators; + +namespace NHibernate.Linq.GroupBy +{ + public class NonAggregatingGroupByRewriter + { + private NonAggregatingGroupByRewriter() + { + } + + public static void ReWrite(QueryModel queryModel) + { + var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression; + + if ((subQueryExpression != null) && + (subQueryExpression.QueryModel.ResultOperators.Count() == 1) && + (subQueryExpression.QueryModel.ResultOperators[0] is GroupResultOperator) && + (IsNonAggregatingGroupBy(queryModel))) + { + var rewriter = new NonAggregatingGroupByRewriter(); + rewriter.FlattenSubQuery(subQueryExpression, queryModel.MainFromClause, queryModel); + } + } + + private void FlattenSubQuery(SubQueryExpression subQueryExpression, MainFromClause fromClause, + QueryModel queryModel) + { + // Create a new client-side select for the outer + // TODO - don't like calling GetGenericArguments here... + var clientSideSelect = new ClientSideSelect(new NonAggregatingGroupBySelectRewriter().Visit(queryModel.SelectClause.Selector, subQueryExpression.Type.GetGenericArguments()[0], queryModel.MainFromClause)); + + // Replace the outer select clause... + queryModel.SelectClause = subQueryExpression.QueryModel.SelectClause; + + MainFromClause innerMainFromClause = subQueryExpression.QueryModel.MainFromClause; + + CopyFromClauseData(innerMainFromClause, fromClause); + + foreach (var bodyClause in subQueryExpression.QueryModel.BodyClauses) + { + queryModel.BodyClauses.Add(bodyClause); + } + + // Move the result operator up + if (queryModel.ResultOperators.Count != 0) + { + throw new NotImplementedException(); + } + + queryModel.ResultOperators.Add(new NonAggregatingGroupBy((GroupResultOperator) subQueryExpression.QueryModel.ResultOperators[0])); + queryModel.ResultOperators.Add(clientSideSelect); + } + + protected void CopyFromClauseData(FromClauseBase source, FromClauseBase destination) + { + destination.FromExpression = source.FromExpression; + destination.ItemName = source.ItemName; + destination.ItemType = source.ItemType; + } + + private static bool IsNonAggregatingGroupBy(QueryModel queryModel) + { + return new GroupByAggregateDetectionVisitor().Visit(queryModel.SelectClause.Selector) == false; + } + } + + internal class NonAggregatingGroupBySelectRewriter : NhExpressionTreeVisitor + { + private ParameterExpression _inputParameter; + private IQuerySource _querySource; + + public LambdaExpression Visit(Expression clause, System.Type resultType, IQuerySource querySource) + { + _inputParameter = Expression.Parameter(resultType, "inputParameter"); + _querySource = querySource; + + return Expression.Lambda(VisitExpression(clause), _inputParameter); + } + + protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + { + if (expression.ReferencedQuerySource == _querySource) + { + return _inputParameter; + } + + return expression; + } + } + + internal class ClientSideSelect : ClientSideTransformOperator + { + public LambdaExpression SelectClause { get; private set; } + + public ClientSideSelect(LambdaExpression selectClause) + { + SelectClause = selectClause; + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/GroupJoin/AggregatingGroupJoinRewriter.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,66 @@ +using System.Collections.Generic; +using System.Linq; +using Remotion.Data.Linq; +using Remotion.Data.Linq.Clauses; + +namespace NHibernate.Linq.GroupJoin +{ + /// <summary> + /// An AggregatingGroupJoin is a query such as: + /// + /// from c in db.Customers + /// join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords + /// join e in db.Employees on c.Address.City equals e.Address.City into emps + /// select new { c.ContactName, ords = ords.Count(), emps = emps.Count() }; + /// + /// where the results of the joins are being fully aggregated and hence do not create any form of hierarchy. + /// This class takes such expressions and turns them into this form: + /// + /// from c in db.Customers + /// select new + /// { + /// c.ContactName, + /// ords = (from o2 in db.Orders where o2.Customer.CustomerId == c.CustomerId select o2).Count(), + /// emps = (from e2 in db.Employees where e2.Address.City == c.Address.City select e2).Count() + /// }; + /// + /// </summary> + public class AggregatingGroupJoinRewriter + { + private AggregatingGroupJoinRewriter() + { + } + + public static void ReWrite(QueryModel model) + { + // firstly, get the group join clauses + var groupJoin = model.BodyClauses.Where(bc => bc is GroupJoinClause).Cast<GroupJoinClause>(); + + if (groupJoin.Count() == 0) + { + // No group join here.. + return; + } + + // Now walk the tree to decide which groupings are fully aggregated (and can hence be done in hql) + var aggregateDetectorResults = IsAggregatingGroupJoin(model, groupJoin); + + if (aggregateDetectorResults.AggregatingClauses.Count > 0) + { + // Re-write the select expression + model.SelectClause.TransformExpressions(s => GroupJoinSelectClauseRewriter.ReWrite(s, aggregateDetectorResults)); + + // Remove the aggregating group joins + foreach (GroupJoinClause aggregatingGroupJoin in aggregateDetectorResults.AggregatingClauses) + { + model.BodyClauses.Remove(aggregatingGroupJoin); + } + } + } + + private static IsAggregatingResults IsAggregatingGroupJoin(QueryModel model, IEnumerable<GroupJoinClause> clause) + { + return GroupJoinAggregateDetectionVisitor.Visit(clause, model.SelectClause.Selector); + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,116 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using NHibernate.Linq.Expressions; +using NHibernate.Linq.Visitors; +using Remotion.Data.Linq.Clauses; +using Remotion.Data.Linq.Clauses.Expressions; + +namespace NHibernate.Linq.GroupJoin +{ + internal class GroupJoinAggregateDetectionVisitor : NhExpressionTreeVisitor + { + private readonly HashSet<GroupJoinClause> _groupJoinClauses; + private readonly StackFlag _inAggregate = new StackFlag(); + private readonly StackFlag _parentExpressionProcessed = new StackFlag(); + + private readonly List<Expression> _nonAggregatingExpressions = new List<Expression>(); + private readonly List<GroupJoinClause> _nonAggregatingGroupJoins = new List<GroupJoinClause>(); + private readonly List<GroupJoinClause> _aggregatingGroupJoins = new List<GroupJoinClause>(); + + private GroupJoinAggregateDetectionVisitor(IEnumerable<GroupJoinClause> groupJoinClause) + { + _groupJoinClauses = new HashSet<GroupJoinClause>(groupJoinClause); + } + + public static IsAggregatingResults Visit(IEnumerable<GroupJoinClause> groupJoinClause, Expression selectExpression) + { + var visitor = new GroupJoinAggregateDetectionVisitor(groupJoinClause); + + visitor.VisitExpression(selectExpression); + + return new IsAggregatingResults { NonAggregatingClauses = visitor._nonAggregatingGroupJoins, AggregatingClauses = visitor._aggregatingGroupJoins, NonAggregatingExpressions = visitor._nonAggregatingExpressions }; + } + + protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + { + VisitExpression(expression.QueryModel.SelectClause.Selector); + return expression; + } + + protected override Expression VisitNhAggregate(NhAggregatedExpression expression) + { + using (_inAggregate.SetFlag()) + { + return base.VisitNhAggregate(expression); + } + } + + protected override Expression VisitMemberExpression(MemberExpression expression) + { + if (_inAggregate.FlagIsFalse && _parentExpressionProcessed.FlagIsFalse) + { + _nonAggregatingExpressions.Add(expression); + } + + using (_parentExpressionProcessed.SetFlag()) + { + return base.VisitMemberExpression(expression); + } + } + + protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + { + var fromClause = (FromClauseBase) expression.ReferencedQuerySource; + + if (fromClause.FromExpression is QuerySourceReferenceExpression) + { + var querySourceReference = (QuerySourceReferenceExpression) fromClause.FromExpression; + + if (_groupJoinClauses.Contains(querySourceReference.ReferencedQuerySource as GroupJoinClause)) + { + if (_inAggregate.FlagIsFalse) + { + _nonAggregatingGroupJoins.Add((GroupJoinClause) querySourceReference.ReferencedQuerySource); + } + else + { + _aggregatingGroupJoins.Add((GroupJoinClause) querySourceReference.ReferencedQuerySource); + } + } + } + + return base.VisitQuerySourceReferenceExpression(expression); + } + + internal class StackFlag + { + public bool FlagIsTrue { get; private set; } + + public bool FlagIsFalse { get { return !FlagIsTrue; } } + + public IDisposable SetFlag() + { + return new StackFlagDisposable(this); + } + + internal class StackFlagDisposable : IDisposable + { + private readonly StackFlag _parent; + private readonly bool _old; + + public StackFlagDisposable(StackFlag parent) + { + _parent = parent; + _old = parent.FlagIsTrue; + parent.FlagIsTrue = true; + } + + public void Dispose() + { + _parent.FlagIsTrue = _old; + } + } + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/GroupJoin/GroupJoinSelectClauseRewriter.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,52 @@ +using System.Linq.Expressions; +using NHibernate.Linq.Visitors; +using Remotion.Data.Linq; +using Remotion.Data.Linq.Clauses; +using Remotion.Data.Linq.Clauses.Expressions; + +namespace NHibernate.Linq.GroupJoin +{ + public class GroupJoinSelectClauseRewriter : NhExpressionTreeVisitor + { + private readonly IsAggregatingResults _results; + + public static Expression ReWrite(Expression expression, IsAggregatingResults results) + { + return new GroupJoinSelectClauseRewriter(results).VisitExpression(expression); + } + + private GroupJoinSelectClauseRewriter(IsAggregatingResults results) + { + _results = results; + } + + protected override Expression VisitSubQueryExpression(SubQueryExpression expression) + { + // If the sub query's main (and only) from clause is one of our aggregating group bys, then swap it + GroupJoinClause groupJoin = LocateGroupJoinQuerySource(expression.QueryModel); + + if (groupJoin != null) + { + Expression innerSelector = new SwapQuerySourceVisitor(groupJoin.JoinClause, expression.QueryModel.MainFromClause). + Swap(groupJoin.JoinClause.InnerKeySelector); + + expression.QueryModel.MainFromClause.FromExpression = groupJoin.JoinClause.InnerSequence; + + + // TODO - this only works if the key selectors are not composite. Needs improvement... + expression.QueryModel.BodyClauses.Add(new WhereClause(Expression.Equal(innerSelector, groupJoin.JoinClause.OuterKeySelector))); + } + + return expression; + } + + private GroupJoinClause LocateGroupJoinQuerySource(QueryModel model) + { + if (model.BodyClauses.Count > 0) + { + return null; + } + return new LocateGroupJoinQuerySource(_results).Detect(model.MainFromClause.FromExpression); + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/GroupJoin/IsAggregatingResults.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupJoin/IsAggregatingResults.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/GroupJoin/IsAggregatingResults.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,13 @@ +using System.Collections.Generic; +using System.Linq.Expressions; +using Remotion.Data.Linq.Clauses; + +namespace NHibernate.Linq.GroupJoin +{ + public class IsAggregatingResults + { + public List<GroupJoinClause> NonAggregatingClauses { get; set; } + public List<GroupJoinClause> AggregatingClauses { get; set; } + public List<Expression> NonAggregatingExpressions { get; set; } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/GroupJoin/LocateGroupJoinQuerySource.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,34 @@ +using System.Linq.Expressions; +using NHibernate.Linq.Visitors; +using Remotion.Data.Linq.Clauses; +using Remotion.Data.Linq.Clauses.Expressions; + +namespace NHibernate.Linq.GroupJoin +{ + public class LocateGroupJoinQuerySource : NhExpressionTreeVisitor + { + private readonly IsAggregatingResults _results; + private GroupJoinClause _groupJoin; + + public LocateGroupJoinQuerySource(IsAggregatingResults results) + { + _results = results; + } + + public GroupJoinClause Detect(Expression expression) + { + VisitExpression(expression); + return _groupJoin; + } + + protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + { + if (_results.AggregatingClauses.Contains(expression.ReferencedQuerySource as GroupJoinClause)) + { + _groupJoin = expression.ReferencedQuerySource as GroupJoinClause; + } + + return base.VisitQuerySourceReferenceExpression(expression); + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -0,0 +1,203 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using NHibernate.Linq.GroupJoin; +using Remotion.Data.Linq; +using Remotion.Data.Linq.Clauses; +using Remotion.Data.Linq.Clauses.Expressions; + +namespace NHibernate.Linq.Visitors +{ + public class NonAggregatingGroupJoinRewriter + { + private readonly QueryModel _model; + private readonly IEnumerable<GroupJoinClause> _groupJoinClauses; + private QuerySourceUsageLocator _locator; + + private NonAggregatingGroupJoinRewriter(QueryModel model, IEnumerable<GroupJoinClause> groupJoinClauses) + { + _model = model; + _groupJoinClauses = groupJoinClauses; + } + + public static void ReWrite(QueryModel model) + { + // firstly, get the group join clauses + var groupJoinClauses = model.BodyClauses.Where(bc => bc is GroupJoinClause).Cast<GroupJoinClause>(); + + if (groupJoinClauses.Count() == 0) + { + // No group join here.. + return; + } + + var rewriter = new NonAggregatingGroupJoinRewriter(model, groupJoinClauses); + + rewriter.ReWrite(); + } + + private void ReWrite() + { + var aggregateDetectorResults = GetGroupJoinInformation(_groupJoinClauses); + + foreach (var nonAggregatingJoin in aggregateDetectorResults.NonAggregatingClauses) + { + // Group joins get processed (currently) in one of three ways. + // Option 1: results of group join are not further referenced outside of the final projection. + // In this case, replace the group join with a join, and add a client-side grouping operator + // to build the correct hierarchy + // + // Option 2: Results of group join are only used in a plain "from" expression, such as: + // from o in db.Orders + // from p in db.Products + // join d in db.OrderLines + // on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + // into details + // from d in details + // select new {o.OrderId, p.ProductId, d.UnitPrice}; + // In this case, simply change the group join to a join; the results of the grouping are being + // removed by the subsequent "from" + // + // Option 3: Results of group join are only used in a "from ... DefaultIfEmpty()" construction, such as: + // from o in dc.Orders + // join v in dc.Vendors on o.VendorId equals v.Id into ov + // from x in ov.DefaultIfEmpty() + // join s in dc.Status on o.StatusId equals s.Id into os + // from y in os.DefaultIfEmpty() + // select new { o.OrderNumber, x.VendorName, y.StatusName } + // This is used to repesent an outer join, and again the "from" is removing the hierarchy. So + // simply change the group join to an outer join + + _locator = new QuerySourceUsageLocator(nonAggregatingJoin); + + foreach (var bodyClause in _model.BodyClauses) + { + _locator.Search(bodyClause); + } + + if (IsHierarchicalJoin(nonAggregatingJoin)) + { + } + else if (IsFlattenedJoin(nonAggregatingJoin)) + { + ProcessFlattenedJoin(nonAggregatingJoin); + } + else if (IsOuterJoin(nonAggregatingJoin)) + { + + } + else + { + // Wonder what this is? + throw new NotSupportedException(); + } + } + } + + private void ProcessFlattenedJoin(GroupJoinClause nonAggregatingJoin) + { + // Need to: + // 1. Remove the group join and replace it with a join + // 2. Remove the corresponding "from" clause (the thing that was doing the flattening) + // 3. Rewrite the selector to reference the "join" rather than the "from" clause + SwapClause(nonAggregatingJoin, nonAggregatingJoin.JoinClause); + + // TODO - don't like use of _locator here; would rather we got this passed in. Ditto on next line (esp. the cast) + _model.BodyClauses.Remove(_locator.Clauses[0]); + + var querySourceSwapper = new SwapQuerySourceVisitor((IQuerySource) _locator.Clauses[0], nonAggregatingJoin.JoinClause); + _model.SelectClause.TransformExpressions(querySourceSwapper.Swap); + } + + // TODO - store the indexes of the join clauses when we find them, then can remove this loop + private void SwapClause(IBodyClause oldClause, IBodyClause newClause) + { + for (int i = 0; i < _model.BodyClauses.Count; i++) + { + if (_model.BodyClauses[i] == oldClause) + { + _model.BodyClauses.RemoveAt(i); + _model.BodyClauses.Insert(i, newClause); + } + } + } + + private bool IsOuterJoin(GroupJoinClause nonAggregatingJoin) + { + return false; + } + + private bool IsFlattenedJoin(GroupJoinClause nonAggregatingJoin) + { + if (_locator.Clauses.Count == 1) + { + var from = _locator.Clauses[0] as AdditionalFromClause; + + if (from != null) + { + return true; + } + } + + return false; + } + + private bool IsHierarchicalJoin(GroupJoinClause nonAggregatingJoin) + { + return _locator.Clauses.Count == 0; + } + + // TODO - rename this and share with the AggregatingGroupJoinRewriter + private IsAggregatingResults GetGroupJoinInformation(IEnumerable<GroupJoinClause> clause) + { + return GroupJoinAggregateDetectionVisitor.Visit(clause, _model.SelectClause.Selector); + } + + } + + internal class QuerySourceUsageLocator : NhExpressionTreeVisitor + { + private readonly IQuerySource _querySource; + private bool _references; + private readonly List<IBodyClause> _clauses = new List<IBodyClause>(); + + public QuerySourceUsageLocator(IQuerySource querySource) + { + _querySource = querySource; + } + + public IList<IBodyClause> Clauses + { + get { return _clauses.AsReadOnly(); } + } + + public void Search(IBodyClause clause) + { + _references = false; + + clause.TransformExpressions(ExpressionSearcher); + + if (_references) + { + _clauses.Add(clause); + } + } + + private Expression ExpressionSearcher(Expression arg) + { + VisitExpression(arg); + return arg; + } + + protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) + { + if (expression.ReferencedQuerySource == _querySource) + { + _references = true; + } + + return expression; + } + } +} \ No newline at end of file Deleted: trunk/nhibernate/src/NHibernate/Linq/NhNewExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/NhNewExpression.cs 2009-11-04 10:55:10 UTC (rev 4818) +++ trunk/nhibernate/src/NHibernate/Linq/NhNewExpression.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -1,30 +0,0 @@ -using System.Collections.Generic; -using System.Collections.ObjectModel; -using System.Linq.Expressions; -using NHibernate.Linq.ReWriters; - -namespace NHibernate.Linq -{ - public class NhNewExpression : Expression - { - private readonly ReadOnlyCollection<string> _members; - private readonly ReadOnlyCollection<Expression> _arguments; - - public NhNewExpression(IList<string> members, IList<Expression> arguments) - : base((ExpressionType)NhExpressionType.New, typeof(object)) - { - _members = new ReadOnlyCollection<string>(members); - _arguments = new ReadOnlyCollection<Expression>(arguments); - } - - public ReadOnlyCollection<Expression> Arguments - { - get { return _arguments; } - } - - public ReadOnlyCollection<string> Members - { - get { return _members; } - } - } -} Modified: trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs 2009-11-04 10:55:10 UTC (rev 4818) +++ trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -47,7 +47,7 @@ return new NhQueryable<T>(this, expression); } - void SetParameters(IQuery query, IDictionary<string, object> parameters) + static void SetParameters(IQuery query, IDictionary<string, object> parameters) { foreach (var parameterName in query.NamedParameters) { Deleted: trunk/nhibernate/src/NHibernate/Linq/ReWriters/AggregatingGroupByRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ReWriters/AggregatingGroupByRewriter.cs 2009-11-04 10:55:10 UTC (rev 4818) +++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/AggregatingGroupByRewriter.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -1,69 +0,0 @@ -using System; -using System.Linq; -using NHibernate.Linq.Visitors; -using Remotion.Data.Linq; -using Remotion.Data.Linq.Clauses; -using Remotion.Data.Linq.Clauses.Expressions; -using Remotion.Data.Linq.Clauses.ResultOperators; - -namespace NHibernate.Linq.ReWriters -{ - public class AggregatingGroupByRewriter - { - public void ReWrite(QueryModel queryModel) - { - var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression; - - if ((subQueryExpression != null) && - (subQueryExpression.QueryModel.ResultOperators.Count() == 1) && - (subQueryExpression.QueryModel.ResultOperators[0] is GroupResultOperator) && - (IsAggregatingGroupBy(queryModel))) - { - FlattenSubQuery(subQueryExpression, queryModel.MainFromClause, queryModel); - } - } - - private static bool IsAggregatingGroupBy(QueryModel queryModel) - { - return new GroupByAggregateDetectionVisitor().Visit(queryModel.SelectClause.Selector); - } - - private void FlattenSubQuery(SubQueryExpression subQueryExpression, FromClauseBase fromClause, - QueryModel queryModel) - { - // Move the result operator up - if (queryModel.ResultOperators.Count != 0) - { - throw new NotImplementedException(); - } - - var groupBy = (GroupResultOperator) subQueryExpression.QueryModel.ResultOperators[0]; - - // Replace the outer select clause... - queryModel.SelectClause.TransformExpressions(s => GroupBySelectClauseRewriter.ReWrite(s, groupBy, subQueryExpression.QueryModel)); - - queryModel.SelectClause.TransformExpressions( - s => - new SwapQuerySourceVisitor(queryModel.MainFromClause, subQueryExpression.QueryModel.MainFromClause).Swap - (s)); - - - MainFromClause innerMainFromClause = subQueryExpression.QueryModel.MainFromClause; - CopyFromClauseData(innerMainFromClause, fromClause); - - foreach (var bodyClause in subQueryExpression.QueryModel.BodyClauses) - { - queryModel.BodyClauses.Add(bodyClause); - } - - queryModel.ResultOperators.Add(groupBy); - } - - protected void CopyFromClauseData(FromClauseBase source, FromClauseBase destination) - { - destination.FromExpression = source.FromExpression; - destination.ItemName = source.ItemName; - destination.ItemType = source.ItemType; - } - } -} \ No newline at end of file Deleted: trunk/nhibernate/src/NHibernate/Linq/ReWriters/AggregatingGroupJoinRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ReWriters/AggregatingGroupJoinRewriter.cs 2009-11-04 10:55:10 UTC (rev 4818) +++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/AggregatingGroupJoinRewriter.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -1,322 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Linq.Expressions; -using NHibernate.Linq.Visitors; -using Remotion.Data.Linq; -using Remotion.Data.Linq.Clauses; -using Remotion.Data.Linq.Clauses.Expressions; - -namespace NHibernate.Linq.ReWriters -{ - public class AggregatingGroupJoinRewriter - { - public void ReWrite(QueryModel model) - { - // We want to take queries like this: - - //var q = - // from c in db.Customers - // join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords - // join e in db.Employees on c.Address.City equals e.Address.City into emps - // select new { c.ContactName, ords = ords.Count(), emps = emps.Count() }; - - // and turn them into this: - - //var q = - // from c in db.Customers - // select new - // { - // c.ContactName, - // ords = (from o2 in db.Orders where o2.Customer.CustomerId == c.CustomerId select o2).Count(), - // emps = (from e2 in db.Employees where e2.Address.City == c.Address.City select e2).Count() - // }; - - // so spot a group join where every use of the grouping in the selector is an aggregate - - // firstly, get the group join clauses - var groupJoin = model.BodyClauses.Where(bc => bc is GroupJoinClause).Cast<GroupJoinClause>(); - - if (groupJoin.Count() == 0) - { - // No group join here.. - return; - } - - // Now walk the tree to decide which groupings are fully aggregated (and can hence be done in hql) - var aggregateDetectorResults = IsAggregatingGroupJoin(model, groupJoin); - - if (aggregateDetectorResults.AggregatingClauses.Count > 0) - { - // Re-write the select expression - model.SelectClause.TransformExpressions(s => GroupJoinSelectClauseRewriter.ReWrite(s, aggregateDetectorResults)); - - // Remove the aggregating group joins - foreach (GroupJoinClause aggregatingGroupJoin in aggregateDetectorResults.AggregatingClauses) - { - model.BodyClauses.Remove(aggregatingGroupJoin); - } - } - } - - private static IsAggregatingResults IsAggregatingGroupJoin(QueryModel model, IEnumerable<GroupJoinClause> clause) - { - return new GroupJoinAggregateDetectionVisitor(clause).Visit(model.SelectClause.Selector); - } - } - - public class GroupJoinSelectClauseRewriter : NhExpressionTreeVisitor - { - private readonly IsAggregatingResults _results; - - public static Expression ReWrite(Expression expression, IsAggregatingResults results) - { - return new GroupJoinSelectClauseRewriter(results).VisitExpression(expression); - } - - private GroupJoinSelectClauseRewriter(IsAggregatingResults results) - { - _results = results; - } - - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) - { - // If the sub queries main (and only) from clause is one of our aggregating group bys, then swap it - GroupJoinClause groupJoin = LocateGroupJoinQuerySource(expression.QueryModel); - - if (groupJoin != null) - { - Expression innerSelector = new SwapQuerySourceVisitor(groupJoin.JoinClause, expression.QueryModel.MainFromClause). - Swap(groupJoin.JoinClause.InnerKeySelector); - - expression.QueryModel.MainFromClause.FromExpression = groupJoin.JoinClause.InnerSequence; - - - // TODO - this only works if the key selectors are not composite. Needs improvement... - expression.QueryModel.BodyClauses.Add(new WhereClause(Expression.Equal(innerSelector, groupJoin.JoinClause.OuterKeySelector))); - } - - return expression; - } - - private GroupJoinClause LocateGroupJoinQuerySource(QueryModel model) - { - if (model.BodyClauses.Count > 0) - { - return null; - } - return new LocateGroupJoinQuerySource(_results).Detect(model.MainFromClause.FromExpression); - } - } - - public class SwapQuerySourceVisitor : NhExpressionTreeVisitor - { - private readonly IQuerySource _oldClause; - private readonly IQuerySource _newClause; - - public SwapQuerySourceVisitor(IQuerySource oldClause, IQuerySource newClause) - { - _oldClause = oldClause; - _newClause = newClause; - } - - public Expression Swap(Expression expression) - { - return VisitExpression(expression); - } - - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) - { - if (expression.ReferencedQuerySource == _oldClause) - { - return new QuerySourceReferenceExpression(_newClause); - } - - // TODO - really don't like this drill down approach. Feels fragile - var mainFromClause = expression.ReferencedQuerySource as MainFromClause; - - if (mainFromClause != null) - { - mainFromClause.FromExpression = VisitExpression(mainFromClause.FromExpression); - } - - return expression; - } - - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) - { - expression.QueryModel.TransformExpressions(VisitExpression); - return base.VisitSubQueryExpression(expression); - } - } - - public class LocateGroupJoinQuerySource : NhExpressionTreeVisitor - { - private readonly IsAggregatingResults _results; - private GroupJoinClause _groupJoin; - - public LocateGroupJoinQuerySource(IsAggregatingResults results) - { - _results = results; - } - - public GroupJoinClause Detect(Expression expression) - { - VisitExpression(expression); - return _groupJoin; - } - - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) - { - if (_results.AggregatingClauses.Contains(expression.ReferencedQuerySource as GroupJoinClause)) - { - _groupJoin = expression.ReferencedQuerySource as GroupJoinClause; - } - - return base.VisitQuerySourceReferenceExpression(expression); - } - } - - public class IsAggregatingResults - { - public List<GroupJoinClause> NonAggregatingClauses { get; set; } - public List<GroupJoinClause> AggregatingClauses { get; set; } - public List<Expression> NonAggregatingExpressions { get; set; } - } - - internal class GroupJoinAggregateDetectionVisitor : NhExpressionTreeVisitor - { - private readonly HashSet<GroupJoinClause> _groupJoinClauses; - private readonly StackFlag _inAggregate = new StackFlag(); - private readonly StackFlag _parentExpressionProcessed = new StackFlag(); - - private readonly List<Expression> _nonAggregatingExpressions = new List<Expression>(); - private readonly List<GroupJoinClause> _nonAggregatingGroupJoins = new List<GroupJoinClause>(); - private readonly List<GroupJoinClause> _aggregatingGroupJoins = new List<GroupJoinClause>(); - - public GroupJoinAggregateDetectionVisitor(IEnumerable<GroupJoinClause> groupJoinClause) - { - _groupJoinClauses = new HashSet<GroupJoinClause>(groupJoinClause); - } - - public IsAggregatingResults Visit(Expression expression) - { - VisitExpression(expression); - - return new IsAggregatingResults { NonAggregatingClauses = _nonAggregatingGroupJoins, AggregatingClauses = _aggregatingGroupJoins, NonAggregatingExpressions = _nonAggregatingExpressions }; - } - - protected override Expression VisitSubQueryExpression(SubQueryExpression expression) - { - VisitExpression(expression.QueryModel.SelectClause.Selector); - return expression; - } - - protected override Expression VisitNhAverage(NhAverageExpression expression) - { - using (_inAggregate.SetFlag()) - { - return base.VisitNhAverage(expression); - } - } - - protected override Expression VisitNhCount(NhCountExpression expression) - { - using (_inAggregate.SetFlag()) - { - return base.VisitNhCount(expression); - } - } - - protected override Expression VisitNhMax(NhMaxExpression expression) - { - using (_inAggregate.SetFlag()) - { - return base.VisitNhMax(expression); - } - } - - protected override Expression VisitNhMin(NhMinExpression expression) - { - using (_inAggregate.SetFlag()) - { - return base.VisitNhMin(expression); - } - } - - protected override Expression VisitNhSum(NhSumExpression expression) - { - using (_inAggregate.SetFlag()) - { - return base.VisitNhSum(expression); - } - } - - protected override Expression VisitMemberExpression(MemberExpression expression) - { - if (_inAggregate.FlagIsFalse && _parentExpressionProcessed.FlagIsFalse) - { - _nonAggregatingExpressions.Add(expression); - } - - using (_parentExpressionProcessed.SetFlag()) - { - return base.VisitMemberExpression(expression); - } - } - - protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) - { - var fromClause = (FromClauseBase) expression.ReferencedQuerySource; - - if (fromClause.FromExpression is QuerySourceReferenceExpression) - { - var querySourceReference = (QuerySourceReferenceExpression) fromClause.FromExpression; - - if (_groupJoinClauses.Contains(querySourceReference.ReferencedQuerySource as GroupJoinClause)) - { - if (_inAggregate.FlagIsFalse) - { - _nonAggregatingGroupJoins.Add((GroupJoinClause) querySourceReference.ReferencedQuerySource); - } - else - { - _aggregatingGroupJoins.Add((GroupJoinClause) querySourceReference.ReferencedQuerySource); - } - } - } - - return base.VisitQuerySourceReferenceExpression(expression); - } - - internal class StackFlag - { - public bool FlagIsTrue { get; private set; } - - public bool FlagIsFalse { get { return !FlagIsTrue; } } - - public IDisposable SetFlag() - { - return new StackFlagDisposable(this); - } - - internal class StackFlagDisposable : IDisposable - { - private readonly StackFlag _parent; - private readonly bool _old; - - public StackFlagDisposable(StackFlag parent) - { - _parent = parent; - _old = parent.FlagIsTrue; - parent.FlagIsTrue = true; - } - - public void Dispose() - { - _parent.FlagIsTrue = _old; - } - } - } - } -} \ No newline at end of file Deleted: trunk/nhibernate/src/NHibernate/Linq/ReWriters/GroupBySelectClauseRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ReWriters/GroupBySelectClauseRewriter.cs 2009-11-04 10:55:10 UTC (rev 4818) +++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/GroupBySelectClauseRewriter.cs 2009-11-05 16:43:43 UTC (rev 4819) @@ -1,203 +0,0 @@ -using System; -using System.Linq.Expressions; -using NHibernate.Linq.Visitors; -using Remotion.Data.Linq; -using Remotion.Data.Linq.Clauses; -using Remotion.Data.Linq.Clauses.Expressions; -using Remotion.Data.Linq.Clauses.ResultOperators; - -namespace NHibernate.Linq.ReWriters -{ - internal class GroupBySelectClauseRewriter : NhExpressionTreeVisitor - { - public static Expression ReWrite(Expression expression, ... [truncated message content] |