From: <ste...@us...> - 2010-02-24 15:52:33
|
Revision: 4945 http://nhibernate.svn.sourceforge.net/nhibernate/?rev=4945&view=rev Author: steverstrong Date: 2010-02-24 15:52:13 +0000 (Wed, 24 Feb 2010) Log Message: ----------- Linq provider now adds correct left join when accessing entity references in projections. Also some refactoring to the ResultOperator processing Modified Paths: -------------- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs trunk/nhibernate/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs trunk/nhibernate/src/NHibernate/NHibernate.csproj trunk/nhibernate/src/NHibernate.Test/Linq/FunctionTests.cs trunk/nhibernate/src/NHibernate.Test/Linq/LinqTestCase.cs trunk/nhibernate/src/NHibernate.Test/Linq/MethodCallTests.cs trunk/nhibernate/src/NHibernate.Test/Linq/MiscellaneousTextFixture.cs trunk/nhibernate/src/NHibernate.Test/Linq/QueryReuseTests.cs trunk/nhibernate/src/NHibernate.Test/Linq/RegresstionTests.cs trunk/nhibernate/src/NHibernate.Test/Linq/SelectionTests.cs trunk/nhibernate/src/NHibernate.Test/Linq/WhereTests.cs trunk/nhibernate/src/NHibernate.Test/ProjectionFixtures/Fixture.cs Added Paths: ----------- trunk/nhibernate/src/NHibernate/Linq/Clauses/ trunk/nhibernate/src/NHibernate/Linq/Clauses/LeftJoinClause.cs trunk/nhibernate/src/NHibernate/Linq/ReWriters/AddLeftJoinsReWriter.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/LeftJoinDetector.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/NameGenerator.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/IResultOperatorProcessor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAny.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessClientSideSelect.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessFirst.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessFirstOrSingleBase.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessGroupBy.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessResultOperatorReturn.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessSingle.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessSkip.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessTake.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ResultOperatorMap.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ResultOperatorProcessor.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ResultOperatorProcessorBase.cs trunk/nhibernate/src/NHibernate/Linq/Visitors/VisitorParameters.cs Property Changed: ---------------- trunk/nhibernate/src/ Property changes on: trunk/nhibernate/src ___________________________________________________________________ Modified: svn:ignore - *.suo CloverSrc _ReSharper* *.resharperoptions *.resharper.user CloverBuild Ankh.Load *.resharper ConsoleTest _UpgradeReport_Files NHibernate.userprefs NHibernate.usertasks UpgradeLog.XML UpgradeLog2.XML UpgradeLog3.XML UpgradeLog4.XML UpgradeLog5.XML UpgradeLog6.XML UpgradeLog7.XML UpgradeLog8.XML UpgradeLog9.XML NHibernate.sln.proj NHibernate.sln.AssemblySurfaceCache.user NHibernate.sln.cache + *.suo CloverSrc _ReSharper* *.resharperoptions *.resharper.user CloverBuild Ankh.Load *.resharper ConsoleTest _UpgradeReport_Files NHibernate.userprefs NHibernate.usertasks UpgradeLog.XML UpgradeLog2.XML UpgradeLog3.XML UpgradeLog4.XML UpgradeLog5.XML UpgradeLog6.XML UpgradeLog7.XML UpgradeLog8.XML UpgradeLog9.XML NHibernate.sln.proj NHibernate.sln.AssemblySurfaceCache.user NHibernate.sln.cache .git Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2010-02-16 22:50:19 UTC (rev 4944) +++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -390,5 +390,10 @@ { return new HqlIn(_factory, itemExpression, source); } + + public HqlTreeNode LeftJoin(HqlExpression expression, HqlAlias @alias) + { + return new HqlLeftJoin(_factory, expression, @alias); + } } } \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2010-02-16 22:50:19 UTC (rev 4944) +++ trunk/nhibernate/src/NHibernate/Hql/Ast/HqlTreeNode.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -667,6 +667,28 @@ } } + public class HqlLeftJoin : HqlTreeNode + { + public HqlLeftJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) : base(HqlSqlWalker.JOIN, "join", factory, new HqlLeft(factory), expression, @alias) + { + } + } + + public class HqlFetch : HqlTreeNode + { + public HqlFetch(IASTFactory factory) : base(HqlSqlWalker.FETCH, "fetch", factory) + { + } + } + + public class HqlLeft : HqlTreeNode + { + public HqlLeft(IASTFactory factory) + : base(HqlSqlWalker.LEFT, "left", factory) + { + } + } + public class HqlAny : HqlBooleanExpression { public HqlAny(IASTFactory factory) : base(HqlSqlWalker.ANY, "any", factory) Added: trunk/nhibernate/src/NHibernate/Linq/Clauses/LeftJoinClause.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Clauses/LeftJoinClause.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Clauses/LeftJoinClause.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -0,0 +1,12 @@ +using System.Linq.Expressions; +using Remotion.Data.Linq.Clauses; + +namespace NHibernate.Linq.Visitors +{ + public class LeftJoinClause : AdditionalFromClause + { + public LeftJoinClause(string itemName, System.Type itemType, Expression fromExpression) : base(itemName, itemType, fromExpression) + { + } + } +} \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs 2010-02-16 22:50:19 UTC (rev 4944) +++ trunk/nhibernate/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -96,7 +96,7 @@ } } - internal class ClientSideSelect : ClientSideTransformOperator + public class ClientSideSelect : ClientSideTransformOperator { public LambdaExpression SelectClause { get; private set; } Modified: trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs 2010-02-16 22:50:19 UTC (rev 4944) +++ trunk/nhibernate/src/NHibernate/Linq/LinqExtensionMethods.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -11,7 +11,7 @@ return new NhQueryable<T>(session); } - public static void ForEach<T>(this IEnumerable<T> query, System.Action<T> method) + public static void ForEach<T>(this IEnumerable<T> query, Action<T> method) { foreach (T item in query) { @@ -26,7 +26,7 @@ public static bool IsNullableOrReference(this System.Type type) { - return !type.IsValueType || (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)); + return !type.IsValueType || type.IsNullable(); } public static System.Type NullableOf(this System.Type type) @@ -34,6 +34,16 @@ return type.GetGenericArguments()[0]; } + public static bool IsPrimitive(this System.Type type) + { + return (type.IsValueType || type.IsNullable() || type == typeof (string)); + } + + public static bool IsNonPrimitive(this System.Type type) + { + return !type.IsPrimitive(); + } + public static T As<T>(this object source) { return (T) source; Modified: trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs 2010-02-16 22:50:19 UTC (rev 4944) +++ trunk/nhibernate/src/NHibernate/Linq/NhLinqExpression.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -3,7 +3,6 @@ using System.Linq; using System.Linq.Expressions; using NHibernate.Engine.Query; -using NHibernate.Hql.Ast; using NHibernate.Hql.Ast.ANTLR.Tree; using NHibernate.Linq.ResultOperators; using NHibernate.Linq.Visitors; @@ -71,8 +70,11 @@ var queryModel = NhRelinqQueryParser.Parse(_expression); ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, - _constantToParameterMap, - requiredHqlParameters, true); + new VisitorParameters( + sessionFactory, + _constantToParameterMap, + requiredHqlParameters), + true); ParameterDescriptors = requiredHqlParameters.AsReadOnly(); _astNode = ExpressionToHqlTranslationResults.Statement.AstNode; Added: trunk/nhibernate/src/NHibernate/Linq/ReWriters/AddLeftJoinsReWriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ReWriters/AddLeftJoinsReWriter.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/AddLeftJoinsReWriter.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -0,0 +1,78 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using Remotion.Data.Linq; +using Remotion.Data.Linq.Clauses; + +namespace NHibernate.Linq.Visitors +{ + public class AddLeftJoinsReWriter : QueryModelVisitorBase + { + private readonly ISessionFactory _sessionFactory; + + private AddLeftJoinsReWriter(ISessionFactory sessionFactory) + { + _sessionFactory = sessionFactory; + } + + public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory) + { + var rewriter = new AddLeftJoinsReWriter(sessionFactory); + + rewriter.VisitQueryModel(queryModel); + } + + public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) + { + var joins = LeftJoinDetector.Detect(selectClause.Selector, new NameGenerator(queryModel), _sessionFactory); + + if (joins.Joins.Count > 0) + { + selectClause.Selector = joins.Selector; + + queryModel.TransformExpressions(e => ExpressionSwapper.Swap(e, joins.ExpressionMap)); + + foreach (var join in joins.Joins) + { + queryModel.BodyClauses.Add(join); + } + } + } + } + + public class ExpressionSwapper : NhExpressionTreeVisitor + { + private readonly Dictionary<Expression, Expression> _expressionMap; + + private ExpressionSwapper(Dictionary<Expression, Expression> expressionMap) + { + _expressionMap = expressionMap; + } + + public static Expression Swap(Expression expression, Dictionary<Expression, Expression> expressionMap) + { + var swapper = new ExpressionSwapper(expressionMap); + + return swapper.VisitExpression(expression); + } + + protected override Expression VisitExpression(Expression expression) + { + if (expression == null) + { + return null; + } + + Expression replacement; + + if (_expressionMap.TryGetValue(expression, out replacement)) + { + return replacement; + } + else + { + return base.VisitExpression(expression); + } + } + } +} \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs 2010-02-16 22:50:19 UTC (rev 4944) +++ trunk/nhibernate/src/NHibernate/Linq/ReWriters/MergeAggregatingResultsRewriter.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -67,24 +67,27 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) { - selectClause.TransformExpressions(MergeAggregatingResultsInExpressionRewriter.Rewrite); + selectClause.TransformExpressions(e => MergeAggregatingResultsInExpressionRewriter.Rewrite(e, new NameGenerator(queryModel))); } public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) { - whereClause.TransformExpressions(MergeAggregatingResultsInExpressionRewriter.Rewrite); + whereClause.TransformExpressions(e => MergeAggregatingResultsInExpressionRewriter.Rewrite(e, new NameGenerator(queryModel))); } } internal class MergeAggregatingResultsInExpressionRewriter : NhExpressionTreeVisitor { - private MergeAggregatingResultsInExpressionRewriter() + private readonly NameGenerator _nameGenerator; + + private MergeAggregatingResultsInExpressionRewriter(NameGenerator nameGenerator) { + _nameGenerator = nameGenerator; } - public static Expression Rewrite(Expression expression) + public static Expression Rewrite(Expression expression, NameGenerator nameGenerator) { - var visitor = new MergeAggregatingResultsInExpressionRewriter(); + var visitor = new MergeAggregatingResultsInExpressionRewriter(nameGenerator); return visitor.VisitExpression(expression); } @@ -131,8 +134,7 @@ private Expression CreateAggregate(Expression fromClauseExpression, LambdaExpression body, Func<Expression,Expression> aggregateFactory, Func<ResultOperatorBase> resultOperatorFactory) { - // TODO - need generated name here - var fromClause = new MainFromClause("x2", body.Parameters[0].Type, fromClauseExpression); + var fromClause = new MainFromClause(_nameGenerator.GetNewName(), body.Parameters[0].Type, fromClauseExpression); var selectClause = body.Body; selectClause = ReplacingExpressionTreeVisitor.Replace(body.Parameters[0], new QuerySourceReferenceExpression( Modified: trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs 2010-02-16 22:50:19 UTC (rev 4944) +++ trunk/nhibernate/src/NHibernate/Linq/ResultTransformer.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -1,7 +1,6 @@ using System; using System.Collections; using System.Linq; -using System.Linq.Expressions; using NHibernate.Transform; namespace NHibernate.Linq Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs 2010-02-16 22:50:19 UTC (rev 4944) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/EqualityHqlGenerator.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -1,7 +1,5 @@ using System; -using System.Collections.Generic; using System.Linq.Expressions; -using NHibernate.Engine.Query; using NHibernate.Hql.Ast; namespace NHibernate.Linq.Visitors @@ -12,29 +10,22 @@ public class EqualityHqlGenerator { private readonly HqlTreeBuilder _hqlTreeBuilder; - private readonly IDictionary<ConstantExpression, NamedParameter> _parameters; - private readonly IList<NamedParameterDescriptor> _requiredHqlParameters; + private readonly VisitorParameters _parameters; - public EqualityHqlGenerator(IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters) + public EqualityHqlGenerator(VisitorParameters parameters) { _parameters = parameters; - _requiredHqlParameters = requiredHqlParameters; _hqlTreeBuilder = new HqlTreeBuilder(); } public HqlBooleanExpression Visit(Expression innerKeySelector, Expression outerKeySelector) { - if (innerKeySelector is NewExpression && outerKeySelector is NewExpression) - { - return VisitNew((NewExpression)innerKeySelector, (NewExpression)outerKeySelector); - } - else - { - return GenerateEqualityNode(innerKeySelector, outerKeySelector); - } + return innerKeySelector is NewExpression && outerKeySelector is NewExpression + ? VisitNew((NewExpression) innerKeySelector, (NewExpression) outerKeySelector) + : GenerateEqualityNode(innerKeySelector, outerKeySelector); } - private HqlBooleanExpression VisitNew(NewExpression innerKeySelector, NewExpression outerKeySelector) + private HqlBooleanExpression VisitNew(NewExpression innerKeySelector, NewExpression outerKeySelector) { if (innerKeySelector.Arguments.Count != outerKeySelector.Arguments.Count) { @@ -59,8 +50,8 @@ private HqlEquality GenerateEqualityNode(Expression leftExpr, Expression rightExpr) { // TODO - why two visitors? Can't we just reuse? - var left = new HqlGeneratorExpressionTreeVisitor(_parameters, _requiredHqlParameters); - var right = new HqlGeneratorExpressionTreeVisitor(_parameters, _requiredHqlParameters); + var left = new HqlGeneratorExpressionTreeVisitor(_parameters); + var right = new HqlGeneratorExpressionTreeVisitor(_parameters); return _hqlTreeBuilder.Equality(left.Visit(leftExpr).AsExpression(), right.Visit(rightExpr).AsExpression()); } Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2010-02-16 22:50:19 UTC (rev 4944) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -1,6 +1,4 @@ using System; -using System.Collections.Generic; -using System.Linq; using System.Linq.Expressions; using NHibernate.Engine.Query; using NHibernate.Hql.Ast; @@ -13,21 +11,19 @@ public class HqlGeneratorExpressionTreeVisitor : IHqlExpressionVisitor { private readonly HqlTreeBuilder _hqlTreeBuilder; - private readonly IDictionary<ConstantExpression, NamedParameter> _parameters; - private readonly IList<NamedParameterDescriptor> _requiredHqlParameters; + private readonly VisitorParameters _parameters; static private readonly FunctionRegistry FunctionRegistry = FunctionRegistry.Initialise(); - public static HqlTreeNode Visit(Expression expression, IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters) + public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters) { - var visitor = new HqlGeneratorExpressionTreeVisitor(parameters, requiredHqlParameters); + var visitor = new HqlGeneratorExpressionTreeVisitor(parameters); return visitor.VisitExpression(expression); } - public HqlGeneratorExpressionTreeVisitor(IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters) + public HqlGeneratorExpressionTreeVisitor(VisitorParameters parameters) { - _parameters = parameters; - _requiredHqlParameters = requiredHqlParameters; + _parameters = parameters; _hqlTreeBuilder = new HqlTreeBuilder(); } @@ -184,7 +180,7 @@ protected HqlTreeNode VisitNhDistinct(NhDistinctExpression expression) { - var visitor = new HqlGeneratorExpressionTreeVisitor(_parameters, _requiredHqlParameters); + var visitor = new HqlGeneratorExpressionTreeVisitor(_parameters); return _hqlTreeBuilder.DistinctHolder( _hqlTreeBuilder.Distinct(), @@ -367,9 +363,9 @@ NamedParameter namedParameter; - if (_parameters.TryGetValue(expression, out namedParameter)) + if (_parameters.ConstantToParameterMap.TryGetValue(expression, out namedParameter)) { - _requiredHqlParameters.Add(new NamedParameterDescriptor(namedParameter.Name, null, new[] { _requiredHqlParameters.Count + 1 }, false)); + _parameters.RequiredHqlParameters.Add(new NamedParameterDescriptor(namedParameter.Name, null, new[] { _parameters.RequiredHqlParameters.Count + 1 }, false)); if (namedParameter.Value is bool) { @@ -415,7 +411,7 @@ protected HqlTreeNode VisitSubQueryExpression(SubQueryExpression expression) { - ExpressionToHqlTranslationResults query = QueryModelVisitor.GenerateHqlQuery(expression.QueryModel, _parameters, _requiredHqlParameters, false); + ExpressionToHqlTranslationResults query = QueryModelVisitor.GenerateHqlQuery(expression.QueryModel, _parameters, false); return query.Statement; } Added: trunk/nhibernate/src/NHibernate/Linq/Visitors/LeftJoinDetector.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/LeftJoinDetector.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/LeftJoinDetector.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -0,0 +1,74 @@ +using System.Collections.Generic; +using System.Linq.Expressions; +using Remotion.Data.Linq.Clauses.Expressions; + +namespace NHibernate.Linq.Visitors +{ + public class LeftJoinDetector : NhExpressionTreeVisitor + { + private readonly NameGenerator _nameGenerator; + private readonly ISessionFactory _sessionFactory; + private readonly Dictionary<string, LeftJoinClause> _joins = new Dictionary<string, LeftJoinClause>(); + private readonly Dictionary<Expression, Expression> _expressionMap = new Dictionary<Expression, Expression>(); + + private LeftJoinDetector(NameGenerator nameGenerator, ISessionFactory sessionFactory) + { + _nameGenerator = nameGenerator; + _sessionFactory = sessionFactory; + } + + public static Results Detect(Expression selector, NameGenerator nameGenerator, ISessionFactory sessionFactory) + { + var detector = new LeftJoinDetector(nameGenerator, sessionFactory); + + var newSelector = detector.VisitExpression(selector); + + return new Results(newSelector, detector._joins.Values, detector._expressionMap); + } + + protected override Expression VisitMemberExpression(MemberExpression expression) + { + if (expression.Type.IsNonPrimitive() && IsEntity(expression.Type)) + { + var newExpr = AddJoin(expression); + _expressionMap.Add(expression, newExpr); + return newExpr; + } + + return base.VisitMemberExpression(expression); + } + + private bool IsEntity(System.Type type) + { + return _sessionFactory.GetClassMetadata(type) != null; + } + + private Expression AddJoin(MemberExpression expression) + { + string key = ExpressionKeyVisitor.Visit(expression, null); + LeftJoinClause join; + + if (!_joins.TryGetValue(key, out join)) + { + join = new LeftJoinClause(_nameGenerator.GetNewName(), expression.Type, expression); + _joins.Add(key, join); + } + + return new QuerySourceReferenceExpression(join); + } + + public class Results + { + public Expression Selector { get; private set; } + public ICollection<LeftJoinClause> Joins { get; private set; } + public Dictionary<Expression, Expression> ExpressionMap { get; private set; } + + public Results(Expression selector, ICollection<LeftJoinClause> joins, Dictionary<Expression, Expression> expressionMap) + { + Selector = selector; + Joins = joins; + ExpressionMap = expressionMap; + } + } + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Visitors/NameGenerator.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/NameGenerator.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/NameGenerator.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -0,0 +1,19 @@ +using Remotion.Data.Linq; + +namespace NHibernate.Linq.Visitors +{ + public class NameGenerator + { + private readonly QueryModel _model; + + public NameGenerator(QueryModel model) + { + _model = model; + } + + public string GetNewName() + { + return _model.GetNewName("_"); + } + } +} \ No newline at end of file Modified: trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2010-02-16 22:50:19 UTC (rev 4944) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -1,15 +1,13 @@ using System; -using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using System.Reflection; -using NHibernate.Engine.Query; using NHibernate.Hql.Ast; using NHibernate.Linq.GroupBy; using NHibernate.Linq.GroupJoin; using NHibernate.Linq.ResultOperators; using NHibernate.Linq.ReWriters; +using NHibernate.Linq.Visitors.ResultOperatorProcessors; using NHibernate.Type; using Remotion.Data.Linq; using Remotion.Data.Linq.Clauses; @@ -21,7 +19,7 @@ { public class QueryModelVisitor : QueryModelVisitorBase { - public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters, bool root) + public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, VisitorParameters parameters, bool root) { // Remove unnecessary body operators RemoveUnnecessaryBodyOperators.ReWrite(queryModel); @@ -44,47 +42,71 @@ // Flatten pointless subqueries QueryReferenceExpressionFlattener.ReWrite(queryModel); - var visitor = new QueryModelVisitor(parameters, requiredHqlParameters, root); - visitor.VisitQueryModel(queryModel); + // Add left joins for references + AddLeftJoinsReWriter.ReWrite(queryModel, parameters.SessionFactory); + var visitor = new QueryModelVisitor(parameters, root, queryModel); + visitor.Visit(); + return visitor.GetTranslation(); } - private readonly HqlTreeBuilder _hqlTreeBuilder; + private static readonly ResultOperatorMap ResultOperatorMap; private readonly List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>> _additionalCriteria = new List<Action<IQuery, IDictionary<string, Tuple<object, IType>>>>(); private readonly List<LambdaExpression> _listTransformers = new List<LambdaExpression>(); private readonly List<LambdaExpression> _itemTransformers = new List<LambdaExpression>(); private readonly List<LambdaExpression> _postExecuteTransformers = new List<LambdaExpression>(); - - private IStreamedDataInfo _previousEvaluationType; - private IStreamedDataInfo _currentEvaluationType; - - private readonly IDictionary<ConstantExpression, NamedParameter> _parameters; - private readonly IList<NamedParameterDescriptor> _requiredHqlParameters; private readonly bool _root; private bool _serverSide = true; - private HqlTreeNode _treeNode; - private System.Type _resultType; + public HqlTreeNode Root { get; private set; } + public VisitorParameters VisitorParameters { get; private set; } + public IStreamedDataInfo CurrentEvaluationType { get; private set; } + public IStreamedDataInfo PreviousEvaluationType { get; private set; } + public HqlTreeBuilder TreeBuilder { get; private set; } + public QueryModel Model { get; private set; } - private QueryModelVisitor(IDictionary<ConstantExpression, NamedParameter> parameters, IList<NamedParameterDescriptor> requiredHqlParameters, bool root) + static QueryModelVisitor() + { + ResultOperatorMap = new ResultOperatorMap(); + + ResultOperatorMap.Add<AggregateResultOperator, ProcessAggregate>(); + ResultOperatorMap.Add<FirstResultOperator, ProcessFirst>(); + ResultOperatorMap.Add<TakeResultOperator, ProcessTake>(); + ResultOperatorMap.Add<SkipResultOperator, ProcessSkip>(); + ResultOperatorMap.Add<GroupResultOperator, ProcessGroupBy>(); + ResultOperatorMap.Add<SingleResultOperator, ProcessSingle>(); + ResultOperatorMap.Add<ContainsResultOperator, ProcessContains>(); + ResultOperatorMap.Add<NonAggregatingGroupBy, ProcessNonAggregatingGroupBy>(); + ResultOperatorMap.Add<ClientSideSelect, ProcessClientSideSelect>(); + ResultOperatorMap.Add<AnyResultOperator, ProcessAny>(); + ResultOperatorMap.Add<AllResultOperator, ProcessAll>(); + } + + private QueryModelVisitor(VisitorParameters visitorParameters, bool root, QueryModel queryModel) { - _parameters = parameters; - _requiredHqlParameters = requiredHqlParameters; + VisitorParameters = visitorParameters; + Model = queryModel; _root = root; - _hqlTreeBuilder = new HqlTreeBuilder(); - _treeNode = _hqlTreeBuilder.Query(_hqlTreeBuilder.SelectFrom(_hqlTreeBuilder.From())); + TreeBuilder = new HqlTreeBuilder(); + Root = TreeBuilder.Query(TreeBuilder.SelectFrom(TreeBuilder.From())); } - public ExpressionToHqlTranslationResults GetTranslation() + private void Visit() + { + VisitQueryModel(Model); + } + + + private ExpressionToHqlTranslationResults GetTranslation() { if (_root) { DetectOuterExists(); } - return new ExpressionToHqlTranslationResults(_treeNode, + return new ExpressionToHqlTranslationResults(Root, _itemTransformers, _listTransformers, _postExecuteTransformers, @@ -93,9 +115,9 @@ private void DetectOuterExists() { - if (_treeNode is HqlExists) + if (Root is HqlExists) { - _treeNode = _treeNode.Children.First(); + Root = Root.Children.First(); _additionalCriteria.Add((q, p) => q.SetMaxResults(1)); @@ -107,9 +129,9 @@ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) { - AddFromClause(_hqlTreeBuilder.Range( - HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, _parameters, _requiredHqlParameters), - _hqlTreeBuilder.Alias(fromClause.ItemName))); + AddFromClause(TreeBuilder.Range( + HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters), + TreeBuilder.Alias(fromClause.ItemName))); base.VisitMainFromClause(fromClause, queryModel); } @@ -117,56 +139,55 @@ private void AddWhereClause(HqlBooleanExpression where) { - var currentWhere = _treeNode.NodesPreOrder.Where(n => n is HqlWhere).FirstOrDefault(); + var currentWhere = Root.NodesPreOrder.Where(n => n is HqlWhere).FirstOrDefault(); if (currentWhere == null) { - currentWhere = _hqlTreeBuilder.Where(where); - _treeNode.As<HqlQuery>().AddChild(currentWhere); + currentWhere = TreeBuilder.Where(where); + Root.As<HqlQuery>().AddChild(currentWhere); } else { var currentClause = (HqlBooleanExpression)currentWhere.Children.Single(); currentWhere.ClearChildren(); - currentWhere.AddChild(_hqlTreeBuilder.BooleanAnd(currentClause, where)); + currentWhere.AddChild(TreeBuilder.BooleanAnd(currentClause, where)); } } private void AddFromClause(HqlTreeNode from) { - _treeNode.NodesPreOrder.Where(n => n is HqlFrom).First().AddChild(from); + Root.NodesPreOrder.Where(n => n is HqlFrom).First().AddChild(from); } private void AddSelectClause(HqlTreeNode select) { - _treeNode.NodesPreOrder.Where(n => n is HqlSelectFrom).First().AddChild(select); + Root.NodesPreOrder.Where(n => n is HqlSelectFrom).First().AddChild(select); } private void AddGroupByClause(HqlGroupBy groupBy) { - _treeNode.As<HqlQuery>().AddChild(groupBy); + Root.As<HqlQuery>().AddChild(groupBy); } private void AddOrderByClause(HqlExpression orderBy, HqlDirectionStatement direction) { - var orderByRoot = _treeNode.NodesPreOrder.Where(n => n is HqlOrderBy).FirstOrDefault(); + var orderByRoot = Root.NodesPreOrder.Where(n => n is HqlOrderBy).FirstOrDefault(); if (orderByRoot == null) { - orderByRoot = _hqlTreeBuilder.OrderBy(); - _treeNode.As<HqlQuery>().AddChild(orderByRoot); + orderByRoot = TreeBuilder.OrderBy(); + Root.As<HqlQuery>().AddChild(orderByRoot); } orderByRoot.AddChild(orderBy); orderByRoot.AddChild(direction); } - public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index) { - _previousEvaluationType = _currentEvaluationType; - _currentEvaluationType = resultOperator.GetOutputDataInfo(_previousEvaluationType); + PreviousEvaluationType = CurrentEvaluationType; + CurrentEvaluationType = resultOperator.GetOutputDataInfo(PreviousEvaluationType); if (resultOperator is ClientSideTransformOperator) { @@ -180,270 +201,34 @@ } } - if (resultOperator is FirstResultOperator) - { - ProcessFirstOperator((FirstResultOperator) resultOperator); - } - else if (resultOperator is TakeResultOperator) - { - ProcessTakeOperator((TakeResultOperator)resultOperator); - } - else if (resultOperator is SkipResultOperator) - { - ProcessSkipOperator((SkipResultOperator)resultOperator); - } - else if (resultOperator is GroupResultOperator) - { - ProcessGroupByOperator((GroupResultOperator)resultOperator); - } - else if (resultOperator is SingleResultOperator) - { - ProcessSingleOperator((SingleResultOperator) resultOperator); - } - else if (resultOperator is ContainsResultOperator) - { - ProcessContainsOperator((ContainsResultOperator) resultOperator); - } - else if (resultOperator is NonAggregatingGroupBy) - { - ProcessNonAggregatingGroupBy((NonAggregatingGroupBy)resultOperator, queryModel); - } - else if (resultOperator is ClientSideSelect) - { - ProcessClientSideSelect((ClientSideSelect)resultOperator); - } - else if (resultOperator is AggregateResultOperator) - { - ProcessAggregateOperator((AggregateResultOperator)resultOperator); - } - else if (resultOperator is AnyResultOperator) - { - ProcessAnyOperator((AnyResultOperator) resultOperator); - } - else if (resultOperator is AllResultOperator) - { - ProcessAllOperator((AllResultOperator) resultOperator); - } - else - { - throw new NotSupportedException(string.Format("The {0} result operator is not current supported", - resultOperator.GetType().Name)); - } - } + var results = ResultOperatorMap.Process(resultOperator, this); - private void ProcessAllOperator(AllResultOperator resultOperator) - { - AddWhereClause(_hqlTreeBuilder.BooleanNot( - HqlGeneratorExpressionTreeVisitor.Visit(resultOperator.Predicate, _parameters, - _requiredHqlParameters).AsBooleanExpression())); - - _treeNode = _hqlTreeBuilder.BooleanNot(_hqlTreeBuilder.Exists((HqlQuery)_treeNode)); - } - - private void ProcessAnyOperator(AnyResultOperator anyOperator) - { - _treeNode = _hqlTreeBuilder.Exists((HqlQuery) _treeNode); - } - - private void ProcessContainsOperator(ContainsResultOperator resultOperator) - { - var itemExpression = - HqlGeneratorExpressionTreeVisitor.Visit(resultOperator.Item, _parameters, _requiredHqlParameters) - .AsExpression(); - - var from = GetFromRangeClause(); - var source = from.Children.First(); - - if (source is HqlParameter) + if (results.AdditionalCriteria != null) { - // This is an "in" style statement - _treeNode = _hqlTreeBuilder.In(itemExpression, source); - + _additionalCriteria.Add(results.AdditionalCriteria); } - else + if (results.GroupBy != null) { - // This is an "exists" style statement - AddWhereClause(_hqlTreeBuilder.Equality( - _hqlTreeBuilder.Ident(GetFromAlias().AstNode.Text), - itemExpression)); - - _treeNode = _hqlTreeBuilder.Exists((HqlQuery)_treeNode); + AddGroupByClause(results.GroupBy); } - } - - private HqlAlias GetFromAlias() - { - return _treeNode.NodesPreOrder.Single(n => n is HqlRange).Children.Single(n => n is HqlAlias) as HqlAlias; - } - - private HqlRange GetFromRangeClause() - { - return _treeNode.NodesPreOrder.Single(n => n is HqlRange).As<HqlRange>(); - } - - private void ProcessAggregateOperator(AggregateResultOperator resultOperator) - { - var inputType = resultOperator.Accumulator.Parameters[1].Type; - var accumulatorType = resultOperator.Accumulator.Parameters[0].Type; - var inputList = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(typeof(object)), "inputList"); - - var castToItem = EnumerableHelper.GetMethod("Cast", new[] { typeof(IEnumerable) }, new[] { inputType }); - var castToItemExpr = Expression.Call(castToItem, inputList); - - MethodCallExpression call; - - if (resultOperator.ParseInfo.ParsedExpression.Arguments.Count == 2) + if (results.ListTransformer != null) { - var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object>(null, null)); - aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType); - - call = Expression.Call( - aggregate, - castToItemExpr, - resultOperator.Accumulator - ); - + _listTransformers.Add(results.ListTransformer); } - else if (resultOperator.ParseInfo.ParsedExpression.Arguments.Count == 3) + if (results.PostExecuteTransformer != null) { - var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object, object>(null, null, null)); - aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType, accumulatorType); - - call = Expression.Call( - aggregate, - castToItemExpr, - resultOperator.OptionalSeed, - resultOperator.Accumulator - ); + _postExecuteTransformers.Add(results.PostExecuteTransformer); } - else + if (results.WhereClause != null) { - var selectorType = resultOperator.OptionalSelector.Type.GetGenericArguments()[2]; - var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object, object, object>(null, null, null, null)); - aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType, accumulatorType, selectorType); - - call = Expression.Call( - aggregate, - castToItemExpr, - resultOperator.OptionalSeed, - resultOperator.Accumulator, - resultOperator.OptionalSelector - ); + AddWhereClause(results.WhereClause); } - - _listTransformers.Add(Expression.Lambda(call, inputList)); - } - - private void ProcessClientSideSelect(ClientSideSelect resultOperator) - { - var inputType = resultOperator.SelectClause.Parameters[0].Type; - var outputType = resultOperator.SelectClause.Type.GetGenericArguments()[1]; - - var inputList = Expression.Parameter(typeof (IEnumerable<>).MakeGenericType(inputType), "inputList"); - - var selectMethod = EnumerableHelper.GetMethod("Select", new[] { typeof(IEnumerable<>), typeof(Func<,>) }, new[] { inputType, outputType }); - var toListMethod = EnumerableHelper.GetMethod("ToList", new[] {typeof (IEnumerable<>)}, new[] {outputType}); - - var lambda = Expression.Lambda( - Expression.Call(toListMethod, - Expression.Call(selectMethod, inputList, resultOperator.SelectClause)), - inputList); - - _listTransformers.Add(lambda); - } - - private void ProcessTakeOperator(TakeResultOperator resultOperator) - { - NamedParameter parameterName; - - // TODO - very similar to ProcessSkip, plus want to investigate the scenario in the "else" - // clause to see if it is valid - if (_parameters.TryGetValue(resultOperator.Count as ConstantExpression, out parameterName)) + if (results.TreeNode != null) { - _additionalCriteria.Add((q, p) => q.SetMaxResults((int) p[parameterName.Name].First)); + Root = results.TreeNode; } - else - { - _additionalCriteria.Add((q, p) => q.SetMaxResults(resultOperator.GetConstantCount())); - } - } - - private void ProcessSkipOperator(SkipResultOperator resultOperator) - { - NamedParameter parameterName; - - if (_parameters.TryGetValue(resultOperator.Count as ConstantExpression, out parameterName)) - { - _additionalCriteria.Add((q, p) => q.SetFirstResult((int)p[parameterName.Name].First)); - } - else - { - _additionalCriteria.Add((q, p) => q.SetFirstResult(resultOperator.GetConstantCount())); - } } - private void ProcessNonAggregatingGroupBy(NonAggregatingGroupBy resultOperator, QueryModel model) - { - var tSource = model.SelectClause.Selector.Type; - var tKey = resultOperator.GroupBy.KeySelector.Type; - var tElement = resultOperator.GroupBy.ElementSelector.Type; - - // Stuff in the group by that doesn't map to HQL. Run it client-side - var listParameter = Expression.Parameter(typeof (IEnumerable<object>), "list"); - - ParameterExpression itemParam = Expression.Parameter(tSource, "item"); - Expression keySelectorSource = itemParam; - - if (tSource != SourceOf(resultOperator.GroupBy.KeySelector)) - { - keySelectorSource = Expression.MakeMemberAccess(itemParam, - tSource.GetMember( - ((QuerySourceReferenceExpression) - resultOperator.GroupBy.KeySelector).ReferencedQuerySource. - ItemName)[0]); - } - - - Expression keySelector = new GroupByKeySelectorVisitor(keySelectorSource).Visit(resultOperator.GroupBy.KeySelector); - - Expression elementSelectorSource = itemParam; - - if (tSource != SourceOf(resultOperator.GroupBy.ElementSelector)) - { - elementSelectorSource = Expression.MakeMemberAccess(itemParam, - tSource.GetMember( - ((QuerySourceReferenceExpression) - resultOperator.GroupBy.ElementSelector).ReferencedQuerySource. - ItemName)[0]); - } - - Expression elementSelector = new GroupByKeySelectorVisitor(elementSelectorSource).Visit(resultOperator.GroupBy.ElementSelector); - - var groupByMethod = EnumerableHelper.GetMethod("GroupBy", - new[] { typeof(IEnumerable<>), typeof(Func<,>), typeof(Func<,>) }, - new[] { tSource, tKey, tElement }); - - var castToItem = EnumerableHelper.GetMethod("Cast", new[] { typeof(IEnumerable) }, new[] { tSource }); - - var toList = EnumerableHelper.GetMethod("ToList", new [] { typeof(IEnumerable<>)}, new [] {resultOperator.GroupBy.ItemType}); - - LambdaExpression keySelectorExpr = Expression.Lambda(keySelector, itemParam); - - LambdaExpression elementSelectorExpr = Expression.Lambda(elementSelector, itemParam); - - Expression castToItemExpr = Expression.Call(castToItem, listParameter); - - var groupByExpr = Expression.Call(groupByMethod, castToItemExpr, keySelectorExpr, elementSelectorExpr); - - var toListExpr = Expression.Call(toList, groupByExpr); - - var lambdaExpr = Expression.Lambda(toListExpr, listParameter); - - _listTransformers.Add(lambdaExpr); - - return; - } - private void GroupBy<TSource, TKey, TResult>(Expression<Func<TSource, TKey>> keySelector, Expression<Func<TSource, TResult>> elementSelector) { IQueryable<object> list = null; @@ -451,55 +236,11 @@ var x = list.Cast<TSource>().GroupBy(keySelector, elementSelector); } - private static System.Type SourceOf(Expression keySelector) - { - return new GroupByKeySourceFinder().Visit(keySelector).Type; - } - - private void ProcessGroupByOperator(GroupResultOperator resultOperator) - { - AddGroupByClause(_hqlTreeBuilder.GroupBy(HqlGeneratorExpressionTreeVisitor.Visit(resultOperator.KeySelector, _parameters, _requiredHqlParameters).AsExpression())); - } - - private void ProcessFirstOperator(FirstResultOperator resultOperator) - { - var firstMethod = resultOperator.ReturnDefaultWhenEmpty - ? ReflectionHelper.GetMethod(() => Queryable.FirstOrDefault<object>(null)) - : ReflectionHelper.GetMethod(() => Queryable.First<object>(null)); - - ProcessFirstOrSingle(firstMethod); - } - - private void ProcessSingleOperator(SingleResultOperator resultOperator) - { - var firstMethod = resultOperator.ReturnDefaultWhenEmpty - ? ReflectionHelper.GetMethod(() => Queryable.SingleOrDefault<object>(null)) - : ReflectionHelper.GetMethod(() => Queryable.Single<object>(null)); - - ProcessFirstOrSingle(firstMethod); - } - - private void ProcessFirstOrSingle(MethodInfo target) - { - target = target.MakeGenericMethod(_currentEvaluationType.DataType); - - var parameter = Expression.Parameter(_previousEvaluationType.DataType, null); - - var lambda = Expression.Lambda( - Expression.Call( - target, - parameter), - parameter); - - _additionalCriteria.Add((q, p) => q.SetMaxResults(1)); - _postExecuteTransformers.Add(lambda); - } - public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) { - _currentEvaluationType = selectClause.GetOutputDataInfo(); + CurrentEvaluationType = selectClause.GetOutputDataInfo(); - var visitor = new SelectClauseVisitor(typeof(object[]), _parameters, _requiredHqlParameters); + var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters); visitor.Visit(selectClause.Selector); @@ -508,7 +249,7 @@ _itemTransformers.Add(visitor.ProjectionExpression); } - AddSelectClause(_hqlTreeBuilder.Select(visitor.GetHqlNodes())); + AddSelectClause(TreeBuilder.Select(visitor.GetHqlNodes())); base.VisitSelectClause(selectClause, queryModel); } @@ -516,43 +257,50 @@ public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) { // Visit the predicate to build the query - AddWhereClause(HqlGeneratorExpressionTreeVisitor.Visit(whereClause.Predicate, _parameters, _requiredHqlParameters).AsBooleanExpression()); + AddWhereClause(HqlGeneratorExpressionTreeVisitor.Visit(whereClause.Predicate, VisitorParameters).AsBooleanExpression()); } public override void VisitOrderByClause(OrderByClause orderByClause, QueryModel queryModel, int index) { foreach (Ordering clause in orderByClause.Orderings) { - AddOrderByClause(HqlGeneratorExpressionTreeVisitor.Visit(clause.Expression, _parameters, _requiredHqlParameters).AsExpression(), + AddOrderByClause(HqlGeneratorExpressionTreeVisitor.Visit(clause.Expression, VisitorParameters).AsExpression(), clause.OrderingDirection == OrderingDirection.Asc - ? _hqlTreeBuilder.Ascending() - : (HqlDirectionStatement) _hqlTreeBuilder.Descending()); + ? TreeBuilder.Ascending() + : (HqlDirectionStatement) TreeBuilder.Descending()); } } public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) { - var equalityVisitor = new EqualityHqlGenerator(_parameters, _requiredHqlParameters); + var equalityVisitor = new EqualityHqlGenerator(VisitorParameters); var whereClause = equalityVisitor.Visit(joinClause.InnerKeySelector, joinClause.OuterKeySelector); AddWhereClause(whereClause); - AddFromClause(_hqlTreeBuilder.Range(HqlGeneratorExpressionTreeVisitor.Visit(joinClause.InnerSequence, _parameters, _requiredHqlParameters), - _hqlTreeBuilder.Alias(joinClause.ItemName))); + AddFromClause(TreeBuilder.Range(HqlGeneratorExpressionTreeVisitor.Visit(joinClause.InnerSequence, VisitorParameters), + TreeBuilder.Alias(joinClause.ItemName))); } public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { - if (fromClause.FromExpression is MemberExpression) + if (fromClause is LeftJoinClause) + { + // It's a left join + AddFromClause(TreeBuilder.LeftJoin( + HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(), + TreeBuilder.Alias(fromClause.ItemName))); + } + else if (fromClause.FromExpression is MemberExpression) { var member = (MemberExpression) fromClause.FromExpression; if (member.Expression is QuerySourceReferenceExpression) { // It's a join - AddFromClause(_hqlTreeBuilder.Join( - HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, _parameters, _requiredHqlParameters).AsExpression(), - _hqlTreeBuilder.Alias(fromClause.ItemName))); + AddFromClause(TreeBuilder.Join( + HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(), + TreeBuilder.Alias(fromClause.ItemName))); } else { @@ -563,9 +311,9 @@ else { // TODO - exact same code as in MainFromClause; refactor this out - AddFromClause(_hqlTreeBuilder.Range( - HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, _parameters, _requiredHqlParameters), - _hqlTreeBuilder.Alias(fromClause.ItemName))); + AddFromClause(TreeBuilder.Range( + HqlGeneratorExpressionTreeVisitor.Visit(fromClause.FromExpression, VisitorParameters), + TreeBuilder.Alias(fromClause.ItemName))); } @@ -576,11 +324,5 @@ { throw new NotImplementedException(); } - - internal enum ResultOperatorProcessingMode - { - ProcessServerSide, - ProcessClientSide - } } } \ No newline at end of file Property changes on: trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors ___________________________________________________________________ Added: bugtraq:url + http://jira.nhibernate.org/browse/%BUGID% Added: bugtraq:logregex + NH-\d+ Added: trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/IResultOperatorProcessor.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/IResultOperatorProcessor.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/IResultOperatorProcessor.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -0,0 +1,7 @@ +namespace NHibernate.Linq.Visitors.ResultOperatorProcessors +{ + public interface IResultOperatorProcessor<T> + { + ProcessResultOperatorReturn Process(T resultOperator, QueryModelVisitor queryModelVisitor); + } +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs (rev 0) +++ trunk/nhibernate/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs 2010-02-24 15:52:13 UTC (rev 4945) @@ -0,0 +1,63 @@ +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; + +namespace NHibernate.Linq.Visitors.ResultOperatorProcessors +{ + public class ProcessAggregate : IResultOperatorProcessor<AggregateResultOperator> + { + public ProcessResultOperatorReturn Process(AggregateResultOperator resultOperator, QueryModelVisitor queryModelVisitor) + { + var inputType = resultOperator.Accumulator.Parameters[1].Type; + var accumulatorType = resultOperator.Accumulator.Parameters[0].Type; + var inputList = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(typeof(object)), "inputList"); + + var castToItem = EnumerableHelper.GetMethod("Cast", new[] { typeof(IEnumerable) }, new[] { inputType }); + var castToItemExpr = Expression.Call(castToItem, inputList); + + MethodCallExpression call; + + if (resultOperator.ParseInfo.ParsedExpression.Arguments.Count == 2) + { + var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object>(null, null)); + aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType); + + call = Expression.Call( + aggregate, + castToItemExpr, + resultOperator.Accumulator + ); + + } + else if (resultOperator.ParseInfo.ParsedExpression.Arguments.Count == 3) + { + var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object, object>(null, null, null)); + aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType, accumulatorType); + + call = Expression.Call( + aggregate, + castToItemExpr, + resultOperator.OptionalSeed, + resultOperator.Accumulator + ); + } + else + { + var selectorType = resultOperator.OptionalSelector.Type.GetGenericArguments()[2]; + var aggregate = ReflectionHelper.GetMethod(() => Enumerable.Aggregate<object, object, object>(null, null, null, null)); + aggregate = aggregate.GetGenericMethodDefinition().MakeGenericMethod(inputType, accumulatorType, selectorType); + + call = Expression.Call( + aggregate, + castToItemExpr, + resultOperator.OptionalSeed, + resultOperator.Accumulator, + resultOperator.OptionalSelector + ... [truncated message content] |