|
From: <ste...@us...> - 2009-11-06 22:50:31
|
Revision: 4822
http://nhibernate.svn.sourceforge.net/nhibernate/?rev=4822&view=rev
Author: steverstrong
Date: 2009-11-06 22:50:19 +0000 (Fri, 06 Nov 2009)
Log Message:
-----------
Further Linq work - Support for Enumerable.Aggregate and various string functions (xxx.StartsWith etc)
Modified Paths:
--------------
trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs
trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs
trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs
trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs
trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs
trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs
trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj
Added Paths:
-----------
trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs
Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-11-05 17:46:33 UTC (rev 4821)
+++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-11-06 22:50:19 UTC (rev 4822)
@@ -383,5 +383,20 @@
{
return new HqlAll(_factory);
}
+
+ public HqlLike Like()
+ {
+ return new HqlLike(_factory);
+ }
+
+ public HqlConcat Concat()
+ {
+ return new HqlConcat(_factory);
+ }
+
+ public HqlExpressionList ExpressionList()
+ {
+ return new HqlExpressionList(_factory);
+ }
}
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-11-05 17:46:33 UTC (rev 4821)
+++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-11-06 22:50:19 UTC (rev 4822)
@@ -632,4 +632,18 @@
{
}
}
+
+ public class HqlLike : HqlTreeNode
+ {
+ public HqlLike(IASTFactory factory) : base(HqlSqlWalker.LIKE, "like", factory)
+ {
+ }
+ }
+
+ public class HqlConcat : HqlTreeNode
+ {
+ public HqlConcat(IASTFactory factory) : base(HqlSqlWalker.METHOD_CALL, "method", factory)
+ {
+ }
+ }
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs 2009-11-05 17:46:33 UTC (rev 4821)
+++ trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs 2009-11-06 22:50:19 UTC (rev 4822)
@@ -1,8 +1,28 @@
+using System;
using System.Linq;
+using System.Linq.Expressions;
using System.Reflection;
namespace NHibernate.Linq
{
+ public static class ReflectionHelper
+ {
+ public delegate void Action();
+
+ public static MethodInfo GetMethod<TSource>(Expression<Action<TSource>> method)
+ {
+ var methodInfo = ((MethodCallExpression) method.Body).Method;
+ return methodInfo.IsGenericMethod ? methodInfo.GetGenericMethodDefinition() : methodInfo;
+ }
+
+ public static MethodInfo GetMethod(Expression<Action> method)
+ {
+ var methodInfo = ((MethodCallExpression)method.Body).Method;
+ return methodInfo.IsGenericMethod ? methodInfo.GetGenericMethodDefinition() : methodInfo;
+ }
+ }
+
+ // TODO rename / remove - reflection helper above is better
public static class EnumerableHelper
{
public static MethodInfo GetMethod(string name, System.Type[] parameterTypes)
Modified: trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs 2009-11-05 17:46:33 UTC (rev 4821)
+++ trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs 2009-11-06 22:50:19 UTC (rev 4822)
@@ -11,6 +11,7 @@
public HqlNodeStack(HqlTreeBuilder builder)
{
+ // TODO - only reason for the build is to have a root node. Sucks, change this
_root = builder.Holder();
_stack.Push(_root);
}
Modified: trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs 2009-11-05 17:46:33 UTC (rev 4821)
+++ trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs 2009-11-06 22:50:19 UTC (rev 4822)
@@ -1,16 +1,36 @@
-using System.Collections.Generic;
+using System;
+using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using NHibernate.Engine.Query;
using NHibernate.Hql.Ast.ANTLR.Tree;
+using NHibernate.Linq.ResultOperators;
using NHibernate.Linq.Visitors;
+using Remotion.Data.Linq.Clauses;
+using Remotion.Data.Linq.Clauses.StreamedData;
using Remotion.Data.Linq.Parsing.ExpressionTreeVisitors;
using Remotion.Data.Linq.Parsing.Structure;
+using Remotion.Data.Linq.Parsing.Structure.IntermediateModel;
namespace NHibernate.Linq
{
public class NhLinqExpression : IQueryExpression
{
+ private static readonly MethodCallExpressionNodeTypeRegistry MethodCallRegistry =
+ MethodCallExpressionNodeTypeRegistry.CreateDefault();
+
+ static NhLinqExpression()
+ {
+ MethodCallRegistry.Register(
+ new[]
+ {
+ MethodCallExpressionNodeTypeRegistry.GetRegisterableMethodDefinition(ReflectionHelper.GetMethod(() => Queryable.Aggregate<object>(null, null))),
+ MethodCallExpressionNodeTypeRegistry.GetRegisterableMethodDefinition(ReflectionHelper.GetMethod(() => Queryable.Aggregate<object, object>(null, null, null)))
+ },
+ typeof (AggregateExpressionNode));
+ }
+
+
private readonly Expression _expression;
private CommandData _commandData;
private readonly IDictionary<ConstantExpression, NamedParameter> _queryParameters;
@@ -40,7 +60,8 @@
public IASTNode Translate(ISessionFactory sessionFactory)
{
var requiredHqlParameters = new List<NamedParameterDescriptor>();
- var queryModel = new QueryParser(new ExpressionTreeParser(MethodCallExpressionNodeTypeRegistry.CreateDefault())).GetParsedQuery(_expression);
+ // TODO - can we cache any of this?
+ var queryModel = new QueryParser(new ExpressionTreeParser(MethodCallRegistry)).GetParsedQuery(_expression);
_commandData = QueryModelVisitor.GenerateHqlQuery(queryModel, _queryParameters, requiredHqlParameters);
@@ -67,4 +88,65 @@
_commandData.AddAdditionalCriteria(impl);
}
}
+
+ public class AggregateExpressionNode : ResultOperatorExpressionNodeBase
+ {
+ public MethodCallExpressionParseInfo ParseInfo { get; set; }
+ public Expression OptionalSeed { get; set; }
+ public LambdaExpression Accumulator { get; set; }
+ public LambdaExpression OptionalSelector { get; set; }
+
+ public AggregateExpressionNode(MethodCallExpressionParseInfo parseInfo, Expression arg1, Expression arg2, LambdaExpression optionalSelector) : base(parseInfo, null, optionalSelector)
+ {
+ ParseInfo = parseInfo;
+
+ if (arg2 != null)
+ {
+ OptionalSeed = arg1;
+ Accumulator = (LambdaExpression) arg2;
+ }
+ else
+ {
+ Accumulator = (LambdaExpression) arg1;
+ }
+
+ OptionalSelector = optionalSelector;
+ }
+
+ public override Expression Resolve(ParameterExpression inputParameter, Expression expressionToBeResolved, ClauseGenerationContext clauseGenerationContext)
+ {
+ throw new NotImplementedException();
+ }
+
+ protected override ResultOperatorBase CreateResultOperator(ClauseGenerationContext clauseGenerationContext)
+ {
+ return new AggregateResultOperator(ParseInfo, OptionalSeed, Accumulator, OptionalSelector);
+ }
+ }
+
+ public class AggregateResultOperator : ClientSideTransformOperator
+ {
+ public MethodCallExpressionParseInfo ParseInfo { get; set; }
+ public Expression OptionalSeed { get; set; }
+ public LambdaExpression Accumulator { get; set; }
+ public LambdaExpression OptionalSelector { get; set; }
+
+ public AggregateResultOperator(MethodCallExpressionParseInfo parseInfo, Expression optionalSeed, LambdaExpression accumulator, LambdaExpression optionalSelector)
+ {
+ ParseInfo = parseInfo;
+ OptionalSeed = optionalSeed;
+ Accumulator = accumulator;
+ OptionalSelector = optionalSelector;
+ }
+
+ public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo)
+ {
+ throw new NotImplementedException();
+ }
+
+ public override ResultOperatorBase Clone(CloneContext cloneContext)
+ {
+ throw new NotImplementedException();
+ }
+ }
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs 2009-11-05 17:46:33 UTC (rev 4821)
+++ trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs 2009-11-06 22:50:19 UTC (rev 4822)
@@ -1,7 +1,9 @@
using System;
using System.Collections;
+using System.Linq;
using System.Linq.Expressions;
using NHibernate.Transform;
+using Remotion.Collections;
namespace NHibernate.Linq
{
@@ -30,7 +32,38 @@
public IList TransformList(IList collection)
{
- return _listTransformation == null ? collection : (IList) _listTransformation.DynamicInvoke(collection);
+ if (_listTransformation == null)
+ {
+ return collection;
+ }
+
+ object transformResult = collection;
+
+ if (collection.Count > 0)
+ {
+ if (collection[0] is object[])
+ {
+ if ( ((object[])collection[0]).Length != 1)
+ {
+ // We only expect single items
+ throw new NotSupportedException();
+ }
+
+ transformResult = _listTransformation.DynamicInvoke(collection.Cast<object[]>().Select(o => o[0]));
+ }
+ else
+ {
+ transformResult = _listTransformation.DynamicInvoke(collection);
+ }
+ }
+
+ if (transformResult is IList)
+ {
+ return (IList) transformResult;
+ }
+
+ var list = new ArrayList {transformResult};
+ return list;
}
}
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2009-11-05 17:46:33 UTC (rev 4821)
+++ trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2009-11-06 22:50:19 UTC (rev 4822)
@@ -1,11 +1,13 @@
using System;
using System.Collections.Generic;
+using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
+using System.Reflection;
using NHibernate.Engine.Query;
using NHibernate.Hql.Ast;
using NHibernate.Linq.Expressions;
-using NHibernate.Linq.ReWriters;
+using Remotion.Data.Linq;
using Remotion.Data.Linq.Clauses.Expressions;
using Remotion.Data.Linq.Clauses.ExpressionTreeVisitors;
@@ -17,8 +19,9 @@
protected readonly HqlNodeStack _stack;
private readonly IDictionary<ConstantExpression, NamedParameter> _parameters;
private readonly IList<NamedParameterDescriptor> _requiredHqlParameters;
+ static private readonly MethodGeneratorRegistry _methodGeneratorRegistry = MethodGeneratorRegistry.Initialise();
- public HqlGeneratorExpressionTreeVisitor(IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters)
+ public HqlGeneratorExpressionTreeVisitor(IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters)
{
_parameters = parameters;
_requiredHqlParameters = requiredHqlParameters;
@@ -36,6 +39,16 @@
VisitExpression(expression);
}
+ public HqlNodeStack Stack
+ {
+ get { return _stack; }
+ }
+
+ public HqlTreeBuilder TreeBuilder
+ {
+ get { return _hqlTreeBuilder; }
+ }
+
protected override Expression VisitNhAverage(NhAverageExpression expression)
{
var visitor = new HqlGeneratorExpressionTreeVisitor(_parameters, _requiredHqlParameters);
@@ -238,106 +251,11 @@
protected override Expression VisitMethodCallExpression(MethodCallExpression expression)
{
- if (expression.Method.DeclaringType == typeof(Enumerable) ||
- expression.Method.DeclaringType == typeof(Queryable))
- {
- switch (expression.Method.Name)
- {
- case "Any":
- // Any has one or two arguments. Arg 1 is the source and arg 2 is the optional predicate
- using (_stack.PushNode(_hqlTreeBuilder.Exists()))
- {
- using (_stack.PushNode(_hqlTreeBuilder.Query()))
- {
- using (_stack.PushNode(_hqlTreeBuilder.SelectFrom()))
- {
- using (_stack.PushNode(_hqlTreeBuilder.From()))
- {
- using (_stack.PushNode(_hqlTreeBuilder.Range()))
- {
- VisitExpression(expression.Arguments[0]);
+ var generator = _methodGeneratorRegistry.GetMethodGenerator(expression.Method);
- if (expression.Arguments.Count > 1)
- {
- var expr = (LambdaExpression) expression.Arguments[1];
- _stack.PushLeaf(_hqlTreeBuilder.Alias(expr.Parameters[0].Name));
- }
- }
- }
- }
- if (expression.Arguments.Count > 1)
- {
- using (_stack.PushNode(_hqlTreeBuilder.Where()))
- {
- VisitExpression(expression.Arguments[1]);
- }
- }
- }
- }
- break;
+ generator.BuildHql(expression.Object, expression.Arguments, this);
- case "All":
- // All has one or two arguments. Arg 1 is the source and arg 2 is the optional predicate
- using (_stack.PushNode(_hqlTreeBuilder.Not()))
- {
- using (_stack.PushNode(_hqlTreeBuilder.Exists()))
- {
- using (_stack.PushNode(_hqlTreeBuilder.Query()))
- {
- using (_stack.PushNode(_hqlTreeBuilder.SelectFrom()))
- {
- using (_stack.PushNode(_hqlTreeBuilder.From()))
- {
- using (_stack.PushNode(_hqlTreeBuilder.Range()))
- {
- VisitExpression(expression.Arguments[0]);
-
- if (expression.Arguments.Count > 1)
- {
- var expr = (LambdaExpression) expression.Arguments[1];
-
- _stack.PushLeaf(_hqlTreeBuilder.Alias(expr.Parameters[0].Name));
- }
- }
- }
- }
- if (expression.Arguments.Count > 1)
- {
- using (_stack.PushNode(_hqlTreeBuilder.Where()))
- {
- using (_stack.PushNode(_hqlTreeBuilder.Not()))
- {
- VisitExpression(expression.Arguments[1]);
- }
- }
- }
- }
- }
- }
- break;
-
- case "Min":
- using (_stack.PushNode(_hqlTreeBuilder.Min()))
- {
- VisitExpression(expression.Arguments[1]);
- }
- break;
- case "Max":
- using (_stack.PushNode(_hqlTreeBuilder.Max()))
- {
- VisitExpression(expression.Arguments[1]);
- }
- break;
- default:
- throw new NotSupportedException(string.Format("The Enumerable method {0} is not supported", expression.Method.Name));
- }
-
- return expression;
- }
- else
- {
- return base.VisitMethodCallExpression(expression); // throws
- }
+ return expression;
}
protected override Expression VisitLambdaExpression(LambdaExpression expression)
@@ -401,4 +319,352 @@
return itemAsExpression != null ? FormattingExpressionTreeVisitor.Format(itemAsExpression) : unhandledItem.ToString();
}
}
+
+ public class MethodGeneratorRegistry
+ {
+ public static MethodGeneratorRegistry Initialise()
+ {
+ var registry = new MethodGeneratorRegistry();
+
+ // TODO - could use reflection here
+ registry.Register(new QueryableMethodsGenerator());
+ registry.Register(new StringMethodsGenerator());
+
+ return registry;
+ }
+
+ private readonly Dictionary<MethodInfo, IHqlGeneratorForMethod> _registeredMethods = new Dictionary<MethodInfo, IHqlGeneratorForMethod>();
+
+ public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method)
+ {
+ IHqlGeneratorForMethod methodGenerator;
+
+ if (method.IsGenericMethod)
+ {
+ method = method.GetGenericMethodDefinition();
+ }
+
+ if (_registeredMethods.TryGetValue(method, out methodGenerator))
+ {
+ return methodGenerator;
+ }
+
+ throw new NotSupportedException();
+ }
+
+ public void RegisterMethodGenerator(MethodInfo method, IHqlGeneratorForMethod generator)
+ {
+ _registeredMethods.Add(method, generator);
+ }
+
+ private void Register(IHqlGeneratorForType typeMethodGenerator)
+ {
+ typeMethodGenerator.RegisterMethods(this);
+ }
+
+ }
+
+ public interface IHqlGeneratorForMethod
+ {
+ IEnumerable<MethodInfo> SupportedMethods { get; }
+ void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor);
+ }
+
+ public interface IHqlGeneratorForType
+ {
+ void RegisterMethods(MethodGeneratorRegistry methodGeneratorRegistry);
+ }
+
+ abstract public class BaseHqlGeneratorForType : IHqlGeneratorForType
+ {
+ protected readonly List<IHqlGeneratorForMethod> MethodRegistry = new List<IHqlGeneratorForMethod>();
+
+ public void RegisterMethods(MethodGeneratorRegistry methodGeneratorRegistry)
+ {
+ foreach (var generator in MethodRegistry)
+ {
+ foreach (var method in generator.SupportedMethods)
+ {
+ methodGeneratorRegistry.RegisterMethodGenerator(method, generator);
+ }
+ }
+ }
+ }
+
+ public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod
+ {
+ public IEnumerable<MethodInfo> SupportedMethods { get; protected set; }
+
+ public abstract void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor);
+ }
+
+ public class StringMethodsGenerator : BaseHqlGeneratorForType
+ {
+ public StringMethodsGenerator()
+ {
+ // TODO - could use reflection
+ MethodRegistry.Add(new StartsWithGenerator());
+ MethodRegistry.Add(new EndsWithGenerator());
+ MethodRegistry.Add(new ContainsGenerator());
+ MethodRegistry.Add(new EqualsGenerator());
+ }
+
+ class StartsWithGenerator : BaseHqlGeneratorForMethod
+ {
+ public StartsWithGenerator()
+ {
+ SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.StartsWith(null)) };
+ }
+
+ public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor)
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Like()))
+ {
+ hqlVisitor.Visit(targetObject);
+
+ // TODO - this sucks. Concat() just pushes a method node, and we have to do all the child stuff.
+ // Sort out the tree stuff so it works properly
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Concat()))
+ {
+ hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Ident("concat"));
+
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.ExpressionList()))
+ {
+ hqlVisitor.Visit(arguments[0]);
+
+ hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Constant("%"));
+ }
+ }
+ }
+ }
+ }
+
+ class EndsWithGenerator : BaseHqlGeneratorForMethod
+ {
+ public EndsWithGenerator()
+ {
+ SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.EndsWith(null)) };
+ }
+
+ public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor)
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Like()))
+ {
+ hqlVisitor.Visit(targetObject);
+
+ // TODO - this sucks. Concat() just pushes a method node, and we have to do all the child stuff.
+ // Sort out the tree stuff so it works properly
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Concat()))
+ {
+ hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Ident("concat"));
+
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.ExpressionList()))
+ {
+ hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Constant("%"));
+
+ hqlVisitor.Visit(arguments[0]);
+ }
+ }
+ }
+ }
+ }
+
+ class ContainsGenerator : BaseHqlGeneratorForMethod
+ {
+ public ContainsGenerator()
+ {
+ SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.Contains(null)) };
+ }
+
+ public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor)
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Like()))
+ {
+ hqlVisitor.Visit(targetObject);
+
+ // TODO - this sucks. Concat() just pushes a method node, and we have to do all the child stuff.
+ // Sort out the tree stuff so it works properly
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Concat()))
+ {
+ hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Ident("concat"));
+
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.ExpressionList()))
+ {
+ hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Constant("%"));
+
+ hqlVisitor.Visit(arguments[0]);
+
+ hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Constant("%"));
+ }
+ }
+ }
+ }
+ }
+
+ class EqualsGenerator : BaseHqlGeneratorForMethod
+ {
+ public EqualsGenerator()
+ {
+ SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.Equals((string)null)) };
+ }
+
+ public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor)
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Equality()))
+ {
+ hqlVisitor.Visit(targetObject);
+
+ hqlVisitor.Visit(arguments[0]);
+ }
+ }
+ }
+ }
+
+ public class QueryableMethodsGenerator : BaseHqlGeneratorForType
+ {
+ public QueryableMethodsGenerator()
+ {
+ // TODO - could use reflection
+ MethodRegistry.Add(new AnyGenerator());
+ MethodRegistry.Add(new AllGenerator());
+ MethodRegistry.Add(new MinGenerator());
+ MethodRegistry.Add(new MaxGenerator());
+ }
+
+ class AnyGenerator : BaseHqlGeneratorForMethod
+ {
+ public AnyGenerator()
+ {
+ SupportedMethods = new[]
+ {
+ ReflectionHelper.GetMethod(() => Queryable.Any<object>(null)),
+ ReflectionHelper.GetMethod(() => Queryable.Any<object>(null, null)),
+ ReflectionHelper.GetMethod(() => Enumerable.Any<object>(null)),
+ ReflectionHelper.GetMethod(() => Enumerable.Any<object>(null, null))
+ };
+ }
+
+ public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor)
+ {
+ // Any has one or two arguments. Arg 1 is the source and arg 2 is the optional predicate
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Exists()))
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Query()))
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.SelectFrom()))
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.From()))
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Range()))
+ {
+ hqlVisitor.Visit(arguments[0]);
+
+ if (arguments.Count > 1)
+ {
+ var expr = (LambdaExpression)arguments[1];
+ hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Alias(expr.Parameters[0].Name));
+ }
+ }
+ }
+ }
+ if (arguments.Count > 1)
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Where()))
+ {
+ hqlVisitor.Visit(arguments[1]);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ class AllGenerator : BaseHqlGeneratorForMethod
+ {
+ public AllGenerator()
+ {
+ SupportedMethods = new[]
+ {
+ ReflectionHelper.GetMethod(() => Queryable.All<object>(null, null)),
+ ReflectionHelper.GetMethod(() => Enumerable.All<object>(null, null))
+ };
+ }
+
+ public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor)
+ {
+ // All has two arguments. Arg 1 is the source and arg 2 is the predicate
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Not()))
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Exists()))
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Query()))
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.SelectFrom()))
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.From()))
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Range()))
+ {
+ hqlVisitor.Visit(arguments[0]);
+
+ var expr = (LambdaExpression)arguments[1];
+
+ hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Alias(expr.Parameters[0].Name));
+ }
+ }
+ }
+
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Where()))
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Not()))
+ {
+ hqlVisitor.Visit(arguments[1]);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ class MinGenerator : BaseHqlGeneratorForMethod
+ {
+ public MinGenerator()
+ {
+ SupportedMethods = new[]
+ {
+ ReflectionHelper.GetMethod(() => Queryable.Min<object>(null)),
+ ReflectionHelper.GetMethod(() => Enumerable.Min<object>(null))
+ };
+ }
+
+ public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor)
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Min()))
+ {
+ hqlVisitor.Visit(arguments[1]);
+ }
+ }
+ }
+
+ class MaxGenerator : BaseHqlGeneratorForMethod
+ {
+ public MaxGenerator()
+ {
+ SupportedMethods = new[]
+ {
+ ReflectionHelper.GetMethod(() => Queryable.Max<object>(null)),
+ ReflectionHelper.GetMethod(() => Enumerable.Max<object>(null))
+ };
+ }
+
+ public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor)
+ {
+ using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Max()))
+ {
+ hqlVisitor.Visit(arguments[1]);
+ }
+ }
+ }
+ }
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2009-11-05 17:46:33 UTC (rev 4821)
+++ trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2009-11-06 22:50:19 UTC (rev 4822)
@@ -214,13 +214,70 @@
{
ProcessClientSideSelect((ClientSideSelect) resultOperator);
}
- else
+ else if (resultOperator is AggregateResultOperator)
+ {
+ ProcessAggregateOperator((AggregateResultOperator)resultOperator);
+ }
+ else
{
throw new NotSupportedException(string.Format("The {0} result operator is not current supported",
resultOperator.GetType().Name));
}
}
+ private void ProcessAggregateOperator(AggregateResultOperator resultOperator)
+ {
+ var inputType = resultOperator.Accumulator.Parameters[1].Type;
+ var accumulatorType = resultOperator.Accumulator.Parameters[0].Type;
+ var inputList = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(typeof(object)), "inputList");
+
+ var castToItem = EnumerableHelper.GetMethod("Cast", new[] { typeof(IEnumerable) }, new[] { inputType });
+ var castToItemExpr = Expression.Call(castToItem, inputList);
+
+ MethodCallExpression call;
+
+ if (resultOperator.ParseInfo.ParsedExpression.Arguments.Count == 2)
+ {
+ var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object>(null, null));
+ aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType);
+
+ call = Expression.Call(
+ aggregate,
+ castToItemExpr,
+ resultOperator.Accumulator
+ );
+
+ }
+ else if (resultOperator.ParseInfo.ParsedExpression.Arguments.Count == 3)
+ {
+ var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object, object>(null, null, null));
+ aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType, accumulatorType);
+
+ call = Expression.Call(
+ aggregate,
+ castToItemExpr,
+ resultOperator.OptionalSeed,
+ resultOperator.Accumulator
+ );
+ }
+ else
+ {
+ var selectorType = resultOperator.OptionalSelector.Type.GetGenericArguments()[2];
+ var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object, object, object>(null, null, null, null));
+ aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType, accumulatorType, selectorType);
+
+ call = Expression.Call(
+ aggregate,
+ castToItemExpr,
+ resultOperator.OptionalSeed,
+ resultOperator.Accumulator,
+ resultOperator.OptionalSelector
+ );
+ }
+
+ _listTransformers.Add(Expression.Lambda(call, inputList));
+ }
+
private void ProcessClientSideSelect(ClientSideSelect resultOperator)
{
var inputType = resultOperator.SelectClause.Parameters[0].Type;
@@ -251,13 +308,6 @@
private void ProcessNonAggregatingGroupBy(NonAggregatingGroupBy resultOperator, QueryModel model)
{
- /*
- public static IQueryable<IGrouping<TKey, TElement>> GroupBy<TSource, TKey, TElement>(
- this IQueryable<TSource> source,
- Expression<Func<TSource, TKey>> keySelector,
- Expression<Func<TSource, TElement>> elementSelector);
- */
-
var tSource = model.SelectClause.Selector.Type;
var tKey = resultOperator.GroupBy.KeySelector.Type;
var tElement = resultOperator.GroupBy.ElementSelector.Type;
@@ -305,32 +355,8 @@
LambdaExpression elementSelectorExpr = Expression.Lambda(elementSelector, itemParam);
- ParameterExpression objectArrayParam = Expression.Parameter(typeof (object[]), "array");
+ Expression castToItemExpr = Expression.Call(castToItem, listParameter);
- Expression castToItemExpr;
- if (_itemTransformers.Count > 0)
- {
- // The item transformer will already have removed the object[]. We need to do this:
- // list.Cast<tSource>().GroupBy(keySelectorExpr, elementSelectorExpr).ToList();
- castToItemExpr = Expression.Call(castToItem, listParameter);
- }
- else
- {
- // We have an object[] in the ResultTransformer
- // list.Cast<object[]>().Select(o => o[0]).Cast<tSource>().GroupBy(keySelectorExpr, elementSelectorExpr).ToList();
- var castToObjectArray = EnumerableHelper.GetMethod("Cast", new[] { typeof(IEnumerable) }, new[] { typeof(object[]) });
- var selectObject = EnumerableHelper.GetMethod("Select", new[] { typeof(IEnumerable<>), typeof(Func<,>) }, new[] { typeof(object[]), typeof(object) });
-
- var castToObjectArrayExpr = Expression.Call(castToObjectArray, listParameter);
-
- LambdaExpression index = Expression.Lambda(Expression.ArrayIndex(objectArrayParam, Expression.Constant(0)),
- objectArrayParam);
-
- var selectObjectExpr = Expression.Call(selectObject, castToObjectArrayExpr, index);
-
- castToItemExpr = Expression.Call(castToItem, selectObjectExpr);
- }
-
var groupByExpr = Expression.Call(groupByMethod, castToItemExpr, keySelectorExpr, elementSelectorExpr);
var toListExpr = Expression.Call(toList, groupByExpr);
Added: trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs
===================================================================
--- trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs (rev 0)
+++ trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs 2009-11-06 22:50:19 UTC (rev 4822)
@@ -0,0 +1,139 @@
+using System;
+using System.Linq;
+using System.Text;
+using NUnit.Framework;
+
+namespace NHibernate.Test.Linq
+{
+ [TestFixture]
+ public class AggregateTests : LinqTestCase
+ {
+ [Test]
+ public void AggregateWithStartsWith()
+ {
+ var query = (from c in db.Customers where c.CustomerId.StartsWith("A") select c.CustomerId)
+ .Aggregate(new StringBuilder(), (sb, id) => sb.Append(id).Append(","));
+
+ Console.WriteLine(query);
+ Assert.AreEqual("ALFKI,ANATR,ANTON,AROUT,", query.ToString());
+ }
+
+ [Test]
+ public void AggregateWithEndsWith()
+ {
+ var query = (from c in db.Customers where c.CustomerId.EndsWith("TH") select c.CustomerId)
+ .Aggregate(new StringBuilder(), (sb, id) => sb.Append(id).Append(","));
+
+ Console.WriteLine(query);
+ Assert.AreEqual("WARTH,", query.ToString());
+ }
+
+ [Test]
+ public void AggregateWithContains()
+ {
+ var query = (from c in db.Customers where c.CustomerId.Contains("CH") select c.CustomerId)
+ .Aggregate(new StringBuilder(), (sb, id) => sb.Append(id).Append(","));
+
+ Console.WriteLine(query);
+ Assert.AreEqual("CHOPS,RANCH,", query.ToString());
+ }
+
+ [Test]
+ public void AggregateWithEquals()
+ {
+ var query = (from c in db.Customers
+ where c.CustomerId.Equals("ALFKI") || c.CustomerId.Equals("ANATR") || c.CustomerId.Equals("ANTON")
+ select c.CustomerId)
+ .Aggregate((prev, next) => (prev + "," + next));
+
+ Console.WriteLine(query);
+ Assert.AreEqual("ALFKI,ANATR,ANTON", query);
+ }
+
+ [Test]
+ public void AggregateWithNotStartsWith()
+ {
+ var query = (from c in db.Customers
+ where c.CustomerId.StartsWith("A") && !c.CustomerId.StartsWith("AN")
+ select c.CustomerId)
+ .Aggregate(new StringBuilder(), (sb, id) => sb.Append(id).Append(","));
+
+ Console.WriteLine(query);
+ Assert.AreEqual("ALFKI,AROUT,", query.ToString());
+ }
+ /*
+ [Test]
+ [Ignore("TODO")]
+ public void AggregateWithMonthFunction()
+ {
+ var date = new DateTime(2007, 1, 1);
+
+ var query = (from e in db.Employees
+ where db.Methods.Month(e.BirthDate) == date.Month
+ select e.FirstName)
+ .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name));
+
+ Console.WriteLine("{0} Birthdays:", date.ToString("MMMM"));
+ Console.WriteLine(query);
+ }
+
+ [Test]
+ [Ignore("TODO")]
+ public void AggregateWithBeforeYearFunction()
+ {
+ var date = new DateTime(1960, 1, 1);
+
+ var query = (from e in db.Employees
+ where db.Methods.Year(e.BirthDate) < date.Year
+ select db.Methods.Upper(e.FirstName))
+ .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name));
+
+ Console.WriteLine("Birthdays before {0}:", date.ToString("yyyy"));
+ Console.WriteLine(query);
+ }
+
+ [Test]
+ [Ignore("TODO")]
+ public void AggregateWithOnOrAfterYearFunction()
+ {
+ var date = new DateTime(1960, 1, 1);
+
+ var query = (from e in db.Employees
+ where db.Methods.Year(e.BirthDate) >= date.Year && db.Methods.Len(e.FirstName) > 4
+ select e.FirstName)
+ .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name));
+
+ Console.WriteLine("Birthdays after {0}:", date.ToString("yyyy"));
+ Console.WriteLine(query);
+ }
+
+ [Test]
+ [Ignore("TODO")]
+ public void AggregateWithUpperAndLowerFunctions()
+ {
+ var date = new DateTime(2007, 1, 1);
+
+ var query = (from e in db.Employees
+ where db.Methods.Month(e.BirthDate) == date.Month
+ select new { First = e.FirstName.ToUpper(), Last = db.Methods.Lower(e.LastName) })
+ .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name));
+
+ Console.WriteLine("{0} Birthdays:", date.ToString("MMMM"));
+ Console.WriteLine(query);
+ }
+
+ [Test]
+ [Ignore("TODO")]
+ public void AggregateWithCustomFunction()
+ {
+ var date = new DateTime(1960, 1, 1);
+
+ var query = (from e in db.Employees
+ where db.Methods.Year(e.BirthDate) < date.Year
+ select db.Methods.fnEncrypt(e.FirstName))
+ .Aggregate(new StringBuilder(), (sb, name) => sb.AppendLine(BitConverter.ToString(name)));
+
+ Console.WriteLine(query);
+ }*/
+ }
+}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj
===================================================================
--- trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj 2009-11-05 17:46:33 UTC (rev 4821)
+++ trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj 2009-11-06 22:50:19 UTC (rev 4822)
@@ -375,6 +375,7 @@
<Compile Include="LazyOneToOne\Person.cs" />
<Compile Include="LazyProperty\Book.cs" />
<Compile Include="LazyProperty\LazyPropertyFixture.cs" />
+ <Compile Include="Linq\AggregateTests.cs" />
<Compile Include="Linq\Entities\Address.cs" />
<Compile Include="Linq\Entities\Customer.cs" />
<Compile Include="Linq\Entities\Employee.cs" />
This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site.
|