|
From: <ste...@us...> - 2009-09-15 14:26:12
|
Revision: 4712
http://nhibernate.svn.sourceforge.net/nhibernate/?rev=4712&view=rev
Author: steverstrong
Date: 2009-09-15 14:26:02 +0000 (Tue, 15 Sep 2009)
Log Message:
-----------
Modified Paths:
--------------
trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs
trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs
trunk/nhibernate/src/NHibernate/Linq/CommandData.cs
trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs
trunk/nhibernate/src/NHibernate/Linq/NhExpressionTreeVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/Nominator.cs
trunk/nhibernate/src/NHibernate/Linq/ProjectionEvaluator.cs
trunk/nhibernate/src/NHibernate/Linq/QueryModelVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs
trunk/nhibernate/src/NHibernate/NHibernate.csproj
trunk/nhibernate/src/NHibernate/Type/TypeFactory.cs
trunk/nhibernate/src/NHibernate.Test/Linq/LinqQuerySamples.cs
trunk/nhibernate/src/NHibernate.Test/Linq/LinqTestCase.cs
trunk/nhibernate/src/NHibernate.Test/Linq/Mappings/Product.hbm.xml
trunk/nhibernate/src/NHibernate.Test/Linq/Mappings/ProductCategory.hbm.xml
trunk/nhibernate/src/NHibernate.Test/Linq/ReadonlyTestCase.cs
Added Paths:
-----------
trunk/nhibernate/src/NHibernate/Linq/AggregateDetectionVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupByRewriter.cs
trunk/nhibernate/src/NHibernate/Linq/ClientSideTransformOperator.cs
trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs
trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/GroupBySelectorVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/HqlGeneratorExpressionTreeVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/NhThrowingExpressionTreeVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/NonAggregatingGroupBy.cs
trunk/nhibernate/src/NHibernate/Linq/NonAggregatingGroupByRewriter.cs
Property Changed:
----------------
trunk/nhibernate/src/NHibernate/Linq/
Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs 2009-09-10 18:24:03 UTC (rev 4711)
+++ trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -496,11 +496,11 @@
//
///////////////////////////////////////////////////////////////////////////////
- bool found = elem != null;
- // even though we might find a pre-existing element by join path, for FromElements originating in a from-clause
- // we should only ever use the found element if the aliases match (null != null here). Implied joins are
- // always (?) ok to reuse.
- bool useFoundFromElement = found && ( elem.IsImplied || ( AreSame(classAlias, elem.ClassAlias ) ) );
+ bool found = elem != null;
+ // even though we might find a pre-existing element by join path, for FromElements originating in a from-clause
+ // we should only ever use the found element if the aliases match (null != null here). Implied joins are
+ // always (?) ok to reuse.
+ bool useFoundFromElement = found && ( elem.IsImplied || ( AreSame(classAlias, elem.ClassAlias ) ) );
if ( ! useFoundFromElement )
{
@@ -537,10 +537,10 @@
FromElement = elem; // This 'dot' expression now refers to the resulting from element.
}
- private bool AreSame(String alias1, String alias2) {
- // again, null != null here
- return !StringHelper.IsEmpty( alias1 ) && !StringHelper.IsEmpty( alias2 ) && alias1.Equals( alias2 );
- }
+ private bool AreSame(String alias1, String alias2) {
+ // again, null != null here
+ return !StringHelper.IsEmpty( alias1 ) && !StringHelper.IsEmpty( alias2 ) && alias1.Equals( alias2 );
+ }
private void SetImpliedJoin(FromElement elem)
{
Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-09-10 18:24:03 UTC (rev 4711)
+++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -8,7 +8,7 @@
{
public class HqlTreeBuilder
{
- private readonly ASTFactory _factory;
+ private readonly IASTFactory _factory;
public HqlTreeBuilder()
{
@@ -85,6 +85,11 @@
return new HqlIdent(_factory, ident);
}
+ public HqlIdent Ident(System.Type type)
+ {
+ return new HqlIdent(_factory, type);
+ }
+
public HqlAlias Alias(string alias)
{
return new HqlAlias(_factory, alias);
@@ -179,6 +184,11 @@
}
}
+ public HqlOrderBy OrderBy()
+ {
+ return new HqlOrderBy(_factory);
+ }
+
public HqlOrderBy OrderBy(HqlTreeNode expression, HqlDirection hqlDirection)
{
return new HqlOrderBy(_factory, expression, hqlDirection);
@@ -249,6 +259,11 @@
return new HqlGreaterThanOrEqual(_factory);
}
+ public HqlCount Count()
+ {
+ return new HqlCount(_factory);
+ }
+
public HqlCount Count(HqlTreeNode child)
{
return new HqlCount(_factory, child);
@@ -299,11 +314,21 @@
return new HqlMin(_factory);
}
+ public HqlMin Min(HqlTreeNode expression)
+ {
+ return new HqlMin(_factory, expression);
+ }
+
public HqlMax Max()
{
return new HqlMax(_factory);
}
+ public HqlMax Max(HqlTreeNode expression)
+ {
+ return new HqlMax(_factory, expression);
+ }
+
public HqlAnd And(HqlTreeNode left, HqlTreeNode right)
{
return new HqlAnd(_factory, left, right);
@@ -328,5 +353,26 @@
{
return new HqlElements(_factory);
}
+
+ public HqlDistinct Distinct()
+ {
+ return new HqlDistinct(_factory);
+ }
+
+ public HqlDirectionAscending Ascending()
+ {
+ return new HqlDirectionAscending(_factory);
+ }
+
+ public HqlDirectionDescending Descending()
+ {
+ return new HqlDirectionDescending(_factory);
+ }
+
+ public HqlGroupBy GroupBy()
+ {
+ return new HqlGroupBy(_factory);
+ }
}
+
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-09-10 18:24:03 UTC (rev 4711)
+++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -82,6 +82,13 @@
_children.Add(child);
_node.AddChild(child.AstNode);
}
+
+ public void AddChild(int index, HqlTreeNode node)
+ {
+ _children.Insert(index, node);
+ _node.InsertChild(index, node.AstNode);
+ }
+
}
public class HqlQuery : HqlTreeNode
@@ -99,6 +106,41 @@
: base(HqlSqlWalker.IDENT, ident, factory)
{
}
+
+ internal HqlIdent(IASTFactory factory, System.Type type)
+ : base(HqlSqlWalker.IDENT, "", factory)
+ {
+ if (IsNullableType(type))
+ {
+ type = ExtractUnderlyingTypeFromNullable(type);
+ }
+
+ switch (System.Type.GetTypeCode(type))
+ {
+ case TypeCode.Int32:
+ _node.Text = "integer";
+ break;
+ case TypeCode.Decimal:
+ _node.Text = "decimal";
+ break;
+ case TypeCode.DateTime:
+ _node.Text = "datetime";
+ break;
+ default:
+ throw new NotSupportedException(string.Format("Don't currently support idents of type {0}", type.Name));
+ }
+ }
+
+ private static System.Type ExtractUnderlyingTypeFromNullable(System.Type type)
+ {
+ return type.GetGenericArguments()[0];
+ }
+
+ private static bool IsNullableType(System.Type type)
+ {
+ return (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>));
+ }
+
}
public class HqlRange : HqlTreeNode
@@ -276,6 +318,11 @@
public class HqlOrderBy : HqlTreeNode
{
+ public HqlOrderBy(IASTFactory factory)
+ : base(HqlSqlWalker.ORDER, "", factory)
+ {
+ }
+
public HqlOrderBy(IASTFactory factory, HqlTreeNode expression, HqlDirection hqlDirection)
: base(HqlSqlWalker.ORDER, "", factory, expression,
hqlDirection == HqlDirection.Ascending ?
@@ -405,6 +452,12 @@
public class HqlCount : HqlTreeNode
{
+
+ public HqlCount(IASTFactory factory)
+ : base(HqlSqlWalker.COUNT, "count", factory)
+ {
+ }
+
public HqlCount(IASTFactory factory, HqlTreeNode child)
: base(HqlSqlWalker.COUNT, "count", factory, child)
{
@@ -430,41 +483,9 @@
{
public HqlCast(IASTFactory factory, HqlTreeNode expression, System.Type type) : base(HqlSqlWalker.METHOD_CALL, "method", factory)
{
- HqlIdent typeIdent;
-
- if (IsNullableType(type))
- {
- type = ExtractUnderlyingTypeFromNullable(type);
- }
-
- switch (System.Type.GetTypeCode(type))
- {
- case TypeCode.Int32:
- typeIdent = new HqlIdent(factory, "integer");
- break;
- case TypeCode.Decimal:
- typeIdent = new HqlIdent(factory, "decimal");
- break;
- case TypeCode.DateTime:
- typeIdent = new HqlIdent(factory, "datetime");
- break;
- default:
- throw new NotSupportedException(string.Format("Don't currently support casts to {0}", type.Name));
- }
-
AddChild(new HqlIdent(factory, "cast"));
- AddChild(new HqlExpressionList(factory, expression, typeIdent));
+ AddChild(new HqlExpressionList(factory, expression, new HqlIdent(factory, type)));
}
-
- private static System.Type ExtractUnderlyingTypeFromNullable(System.Type type)
- {
- return type.GetGenericArguments()[0];
- }
-
- private static bool IsNullableType(System.Type type)
- {
- return (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>));
- }
}
public class HqlExpressionList : HqlTreeNode
@@ -518,14 +539,24 @@
public HqlMax(IASTFactory factory) : base(HqlSqlWalker.AGGREGATE, "max", factory)
{
}
- }
+ public HqlMax(IASTFactory factory, HqlTreeNode expression)
+ : base(HqlSqlWalker.AGGREGATE, "max", factory, expression)
+ {
+ }
+}
+
public class HqlMin : HqlTreeNode
{
public HqlMin(IASTFactory factory)
: base(HqlSqlWalker.AGGREGATE, "min", factory)
{
}
+
+ public HqlMin(IASTFactory factory, HqlTreeNode expression)
+ : base(HqlSqlWalker.AGGREGATE, "min", factory, expression)
+ {
+ }
}
public class HqlAnd : HqlTreeNode
@@ -563,4 +594,18 @@
}
}
+ public class HqlDistinct : HqlTreeNode
+ {
+ public HqlDistinct(IASTFactory factory) : base(HqlSqlWalker.DISTINCT, "distinct", factory)
+ {
+ }
+ }
+
+ public class HqlGroupBy : HqlTreeNode
+ {
+ public HqlGroupBy(IASTFactory factory) : base(HqlSqlWalker.GROUP, "group by", factory)
+ {
+ }
+ }
+
}
\ No newline at end of file
Property changes on: trunk/nhibernate/src/NHibernate/Linq
___________________________________________________________________
Added: svn:ignore
+ _ReSharper.TempSolution
Added: trunk/nhibernate/src/NHibernate/Linq/AggregateDetectionVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/AggregateDetectionVisitor.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/AggregateDetectionVisitor.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -0,0 +1,54 @@
+using System.Linq;
+using System.Linq.Expressions;
+using Remotion.Data.Linq.Clauses.Expressions;
+using Remotion.Data.Linq.Clauses.ResultOperators;
+using Remotion.Data.Linq.Parsing;
+
+namespace NHibernate.Linq
+{
+ // TODO: This needs strengthening. For example, it doesn't recurse into SubQueries at present
+ internal class AggregateDetectionVisitor : ExpressionTreeVisitor
+ {
+ public bool ContainsAggregateMethods { get; private set; }
+
+ public bool Visit(Expression expression)
+ {
+ ContainsAggregateMethods = false;
+
+ VisitExpression(expression);
+
+ return ContainsAggregateMethods;
+ }
+
+ protected override Expression VisitMethodCallExpression(MethodCallExpression m)
+ {
+ if (m.Method.DeclaringType == typeof (Queryable) ||
+ m.Method.DeclaringType == typeof (Enumerable))
+ {
+ switch (m.Method.Name)
+ {
+ case "Count":
+ case "Min":
+ case "Max":
+ case "Sum":
+ case "Average":
+ ContainsAggregateMethods = true;
+ break;
+ }
+ }
+
+ return base.VisitMethodCallExpression(m);
+ }
+
+ protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
+ {
+ if (expression.QueryModel.ResultOperators.Count == 1
+ && typeof(ValueFromSequenceResultOperatorBase).IsAssignableFrom(expression.QueryModel.ResultOperators[0].GetType()))
+ {
+ ContainsAggregateMethods = true;
+ }
+
+ return base.VisitSubQueryExpression(expression);
+ }
+ }
+}
\ No newline at end of file
Added: trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupByRewriter.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupByRewriter.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/AggregatingGroupByRewriter.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -0,0 +1,57 @@
+using System;
+using System.Linq;
+using Remotion.Data.Linq;
+using Remotion.Data.Linq.Clauses;
+using Remotion.Data.Linq.Clauses.Expressions;
+using Remotion.Data.Linq.Clauses.ResultOperators;
+
+namespace NHibernate.Linq
+{
+ public class AggregatingGroupByRewriter : QueryModelVisitorBase
+ {
+ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel)
+ {
+ var subQueryExpression = fromClause.FromExpression as SubQueryExpression;
+
+ if ((subQueryExpression != null) &&
+ (subQueryExpression.QueryModel.ResultOperators.Count() == 1) &&
+ (subQueryExpression.QueryModel.ResultOperators[0] is GroupResultOperator) &&
+ (IsAggregatingGroupBy(queryModel)))
+ {
+ FlattenSubQuery(subQueryExpression, fromClause, queryModel);
+ }
+
+ base.VisitMainFromClause(fromClause, queryModel);
+ }
+
+ private static bool IsAggregatingGroupBy(QueryModel queryModel)
+ {
+ return new AggregateDetectionVisitor().Visit(queryModel.SelectClause.Selector);
+ }
+
+ private void FlattenSubQuery(SubQueryExpression subQueryExpression, FromClauseBase fromClause,
+ QueryModel queryModel)
+ {
+ // Replace the outer select clause...
+ queryModel.SelectClause.TransformExpressions(GroupBySelectClauseVisitor.Visit);
+
+ MainFromClause innerMainFromClause = subQueryExpression.QueryModel.MainFromClause;
+ CopyFromClauseData(innerMainFromClause, fromClause);
+
+ // Move the result operator up
+ if (queryModel.ResultOperators.Count != 0)
+ {
+ throw new NotImplementedException();
+ }
+
+ queryModel.ResultOperators.Add(subQueryExpression.QueryModel.ResultOperators[0]);
+ }
+
+ protected void CopyFromClauseData(FromClauseBase source, FromClauseBase destination)
+ {
+ destination.FromExpression = source.FromExpression;
+ destination.ItemName = source.ItemName;
+ destination.ItemType = source.ItemType;
+ }
+ }
+}
\ No newline at end of file
Added: trunk/nhibernate/src/NHibernate/Linq/ClientSideTransformOperator.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/ClientSideTransformOperator.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/ClientSideTransformOperator.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -0,0 +1,24 @@
+using System;
+using Remotion.Data.Linq.Clauses;
+using Remotion.Data.Linq.Clauses.StreamedData;
+
+namespace NHibernate.Linq
+{
+ public class ClientSideTransformOperator : ResultOperatorBase
+ {
+ public override IStreamedData ExecuteInMemory(IStreamedData input)
+ {
+ throw new NotImplementedException();
+ }
+
+ public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo)
+ {
+ throw new NotImplementedException();
+ }
+
+ public override ResultOperatorBase Clone(CloneContext cloneContext)
+ {
+ throw new NotImplementedException();
+ }
+ }
+}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/CommandData.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/CommandData.cs 2009-09-10 18:24:03 UTC (rev 4711)
+++ trunk/nhibernate/src/NHibernate/Linq/CommandData.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Linq.Expressions;
using NHibernate.Hql.Ast;
@@ -7,37 +8,83 @@
{
public class CommandData
{
- public CommandData(HqlQuery statement, NamedParameter[] namedParameters, LambdaExpression projectionExpression, List<Action<IQuery>> additionalCriteria)
+ private readonly List<LambdaExpression> _itemTransformers;
+ private readonly List<LambdaExpression> _listTransformers;
+
+ public CommandData(HqlQuery statement, NamedParameter[] namedParameters, List<LambdaExpression> itemTransformers, List<LambdaExpression> listTransformers, List<Action<IQuery>> additionalCriteria)
{
+ _itemTransformers = itemTransformers;
+ _listTransformers = listTransformers;
+
Statement = statement;
NamedParameters = namedParameters;
- ProjectionExpression = projectionExpression;
AdditionalCriteria = additionalCriteria;
}
public HqlQuery Statement { get; private set; }
public NamedParameter[] NamedParameters { get; private set; }
- public LambdaExpression ProjectionExpression { get; set; }
+
public List<Action<IQuery>> AdditionalCriteria { get; set; }
+ public System.Type QueryResultType { get; set; }
+
public IQuery CreateQuery(ISession session, System.Type type)
{
var query = session.CreateQuery(new HqlExpression(Statement, type));
+ SetParameters(query);
+
+ SetResultTransformer(query);
+
+ AddAdditionalCriteria(query);
+
+ return query;
+ }
+
+ private void SetParameters(IQuery query)
+ {
foreach (var parameter in NamedParameters)
+ {
query.SetParameter(parameter.Name, parameter.Value);
-
- if (ProjectionExpression != null)
- {
- query.SetResultTransformer(new ResultTransformer(ProjectionExpression));
}
+ }
+ private void AddAdditionalCriteria(IQuery query)
+ {
foreach (var criteria in AdditionalCriteria)
{
criteria(query);
}
+ }
- return query;
+ private void SetResultTransformer(IQuery query)
+ {
+ var itemTransformer = MergeLambdas(_itemTransformers);
+ var listTransformer = MergeLambdas(_listTransformers);
+
+ if (itemTransformer != null || listTransformer != null)
+ {
+ query.SetResultTransformer(new ResultTransformer(itemTransformer, listTransformer));
+ }
}
+
+ private static LambdaExpression MergeLambdas(IList<LambdaExpression> transformations)
+ {
+ if (transformations == null || transformations.Count == 0)
+ {
+ return null;
+ }
+
+ var listTransformLambda = transformations[0];
+
+ for (int i = 1; i < transformations.Count; i++)
+ {
+ var invoked = Expression.Invoke(transformations[i], listTransformLambda.Body);
+
+ listTransformLambda = Expression.Lambda(invoked, listTransformLambda.Parameters.ToArray());
+ }
+
+ return listTransformLambda;
+ }
}
}
\ No newline at end of file
Added: trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -0,0 +1,53 @@
+using System.Linq;
+using System.Reflection;
+
+namespace NHibernate.Linq
+{
+ public static class EnumerableHelper
+ {
+ public static MethodInfo GetMethod(string name, System.Type[] parameterTypes)
+ {
+ return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
+ .Where(m => m.Name == name &&
+ ParameterTypesMatch(m.GetParameters(), parameterTypes))
+ .Single();
+ }
+
+ public static MethodInfo GetMethod(string name, System.Type[] parameterTypes, System.Type[] genericTypeParameters)
+ {
+ return typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
+ .Where(m => m.Name == name &&
+ m.ContainsGenericParameters &&
+ m.GetGenericArguments().Count() == genericTypeParameters.Length &&
+ ParameterTypesMatch(m.GetParameters(), parameterTypes))
+ .Single()
+ .MakeGenericMethod(genericTypeParameters);
+ }
+
+ private static bool ParameterTypesMatch(ParameterInfo[] parameters, System.Type[] types)
+ {
+ if (parameters.Length != types.Length)
+ {
+ return false;
+ }
+
+ for (int i = 0; i < parameters.Length; i++)
+ {
+ if (parameters[i].ParameterType == types[i])
+ {
+ continue;
+ }
+
+ if (parameters[i].ParameterType.ContainsGenericParameters && types[i].ContainsGenericParameters &&
+ parameters[i].ParameterType.GetGenericArguments().Length == types[i].GetGenericArguments().Length)
+ {
+ continue;
+ }
+
+ return false;
+ }
+
+ return true;
+ }
+ }
+}
\ No newline at end of file
Added: trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseVisitor.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/GroupBySelectClauseVisitor.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -0,0 +1,148 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Linq.Expressions;
+using Remotion.Data.Linq.Clauses;
+using Remotion.Data.Linq.Clauses.Expressions;
+using Remotion.Data.Linq.Clauses.ResultOperators;
+using Remotion.Data.Linq.Parsing;
+
+namespace NHibernate.Linq
+{
+ internal class GroupBySelectClauseVisitor : ExpressionTreeVisitor
+ {
+ public static Expression Visit(Expression expression)
+ {
+ var visitor = new GroupBySelectClauseVisitor();
+ return visitor.VisitExpression(expression);
+ }
+
+ protected override Expression VisitMemberExpression(MemberExpression expression)
+ {
+ if (expression.Member.Name == "Key" &&
+ expression.Member.DeclaringType.GetGenericTypeDefinition() == typeof (IGrouping<,>))
+ {
+ var querySourceRef = expression.Expression as QuerySourceReferenceExpression;
+
+ var fromClause = querySourceRef.ReferencedQuerySource as FromClauseBase;
+
+ var subQuery = fromClause.FromExpression as SubQueryExpression;
+
+ var groupBy =
+ subQuery.QueryModel.ResultOperators.Where(r => r is GroupResultOperator).Single() as
+ GroupResultOperator;
+
+ return groupBy.KeySelector;
+ }
+ else
+ {
+ return base.VisitMemberExpression(expression);
+ }
+ }
+
+ protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
+ {
+ if (expression.QueryModel.ResultOperators.Count == 1)
+ {
+ ResultOperatorBase resultOperator = expression.QueryModel.ResultOperators[0];
+
+ if (resultOperator is AverageResultOperator)
+ {
+ return new AverageExpression(expression.QueryModel.SelectClause.Selector);
+ }
+ else if (resultOperator is MinResultOperator)
+ {
+ return new MinExpression(expression.QueryModel.SelectClause.Selector);
+ }
+ else if (resultOperator is MaxResultOperator)
+ {
+ return new MaxExpression(expression.QueryModel.SelectClause.Selector);
+ }
+ else if (resultOperator is CountResultOperator)
+ {
+ return new CountExpression();
+ }
+ else if (resultOperator is SumResultOperator)
+ {
+ return new SumExpression(expression.QueryModel.SelectClause.Selector);
+ }
+ else
+ {
+ throw new NotImplementedException();
+ }
+ }
+ else
+ {
+ return base.VisitSubQueryExpression(expression);
+ }
+ }
+ }
+
+ public enum NhExpressionType
+ {
+ Average = 10000,
+ Min,
+ Max,
+ Sum,
+ Count,
+ Distinct
+ }
+
+ public class NhAggregatedExpression : Expression
+ {
+ public Expression Expression { get; set; }
+
+ public NhAggregatedExpression(Expression expression, NhExpressionType type)
+ : base((ExpressionType)type, expression.Type)
+ {
+ Expression = expression;
+ }
+ }
+
+ public class AverageExpression : NhAggregatedExpression
+ {
+ public AverageExpression(Expression expression) : base(expression, NhExpressionType.Average)
+ {
+ }
+ }
+
+ public class MinExpression : NhAggregatedExpression
+ {
+ public MinExpression(Expression expression)
+ : base(expression, NhExpressionType.Min)
+ {
+ }
+ }
+
+ public class MaxExpression : NhAggregatedExpression
+ {
+ public MaxExpression(Expression expression)
+ : base(expression, NhExpressionType.Max)
+ {
+ }
+ }
+
+ public class SumExpression : NhAggregatedExpression
+ {
+ public SumExpression(Expression expression)
+ : base(expression, NhExpressionType.Sum)
+ {
+ }
+ }
+
+ public class DistinctExpression : NhAggregatedExpression
+ {
+ public DistinctExpression(Expression expression)
+ : base(expression, NhExpressionType.Distinct)
+ {
+ }
+ }
+
+ public class CountExpression : Expression
+ {
+ public CountExpression()
+ : base((ExpressionType)NhExpressionType.Count, typeof(int))
+ {
+ }
+ }
+}
Added: trunk/nhibernate/src/NHibernate/Linq/GroupBySelectorVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/GroupBySelectorVisitor.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/GroupBySelectorVisitor.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -0,0 +1,26 @@
+using System.Linq.Expressions;
+using Remotion.Data.Linq.Clauses.Expressions;
+using Remotion.Data.Linq.Parsing;
+
+namespace NHibernate.Linq
+{
+ internal class GroupBySelectorVisitor : ExpressionTreeVisitor
+ {
+ private readonly ParameterExpression _parameter;
+
+ public GroupBySelectorVisitor(ParameterExpression parameter)
+ {
+ _parameter = parameter;
+ }
+
+ public Expression Visit(Expression expression)
+ {
+ return VisitExpression(expression);
+ }
+
+ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
+ {
+ return _parameter;
+ }
+ }
+}
\ No newline at end of file
Added: trunk/nhibernate/src/NHibernate/Linq/HqlGeneratorExpressionTreeVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/HqlGeneratorExpressionTreeVisitor.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/HqlGeneratorExpressionTreeVisitor.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -0,0 +1,353 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Linq.Expressions;
+using NHibernate.Hql.Ast;
+using Remotion.Data.Linq.Clauses.Expressions;
+using Remotion.Data.Linq.Clauses.ExpressionTreeVisitors;
+
+namespace NHibernate.Linq
+{
+ public class HqlGeneratorExpressionTreeVisitor : NhThrowingExpressionTreeVisitor
+ {
+ protected readonly HqlTreeBuilder _hqlTreeBuilder;
+ protected readonly HqlNodeStack _stack;
+ private readonly ParameterAggregator _parameterAggregator;
+
+ public HqlGeneratorExpressionTreeVisitor(ParameterAggregator parameterAggregator)
+ {
+ _parameterAggregator = parameterAggregator;
+ _hqlTreeBuilder = new HqlTreeBuilder();
+ _stack = new HqlNodeStack(_hqlTreeBuilder);
+ }
+
+ public IEnumerable<HqlTreeNode> GetHqlTreeNodes()
+ {
+ return _stack.Finish();
+ }
+
+ public virtual void Visit(Expression expression)
+ {
+ VisitExpression(expression);
+ }
+
+ protected override Expression VisitNhAverage(AverageExpression expression)
+ {
+ var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
+ visitor.Visit(expression.Expression);
+
+ _stack.PushLeaf(_hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(visitor.GetHqlTreeNodes().Single()), expression.Type));
+
+ return expression;
+ }
+
+ protected override Expression VisitNhCount(CountExpression expression)
+ {
+ _stack.PushLeaf(_hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(_hqlTreeBuilder.RowStar()), expression.Type));
+
+ return expression;
+ }
+
+ protected override Expression VisitNhMin(MinExpression expression)
+ {
+ var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
+ visitor.Visit(expression.Expression);
+
+ _stack.PushLeaf(_hqlTreeBuilder.Cast(_hqlTreeBuilder.Min(visitor.GetHqlTreeNodes().Single()), expression.Type));
+
+ return expression;
+ }
+
+ protected override Expression VisitNhMax(MaxExpression expression)
+ {
+ var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
+ visitor.Visit(expression.Expression);
+
+ _stack.PushLeaf(_hqlTreeBuilder.Cast(_hqlTreeBuilder.Max(visitor.GetHqlTreeNodes().Single()), expression.Type));
+
+ return expression;
+ }
+
+ protected override Expression VisitNhSum(SumExpression expression)
+ {
+ var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
+ visitor.Visit(expression.Expression);
+
+ _stack.PushLeaf(_hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(visitor.GetHqlTreeNodes().Single()), expression.Type));
+
+ return expression;
+ }
+
+ protected override Expression VisitNhDistinct(DistinctExpression expression)
+ {
+ var visitor = new HqlGeneratorExpressionTreeVisitor(_parameterAggregator);
+ visitor.Visit(expression.Expression);
+
+ _stack.PushLeaf(_hqlTreeBuilder.Distinct());
+
+ foreach (var node in visitor.GetHqlTreeNodes())
+ {
+ _stack.PushLeaf(node);
+ }
+
+ return expression;
+ }
+
+ protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
+ {
+ _stack.PushLeaf(_hqlTreeBuilder.Ident(expression.ReferencedQuerySource.ItemName));
+
+ return expression;
+ }
+
+ protected override Expression VisitBinaryExpression(BinaryExpression expression)
+ {
+ HqlTreeNode operatorNode = GetHqlOperatorNodeForBinaryOperator(expression);
+
+ using (_stack.PushNode(operatorNode))
+ {
+ VisitExpression(expression.Left);
+
+ VisitExpression(expression.Right);
+ }
+
+ return expression;
+ }
+
+ private HqlTreeNode GetHqlOperatorNodeForBinaryOperator(BinaryExpression expression)
+ {
+ switch (expression.NodeType)
+ {
+ case ExpressionType.Equal:
+ return _hqlTreeBuilder.Equality();
+
+ case ExpressionType.NotEqual:
+ return _hqlTreeBuilder.Inequality();
+
+ case ExpressionType.And:
+ case ExpressionType.AndAlso:
+ return _hqlTreeBuilder.BooleanAnd();
+
+ case ExpressionType.Or:
+ case ExpressionType.OrElse:
+ return _hqlTreeBuilder.BooleanOr();
+
+ case ExpressionType.Add:
+ return _hqlTreeBuilder.Add();
+
+ case ExpressionType.Subtract:
+ return _hqlTreeBuilder.Subtract();
+
+ case ExpressionType.Multiply:
+ return _hqlTreeBuilder.Multiply();
+
+ case ExpressionType.Divide:
+ return _hqlTreeBuilder.Divide();
+
+ case ExpressionType.LessThan:
+ return _hqlTreeBuilder.LessThan();
+
+ case ExpressionType.LessThanOrEqual:
+ return _hqlTreeBuilder.LessThanOrEqual();
+
+ case ExpressionType.GreaterThan:
+ return _hqlTreeBuilder.GreaterThan();
+
+ case ExpressionType.GreaterThanOrEqual:
+ return _hqlTreeBuilder.GreaterThanOrEqual();
+ }
+
+ throw new InvalidOperationException();
+ }
+
+ protected override Expression VisitUnaryExpression(UnaryExpression expression)
+ {
+ HqlTreeNode operatorNode = GetHqlOperatorNodeforUnaryOperator(expression);
+
+ using (_stack.PushNode(operatorNode))
+ {
+ VisitExpression(expression.Operand);
+ }
+
+ return expression;
+ }
+
+ private HqlTreeNode GetHqlOperatorNodeforUnaryOperator(UnaryExpression expression)
+ {
+ switch (expression.NodeType)
+ {
+ case ExpressionType.Not:
+ return _hqlTreeBuilder.Not();
+ }
+
+ throw new InvalidOperationException();
+ }
+
+ protected override Expression VisitMemberExpression(MemberExpression expression)
+ {
+ using (_stack.PushNode(_hqlTreeBuilder.Dot()))
+ {
+ Expression newExpression = VisitExpression(expression.Expression);
+
+ _stack.PushLeaf(_hqlTreeBuilder.Ident(expression.Member.Name));
+
+ if (newExpression != expression.Expression)
+ {
+ return Expression.MakeMemberAccess(newExpression, expression.Member);
+ }
+ }
+
+ return expression;
+ }
+
+ protected override Expression VisitConstantExpression(ConstantExpression expression)
+ {
+ if (expression.Value != null)
+ {
+ System.Type t = expression.Value.GetType();
+
+ if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof (NhQueryable<>))
+ {
+ _stack.PushLeaf(_hqlTreeBuilder.Ident(t.GetGenericArguments()[0].Name));
+ return expression;
+ }
+ }
+
+ /*
+ var namedParameter = _parameterAggregator.AddParameter(expression.Value);
+
+ _expression = _hqlTreeBuilder.Parameter(namedParameter.Name);
+
+ return expression;
+ */
+ // TODO - get parameter support in place in the HQLQueryPlan
+ _stack.PushLeaf(_hqlTreeBuilder.Constant(expression.Value));
+
+ return expression;
+ }
+
+ protected override Expression VisitMethodCallExpression(MethodCallExpression expression)
+ {
+ if (expression.Method.DeclaringType == typeof(Enumerable) ||
+ expression.Method.DeclaringType == typeof(Queryable))
+ {
+ switch (expression.Method.Name)
+ {
+ case "Any":
+ // Any has one or two arguments. Arg 1 is the source and arg 2 is the optional predicate
+ using (_stack.PushNode(_hqlTreeBuilder.Exists()))
+ {
+ using (_stack.PushNode(_hqlTreeBuilder.Query()))
+ {
+ using (_stack.PushNode(_hqlTreeBuilder.SelectFrom()))
+ {
+ using (_stack.PushNode(_hqlTreeBuilder.From()))
+ {
+ using (_stack.PushNode(_hqlTreeBuilder.Range()))
+ {
+ VisitExpression(expression.Arguments[0]);
+
+ if (expression.Arguments.Count > 1)
+ {
+ var expr = (LambdaExpression) expression.Arguments[1];
+ _stack.PushLeaf(_hqlTreeBuilder.Alias(expr.Parameters[0].Name));
+ }
+ }
+ }
+ }
+ if (expression.Arguments.Count > 1)
+ {
+ using (_stack.PushNode(_hqlTreeBuilder.Where()))
+ {
+ VisitExpression(expression.Arguments[1]);
+ }
+ }
+ }
+ }
+ break;
+ case "Min":
+ using (_stack.PushNode(_hqlTreeBuilder.Min()))
+ {
+ VisitExpression(expression.Arguments[1]);
+ }
+ break;
+ case "Max":
+ using (_stack.PushNode(_hqlTreeBuilder.Max()))
+ {
+ VisitExpression(expression.Arguments[1]);
+ }
+ break;
+ default:
+ throw new NotSupportedException(string.Format("The Enumerable method {0} is not supported", expression.Method.Name));
+ }
+
+ return expression;
+ }
+ else
+ {
+ return base.VisitMethodCallExpression(expression); // throws
+ }
+ }
+
+ protected override Expression VisitLambdaExpression(LambdaExpression expression)
+ {
+ VisitExpression(expression.Body);
+
+ return expression;
+ }
+
+ protected override Expression VisitParameterExpression(ParameterExpression expression)
+ {
+ _stack.PushLeaf(_hqlTreeBuilder.Ident(expression.Name));
+
+ return expression;
+ }
+
+ protected override Expression VisitConditionalExpression(ConditionalExpression expression)
+ {
+ using (_stack.PushNode(_hqlTreeBuilder.Case()))
+ {
+ using (_stack.PushNode(_hqlTreeBuilder.When()))
+ {
+ VisitExpression(expression.Test);
+
+ VisitExpression(expression.IfTrue);
+ }
+
+ if (expression.IfFalse != null)
+ {
+ using (_stack.PushNode(_hqlTreeBuilder.Else()))
+ {
+ VisitExpression(expression.IfFalse);
+ }
+ }
+ }
+
+ return expression;
+ }
+
+ protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
+ {
+ CommandData query = QueryModelVisitor.GenerateHqlQuery(expression.QueryModel, _parameterAggregator);
+
+ _stack.PushLeaf(query.Statement);
+
+ return expression;
+ }
+
+
+ // Called when a LINQ expression type is not handled above.
+ protected override Exception CreateUnhandledItemException<T>(T unhandledItem, string visitMethod)
+ {
+ string itemText = FormatUnhandledItem(unhandledItem);
+ var message = string.Format("The expression '{0}' (type: {1}) is not supported by this LINQ provider.", itemText, typeof(T));
+ return new NotSupportedException(message);
+ }
+
+ private string FormatUnhandledItem<T>(T unhandledItem)
+ {
+ var itemAsExpression = unhandledItem as Expression;
+ return itemAsExpression != null ? FormattingExpressionTreeVisitor.Format(itemAsExpression) : unhandledItem.ToString();
+ }
+ }
+}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs 2009-09-10 18:24:03 UTC (rev 4711)
+++ trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -27,12 +27,12 @@
return holder.Children;
}
- public void PushAndPop(HqlTreeNode query)
+ public void PushLeaf(HqlTreeNode query)
{
- Push(query).Dispose();
+ PushNode(query).Dispose();
}
- public IDisposable Push(HqlTreeNode query)
+ public IDisposable PushNode(HqlTreeNode query)
{
_stack.Peek().AddChild(query);
Modified: trunk/nhibernate/src/NHibernate/Linq/NhExpressionTreeVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/NhExpressionTreeVisitor.cs 2009-09-10 18:24:03 UTC (rev 4711)
+++ trunk/nhibernate/src/NHibernate/Linq/NhExpressionTreeVisitor.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -1,286 +1,69 @@
-using System;
-using System.Collections.Generic;
-using System.Linq;
using System.Linq.Expressions;
-using NHibernate.Hql.Ast;
-using Remotion.Data.Linq.Clauses.Expressions;
-using Remotion.Data.Linq.Clauses.ExpressionTreeVisitors;
using Remotion.Data.Linq.Parsing;
namespace NHibernate.Linq
{
- public class NhExpressionTreeVisitor : ThrowingExpressionTreeVisitor
+ public class NhExpressionTreeVisitor : ExpressionTreeVisitor
{
- protected readonly HqlTreeBuilder _hqlTreeBuilder;
- protected readonly HqlNodeStack _stack;
- private readonly ParameterAggregator _parameterAggregator;
-
- public NhExpressionTreeVisitor(ParameterAggregator parameterAggregator)
+ protected override Expression VisitExpression(Expression expression)
{
- _parameterAggregator = parameterAggregator;
- _hqlTreeBuilder = new HqlTreeBuilder();
- _stack = new HqlNodeStack(_hqlTreeBuilder);
- }
-
- public IEnumerable<HqlTreeNode> GetAstBuilderNode()
- {
- return _stack.Finish();
- }
-
- public virtual void Visit(Expression expression)
- {
- VisitExpression(expression);
- }
-
- protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
- {
- _stack.PushAndPop(_hqlTreeBuilder.Ident(expression.ReferencedQuerySource.ItemName));
-
- return expression;
- }
-
- protected override Expression VisitBinaryExpression(BinaryExpression expression)
- {
- HqlTreeNode operatorNode = GetHqlOperatorNodeForBinaryOperator(expression);
-
- using (_stack.Push(operatorNode))
+ switch ((NhExpressionType) expression.NodeType)
{
- VisitExpression(expression.Left);
-
- VisitExpression(expression.Right);
+ case NhExpressionType.Average:
+ return VisitNhAverage((AverageExpression) expression);
+ case NhExpressionType.Min:
+ return VisitNhMin((MinExpression)expression);
+ case NhExpressionType.Max:
+ return VisitNhMax((MaxExpression)expression);
+ case NhExpressionType.Sum:
+ return VisitNhSum((SumExpression)expression);
+ case NhExpressionType.Count:
+ return VisitNhCount((CountExpression)expression);
+ case NhExpressionType.Distinct:
+ return VisitNhDistinct((DistinctExpression) expression);
}
- return expression;
+ return base.VisitExpression(expression);
}
- private HqlTreeNode GetHqlOperatorNodeForBinaryOperator(BinaryExpression expression)
+ protected virtual Expression VisitNhDistinct(DistinctExpression expression)
{
- switch (expression.NodeType)
- {
- case ExpressionType.Equal:
- return _hqlTreeBuilder.Equality();
+ Expression nx = base.VisitExpression(expression.Expression);
- case ExpressionType.NotEqual:
- return _hqlTreeBuilder.Inequality();
-
- case ExpressionType.And:
- case ExpressionType.AndAlso:
- return _hqlTreeBuilder.BooleanAnd();
-
- case ExpressionType.Or:
- case ExpressionType.OrElse:
- return _hqlTreeBuilder.BooleanOr();
-
- case ExpressionType.Add:
- return _hqlTreeBuilder.Add();
-
- case ExpressionType.Subtract:
- return _hqlTreeBuilder.Subtract();
-
- case ExpressionType.Multiply:
- return _hqlTreeBuilder.Multiply();
-
- case ExpressionType.Divide:
- return _hqlTreeBuilder.Divide();
-
- case ExpressionType.LessThan:
- return _hqlTreeBuilder.LessThan();
-
- case ExpressionType.LessThanOrEqual:
- return _hqlTreeBuilder.LessThanOrEqual();
-
- case ExpressionType.GreaterThan:
- return _hqlTreeBuilder.GreaterThan();
-
- case ExpressionType.GreaterThanOrEqual:
- return _hqlTreeBuilder.GreaterThanOrEqual();
- }
-
- throw new InvalidOperationException();
+ return nx != expression.Expression ? new DistinctExpression(nx) : expression;
}
- protected override Expression VisitUnaryExpression(UnaryExpression expression)
+ protected virtual Expression VisitNhCount(CountExpression expression)
{
- HqlTreeNode operatorNode = GetHqlOperatorNodeforUnaryOperator(expression);
-
- using (_stack.Push(operatorNode))
- {
- VisitExpression(expression.Operand);
- }
-
return expression;
}
- private HqlTreeNode GetHqlOperatorNodeforUnaryOperator(UnaryExpression expression)
+ protected virtual Expression VisitNhSum(SumExpression expression)
{
- switch (expression.NodeType)
- {
- case ExpressionType.Not:
- return _hqlTreeBuilder.Not();
- }
-
- throw new InvalidOperationException();
- }
+ Expression nx = base.VisitExpression(expression.Expression);
- protected override Expression VisitMemberExpression(MemberExpression expression)
- {
- using (_stack.Push(_hqlTreeBuilder.Dot()))
- {
- Expression newExpression = VisitExpression(expression.Expression);
-
- _stack.PushAndPop(_hqlTreeBuilder.Ident(expression.Member.Name));
-
- if (newExpression != expression.Expression)
- {
- return Expression.MakeMemberAccess(newExpression, expression.Member);
- }
- }
-
- return expression;
+ return nx != expression.Expression ? new SumExpression(nx) : expression;
}
- protected override Expression VisitConstantExpression(ConstantExpression expression)
+ protected virtual Expression VisitNhMax(MaxExpression expression)
{
- if (expression.Value != null)
- {
- System.Type t = expression.Value.GetType();
+ Expression nx = base.VisitExpression(expression.Expression);
- if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof (NhQueryable<>))
- {
- _stack.PushAndPop(_hqlTreeBuilder.Ident(t.GetGenericArguments()[0].Name));
- return expression;
- }
- }
-
- /*
- var namedParameter = _parameterAggregator.AddParameter(expression.Value);
-
- _expression = _hqlTreeBuilder.Parameter(namedParameter.Name);
-
- return expression;
- */
- // TODO - get parameter support in place in the HQLQueryPlan
- _stack.PushAndPop(_hqlTreeBuilder.Constant(expression.Value));
-
- return expression;
+ return nx != expression.Expression ? new MaxExpression(nx) : expression;
}
- protected override Expression VisitMethodCallExpression(MethodCallExpression expression)
+ protected virtual Expression VisitNhMin(MinExpression expression)
{
- if (expression.Method.DeclaringType == typeof(Enumerable))
- {
- switch (expression.Method.Name)
- {
- case "Any":
- // Any has one or two arguments. Arg 1 is the source and arg 2 is the optional predicate
- using (_stack.Push(_hqlTreeBuilder.Exists()))
- {
- using (_stack.Push(_hqlTreeBuilder.Query()))
- {
- using (_stack.Push(_hqlTreeBuilder.SelectFrom()))
- {
- using (_stack.Push(_hqlTreeBuilder.From()))
- {
- using (_stack.Push(_hqlTreeBuilder.Range()))
- {
- VisitExpression(expression.Arguments[0]);
+ Expression nx = base.VisitExpression(expression.Expression);
- if (expression.Arguments.Count > 1)
- {
- var expr = (LambdaExpression) expression.Arguments[1];
- _stack.PushAndPop(_hqlTreeBuilder.Alias(expr.Parameters[0].Name));
- }
- }
- }
- }
- if (expression.Arguments.Count > 1)
- {
- using (_stack.Push(_hqlTreeBuilder.Where()))
- {
- VisitExpression(expression.Arguments[1]);
- }
- }
- }
- }
- break;
- default:
- throw new NotSupportedException(string.Format("The Enumerable method {0} is not supported", expression.Method.Name));
- }
-
- return expression;
- }
- else
- {
- return base.VisitMethodCallExpression(expression); // throws
- }
+ return nx != expression.Expression ? new MinExpression(nx) : expression;
}
- protected override Expression VisitLambdaExpression(LambdaExpression expression)
+ protected virtual Expression VisitNhAverage(AverageExpression expression)
{
- VisitExpression(expression.Body);
+ Expression nx = base.VisitExpression(expression.Expression);
- return expression;
+ return nx != expression.Expression ? new AverageExpression(nx) : expression;
}
-
- protected override Expression VisitParameterExpression(ParameterExpression expression)
- {
- _stack.PushAndPop(_hqlTreeBuilder.Ident(expression.Name));
-
- return expression;
- }
-
- protected override Expression VisitConditionalExpression(ConditionalExpression expression)
- {
- using (_stack.Push(_hqlTreeBuilder.Case()))
- {
- using (_stack.Push(_hqlTreeBuilder.When()))
- {
- VisitExpression(expression.Test);
-
- VisitExpression(expression.IfTrue);
- }
-
- if (expression.IfFalse != null)
- {
- using (_stack.Push(_hqlTreeBuilder.Else()))
- {
- VisitExpression(expression.IfFalse);
- }
- }
- }
-
- return expression;
- }
-
- protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
- {
- CommandData query = QueryModelVisitor.GenerateHqlQuery(expression.QueryModel, _parameterAggregator);
-
- if (query.ProjectionExpression != null)
- {
- throw new InvalidOperationException();
- }
-
- // TODO - what if there was a projection expression?
-
- _stack.PushAndPop(query.Statement);
-
- return expression;
- }
-
-
- // Called when a LINQ expression type is not handled above.
- protected override Exception CreateUnhandledItemException<T>(T unhandledItem, string visitMethod)
- {
- string itemText = FormatUnhandledItem(unhandledItem);
- var message = string.Format("The expression '{0}' (type: {1}) is not supported by this LINQ provider.", itemText, typeof(T));
- return new NotSupportedException(message);
- }
-
- private string FormatUnhandledItem<T>(T unhandledItem)
- {
- var itemAsExpression = unhandledItem as Expression;
- return itemAsExpression != null ? FormattingExpressionTreeVisitor.Format(itemAsExpression) : unhandledItem.ToString();
- }
}
}
\ No newline at end of file
Added: trunk/nhibernate/src/NHibernate/Linq/NhThrowingExpressionTreeVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/NhThrowingExpressionTreeVisitor.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate/Linq/NhThrowingExpressionTreeVisitor.cs 2009-09-15 14:26:02 UTC (rev 4712)
@@ -0,0 +1,101 @@
+using System;
+using System.Linq.Expressions;
+using Remotion.Data.Linq.Parsing;
+
+namespace NHibernate.Linq
+{
+ public abstract class NhThrowingExpressionTreeVisitor : ThrowingExpressionTreeVisitor
+ {
+ protected override Expression VisitExpression(Expression expression)
+ {
+ switch ((NhExpressionType)expression.NodeType)
+ {
+ case NhExpressionType.Average:
+ return VisitNhAverage((AverageExpression)expression);
+ case NhExpressionType.Min:
+ return VisitNhMin((MinExpression)expression);
+ case NhExpressionType.Max:
+ return VisitNhMax((MaxExpression)expression);
+ case NhExpressionType.Sum:
+ return VisitNhSum((SumExpression)expression);
+ case NhExpressionType.Count:
+ return VisitNhCount((CountExpression)expression);
+ case NhExpressionType.Distinct:
+ return VisitNhDistinct((DistinctExpression) expression);
+ }
+
+ return base.VisitExpression(expression);
+ }
+
+ protected virtual Expression VisitNhDistinct(DistinctExpression expression)
+ {
+ return VisitUnhandledItem<DistinctExpression, Expression>(expression, "VisitNhDistinct", BaseVisitNhDistinct);
+ }
+
+ protected virtual Expression VisitNhAverage(AverageExpression expression)
+ {
+ return VisitUnhandledItem<AverageExpression, Expression>(expression, "VisitNhAverage", BaseVisitNhAverage);
+ }
+
+ protected virtual Expression VisitNhMin(MinExpression expression)
+ {
+ return VisitUnhandledItem<MinExpression, Expression>(expression, "VisitNhMin", BaseVisitNhMin);
+ }
+
+ protected virtual Expression VisitNhMax(MaxExpression expression)
+ {
+ return VisitUnhandledItem<MaxExpression, Expression>(expression, "VisitNhMax", BaseVisitNhMax);
+ }
+
+ protected virtual Expression VisitNhSum(SumExpression expression)
+ {
+ return VisitUnhandledItem<SumExpression, Expression>(expression, "VisitNhSum", BaseVisitNhSum);
+ }
+
+ protected virtual Expression VisitNhCount(CountEx...
[truncated message content] |