From: <ste...@us...> - 2009-11-06 22:50:31
|
Revision: 4822 http://nhibernate.svn.sourceforge.net/nhibernate/?rev=4822&view=rev Author: steverstrong Date: 2009-11-06 22:50:19 +0000 (Fri, 06 Nov 2009) Log Message: ----------- Further Linq work - Support for Enumerable.Aggregate and various string functions (xxx.StartsWith etc) Modified Paths: -------------- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj Added Paths: ----------- trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-11-05 17:46:33 UTC (rev 4821) +++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-11-06 22:50:19 UTC (rev 4822) @@ -383,5 +383,20 @@ { return new HqlAll(_factory); } + + public HqlLike Like() + { + return new HqlLike(_factory); + } + + public HqlConcat Concat() + { + return new HqlConcat(_factory); + } + + public HqlExpressionList ExpressionList() + { + return new HqlExpressionList(_factory); + } } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-11-05 17:46:33 UTC (rev 4821) +++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-11-06 22:50:19 UTC (rev 4822) @@ -632,4 +632,18 @@ { } } + + public class HqlLike : HqlTreeNode + { + public HqlLike(IASTFactory factory) : base(HqlSqlWalker.LIKE, "like", factory) + { + } + } + + public class HqlConcat : HqlTreeNode + { + public HqlConcat(IASTFactory factory) : base(HqlSqlWalker.METHOD_CALL, "method", factory) + { + } + } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs 2009-11-05 17:46:33 UTC (rev 4821) +++ trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs 2009-11-06 22:50:19 UTC (rev 4822) @@ -1,8 +1,28 @@ +using System; using System.Linq; +using System.Linq.Expressions; using System.Reflection; namespace NHibernate.Linq { + public static class ReflectionHelper + { + public delegate void Action(); + + public static MethodInfo GetMethod<TSource>(Expression<Action<TSource>> method) + { + var methodInfo = ((MethodCallExpression) method.Body).Method; + return methodInfo.IsGenericMethod ? methodInfo.GetGenericMethodDefinition() : methodInfo; + } + + public static MethodInfo GetMethod(Expression<Action> method) + { + var methodInfo = ((MethodCallExpression)method.Body).Method; + return methodInfo.IsGenericMethod ? methodInfo.GetGenericMethodDefinition() : methodInfo; + } + } + + // TODO rename / remove - reflection helper above is better public static class EnumerableHelper { public static MethodInfo GetMethod(string name, System.Type[] parameterTypes) Modified: trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs 2009-11-05 17:46:33 UTC (rev 4821) +++ trunk/nhibernate/src/NHibernate/Linq/HqlNodeStack.cs 2009-11-06 22:50:19 UTC (rev 4822) @@ -11,6 +11,7 @@ public HqlNodeStack(HqlTreeBuilder builder) { + // TODO - only reason for the build is to have a root node. Sucks, change this _root = builder.Holder(); _stack.Push(_root); } Modified: trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs 2009-11-05 17:46:33 UTC (rev 4821) +++ trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs 2009-11-06 22:50:19 UTC (rev 4822) @@ -1,16 +1,36 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using NHibernate.Engine.Query; using NHibernate.Hql.Ast.ANTLR.Tree; +using NHibernate.Linq.ResultOperators; using NHibernate.Linq.Visitors; +using Remotion.Data.Linq.Clauses; +using Remotion.Data.Linq.Clauses.StreamedData; using Remotion.Data.Linq.Parsing.ExpressionTreeVisitors; using Remotion.Data.Linq.Parsing.Structure; +using Remotion.Data.Linq.Parsing.Structure.IntermediateModel; namespace NHibernate.Linq { public class NhLinqExpression : IQueryExpression { + private static readonly MethodCallExpressionNodeTypeRegistry MethodCallRegistry = + MethodCallExpressionNodeTypeRegistry.CreateDefault(); + + static NhLinqExpression() + { + MethodCallRegistry.Register( + new[] + { + MethodCallExpressionNodeTypeRegistry.GetRegisterableMethodDefinition(ReflectionHelper.GetMethod(() => Queryable.Aggregate<object>(null, null))), + MethodCallExpressionNodeTypeRegistry.GetRegisterableMethodDefinition(ReflectionHelper.GetMethod(() => Queryable.Aggregate<object, object>(null, null, null))) + }, + typeof (AggregateExpressionNode)); + } + + private readonly Expression _expression; private CommandData _commandData; private readonly IDictionary<ConstantExpression, NamedParameter> _queryParameters; @@ -40,7 +60,8 @@ public IASTNode Translate(ISessionFactory sessionFactory) { var requiredHqlParameters = new List<NamedParameterDescriptor>(); - var queryModel = new QueryParser(new ExpressionTreeParser(MethodCallExpressionNodeTypeRegistry.CreateDefault())).GetParsedQuery(_expression); + // TODO - can we cache any of this? + var queryModel = new QueryParser(new ExpressionTreeParser(MethodCallRegistry)).GetParsedQuery(_expression); _commandData = QueryModelVisitor.GenerateHqlQuery(queryModel, _queryParameters, requiredHqlParameters); @@ -67,4 +88,65 @@ _commandData.AddAdditionalCriteria(impl); } } + + public class AggregateExpressionNode : ResultOperatorExpressionNodeBase + { + public MethodCallExpressionParseInfo ParseInfo { get; set; } + public Expression OptionalSeed { get; set; } + public LambdaExpression Accumulator { get; set; } + public LambdaExpression OptionalSelector { get; set; } + + public AggregateExpressionNode(MethodCallExpressionParseInfo parseInfo, Expression arg1, Expression arg2, LambdaExpression optionalSelector) : base(parseInfo, null, optionalSelector) + { + ParseInfo = parseInfo; + + if (arg2 != null) + { + OptionalSeed = arg1; + Accumulator = (LambdaExpression) arg2; + } + else + { + Accumulator = (LambdaExpression) arg1; + } + + OptionalSelector = optionalSelector; + } + + public override Expression Resolve(ParameterExpression inputParameter, Expression expressionToBeResolved, ClauseGenerationContext clauseGenerationContext) + { + throw new NotImplementedException(); + } + + protected override ResultOperatorBase CreateResultOperator(ClauseGenerationContext clauseGenerationContext) + { + return new AggregateResultOperator(ParseInfo, OptionalSeed, Accumulator, OptionalSelector); + } + } + + public class AggregateResultOperator : ClientSideTransformOperator + { + public MethodCallExpressionParseInfo ParseInfo { get; set; } + public Expression OptionalSeed { get; set; } + public LambdaExpression Accumulator { get; set; } + public LambdaExpression OptionalSelector { get; set; } + + public AggregateResultOperator(MethodCallExpressionParseInfo parseInfo, Expression optionalSeed, LambdaExpression accumulator, LambdaExpression optionalSelector) + { + ParseInfo = parseInfo; + OptionalSeed = optionalSeed; + Accumulator = accumulator; + OptionalSelector = optionalSelector; + } + + public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo) + { + throw new NotImplementedException(); + } + + public override ResultOperatorBase Clone(CloneContext cloneContext) + { + throw new NotImplementedException(); + } + } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs 2009-11-05 17:46:33 UTC (rev 4821) +++ trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs 2009-11-06 22:50:19 UTC (rev 4822) @@ -1,7 +1,9 @@ using System; using System.Collections; +using System.Linq; using System.Linq.Expressions; using NHibernate.Transform; +using Remotion.Collections; namespace NHibernate.Linq { @@ -30,7 +32,38 @@ public IList TransformList(IList collection) { - return _listTransformation == null ? collection : (IList) _listTransformation.DynamicInvoke(collection); + if (_listTransformation == null) + { + return collection; + } + + object transformResult = collection; + + if (collection.Count > 0) + { + if (collection[0] is object[]) + { + if ( ((object[])collection[0]).Length != 1) + { + // We only expect single items + throw new NotSupportedException(); + } + + transformResult = _listTransformation.DynamicInvoke(collection.Cast<object[]>().Select(o => o[0])); + } + else + { + transformResult = _listTransformation.DynamicInvoke(collection); + } + } + + if (transformResult is IList) + { + return (IList) transformResult; + } + + var list = new ArrayList {transformResult}; + return list; } } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2009-11-05 17:46:33 UTC (rev 4821) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2009-11-06 22:50:19 UTC (rev 4822) @@ -1,11 +1,13 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; +using System.Reflection; using NHibernate.Engine.Query; using NHibernate.Hql.Ast; using NHibernate.Linq.Expressions; -using NHibernate.Linq.ReWriters; +using Remotion.Data.Linq; using Remotion.Data.Linq.Clauses.Expressions; using Remotion.Data.Linq.Clauses.ExpressionTreeVisitors; @@ -17,8 +19,9 @@ protected readonly HqlNodeStack _stack; private readonly IDictionary<ConstantExpression, NamedParameter> _parameters; private readonly IList<NamedParameterDescriptor> _requiredHqlParameters; + static private readonly MethodGeneratorRegistry _methodGeneratorRegistry = MethodGeneratorRegistry.Initialise(); - public HqlGeneratorExpressionTreeVisitor(IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters) + public HqlGeneratorExpressionTreeVisitor(IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters) { _parameters = parameters; _requiredHqlParameters = requiredHqlParameters; @@ -36,6 +39,16 @@ VisitExpression(expression); } + public HqlNodeStack Stack + { + get { return _stack; } + } + + public HqlTreeBuilder TreeBuilder + { + get { return _hqlTreeBuilder; } + } + protected override Expression VisitNhAverage(NhAverageExpression expression) { var visitor = new HqlGeneratorExpressionTreeVisitor(_parameters, _requiredHqlParameters); @@ -238,106 +251,11 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression expression) { - if (expression.Method.DeclaringType == typeof(Enumerable) || - expression.Method.DeclaringType == typeof(Queryable)) - { - switch (expression.Method.Name) - { - case "Any": - // Any has one or two arguments. Arg 1 is the source and arg 2 is the optional predicate - using (_stack.PushNode(_hqlTreeBuilder.Exists())) - { - using (_stack.PushNode(_hqlTreeBuilder.Query())) - { - using (_stack.PushNode(_hqlTreeBuilder.SelectFrom())) - { - using (_stack.PushNode(_hqlTreeBuilder.From())) - { - using (_stack.PushNode(_hqlTreeBuilder.Range())) - { - VisitExpression(expression.Arguments[0]); + var generator = _methodGeneratorRegistry.GetMethodGenerator(expression.Method); - if (expression.Arguments.Count > 1) - { - var expr = (LambdaExpression) expression.Arguments[1]; - _stack.PushLeaf(_hqlTreeBuilder.Alias(expr.Parameters[0].Name)); - } - } - } - } - if (expression.Arguments.Count > 1) - { - using (_stack.PushNode(_hqlTreeBuilder.Where())) - { - VisitExpression(expression.Arguments[1]); - } - } - } - } - break; + generator.BuildHql(expression.Object, expression.Arguments, this); - case "All": - // All has one or two arguments. Arg 1 is the source and arg 2 is the optional predicate - using (_stack.PushNode(_hqlTreeBuilder.Not())) - { - using (_stack.PushNode(_hqlTreeBuilder.Exists())) - { - using (_stack.PushNode(_hqlTreeBuilder.Query())) - { - using (_stack.PushNode(_hqlTreeBuilder.SelectFrom())) - { - using (_stack.PushNode(_hqlTreeBuilder.From())) - { - using (_stack.PushNode(_hqlTreeBuilder.Range())) - { - VisitExpression(expression.Arguments[0]); - - if (expression.Arguments.Count > 1) - { - var expr = (LambdaExpression) expression.Arguments[1]; - - _stack.PushLeaf(_hqlTreeBuilder.Alias(expr.Parameters[0].Name)); - } - } - } - } - if (expression.Arguments.Count > 1) - { - using (_stack.PushNode(_hqlTreeBuilder.Where())) - { - using (_stack.PushNode(_hqlTreeBuilder.Not())) - { - VisitExpression(expression.Arguments[1]); - } - } - } - } - } - } - break; - - case "Min": - using (_stack.PushNode(_hqlTreeBuilder.Min())) - { - VisitExpression(expression.Arguments[1]); - } - break; - case "Max": - using (_stack.PushNode(_hqlTreeBuilder.Max())) - { - VisitExpression(expression.Arguments[1]); - } - break; - default: - throw new NotSupportedException(string.Format("The Enumerable method {0} is not supported", expression.Method.Name)); - } - - return expression; - } - else - { - return base.VisitMethodCallExpression(expression); // throws - } + return expression; } protected override Expression VisitLambdaExpression(LambdaExpression expression) @@ -401,4 +319,352 @@ return itemAsExpression != null ? FormattingExpressionTreeVisitor.Format(itemAsExpression) : unhandledItem.ToString(); } } + + public class MethodGeneratorRegistry + { + public static MethodGeneratorRegistry Initialise() + { + var registry = new MethodGeneratorRegistry(); + + // TODO - could use reflection here + registry.Register(new QueryableMethodsGenerator()); + registry.Register(new StringMethodsGenerator()); + + return registry; + } + + private readonly Dictionary<MethodInfo, IHqlGeneratorForMethod> _registeredMethods = new Dictionary<MethodInfo, IHqlGeneratorForMethod>(); + + public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) + { + IHqlGeneratorForMethod methodGenerator; + + if (method.IsGenericMethod) + { + method = method.GetGenericMethodDefinition(); + } + + if (_registeredMethods.TryGetValue(method, out methodGenerator)) + { + return methodGenerator; + } + + throw new NotSupportedException(); + } + + public void RegisterMethodGenerator(MethodInfo method, IHqlGeneratorForMethod generator) + { + _registeredMethods.Add(method, generator); + } + + private void Register(IHqlGeneratorForType typeMethodGenerator) + { + typeMethodGenerator.RegisterMethods(this); + } + + } + + public interface IHqlGeneratorForMethod + { + IEnumerable<MethodInfo> SupportedMethods { get; } + void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor); + } + + public interface IHqlGeneratorForType + { + void RegisterMethods(MethodGeneratorRegistry methodGeneratorRegistry); + } + + abstract public class BaseHqlGeneratorForType : IHqlGeneratorForType + { + protected readonly List<IHqlGeneratorForMethod> MethodRegistry = new List<IHqlGeneratorForMethod>(); + + public void RegisterMethods(MethodGeneratorRegistry methodGeneratorRegistry) + { + foreach (var generator in MethodRegistry) + { + foreach (var method in generator.SupportedMethods) + { + methodGeneratorRegistry.RegisterMethodGenerator(method, generator); + } + } + } + } + + public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod + { + public IEnumerable<MethodInfo> SupportedMethods { get; protected set; } + + public abstract void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor); + } + + public class StringMethodsGenerator : BaseHqlGeneratorForType + { + public StringMethodsGenerator() + { + // TODO - could use reflection + MethodRegistry.Add(new StartsWithGenerator()); + MethodRegistry.Add(new EndsWithGenerator()); + MethodRegistry.Add(new ContainsGenerator()); + MethodRegistry.Add(new EqualsGenerator()); + } + + class StartsWithGenerator : BaseHqlGeneratorForMethod + { + public StartsWithGenerator() + { + SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.StartsWith(null)) }; + } + + public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Like())) + { + hqlVisitor.Visit(targetObject); + + // TODO - this sucks. Concat() just pushes a method node, and we have to do all the child stuff. + // Sort out the tree stuff so it works properly + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Concat())) + { + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Ident("concat")); + + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.ExpressionList())) + { + hqlVisitor.Visit(arguments[0]); + + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Constant("%")); + } + } + } + } + } + + class EndsWithGenerator : BaseHqlGeneratorForMethod + { + public EndsWithGenerator() + { + SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.EndsWith(null)) }; + } + + public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Like())) + { + hqlVisitor.Visit(targetObject); + + // TODO - this sucks. Concat() just pushes a method node, and we have to do all the child stuff. + // Sort out the tree stuff so it works properly + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Concat())) + { + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Ident("concat")); + + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.ExpressionList())) + { + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Constant("%")); + + hqlVisitor.Visit(arguments[0]); + } + } + } + } + } + + class ContainsGenerator : BaseHqlGeneratorForMethod + { + public ContainsGenerator() + { + SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.Contains(null)) }; + } + + public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Like())) + { + hqlVisitor.Visit(targetObject); + + // TODO - this sucks. Concat() just pushes a method node, and we have to do all the child stuff. + // Sort out the tree stuff so it works properly + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Concat())) + { + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Ident("concat")); + + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.ExpressionList())) + { + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Constant("%")); + + hqlVisitor.Visit(arguments[0]); + + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Constant("%")); + } + } + } + } + } + + class EqualsGenerator : BaseHqlGeneratorForMethod + { + public EqualsGenerator() + { + SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.Equals((string)null)) }; + } + + public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Equality())) + { + hqlVisitor.Visit(targetObject); + + hqlVisitor.Visit(arguments[0]); + } + } + } + } + + public class QueryableMethodsGenerator : BaseHqlGeneratorForType + { + public QueryableMethodsGenerator() + { + // TODO - could use reflection + MethodRegistry.Add(new AnyGenerator()); + MethodRegistry.Add(new AllGenerator()); + MethodRegistry.Add(new MinGenerator()); + MethodRegistry.Add(new MaxGenerator()); + } + + class AnyGenerator : BaseHqlGeneratorForMethod + { + public AnyGenerator() + { + SupportedMethods = new[] + { + ReflectionHelper.GetMethod(() => Queryable.Any<object>(null)), + ReflectionHelper.GetMethod(() => Queryable.Any<object>(null, null)), + ReflectionHelper.GetMethod(() => Enumerable.Any<object>(null)), + ReflectionHelper.GetMethod(() => Enumerable.Any<object>(null, null)) + }; + } + + public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + // Any has one or two arguments. Arg 1 is the source and arg 2 is the optional predicate + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Exists())) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Query())) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.SelectFrom())) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.From())) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Range())) + { + hqlVisitor.Visit(arguments[0]); + + if (arguments.Count > 1) + { + var expr = (LambdaExpression)arguments[1]; + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Alias(expr.Parameters[0].Name)); + } + } + } + } + if (arguments.Count > 1) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Where())) + { + hqlVisitor.Visit(arguments[1]); + } + } + } + } + } + } + + class AllGenerator : BaseHqlGeneratorForMethod + { + public AllGenerator() + { + SupportedMethods = new[] + { + ReflectionHelper.GetMethod(() => Queryable.All<object>(null, null)), + ReflectionHelper.GetMethod(() => Enumerable.All<object>(null, null)) + }; + } + + public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + // All has two arguments. Arg 1 is the source and arg 2 is the predicate + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Not())) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Exists())) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Query())) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.SelectFrom())) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.From())) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Range())) + { + hqlVisitor.Visit(arguments[0]); + + var expr = (LambdaExpression)arguments[1]; + + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Alias(expr.Parameters[0].Name)); + } + } + } + + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Where())) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Not())) + { + hqlVisitor.Visit(arguments[1]); + } + } + } + } + } + } + } + + class MinGenerator : BaseHqlGeneratorForMethod + { + public MinGenerator() + { + SupportedMethods = new[] + { + ReflectionHelper.GetMethod(() => Queryable.Min<object>(null)), + ReflectionHelper.GetMethod(() => Enumerable.Min<object>(null)) + }; + } + + public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Min())) + { + hqlVisitor.Visit(arguments[1]); + } + } + } + + class MaxGenerator : BaseHqlGeneratorForMethod + { + public MaxGenerator() + { + SupportedMethods = new[] + { + ReflectionHelper.GetMethod(() => Queryable.Max<object>(null)), + ReflectionHelper.GetMethod(() => Enumerable.Max<object>(null)) + }; + } + + public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Max())) + { + hqlVisitor.Visit(arguments[1]); + } + } + } + } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2009-11-05 17:46:33 UTC (rev 4821) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2009-11-06 22:50:19 UTC (rev 4822) @@ -214,13 +214,70 @@ { ProcessClientSideSelect((ClientSideSelect) resultOperator); } - else + else if (resultOperator is AggregateResultOperator) + { + ProcessAggregateOperator((AggregateResultOperator)resultOperator); + } + else { throw new NotSupportedException(string.Format("The {0} result operator is not current supported", resultOperator.GetType().Name)); } } + private void ProcessAggregateOperator(AggregateResultOperator resultOperator) + { + var inputType = resultOperator.Accumulator.Parameters[1].Type; + var accumulatorType = resultOperator.Accumulator.Parameters[0].Type; + var inputList = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(typeof(object)), "inputList"); + + var castToItem = EnumerableHelper.GetMethod("Cast", new[] { typeof(IEnumerable) }, new[] { inputType }); + var castToItemExpr = Expression.Call(castToItem, inputList); + + MethodCallExpression call; + + if (resultOperator.ParseInfo.ParsedExpression.Arguments.Count == 2) + { + var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object>(null, null)); + aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType); + + call = Expression.Call( + aggregate, + castToItemExpr, + resultOperator.Accumulator + ); + + } + else if (resultOperator.ParseInfo.ParsedExpression.Arguments.Count == 3) + { + var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object, object>(null, null, null)); + aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType, accumulatorType); + + call = Expression.Call( + aggregate, + castToItemExpr, + resultOperator.OptionalSeed, + resultOperator.Accumulator + ); + } + else + { + var selectorType = resultOperator.OptionalSelector.Type.GetGenericArguments()[2]; + var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object, object, object>(null, null, null, null)); + aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType, accumulatorType, selectorType); + + call = Expression.Call( + aggregate, + castToItemExpr, + resultOperator.OptionalSeed, + resultOperator.Accumulator, + resultOperator.OptionalSelector + ); + } + + _listTransformers.Add(Expression.Lambda(call, inputList)); + } + private void ProcessClientSideSelect(ClientSideSelect resultOperator) { var inputType = resultOperator.SelectClause.Parameters[0].Type; @@ -251,13 +308,6 @@ private void ProcessNonAggregatingGroupBy(NonAggregatingGroupBy resultOperator, QueryModel model) { - /* - public static IQueryable<IGrouping<TKey, TElement>> GroupBy<TSource, TKey, TElement>( - this IQueryable<TSource> source, - Expression<Func<TSource, TKey>> keySelector, - Expression<Func<TSource, TElement>> elementSelector); - */ - var tSource = model.SelectClause.Selector.Type; var tKey = resultOperator.GroupBy.KeySelector.Type; var tElement = resultOperator.GroupBy.ElementSelector.Type; @@ -305,32 +355,8 @@ LambdaExpression elementSelectorExpr = Expression.Lambda(elementSelector, itemParam); - ParameterExpression objectArrayParam = Expression.Parameter(typeof (object[]), "array"); + Expression castToItemExpr = Expression.Call(castToItem, listParameter); - Expression castToItemExpr; - if (_itemTransformers.Count > 0) - { - // The item transformer will already have removed the object[]. We need to do this: - // list.Cast<tSource>().GroupBy(keySelectorExpr, elementSelectorExpr).ToList(); - castToItemExpr = Expression.Call(castToItem, listParameter); - } - else - { - // We have an object[] in the ResultTransformer - // list.Cast<object[]>().Select(o => o[0]).Cast<tSource>().GroupBy(keySelectorExpr, elementSelectorExpr).ToList(); - var castToObjectArray = EnumerableHelper.GetMethod("Cast", new[] { typeof(IEnumerable) }, new[] { typeof(object[]) }); - var selectObject = EnumerableHelper.GetMethod("Select", new[] { typeof(IEnumerable<>), typeof(Func<,>) }, new[] { typeof(object[]), typeof(object) }); - - var castToObjectArrayExpr = Expression.Call(castToObjectArray, listParameter); - - LambdaExpression index = Expression.Lambda(Expression.ArrayIndex(objectArrayParam, Expression.Constant(0)), - objectArrayParam); - - var selectObjectExpr = Expression.Call(selectObject, castToObjectArrayExpr, index); - - castToItemExpr = Expression.Call(castToItem, selectObjectExpr); - } - var groupByExpr = Expression.Call(groupByMethod, castToItemExpr, keySelectorExpr, elementSelectorExpr); var toListExpr = Expression.Call(toList, groupByExpr); Added: trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs =================================================================== --- trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs (rev 0) +++ trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs 2009-11-06 22:50:19 UTC (rev 4822) @@ -0,0 +1,139 @@ +using System; +using System.Linq; +using System.Text; +using NUnit.Framework; + +namespace NHibernate.Test.Linq +{ + [TestFixture] + public class AggregateTests : LinqTestCase + { + [Test] + public void AggregateWithStartsWith() + { + var query = (from c in db.Customers where c.CustomerId.StartsWith("A") select c.CustomerId) + .Aggregate(new StringBuilder(), (sb, id) => sb.Append(id).Append(",")); + + Console.WriteLine(query); + Assert.AreEqual("ALFKI,ANATR,ANTON,AROUT,", query.ToString()); + } + + [Test] + public void AggregateWithEndsWith() + { + var query = (from c in db.Customers where c.CustomerId.EndsWith("TH") select c.CustomerId) + .Aggregate(new StringBuilder(), (sb, id) => sb.Append(id).Append(",")); + + Console.WriteLine(query); + Assert.AreEqual("WARTH,", query.ToString()); + } + + [Test] + public void AggregateWithContains() + { + var query = (from c in db.Customers where c.CustomerId.Contains("CH") select c.CustomerId) + .Aggregate(new StringBuilder(), (sb, id) => sb.Append(id).Append(",")); + + Console.WriteLine(query); + Assert.AreEqual("CHOPS,RANCH,", query.ToString()); + } + + [Test] + public void AggregateWithEquals() + { + var query = (from c in db.Customers + where c.CustomerId.Equals("ALFKI") || c.CustomerId.Equals("ANATR") || c.CustomerId.Equals("ANTON") + select c.CustomerId) + .Aggregate((prev, next) => (prev + "," + next)); + + Console.WriteLine(query); + Assert.AreEqual("ALFKI,ANATR,ANTON", query); + } + + [Test] + public void AggregateWithNotStartsWith() + { + var query = (from c in db.Customers + where c.CustomerId.StartsWith("A") && !c.CustomerId.StartsWith("AN") + select c.CustomerId) + .Aggregate(new StringBuilder(), (sb, id) => sb.Append(id).Append(",")); + + Console.WriteLine(query); + Assert.AreEqual("ALFKI,AROUT,", query.ToString()); + } + /* + [Test] + [Ignore("TODO")] + public void AggregateWithMonthFunction() + { + var date = new DateTime(2007, 1, 1); + + var query = (from e in db.Employees + where db.Methods.Month(e.BirthDate) == date.Month + select e.FirstName) + .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name)); + + Console.WriteLine("{0} Birthdays:", date.ToString("MMMM")); + Console.WriteLine(query); + } + + [Test] + [Ignore("TODO")] + public void AggregateWithBeforeYearFunction() + { + var date = new DateTime(1960, 1, 1); + + var query = (from e in db.Employees + where db.Methods.Year(e.BirthDate) < date.Year + select db.Methods.Upper(e.FirstName)) + .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name)); + + Console.WriteLine("Birthdays before {0}:", date.ToString("yyyy")); + Console.WriteLine(query); + } + + [Test] + [Ignore("TODO")] + public void AggregateWithOnOrAfterYearFunction() + { + var date = new DateTime(1960, 1, 1); + + var query = (from e in db.Employees + where db.Methods.Year(e.BirthDate) >= date.Year && db.Methods.Len(e.FirstName) > 4 + select e.FirstName) + .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name)); + + Console.WriteLine("Birthdays after {0}:", date.ToString("yyyy")); + Console.WriteLine(query); + } + + [Test] + [Ignore("TODO")] + public void AggregateWithUpperAndLowerFunctions() + { + var date = new DateTime(2007, 1, 1); + + var query = (from e in db.Employees + where db.Methods.Month(e.BirthDate) == date.Month + select new { First = e.FirstName.ToUpper(), Last = db.Methods.Lower(e.LastName) }) + .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name)); + + Console.WriteLine("{0} Birthdays:", date.ToString("MMMM")); + Console.WriteLine(query); + } + + [Test] + [Ignore("TODO")] + public void AggregateWithCustomFunction() + { + var date = new DateTime(1960, 1, 1); + + var query = (from e in db.Employees + where db.Methods.Year(e.BirthDate) < date.Year + select db.Methods.fnEncrypt(e.FirstName)) + .Aggregate(new StringBuilder(), (sb, name) => sb.AppendLine(BitConverter.ToString(name))); + + Console.WriteLine(query); + }*/ + } +} \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj =================================================================== --- trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj 2009-11-05 17:46:33 UTC (rev 4821) +++ trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj 2009-11-06 22:50:19 UTC (rev 4822) @@ -375,6 +375,7 @@ <Compile Include="LazyOneToOne\Person.cs" /> <Compile Include="LazyProperty\Book.cs" /> <Compile Include="LazyProperty\LazyPropertyFixture.cs" /> + <Compile Include="Linq\AggregateTests.cs" /> <Compile Include="Linq\Entities\Address.cs" /> <Compile Include="Linq\Entities\Customer.cs" /> <Compile Include="Linq\Entities\Employee.cs" /> This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |