From: <fab...@us...> - 2009-07-09 17:05:28
|
Revision: 4600 http://nhibernate.svn.sourceforge.net/nhibernate/?rev=4600&view=rev Author: fabiomaulo Date: 2009-07-09 17:05:23 +0000 (Thu, 09 Jul 2009) Log Message: ----------- Merge r4598 and r4599 (fix NH-1876) Modified Paths: -------------- trunk/nhibernate/src/NHibernate/Id/TableGenerator.cs Added Paths: ----------- trunk/nhibernate/src/NHibernate.Test/IdTest/TableGeneratorFixture.cs Modified: trunk/nhibernate/src/NHibernate/Id/TableGenerator.cs =================================================================== --- trunk/nhibernate/src/NHibernate/Id/TableGenerator.cs 2009-07-09 16:45:16 UTC (rev 4599) +++ trunk/nhibernate/src/NHibernate/Id/TableGenerator.cs 2009-07-09 17:05:23 UTC (rev 4600) @@ -4,9 +4,9 @@ using System.Data; using System.Runtime.CompilerServices; using log4net; +using NHibernate.AdoNet.Util; using NHibernate.Dialect; using NHibernate.Engine; -using NHibernate.Engine.Transaction; using NHibernate.SqlCommand; using NHibernate.SqlTypes; using NHibernate.Type; @@ -14,9 +14,6 @@ namespace NHibernate.Id { - using System.Transactions; - using NHibernate.AdoNet.Util; - /// <summary> /// An <see cref="IIdentifierGenerator" /> that uses a database table to store the last /// generated value. @@ -38,7 +35,8 @@ /// </remarks> public class TableGenerator : TransactionHelper, IPersistentIdentifierGenerator, IConfigurable { - private static readonly ILog log = LogManager.GetLogger(typeof(TableGenerator)); + private static readonly ILog log = LogManager.GetLogger(typeof (TableGenerator)); + /// <summary> /// An additional where clause that is added to /// the queries against the table. @@ -83,7 +81,6 @@ /// <param name="dialect">The <see cref="Dialect"/> to help with Configuration.</param> public virtual void Configure(IType type, IDictionary<string, string> parms, Dialect.Dialect dialect) { - tableName = PropertiesHelper.GetString(TableParamName, parms, DefaultTableName); columnName = PropertiesHelper.GetString(ColumnParamName, parms, DefaultColumnName); whereClause = PropertiesHelper.GetString(Where, parms, ""); @@ -95,9 +92,17 @@ tableName = dialect.Qualify(catalogName, schemaName, tableName); } - query = "select " + columnName + " from " + dialect.AppendLockHint(LockMode.Upgrade, tableName) - + dialect.ForUpdateString; + var selectBuilder = new SqlStringBuilder(100); + selectBuilder.Add("select " + columnName) + .Add(" from " + dialect.AppendLockHint(LockMode.Upgrade, tableName)); + if (string.IsNullOrEmpty(whereClause) == false) + { + selectBuilder.Add(" where ").Add(whereClause); + } + selectBuilder.Add(dialect.ForUpdateString); + query = selectBuilder.ToString(); + columnType = type as PrimitiveType; if (columnType == null) { @@ -119,21 +124,16 @@ columnSqlType = SqlTypeFactory.Int32; } - parameterTypes = new SqlType[2] {columnSqlType, columnSqlType}; + parameterTypes = new[] {columnSqlType, columnSqlType}; - SqlStringBuilder builder = new SqlStringBuilder(); + var builder = new SqlStringBuilder(100); builder.Add("update " + tableName + " set ") - .Add(columnName) - .Add(" = ") - .Add(Parameter.Placeholder) + .Add(columnName).Add(" = ").Add(Parameter.Placeholder) .Add(" where ") - .Add(columnName) - .Add(" = ") - .Add(Parameter.Placeholder); + .Add(columnName).Add(" = ").Add(Parameter.Placeholder); if (string.IsNullOrEmpty(whereClause) == false) { - builder.Add(" and ") - .Add(whereClause); + builder.Add(" and ").Add(whereClause); } updateSql = builder.ToSqlString(); @@ -172,16 +172,16 @@ /// create the necessary database objects and to create the first value as <c>1</c> /// for the TableGenerator. /// </returns> - public string[] SqlCreateStrings(Dialect.Dialect dialect) + public virtual string[] SqlCreateStrings(Dialect.Dialect dialect) { // changed the first value to be "1" by default since an uninitialized Int32 is 0 - leaving // it at 0 would cause problems with an unsaved-value="0" which is what most people are // defaulting <id>'s with Int32 types at. - return new string[] - { - "create table " + tableName + " ( " + columnName + " " + dialect.GetTypeName(columnSqlType) + " )", - "insert into " + tableName + " values ( 1 )" - }; + return new[] + { + "create table " + tableName + " ( " + columnName + " " + dialect.GetTypeName(columnSqlType) + " )", + "insert into " + tableName + " values ( 1 )" + }; } /// <summary> @@ -191,9 +191,9 @@ /// <returns> /// A <see cref="string"/> that will drop the database objects for the TableGenerator. /// </returns> - public string[] SqlDropString(Dialect.Dialect dialect) + public virtual string[] SqlDropString(Dialect.Dialect dialect) { - return new string[] { dialect.GetDropTableString(tableName) }; + return new[] {dialect.GetDropTableString(tableName)}; } /// <summary> @@ -209,7 +209,8 @@ #endregion - public override object DoWorkInCurrentTransaction(ISessionImplementor session, IDbConnection conn, IDbTransaction transaction) + public override object DoWorkInCurrentTransaction(ISessionImplementor session, IDbConnection conn, + IDbTransaction transaction) { long result; int rows; @@ -243,15 +244,18 @@ } finally { - if (rs != null) rs.Close(); + if (rs != null) + { + rs.Close(); + } qps.Dispose(); } - IDbCommand ups = - session.Factory.ConnectionProvider.Driver.GenerateCommand(CommandType.Text, updateSql, parameterTypes); + IDbCommand ups = session.Factory.ConnectionProvider.Driver.GenerateCommand(CommandType.Text, updateSql, + parameterTypes); ups.Connection = conn; ups.Transaction = transaction; - + try { columnType.Set(ups, result + 1, 0); @@ -270,9 +274,10 @@ { ups.Dispose(); } - } while (rows == 0); + } + while (rows == 0); return result; } } -} +} \ No newline at end of file Added: trunk/nhibernate/src/NHibernate.Test/IdTest/TableGeneratorFixture.cs =================================================================== --- trunk/nhibernate/src/NHibernate.Test/IdTest/TableGeneratorFixture.cs (rev 0) +++ trunk/nhibernate/src/NHibernate.Test/IdTest/TableGeneratorFixture.cs 2009-07-09 17:05:23 UTC (rev 4600) @@ -0,0 +1,29 @@ +using System.Collections.Generic; +using System.Reflection; +using NHibernate.Dialect; +using NHibernate.Id; +using NUnit.Framework; + +namespace NHibernate.Test.IdTest +{ + [TestFixture] + public class TableGeneratorFixture + { + private const BindingFlags Flags = + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly; + + private readonly FieldInfo updateSql = typeof (TableGenerator).GetField("updateSql", Flags); + private readonly FieldInfo selectSql = typeof (TableGenerator).GetField("query", Flags); + + [Test] + public void SelectAndUpdateStringContainCustomWhere() + { + const string customWhere = "table_name='second'"; + var dialect = new MsSql2005Dialect(); + var tg = new TableGenerator(); + tg.Configure(NHibernateUtil.Int64, new Dictionary<string, string> {{"where", customWhere}}, dialect); + Assert.That(selectSql.GetValue(tg).ToString(), Text.Contains(customWhere)); + Assert.That(updateSql.GetValue(tg).ToString(), Text.Contains(customWhere)); + } + } +} This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |