From: <ste...@us...> - 2009-11-10 14:19:22
|
Revision: 4826 http://nhibernate.svn.sourceforge.net/nhibernate/?rev=4826&view=rev Author: steverstrong Date: 2009-11-10 14:19:09 +0000 (Tue, 10 Nov 2009) Log Message: ----------- Added more tests and support (in a semi-extensible way) for constructs such as string.ToUpper and datetime.Year Modified Paths: -------------- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs trunk/nhibernate/src/NHibernate/NHibernate.csproj trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs Added Paths: ----------- trunk/nhibernate/src/NHibernate/Linq/Visitors/BaseHqlGeneratorForProperty.cs Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-11-10 13:32:28 UTC (rev 4825) +++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-11-10 14:19:09 UTC (rev 4826) @@ -398,5 +398,10 @@ { return new HqlExpressionList(_factory); } + + public HqlMethodCall MethodCall() + { + return new HqlMethodCall(_factory); + } } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-11-10 13:32:28 UTC (rev 4825) +++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-11-10 14:19:09 UTC (rev 4826) @@ -147,6 +147,7 @@ return type.GetGenericArguments()[0]; } + // TODO - code duplicated in LinqExtensionMethods private static bool IsNullableType(System.Type type) { return (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)); @@ -646,4 +647,11 @@ { } } + + public class HqlMethodCall : HqlTreeNode + { + public HqlMethodCall(IASTFactory factory) : base(HqlSqlWalker.METHOD_CALL, "method", factory) + { + } + } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs 2009-11-10 13:32:28 UTC (rev 4825) +++ trunk/nhibernate/src/NHibernate/Linq/EnumerableHelper.cs 2009-11-10 14:19:09 UTC (rev 4826) @@ -20,6 +20,11 @@ var methodInfo = ((MethodCallExpression)method.Body).Method; return methodInfo.IsGenericMethod ? methodInfo.GetGenericMethodDefinition() : methodInfo; } + + public static MemberInfo GetProperty<TSource, TResult>(Expression<Func<TSource, TResult>> property) + { + return ((MemberExpression) property.Body).Member; + } } // TODO rename / remove - reflection helper above is better Modified: trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs 2009-11-10 13:32:28 UTC (rev 4825) +++ trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs 2009-11-10 14:19:09 UTC (rev 4826) @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; namespace NHibernate.Linq @@ -17,5 +18,11 @@ method(item); } } + + public static bool IsNullable(this System.Type type) + { + return (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)); + } + } } \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Visitors/BaseHqlGeneratorForProperty.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/BaseHqlGeneratorForProperty.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/BaseHqlGeneratorForProperty.cs 2009-11-10 14:19:09 UTC (rev 4826) @@ -0,0 +1,12 @@ +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace NHibernate.Linq.Visitors +{ + public abstract class BaseHqlGeneratorForProperty : IHqlGeneratorForProperty + { + public IEnumerable<MemberInfo> SupportedProperties { get; protected set; } + public abstract void BuildHql(MemberInfo member, Expression expression, HqlGeneratorExpressionTreeVisitor hqlGeneratorExpressionTreeVisitor); + } +} \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2009-11-10 13:32:28 UTC (rev 4825) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2009-11-10 14:19:09 UTC (rev 4826) @@ -7,7 +7,6 @@ using NHibernate.Engine.Query; using NHibernate.Hql.Ast; using NHibernate.Linq.Expressions; -using Remotion.Data.Linq; using Remotion.Data.Linq.Clauses.Expressions; using Remotion.Data.Linq.Clauses.ExpressionTreeVisitors; @@ -206,16 +205,28 @@ protected override Expression VisitMemberExpression(MemberExpression expression) { + // Strip out the .Value property of a nullable type, HQL doesn't need that + if (expression.Member.Name == "Value" && expression.Expression.Type.IsNullable()) + { + VisitExpression(expression.Expression); + return expression; + } + + // Look for "special" properties (DateTime.Month etc) + var generator = _methodGeneratorRegistry.GetPropertyGenerator(expression.Expression.Type, expression.Member); + + if (generator != null) + { + generator.BuildHql(expression.Member, expression.Expression, this); + return expression; + } + + // Else just emit standard HQL for a property reference using (_stack.PushNode(_hqlTreeBuilder.Dot())) { - Expression newExpression = VisitExpression(expression.Expression); + VisitExpression(expression.Expression); _stack.PushLeaf(_hqlTreeBuilder.Ident(expression.Member.Name)); - - if (newExpression != expression.Expression) - { - return Expression.MakeMemberAccess(newExpression, expression.Member); - } } return expression; @@ -253,7 +264,7 @@ { var generator = _methodGeneratorRegistry.GetMethodGenerator(expression.Method); - generator.BuildHql(expression.Object, expression.Arguments, this); + generator.BuildHql(expression.Method, expression.Object, expression.Arguments, this); return expression; } @@ -320,6 +331,12 @@ } } + public interface IHqlGeneratorForProperty + { + IEnumerable<MemberInfo> SupportedProperties { get; } + void BuildHql(MemberInfo member, Expression expression, HqlGeneratorExpressionTreeVisitor hqlGeneratorExpressionTreeVisitor); + } + public class MethodGeneratorRegistry { public static MethodGeneratorRegistry Initialise() @@ -329,11 +346,13 @@ // TODO - could use reflection here registry.Register(new QueryableMethodsGenerator()); registry.Register(new StringMethodsGenerator()); + registry.Register(new DateTimePropertyGenerator()); return registry; } private readonly Dictionary<MethodInfo, IHqlGeneratorForMethod> _registeredMethods = new Dictionary<MethodInfo, IHqlGeneratorForMethod>(); + private readonly Dictionary<MemberInfo, IHqlGeneratorForProperty> _registeredProperties = new Dictionary<MemberInfo, IHqlGeneratorForProperty>(); public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) { @@ -349,37 +368,55 @@ return methodGenerator; } - throw new NotSupportedException(); + throw new NotSupportedException(method.ToString()); } + public IHqlGeneratorForProperty GetPropertyGenerator(System.Type type, MemberInfo member) + { + IHqlGeneratorForProperty propertyGenerator; + + if (_registeredProperties.TryGetValue(member, out propertyGenerator)) + { + return propertyGenerator; + } + + // TODO - different usage pattern to method generator + return null; + } + public void RegisterMethodGenerator(MethodInfo method, IHqlGeneratorForMethod generator) { _registeredMethods.Add(method, generator); } + public void RegisterPropertyGenerator(MemberInfo property, IHqlGeneratorForProperty generator) + { + _registeredProperties.Add(property, generator); + } + private void Register(IHqlGeneratorForType typeMethodGenerator) { - typeMethodGenerator.RegisterMethods(this); + typeMethodGenerator.Register(this); } - } public interface IHqlGeneratorForMethod { IEnumerable<MethodInfo> SupportedMethods { get; } - void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor); + void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor); } public interface IHqlGeneratorForType { - void RegisterMethods(MethodGeneratorRegistry methodGeneratorRegistry); + void Register(MethodGeneratorRegistry methodGeneratorRegistry); } abstract public class BaseHqlGeneratorForType : IHqlGeneratorForType { protected readonly List<IHqlGeneratorForMethod> MethodRegistry = new List<IHqlGeneratorForMethod>(); + protected readonly List<IHqlGeneratorForProperty> PropertyRegistry = new List<IHqlGeneratorForProperty>(); - public void RegisterMethods(MethodGeneratorRegistry methodGeneratorRegistry) + public void Register(MethodGeneratorRegistry methodGeneratorRegistry) { foreach (var generator in MethodRegistry) { @@ -388,6 +425,14 @@ methodGeneratorRegistry.RegisterMethodGenerator(method, generator); } } + + foreach (var generator in PropertyRegistry) + { + foreach (var property in generator.SupportedProperties) + { + methodGeneratorRegistry.RegisterPropertyGenerator(property, generator); + } + } } } @@ -395,9 +440,46 @@ { public IEnumerable<MethodInfo> SupportedMethods { get; protected set; } - public abstract void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor); + public abstract void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor); } + public class DateTimePropertyGenerator : BaseHqlGeneratorForType + { + public DateTimePropertyGenerator() + { + PropertyRegistry.Add(new DatePartGenerator()); + } + + public class DatePartGenerator : BaseHqlGeneratorForProperty + { + public DatePartGenerator() + { + SupportedProperties = new[] + { + ReflectionHelper.GetProperty((DateTime x) => x.Year), + ReflectionHelper.GetProperty((DateTime x) => x.Month), + ReflectionHelper.GetProperty((DateTime x) => x.Day), + ReflectionHelper.GetProperty((DateTime x) => x.Hour), + ReflectionHelper.GetProperty((DateTime x) => x.Minute), + ReflectionHelper.GetProperty((DateTime x) => x.Second), + }; + } + + public override void BuildHql(MemberInfo member, Expression expression, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.MethodCall())) + { + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Ident(member.Name.ToLowerInvariant())); + + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.ExpressionList())) + { + hqlVisitor.Visit(expression); + } + } + } + } + } + public class StringMethodsGenerator : BaseHqlGeneratorForType { public StringMethodsGenerator() @@ -407,8 +489,32 @@ MethodRegistry.Add(new EndsWithGenerator()); MethodRegistry.Add(new ContainsGenerator()); MethodRegistry.Add(new EqualsGenerator()); + MethodRegistry.Add(new ToUpperLowerGenerator()); + + PropertyRegistry.Add(new LengthGenerator()); } + public class LengthGenerator : BaseHqlGeneratorForProperty + { + public LengthGenerator() + { + SupportedProperties = new[] {ReflectionHelper.GetProperty((string x) => x.Length)}; + } + + public override void BuildHql(MemberInfo member, Expression expression, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.MethodCall())) + { + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Ident("length")); + + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.ExpressionList())) + { + hqlVisitor.Visit(expression); + } + } + } + } + class StartsWithGenerator : BaseHqlGeneratorForMethod { public StartsWithGenerator() @@ -416,7 +522,7 @@ SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.StartsWith(null)) }; } - public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + public override void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) { using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Like())) { @@ -446,7 +552,7 @@ SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.EndsWith(null)) }; } - public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + public override void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) { using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Like())) { @@ -476,7 +582,7 @@ SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.Contains(null)) }; } - public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + public override void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) { using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Like())) { @@ -508,7 +614,7 @@ SupportedMethods = new[] { ReflectionHelper.GetMethod<string>(x => x.Equals((string)null)) }; } - public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + public override void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) { using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Equality())) { @@ -518,6 +624,40 @@ } } } + + class ToUpperLowerGenerator : BaseHqlGeneratorForMethod + { + public ToUpperLowerGenerator() + { + SupportedMethods = new[] + { + ReflectionHelper.GetMethod<string>(x => x.ToUpper()), + ReflectionHelper.GetMethod<string>(x => x.ToUpperInvariant()), + ReflectionHelper.GetMethod<string>(x => x.ToLower()), + ReflectionHelper.GetMethod<string>(x => x.ToLowerInvariant()) + }; + } + + public override void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + { + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.MethodCall())) + { + if (((method.Name == "ToUpper") || (method.Name == "ToUpperInvariant"))) + { + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Ident("lower")); + } + else + { + hqlVisitor.Stack.PushLeaf(hqlVisitor.TreeBuilder.Ident("upper")); + } + + using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.ExpressionList())) + { + hqlVisitor.Visit(targetObject); + } + } + } + } } public class QueryableMethodsGenerator : BaseHqlGeneratorForType @@ -544,7 +684,7 @@ }; } - public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + public override void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) { // Any has one or two arguments. Arg 1 is the source and arg 2 is the optional predicate using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Exists())) @@ -590,7 +730,7 @@ }; } - public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + public override void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) { // All has two arguments. Arg 1 is the source and arg 2 is the predicate using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Not())) @@ -638,7 +778,7 @@ }; } - public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + public override void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) { using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Min())) { @@ -658,7 +798,7 @@ }; } - public override void BuildHql(Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) + public override void BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlGeneratorExpressionTreeVisitor hqlVisitor) { using (hqlVisitor.Stack.PushNode(hqlVisitor.TreeBuilder.Max())) { Modified: trunk/nhibernate/src/NHibernate/NHibernate.csproj =================================================================== --- trunk/nhibernate/src/NHibernate/NHibernate.csproj 2009-11-10 13:32:28 UTC (rev 4825) +++ trunk/nhibernate/src/NHibernate/NHibernate.csproj 2009-11-10 14:19:09 UTC (rev 4826) @@ -585,6 +585,7 @@ <Compile Include="Linq\GroupJoin\GroupJoinSelectClauseRewriter.cs" /> <Compile Include="Linq\GroupJoin\LocateGroupJoinQuerySource.cs" /> <Compile Include="Linq\GroupJoin\NonAggregatingGroupJoinRewriter.cs" /> + <Compile Include="Linq\Visitors\BaseHqlGeneratorForProperty.cs" /> <Compile Include="Linq\Visitors\SwapQuerySourceVisitor.cs" /> <Compile Include="Linq\Visitors\EqualityHqlGenerator.cs" /> <Compile Include="Linq\Visitors\ExpressionParameterVisitor.cs" /> Modified: trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs =================================================================== --- trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs 2009-11-10 13:32:28 UTC (rev 4825) +++ trunk/nhibernate/src/NHibernate.Test/Linq/AggregateTests.cs 2009-11-10 14:19:09 UTC (rev 4826) @@ -61,15 +61,14 @@ Console.WriteLine(query); Assert.AreEqual("ALFKI,AROUT,", query.ToString()); } - /* - [Test] - [Ignore("TODO")] + + [Test] public void AggregateWithMonthFunction() { var date = new DateTime(2007, 1, 1); var query = (from e in db.Employees - where db.Methods.Month(e.BirthDate) == date.Month + where e.BirthDate.Value.Month == date.Month select e.FirstName) .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name)); @@ -77,15 +76,14 @@ Console.WriteLine(query); } - [Test] - [Ignore("TODO")] + [Test] public void AggregateWithBeforeYearFunction() { var date = new DateTime(1960, 1, 1); var query = (from e in db.Employees - where db.Methods.Year(e.BirthDate) < date.Year - select db.Methods.Upper(e.FirstName)) + where e.BirthDate.Value.Year < date.Year + select e.FirstName.ToUpper()) .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name)); Console.WriteLine("Birthdays before {0}:", date.ToString("yyyy")); @@ -93,13 +91,12 @@ } [Test] - [Ignore("TODO")] public void AggregateWithOnOrAfterYearFunction() { var date = new DateTime(1960, 1, 1); var query = (from e in db.Employees - where db.Methods.Year(e.BirthDate) >= date.Year && db.Methods.Len(e.FirstName) > 4 + where e.BirthDate.Value.Year >= date.Year && e.FirstName.Length > 4 select e.FirstName) .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name)); @@ -108,14 +105,13 @@ } [Test] - [Ignore("TODO")] public void AggregateWithUpperAndLowerFunctions() { var date = new DateTime(2007, 1, 1); var query = (from e in db.Employees - where db.Methods.Month(e.BirthDate) == date.Month - select new { First = e.FirstName.ToUpper(), Last = db.Methods.Lower(e.LastName) }) + where e.BirthDate.Value.Month == date.Month + select new { First = e.FirstName.ToUpper(), Last = e.LastName.ToLower() }) .Aggregate(new StringBuilder(), (sb, name) => sb.Length > 0 ? sb.Append(", ").Append(name) : sb.Append(name)); Console.WriteLine("{0} Birthdays:", date.ToString("MMMM")); @@ -123,17 +119,19 @@ } [Test] - [Ignore("TODO")] + [Ignore("TODO: Custom functions")] public void AggregateWithCustomFunction() { + /* var date = new DateTime(1960, 1, 1); var query = (from e in db.Employees - where db.Methods.Year(e.BirthDate) < date.Year + where e.BirthDate.Value.Year < date.Year select db.Methods.fnEncrypt(e.FirstName)) .Aggregate(new StringBuilder(), (sb, name) => sb.AppendLine(BitConverter.ToString(name))); Console.WriteLine(query); - }*/ + */ + } } } \ No newline at end of file This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |