From: <ste...@us...> - 2009-12-16 21:36:52
|
Revision: 4895 http://nhibernate.svn.sourceforge.net/nhibernate/?rev=4895&view=rev Author: steverstrong Date: 2009-12-16 21:36:34 +0000 (Wed, 16 Dec 2009) Log Message: ----------- More Linq test cases and supporting fixes. Modified Paths: -------------- trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/PolymorphicQuerySourceDetector.cs trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs trunk/nhibernate/src/NHibernate/Impl/AbstractQueryImpl.cs trunk/nhibernate/src/NHibernate/Impl/ExpressionQueryImpl.cs trunk/nhibernate/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs trunk/nhibernate/src/NHibernate/Linq/Functions/BaseHqlGeneratorForType.cs trunk/nhibernate/src/NHibernate/Linq/Functions/FunctionRegistry.cs trunk/nhibernate/src/NHibernate/Linq/Functions/IHqlGeneratorForType.cs trunk/nhibernate/src/NHibernate/Linq/Functions/QueryableGenerator.cs trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs trunk/nhibernate/src/NHibernate/Linq/ResultOperators/ClientSideTransformOperator.cs trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/Nominator.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs trunk/nhibernate/src/NHibernate/NHibernate.csproj trunk/nhibernate/src/NHibernate.ByteCode.LinFu/NHibernate.ByteCode.LinFu.csproj trunk/nhibernate/src/NHibernate.DomainModel/NHibernate.DomainModel.csproj trunk/nhibernate/src/NHibernate.Test/App.config trunk/nhibernate/src/NHibernate.Test/Linq/Entities/Northwind.cs trunk/nhibernate/src/NHibernate.Test/Linq/LinqQuerySamples.cs trunk/nhibernate/src/NHibernate.Test/Linq/ParameterisedQueries.cs trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj Added Paths: ----------- trunk/nhibernate/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs trunk/nhibernate/src/NHibernate.Test/Linq/CollectionAssert.cs trunk/nhibernate/src/NHibernate.Test/Linq/Entities/UserDto.cs trunk/nhibernate/src/NHibernate.Test/Linq/RegresstionTests.cs trunk/nhibernate/src/NHibernate.Test/Linq/SelectionTests.cs trunk/nhibernate/src/NHibernate.Test/Linq/WhereSubqueryTests.cs trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/PolymorphicQuerySourceDetector.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/PolymorphicQuerySourceDetector.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/PolymorphicQuerySourceDetector.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -27,20 +27,17 @@ var className = GetClassName(querySource); var classType = _sessionFactoryHelper.GetImportedClass(className); - if (classType != null) - { - AddImplementorsToMap(querySource, classType); - } + AddImplementorsToMap(querySource, classType == null ? className : classType.FullName); } return _map; } - private void AddImplementorsToMap(IASTNode querySource, System.Type classType) + private void AddImplementorsToMap(IASTNode querySource, string className) { - var implementors = _sfi.GetImplementors(classType.FullName); + var implementors = _sfi.GetImplementors(className); - if (implementors.Length == 1 && implementors[0] == classType.FullName) + if (implementors.Length == 1 && implementors[0] == className) { // No need to change things return; Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -341,11 +341,6 @@ return new HqlConcat(_factory, args); } - public HqlExpressionList ExpressionList() - { - return new HqlExpressionList(_factory); - } - public HqlMethodCall MethodCall(string methodName, IEnumerable<HqlExpression> parameters) { return new HqlMethodCall(_factory, methodName, parameters); @@ -370,6 +365,30 @@ { return new HqlIsNotNull(_factory, lhs); } + + public HqlTreeNode ExpressionList(IEnumerable<HqlExpression> expressions) + { + return new HqlExpressionList(_factory, expressions); + } + + public HqlStar Star() + { + return new HqlStar(_factory); + } + + public HqlTrue True() + { + return new HqlTrue(_factory); + } + + public HqlFalse False() + { + return new HqlFalse(_factory); + } + + public HqlIn In(HqlExpression itemExpression, HqlTreeNode source) + { + return new HqlIn(_factory, itemExpression, source); + } } - } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -8,11 +8,13 @@ { public class HqlTreeNode { + public IASTFactory Factory { get; private set; } private readonly IASTNode _node; private readonly List<HqlTreeNode> _children; protected HqlTreeNode(int type, string text, IASTFactory factory, IEnumerable<HqlTreeNode> children) { + Factory = factory; _node = factory.CreateNode(type, text); _children = new List<HqlTreeNode>(); @@ -107,18 +109,27 @@ _node.AddChild(child.AstNode); } } + } - public HqlExpression AsExpression() + public static class HqlTreeNodeExtensions + { + public static HqlExpression AsExpression(this HqlTreeNode node) { // TODO - nice error handling if cast fails - return (HqlExpression) this; + return (HqlExpression)node; } - public virtual HqlBooleanExpression AsBooleanExpression() + public static HqlBooleanExpression AsBooleanExpression(this HqlTreeNode node) { + if (node is HqlDot) + { + return new HqlBooleanDot(node.Factory, (HqlDot) node); + } + // TODO - nice error handling if cast fails - return (HqlBooleanExpression)this; + return (HqlBooleanExpression)node; } + } public abstract class HqlStatement : HqlTreeNode @@ -329,20 +340,10 @@ public class HqlDot : HqlExpression { - private readonly IASTFactory _factory; - public HqlDot(IASTFactory factory, HqlExpression lhs, HqlExpression rhs) : base(HqlSqlWalker.DOT, ".", factory, lhs, rhs) { - _factory = factory; } - - public override HqlBooleanExpression AsBooleanExpression() - { - // If we are of boolean type, then we can acts as boolean expression - // TODO - implement type check - return new HqlBooleanDot(_factory, this); - } } public class HqlBooleanDot : HqlBooleanExpression @@ -408,6 +409,23 @@ } } + public class HqlFalse : HqlConstant + { + public HqlFalse(IASTFactory factory) + : base(factory, HqlSqlWalker.FALSE, "false") + { + } + } + + public class HqlTrue : HqlConstant + { + public HqlTrue(IASTFactory factory) + : base(factory, HqlSqlWalker.TRUE, "true") + { + } + } + + public class HqlNull : HqlConstant { public HqlNull(IASTFactory factory) @@ -430,15 +448,23 @@ Descending } - public class HqlDirectionAscending : HqlStatement + public class HqlDirectionStatement : HqlStatement { + public HqlDirectionStatement(int type, string text, IASTFactory factory) + : base(type, text, factory) + { + } + } + + public class HqlDirectionAscending : HqlDirectionStatement + { public HqlDirectionAscending(IASTFactory factory) : base(HqlSqlWalker.ASCENDING, "asc", factory) { } } - public class HqlDirectionDescending : HqlStatement + public class HqlDirectionDescending : HqlDirectionStatement { public HqlDirectionDescending(IASTFactory factory) : base(HqlSqlWalker.DESCENDING, "desc", factory) @@ -733,4 +759,28 @@ { } } + + public class HqlStar : HqlExpression + { + public HqlStar(IASTFactory factory) : base(HqlSqlWalker.ROW_STAR, "*", factory) + { + } + } + + public class HqlIn : HqlBooleanExpression + { + public HqlIn(IASTFactory factory, HqlExpression itemExpression, HqlTreeNode source) + : base(HqlSqlWalker.IN, "in", factory, itemExpression) + { + AddChild(new HqlInList(factory, source)); + } + } + + public class HqlInList : HqlTreeNode + { + public HqlInList(IASTFactory factory, HqlTreeNode source) + : base(HqlSqlWalker.IN_LIST, "inlist", factory, source) + { + } + } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Impl/AbstractQueryImpl.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Impl/AbstractQueryImpl.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Impl/AbstractQueryImpl.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -27,7 +27,7 @@ private readonly ArrayList values = new ArrayList(4); private readonly List<IType> types = new List<IType>(4); private readonly Dictionary<string, TypedValue> namedParameters = new Dictionary<string, TypedValue>(4); - private readonly Dictionary<string, TypedValue> namedParameterLists = new Dictionary<string, TypedValue>(4); + protected readonly Dictionary<string, TypedValue> namedParameterLists = new Dictionary<string, TypedValue>(4); private bool cacheable; private string cacheRegion; private bool readOnly; Modified: trunk/nhibernate/src/NHibernate/Impl/ExpressionQueryImpl.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Impl/ExpressionQueryImpl.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Impl/ExpressionQueryImpl.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -1,8 +1,17 @@ -using System; +using System; using System.Collections; using System.Collections.Generic; +using System.Linq; +using System.Text; using NHibernate.Engine; using NHibernate.Engine.Query; +using NHibernate.Engine.Query.Sql; +using NHibernate.Hql.Ast.ANTLR; +using NHibernate.Hql.Ast.ANTLR.Tree; +using NHibernate.Hql.Ast.ANTLR.Util; +using NHibernate.Hql.Classic; +using NHibernate.Type; +using NHibernate.Util; namespace NHibernate.Impl { @@ -12,12 +21,12 @@ public IQueryExpression QueryExpression { get; private set; } - public ExpressionQueryImpl(IQueryExpression queryExpression, ISessionImplementor session, ParameterMetadata parameterMetadata) + public ExpressionQueryImpl(IQueryExpression queryExpression, ISessionImplementor session, ParameterMetadata parameterMetadata) : base(queryExpression.Key, FlushMode.Unspecified, session, parameterMetadata) { QueryExpression = queryExpression; } - + public override IQuery SetLockMode(string alias, LockMode lockMode) { _lockModes[alias] = lockMode; @@ -51,7 +60,7 @@ Before(); try { - return Session.List(QueryExpression, GetQueryParameters(namedParams)); + return Session.List(ExpandParameters(namedParams), GetQueryParameters(namedParams)); } finally { @@ -59,6 +68,65 @@ } } + /// <summary> + /// Warning: adds new parameters to the argument by side-effect, as well as mutating the query expression tree! + /// </summary> + protected IQueryExpression ExpandParameters(IDictionary<string, TypedValue> namedParamsCopy) + { + if (namedParameterLists.Count == 0) + { + // Short circuit straight out + return QueryExpression; + } + + // Build a map from single parameters to lists + var map = new Dictionary<string, List<string>>(); + + foreach (var me in namedParameterLists) + { + string name = me.Key; + var vals = (ICollection) me.Value.Value; + var type = me.Value.Type; + + if (vals.Count == 1) + { + // No expansion needed here + var iter = vals.GetEnumerator(); + iter.MoveNext(); + namedParamsCopy[name] = new TypedValue(type, iter.Current, Session.EntityMode); + continue; + } + + var aliases = new List<string>(); + var i = 0; + var isJpaPositionalParam = parameterMetadata.GetNamedParameterDescriptor(name).JpaStyle; + + foreach (var obj in vals) + { + var alias = (isJpaPositionalParam ? 'x' + name : name + StringHelper.Underscore) + i++ + StringHelper.Underscore; + namedParamsCopy[alias] = new TypedValue(type, obj, Session.EntityMode); + aliases.Add(alias); + } + + map.Add(name, aliases); + + } + + var newTree = ParameterExpander.Expand(QueryExpression.Translate(Session.Factory), map); + var key = new StringBuilder(QueryExpression.Key); + + map.Aggregate(key, (sb, kvp) => + { + sb.Append(' '); + sb.Append(kvp.Key); + sb.Append(':'); + kvp.Value.Aggregate(sb, (sb2, str) => sb2.Append(str)); + return sb; + }); + + return new ExpandedQueryExpression(QueryExpression, newTree, key.ToString()); + } + public override void List(IList results) { throw new NotImplementedException(); @@ -69,4 +137,141 @@ throw new NotImplementedException(); } } + + internal class ExpandedQueryExpression : IQueryExpression + { + private readonly IASTNode _tree; + + public ExpandedQueryExpression(IQueryExpression queryExpression, IASTNode tree, string key) + { + _tree = tree; + Key = key; + Type = queryExpression.Type; + ParameterDescriptors = queryExpression.ParameterDescriptors; + } + + public IASTNode Translate(ISessionFactory sessionFactory) + { + return _tree; + } + + public string Key { get; private set; } + + public System.Type Type { get; private set; } + + public IList<NamedParameterDescriptor> ParameterDescriptors { get; private set; } + } + + internal class ParameterExpander + { + private readonly IASTNode _tree; + private readonly Dictionary<string, List<string>> _map; + + private ParameterExpander(IASTNode tree, Dictionary<string, List<string>> map) + { + _tree = tree; + _map = map; + } + + public static IASTNode Expand(IASTNode tree, Dictionary<string, List<string>> map) + { + var expander = new ParameterExpander(tree, map); + + return expander.Expand(); + } + + private IASTNode Expand() + { + var parameters = ParameterDetector.LocateParameters(_tree, new HashSet<string>(_map.Keys)); + var nodeMapping = new Dictionary<IASTNode, IEnumerable<IASTNode>>(); + + foreach (var param in parameters) + { + var paramName = param.GetChild(0); + var aliases = _map[paramName.Text]; + var astAliases = new List<IASTNode>(); + + foreach (var alias in aliases) + { + var astAlias = param.DupNode(); + var astAliasName = paramName.DupNode(); + astAliasName.Text = alias; + astAlias.AddChild(astAliasName); + + astAliases.Add(astAlias); + } + + nodeMapping.Add(param, astAliases); + } + + return DuplicateTree(_tree, nodeMapping); + } + + private static IASTNode DuplicateTree(IASTNode ast, IDictionary<IASTNode, IEnumerable<IASTNode>> nodeMapping) + { + var thisNode = ast.DupNode(); + + foreach (var child in ast) + { + IEnumerable<IASTNode> candidate; + + if (nodeMapping.TryGetValue(child, out candidate)) + { + foreach (var replacement in candidate) + { + thisNode.AddChild(replacement); + } + } + else + { + thisNode.AddChild(DuplicateTree(child, nodeMapping)); + } + } + + return thisNode; + } + } + + internal class ParameterDetector : IVisitationStrategy + { + private readonly IASTNode _tree; + private readonly HashSet<string> _parameterNames; + private readonly List<IASTNode> _nodes; + + private ParameterDetector(IASTNode tree, HashSet<string> parameterNames) + { + _tree = tree; + _parameterNames = parameterNames; + _nodes = new List<IASTNode>(); + } + + public static IList<IASTNode> LocateParameters(IASTNode tree, HashSet<string> parameterNames) + { + var detector = new ParameterDetector(tree, parameterNames); + + return detector.LocateParameters(); + } + + private IList<IASTNode> LocateParameters() + { + var nodeTraverser = new NodeTraverser(this); + nodeTraverser.TraverseDepthFirst(_tree); + + return _nodes; + } + + public void Visit(IASTNode node) + { + if ((node.Type == HqlSqlWalker.PARAM) || (node.Type == HqlSqlWalker.COLON)) + { + var name = node.GetChild(0).Text; + + if (_parameterNames.Contains(name)) + { + _nodes.Add(node); + } + } + } + + } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -9,17 +9,24 @@ { public class ExpressionToHqlTranslationResults { - public HqlQuery Statement { get; private set; } + public HqlTreeNode Statement { get; private set; } public ResultTransformer ResultTransformer { get; private set; } - public List<Action<IQuery, IDictionary<string, Pair<object, IType>>>> AdditionalCriteria { get; private set; } + public Delegate PostExecuteTransformer { get; private set; } + public List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> AdditionalCriteria { get; private set; } - public ExpressionToHqlTranslationResults(HqlQuery statement, IList<LambdaExpression> itemTransformers, IList<LambdaExpression> listTransformers, List<Action<IQuery, IDictionary<string, Pair<object, IType>>>> additionalCriteria) + public ExpressionToHqlTranslationResults(HqlTreeNode statement, + IList<LambdaExpression> itemTransformers, + IList<LambdaExpression> listTransformers, + IList<LambdaExpression> postExecuteTransformers, + List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> additionalCriteria) { Statement = statement; - var itemTransformer = MergeLambdas(itemTransformers); - var listTransformer = MergeLambdas(listTransformers); + PostExecuteTransformer = MergeLambdasAndCompile(postExecuteTransformers); + var itemTransformer = MergeLambdasAndCompile(itemTransformers); + var listTransformer = MergeLambdasAndCompile(listTransformers); + if (itemTransformer != null || listTransformer != null) { ResultTransformer = new ResultTransformer(itemTransformer, listTransformer); @@ -28,7 +35,7 @@ AdditionalCriteria = additionalCriteria; } - private static LambdaExpression MergeLambdas(IList<LambdaExpression> transformations) + private static Delegate MergeLambdasAndCompile(IList<LambdaExpression> transformations) { if (transformations == null || transformations.Count == 0) { @@ -44,7 +51,7 @@ listTransformLambda = Expression.Lambda(invoked, listTransformLambda.Parameters.ToArray()); } - return listTransformLambda; + return listTransformLambda.Compile(); } } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -8,6 +8,7 @@ Sum, Count, Distinct, - New + New, + Star } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq.Expressions; @@ -26,4 +27,17 @@ get { return _members; } } } + + public class NhStarExpression : Expression + { + public NhStarExpression(Expression expression) : base((ExpressionType) NhExpressionType.Star, expression.Type) + { + Expression = expression; + } + + public Expression Expression + { + get; private set; + } + } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/Functions/BaseHqlGeneratorForType.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Functions/BaseHqlGeneratorForType.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/Functions/BaseHqlGeneratorForType.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -1,4 +1,6 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; +using System.Reflection; using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Functions @@ -26,5 +28,15 @@ } } } + + public virtual bool SupportsMethod(MethodInfo method) + { + return false; + } + + public virtual IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) + { + throw new NotSupportedException(); + } } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/Functions/FunctionRegistry.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Functions/FunctionRegistry.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/Functions/FunctionRegistry.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -36,12 +36,14 @@ registry.Register(new QueryableGenerator()); registry.Register(new StringGenerator()); registry.Register(new DateTimeGenerator()); + registry.Register(new ICollectionGenerator()); return registry; } private readonly Dictionary<MethodInfo, IHqlGeneratorForMethod> _registeredMethods = new Dictionary<MethodInfo, IHqlGeneratorForMethod>(); private readonly Dictionary<MemberInfo, IHqlGeneratorForProperty> _registeredProperties = new Dictionary<MemberInfo, IHqlGeneratorForProperty>(); + private readonly List<IHqlGeneratorForType> _typeGenerators = new List<IHqlGeneratorForType>(); public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) { @@ -58,14 +60,23 @@ } // No method generator registered. Look to see if it's a standard LinqExtensionMethod - var attr = (LinqExtensionMethodAttribute) method.GetCustomAttributes(typeof (LinqExtensionMethodAttribute), false)[0]; - if (attr != null) + var attr = method.GetCustomAttributes(typeof (LinqExtensionMethodAttribute), false); + if (attr.Length == 1) { // It is // TODO - cache this? Is it worth it? - return new HqlGeneratorForExtensionMethod(attr, method); + return new HqlGeneratorForExtensionMethod((LinqExtensionMethodAttribute) attr[0], method); } + // Not that either. Let's query each type generator to see if it can handle it + foreach (var typeGenerator in _typeGenerators) + { + if (typeGenerator.SupportsMethod(method)) + { + return typeGenerator.GetMethodGenerator(method); + } + } + throw new NotSupportedException(method.ToString()); } @@ -94,6 +105,7 @@ private void Register(IHqlGeneratorForType typeMethodGenerator) { + _typeGenerators.Add(typeMethodGenerator); typeMethodGenerator.Register(this); } } Modified: trunk/nhibernate/src/NHibernate/Linq/Functions/IHqlGeneratorForType.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Functions/IHqlGeneratorForType.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/Functions/IHqlGeneratorForType.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -1,7 +1,11 @@ -namespace NHibernate.Linq.Functions +using System.Reflection; + +namespace NHibernate.Linq.Functions { public interface IHqlGeneratorForType { void Register(FunctionRegistry functionRegistry); + bool SupportsMethod(MethodInfo method); + IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method); } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/Functions/QueryableGenerator.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Functions/QueryableGenerator.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/Functions/QueryableGenerator.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -1,4 +1,6 @@ -using System.Collections.ObjectModel; +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -126,4 +128,74 @@ } } } + + public class ICollectionGenerator : BaseHqlGeneratorForType + { + public ICollectionGenerator() + { + // TODO - could use reflection + MethodRegistry.Add(new ContainsGenerator()); + } + + public override bool SupportsMethod(MethodInfo method) + { + var declaringType = method.DeclaringType; + + if (declaringType.IsGenericType) + { + if (declaringType.GetGenericTypeDefinition() == typeof(ICollection<>) || + declaringType.GetGenericTypeDefinition() == typeof(IEnumerable<>)) + { + if (method.Name == "Contains") + { + return true; + } + } + } + + return false; + } + + public override IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) + { + // TODO - ick + if (method.Name == "Contains") + { + return new ContainsGenerator(); + } + + throw new NotSupportedException(method.Name); + } + + class ContainsGenerator : BaseHqlGeneratorForMethod + { + public ContainsGenerator() + { + SupportedMethods = new MethodInfo[0]; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + // TODO - alias generator + var alias = treeBuilder.Alias("x"); + + var param = Expression.Parameter(targetObject.Type, "x"); + var where = treeBuilder.Where(visitor.Visit(Expression.Lambda( + Expression.Equal(param, arguments[0]), param)) + .AsExpression()); + + return treeBuilder.Exists( + treeBuilder.Query( + treeBuilder.SelectFrom( + treeBuilder.From( + treeBuilder.Range( + visitor.Visit(targetObject), + alias) + ) + ), + where)); + } + } + + } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -24,10 +24,19 @@ return (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)); } + public static bool IsNullableOrReference(this System.Type type) + { + return !type.IsValueType || (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)); + } + public static System.Type NullableOf(this System.Type type) { return type.GetGenericArguments()[0]; } + public static T As<T>(this object source) + { + return (T) source; + } } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -9,7 +9,6 @@ using NHibernate.Type; using Remotion.Data.Linq; using Remotion.Data.Linq.Clauses; -using Remotion.Data.Linq.Clauses.StreamedData; using Remotion.Data.Linq.Parsing.ExpressionTreeVisitors; using Remotion.Data.Linq.Parsing.Structure; using Remotion.Data.Linq.Parsing.Structure.IntermediateModel; @@ -26,14 +25,15 @@ public NhLinqExpressionReturnType ReturnType { get; private set; } - public IDictionary<string, Pair<object, IType>> ParameterValuesByName { get; private set; } + public IDictionary<string, Tuple<object, IType>> ParameterValuesByName { get; private set; } public ExpressionToHqlTranslationResults ExpressionToHqlTranslationResults { get; private set; } private readonly Expression _expression; private readonly IDictionary<ConstantExpression, NamedParameter> _constantToParameterMap; + private IASTNode _astNode; - public NhLinqExpression(Expression expression) + public NhLinqExpression(Expression expression) { _expression = PartialEvaluatingExpressionTreeVisitor.EvaluateIndependentSubtrees(expression); @@ -43,8 +43,8 @@ ParameterValuesByName = _constantToParameterMap.Values.ToDictionary(p => p.Name, p => - new Pair<object, IType> - {Left = p.Value, Right = p.Type}); + new Tuple<object, IType> + {First = p.Value, Second = p.Type}); Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap); @@ -62,16 +62,22 @@ public IASTNode Translate(ISessionFactory sessionFactory) { - var requiredHqlParameters = new List<NamedParameterDescriptor>(); + //if (_astNode == null) + { + var requiredHqlParameters = new List<NamedParameterDescriptor>(); - // TODO - can we cache any of this? - var queryModel = NhRelinqQueryParser.Parse(_expression); + // TODO - can we cache any of this? + var queryModel = NhRelinqQueryParser.Parse(_expression); - ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, _constantToParameterMap, requiredHqlParameters); + ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, + _constantToParameterMap, + requiredHqlParameters); - ParameterDescriptors = requiredHqlParameters.AsReadOnly(); + ParameterDescriptors = requiredHqlParameters.AsReadOnly(); + _astNode = ExpressionToHqlTranslationResults.Statement.AstNode; + } - return ExpressionToHqlTranslationResults.Statement.AstNode; + return _astNode; } } @@ -89,6 +95,14 @@ MethodCallExpressionNodeTypeRegistry.GetRegisterableMethodDefinition(ReflectionHelper.GetMethod(() => Queryable.Aggregate<object, object>(null, null, null))) }, typeof (AggregateExpressionNode)); + + MethodCallRegistry.Register( + new [] + { + MethodCallExpressionNodeTypeRegistry.GetRegisterableMethodDefinition(ReflectionHelper.GetMethod((List<object> l) => l.Contains(null))), + + }, + typeof(ContainsExpressionNode)); } public static QueryModel Parse(Expression expression) @@ -179,15 +193,5 @@ Accumulator = accumulator; OptionalSelector = optionalSelector; } - - public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo) - { - throw new NotImplementedException(); - } - - public override ResultOperatorBase Clone(CloneContext cloneContext) - { - throw new NotImplementedException(); - } } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -1,7 +1,9 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using System.Reflection; using NHibernate.Impl; using NHibernate.Type; @@ -22,12 +24,26 @@ var query = _session.CreateQuery(nhLinqExpression); + var nhQuery = query.As<ExpressionQueryImpl>().QueryExpression.As<NhLinqExpression>(); + SetParameters(query, nhLinqExpression.ParameterValuesByName); - SetResultTransformerAndAdditionalCriteria(query, nhLinqExpression.ParameterValuesByName); + SetResultTransformerAndAdditionalCriteria(query, nhQuery, nhLinqExpression.ParameterValuesByName); var results = query.List(); - if (nhLinqExpression.ReturnType == NhLinqExpressionReturnType.Sequence) + if (nhQuery.ExpressionToHqlTranslationResults.PostExecuteTransformer != null) + { + try + { + return nhQuery.ExpressionToHqlTranslationResults.PostExecuteTransformer.DynamicInvoke(results.AsQueryable()); + } + catch (TargetInvocationException e) + { + throw e.InnerException; + } + } + + if (nhLinqExpression.ReturnType == NhLinqExpressionReturnType.Sequence) { return results.AsQueryable(); } @@ -52,28 +68,39 @@ return new NhQueryable<T>(this, expression); } - static void SetParameters(IQuery query, IDictionary<string, Pair<object, IType>> parameters) + static void SetParameters(IQuery query, IDictionary<string, Tuple<object, IType>> parameters) { foreach (var parameterName in query.NamedParameters) { var param = parameters[parameterName]; - if (param.Left == null) + + if (param.First == null) { - query.SetParameter(parameterName, param.Left, param.Right); + if (typeof(ICollection).IsAssignableFrom(param.Second.ReturnedClass)) + { + query.SetParameterList(parameterName, null, param.Second); + } + else + { + query.SetParameter(parameterName, null, param.Second); + } } else { - query.SetParameter(parameterName, param.Left); + if (param.First is ICollection) + { + query.SetParameterList(parameterName, (ICollection) param.First); + } + else + { + query.SetParameter(parameterName, param.First); + } } } } - public void SetResultTransformerAndAdditionalCriteria(IQuery query, IDictionary<string, Pair<object, IType>> parameters) + public void SetResultTransformerAndAdditionalCriteria(IQuery query, NhLinqExpression nhExpression, IDictionary<string, Tuple<object, IType>> parameters) { - var queryImpl = (ExpressionQueryImpl) query; - - var nhExpression = (NhLinqExpression) queryImpl.QueryExpression; - query.SetResultTransformer(nhExpression.ExpressionToHqlTranslationResults.ResultTransformer); foreach (var criteria in nhExpression.ExpressionToHqlTranslationResults.AdditionalCriteria) @@ -83,9 +110,17 @@ } } - public class Pair<TLeft, TRight> + public class Tuple<T1, T2> { - public TLeft Left { get; set; } - public TRight Right { get; set; } + public T1 First { get; set; } + public T2 Second { get; set; } + } + + public class Tuple<T1, T2, T3> + { + public T1 First { get; set; } + public T2 Second { get; set; } + public T3 Third { get; set; } + } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -53,29 +53,40 @@ } else if (resultOperator is CountResultOperator) { - queryModel.SelectClause.Selector = new NhShortCountExpression(queryModel.SelectClause.Selector); + queryModel.SelectClause.Selector = new NhShortCountExpression(new NhStarExpression(queryModel.SelectClause.Selector)); queryModel.ResultOperators.Remove(resultOperator); } else if (resultOperator is LongCountResultOperator) { - queryModel.SelectClause.Selector = new NhLongCountExpression(queryModel.SelectClause.Selector); + queryModel.SelectClause.Selector = new NhLongCountExpression(new NhStarExpression(queryModel.SelectClause.Selector)); queryModel.ResultOperators.Remove(resultOperator); } base.VisitResultOperator(resultOperator, queryModel, index); } - + public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) { - selectClause.TransformExpressions(s => new MergeAggregatingResultsInExpressionRewriter().Visit(s)); + selectClause.TransformExpressions(MergeAggregatingResultsInExpressionRewriter.Rewrite); } + + public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) + { + whereClause.TransformExpressions(MergeAggregatingResultsInExpressionRewriter.Rewrite); + } } - + internal class MergeAggregatingResultsInExpressionRewriter : NhExpressionTreeVisitor { - public Expression Visit(Expression expression) + private MergeAggregatingResultsInExpressionRewriter() + { + } + + public static Expression Rewrite(Expression expression) { - return VisitExpression(expression); + var visitor = new MergeAggregatingResultsInExpressionRewriter(); + + return visitor.VisitExpression(expression); } protected override Expression VisitSubQueryExpression(SubQueryExpression expression) @@ -83,7 +94,7 @@ MergeAggregatingResultsRewriter.ReWrite(expression.QueryModel); return expression; } - + protected override Expression VisitMethodCallExpression(MethodCallExpression m) { if (m.Method.DeclaringType == typeof(Queryable) || @@ -94,36 +105,43 @@ { case "Count": return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1], - e => new NhShortCountExpression(e)); + e => new NhShortCountExpression(e), + () => new CountResultOperator()); case "Min": return CreateAggregate(m.Arguments[0], (LambdaExpression) m.Arguments[1], - e => new NhMinExpression(e)); + e => new NhMinExpression(e), + () => new MinResultOperator()); case "Max": return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1], - e => new NhMaxExpression(e)); + e => new NhMaxExpression(e), + () => new MaxResultOperator()); case "Sum": return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1], - e => new NhSumExpression(e)); + e => new NhSumExpression(e), + () => new SumResultOperator()); case "Average": return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1], - e => new NhAverageExpression(e)); + e => new NhAverageExpression(e), + () => new AverageResultOperator()); } } return base.VisitMethodCallExpression(m); } - private Expression CreateAggregate(Expression fromClauseExpression, LambdaExpression body, Func<Expression,Expression> factory) + private Expression CreateAggregate(Expression fromClauseExpression, LambdaExpression body, Func<Expression,Expression> aggregateFactory, Func<ResultOperatorBase> resultOperatorFactory) { + // TODO - need generated name here var fromClause = new MainFromClause("x2", body.Parameters[0].Type, fromClauseExpression); var selectClause = body.Body; selectClause = ReplacingExpressionTreeVisitor.Replace(body.Parameters[0], new QuerySourceReferenceExpression( fromClause), selectClause); var queryModel = new QueryModel(fromClause, - new SelectClause(factory(selectClause))); + new SelectClause(aggregateFactory(selectClause))); - queryModel.ResultOperators.Add(new AverageResultOperator()); + // TODO - this sucks, but we use it to get the Type of the SubQueryExpression correct + queryModel.ResultOperators.Add(resultOperatorFactory()); var subQuery = new SubQueryExpression(queryModel); Modified: trunk/nhibernate/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -1,8 +1,11 @@ using System; using System.Linq; +using System.Linq.Expressions; +using NHibernate.Linq.Expressions; using Remotion.Data.Linq; using Remotion.Data.Linq.Clauses; using Remotion.Data.Linq.Clauses.ResultOperators; +using Remotion.Data.Linq.Parsing; namespace NHibernate.Linq.ReWriters { Modified: trunk/nhibernate/src/NHibernate/Linq/ResultOperators/ClientSideTransformOperator.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ResultOperators/ClientSideTransformOperator.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/ResultOperators/ClientSideTransformOperator.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -13,7 +13,7 @@ public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo) { - throw new NotImplementedException(); + return null; } public override ResultOperatorBase Clone(CloneContext cloneContext) Modified: trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -13,16 +13,10 @@ private readonly Delegate _listTransformation; private readonly Delegate _itemTransformation; - public ResultTransformer(LambdaExpression itemTransformation, LambdaExpression listTransformation) + public ResultTransformer(Delegate itemTransformation, Delegate listTransformation) { - if (itemTransformation != null) - { - _itemTransformation = itemTransformation.Compile(); - } - if (listTransformation != null) - { - _listTransformation = listTransformation.Compile(); - } + _itemTransformation = itemTransformation; + _listTransformation = listTransformation; } public object TransformTuple(object[] tuple, string[] aliases) @@ -39,9 +33,9 @@ object transformResult = collection; - if (collection.Count > 0) + //if (collection.Count > 0) { - if (collection[0] is object[]) + if (collection.Count > 0 && collection[0] is object[]) { if ( ((object[])collection[0]).Length != 1) { Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using NHibernate.Engine.Query; using NHibernate.Hql.Ast; @@ -16,6 +17,13 @@ private readonly IList<NamedParameterDescriptor> _requiredHqlParameters; static private readonly FunctionRegistry FunctionRegistry = FunctionRegistry.Initialise(); + public static HqlTreeNode Visit(Expression expression, IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters) + { + var visitor = new HqlGeneratorExpressionTreeVisitor(parameters, requiredHqlParameters); + + return visitor.VisitExpression(expression); + } + public HqlGeneratorExpressionTreeVisitor(IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters) { _parameters = parameters; @@ -23,7 +31,7 @@ _hqlTreeBuilder = new HqlTreeBuilder(); } - public virtual HqlTreeNode Visit(Expression expression) + public HqlTreeNode Visit(Expression expression) { return VisitExpression(expression); } @@ -122,6 +130,8 @@ return VisitNhCount((NhCountExpression)expression); case NhExpressionType.Distinct: return VisitNhDistinct((NhDistinctExpression)expression); + case NhExpressionType.Star: + return VisitNhStar((NhStarExpression) expression); //case NhExpressionType.New: // return VisitNhNew((NhNewExpression)expression); } @@ -130,6 +140,11 @@ } } + protected HqlTreeNode VisitNhStar(NhStarExpression expression) + { + return _hqlTreeBuilder.Star(); + } + private HqlTreeNode VisitInvocationExpression(InvocationExpression expression) { return VisitExpression(expression.Expression); @@ -206,7 +221,7 @@ } // Also check for nullability - if (expression.Left.Type.IsNullable() || expression.Right.Type.IsNullable()) + if (expression.Left.Type.IsNullableOrReference() || expression.Right.Type.IsNullableOrReference()) { // TODO - yuck. This clone is needed because the AST tree nodes are not immutable, // and sharing nodes between multiple branches will cause issues in the hqlSqlWalker phase - @@ -247,7 +262,7 @@ } // Also check for nullability - if (expression.Left.Type.IsNullable() || expression.Right.Type.IsNullable()) + if (expression.Left.Type.IsNullableOrReference() || expression.Right.Type.IsNullableOrReference()) { var lhs2 = VisitExpression(expression.Left).AsExpression(); var rhs2 = VisitExpression(expression.Right).AsExpression(); Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -25,13 +25,23 @@ return VisitNhAggregate((NhAggregatedExpression)expression); case NhExpressionType.New: return VisitNhNew((NhNewExpression) expression); + case NhExpressionType.Star: + return VisitNhStar((NhStarExpression) expression); + } return base.VisitExpression(expression); } - protected virtual Expression VisitNhNew(NhNewExpression expression) + protected virtual Expression VisitNhStar(NhStarExpression expression) { + var newExpression = VisitExpression(expression.Expression); + + return newExpression != expression.Expression ? new NhStarExpression(newExpression) : expression; + } + + protected virtual Expression VisitNhNew(NhNewExpression expression) + { var arguments = VisitExpressionList(expression.Arguments); return arguments != expression.Arguments ? new NhNewExpression(expression.Members, arguments) : expression; @@ -60,42 +70,42 @@ protected virtual Expression VisitNhDistinct(NhDistinctExpression expression) { - Expression nx = base.VisitExpression(expression.Expression); + Expression nx = VisitExpression(expression.Expression); return nx != expression.Expression ? new NhDistinctExpression(nx) : expression; } protected virtual Expression VisitNhCount(NhCountExpression expression) { - Expression nx = base.VisitExpression(expression.Expression); + Expression nx = VisitExpression(expression.Expression); return nx != expression.Expression ? new NhShortCountExpression(nx) : expression; } protected virtual Expression VisitNhSum(NhSumExpression expression) { - Expression nx = base.VisitExpression(expression.Expression); + Expression nx = VisitExpression(expression.Expression); return nx != expression.Expression ? new NhSumExpression(nx) : expression; } protected virtual Expression VisitNhMax(NhMaxExpression expression) { - Expression nx = base.VisitExpression(expression.Expression); + Expression nx = VisitExpression(expression.Expression); return nx != expression.Expression ? new NhMaxExpression(nx) : expression; } protected virtual Expression VisitNhMin(NhMinExpression expression) { - Expression nx = base.VisitExpression(expression.Expression); + Expression nx = VisitExpression(expression.Expression); return nx != expression.Expression ? new NhMinExpression(nx) : expression; } protected virtual Expression VisitNhAverage(NhAverageExpression expression) { - Expression nx = base.VisitExpression(expression.Expression); + Expression nx = VisitExpression(expression.Expression); return nx != expression.Expression ? new NhAverageExpression(nx) : expression; } Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/Nominator.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/Nominator.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/Nominator.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -9,18 +9,21 @@ /// </summary> class Nominator : NhExpressionTreeVisitor { - readonly Func<Expression, bool> _fnIsCandidate; - HashSet<Expression> _candidates; - bool _cannotBeCandidate; + private readonly Func<Expression, bool> _fnIsCandidate; + private readonly Func<Expression, bool> _fnIsCandidateShortcut; + private HashSet<Expression> _candidates; + private bool _canBeCandidate; - internal Nominator(Func<Expression, bool> fnIsCandidate) + internal Nominator(Func<Expression, bool> fnIsCandidate, Func<Expression, bool> fnIsCandidateShortcut) { _fnIsCandidate = fnIsCandidate; + _fnIsCandidateShortcut = fnIsCandidateShortcut; } internal HashSet<Expression> Nominate(Expression expression) { _candidates = new HashSet<Expression>(); + _canBeCandidate = true; VisitExpression(expression); return _candidates; } @@ -29,12 +32,18 @@ { if (expression != null) { - bool saveCannotBeEvaluated = _cannotBeCandidate; - _cannotBeCandidate = false; + bool saveCanBeCandidate = _canBeCandidate; + _canBeCandidate = true; + if (_fnIsCandidateShortcut(expression)) + { + _candidates.Add(expression); + return expression; + } + base.VisitExpression(expression); - if (!_cannotBeCandidate) + if (_canBeCandidate) { if (_fnIsCandidate(expression)) { @@ -42,11 +51,11 @@ } else { - _cannotBeCandidate = true; + _canBeCandidate = false; } } - _cannotBeCandidate |= saveCannotBeEvaluated; + _canBeCandidate = _canBeCandidate & saveCanBeCandidate; } return expression; Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2009-12-05 22:14:08 UTC (rev 4894) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2009-12-16 21:36:34 UTC (rev 4895) @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using System.Reflection; using NHibernate.Engine.Query; using NHibernate.Hql.Ast; using NHibernate.Linq.GroupBy; @@ -14,6 +15,7 @@ using Remotion.Data.Linq.Clauses; using Remotion.Data.Linq.Clauses.Expressions; using Remotion.Data.Linq.Clauses.ResultOperators; +using Remotion.Data.Linq.Clauses.StreamedData; namespace NHibernate.Linq.Visitors { @@ -50,144 +52,115 @@ private readonly HqlTreeBuilder _hqlTreeBuilder; - private readonly List<Action<IQuery, IDictionary<string, Pair<object, IType>>>> _additionalCriteria = new List<Action<IQuery, IDictionary<string, Pair<object, IType>>>>(); + private readonly List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> _additionalCriteria = new List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>>(); private readonly List<LambdaExpression> _listTransformers = new List<LambdaExpression>(); ... [truncated message content] |