|
From: <ste...@us...> - 2009-12-16 21:36:52
|
Revision: 4895
http://nhibernate.svn.sourceforge.net/nhibernate/?rev=4895&view=rev
Author: steverstrong
Date: 2009-12-16 21:36:34 +0000 (Wed, 16 Dec 2009)
Log Message:
-----------
More Linq test cases and supporting fixes.
Modified Paths:
--------------
trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/PolymorphicQuerySourceDetector.cs
trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs
trunk/nhibernate/src/NHibernate/Impl/AbstractQueryImpl.cs
trunk/nhibernate/src/NHibernate/Impl/ExpressionQueryImpl.cs
trunk/nhibernate/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs
trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs
trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs
trunk/nhibernate/src/NHibernate/Linq/Functions/BaseHqlGeneratorForType.cs
trunk/nhibernate/src/NHibernate/Linq/Functions/FunctionRegistry.cs
trunk/nhibernate/src/NHibernate/Linq/Functions/IHqlGeneratorForType.cs
trunk/nhibernate/src/NHibernate/Linq/Functions/QueryableGenerator.cs
trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs
trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs
trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs
trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs
trunk/nhibernate/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs
trunk/nhibernate/src/NHibernate/Linq/ResultOperators/ClientSideTransformOperator.cs
trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs
trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs
trunk/nhibernate/src/NHibernate/Linq/Visitors/Nominator.cs
trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs
trunk/nhibernate/src/NHibernate/NHibernate.csproj
trunk/nhibernate/src/NHibernate.ByteCode.LinFu/NHibernate.ByteCode.LinFu.csproj
trunk/nhibernate/src/NHibernate.DomainModel/NHibernate.DomainModel.csproj
trunk/nhibernate/src/NHibernate.Test/App.config
trunk/nhibernate/src/NHibernate.Test/Linq/Entities/Northwind.cs
trunk/nhibernate/src/NHibernate.Test/Linq/LinqQuerySamples.cs
trunk/nhibernate/src/NHibernate.Test/Linq/ParameterisedQueries.cs
trunk/nhibernate/src/NHibernate.Test/NHibernate.Test.csproj
Added Paths:
-----------
trunk/nhibernate/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs
trunk/nhibernate/src/NHibernate.Test/Linq/CollectionAssert.cs
trunk/nhibernate/src/NHibernate.Test/Linq/Entities/UserDto.cs
trunk/nhibernate/src/NHibernate.Test/Linq/RegresstionTests.cs
trunk/nhibernate/src/NHibernate.Test/Linq/SelectionTests.cs
trunk/nhibernate/src/NHibernate.Test/Linq/WhereSubqueryTests.cs
trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs
Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/PolymorphicQuerySourceDetector.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/PolymorphicQuerySourceDetector.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Hql/Ast/ANTLR/PolymorphicQuerySourceDetector.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -27,20 +27,17 @@
var className = GetClassName(querySource);
var classType = _sessionFactoryHelper.GetImportedClass(className);
- if (classType != null)
- {
- AddImplementorsToMap(querySource, classType);
- }
+ AddImplementorsToMap(querySource, classType == null ? className : classType.FullName);
}
return _map;
}
- private void AddImplementorsToMap(IASTNode querySource, System.Type classType)
+ private void AddImplementorsToMap(IASTNode querySource, string className)
{
- var implementors = _sfi.GetImplementors(classType.FullName);
+ var implementors = _sfi.GetImplementors(className);
- if (implementors.Length == 1 && implementors[0] == classType.FullName)
+ if (implementors.Length == 1 && implementors[0] == className)
{
// No need to change things
return;
Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -341,11 +341,6 @@
return new HqlConcat(_factory, args);
}
- public HqlExpressionList ExpressionList()
- {
- return new HqlExpressionList(_factory);
- }
-
public HqlMethodCall MethodCall(string methodName, IEnumerable<HqlExpression> parameters)
{
return new HqlMethodCall(_factory, methodName, parameters);
@@ -370,6 +365,30 @@
{
return new HqlIsNotNull(_factory, lhs);
}
+
+ public HqlTreeNode ExpressionList(IEnumerable<HqlExpression> expressions)
+ {
+ return new HqlExpressionList(_factory, expressions);
+ }
+
+ public HqlStar Star()
+ {
+ return new HqlStar(_factory);
+ }
+
+ public HqlTrue True()
+ {
+ return new HqlTrue(_factory);
+ }
+
+ public HqlFalse False()
+ {
+ return new HqlFalse(_factory);
+ }
+
+ public HqlIn In(HqlExpression itemExpression, HqlTreeNode source)
+ {
+ return new HqlIn(_factory, itemExpression, source);
+ }
}
-
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -8,11 +8,13 @@
{
public class HqlTreeNode
{
+ public IASTFactory Factory { get; private set; }
private readonly IASTNode _node;
private readonly List<HqlTreeNode> _children;
protected HqlTreeNode(int type, string text, IASTFactory factory, IEnumerable<HqlTreeNode> children)
{
+ Factory = factory;
_node = factory.CreateNode(type, text);
_children = new List<HqlTreeNode>();
@@ -107,18 +109,27 @@
_node.AddChild(child.AstNode);
}
}
+ }
- public HqlExpression AsExpression()
+ public static class HqlTreeNodeExtensions
+ {
+ public static HqlExpression AsExpression(this HqlTreeNode node)
{
// TODO - nice error handling if cast fails
- return (HqlExpression) this;
+ return (HqlExpression)node;
}
- public virtual HqlBooleanExpression AsBooleanExpression()
+ public static HqlBooleanExpression AsBooleanExpression(this HqlTreeNode node)
{
+ if (node is HqlDot)
+ {
+ return new HqlBooleanDot(node.Factory, (HqlDot) node);
+ }
+
// TODO - nice error handling if cast fails
- return (HqlBooleanExpression)this;
+ return (HqlBooleanExpression)node;
}
+
}
public abstract class HqlStatement : HqlTreeNode
@@ -329,20 +340,10 @@
public class HqlDot : HqlExpression
{
- private readonly IASTFactory _factory;
-
public HqlDot(IASTFactory factory, HqlExpression lhs, HqlExpression rhs)
: base(HqlSqlWalker.DOT, ".", factory, lhs, rhs)
{
- _factory = factory;
}
-
- public override HqlBooleanExpression AsBooleanExpression()
- {
- // If we are of boolean type, then we can acts as boolean expression
- // TODO - implement type check
- return new HqlBooleanDot(_factory, this);
- }
}
public class HqlBooleanDot : HqlBooleanExpression
@@ -408,6 +409,23 @@
}
}
+ public class HqlFalse : HqlConstant
+ {
+ public HqlFalse(IASTFactory factory)
+ : base(factory, HqlSqlWalker.FALSE, "false")
+ {
+ }
+ }
+
+ public class HqlTrue : HqlConstant
+ {
+ public HqlTrue(IASTFactory factory)
+ : base(factory, HqlSqlWalker.TRUE, "true")
+ {
+ }
+ }
+
+
public class HqlNull : HqlConstant
{
public HqlNull(IASTFactory factory)
@@ -430,15 +448,23 @@
Descending
}
- public class HqlDirectionAscending : HqlStatement
+ public class HqlDirectionStatement : HqlStatement
{
+ public HqlDirectionStatement(int type, string text, IASTFactory factory)
+ : base(type, text, factory)
+ {
+ }
+ }
+
+ public class HqlDirectionAscending : HqlDirectionStatement
+ {
public HqlDirectionAscending(IASTFactory factory)
: base(HqlSqlWalker.ASCENDING, "asc", factory)
{
}
}
- public class HqlDirectionDescending : HqlStatement
+ public class HqlDirectionDescending : HqlDirectionStatement
{
public HqlDirectionDescending(IASTFactory factory)
: base(HqlSqlWalker.DESCENDING, "desc", factory)
@@ -733,4 +759,28 @@
{
}
}
+
+ public class HqlStar : HqlExpression
+ {
+ public HqlStar(IASTFactory factory) : base(HqlSqlWalker.ROW_STAR, "*", factory)
+ {
+ }
+ }
+
+ public class HqlIn : HqlBooleanExpression
+ {
+ public HqlIn(IASTFactory factory, HqlExpression itemExpression, HqlTreeNode source)
+ : base(HqlSqlWalker.IN, "in", factory, itemExpression)
+ {
+ AddChild(new HqlInList(factory, source));
+ }
+ }
+
+ public class HqlInList : HqlTreeNode
+ {
+ public HqlInList(IASTFactory factory, HqlTreeNode source)
+ : base(HqlSqlWalker.IN_LIST, "inlist", factory, source)
+ {
+ }
+ }
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Impl/AbstractQueryImpl.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Impl/AbstractQueryImpl.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Impl/AbstractQueryImpl.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -27,7 +27,7 @@
private readonly ArrayList values = new ArrayList(4);
private readonly List<IType> types = new List<IType>(4);
private readonly Dictionary<string, TypedValue> namedParameters = new Dictionary<string, TypedValue>(4);
- private readonly Dictionary<string, TypedValue> namedParameterLists = new Dictionary<string, TypedValue>(4);
+ protected readonly Dictionary<string, TypedValue> namedParameterLists = new Dictionary<string, TypedValue>(4);
private bool cacheable;
private string cacheRegion;
private bool readOnly;
Modified: trunk/nhibernate/src/NHibernate/Impl/ExpressionQueryImpl.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Impl/ExpressionQueryImpl.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Impl/ExpressionQueryImpl.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -1,8 +1,17 @@
-using System;
+using System;
using System.Collections;
using System.Collections.Generic;
+using System.Linq;
+using System.Text;
using NHibernate.Engine;
using NHibernate.Engine.Query;
+using NHibernate.Engine.Query.Sql;
+using NHibernate.Hql.Ast.ANTLR;
+using NHibernate.Hql.Ast.ANTLR.Tree;
+using NHibernate.Hql.Ast.ANTLR.Util;
+using NHibernate.Hql.Classic;
+using NHibernate.Type;
+using NHibernate.Util;
namespace NHibernate.Impl
{
@@ -12,12 +21,12 @@
public IQueryExpression QueryExpression { get; private set; }
- public ExpressionQueryImpl(IQueryExpression queryExpression, ISessionImplementor session, ParameterMetadata parameterMetadata)
+ public ExpressionQueryImpl(IQueryExpression queryExpression, ISessionImplementor session, ParameterMetadata parameterMetadata)
: base(queryExpression.Key, FlushMode.Unspecified, session, parameterMetadata)
{
QueryExpression = queryExpression;
}
-
+
public override IQuery SetLockMode(string alias, LockMode lockMode)
{
_lockModes[alias] = lockMode;
@@ -51,7 +60,7 @@
Before();
try
{
- return Session.List(QueryExpression, GetQueryParameters(namedParams));
+ return Session.List(ExpandParameters(namedParams), GetQueryParameters(namedParams));
}
finally
{
@@ -59,6 +68,65 @@
}
}
+ /// <summary>
+ /// Warning: adds new parameters to the argument by side-effect, as well as mutating the query expression tree!
+ /// </summary>
+ protected IQueryExpression ExpandParameters(IDictionary<string, TypedValue> namedParamsCopy)
+ {
+ if (namedParameterLists.Count == 0)
+ {
+ // Short circuit straight out
+ return QueryExpression;
+ }
+
+ // Build a map from single parameters to lists
+ var map = new Dictionary<string, List<string>>();
+
+ foreach (var me in namedParameterLists)
+ {
+ string name = me.Key;
+ var vals = (ICollection) me.Value.Value;
+ var type = me.Value.Type;
+
+ if (vals.Count == 1)
+ {
+ // No expansion needed here
+ var iter = vals.GetEnumerator();
+ iter.MoveNext();
+ namedParamsCopy[name] = new TypedValue(type, iter.Current, Session.EntityMode);
+ continue;
+ }
+
+ var aliases = new List<string>();
+ var i = 0;
+ var isJpaPositionalParam = parameterMetadata.GetNamedParameterDescriptor(name).JpaStyle;
+
+ foreach (var obj in vals)
+ {
+ var alias = (isJpaPositionalParam ? 'x' + name : name + StringHelper.Underscore) + i++ + StringHelper.Underscore;
+ namedParamsCopy[alias] = new TypedValue(type, obj, Session.EntityMode);
+ aliases.Add(alias);
+ }
+
+ map.Add(name, aliases);
+
+ }
+
+ var newTree = ParameterExpander.Expand(QueryExpression.Translate(Session.Factory), map);
+ var key = new StringBuilder(QueryExpression.Key);
+
+ map.Aggregate(key, (sb, kvp) =>
+ {
+ sb.Append(' ');
+ sb.Append(kvp.Key);
+ sb.Append(':');
+ kvp.Value.Aggregate(sb, (sb2, str) => sb2.Append(str));
+ return sb;
+ });
+
+ return new ExpandedQueryExpression(QueryExpression, newTree, key.ToString());
+ }
+
public override void List(IList results)
{
throw new NotImplementedException();
@@ -69,4 +137,141 @@
throw new NotImplementedException();
}
}
+
+ internal class ExpandedQueryExpression : IQueryExpression
+ {
+ private readonly IASTNode _tree;
+
+ public ExpandedQueryExpression(IQueryExpression queryExpression, IASTNode tree, string key)
+ {
+ _tree = tree;
+ Key = key;
+ Type = queryExpression.Type;
+ ParameterDescriptors = queryExpression.ParameterDescriptors;
+ }
+
+ public IASTNode Translate(ISessionFactory sessionFactory)
+ {
+ return _tree;
+ }
+
+ public string Key { get; private set; }
+
+ public System.Type Type { get; private set; }
+
+ public IList<NamedParameterDescriptor> ParameterDescriptors { get; private set; }
+ }
+
+ internal class ParameterExpander
+ {
+ private readonly IASTNode _tree;
+ private readonly Dictionary<string, List<string>> _map;
+
+ private ParameterExpander(IASTNode tree, Dictionary<string, List<string>> map)
+ {
+ _tree = tree;
+ _map = map;
+ }
+
+ public static IASTNode Expand(IASTNode tree, Dictionary<string, List<string>> map)
+ {
+ var expander = new ParameterExpander(tree, map);
+
+ return expander.Expand();
+ }
+
+ private IASTNode Expand()
+ {
+ var parameters = ParameterDetector.LocateParameters(_tree, new HashSet<string>(_map.Keys));
+ var nodeMapping = new Dictionary<IASTNode, IEnumerable<IASTNode>>();
+
+ foreach (var param in parameters)
+ {
+ var paramName = param.GetChild(0);
+ var aliases = _map[paramName.Text];
+ var astAliases = new List<IASTNode>();
+
+ foreach (var alias in aliases)
+ {
+ var astAlias = param.DupNode();
+ var astAliasName = paramName.DupNode();
+ astAliasName.Text = alias;
+ astAlias.AddChild(astAliasName);
+
+ astAliases.Add(astAlias);
+ }
+
+ nodeMapping.Add(param, astAliases);
+ }
+
+ return DuplicateTree(_tree, nodeMapping);
+ }
+
+ private static IASTNode DuplicateTree(IASTNode ast, IDictionary<IASTNode, IEnumerable<IASTNode>> nodeMapping)
+ {
+ var thisNode = ast.DupNode();
+
+ foreach (var child in ast)
+ {
+ IEnumerable<IASTNode> candidate;
+
+ if (nodeMapping.TryGetValue(child, out candidate))
+ {
+ foreach (var replacement in candidate)
+ {
+ thisNode.AddChild(replacement);
+ }
+ }
+ else
+ {
+ thisNode.AddChild(DuplicateTree(child, nodeMapping));
+ }
+ }
+
+ return thisNode;
+ }
+ }
+
+ internal class ParameterDetector : IVisitationStrategy
+ {
+ private readonly IASTNode _tree;
+ private readonly HashSet<string> _parameterNames;
+ private readonly List<IASTNode> _nodes;
+
+ private ParameterDetector(IASTNode tree, HashSet<string> parameterNames)
+ {
+ _tree = tree;
+ _parameterNames = parameterNames;
+ _nodes = new List<IASTNode>();
+ }
+
+ public static IList<IASTNode> LocateParameters(IASTNode tree, HashSet<string> parameterNames)
+ {
+ var detector = new ParameterDetector(tree, parameterNames);
+
+ return detector.LocateParameters();
+ }
+
+ private IList<IASTNode> LocateParameters()
+ {
+ var nodeTraverser = new NodeTraverser(this);
+ nodeTraverser.TraverseDepthFirst(_tree);
+
+ return _nodes;
+ }
+
+ public void Visit(IASTNode node)
+ {
+ if ((node.Type == HqlSqlWalker.PARAM) || (node.Type == HqlSqlWalker.COLON))
+ {
+ var name = node.GetChild(0).Text;
+
+ if (_parameterNames.Contains(name))
+ {
+ _nodes.Add(node);
+ }
+ }
+ }
+
+ }
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -9,17 +9,24 @@
{
public class ExpressionToHqlTranslationResults
{
- public HqlQuery Statement { get; private set; }
+ public HqlTreeNode Statement { get; private set; }
public ResultTransformer ResultTransformer { get; private set; }
- public List<Action<IQuery, IDictionary<string, Pair<object, IType>>>> AdditionalCriteria { get; private set; }
+ public Delegate PostExecuteTransformer { get; private set; }
+ public List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> AdditionalCriteria { get; private set; }
- public ExpressionToHqlTranslationResults(HqlQuery statement, IList<LambdaExpression> itemTransformers, IList<LambdaExpression> listTransformers, List<Action<IQuery, IDictionary<string, Pair<object, IType>>>> additionalCriteria)
+ public ExpressionToHqlTranslationResults(HqlTreeNode statement,
+ IList<LambdaExpression> itemTransformers,
+ IList<LambdaExpression> listTransformers,
+ IList<LambdaExpression> postExecuteTransformers,
+ List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> additionalCriteria)
{
Statement = statement;
- var itemTransformer = MergeLambdas(itemTransformers);
- var listTransformer = MergeLambdas(listTransformers);
+ PostExecuteTransformer = MergeLambdasAndCompile(postExecuteTransformers);
+ var itemTransformer = MergeLambdasAndCompile(itemTransformers);
+ var listTransformer = MergeLambdasAndCompile(listTransformers);
+
if (itemTransformer != null || listTransformer != null)
{
ResultTransformer = new ResultTransformer(itemTransformer, listTransformer);
@@ -28,7 +35,7 @@
AdditionalCriteria = additionalCriteria;
}
- private static LambdaExpression MergeLambdas(IList<LambdaExpression> transformations)
+ private static Delegate MergeLambdasAndCompile(IList<LambdaExpression> transformations)
{
if (transformations == null || transformations.Count == 0)
{
@@ -44,7 +51,7 @@
listTransformLambda = Expression.Lambda(invoked, listTransformLambda.Parameters.ToArray());
}
- return listTransformLambda;
+ return listTransformLambda.Compile();
}
}
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhExpressionType.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -8,6 +8,7 @@
Sum,
Count,
Distinct,
- New
+ New,
+ Star
}
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/Expressions/NhNewExpression.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -1,4 +1,5 @@
-using System.Collections.Generic;
+using System;
+using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq.Expressions;
@@ -26,4 +27,17 @@
get { return _members; }
}
}
+
+ public class NhStarExpression : Expression
+ {
+ public NhStarExpression(Expression expression) : base((ExpressionType) NhExpressionType.Star, expression.Type)
+ {
+ Expression = expression;
+ }
+
+ public Expression Expression
+ {
+ get; private set;
+ }
+ }
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/Functions/BaseHqlGeneratorForType.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Functions/BaseHqlGeneratorForType.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/Functions/BaseHqlGeneratorForType.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -1,4 +1,6 @@
-using System.Collections.Generic;
+using System;
+using System.Collections.Generic;
+using System.Reflection;
using NHibernate.Linq.Visitors;
namespace NHibernate.Linq.Functions
@@ -26,5 +28,15 @@
}
}
}
+
+ public virtual bool SupportsMethod(MethodInfo method)
+ {
+ return false;
+ }
+
+ public virtual IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method)
+ {
+ throw new NotSupportedException();
+ }
}
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/Functions/FunctionRegistry.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Functions/FunctionRegistry.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/Functions/FunctionRegistry.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -36,12 +36,14 @@
registry.Register(new QueryableGenerator());
registry.Register(new StringGenerator());
registry.Register(new DateTimeGenerator());
+ registry.Register(new ICollectionGenerator());
return registry;
}
private readonly Dictionary<MethodInfo, IHqlGeneratorForMethod> _registeredMethods = new Dictionary<MethodInfo, IHqlGeneratorForMethod>();
private readonly Dictionary<MemberInfo, IHqlGeneratorForProperty> _registeredProperties = new Dictionary<MemberInfo, IHqlGeneratorForProperty>();
+ private readonly List<IHqlGeneratorForType> _typeGenerators = new List<IHqlGeneratorForType>();
public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method)
{
@@ -58,14 +60,23 @@
}
// No method generator registered. Look to see if it's a standard LinqExtensionMethod
- var attr = (LinqExtensionMethodAttribute) method.GetCustomAttributes(typeof (LinqExtensionMethodAttribute), false)[0];
- if (attr != null)
+ var attr = method.GetCustomAttributes(typeof (LinqExtensionMethodAttribute), false);
+ if (attr.Length == 1)
{
// It is
// TODO - cache this? Is it worth it?
- return new HqlGeneratorForExtensionMethod(attr, method);
+ return new HqlGeneratorForExtensionMethod((LinqExtensionMethodAttribute) attr[0], method);
}
+ // Not that either. Let's query each type generator to see if it can handle it
+ foreach (var typeGenerator in _typeGenerators)
+ {
+ if (typeGenerator.SupportsMethod(method))
+ {
+ return typeGenerator.GetMethodGenerator(method);
+ }
+ }
+
throw new NotSupportedException(method.ToString());
}
@@ -94,6 +105,7 @@
private void Register(IHqlGeneratorForType typeMethodGenerator)
{
+ _typeGenerators.Add(typeMethodGenerator);
typeMethodGenerator.Register(this);
}
}
Modified: trunk/nhibernate/src/NHibernate/Linq/Functions/IHqlGeneratorForType.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Functions/IHqlGeneratorForType.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/Functions/IHqlGeneratorForType.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -1,7 +1,11 @@
-namespace NHibernate.Linq.Functions
+using System.Reflection;
+
+namespace NHibernate.Linq.Functions
{
public interface IHqlGeneratorForType
{
void Register(FunctionRegistry functionRegistry);
+ bool SupportsMethod(MethodInfo method);
+ IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method);
}
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/Functions/QueryableGenerator.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Functions/QueryableGenerator.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/Functions/QueryableGenerator.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -1,4 +1,6 @@
-using System.Collections.ObjectModel;
+using System;
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
@@ -126,4 +128,74 @@
}
}
}
+
+ public class ICollectionGenerator : BaseHqlGeneratorForType
+ {
+ public ICollectionGenerator()
+ {
+ // TODO - could use reflection
+ MethodRegistry.Add(new ContainsGenerator());
+ }
+
+ public override bool SupportsMethod(MethodInfo method)
+ {
+ var declaringType = method.DeclaringType;
+
+ if (declaringType.IsGenericType)
+ {
+ if (declaringType.GetGenericTypeDefinition() == typeof(ICollection<>) ||
+ declaringType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
+ {
+ if (method.Name == "Contains")
+ {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ public override IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method)
+ {
+ // TODO - ick
+ if (method.Name == "Contains")
+ {
+ return new ContainsGenerator();
+ }
+
+ throw new NotSupportedException(method.Name);
+ }
+
+ class ContainsGenerator : BaseHqlGeneratorForMethod
+ {
+ public ContainsGenerator()
+ {
+ SupportedMethods = new MethodInfo[0];
+ }
+
+ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
+ {
+ // TODO - alias generator
+ var alias = treeBuilder.Alias("x");
+
+ var param = Expression.Parameter(targetObject.Type, "x");
+ var where = treeBuilder.Where(visitor.Visit(Expression.Lambda(
+ Expression.Equal(param, arguments[0]), param))
+ .AsExpression());
+
+ return treeBuilder.Exists(
+ treeBuilder.Query(
+ treeBuilder.SelectFrom(
+ treeBuilder.From(
+ treeBuilder.Range(
+ visitor.Visit(targetObject),
+ alias)
+ )
+ ),
+ where));
+ }
+ }
+
+ }
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -24,10 +24,19 @@
return (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>));
}
+ public static bool IsNullableOrReference(this System.Type type)
+ {
+ return !type.IsValueType || (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>));
+ }
+
public static System.Type NullableOf(this System.Type type)
{
return type.GetGenericArguments()[0];
}
+ public static T As<T>(this object source)
+ {
+ return (T) source;
+ }
}
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -9,7 +9,6 @@
using NHibernate.Type;
using Remotion.Data.Linq;
using Remotion.Data.Linq.Clauses;
-using Remotion.Data.Linq.Clauses.StreamedData;
using Remotion.Data.Linq.Parsing.ExpressionTreeVisitors;
using Remotion.Data.Linq.Parsing.Structure;
using Remotion.Data.Linq.Parsing.Structure.IntermediateModel;
@@ -26,14 +25,15 @@
public NhLinqExpressionReturnType ReturnType { get; private set; }
- public IDictionary<string, Pair<object, IType>> ParameterValuesByName { get; private set; }
+ public IDictionary<string, Tuple<object, IType>> ParameterValuesByName { get; private set; }
public ExpressionToHqlTranslationResults ExpressionToHqlTranslationResults { get; private set; }
private readonly Expression _expression;
private readonly IDictionary<ConstantExpression, NamedParameter> _constantToParameterMap;
+ private IASTNode _astNode;
- public NhLinqExpression(Expression expression)
+ public NhLinqExpression(Expression expression)
{
_expression = PartialEvaluatingExpressionTreeVisitor.EvaluateIndependentSubtrees(expression);
@@ -43,8 +43,8 @@
ParameterValuesByName = _constantToParameterMap.Values.ToDictionary(p => p.Name,
p =>
- new Pair<object, IType>
- {Left = p.Value, Right = p.Type});
+ new Tuple<object, IType>
+ {First = p.Value, Second = p.Type});
Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap);
@@ -62,16 +62,22 @@
public IASTNode Translate(ISessionFactory sessionFactory)
{
- var requiredHqlParameters = new List<NamedParameterDescriptor>();
+ //if (_astNode == null)
+ {
+ var requiredHqlParameters = new List<NamedParameterDescriptor>();
- // TODO - can we cache any of this?
- var queryModel = NhRelinqQueryParser.Parse(_expression);
+ // TODO - can we cache any of this?
+ var queryModel = NhRelinqQueryParser.Parse(_expression);
- ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, _constantToParameterMap, requiredHqlParameters);
+ ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel,
+ _constantToParameterMap,
+ requiredHqlParameters);
- ParameterDescriptors = requiredHqlParameters.AsReadOnly();
+ ParameterDescriptors = requiredHqlParameters.AsReadOnly();
+ _astNode = ExpressionToHqlTranslationResults.Statement.AstNode;
+ }
- return ExpressionToHqlTranslationResults.Statement.AstNode;
+ return _astNode;
}
}
@@ -89,6 +95,14 @@
MethodCallExpressionNodeTypeRegistry.GetRegisterableMethodDefinition(ReflectionHelper.GetMethod(() => Queryable.Aggregate<object, object>(null, null, null)))
},
typeof (AggregateExpressionNode));
+
+ MethodCallRegistry.Register(
+ new []
+ {
+ MethodCallExpressionNodeTypeRegistry.GetRegisterableMethodDefinition(ReflectionHelper.GetMethod((List<object> l) => l.Contains(null))),
+
+ },
+ typeof(ContainsExpressionNode));
}
public static QueryModel Parse(Expression expression)
@@ -179,15 +193,5 @@
Accumulator = accumulator;
OptionalSelector = optionalSelector;
}
-
- public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo)
- {
- throw new NotImplementedException();
- }
-
- public override ResultOperatorBase Clone(CloneContext cloneContext)
- {
- throw new NotImplementedException();
- }
}
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/NhQueryProvider.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -1,7 +1,9 @@
using System;
+using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
+using System.Reflection;
using NHibernate.Impl;
using NHibernate.Type;
@@ -22,12 +24,26 @@
var query = _session.CreateQuery(nhLinqExpression);
+ var nhQuery = query.As<ExpressionQueryImpl>().QueryExpression.As<NhLinqExpression>();
+
SetParameters(query, nhLinqExpression.ParameterValuesByName);
- SetResultTransformerAndAdditionalCriteria(query, nhLinqExpression.ParameterValuesByName);
+ SetResultTransformerAndAdditionalCriteria(query, nhQuery, nhLinqExpression.ParameterValuesByName);
var results = query.List();
- if (nhLinqExpression.ReturnType == NhLinqExpressionReturnType.Sequence)
+ if (nhQuery.ExpressionToHqlTranslationResults.PostExecuteTransformer != null)
+ {
+ try
+ {
+ return nhQuery.ExpressionToHqlTranslationResults.PostExecuteTransformer.DynamicInvoke(results.AsQueryable());
+ }
+ catch (TargetInvocationException e)
+ {
+ throw e.InnerException;
+ }
+ }
+
+ if (nhLinqExpression.ReturnType == NhLinqExpressionReturnType.Sequence)
{
return results.AsQueryable();
}
@@ -52,28 +68,39 @@
return new NhQueryable<T>(this, expression);
}
- static void SetParameters(IQuery query, IDictionary<string, Pair<object, IType>> parameters)
+ static void SetParameters(IQuery query, IDictionary<string, Tuple<object, IType>> parameters)
{
foreach (var parameterName in query.NamedParameters)
{
var param = parameters[parameterName];
- if (param.Left == null)
+
+ if (param.First == null)
{
- query.SetParameter(parameterName, param.Left, param.Right);
+ if (typeof(ICollection).IsAssignableFrom(param.Second.ReturnedClass))
+ {
+ query.SetParameterList(parameterName, null, param.Second);
+ }
+ else
+ {
+ query.SetParameter(parameterName, null, param.Second);
+ }
}
else
{
- query.SetParameter(parameterName, param.Left);
+ if (param.First is ICollection)
+ {
+ query.SetParameterList(parameterName, (ICollection) param.First);
+ }
+ else
+ {
+ query.SetParameter(parameterName, param.First);
+ }
}
}
}
- public void SetResultTransformerAndAdditionalCriteria(IQuery query, IDictionary<string, Pair<object, IType>> parameters)
+ public void SetResultTransformerAndAdditionalCriteria(IQuery query, NhLinqExpression nhExpression, IDictionary<string, Tuple<object, IType>> parameters)
{
- var queryImpl = (ExpressionQueryImpl) query;
-
- var nhExpression = (NhLinqExpression) queryImpl.QueryExpression;
-
query.SetResultTransformer(nhExpression.ExpressionToHqlTranslationResults.ResultTransformer);
foreach (var criteria in nhExpression.ExpressionToHqlTranslationResults.AdditionalCriteria)
@@ -83,9 +110,17 @@
}
}
- public class Pair<TLeft, TRight>
+ public class Tuple<T1, T2>
{
- public TLeft Left { get; set; }
- public TRight Right { get; set; }
+ public T1 First { get; set; }
+ public T2 Second { get; set; }
+
}
+
+ public class Tuple<T1, T2, T3>
+ {
+ public T1 First { get; set; }
+ public T2 Second { get; set; }
+ public T3 Third { get; set; }
+ }
}
\ No newline at end of file
Modified: trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -53,29 +53,40 @@
}
else if (resultOperator is CountResultOperator)
{
- queryModel.SelectClause.Selector = new NhShortCountExpression(queryModel.SelectClause.Selector);
+ queryModel.SelectClause.Selector = new NhShortCountExpression(new NhStarExpression(queryModel.SelectClause.Selector));
queryModel.ResultOperators.Remove(resultOperator);
}
else if (resultOperator is LongCountResultOperator)
{
- queryModel.SelectClause.Selector = new NhLongCountExpression(queryModel.SelectClause.Selector);
+ queryModel.SelectClause.Selector = new NhLongCountExpression(new NhStarExpression(queryModel.SelectClause.Selector));
queryModel.ResultOperators.Remove(resultOperator);
}
base.VisitResultOperator(resultOperator, queryModel, index);
}
-
+
public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel)
{
- selectClause.TransformExpressions(s => new MergeAggregatingResultsInExpressionRewriter().Visit(s));
+ selectClause.TransformExpressions(MergeAggregatingResultsInExpressionRewriter.Rewrite);
}
+
+ public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index)
+ {
+ whereClause.TransformExpressions(MergeAggregatingResultsInExpressionRewriter.Rewrite);
+ }
}
-
+
internal class MergeAggregatingResultsInExpressionRewriter : NhExpressionTreeVisitor
{
- public Expression Visit(Expression expression)
+ private MergeAggregatingResultsInExpressionRewriter()
+ {
+ }
+
+ public static Expression Rewrite(Expression expression)
{
- return VisitExpression(expression);
+ var visitor = new MergeAggregatingResultsInExpressionRewriter();
+
+ return visitor.VisitExpression(expression);
}
protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
@@ -83,7 +94,7 @@
MergeAggregatingResultsRewriter.ReWrite(expression.QueryModel);
return expression;
}
-
+
protected override Expression VisitMethodCallExpression(MethodCallExpression m)
{
if (m.Method.DeclaringType == typeof(Queryable) ||
@@ -94,36 +105,43 @@
{
case "Count":
return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1],
- e => new NhShortCountExpression(e));
+ e => new NhShortCountExpression(e),
+ () => new CountResultOperator());
case "Min":
return CreateAggregate(m.Arguments[0], (LambdaExpression) m.Arguments[1],
- e => new NhMinExpression(e));
+ e => new NhMinExpression(e),
+ () => new MinResultOperator());
case "Max":
return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1],
- e => new NhMaxExpression(e));
+ e => new NhMaxExpression(e),
+ () => new MaxResultOperator());
case "Sum":
return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1],
- e => new NhSumExpression(e));
+ e => new NhSumExpression(e),
+ () => new SumResultOperator());
case "Average":
return CreateAggregate(m.Arguments[0], (LambdaExpression)m.Arguments[1],
- e => new NhAverageExpression(e));
+ e => new NhAverageExpression(e),
+ () => new AverageResultOperator());
}
}
return base.VisitMethodCallExpression(m);
}
- private Expression CreateAggregate(Expression fromClauseExpression, LambdaExpression body, Func<Expression,Expression> factory)
+ private Expression CreateAggregate(Expression fromClauseExpression, LambdaExpression body, Func<Expression,Expression> aggregateFactory, Func<ResultOperatorBase> resultOperatorFactory)
{
+ // TODO - need generated name here
var fromClause = new MainFromClause("x2", body.Parameters[0].Type, fromClauseExpression);
var selectClause = body.Body;
selectClause = ReplacingExpressionTreeVisitor.Replace(body.Parameters[0],
new QuerySourceReferenceExpression(
fromClause), selectClause);
var queryModel = new QueryModel(fromClause,
- new SelectClause(factory(selectClause)));
+ new SelectClause(aggregateFactory(selectClause)));
- queryModel.ResultOperators.Add(new AverageResultOperator());
+ // TODO - this sucks, but we use it to get the Type of the SubQueryExpression correct
+ queryModel.ResultOperators.Add(resultOperatorFactory());
var subQuery = new SubQueryExpression(queryModel);
Modified: trunk/nhibernate/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/RemoveUnnecessaryBodyOperators.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -1,8 +1,11 @@
using System;
using System.Linq;
+using System.Linq.Expressions;
+using NHibernate.Linq.Expressions;
using Remotion.Data.Linq;
using Remotion.Data.Linq.Clauses;
using Remotion.Data.Linq.Clauses.ResultOperators;
+using Remotion.Data.Linq.Parsing;
namespace NHibernate.Linq.ReWriters
{
Modified: trunk/nhibernate/src/NHibernate/Linq/ResultOperators/ClientSideTransformOperator.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/ResultOperators/ClientSideTransformOperator.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/ResultOperators/ClientSideTransformOperator.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -13,7 +13,7 @@
public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo)
{
- throw new NotImplementedException();
+ return null;
}
public override ResultOperatorBase Clone(CloneContext cloneContext)
Modified: trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -13,16 +13,10 @@
private readonly Delegate _listTransformation;
private readonly Delegate _itemTransformation;
- public ResultTransformer(LambdaExpression itemTransformation, LambdaExpression listTransformation)
+ public ResultTransformer(Delegate itemTransformation, Delegate listTransformation)
{
- if (itemTransformation != null)
- {
- _itemTransformation = itemTransformation.Compile();
- }
- if (listTransformation != null)
- {
- _listTransformation = listTransformation.Compile();
- }
+ _itemTransformation = itemTransformation;
+ _listTransformation = listTransformation;
}
public object TransformTuple(object[] tuple, string[] aliases)
@@ -39,9 +33,9 @@
object transformResult = collection;
- if (collection.Count > 0)
+ //if (collection.Count > 0)
{
- if (collection[0] is object[])
+ if (collection.Count > 0 && collection[0] is object[])
{
if ( ((object[])collection[0]).Length != 1)
{
Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Linq.Expressions;
using NHibernate.Engine.Query;
using NHibernate.Hql.Ast;
@@ -16,6 +17,13 @@
private readonly IList<NamedParameterDescriptor> _requiredHqlParameters;
static private readonly FunctionRegistry FunctionRegistry = FunctionRegistry.Initialise();
+ public static HqlTreeNode Visit(Expression expression, IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters)
+ {
+ var visitor = new HqlGeneratorExpressionTreeVisitor(parameters, requiredHqlParameters);
+
+ return visitor.VisitExpression(expression);
+ }
+
public HqlGeneratorExpressionTreeVisitor(IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters)
{
_parameters = parameters;
@@ -23,7 +31,7 @@
_hqlTreeBuilder = new HqlTreeBuilder();
}
- public virtual HqlTreeNode Visit(Expression expression)
+ public HqlTreeNode Visit(Expression expression)
{
return VisitExpression(expression);
}
@@ -122,6 +130,8 @@
return VisitNhCount((NhCountExpression)expression);
case NhExpressionType.Distinct:
return VisitNhDistinct((NhDistinctExpression)expression);
+ case NhExpressionType.Star:
+ return VisitNhStar((NhStarExpression) expression);
//case NhExpressionType.New:
// return VisitNhNew((NhNewExpression)expression);
}
@@ -130,6 +140,11 @@
}
}
+ protected HqlTreeNode VisitNhStar(NhStarExpression expression)
+ {
+ return _hqlTreeBuilder.Star();
+ }
+
private HqlTreeNode VisitInvocationExpression(InvocationExpression expression)
{
return VisitExpression(expression.Expression);
@@ -206,7 +221,7 @@
}
// Also check for nullability
- if (expression.Left.Type.IsNullable() || expression.Right.Type.IsNullable())
+ if (expression.Left.Type.IsNullableOrReference() || expression.Right.Type.IsNullableOrReference())
{
// TODO - yuck. This clone is needed because the AST tree nodes are not immutable,
// and sharing nodes between multiple branches will cause issues in the hqlSqlWalker phase -
@@ -247,7 +262,7 @@
}
// Also check for nullability
- if (expression.Left.Type.IsNullable() || expression.Right.Type.IsNullable())
+ if (expression.Left.Type.IsNullableOrReference() || expression.Right.Type.IsNullableOrReference())
{
var lhs2 = VisitExpression(expression.Left).AsExpression();
var rhs2 = VisitExpression(expression.Right).AsExpression();
Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/Visitors/NhExpressionTreeVisitor.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -25,13 +25,23 @@
return VisitNhAggregate((NhAggregatedExpression)expression);
case NhExpressionType.New:
return VisitNhNew((NhNewExpression) expression);
+ case NhExpressionType.Star:
+ return VisitNhStar((NhStarExpression) expression);
+
}
return base.VisitExpression(expression);
}
- protected virtual Expression VisitNhNew(NhNewExpression expression)
+ protected virtual Expression VisitNhStar(NhStarExpression expression)
{
+ var newExpression = VisitExpression(expression.Expression);
+
+ return newExpression != expression.Expression ? new NhStarExpression(newExpression) : expression;
+ }
+
+ protected virtual Expression VisitNhNew(NhNewExpression expression)
+ {
var arguments = VisitExpressionList(expression.Arguments);
return arguments != expression.Arguments ? new NhNewExpression(expression.Members, arguments) : expression;
@@ -60,42 +70,42 @@
protected virtual Expression VisitNhDistinct(NhDistinctExpression expression)
{
- Expression nx = base.VisitExpression(expression.Expression);
+ Expression nx = VisitExpression(expression.Expression);
return nx != expression.Expression ? new NhDistinctExpression(nx) : expression;
}
protected virtual Expression VisitNhCount(NhCountExpression expression)
{
- Expression nx = base.VisitExpression(expression.Expression);
+ Expression nx = VisitExpression(expression.Expression);
return nx != expression.Expression ? new NhShortCountExpression(nx) : expression;
}
protected virtual Expression VisitNhSum(NhSumExpression expression)
{
- Expression nx = base.VisitExpression(expression.Expression);
+ Expression nx = VisitExpression(expression.Expression);
return nx != expression.Expression ? new NhSumExpression(nx) : expression;
}
protected virtual Expression VisitNhMax(NhMaxExpression expression)
{
- Expression nx = base.VisitExpression(expression.Expression);
+ Expression nx = VisitExpression(expression.Expression);
return nx != expression.Expression ? new NhMaxExpression(nx) : expression;
}
protected virtual Expression VisitNhMin(NhMinExpression expression)
{
- Expression nx = base.VisitExpression(expression.Expression);
+ Expression nx = VisitExpression(expression.Expression);
return nx != expression.Expression ? new NhMinExpression(nx) : expression;
}
protected virtual Expression VisitNhAverage(NhAverageExpression expression)
{
- Expression nx = base.VisitExpression(expression.Expression);
+ Expression nx = VisitExpression(expression.Expression);
return nx != expression.Expression ? new NhAverageExpression(nx) : expression;
}
Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/Nominator.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Visitors/Nominator.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/Visitors/Nominator.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -9,18 +9,21 @@
/// </summary>
class Nominator : NhExpressionTreeVisitor
{
- readonly Func<Expression, bool> _fnIsCandidate;
- HashSet<Expression> _candidates;
- bool _cannotBeCandidate;
+ private readonly Func<Expression, bool> _fnIsCandidate;
+ private readonly Func<Expression, bool> _fnIsCandidateShortcut;
+ private HashSet<Expression> _candidates;
+ private bool _canBeCandidate;
- internal Nominator(Func<Expression, bool> fnIsCandidate)
+ internal Nominator(Func<Expression, bool> fnIsCandidate, Func<Expression, bool> fnIsCandidateShortcut)
{
_fnIsCandidate = fnIsCandidate;
+ _fnIsCandidateShortcut = fnIsCandidateShortcut;
}
internal HashSet<Expression> Nominate(Expression expression)
{
_candidates = new HashSet<Expression>();
+ _canBeCandidate = true;
VisitExpression(expression);
return _candidates;
}
@@ -29,12 +32,18 @@
{
if (expression != null)
{
- bool saveCannotBeEvaluated = _cannotBeCandidate;
- _cannotBeCandidate = false;
+ bool saveCanBeCandidate = _canBeCandidate;
+ _canBeCandidate = true;
+ if (_fnIsCandidateShortcut(expression))
+ {
+ _candidates.Add(expression);
+ return expression;
+ }
+
base.VisitExpression(expression);
- if (!_cannotBeCandidate)
+ if (_canBeCandidate)
{
if (_fnIsCandidate(expression))
{
@@ -42,11 +51,11 @@
}
else
{
- _cannotBeCandidate = true;
+ _canBeCandidate = false;
}
}
- _cannotBeCandidate |= saveCannotBeEvaluated;
+ _canBeCandidate = _canBeCandidate & saveCanBeCandidate;
}
return expression;
Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs
===================================================================
--- trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2009-12-05 22:14:08 UTC (rev 4894)
+++ trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2009-12-16 21:36:34 UTC (rev 4895)
@@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
+using System.Reflection;
using NHibernate.Engine.Query;
using NHibernate.Hql.Ast;
using NHibernate.Linq.GroupBy;
@@ -14,6 +15,7 @@
using Remotion.Data.Linq.Clauses;
using Remotion.Data.Linq.Clauses.Expressions;
using Remotion.Data.Linq.Clauses.ResultOperators;
+using Remotion.Data.Linq.Clauses.StreamedData;
namespace NHibernate.Linq.Visitors
{
@@ -50,144 +52,115 @@
private readonly HqlTreeBuilder _hqlTreeBuilder;
- private readonly List<Action<IQuery, IDictionary<string, Pair<object, IType>>>> _additionalCriteria = new List<Action<IQuery, IDictionary<string, Pair<object, IType>>>>();
+ private readonly List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> _additionalCriteria = new List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>>();
private readonly List<LambdaExpression> _listTransformers = new List<LambdaExpression>();
...
[truncated message content] |