<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN"
"http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head><meta http-equiv="content-type" content="text/html; charset=utf-8" /><style type="text/css"><!--
#msg dl { border: 1px #006 solid; background: #369; padding: 6px; color: #fff; }
#msg dt { float: left; width: 6em; font-weight: bold; }
#msg dt:after { content:':';}
#msg dl, #msg dt, #msg ul, #msg li, #header, #footer { font-family: verdana,arial,helvetica,sans-serif; font-size: 10pt; }
#msg dl a { font-weight: bold}
#msg dl a:link { color:#fc3; }
#msg dl a:active { color:#ff0; }
#msg dl a:visited { color:#cc6; }
h3 { font-family: verdana,arial,helvetica,sans-serif; font-size: 10pt; font-weight: bold; }
#msg pre { overflow: auto; background: #ffc; border: 1px #fc0 solid; padding: 6px; }
#msg ul, pre { overflow: auto; }
#header, #footer { color: #fff; background: #636; border: 1px #300 solid; padding: 6px; }
#patch { width: 100%; }
#patch h4 {font-family: verdana,arial,helvetica,sans-serif;font-size:10pt;padding:8px;background:#369;color:#fff;margin:0;}
#patch .propset h4, #patch .binary h4 {margin:0;}
#patch pre {padding:0;line-height:1.2em;margin:0;}
#patch .diff {width:100%;background:#eee;padding: 0 0 10px 0;overflow:auto;}
#patch .propset .diff, #patch .binary .diff {padding:10px 0;}
#patch span {display:block;padding:0 10px;}
#patch .modfile, #patch .addfile, #patch .delfile, #patch .propset, #patch .binary, #patch .copfile {border:1px solid #ccc;margin:10px 0;}
#patch ins {background:#dfd;text-decoration:none;display:block;padding:0 10px;}
#patch del {background:#fdd;text-decoration:none;display:block;padding:0 10px;}
#patch .lines, .info {color:#888;background:#fff;}
--></style>
<title>[6027] sqlalchemy/branches/nosetests: the basic idea.</title>
</head>
<body>
<div id="msg">
<dl>
<dt>Revision</dt> <dd>6027</dd>
<dt>Author</dt> <dd>zzzeek</dd>
<dt>Date</dt> <dd>2009-06-07 18:09:07 -0400 (Sun, 07 Jun 2009)</dd>
</dl>
<h3>Log Message</h3>
<pre>the basic idea. query.py becomes test_query.py</pre>
<h3>Modified Paths</h3>
<ul>
<li><a href="#sqlalchemybranchesnosetestssetupcfg">sqlalchemy/branches/nosetests/setup.cfg</a></li>
<li><a href="#sqlalchemybranchesnosetestssetuppy">sqlalchemy/branches/nosetests/setup.py</a></li>
</ul>
<h3>Added Paths</h3>
<ul>
<li>sqlalchemy/branches/nosetests/lib/sqlalchemy/test/</li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytest__init__py">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/__init__.py</a></li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytestassertsqlpy">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/assertsql.py</a></li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytestcompatpy">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/compat.py</a></li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytestconfigpy">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/config.py</a></li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytestenginespy">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/engines.py</a></li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytestnosepluginpy">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/noseplugin.py</a></li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytestormpy">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/orm.py</a></li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytestprofilingpy">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/profiling.py</a></li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytestrequirespy">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/requires.py</a></li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytestschemapy">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/schema.py</a></li>
<li><a href="#sqlalchemybranchesnosetestslibsqlalchemytesttestingpy">sqlalchemy/branches/nosetests/lib/sqlalchemy/test/testing.py</a></li>
<li><a href="#sqlalchemybranchesnoseteststestsqltest_querypy">sqlalchemy/branches/nosetests/test/sql/test_query.py</a></li>
</ul>
<h3>Removed Paths</h3>
<ul>
<li><a href="#sqlalchemybranchesnoseteststestsqlquerypy">sqlalchemy/branches/nosetests/test/sql/query.py</a></li>
<li><a href="#sqlalchemybranchesnoseteststesttestenvpy">sqlalchemy/branches/nosetests/test/testenv.py</a></li>
<li>sqlalchemy/branches/nosetests/test/testlib/</li>
</ul>
</div>
<div id="patch">
<h3>Diff</h3>
<a id="sqlalchemybranchesnosetestslibsqlalchemytest__init__py"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/__init__.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/__init__.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/__init__.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,29 @@
</span><ins>+"""Enhance unittest and instrument SQLAlchemy classes for testing.
+
+Load after sqlalchemy imports to use instrumented stand-ins like Table.
+"""
+
+import sys
+import config
+import testing, engines, requires, profiling
+from schema import Table, Column
+from testing import \
+ AssertsCompiledSQL, \
+ AssertsExecutionResults, \
+ ComparesTables, \
+ TestBase, \
+ rowset
+from orm import mapper
+from compat import _function_named
+
+
+__all__ = ('testing',
+ 'mapper',
+ 'Table', 'Column',
+ 'rowset',
+ 'TestBase', 'AssertsExecutionResults',
+ 'AssertsCompiledSQL', 'ComparesTables',
+ 'engines',
+ '_function_named')
+
+
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestslibsqlalchemytestassertsqlpy"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/assertsql.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/assertsql.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/assertsql.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,283 @@
</span><ins>+
+from sqlalchemy.interfaces import ConnectionProxy
+from sqlalchemy.engine.default import DefaultDialect
+from sqlalchemy.engine.base import Connection
+from sqlalchemy import util
+import testing
+import re
+
+class AssertRule(object):
+ def process_execute(self, clauseelement, *multiparams, **params):
+ pass
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ pass
+
+ def is_consumed(self):
+ """Return True if this rule has been consumed, False if not.
+
+ Should raise an AssertionError if this rule's condition has definitely failed.
+
+ """
+ raise NotImplementedError()
+
+ def rule_passed(self):
+ """Return True if the last test of this rule passed, False if failed, None if no test was applied."""
+
+ raise NotImplementedError()
+
+ def consume_final(self):
+ """Return True if this rule has been consumed.
+
+ Should raise an AssertionError if this rule's condition has not been consumed or has failed.
+
+ """
+
+ if self._result is None:
+ assert False, "Rule has not been consumed"
+
+ return self.is_consumed()
+
+class SQLMatchRule(AssertRule):
+ def __init__(self):
+ self._result = None
+ self._errmsg = ""
+
+ def rule_passed(self):
+ return self._result
+
+ def is_consumed(self):
+ if self._result is None:
+ return False
+
+ assert self._result, self._errmsg
+
+ return True
+
+class ExactSQL(SQLMatchRule):
+ def __init__(self, sql, params=None):
+ SQLMatchRule.__init__(self)
+ self.sql = sql
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ if not context:
+ return
+
+ _received_statement = _process_engine_statement(statement, context)
+ _received_parameters = context.compiled_parameters
+
+ # TODO: remove this step once all unit tests
+ # are migrated, as ExactSQL should really be *exact* SQL
+ sql = _process_assertion_statement(self.sql, context)
+
+ equivalent = _received_statement == sql
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+
+ if not isinstance(params, list):
+ params = [params]
+ equivalent = equivalent and params == context.compiled_parameters
+ else:
+ params = {}
+
+
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = "Testing for exact statement %r exact params %r, " \
+ "received %r with params %r" % (sql, params, _received_statement, _received_parameters)
+
+
+class RegexSQL(SQLMatchRule):
+ def __init__(self, regex, params=None):
+ SQLMatchRule.__init__(self)
+ self.regex = re.compile(regex)
+ self.orig_regex = regex
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ if not context:
+ return
+
+ _received_statement = _process_engine_statement(statement, context)
+ _received_parameters = context.compiled_parameters
+
+ equivalent = bool(self.regex.match(_received_statement))
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+
+ if not isinstance(params, list):
+ params = [params]
+
+ # do a positive compare only
+ for param, received in zip(params, _received_parameters):
+ for k, v in param.iteritems():
+ if k not in received or received[k] != v:
+ equivalent = False
+ break
+ else:
+ params = {}
+
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = "Testing for regex %r partial params %r, "\
+ "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters)
+
+class CompiledSQL(SQLMatchRule):
+ def __init__(self, statement, params):
+ SQLMatchRule.__init__(self)
+ self.statement = statement
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ if not context:
+ return
+
+ _received_parameters = context.compiled_parameters
+
+ # recompile from the context, using the default dialect
+ compiled = context.compiled.statement.\
+ compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys)
+
+ _received_statement = re.sub(r'\n', '', str(compiled))
+
+ equivalent = self.statement == _received_statement
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+
+ if not isinstance(params, list):
+ params = [params]
+
+ # do a positive compare only
+ for param, received in zip(params, _received_parameters):
+ for k, v in param.iteritems():
+ if k not in received or received[k] != v:
+ equivalent = False
+ break
+ else:
+ params = {}
+
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = "Testing for compiled statement %r partial params %r, " \
+ "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters)
+
+
+class CountStatements(AssertRule):
+ def __init__(self, count):
+ self.count = count
+ self._statement_count = 0
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ self._statement_count += 1
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ pass
+
+ def is_consumed(self):
+ return False
+
+ def consume_final(self):
+ assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count)
+ return True
+
+class AllOf(AssertRule):
+ def __init__(self, *rules):
+ self.rules = set(rules)
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ for rule in self.rules:
+ rule.process_execute(clauseelement, *multiparams, **params)
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ for rule in self.rules:
+ rule.process_cursor_execute(statement, parameters, context, executemany)
+
+ def is_consumed(self):
+ if not self.rules:
+ return True
+
+ for rule in list(self.rules):
+ if rule.rule_passed(): # a rule passed, move on
+ self.rules.remove(rule)
+ return len(self.rules) == 0
+
+ assert False, "No assertion rules were satisfied for statement"
+
+ def consume_final(self):
+ return len(self.rules) == 0
+
+def _process_engine_statement(query, context):
+ if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
+ query = query[:-25]
+
+ query = re.sub(r'\n', '', query)
+
+ return query
+
+def _process_assertion_statement(query, context):
+ paramstyle = context.dialect.paramstyle
+ if paramstyle == 'named':
+ pass
+ elif paramstyle =='pyformat':
+ query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+ else:
+ # positional params
+ repl = None
+ if paramstyle=='qmark':
+ repl = "?"
+ elif paramstyle=='format':
+ repl = r"%s"
+ elif paramstyle=='numeric':
+ repl = None
+ query = re.sub(r':([\w_]+)', repl, query)
+
+ return query
+
+class SQLAssert(ConnectionProxy):
+ rules = None
+
+ def add_rules(self, rules):
+ self.rules = list(rules)
+
+ def statement_complete(self):
+ for rule in self.rules:
+ if not rule.consume_final():
+ assert False, "All statements are complete, but pending assertion rules remain"
+
+ def clear_rules(self):
+ del self.rules
+
+ def execute(self, conn, execute, clauseelement, *multiparams, **params):
+ result = execute(clauseelement, *multiparams, **params)
+
+ if self.rules is not None:
+ if not self.rules:
+ assert False, "All rules have been exhausted, but further statements remain"
+ rule = self.rules[0]
+ rule.process_execute(clauseelement, *multiparams, **params)
+ if rule.is_consumed():
+ self.rules.pop(0)
+
+ return result
+
+ def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+ result = execute(cursor, statement, parameters, context)
+
+ if self.rules:
+ rule = self.rules[0]
+ rule.process_cursor_execute(statement, parameters, context, executemany)
+
+ return result
+
+asserter = SQLAssert()
+
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestslibsqlalchemytestcompatpy"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/compat.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/compat.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/compat.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,19 @@
</span><ins>+import types
+import __builtin__
+
+__all__ = '_function_named', 'callable'
+
+
+def _function_named(fn, newname):
+ try:
+ fn.__name__ = newname
+ except:
+ fn = types.FunctionType(fn.func_code, fn.func_globals, newname,
+ fn.func_defaults, fn.func_closure)
+ return fn
+
+try:
+ callable = __builtin__.callable
+except NameError:
+ def callable(fn): return hasattr(fn, '__call__')
+
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestslibsqlalchemytestconfigpy"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/config.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/config.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/config.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,181 @@
</span><ins>+import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
+logging, require = None, None
+
+
+__all__ = 'parser', 'configure', 'options',
+
+db = None
+db_label, db_url, db_opts = None, None, {}
+
+options = None
+file_config = None
+coverage_enabled = False
+
+base_config = """
+[db]
+sqlite=sqlite:///:memory:
+sqlite_file=sqlite:///querytest.db
+postgres=postgres://scott:tiger@...>
+mysql=mysql://scott:tiger@...>
+oracle=oracle://scott:tiger@...>
+oracle8=oracle://scott:tiger@...>
+mssql=mssql://scott:tiger@...>
+firebird=firebird://sysdba:masterkey@...>
+maxdb=maxdb://MONA:RED@...>
+"""
+
+def _log(option, opt_str, value, parser):
+ global logging
+ if not logging:
+ import logging
+ logging.basicConfig()
+
+ if opt_str.endswith('-info'):
+ logging.getLogger(value).setLevel(logging.INFO)
+ elif opt_str.endswith('-debug'):
+ logging.getLogger(value).setLevel(logging.DEBUG)
+
+
+def _list_dbs(*args):
+ print "Available --db options (use --dburi to override)"
+ for macro in sorted(file_config.options('db')):
+ print "%20s\t%s" % (macro, file_config.get('db', macro))
+ sys.exit(0)
+
+def _server_side_cursors(options, opt_str, value, parser):
+ db_opts['server_side_cursors'] = True
+
+def _engine_strategy(options, opt_str, value, parser):
+ if value:
+ db_opts['strategy'] = value
+
+class _ordered_map(object):
+ def __init__(self):
+ self._keys = list()
+ self._data = dict()
+
+ def __setitem__(self, key, value):
+ if key not in self._keys:
+ self._keys.append(key)
+ self._data[key] = value
+
+ def __iter__(self):
+ for key in self._keys:
+ yield self._data[key]
+
+# at one point in refactoring, modules were injecting into the config
+# process. this could probably just become a list now.
+post_configure = _ordered_map()
+
+def _engine_uri(options, file_config):
+ global db_label, db_url
+ db_label = 'sqlite'
+ if options.dburi:
+ db_url = options.dburi
+ db_label = db_url[:db_url.index(':')]
+ elif options.db:
+ db_label = options.db
+ db_url = None
+
+ if db_url is None:
+ if db_label not in file_config.options('db'):
+ raise RuntimeError(
+ "Unknown engine. Specify --dbs for known engines.")
+ db_url = file_config.get('db', db_label)
+post_configure['engine_uri'] = _engine_uri
+
+def _require(options, file_config):
+ if not(options.require or
+ (file_config.has_section('require') and
+ file_config.items('require'))):
+ return
+
+ try:
+ import pkg_resources
+ except ImportError:
+ raise RuntimeError("setuptools is required for version requirements")
+
+ cmdline = []
+ for requirement in options.require:
+ pkg_resources.require(requirement)
+ cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
+
+ if file_config.has_section('require'):
+ for label, requirement in file_config.items('require'):
+ if not label == db_label or label.startswith('%s.' % db_label):
+ continue
+ seen = [c for c in cmdline if requirement.startswith(c)]
+ if seen:
+ continue
+ pkg_resources.require(requirement)
+post_configure['require'] = _require
+
+def _engine_pool(options, file_config):
+ if options.mockpool:
+ from sqlalchemy import pool
+ db_opts['poolclass'] = pool.AssertionPool
+post_configure['engine_pool'] = _engine_pool
+
+def _create_testing_engine(options, file_config):
+ from sqlalchemy.test import engines, testing
+ global db
+ db = engines.testing_engine(db_url, db_opts)
+ testing.db = db
+post_configure['create_engine'] = _create_testing_engine
+
+def _prep_testing_database(options, file_config):
+ from sqlalchemy.test import engines
+ from sqlalchemy import schema
+
+ try:
+ # also create alt schemas etc. here?
+ if options.dropfirst:
+ e = engines.utf8_engine()
+ existing = e.table_names()
+ if existing:
+ if not options.quiet:
+ print "Dropping existing tables in database: " + db_url
+ try:
+ print "Tables: %s" % ', '.join(existing)
+ except:
+ pass
+ print "Abort within 5 seconds..."
+ time.sleep(5)
+ md = schema.MetaData(e, reflect=True)
+ md.drop_all()
+ e.dispose()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except Exception, e:
+ if not options.quiet:
+ warnings.warn(RuntimeWarning(
+ "Error checking for existing tables in testing "
+ "database: %s" % e))
+post_configure['prep_db'] = _prep_testing_database
+
+def _set_table_options(options, file_config):
+ from sqlalchemy.test import schema
+
+ table_options = schema.table_options
+ for spec in options.tableopts:
+ key, value = spec.split('=')
+ table_options[key] = value
+
+ if options.mysql_engine:
+ table_options['mysql_engine'] = options.mysql_engine
+post_configure['table_options'] = _set_table_options
+
+def _reverse_topological(options, file_config):
+ if options.reversetop:
+ from sqlalchemy.orm import unitofwork
+ from sqlalchemy import topological
+ class RevQueueDepSort(topological.QueueDependencySorter):
+ def __init__(self, tuples, allitems):
+ self.tuples = list(tuples)
+ self.allitems = list(allitems)
+ self.tuples.reverse()
+ self.allitems.reverse()
+ topological.QueueDependencySorter = RevQueueDepSort
+ unitofwork.DependencySorter = RevQueueDepSort
+post_configure['topological'] = _reverse_topological
+
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestslibsqlalchemytestenginespy"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/engines.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/engines.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/engines.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,245 @@
</span><ins>+import sys, types, weakref
+from collections import deque
+import config
+from compat import _function_named, callable
+
+class ConnectionKiller(object):
+ def __init__(self):
+ self.proxy_refs = weakref.WeakKeyDictionary()
+
+ def checkout(self, dbapi_con, con_record, con_proxy):
+ self.proxy_refs[con_proxy] = True
+
+ def _apply_all(self, methods):
+ for rec in self.proxy_refs:
+ if rec is not None and rec.is_valid:
+ try:
+ for name in methods:
+ if callable(name):
+ name(rec)
+ else:
+ getattr(rec, name)()
+ except (SystemExit, KeyboardInterrupt):
+ raise
+ except Exception, e:
+ # fixme
+ sys.stderr.write("\n" + str(e) + "\n")
+
+ def rollback_all(self):
+ self._apply_all(('rollback',))
+
+ def close_all(self):
+ self._apply_all(('rollback', 'close'))
+
+ def assert_all_closed(self):
+ for rec in self.proxy_refs:
+ if rec.is_valid:
+ assert False
+
+testing_reaper = ConnectionKiller()
+
+def assert_conns_closed(fn):
+ def decorated(*args, **kw):
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.assert_all_closed()
+ return _function_named(decorated, fn.__name__)
+
+def rollback_open_connections(fn):
+ """Decorator that rolls back all open connections after fn execution."""
+
+ def decorated(*args, **kw):
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.rollback_all()
+ return _function_named(decorated, fn.__name__)
+
+def close_open_connections(fn):
+ """Decorator that closes all connections after fn execution."""
+
+ def decorated(*args, **kw):
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.close_all()
+ return _function_named(decorated, fn.__name__)
+
+def all_dialects():
+ import sqlalchemy.databases as d
+ for name in d.__all__:
+ mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+ yield mod.dialect()
+
+class ReconnectFixture(object):
+ def __init__(self, dbapi):
+ self.dbapi = dbapi
+ self.connections = []
+
+ def __getattr__(self, key):
+ return getattr(self.dbapi, key)
+
+ def connect(self, *args, **kwargs):
+ conn = self.dbapi.connect(*args, **kwargs)
+ self.connections.append(conn)
+ return conn
+
+ def shutdown(self):
+ for c in list(self.connections):
+ c.close()
+ self.connections = []
+
+def reconnecting_engine(url=None, options=None):
+ url = url or config.db_url
+ dbapi = config.db.dialect.dbapi
+ if not options:
+ options = {}
+ options['module'] = ReconnectFixture(dbapi)
+ engine = testing_engine(url, options)
+ engine.test_shutdown = engine.dialect.dbapi.shutdown
+ return engine
+
+def testing_engine(url=None, options=None):
+ """Produce an engine configured by --options with optional overrides."""
+
+ from sqlalchemy import create_engine
+ from sqlalchemy.test.assertsql import asserter
+
+ url = url or config.db_url
+ options = options or config.db_opts
+
+ options.setdefault('proxy', asserter)
+
+ listeners = options.setdefault('listeners', [])
+ listeners.append(testing_reaper)
+
+ engine = create_engine(url, **options)
+
+ return engine
+
+def utf8_engine(url=None, options=None):
+ """Hook for dialects or drivers that don't handle utf8 by default."""
+
+ from sqlalchemy.engine import url as engine_url
+
+ if config.db.name == 'mysql':
+ dbapi_ver = config.db.dialect.dbapi.version_info
+ if (dbapi_ver < (1, 2, 1) or
+ dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
+ (1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))):
+ raise RuntimeError('Character set support unavailable with this '
+ 'driver version: %s' % repr(dbapi_ver))
+ else:
+ url = url or config.db_url
+ url = engine_url.make_url(url)
+ url.query['charset'] = 'utf8'
+ url.query['use_unicode'] = '0'
+ url = str(url)
+
+ return testing_engine(url, options)
+
+def mock_engine(db=None):
+ """Provides a mocking engine based on the current testing.db."""
+
+ from sqlalchemy import create_engine
+
+ dbi = db or config.db
+ buffer = []
+ def executor(sql, *a, **kw):
+ buffer.append(sql)
+ engine = create_engine(dbi.name + '://',
+ strategy='mock', executor=executor)
+ assert not hasattr(engine, 'mock')
+ engine.mock = buffer
+ return engine
+
+class ReplayableSession(object):
+ """A simple record/playback tool.
+
+ This is *not* a mock testing class. It only records a session for later
+ playback and makes no assertions on call consistency whatsoever. It's
+ unlikely to be suitable for anything other than DB-API recording.
+
+ """
+
+ Callable = object()
+ NoAttribute = object()
+ Natives = set([getattr(types, t)
+ for t in dir(types) if not t.startswith('_')]). \
+ difference([getattr(types, t)
+ for t in ('FunctionType', 'BuiltinFunctionType',
+ 'MethodType', 'BuiltinMethodType',
+ 'LambdaType', 'UnboundMethodType',)])
+ def __init__(self):
+ self.buffer = deque()
+
+ def recorder(self, base):
+ return self.Recorder(self.buffer, base)
+
+ def player(self):
+ return self.Player(self.buffer)
+
+ class Recorder(object):
+ def __init__(self, buffer, subject):
+ self._buffer = buffer
+ self._subject = subject
+
+ def __call__(self, *args, **kw):
+ subject, buffer = [object.__getattribute__(self, x)
+ for x in ('_subject', '_buffer')]
+
+ result = subject(*args, **kw)
+ if type(result) not in ReplayableSession.Natives:
+ buffer.append(ReplayableSession.Callable)
+ return type(self)(buffer, result)
+ else:
+ buffer.append(result)
+ return result
+
+ def __getattribute__(self, key):
+ try:
+ return object.__getattribute__(self, key)
+ except AttributeError:
+ pass
+
+ subject, buffer = [object.__getattribute__(self, x)
+ for x in ('_subject', '_buffer')]
+ try:
+ result = type(subject).__getattribute__(subject, key)
+ except AttributeError:
+ buffer.append(ReplayableSession.NoAttribute)
+ raise
+ else:
+ if type(result) not in ReplayableSession.Natives:
+ buffer.append(ReplayableSession.Callable)
+ return type(self)(buffer, result)
+ else:
+ buffer.append(result)
+ return result
+
+ class Player(object):
+ def __init__(self, buffer):
+ self._buffer = buffer
+
+ def __call__(self, *args, **kw):
+ buffer = object.__getattribute__(self, '_buffer')
+ result = buffer.popleft()
+ if result is ReplayableSession.Callable:
+ return self
+ else:
+ return result
+
+ def __getattribute__(self, key):
+ try:
+ return object.__getattribute__(self, key)
+ except AttributeError:
+ pass
+ buffer = object.__getattribute__(self, '_buffer')
+ result = buffer.popleft()
+ if result is ReplayableSession.Callable:
+ return self
+ elif result is ReplayableSession.NoAttribute:
+ raise AttributeError(key)
+ else:
+ return result
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestslibsqlalchemytestnosepluginpy"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/noseplugin.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/noseplugin.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/noseplugin.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,160 @@
</span><ins>+import logging
+import os
+import re
+import sys
+import time
+import warnings
+import ConfigParser
+import StringIO
+from config import db, db_label, db_url, file_config, base_config, \
+ post_configure, \
+ _list_dbs, _server_side_cursors, _engine_strategy, \
+ _engine_uri, _require, _engine_pool, \
+ _create_testing_engine, _prep_testing_database, \
+ _set_table_options, _reverse_topological, _log
+import testing, config, requires
+from nose.plugins import Plugin
+from nose.util import tolist
+import nose.case
+
+log = logging.getLogger('nose.plugins.sqlalchemy')
+requires = None
+
+class NoseSQLAlchemy(Plugin):
+ """
+ Handles the setup and extra properties required for testing SQLAlchemy
+ """
+ enabled = True
+ name = 'sqlalchemy'
+ score = 100
+
+ def options(self, parser, env=os.environ):
+ Plugin.options(self, parser, env)
+ opt = parser.add_option
+ #opt("--verbose", action="store_true", dest="verbose",
+ #help="enable stdout echoing/printing")
+ #opt("--quiet", action="store_true", dest="quiet", help="suppress output")
+ opt("--log-info", action="callback", type="string", callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)")
+ opt("--log-debug", action="callback", type="string", callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)")
+ opt("--require", action="append", dest="require", default=[],
+ help="require a particular driver or module version (multiple OK)")
+ opt("--db", action="store", dest="db", default="sqlite",
+ help="Use prefab database uri")
+ opt('--dbs', action='callback', callback=_list_dbs,
+ help="List available prefab dbs")
+ opt("--dburi", action="store", dest="dburi",
+ help="Database uri (overrides --db)")
+ opt("--dropfirst", action="store_true", dest="dropfirst",
+ help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)")
+ opt("--mockpool", action="store_true", dest="mockpool",
+ help="Use mock pool (asserts only one connection used)")
+ opt("--enginestrategy", action="callback", type="string",
+ callback=_engine_strategy,
+ help="Engine strategy (plain or threadlocal, defaults to plain)")
+ opt("--reversetop", action="store_true", dest="reversetop", default=False,
+ help="Reverse the collection ordering for topological sorts (helps "
+ "reveal dependency issues)")
+ opt("--unhashable", action="store_true", dest="unhashable", default=False,
+ help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
+ opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
+ help="Disallow SQLAlchemy from performing == on mapped test objects.")
+ opt("--truthless", action="store_true", dest="truthless", default=False,
+ help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
+ opt("--serverside", action="callback", callback=_server_side_cursors,
+ help="Turn on server side cursors for PG")
+ opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
+ help="Use the specified MySQL storage engine for all tables, default is "
+ "a db-default/InnoDB combo.")
+ opt("--table-option", action="append", dest="tableopts", default=[],
+ help="Add a dialect-specific table option, key=value")
+
+ def configure(self, options, conf):
+ Plugin.configure(self, options, conf)
+# sys.path.insert(0, os.path.join(os.getcwd(), '../lib'))
+
+ #conf.exclude = map(re.compile, tolist(r'^(manage\.py|.*settings\.py|apps)$'))
+
+ #global options#, config
+ global file_config#, getopts_options
+ file_config = ConfigParser.ConfigParser()
+ file_config.readfp(StringIO.StringIO(base_config))
+ file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
+
+ import testing, requires
+ testing.db = db
+ testing.requires = requires
+
+ # Lazy setup of other options (post coverage)
+ for fn in post_configure:
+ fn(options, file_config)
+
+ def describeTest(self, test):
+ return ""
+
+ def wantClass(self, cls):
+ """Return true if you want the main test selector to collect
+ tests from this class, false if you don't, and None if you don't
+ care.
+
+ :Parameters:
+ cls : class
+ The class being examined by the selector
+
+ """
+
+ if re.search(r'^Test|Test$', cls.__name__) is None:
+ return True
+ else:
+ if (hasattr(cls, '__whitelist__') and
+ testing.db.name in cls.__whitelist__):
+ return True
+ else:
+ if self.__should_skip_for(cls):
+ return False
+ else:
+ return True
+
+ def __should_skip_for(self, cls):
+ if hasattr(cls, '__requires__'):
+ def test_suite(): return 'ok'
+ for requirement in cls.__requires__:
+ check = getattr(requires, requirement)
+ if check(test_suite)() != 'ok':
+ # The requirement will perform messaging.
+ return True
+ if (hasattr(cls, '__unsupported_on__') and
+ testing.db.name in cls.__unsupported_on__):
+ print "'%s' unsupported on DB implementation '%s'" % (
+ cls.__class__.__name__, testing.db.name)
+ return True
+ if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)):
+ print "'%s' unsupported on DB implementation '%s'" % (
+ cls.__class__.__name__, testing.db.name)
+ return True
+ if (getattr(cls, '__skip_if__', False)):
+ for c in getattr(cls, '__skip_if__'):
+ if c():
+ print "'%s' skipped by %s" % (
+ cls.__class__.__name__, c.__name__)
+ return True
+ for rule in getattr(cls, '__excluded_on__', ()):
+ if testing._is_excluded(*rule):
+ print "'%s' unsupported on DB %s version %s" % (
+ cls.__class__.__name__, testing.db.name,
+ _server_version())
+ return True
+ return False
+
+ #def begin(self):
+ #pass
+
+ #def beforeTest(self, test):
+ #pass
+
+ #def handleError(self, test, err):
+ #pass
+
+ #def finalize(self, result=None):
+ #pass
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestslibsqlalchemytestormpy"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/orm.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/orm.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/orm.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,117 @@
</span><ins>+import inspect, re
+import config, testing
+
+sa = None
+orm = None
+
+__all__ = 'mapper',
+
+
+_whitespace = re.compile(r'^(\s+)')
+
+def _find_pragma(lines, current):
+ m = _whitespace.match(lines[current])
+ basis = m and m.group() or ''
+
+ for line in reversed(lines[0:current]):
+ if 'testlib.pragma' in line:
+ return line
+ m = _whitespace.match(line)
+ indent = m and m.group() or ''
+
+ # simplistic detection:
+
+ # >> # testlib.pragma foo
+ # >> center_line()
+ if indent == basis:
+ break
+ # >> # testlib.pragma foo
+ # >> if fleem:
+ # >> center_line()
+ if line.endswith(':'):
+ break
+ return None
+
+def _make_blocker(method_name, fallback):
+ """Creates tripwired variant of a method, raising when called.
+
+ To excempt an invocation from blockage, there are two options.
+
+ 1) add a pragma in a comment::
+
+ # testlib.pragma exempt:methodname
+ offending_line()
+
+ 2) add a magic cookie to the function's namespace::
+ __sa_baremethodname_exempt__ = True
+ ...
+ offending_line()
+ another_offending_lines()
+
+ The second is useful for testing and development.
+ """
+
+ if method_name.startswith('__') and method_name.endswith('__'):
+ frame_marker = '__sa_%s_exempt__' % method_name[2:-2]
+ else:
+ frame_marker = '__sa_%s_exempt__' % method_name
+ pragma_marker = 'exempt:' + method_name
+
+ def method(self, *args, **kw):
+ frame_r = None
+ try:
+ frame = inspect.stack()[1][0]
+ frame_r = inspect.getframeinfo(frame, 9)
+
+ module = frame.f_globals.get('__name__', '')
+
+ type_ = type(self)
+
+ pragma = _find_pragma(*frame_r[3:5])
+
+ exempt = (
+ (not module.startswith('sqlalchemy')) or
+ (pragma and pragma_marker in pragma) or
+ (frame_marker in frame.f_locals) or
+ ('self' in frame.f_locals and
+ getattr(frame.f_locals['self'], frame_marker, False)))
+
+ if exempt:
+ supermeth = getattr(super(type_, self), method_name, None)
+ if (supermeth is None or
+ getattr(supermeth, 'im_func', None) is method):
+ return fallback(self, *args, **kw)
+ else:
+ return supermeth(*args, **kw)
+ else:
+ raise AssertionError(
+ "%s.%s called in %s, line %s in %s" % (
+ type_.__name__, method_name, module, frame_r[1], frame_r[2]))
+ finally:
+ del frame
+ method.__name__ = method_name
+ return method
+
+def mapper(type_, *args, **kw):
+ global orm
+ if orm is None:
+ from sqlalchemy import orm
+
+ forbidden = [
+ ('__hash__', 'unhashable', lambda s: id(s)),
+ ('__eq__', 'noncomparable', lambda s, o: s is o),
+ ('__ne__', 'noncomparable', lambda s, o: s is not o),
+ ('__cmp__', 'noncomparable', lambda s, o: object.__cmp__(s, o)),
+ ('__le__', 'noncomparable', lambda s, o: object.__le__(s, o)),
+ ('__lt__', 'noncomparable', lambda s, o: object.__lt__(s, o)),
+ ('__ge__', 'noncomparable', lambda s, o: object.__ge__(s, o)),
+ ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)),
+ ('__nonzero__', 'truthless', lambda s: 1), ]
+
+ if isinstance(type_, type) and type_.__bases__ == (object,):
+ for method_name, option, fallback in forbidden:
+ if (getattr(config.options, option, False) and
+ method_name not in type_.__dict__):
+ setattr(type_, method_name, _make_blocker(method_name, fallback))
+
+ return orm.mapper(type_, *args, **kw)
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestslibsqlalchemytestprofilingpy"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/profiling.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/profiling.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/profiling.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,202 @@
</span><ins>+"""Profiling support for unit and performance tests."""
+
+import os, sys
+from compat import _function_named, gc_collect
+import config
+
+__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
+
+all_targets = set()
+profile_config = { 'targets': set(),
+ 'report': True,
+ 'sort': ('time', 'calls'),
+ 'limit': None }
+profiler = None
+
+def profiled(target=None, **target_opts):
+ """Optional function profiling.
+
+ @profiled('label')
+ or
+ @profiled('label', report=True, sort=('calls',), limit=20)
+
+ Enables profiling for a function when 'label' is targetted for
+ profiling. Report options can be supplied, and override the global
+ configuration and command-line options.
+ """
+
+ # manual or automatic namespacing by module would remove conflict issues
+ if target is None:
+ target = 'anonymous_target'
+ elif target in all_targets:
+ print "Warning: redefining profile target '%s'" % target
+ all_targets.add(target)
+
+ filename = "%s.prof" % target
+
+ def decorator(fn):
+ def profiled(*args, **kw):
+ if (target not in profile_config['targets'] and
+ not target_opts.get('always', None)):
+ return fn(*args, **kw)
+
+ elapsed, load_stats, result = _profile(
+ filename, fn, *args, **kw)
+
+ report = target_opts.get('report', profile_config['report'])
+ if report:
+ sort_ = target_opts.get('sort', profile_config['sort'])
+ limit = target_opts.get('limit', profile_config['limit'])
+ print "Profile report for target '%s' (%s)" % (
+ target, filename)
+
+ stats = load_stats()
+ stats.sort_stats(*sort_)
+ if limit:
+ stats.print_stats(limit)
+ else:
+ stats.print_stats()
+ #stats.print_callers()
+ os.unlink(filename)
+ return result
+ return _function_named(profiled, fn.__name__)
+ return decorator
+
+def function_call_count(count=None, versions={}, variance=0.05):
+ """Assert a target for a test case's function call count.
+
+ count
+ Optional, general target function call count.
+
+ versions
+ Optional, a dictionary of Python version strings to counts,
+ for example::
+
+ { '2.5.1': 110,
+ '2.5': 100,
+ '2.4': 150 }
+
+ The best match for the current running python will be used.
+ If none match, 'count' will be used as the fallback.
+
+ variance
+ An +/- deviation percentage, defaults to 5%.
+ """
+
+ # this could easily dump the profile report if --verbose is in effect
+
+ version_info = list(sys.version_info)
+ py_version = '.'.join([str(v) for v in sys.version_info])
+
+ while version_info:
+ version = '.'.join([str(v) for v in version_info])
+ if version in versions:
+ count = versions[version]
+ break
+ version_info.pop()
+
+ if count is None:
+ return lambda fn: fn
+
+ def decorator(fn):
+ def counted(*args, **kw):
+ try:
+ filename = "%s.prof" % fn.__name__
+
+ elapsed, stat_loader, result = _profile(
+ filename, fn, *args, **kw)
+
+ stats = stat_loader()
+ calls = stats.total_calls
+
+ stats.sort_stats('calls', 'cumulative')
+ stats.print_stats()
+ #stats.print_callers()
+ deviance = int(count * variance)
+ if (calls < (count - deviance) or
+ calls > (count + deviance)):
+ raise AssertionError(
+ "Function call count %s not within %s%% "
+ "of expected %s. (Python version %s)" % (
+ calls, (variance * 100), count, py_version))
+
+ return result
+ finally:
+ if os.path.exists(filename):
+ os.unlink(filename)
+ return _function_named(counted, fn.__name__)
+ return decorator
+
+def conditional_call_count(discriminator, categories):
+ """Apply a function call count conditionally at runtime.
+
+ Takes two arguments, a callable that returns a key value, and a dict
+ mapping key values to a tuple of arguments to function_call_count.
+
+ The callable is not evaluated until the decorated function is actually
+ invoked. If the `discriminator` returns a key not present in the
+ `categories` dictionary, no call count assertion is applied.
+
+ Useful for integration tests, where running a named test in isolation may
+ have a function count penalty not seen in the full suite, due to lazy
+ initialization in the DB-API, SA, etc.
+ """
+
+ def decorator(fn):
+ def at_runtime(*args, **kw):
+ criteria = categories.get(discriminator(), None)
+ if criteria is None:
+ return fn(*args, **kw)
+
+ rewrapped = function_call_count(*criteria)(fn)
+ return rewrapped(*args, **kw)
+ return _function_named(at_runtime, fn.__name__)
+ return decorator
+
+
+def _profile(filename, fn, *args, **kw):
+ global profiler
+ if not profiler:
+ profiler = 'hotshot'
+ if sys.version_info > (2, 5):
+ try:
+ import cProfile
+ profiler = 'cProfile'
+ except ImportError:
+ pass
+
+ if profiler == 'cProfile':
+ return _profile_cProfile(filename, fn, *args, **kw)
+ else:
+ return _profile_hotshot(filename, fn, *args, **kw)
+
+def _profile_cProfile(filename, fn, *args, **kw):
+ import cProfile, gc, pstats, time
+
+ load_stats = lambda: pstats.Stats(filename)
+ gc_collect()
+
+ began = time.time()
+ cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
+ filename=filename)
+ ended = time.time()
+
+ return ended - began, load_stats, locals()['result']
+
+def _profile_hotshot(filename, fn, *args, **kw):
+ import gc, hotshot, hotshot.stats, time
+ load_stats = lambda: hotshot.stats.load(filename)
+
+ gc_collect()
+ prof = hotshot.Profile(filename)
+ began = time.time()
+ prof.start()
+ try:
+ result = fn(*args, **kw)
+ finally:
+ prof.stop()
+ ended = time.time()
+ prof.close()
+
+ return ended - began, load_stats, result
+
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestslibsqlalchemytestrequirespy"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/requires.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/requires.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/requires.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,127 @@
</span><ins>+"""Global database feature support policy.
+
+Provides decorators to mark tests requiring specific feature support from the
+target database.
+
+"""
+
+from testing import \
+ _block_unconditionally as no_support, \
+ _chain_decorators_on, \
+ exclude, \
+ emits_warning_on
+
+
+def deferrable_constraints(fn):
+ """Target database must support derferable constraints."""
+ return _chain_decorators_on(
+ fn,
+ no_support('firebird', 'not supported by database'),
+ no_support('mysql', 'not supported by database'),
+ no_support('mssql', 'not supported by database'),
+ )
+
+def foreign_keys(fn):
+ """Target database must support foreign keys."""
+ return _chain_decorators_on(
+ fn,
+ no_support('sqlite', 'not supported by database'),
+ )
+
+def identity(fn):
+ """Target database must support GENERATED AS IDENTITY or a facsimile.
+
+ Includes GENERATED AS IDENTITY, AUTOINCREMENT, AUTO_INCREMENT, or other
+ column DDL feature that fills in a DB-generated identifier at INSERT-time
+ without requiring pre-execution of a SEQUENCE or other artifact.
+
+ """
+ return _chain_decorators_on(
+ fn,
+ no_support('firebird', 'not supported by database'),
+ no_support('oracle', 'not supported by database'),
+ no_support('postgres', 'not supported by database'),
+ no_support('sybase', 'not supported by database'),
+ )
+
+def independent_connections(fn):
+ """Target must support simultaneous, independent database connections."""
+
+ # This is also true of some configurations of UnixODBC and probably win32
+ # ODBC as well.
+ return _chain_decorators_on(
+ fn,
+ no_support('sqlite', 'no driver support')
+ )
+
+def row_triggers(fn):
+ """Target must support standard statement-running EACH ROW triggers."""
+ return _chain_decorators_on(
+ fn,
+ # no access to same table
+ no_support('mysql', 'requires SUPER priv'),
+ exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
+ no_support('postgres', 'not supported by database: no statements'),
+ )
+
+def savepoints(fn):
+ """Target database must support savepoints."""
+ return _chain_decorators_on(
+ fn,
+ emits_warning_on('mssql', 'Savepoint support in mssql is experimental and may lead to data loss.'),
+ no_support('access', 'not supported by database'),
+ no_support('sqlite', 'not supported by database'),
+ no_support('sybase', 'FIXME: guessing, needs confirmation'),
+ exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
+ )
+
+def sequences(fn):
+ """Target database must support SEQUENCEs."""
+ return _chain_decorators_on(
+ fn,
+ no_support('access', 'no SEQUENCE support'),
+ no_support('mssql', 'no SEQUENCE support'),
+ no_support('mysql', 'no SEQUENCE support'),
+ no_support('sqlite', 'no SEQUENCE support'),
+ no_support('sybase', 'no SEQUENCE support'),
+ )
+
+def subqueries(fn):
+ """Target database must support subqueries."""
+ return _chain_decorators_on(
+ fn,
+ exclude('mysql', '<', (4, 1, 1), 'no subquery support'),
+ )
+
+def two_phase_transactions(fn):
+ """Target database must support two-phase transactions."""
+ return _chain_decorators_on(
+ fn,
+ no_support('access', 'not supported by database'),
+ no_support('firebird', 'no SA implementation'),
+ no_support('maxdb', 'not supported by database'),
+ no_support('mssql', 'FIXME: guessing, needs confirmation'),
+ no_support('oracle', 'no SA implementation'),
+ no_support('sqlite', 'not supported by database'),
+ no_support('sybase', 'FIXME: guessing, needs confirmation'),
+ exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
+ )
+
+def unicode_connections(fn):
+ """Target driver must support some encoding of Unicode across the wire."""
+ # TODO: expand to exclude MySQLdb versions w/ broken unicode
+ return _chain_decorators_on(
+ fn,
+ exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
+ )
+
+def unicode_ddl(fn):
+ """Target driver must support some encoding of Unicode across the wire."""
+ # TODO: expand to exclude MySQLdb versions w/ broken unicode
+ return _chain_decorators_on(
+ fn,
+ no_support('maxdb', 'database support flakey'),
+ no_support('oracle', 'FIXME: no support in database?'),
+ no_support('sybase', 'FIXME: guessing, needs confirmation'),
+ exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
+ )
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestslibsqlalchemytestschemapy"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/schema.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/schema.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/schema.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,79 @@
</span><ins>+import testing
+
+schema = None
+
+__all__ = 'Table', 'Column',
+
+table_options = {}
+
+def Table(*args, **kw):
+ """A schema.Table wrapper/hook for dialect-specific tweaks."""
+
+ global schema
+ if schema is None:
+ from sqlalchemy import schema
+
+ test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
+ if k.startswith('test_')])
+
+ kw.update(table_options)
+
+ if testing.against('mysql'):
+ if 'mysql_engine' not in kw and 'mysql_type' not in kw:
+ if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
+ kw['mysql_engine'] = 'InnoDB'
+
+ # Apply some default cascading rules for self-referential foreign keys.
+ # MySQL InnoDB has some issues around seleting self-refs too.
+ if testing.against('firebird'):
+ table_name = args[0]
+ unpack = (testing.config.db.dialect.
+ identifier_preparer.unformat_identifiers)
+
+ # Only going after ForeignKeys in Columns. May need to
+ # expand to ForeignKeyConstraint too.
+ fks = [fk
+ for col in args if isinstance(col, schema.Column)
+ for fk in col.args if isinstance(fk, schema.ForeignKey)]
+
+ for fk in fks:
+ # root around in raw spec
+ ref = fk._colspec
+ if isinstance(ref, schema.Column):
+ name = ref.table.name
+ else:
+ # take just the table name: on FB there cannot be
+ # a schema, so the first element is always the
+ # table name, possibly followed by the field name
+ name = unpack(ref)[0]
+ if name == table_name:
+ if fk.ondelete is None:
+ fk.ondelete = 'CASCADE'
+ if fk.onupdate is None:
+ fk.onupdate = 'CASCADE'
+
+ if testing.against('firebird', 'oracle'):
+ pk_seqs = [col for col in args
+ if (isinstance(col, schema.Column)
+ and col.primary_key
+ and getattr(col, '_needs_autoincrement', False))]
+ for c in pk_seqs:
+ c.args.append(schema.Sequence(args[0] + '_' + c.name + '_seq', optional=True))
+ return schema.Table(*args, **kw)
+
+
+def Column(*args, **kw):
+ """A schema.Column wrapper/hook for dialect-specific tweaks."""
+
+ global schema
+ if schema is None:
+ from sqlalchemy import schema
+
+ test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
+ if k.startswith('test_')])
+
+ c = schema.Column(*args, **kw)
+ if testing.against('firebird', 'oracle'):
+ if 'test_needs_autoincrement' in test_opts:
+ c._needs_autoincrement = True
+ return c
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestslibsqlalchemytesttestingpy"></a>
<div class="addfile"><h4>Added: sqlalchemy/branches/nosetests/lib/sqlalchemy/test/testing.py (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/lib/sqlalchemy/test/testing.py (rev 0)
+++ sqlalchemy/branches/nosetests/lib/sqlalchemy/test/testing.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,742 @@
</span><ins>+"""TestCase and TestSuite artifacts and testing decorators."""
+
+import itertools
+import operator
+import re
+import sys
+import types
+import warnings
+from cStringIO import StringIO
+
+import config
+from compat import _function_named, callable
+
+# Delayed imports
+MetaData = None
+Session = None
+clear_mappers = None
+sa_exc = None
+schema = None
+sqltypes = None
+util = None
+
+
+_ops = { '<': operator.lt,
+ '>': operator.gt,
+ '==': operator.eq,
+ '!=': operator.ne,
+ '<=': operator.le,
+ '>=': operator.ge,
+ 'in': operator.contains,
+ 'between': lambda val, pair: val >= pair[0] and val <= pair[1],
+ }
+
+# sugar ('testing.db'); set here by config() at runtime
+db = None
+
+# more sugar, installed by __init__
+requires = None
+
+def fails_if(callable_):
+ """Mark a test as expected to fail if callable_ returns True.
+
+ If the callable returns false, the test is run and reported as normal.
+ However if the callable returns true, the test is expected to fail and the
+ unit test logic is inverted: if the test fails, a success is reported. If
+ the test succeeds, a failure is reported.
+ """
+
+ docstring = getattr(callable_, '__doc__', None) or callable_.__name__
+ description = docstring.split('\n')[0]
+
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if not callable_():
+ return fn(*args, **kw)
+ else:
+ try:
+ fn(*args, **kw)
+ except Exception, ex:
+ print ("'%s' failed as expected (condition: %s): %s " % (
+ fn_name, description, str(ex)))
+ return True
+ else:
+ raise AssertionError(
+ "Unexpected success for '%s' (condition: %s)" %
+ (fn_name, description))
+ return _function_named(maybe, fn_name)
+ return decorate
+
+
+def future(fn):
+ """Mark a test as expected to unconditionally fail.
+
+ Takes no arguments, omit parens when using as a decorator.
+ """
+
+ fn_name = fn.__name__
+ def decorated(*args, **kw):
+ try:
+ fn(*args, **kw)
+ except Exception, ex:
+ print ("Future test '%s' failed as expected: %s " % (
+ fn_name, str(ex)))
+ return True
+ else:
+ raise AssertionError(
+ "Unexpected success for future test '%s'" % fn_name)
+ return _function_named(decorated, fn_name)
+
+def fails_on(dbs, reason):
+ """Mark a test as expected to fail on the specified database
+ implementation.
+
+ Unlike ``crashes``, tests marked as ``fails_on`` will be run
+ for the named databases. The test is expected to fail and the unit test
+ logic is inverted: if the test fails, a success is reported. If the test
+ succeeds, a failure is reported.
+ """
+
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if config.db.name != dbs:
+ return fn(*args, **kw)
+ else:
+ try:
+ fn(*args, **kw)
+ except Exception, ex:
+ print ("'%s' failed as expected on DB implementation "
+ "'%s': %s" % (
+ fn_name, config.db.name, reason))
+ return True
+ else:
+ raise AssertionError(
+ "Unexpected success for '%s' on DB implementation '%s'" %
+ (fn_name, config.db.name))
+ return _function_named(maybe, fn_name)
+ return decorate
+
+def fails_on_everything_except(*dbs):
+ """Mark a test as expected to fail on most database implementations.
+
+ Like ``fails_on``, except failure is the expected outcome on all
+ databases except those listed.
+ """
+
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if config.db.name in dbs:
+ return fn(*args, **kw)
+ else:
+ try:
+ fn(*args, **kw)
+ except Exception, ex:
+ print ("'%s' failed as expected on DB implementation "
+ "'%s': %s" % (
+ fn_name, config.db.name, str(ex)))
+ return True
+ else:
+ raise AssertionError(
+ "Unexpected success for '%s' on DB implementation '%s'" %
+ (fn_name, config.db.name))
+ return _function_named(maybe, fn_name)
+ return decorate
+
+def crashes(db, reason):
+ """Mark a test as unsupported by a database implementation.
+
+ ``crashes`` tests will be skipped unconditionally. Use for feature tests
+ that cause deadlocks or other fatal problems.
+
+ """
+ carp = _should_carp_about_exclusion(reason)
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if config.db.name == db:
+ msg = "'%s' unsupported on DB implementation '%s': %s" % (
+ fn_name, config.db.name, reason)
+ print msg
+ if carp:
+ print >> sys.stderr, msg
+ return True
+ else:
+ return fn(*args, **kw)
+ return _function_named(maybe, fn_name)
+ return decorate
+
+def _block_unconditionally(db, reason):
+ """Mark a test as unsupported by a database implementation.
+
+ Will never run the test against any version of the given database, ever,
+ no matter what. Use when your assumptions are infallible; past, present
+ and future.
+
+ """
+ carp = _should_carp_about_exclusion(reason)
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if config.db.name == db:
+ msg = "'%s' unsupported on DB implementation '%s': %s" % (
+ fn_name, config.db.name, reason)
+ print msg
+ if carp:
+ print >> sys.stderr, msg
+ return True
+ else:
+ return fn(*args, **kw)
+ return _function_named(maybe, fn_name)
+ return decorate
+
+
+def exclude(db, op, spec, reason):
+ """Mark a test as unsupported by specific database server versions.
+
+ Stackable, both with other excludes and other decorators. Examples::
+
+ # Not supported by mydb versions less than 1, 0
+ @exclude('mydb', '<', (1,0))
+ # Other operators work too
+ @exclude('bigdb', '==', (9,0,9))
+ @exclude('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
+
+ """
+ carp = _should_carp_about_exclusion(reason)
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if _is_excluded(db, op, spec):
+ msg = "'%s' unsupported on DB %s version '%s': %s" % (
+ fn_name, config.db.name, _server_version(), reason)
+ print msg
+ if carp:
+ print >> sys.stderr, msg
+ return True
+ else:
+ return fn(*args, **kw)
+ return _function_named(maybe, fn_name)
+ return decorate
+
+def _should_carp_about_exclusion(reason):
+ """Guard against forgotten exclusions."""
+ assert reason
+ for _ in ('todo', 'fixme', 'xxx'):
+ if _ in reason.lower():
+ return True
+ else:
+ if len(reason) < 4:
+ return True
+
+def _is_excluded(db, op, spec):
+ """Return True if the configured db matches an exclusion specification.
+
+ db:
+ A dialect name
+ op:
+ An operator or stringified operator, such as '=='
+ spec:
+ A value that will be compared to the dialect's server_version_info
+ using the supplied operator.
+
+ Examples::
+ # Not supported by mydb versions less than 1, 0
+ _is_excluded('mydb', '<', (1,0))
+ # Other operators work too
+ _is_excluded('bigdb', '==', (9,0,9))
+ _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
+ """
+
+ if config.db.name != db:
+ return False
+
+ version = _server_version()
+
+ oper = hasattr(op, '__call__') and op or _ops[op]
+ return oper(version, spec)
+
+def _server_version(bind=None):
+ """Return a server_version_info tuple."""
+
+ if bind is None:
+ bind = config.db
+ return bind.dialect.server_version_info(bind.contextual_connect())
+
+def skip_if(predicate, reason=None):
+ """Skip a test if predicate is true."""
+ reason = reason or predicate.__name__
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if predicate():
+ msg = "'%s' skipped on DB %s version '%s': %s" % (
+ fn_name, config.db.name, _server_version(), reason)
+ print msg
+ return True
+ else:
+ return fn(*args, **kw)
+ return _function_named(maybe, fn_name)
+ return decorate
+
+def emits_warning(*messages):
+ """Mark a test as emitting a warning.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+ """
+
+ # TODO: it would be nice to assert that a named warning was
+ # emitted. should work with some monkeypatching of warnings,
+ # and may work on non-CPython if they keep to the spirit of
+ # warnings.showwarning's docstring.
+ # - update: jython looks ok, it uses cpython's module
+ def decorate(fn):
+ def safe(*args, **kw):
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
+
+ # todo: should probably be strict about this, too
+ filters = [dict(action='ignore',
+ category=sa_exc.SAPendingDeprecationWarning)]
+ if not messages:
+ filters.append(dict(action='ignore',
+ category=sa_exc.SAWarning))
+ else:
+ filters.extend(dict(action='ignore',
+ message=message,
+ category=sa_exc.SAWarning)
+ for message in messages)
+ for f in filters:
+ warnings.filterwarnings(**f)
+ try:
+ return fn(*args, **kw)
+ finally:
+ resetwarnings()
+ return _function_named(safe, fn.__name__)
+ return decorate
+
+def emits_warning_on(db, *warnings):
+ """Mark a test as emitting a warning on a specific dialect.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+ """
+ def decorate(fn):
+ def maybe(*args, **kw):
+ if isinstance(db, basestring):
+ if config.db.name != db:
+ return fn(*args, **kw)
+ else:
+ wrapped = emits_warning(*warnings)(fn)
+ return wrapped(*args, **kw)
+ else:
+ if not _is_excluded(*db):
+ return fn(*args, **kw)
+ else:
+ wrapped = emits_warning(*warnings)(fn)
+ return wrapped(*args, **kw)
+ return _function_named(maybe, fn.__name__)
+ return decorate
+
+def uses_deprecated(*messages):
+ """Mark a test as immune from fatal deprecation warnings.
+
+ With no arguments, squelches all SADeprecationWarning failures.
+ Or pass one or more strings; these will be matched to the root
+ of the warning description by warnings.filterwarnings().
+
+ As a special case, you may pass a function name prefixed with //
+ and it will be re-written as needed to match the standard warning
+ verbiage emitted by the sqlalchemy.util.deprecated decorator.
+ """
+
+ def decorate(fn):
+ def safe(*args, **kw):
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
+
+ # todo: should probably be strict about this, too
+ filters = [dict(action='ignore',
+ category=sa_exc.SAPendingDeprecationWarning)]
+ if not messages:
+ filters.append(dict(action='ignore',
+ category=sa_exc.SADeprecationWarning))
+ else:
+ filters.extend(
+ [dict(action='ignore',
+ message=message,
+ category=sa_exc.SADeprecationWarning)
+ for message in
+ [ (m.startswith('//') and
+ ('Call to deprecated function ' + m[2:]) or m)
+ for m in messages] ])
+
+ for f in filters:
+ warnings.filterwarnings(**f)
+ try:
+ return fn(*args, **kw)
+ finally:
+ resetwarnings()
+ return _function_named(safe, fn.__name__)
+ return decorate
+
+def resetwarnings():
+ """Reset warning behavior to testing defaults."""
+
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
+
+ warnings.filterwarnings('ignore',
+ category=sa_exc.SAPendingDeprecationWarning)
+ warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings('error', category=sa_exc.SAWarning)
+
+# warnings.simplefilter('error')
+
+ if sys.version_info < (2, 4):
+ warnings.filterwarnings('ignore', category=FutureWarning)
+
+
+def against(*queries):
+ """Boolean predicate, compares to testing database configuration.
+
+ Given one or more dialect names, returns True if one is the configured
+ database engine.
+
+ Also supports comparison to database version when provided with one or
+ more 3-tuples of dialect name, operator, and version specification::
+
+ testing.against('mysql', 'postgres')
+ testing.against(('mysql', '>=', (5, 0, 0))
+ """
+
+ for query in queries:
+ if isinstance(query, basestring):
+ if config.db.name == query:
+ return True
+ else:
+ name, op, spec = query
+ if config.db.name != name:
+ continue
+
+ have = config.db.dialect.server_version_info(
+ config.db.contextual_connect())
+
+ oper = hasattr(op, '__call__') and op or _ops[op]
+ if oper(have, spec):
+ return True
+ return False
+
+def _chain_decorators_on(fn, *decorators):
+ """Apply a series of decorators to fn, returning a decorated function."""
+ for decorator in reversed(decorators):
+ fn = decorator(fn)
+ return fn
+
+def rowset(results):
+ """Converts the results of sql execution into a plain set of column tuples.
+
+ Useful for asserting the results of an unordered query.
+ """
+
+ return set([tuple(row) for row in results])
+
+
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+def is_not_(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+def startswith_(a, fragment, msg=None):
+ """Assert a.startswith(fragment), with repr messaging on failure."""
+ assert a.startswith(fragment), msg or "%r does not start with %r" % (
+ a, fragment)
+
+
+def fixture(table, columns, *rows):
+ """Insert data into table after creation."""
+ def onload(event, schema_item, connection):
+ insert = table.insert()
+ column_names = [col.key for col in columns]
+ connection.execute(insert, [dict(zip(column_names, column_values))
+ for column_values in rows])
+ table.append_ddl_listener('after-create', onload)
+
+def resolve_artifact_names(fn):
+ """Decorator, augment function globals with tables and classes.
+
+ Swaps out the function's globals at execution time. The 'global' statement
+ will not work as expected inside a decorated function.
+
+ """
+ # This could be automatically applied to framework and test_ methods in
+ # the MappedTest-derived test suites but... *some* explicitness for this
+ # magic is probably good. Especially as 'global' won't work- these
+ # rebound functions aren't regular Python..
+ #
+ # Also: it's lame that CPython accepts a dict-subclass for globals, but
+ # only calls dict methods. That would allow 'global' to pass through to
+ # the func_globals.
+ def resolved(*args, **kwargs):
+ self = args[0]
+ context = dict(fn.func_globals)
+ for source in self._artifact_registries:
+ context.update(getattr(self, source))
+ # jython bug #1034
+ rebound = types.FunctionType(
+ fn.func_code, context, fn.func_name, fn.func_defaults,
+ fn.func_closure)
+ return rebound(*args, **kwargs)
+ return _function_named(resolved, fn.func_name)
+
+class adict(dict):
+ """Dict keys available as attributes. Shadows."""
+ def __getattribute__(self, key):
+ try:
+ return self[key]
+ except KeyError:
+ return dict.__getattribute__(self, key)
+
+ def get_all(self, *keys):
+ return tuple([self[key] for key in keys])
+
+
+class TestBase(object):
+ # A sequence of database names to always run, regardless of the
+ # constraints below.
+ __whitelist__ = ()
+
+ # A sequence of requirement names matching testing.requires decorators
+ __requires__ = ()
+
+ # A sequence of dialect names to exclude from the test class.
+ __unsupported_on__ = ()
+
+ # If present, test class is only runnable for the *single* specified
+ # dialect. If you need multiple, use __unsupported_on__ and invert.
+ __only_on__ = None
+
+ # A sequence of no-arg callables. If any are True, the entire testcase is
+ # skipped.
+ __skip_if__ = None
+
+
+ _artifact_registries = ()
+
+ _sa_first_test = False
+ _sa_last_test = False
+
+ def assert_(self, val, msg=None):
+ assert val, msg
+
+ def assertEqual(self, x, y):
+ eq_(x, y)
+
+ def assertEquals(self, x, y):
+ eq_(x, y)
+
+ def assertRaises(self, except_cls, callable_, *args, **kw):
+ try:
+ callable_(*args, **kwargs)
+ assert False, "Callable did not raise an exception"
+ except except_cls, e:
+ pass
+
+ def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs):
+ try:
+ callable_(*args, **kwargs)
+ assert False, "Callable did not raise an exception"
+ except except_cls, e:
+ assert re.search(msg, str(e)), "%r !~ %s" % (msg, e)
+
+class AssertsCompiledSQL(object):
+ def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None):
+ if dialect is None:
+ dialect = getattr(self, '__dialect__', None)
+
+ if params is None:
+ keys = None
+ else:
+ keys = params.keys()
+
+ c = clause.compile(column_keys=keys, dialect=dialect)
+
+ print "\nSQL String:\n" + str(c) + repr(c.params)
+
+ cc = re.sub(r'\n', '', str(c))
+
+ self.assertEquals(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
+
+ if checkparams is not None:
+ self.assertEquals(c.construct_params(params), checkparams)
+
+class ComparesTables(object):
+ def assert_tables_equal(self, table, reflected_table):
+ global sqltypes, schema
+ if sqltypes is None:
+ import sqlalchemy.types as sqltypes
+ if schema is None:
+ import sqlalchemy.schema as schema
+ base_mro = sqltypes.TypeEngine.__mro__
+ assert len(table.c) == len(reflected_table.c)
+ for c, reflected_c in zip(table.c, reflected_table.c):
+ self.assertEquals(c.name, reflected_c.name)
+ assert reflected_c is reflected_table.c[c.name]
+ self.assertEquals(c.primary_key, reflected_c.primary_key)
+ self.assertEquals(c.nullable, reflected_c.nullable)
+ assert len(
+ set(type(reflected_c.type).__mro__).difference(base_mro).intersection(
+ set(type(c.type).__mro__).difference(base_mro)
+ )
+ ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+
+ if isinstance(c.type, sqltypes.String):
+ self.assertEquals(c.type.length, reflected_c.type.length)
+
+ self.assertEquals(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
+ if c.default:
+ assert isinstance(reflected_c.server_default,
+ schema.FetchedValue)
+ elif against(('mysql', '<', (5, 0))):
+ # ignore reflection of bogus db-generated DefaultClause()
+ pass
+ elif not c.primary_key or not against('postgres'):
+ print repr(c)
+ assert reflected_c.default is None, reflected_c.default
+
+ assert len(table.primary_key) == len(reflected_table.primary_key)
+ for c in table.primary_key:
+ assert reflected_table.primary_key.columns[c.name]
+
+
+class AssertsExecutionResults(object):
+ def assert_result(self, result, class_, *objects):
+ result = list(result)
+ print repr(result)
+ self.assert_list(result, class_, objects)
+
+ def assert_list(self, result, class_, list):
+ self.assert_(len(result) == len(list),
+ "result list is not the same size as test list, " +
+ "for class " + class_.__name__)
+ for i in range(0, len(list)):
+ self.assert_row(class_, result[i], list[i])
+
+ def assert_row(self, class_, rowobj, desc):
+ self.assert_(rowobj.__class__ is class_,
+ "item class is not " + repr(class_))
+ for key, value in desc.iteritems():
+ if isinstance(value, tuple):
+ if isinstance(value[1], list):
+ self.assert_list(getattr(rowobj, key), value[0], value[1])
+ else:
+ self.assert_row(value[0], getattr(rowobj, key), value[1])
+ else:
+ self.assert_(getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s" % (
+ key, getattr(rowobj, key), value))
+
+ def assert_unordered_result(self, result, cls, *expected):
+ """As assert_result, but the order of objects is not considered.
+
+ The algorithm is very expensive but not a big deal for the small
+ numbers of rows that the test suite manipulates.
+ """
+
+ global util
+ if util is None:
+ from sqlalchemy import util
+
+ class frozendict(dict):
+ def __hash__(self):
+ return id(self)
+
+ found = util.IdentitySet(result)
+ expected = set([frozendict(e) for e in expected])
+
+ for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
+ self.fail('Unexpected type "%s", expected "%s"' % (
+ type(wrong).__name__, cls.__name__))
+
+ if len(found) != len(expected):
+ self.fail('Unexpected object count "%s", expected "%s"' % (
+ len(found), len(expected)))
+
+ NOVALUE = object()
+ def _compare_item(obj, spec):
+ for key, value in spec.iteritems():
+ if isinstance(value, tuple):
+ try:
+ self.assert_unordered_result(
+ getattr(obj, key), value[0], *value[1])
+ except AssertionError:
+ return False
+ else:
+ if getattr(obj, key, NOVALUE) != value:
+ return False
+ return True
+
+ for expected_item in expected:
+ for found_item in found:
+ if _compare_item(found_item, expected_item):
+ found.remove(found_item)
+ break
+ else:
+ self.fail(
+ "Expected %s instance with attributes %s not found." % (
+ cls.__name__, repr(expected_item)))
+ return True
+
+ def assert_sql_execution(self, db, callable_, *rules):
+ from testlib import assertsql
+ assertsql.asserter.add_rules(rules)
+ try:
+ callable_()
+ assertsql.asserter.statement_complete()
+ finally:
+ assertsql.asserter.clear_rules()
+
+ def assert_sql(self, db, callable_, list_, with_sequences=None):
+ from testlib import assertsql
+
+ if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
+ rules = with_sequences
+ else:
+ rules = list_
+
+ newrules = []
+ for rule in rules:
+ if isinstance(rule, dict):
+ newrule = assertsql.AllOf(*[
+ assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
+ ])
+ else:
+ newrule = assertsql.ExactSQL(*rule)
+ newrules.append(newrule)
+
+ self.assert_sql_execution(db, callable_, *newrules)
+
+ def assert_sql_count(self, db, callable_, count):
+ from testlib import assertsql
+ self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
+
+
</ins></span></pre></div>
<a id="sqlalchemybranchesnosetestssetupcfg"></a>
<div class="modfile"><h4>Modified: sqlalchemy/branches/nosetests/setup.cfg (6026 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/setup.cfg 2009-06-07 22:08:09 UTC (rev 6026)
+++ sqlalchemy/branches/nosetests/setup.cfg 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -1,3 +1,6 @@
</span><span class="cx"> [egg_info]
</span><span class="cx"> tag_build = dev
</span><span class="cx"> tag_svn_revision = true
</span><ins>+
+[nosetests]
+with-sqlalchemy = true
</ins><span class="cx">\ No newline at end of file
</span></span></pre></div>
<a id="sqlalchemybranchesnosetestssetuppy"></a>
<div class="modfile"><h4>Modified: sqlalchemy/branches/nosetests/setup.py (6026 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/setup.py 2009-06-07 22:08:09 UTC (rev 6026)
+++ sqlalchemy/branches/nosetests/setup.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -31,6 +31,13 @@
</span><span class="cx"> packages = find_packages('lib'),
</span><span class="cx"> package_dir = {'':'lib'},
</span><span class="cx"> license = "MIT License",
</span><ins>+
+ entry_points = {
+ 'nose.plugins.0.10': [
+ 'sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy',
+ ]
+ },
+
</ins><span class="cx"> long_description = """\
</span><span class="cx"> SQLAlchemy is:
</span><span class="cx">
</span></span></pre></div>
<a id="sqlalchemybranchesnoseteststestsqlquerypy"></a>
<div class="delfile"><h4>Deleted: sqlalchemy/branches/nosetests/test/sql/query.py (6026 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/test/sql/query.py 2009-06-07 22:08:09 UTC (rev 6026)
+++ sqlalchemy/branches/nosetests/test/sql/query.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -1,1321 +0,0 @@
</span><del>-import testenv; testenv.configure_for_tests()
-import datetime
-from sqlalchemy import *
-from sqlalchemy import exc, sql
-from sqlalchemy.engine import default
-from testlib import *
-from testlib.testing import eq_
-
-class QueryTest(TestBase):
-
- def setUpAll(self):
- global users, users2, addresses, metadata
- metadata = MetaData(testing.db)
- users = Table('query_users', metadata,
- Column('user_id', INT, primary_key = True),
- Column('user_name', VARCHAR(20)),
- )
- addresses = Table('query_addresses', metadata,
- Column('address_id', Integer, primary_key=True),
- Column('user_id', Integer, ForeignKey('query_users.user_id')),
- Column('address', String(30)))
-
- users2 = Table('u2', metadata,
- Column('user_id', INT, primary_key = True),
- Column('user_name', VARCHAR(20)),
- )
- metadata.create_all()
-
- def tearDown(self):
- addresses.delete().execute()
- users.delete().execute()
- users2.delete().execute()
-
- def tearDownAll(self):
- metadata.drop_all()
-
- def test_insert(self):
- users.insert().execute(user_id = 7, user_name = 'jack')
- assert users.count().scalar() == 1
-
- def test_insert_heterogeneous_params(self):
- users.insert().execute(
- {'user_id':7, 'user_name':'jack'},
- {'user_id':8, 'user_name':'ed'},
- {'user_id':9}
- )
- assert users.select().execute().fetchall() == [(7, 'jack'), (8, 'ed'), (9, None)]
-
- def test_update(self):
- users.insert().execute(user_id = 7, user_name = 'jack')
- assert users.count().scalar() == 1
-
- users.update(users.c.user_id == 7).execute(user_name = 'fred')
- assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred'
-
- def test_lastrow_accessor(self):
- """Tests the last_inserted_ids() and lastrow_has_id() functions."""
-
- def insert_values(table, values):
- """
- Inserts a row into a table, returns the full list of values
- INSERTed including defaults that fired off on the DB side and
- detects rows that had defaults and post-fetches.
- """
-
- result = table.insert().execute(**values)
- ret = values.copy()
-
- for col, id in zip(table.primary_key, result.last_inserted_ids()):
- ret[col.key] = id
-
- if result.lastrow_has_defaults():
- criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
- row = table.select(criterion).execute().fetchone()
- for c in table.c:
- ret[c.key] = row[c]
- return ret
-
- for supported, table, values, assertvalues in [
- (
- {'unsupported':['sqlite']},
- Table("t1", metadata,
- Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True),
- Column('foo', String(30), primary_key=True)),
- {'foo':'hi'},
- {'id':1, 'foo':'hi'}
- ),
- (
- {'unsupported':['sqlite']},
- Table("t2", metadata,
- Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True),
- Column('foo', String(30), primary_key=True),
- Column('bar', String(30), server_default='hi')
- ),
- {'foo':'hi'},
- {'id':1, 'foo':'hi', 'bar':'hi'}
- ),
- (
- {'unsupported':[]},
- Table("t3", metadata,
- Column("id", String(40), primary_key=True),
- Column('foo', String(30), primary_key=True),
- Column("bar", String(30))
- ),
- {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"},
- {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}
- ),
- (
- {'unsupported':[]},
- Table("t4", metadata,
- Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
- Column('foo', String(30), primary_key=True),
- Column('bar', String(30), server_default='hi')
- ),
- {'foo':'hi', 'id':1},
- {'id':1, 'foo':'hi', 'bar':'hi'}
- ),
- (
- {'unsupported':[]},
- Table("t5", metadata,
- Column('id', String(10), primary_key=True),
- Column('bar', String(30), server_default='hi')
- ),
- {'id':'id1'},
- {'id':'id1', 'bar':'hi'},
- ),
- ]:
- if testing.db.name in supported['unsupported']:
- continue
- try:
- table.create()
- i = insert_values(table, values)
- assert i == assertvalues, repr(i) + " " + repr(assertvalues)
- finally:
- table.drop()
-
- def test_row_iteration(self):
- users.insert().execute(
- {'user_id':7, 'user_name':'jack'},
- {'user_id':8, 'user_name':'ed'},
- {'user_id':9, 'user_name':'fred'},
- )
- r = users.select().execute()
- l = []
- for row in r:
- l.append(row)
- self.assert_(len(l) == 3)
-
- @testing.fails_on('firebird', 'Data type unknown')
- @testing.requires.subqueries
- def test_anonymous_rows(self):
- users.insert().execute(
- {'user_id':7, 'user_name':'jack'},
- {'user_id':8, 'user_name':'ed'},
- {'user_id':9, 'user_name':'fred'},
- )
-
- sel = select([users.c.user_id]).where(users.c.user_name=='jack').as_scalar()
- for row in select([sel + 1, sel + 3], bind=users.bind).execute():
- assert row['anon_1'] == 8
- assert row['anon_2'] == 10
-
- def test_order_by_label(self):
- """test that a label within an ORDER BY works on each backend.
-
- simple labels in ORDER BYs now render as the actual labelname
- which not every database supports.
-
- """
- users.insert().execute(
- {'user_id':7, 'user_name':'jack'},
- {'user_id':8, 'user_name':'ed'},
- {'user_id':9, 'user_name':'fred'},
- )
-
- concat = ("test: " + users.c.user_name).label('thedata')
- self.assertEquals(
- select([concat]).order_by(concat).execute().fetchall(),
- [("test: ed",), ("test: fred",), ("test: jack",)]
- )
-
- concat = ("test: " + users.c.user_name).label('thedata')
- self.assertEquals(
- select([concat]).order_by(desc(concat)).execute().fetchall(),
- [("test: jack",), ("test: fred",), ("test: ed",)]
- )
-
- concat = ("test: " + users.c.user_name).label('thedata')
- self.assertEquals(
- select([concat]).order_by(concat + "x").execute().fetchall(),
- [("test: ed",), ("test: fred",), ("test: jack",)]
- )
-
-
- def test_row_comparison(self):
- users.insert().execute(user_id = 7, user_name = 'jack')
- rp = users.select().execute().fetchone()
-
- self.assert_(rp == rp)
- self.assert_(not(rp != rp))
-
- equal = (7, 'jack')
-
- self.assert_(rp == equal)
- self.assert_(equal == rp)
- self.assert_(not (rp != equal))
- self.assert_(not (equal != equal))
-
- @testing.fails_on('mssql', 'No support for boolean logic in column select.')
- @testing.fails_on('oracle', 'FIXME: unknown')
- def test_or_and_as_columns(self):
- true, false = literal(True), literal(False)
-
- self.assertEquals(testing.db.execute(select([and_(true, false)])).scalar(), False)
- self.assertEquals(testing.db.execute(select([and_(true, true)])).scalar(), True)
- self.assertEquals(testing.db.execute(select([or_(true, false)])).scalar(), True)
- self.assertEquals(testing.db.execute(select([or_(false, false)])).scalar(), False)
- self.assertEquals(testing.db.execute(select([not_(or_(false, false))])).scalar(), True)
-
- row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).fetchone()
- assert row.x == False
- assert row.y == False
-
- row = testing.db.execute(select([or_(true, false).label("x"), and_(true, false).label("y")])).fetchone()
- assert row.x == True
- assert row.y == False
-
- def test_fetchmany(self):
- users.insert().execute(user_id = 7, user_name = 'jack')
- users.insert().execute(user_id = 8, user_name = 'ed')
- users.insert().execute(user_id = 9, user_name = 'fred')
- r = users.select().execute()
- l = []
- for row in r.fetchmany(size=2):
- l.append(row)
- self.assert_(len(l) == 2, "fetchmany(size=2) got %s rows" % len(l))
-
- def test_like_ops(self):
- users.insert().execute(
- {'user_id':1, 'user_name':'apples'},
- {'user_id':2, 'user_name':'oranges'},
- {'user_id':3, 'user_name':'bananas'},
- {'user_id':4, 'user_name':'legumes'},
- {'user_id':5, 'user_name':'hi % there'},
- )
-
- for expr, result in (
- (select([users.c.user_id]).where(users.c.user_name.startswith('apple')), [(1,)]),
- (select([users.c.user_id]).where(users.c.user_name.contains('i % t')), [(5,)]),
- (select([users.c.user_id]).where(users.c.user_name.endswith('anas')), [(3,)]),
- ):
- eq_(expr.execute().fetchall(), result)
-
-
- @testing.emits_warning('.*now automatically escapes.*')
- def test_percents_in_text(self):
- for expr, result in (
- (text("select 6 % 10"), 6),
- (text("select 17 % 10"), 7),
- (text("select '%'"), '%'),
- (text("select '%%'"), '%%'),
- (text("select '%%%'"), '%%%'),
- (text("select 'hello % world'"), "hello % world")
- ):
- eq_(testing.db.scalar(expr), result)
-
- def test_ilike(self):
- users.insert().execute(
- {'user_id':1, 'user_name':'one'},
- {'user_id':2, 'user_name':'TwO'},
- {'user_id':3, 'user_name':'ONE'},
- {'user_id':4, 'user_name':'OnE'},
- )
-
- self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('one')).execute().fetchall(), [(1, ), (3, ), (4, )])
-
- self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )])
-
- if testing.against('postgres'):
- self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )])
- self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), [])
-
-
- def test_compiled_execute(self):
- users.insert().execute(user_id = 7, user_name = 'jack')
- s = select([users], users.c.user_id==bindparam('id')).compile()
- c = testing.db.connect()
- assert c.execute(s, id=7).fetchall()[0]['user_id'] == 7
-
- def test_compiled_insert_execute(self):
- users.insert().compile().execute(user_id = 7, user_name = 'jack')
- s = select([users], users.c.user_id==bindparam('id')).compile()
- c = testing.db.connect()
- assert c.execute(s, id=7).fetchall()[0]['user_id'] == 7
-
- def test_repeated_bindparams(self):
- """Tests that a BindParam can be used more than once.
-
- This should be run for DB-APIs with both positional and named
- paramstyles.
- """
- users.insert().execute(user_id = 7, user_name = 'jack')
- users.insert().execute(user_id = 8, user_name = 'fred')
-
- u = bindparam('userid')
- s = users.select(and_(users.c.user_name==u, users.c.user_name==u))
- r = s.execute(userid='fred').fetchall()
- assert len(r) == 1
-
- def test_bindparam_shortname(self):
- """test the 'shortname' field on BindParamClause."""
- users.insert().execute(user_id = 7, user_name = 'jack')
- users.insert().execute(user_id = 8, user_name = 'fred')
- u = bindparam('userid', shortname='someshortname')
- s = users.select(users.c.user_name==u)
- r = s.execute(someshortname='fred').fetchall()
- assert len(r) == 1
-
- def test_bindparam_detection(self):
- dialect = default.DefaultDialect(paramstyle='qmark')
- prep = lambda q: str(sql.text(q).compile(dialect=dialect))
-
- def a_eq(got, wanted):
- if got != wanted:
- print "Wanted %s" % wanted
- print "Received %s" % got
- self.assert_(got == wanted, got)
-
- a_eq(prep('select foo'), 'select foo')
- a_eq(prep("time='12:30:00'"), "time='12:30:00'")
- a_eq(prep(u"time='12:30:00'"), u"time='12:30:00'")
- a_eq(prep(":this:that"), ":this:that")
- a_eq(prep(":this :that"), "? ?")
- a_eq(prep("(:this),(:that :other)"), "(?),(? ?)")
- a_eq(prep("(:this),(:that:other)"), "(?),(:that:other)")
- a_eq(prep("(:this),(:that,:other)"), "(?),(?,?)")
- a_eq(prep("(:that_:other)"), "(:that_:other)")
- a_eq(prep("(:that_ :other)"), "(? ?)")
- a_eq(prep("(:that_other)"), "(?)")
- a_eq(prep("(:that$other)"), "(?)")
- a_eq(prep("(:that$:other)"), "(:that$:other)")
- a_eq(prep(".:that$ :other."), ".? ?.")
-
- a_eq(prep(r'select \foo'), r'select \foo')
- a_eq(prep(r"time='12\:30:00'"), r"time='12\:30:00'")
- a_eq(prep(":this \:that"), "? :that")
- a_eq(prep(r"(\:that$other)"), "(:that$other)")
- a_eq(prep(r".\:that$ :other."), ".:that$ ?.")
-
- def test_delete(self):
- users.insert().execute(user_id = 7, user_name = 'jack')
- users.insert().execute(user_id = 8, user_name = 'fred')
- print repr(users.select().execute().fetchall())
-
- users.delete(users.c.user_name == 'fred').execute()
-
- print repr(users.select().execute().fetchall())
-
-
-
- @testing.exclude('mysql', '<', (5, 0, 37), 'database bug')
- def test_scalar_select(self):
- """test that scalar subqueries with labels get their type propagated to the result set."""
- # mysql and/or mysqldb has a bug here, type isn't propagated for scalar
- # subquery.
- datetable = Table('datetable', metadata,
- Column('id', Integer, primary_key=True),
- Column('today', DateTime))
- datetable.create()
- try:
- datetable.insert().execute(id=1, today=datetime.datetime(2006, 5, 12, 12, 0, 0))
- s = select([datetable.alias('x').c.today]).as_scalar()
- s2 = select([datetable.c.id, s.label('somelabel')])
- #print s2.c.somelabel.type
- assert isinstance(s2.execute().fetchone()['somelabel'], datetime.datetime)
- finally:
- datetable.drop()
-
- def test_order_by(self):
- """Exercises ORDER BY clause generation.
-
- Tests simple, compound, aliased and DESC clauses.
- """
-
- users.insert().execute(user_id=1, user_name='c')
- users.insert().execute(user_id=2, user_name='b')
- users.insert().execute(user_id=3, user_name='a')
-
- def a_eq(executable, wanted):
- got = list(executable.execute())
- self.assertEquals(got, wanted)
-
- for labels in False, True:
- a_eq(users.select(order_by=[users.c.user_id],
- use_labels=labels),
- [(1, 'c'), (2, 'b'), (3, 'a')])
-
- a_eq(users.select(order_by=[users.c.user_name, users.c.user_id],
- use_labels=labels),
- [(3, 'a'), (2, 'b'), (1, 'c')])
-
- a_eq(select([users.c.user_id.label('foo')],
- use_labels=labels,
- order_by=[users.c.user_id]),
- [(1,), (2,), (3,)])
-
- a_eq(select([users.c.user_id.label('foo'), users.c.user_name],
- use_labels=labels,
- order_by=[users.c.user_name, users.c.user_id]),
- [(3, 'a'), (2, 'b'), (1, 'c')])
-
- a_eq(users.select(distinct=True,
- use_labels=labels,
- order_by=[users.c.user_id]),
- [(1, 'c'), (2, 'b'), (3, 'a')])
-
- a_eq(select([users.c.user_id.label('foo')],
- distinct=True,
- use_labels=labels,
- order_by=[users.c.user_id]),
- [(1,), (2,), (3,)])
-
- a_eq(select([users.c.user_id.label('a'),
- users.c.user_id.label('b'),
- users.c.user_name],
- use_labels=labels,
- order_by=[users.c.user_id]),
- [(1, 1, 'c'), (2, 2, 'b'), (3, 3, 'a')])
-
- a_eq(users.select(distinct=True,
- use_labels=labels,
- order_by=[desc(users.c.user_id)]),
- [(3, 'a'), (2, 'b'), (1, 'c')])
-
- a_eq(select([users.c.user_id.label('foo')],
- distinct=True,
- use_labels=labels,
- order_by=[users.c.user_id.desc()]),
- [(3,), (2,), (1,)])
-
- def test_column_accessor(self):
- users.insert().execute(user_id=1, user_name='john')
- users.insert().execute(user_id=2, user_name='jack')
- addresses.insert().execute(address_id=1, user_id=2, address='foo@...')
-
- r = users.select(users.c.user_id==2).execute().fetchone()
- self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
- self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
-
- r = text("select * from query_users where user_id=2", bind=testing.db).execute().fetchone()
- self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
- self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
-
- # test slices
- r = text("select * from query_addresses", bind=testing.db).execute().fetchone()
- self.assert_(r[0:1] == (1,))
- self.assert_(r[1:] == (2, 'foo@...'))
- self.assert_(r[:-1] == (1, 2))
-
- # test a little sqlite weirdness - with the UNION, cols come back as "query_users.user_id" in cursor.description
- r = text("select query_users.user_id, query_users.user_name from query_users "
- "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().fetchone()
- self.assert_(r['user_id']) == 1
- self.assert_(r['user_name']) == "john"
-
- # test using literal tablename.colname
- r = text('select query_users.user_id AS "query_users.user_id", query_users.user_name AS "query_users.user_name" from query_users', bind=testing.db).execute().fetchone()
- self.assert_(r['query_users.user_id']) == 1
- self.assert_(r['query_users.user_name']) == "john"
-
- def test_row_as_args(self):
- users.insert().execute(user_id=1, user_name='john')
- r = users.select(users.c.user_id==1).execute().fetchone()
- users.delete().execute()
- users.insert().execute(r)
- assert users.select().execute().fetchall() == [(1, 'john')]
-
- def test_result_as_args(self):
- users.insert().execute([dict(user_id=1, user_name='john'), dict(user_id=2, user_name='ed')])
- r = users.select().execute()
- users2.insert().execute(list(r))
- assert users2.select().execute().fetchall() == [(1, 'john'), (2, 'ed')]
-
- users2.delete().execute()
- r = users.select().execute()
- users2.insert().execute(*list(r))
- assert users2.select().execute().fetchall() == [(1, 'john'), (2, 'ed')]
-
- def test_ambiguous_column(self):
- users.insert().execute(user_id=1, user_name='john')
- r = users.outerjoin(addresses).select().execute().fetchone()
- try:
- print r['user_id']
- assert False
- except exc.InvalidRequestError, e:
- assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \
- str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement."
-
- @testing.requires.subqueries
- def test_column_label_targeting(self):
- users.insert().execute(user_id=7, user_name='ed')
-
- for s in (
- users.select().alias('foo'),
- users.select().alias(users.name),
- ):
- row = s.select(use_labels=True).execute().fetchone()
- assert row[s.c.user_id] == 7
- assert row[s.c.user_name] == 'ed'
-
- def test_keys(self):
- users.insert().execute(user_id=1, user_name='foo')
- r = users.select().execute().fetchone()
- self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
-
- def test_items(self):
- users.insert().execute(user_id=1, user_name='foo')
- r = users.select().execute().fetchone()
- self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
-
- def test_len(self):
- users.insert().execute(user_id=1, user_name='foo')
- r = users.select().execute().fetchone()
- self.assertEqual(len(r), 2)
- r.close()
- r = testing.db.execute('select user_name, user_id from query_users').fetchone()
- self.assertEqual(len(r), 2)
- r.close()
- r = testing.db.execute('select user_name from query_users').fetchone()
- self.assertEqual(len(r), 1)
- r.close()
-
- def test_cant_execute_join(self):
- try:
- users.join(addresses).execute()
- except exc.ArgumentError, e:
- assert str(e).startswith('Not an executable clause: ')
-
-
-
- def test_column_order_with_simple_query(self):
- # should return values in column definition order
- users.insert().execute(user_id=1, user_name='foo')
- r = users.select(users.c.user_id==1).execute().fetchone()
- self.assertEqual(r[0], 1)
- self.assertEqual(r[1], 'foo')
- self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
- self.assertEqual(r.values(), [1, 'foo'])
-
- def test_column_order_with_text_query(self):
- # should return values in query order
- users.insert().execute(user_id=1, user_name='foo')
- r = testing.db.execute('select user_name, user_id from query_users').fetchone()
- self.assertEqual(r[0], 'foo')
- self.assertEqual(r[1], 1)
- self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id'])
- self.assertEqual(r.values(), ['foo', 1])
-
- @testing.crashes('oracle', 'FIXME: unknown, varify not fails_on()')
- @testing.crashes('firebird', 'An identifier must begin with a letter')
- @testing.crashes('maxdb', 'FIXME: unknown, verify not fails_on()')
- def test_column_accessor_shadow(self):
- meta = MetaData(testing.db)
- shadowed = Table('test_shadowed', meta,
- Column('shadow_id', INT, primary_key = True),
- Column('shadow_name', VARCHAR(20)),
- Column('parent', VARCHAR(20)),
- Column('row', VARCHAR(40)),
- Column('__parent', VARCHAR(20)),
- Column('__row', VARCHAR(20)),
- )
- shadowed.create(checkfirst=True)
- try:
- shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row')
- r = shadowed.select(shadowed.c.shadow_id==1).execute().fetchone()
- self.assert_(r.shadow_id == r['shadow_id'] == r[shadowed.c.shadow_id] == 1)
- self.assert_(r.shadow_name == r['shadow_name'] == r[shadowed.c.shadow_name] == 'The Shadow')
- self.assert_(r.parent == r['parent'] == r[shadowed.c.parent] == 'The Light')
- self.assert_(r.row == r['row'] == r[shadowed.c.row] == 'Without light there is no shadow')
- self.assert_(r['__parent'] == 'Hidden parent')
- self.assert_(r['__row'] == 'Hidden row')
- try:
- print r.__parent, r.__row
- self.fail('Should not allow access to private attributes')
- except AttributeError:
- pass # expected
- r.close()
- finally:
- shadowed.drop(checkfirst=True)
-
- def test_in_filtering(self):
- """test the behavior of the in_() function."""
-
- users.insert().execute(user_id = 7, user_name = 'jack')
- users.insert().execute(user_id = 8, user_name = 'fred')
- users.insert().execute(user_id = 9, user_name = None)
-
- s = users.select(users.c.user_name.in_([]))
- r = s.execute().fetchall()
- # No username is in empty set
- assert len(r) == 0
-
- s = users.select(not_(users.c.user_name.in_([])))
- r = s.execute().fetchall()
- # All usernames with a value are outside an empty set
- assert len(r) == 2
-
- s = users.select(users.c.user_name.in_(['jack','fred']))
- r = s.execute().fetchall()
- assert len(r) == 2
-
- s = users.select(not_(users.c.user_name.in_(['jack','fred'])))
- r = s.execute().fetchall()
- # Null values are not outside any set
- assert len(r) == 0
-
- u = bindparam('search_key')
-
- s = users.select(u.in_([]))
- r = s.execute(search_key='john').fetchall()
- assert len(r) == 0
- r = s.execute(search_key=None).fetchall()
- assert len(r) == 0
-
- s = users.select(not_(u.in_([])))
- r = s.execute(search_key='john').fetchall()
- assert len(r) == 3
- r = s.execute(search_key=None).fetchall()
- assert len(r) == 0
-
- @testing.fails_on('firebird', 'FIXME: unknown')
- @testing.fails_on('maxdb', 'FIXME: unknown')
- @testing.fails_on('oracle', 'FIXME: unknown')
- @testing.fails_on('mssql', 'FIXME: unknown')
- def test_in_filtering_advanced(self):
- """test the behavior of the in_() function when comparing against an empty collection."""
-
- users.insert().execute(user_id = 7, user_name = 'jack')
- users.insert().execute(user_id = 8, user_name = 'fred')
- users.insert().execute(user_id = 9, user_name = None)
-
- s = users.select(users.c.user_name.in_([]) == True)
- r = s.execute().fetchall()
- assert len(r) == 0
- s = users.select(users.c.user_name.in_([]) == False)
- r = s.execute().fetchall()
- assert len(r) == 2
- s = users.select(users.c.user_name.in_([]) == None)
- r = s.execute().fetchall()
- assert len(r) == 1
-
-class PercentSchemaNamesTest(TestBase):
- """tests using percent signs, spaces in table and column names.
-
- Doesn't pass for mysql, postgres, but this is really a
- SQLAlchemy bug - we should be escaping out %% signs for this
- operation the same way we do for text() and column labels.
-
- """
- @testing.crashes('mysql', 'mysqldb calls name % (params)')
- @testing.crashes('postgres', 'postgres calls name % (params)')
- def setUpAll(self):
- global percent_table, metadata
- metadata = MetaData(testing.db)
- percent_table = Table('percent%table', metadata,
- Column("percent%", Integer),
- Column("%(oneofthese)s", Integer),
- Column("spaces % more spaces", Integer),
- )
- metadata.create_all()
-
- @testing.crashes('mysql', 'mysqldb calls name % (params)')
- @testing.crashes('postgres', 'postgres calls name % (params)')
- def tearDownAll(self):
- metadata.drop_all()
-
- @testing.crashes('mysql', 'mysqldb calls name % (params)')
- @testing.crashes('postgres', 'postgres calls name % (params)')
- def test_roundtrip(self):
- percent_table.insert().execute(
- {'percent%':5, '%(oneofthese)s':7, 'spaces % more spaces':12},
- )
- percent_table.insert().execute(
- {'percent%':7, '%(oneofthese)s':8, 'spaces % more spaces':11},
- {'percent%':9, '%(oneofthese)s':9, 'spaces % more spaces':10},
- {'percent%':11, '%(oneofthese)s':10, 'spaces % more spaces':9},
- )
-
- for table in (percent_table, percent_table.alias()):
- eq_(
- table.select().order_by(table.c['%(oneofthese)s']).execute().fetchall(),
- [
- (5, 7, 12),
- (7, 8, 11),
- (9, 9, 10),
- (11, 10, 9)
- ]
- )
-
- eq_(
- table.select().
- where(table.c['spaces % more spaces'].in_([9, 10])).
- order_by(table.c['%(oneofthese)s']).execute().fetchall(),
- [
- (9, 9, 10),
- (11, 10, 9)
- ]
- )
-
- result = table.select().order_by(table.c['%(oneofthese)s']).execute()
- row = result.fetchone()
- eq_(row[table.c['percent%']], 5)
- eq_(row[table.c['%(oneofthese)s']], 7)
- eq_(row[table.c['spaces % more spaces']], 12)
- row = result.fetchone()
- eq_(row['percent%'], 7)
- eq_(row['%(oneofthese)s'], 8)
- eq_(row['spaces % more spaces'], 11)
- result.close()
-
- percent_table.update().values({percent_table.c['%(oneofthese)s']:9, percent_table.c['spaces % more spaces']:15}).execute()
-
- eq_(
- percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute().fetchall(),
- [
- (5, 9, 15),
- (7, 9, 15),
- (9, 9, 15),
- (11, 9, 15)
- ]
- )
-
-
-
-class LimitTest(TestBase):
-
- def setUpAll(self):
- global users, addresses, metadata
- metadata = MetaData(testing.db)
- users = Table('query_users', metadata,
- Column('user_id', INT, primary_key = True),
- Column('user_name', VARCHAR(20)),
- )
- addresses = Table('query_addresses', metadata,
- Column('address_id', Integer, primary_key=True),
- Column('user_id', Integer, ForeignKey('query_users.user_id')),
- Column('address', String(30)))
- metadata.create_all()
- self._data()
-
- def _data(self):
- users.insert().execute(user_id=1, user_name='john')
- addresses.insert().execute(address_id=1, user_id=1, address='addr1')
- users.insert().execute(user_id=2, user_name='jack')
- addresses.insert().execute(address_id=2, user_id=2, address='addr1')
- users.insert().execute(user_id=3, user_name='ed')
- addresses.insert().execute(address_id=3, user_id=3, address='addr2')
- users.insert().execute(user_id=4, user_name='wendy')
- addresses.insert().execute(address_id=4, user_id=4, address='addr3')
- users.insert().execute(user_id=5, user_name='laura')
- addresses.insert().execute(address_id=5, user_id=5, address='addr4')
- users.insert().execute(user_id=6, user_name='ralph')
- addresses.insert().execute(address_id=6, user_id=6, address='addr5')
- users.insert().execute(user_id=7, user_name='fido')
- addresses.insert().execute(address_id=7, user_id=7, address='addr5')
-
- def tearDownAll(self):
- metadata.drop_all()
-
- def test_select_limit(self):
- r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall()
- self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r))
-
- @testing.fails_on('maxdb', 'FIXME: unknown')
- def test_select_limit_offset(self):
- """Test the interaction between limit and offset"""
-
- r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall()
- self.assert_(r==[(3, 'ed'), (4, 'wendy'), (5, 'laura')])
- r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall()
- self.assert_(r==[(6, 'ralph'), (7, 'fido')])
-
- def test_select_distinct_limit(self):
- """Test the interaction between limit and distinct"""
-
- r = sorted([x[0] for x in select([addresses.c.address]).distinct().limit(3).order_by(addresses.c.address).execute().fetchall()])
- self.assert_(len(r) == 3, repr(r))
- self.assert_(r[0] != r[1] and r[1] != r[2], repr(r))
-
- @testing.fails_on('mssql', 'FIXME: unknown')
- def test_select_distinct_offset(self):
- """Test the interaction between distinct and offset"""
-
- r = sorted([x[0] for x in select([addresses.c.address]).distinct().offset(1).order_by(addresses.c.address).execute().fetchall()])
- self.assert_(len(r) == 4, repr(r))
- self.assert_(r[0] != r[1] and r[1] != r[2] and r[2] != [3], repr(r))
-
- def test_select_distinct_limit_offset(self):
- """Test the interaction between limit and limit/offset"""
-
- r = select([addresses.c.address]).order_by(addresses.c.address).distinct().offset(2).limit(3).execute().fetchall()
- self.assert_(len(r) == 3, repr(r))
- self.assert_(r[0] != r[1] and r[1] != r[2], repr(r))
-
-class CompoundTest(TestBase):
- """test compound statements like UNION, INTERSECT, particularly their ability to nest on
- different databases."""
- def setUpAll(self):
- global metadata, t1, t2, t3
- metadata = MetaData(testing.db)
- t1 = Table('t1', metadata,
- Column('col1', Integer, Sequence('t1pkseq'), primary_key=True),
- Column('col2', String(30)),
- Column('col3', String(40)),
- Column('col4', String(30))
- )
- t2 = Table('t2', metadata,
- Column('col1', Integer, Sequence('t2pkseq'), primary_key=True),
- Column('col2', String(30)),
- Column('col3', String(40)),
- Column('col4', String(30)))
- t3 = Table('t3', metadata,
- Column('col1', Integer, Sequence('t3pkseq'), primary_key=True),
- Column('col2', String(30)),
- Column('col3', String(40)),
- Column('col4', String(30)))
- metadata.create_all()
-
- t1.insert().execute([
- dict(col2="t1col2r1", col3="aaa", col4="aaa"),
- dict(col2="t1col2r2", col3="bbb", col4="bbb"),
- dict(col2="t1col2r3", col3="ccc", col4="ccc"),
- ])
- t2.insert().execute([
- dict(col2="t2col2r1", col3="aaa", col4="bbb"),
- dict(col2="t2col2r2", col3="bbb", col4="ccc"),
- dict(col2="t2col2r3", col3="ccc", col4="aaa"),
- ])
- t3.insert().execute([
- dict(col2="t3col2r1", col3="aaa", col4="ccc"),
- dict(col2="t3col2r2", col3="bbb", col4="aaa"),
- dict(col2="t3col2r3", col3="ccc", col4="bbb"),
- ])
-
- def tearDownAll(self):
- metadata.drop_all()
-
- def _fetchall_sorted(self, executed):
- return sorted([tuple(row) for row in executed.fetchall()])
-
- @testing.requires.subqueries
- def test_union(self):
- (s1, s2) = (
- select([t1.c.col3.label('col3'), t1.c.col4.label('col4')],
- t1.c.col2.in_(["t1col2r1", "t1col2r2"])),
- select([t2.c.col3.label('col3'), t2.c.col4.label('col4')],
- t2.c.col2.in_(["t2col2r2", "t2col2r3"]))
- )
- u = union(s1, s2)
-
- wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
- ('ccc', 'aaa')]
- found1 = self._fetchall_sorted(u.execute())
- self.assertEquals(found1, wanted)
-
- found2 = self._fetchall_sorted(u.alias('bar').select().execute())
- self.assertEquals(found2, wanted)
-
- def test_union_ordered(self):
- (s1, s2) = (
- select([t1.c.col3.label('col3'), t1.c.col4.label('col4')],
- t1.c.col2.in_(["t1col2r1", "t1col2r2"])),
- select([t2.c.col3.label('col3'), t2.c.col4.label('col4')],
- t2.c.col2.in_(["t2col2r2", "t2col2r3"]))
- )
- u = union(s1, s2, order_by=['col3', 'col4'])
-
- wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
- ('ccc', 'aaa')]
- self.assertEquals(u.execute().fetchall(), wanted)
-
- @testing.fails_on('maxdb', 'FIXME: unknown')
- @testing.requires.subqueries
- def test_union_ordered_alias(self):
- (s1, s2) = (
- select([t1.c.col3.label('col3'), t1.c.col4.label('col4')],
- t1.c.col2.in_(["t1col2r1", "t1col2r2"])),
- select([t2.c.col3.label('col3'), t2.c.col4.label('col4')],
- t2.c.col2.in_(["t2col2r2", "t2col2r3"]))
- )
- u = union(s1, s2, order_by=['col3', 'col4'])
-
- wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
- ('ccc', 'aaa')]
- self.assertEquals(u.alias('bar').select().execute().fetchall(), wanted)
-
- @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
- @testing.fails_on('mysql', 'FIXME: unknown')
- @testing.fails_on('sqlite', 'FIXME: unknown')
- def test_union_all(self):
- e = union_all(
- select([t1.c.col3]),
- union(
- select([t1.c.col3]),
- select([t1.c.col3]),
- )
- )
-
- wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)]
- found1 = self._fetchall_sorted(e.execute())
- self.assertEquals(found1, wanted)
-
- found2 = self._fetchall_sorted(e.alias('foo').select().execute())
- self.assertEquals(found2, wanted)
-
- @testing.crashes('firebird', 'Does not support intersect')
- @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
- @testing.fails_on('mysql', 'FIXME: unknown')
- def test_intersect(self):
- i = intersect(
- select([t2.c.col3, t2.c.col4]),
- select([t2.c.col3, t2.c.col4], t2.c.col4==t3.c.col3)
- )
-
- wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
-
- found1 = self._fetchall_sorted(i.execute())
- self.assertEquals(found1, wanted)
-
- found2 = self._fetchall_sorted(i.alias('bar').select().execute())
- self.assertEquals(found2, wanted)
-
- @testing.crashes('firebird', 'Does not support except')
- @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
- @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
- @testing.fails_on('mysql', 'FIXME: unknown')
- def test_except_style1(self):
- e = except_(union(
- select([t1.c.col3, t1.c.col4]),
- select([t2.c.col3, t2.c.col4]),
- select([t3.c.col3, t3.c.col4]),
- ), select([t2.c.col3, t2.c.col4]))
-
- wanted = [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'),
- ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
-
- found = self._fetchall_sorted(e.alias('bar').select().execute())
- self.assertEquals(found, wanted)
-
- @testing.crashes('firebird', 'Does not support except')
- @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
- @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
- @testing.fails_on('mysql', 'FIXME: unknown')
- def test_except_style2(self):
- e = except_(union(
- select([t1.c.col3, t1.c.col4]),
- select([t2.c.col3, t2.c.col4]),
- select([t3.c.col3, t3.c.col4]),
- ).alias('foo').select(), select([t2.c.col3, t2.c.col4]))
-
- wanted = [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'),
- ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
-
- found1 = self._fetchall_sorted(e.execute())
- self.assertEquals(found1, wanted)
-
- found2 = self._fetchall_sorted(e.alias('bar').select().execute())
- self.assertEquals(found2, wanted)
-
- @testing.crashes('firebird', 'Does not support except')
- @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
- @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
- @testing.fails_on('mysql', 'FIXME: unknown')
- @testing.fails_on('sqlite', 'FIXME: unknown')
- def test_except_style3(self):
- # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc
- e = except_(
- select([t1.c.col3]), # aaa, bbb, ccc
- except_(
- select([t2.c.col3]), # aaa, bbb, ccc
- select([t3.c.col3], t3.c.col3 == 'ccc'), #ccc
- )
- )
- self.assertEquals(e.execute().fetchall(), [('ccc',)])
- self.assertEquals(e.alias('foo').select().execute().fetchall(),
- [('ccc',)])
-
- @testing.crashes('firebird', 'Does not support intersect')
- @testing.fails_on('mysql', 'FIXME: unknown')
- def test_composite(self):
- u = intersect(
- select([t2.c.col3, t2.c.col4]),
- union(
- select([t1.c.col3, t1.c.col4]),
- select([t2.c.col3, t2.c.col4]),
- select([t3.c.col3, t3.c.col4]),
- ).alias('foo').select()
- )
- wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
- found = self._fetchall_sorted(u.execute())
-
- self.assertEquals(found, wanted)
-
- @testing.crashes('firebird', 'Does not support intersect')
- @testing.fails_on('mysql', 'FIXME: unknown')
- def test_composite_alias(self):
- ua = intersect(
- select([t2.c.col3, t2.c.col4]),
- union(
- select([t1.c.col3, t1.c.col4]),
- select([t2.c.col3, t2.c.col4]),
- select([t3.c.col3, t3.c.col4]),
- ).alias('foo').select()
- ).alias('bar')
-
- wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
- found = self._fetchall_sorted(ua.select().execute())
- self.assertEquals(found, wanted)
-
-
-class JoinTest(TestBase):
- """Tests join execution.
-
- The compiled SQL emitted by the dialect might be ANSI joins or
- theta joins ('old oracle style', with (+) for OUTER). This test
- tries to exercise join syntax and uncover any inconsistencies in
- `JOIN rhs ON lhs.col=rhs.col` vs `rhs.col=lhs.col`. At least one
- database seems to be sensitive to this.
- """
-
- def setUpAll(self):
- global metadata
- global t1, t2, t3
-
- metadata = MetaData(testing.db)
- t1 = Table('t1', metadata,
- Column('t1_id', Integer, primary_key=True),
- Column('name', String(32)))
- t2 = Table('t2', metadata,
- Column('t2_id', Integer, primary_key=True),
- Column('t1_id', Integer, ForeignKey('t1.t1_id')),
- Column('name', String(32)))
- t3 = Table('t3', metadata,
- Column('t3_id', Integer, primary_key=True),
- Column('t2_id', Integer, ForeignKey('t2.t2_id')),
- Column('name', String(32)))
- metadata.drop_all()
- metadata.create_all()
-
- # t1.10 -> t2.20 -> t3.30
- # t1.11 -> t2.21
- # t1.12
- t1.insert().execute({'t1_id': 10, 'name': 't1 #10'},
- {'t1_id': 11, 'name': 't1 #11'},
- {'t1_id': 12, 'name': 't1 #12'})
- t2.insert().execute({'t2_id': 20, 't1_id': 10, 'name': 't2 #20'},
- {'t2_id': 21, 't1_id': 11, 'name': 't2 #21'})
- t3.insert().execute({'t3_id': 30, 't2_id': 20, 'name': 't3 #30'})
-
- def tearDownAll(self):
- metadata.drop_all()
-
- def assertRows(self, statement, expected):
- """Execute a statement and assert that rows returned equal expected."""
-
- found = sorted([tuple(row)
- for row in statement.execute().fetchall()])
-
- self.assertEquals(found, sorted(expected))
-
- def test_join_x1(self):
- """Joins t1->t2."""
-
- for criteria in (t1.c.t1_id==t2.c.t1_id, t2.c.t1_id==t1.c.t1_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id],
- from_obj=[t1.join(t2, criteria)])
- self.assertRows(expr, [(10, 20), (11, 21)])
-
- def test_join_x2(self):
- """Joins t1->t2->t3."""
-
- for criteria in (t1.c.t1_id==t2.c.t1_id, t2.c.t1_id==t1.c.t1_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id],
- from_obj=[t1.join(t2, criteria)])
- self.assertRows(expr, [(10, 20), (11, 21)])
-
- def test_outerjoin_x1(self):
- """Outer joins t1->t2."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id],
- from_obj=[t1.join(t2).join(t3, criteria)])
- self.assertRows(expr, [(10, 20)])
-
- def test_outerjoin_x2(self):
- """Outer joins t1->t2,t3."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- from_obj=[t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). \
- outerjoin(t3, criteria)])
- self.assertRows(expr, [(10, 20, 30), (11, 21, None), (12, None, None)])
-
- def test_outerjoin_where_x2_t1(self):
- """Outer joins t1->t2,t3, where on t1."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- t1.c.name == 't1 #10',
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- t1.c.t1_id < 12,
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
-
- def test_outerjoin_where_x2_t2(self):
- """Outer joins t1->t2,t3, where on t2."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- t2.c.name == 't2 #20',
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- t2.c.t2_id < 29,
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
-
- def test_outerjoin_where_x2_t1t2(self):
- """Outer joins t1->t2,t3, where on t1 and t2."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'),
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t1.c.t1_id < 19, 29 > t2.c.t2_id),
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
-
- def test_outerjoin_where_x2_t3(self):
- """Outer joins t1->t2,t3, where on t3."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- t3.c.name == 't3 #30',
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- t3.c.t3_id < 39,
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- def test_outerjoin_where_x2_t1t3(self):
- """Outer joins t1->t2,t3, where on t1 and t3."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t1.c.name == 't1 #10', t3.c.name == 't3 #30'),
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t1.c.t1_id < 19, t3.c.t3_id < 39),
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- def test_outerjoin_where_x2_t1t2(self):
- """Outer joins t1->t2,t3, where on t1 and t2."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'),
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t1.c.t1_id < 12, t2.c.t2_id < 39),
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
-
- def test_outerjoin_where_x2_t1t2t3(self):
- """Outer joins t1->t2,t3, where on t1, t2 and t3."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t1.c.name == 't1 #10',
- t2.c.name == 't2 #20',
- t3.c.name == 't3 #30'),
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t1.c.t1_id < 19,
- t2.c.t2_id < 29,
- t3.c.t3_id < 39),
- from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
- outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- def test_mixed(self):
- """Joins t1->t2, outer t2->t3."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
- print expr
- self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
-
- def test_mixed_where(self):
- """Joins t1->t2, outer t2->t3, plus a where on each table in turn."""
-
- for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- t1.c.name == 't1 #10',
- from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- t2.c.name == 't2 #20',
- from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- t3.c.name == 't3 #30',
- from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'),
- from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t2.c.name == 't2 #20', t3.c.name == 't3 #30'),
- from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
- expr = select(
- [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
- and_(t1.c.name == 't1 #10',
- t2.c.name == 't2 #20',
- t3.c.name == 't3 #30'),
- from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
- self.assertRows(expr, [(10, 20, 30)])
-
-
-class OperatorTest(TestBase):
- def setUpAll(self):
- global metadata, flds
- metadata = MetaData(testing.db)
- flds = Table('flds', metadata,
- Column('idcol', Integer, Sequence('t1pkseq'), primary_key=True),
- Column('intcol', Integer),
- Column('strcol', String(50)),
- )
- metadata.create_all()
-
- flds.insert().execute([
- dict(intcol=5, strcol='foo'),
- dict(intcol=13, strcol='bar')
- ])
-
- def tearDownAll(self):
- metadata.drop_all()
-
- @testing.fails_on('maxdb', 'FIXME: unknown')
- def test_modulo(self):
- self.assertEquals(
- select([flds.c.intcol % 3],
- order_by=flds.c.idcol).execute().fetchall(),
- [(2,),(1,)]
- )
-
-
-
-if __name__ == "__main__":
- testenv.main()
</del></span></pre></div>
<a id="sqlalchemybranchesnoseteststestsqltest_querypyfromrev6010sqlalchemytrunktestsqlquerypy"></a>
<div class="copfile"><h4>Copied: sqlalchemy/branches/nosetests/test/sql/test_query.py (from rev 6010, sqlalchemy/trunk/test/sql/query.py) (0 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/test/sql/test_query.py (rev 0)
+++ sqlalchemy/branches/nosetests/test/sql/test_query.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -0,0 +1,1325 @@
</span><ins>+import datetime
+from sqlalchemy import *
+from sqlalchemy import exc, sql
+from sqlalchemy.engine import default
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
+
+class QueryTest(TestBase):
+
+ @classmethod
+ def setup_class(cls):
+ global users, users2, addresses, metadata
+ metadata = MetaData(testing.db)
+ users = Table('query_users', metadata,
+ Column('user_id', INT, primary_key = True),
+ Column('user_name', VARCHAR(20)),
+ )
+ addresses = Table('query_addresses', metadata,
+ Column('address_id', Integer, primary_key=True),
+ Column('user_id', Integer, ForeignKey('query_users.user_id')),
+ Column('address', String(30)))
+
+ users2 = Table('u2', metadata,
+ Column('user_id', INT, primary_key = True),
+ Column('user_name', VARCHAR(20)),
+ )
+ metadata.create_all()
+
+ def tearDown(self):
+ addresses.delete().execute()
+ users.delete().execute()
+ users2.delete().execute()
+
+ @classmethod
+ def teardown_class(cls):
+ metadata.drop_all()
+
+ def test_insert(self):
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ assert users.count().scalar() == 1
+
+ def test_insert_heterogeneous_params(self):
+ users.insert().execute(
+ {'user_id':7, 'user_name':'jack'},
+ {'user_id':8, 'user_name':'ed'},
+ {'user_id':9}
+ )
+ assert users.select().execute().fetchall() == [(7, 'jack'), (8, 'ed'), (9, None)]
+
+ def test_update(self):
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ assert users.count().scalar() == 1
+
+ users.update(users.c.user_id == 7).execute(user_name = 'fred')
+ assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred'
+
+ def test_lastrow_accessor(self):
+ """Tests the last_inserted_ids() and lastrow_has_id() functions."""
+
+ def insert_values(table, values):
+ """
+ Inserts a row into a table, returns the full list of values
+ INSERTed including defaults that fired off on the DB side and
+ detects rows that had defaults and post-fetches.
+ """
+
+ result = table.insert().execute(**values)
+ ret = values.copy()
+
+ for col, id in zip(table.primary_key, result.last_inserted_ids()):
+ ret[col.key] = id
+
+ if result.lastrow_has_defaults():
+ criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
+ row = table.select(criterion).execute().fetchone()
+ for c in table.c:
+ ret[c.key] = row[c]
+ return ret
+
+ for supported, table, values, assertvalues in [
+ (
+ {'unsupported':['sqlite']},
+ Table("t1", metadata,
+ Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True),
+ Column('foo', String(30), primary_key=True)),
+ {'foo':'hi'},
+ {'id':1, 'foo':'hi'}
+ ),
+ (
+ {'unsupported':['sqlite']},
+ Table("t2", metadata,
+ Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True),
+ Column('foo', String(30), primary_key=True),
+ Column('bar', String(30), server_default='hi')
+ ),
+ {'foo':'hi'},
+ {'id':1, 'foo':'hi', 'bar':'hi'}
+ ),
+ (
+ {'unsupported':[]},
+ Table("t3", metadata,
+ Column("id", String(40), primary_key=True),
+ Column('foo', String(30), primary_key=True),
+ Column("bar", String(30))
+ ),
+ {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"},
+ {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}
+ ),
+ (
+ {'unsupported':[]},
+ Table("t4", metadata,
+ Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
+ Column('foo', String(30), primary_key=True),
+ Column('bar', String(30), server_default='hi')
+ ),
+ {'foo':'hi', 'id':1},
+ {'id':1, 'foo':'hi', 'bar':'hi'}
+ ),
+ (
+ {'unsupported':[]},
+ Table("t5", metadata,
+ Column('id', String(10), primary_key=True),
+ Column('bar', String(30), server_default='hi')
+ ),
+ {'id':'id1'},
+ {'id':'id1', 'bar':'hi'},
+ ),
+ ]:
+ if testing.db.name in supported['unsupported']:
+ continue
+ try:
+ table.create()
+ i = insert_values(table, values)
+ assert i == assertvalues, repr(i) + " " + repr(assertvalues)
+ finally:
+ table.drop()
+
+ def test_row_iteration(self):
+ users.insert().execute(
+ {'user_id':7, 'user_name':'jack'},
+ {'user_id':8, 'user_name':'ed'},
+ {'user_id':9, 'user_name':'fred'},
+ )
+ r = users.select().execute()
+ l = []
+ for row in r:
+ l.append(row)
+ self.assert_(len(l) == 3)
+
+ @testing.fails_on('firebird', 'Data type unknown')
+ @testing.requires.subqueries
+ def test_anonymous_rows(self):
+ users.insert().execute(
+ {'user_id':7, 'user_name':'jack'},
+ {'user_id':8, 'user_name':'ed'},
+ {'user_id':9, 'user_name':'fred'},
+ )
+
+ sel = select([users.c.user_id]).where(users.c.user_name=='jack').as_scalar()
+ for row in select([sel + 1, sel + 3], bind=users.bind).execute():
+ assert row['anon_1'] == 8
+ assert row['anon_2'] == 10
+
+ def test_order_by_label(self):
+ """test that a label within an ORDER BY works on each backend.
+
+ simple labels in ORDER BYs now render as the actual labelname
+ which not every database supports.
+
+ """
+ users.insert().execute(
+ {'user_id':7, 'user_name':'jack'},
+ {'user_id':8, 'user_name':'ed'},
+ {'user_id':9, 'user_name':'fred'},
+ )
+
+ concat = ("test: " + users.c.user_name).label('thedata')
+ self.assertEquals(
+ select([concat]).order_by(concat).execute().fetchall(),
+ [("test: ed",), ("test: fred",), ("test: jack",)]
+ )
+
+ concat = ("test: " + users.c.user_name).label('thedata')
+ self.assertEquals(
+ select([concat]).order_by(desc(concat)).execute().fetchall(),
+ [("test: jack",), ("test: fred",), ("test: ed",)]
+ )
+
+ concat = ("test: " + users.c.user_name).label('thedata')
+ self.assertEquals(
+ select([concat]).order_by(concat + "x").execute().fetchall(),
+ [("test: ed",), ("test: fred",), ("test: jack",)]
+ )
+
+
+ def test_row_comparison(self):
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ rp = users.select().execute().fetchone()
+
+ self.assert_(rp == rp)
+ self.assert_(not(rp != rp))
+
+ equal = (7, 'jack')
+
+ self.assert_(rp == equal)
+ self.assert_(equal == rp)
+ self.assert_(not (rp != equal))
+ self.assert_(not (equal != equal))
+
+ @testing.fails_on('mssql', 'No support for boolean logic in column select.')
+ @testing.fails_on('oracle', 'FIXME: unknown')
+ def test_or_and_as_columns(self):
+ true, false = literal(True), literal(False)
+
+ self.assertEquals(testing.db.execute(select([and_(true, false)])).scalar(), False)
+ self.assertEquals(testing.db.execute(select([and_(true, true)])).scalar(), True)
+ self.assertEquals(testing.db.execute(select([or_(true, false)])).scalar(), True)
+ self.assertEquals(testing.db.execute(select([or_(false, false)])).scalar(), False)
+ self.assertEquals(testing.db.execute(select([not_(or_(false, false))])).scalar(), True)
+
+ row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).fetchone()
+ assert row.x == False
+ assert row.y == False
+
+ row = testing.db.execute(select([or_(true, false).label("x"), and_(true, false).label("y")])).fetchone()
+ assert row.x == True
+ assert row.y == False
+
+ def test_fetchmany(self):
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'ed')
+ users.insert().execute(user_id = 9, user_name = 'fred')
+ r = users.select().execute()
+ l = []
+ for row in r.fetchmany(size=2):
+ l.append(row)
+ self.assert_(len(l) == 2, "fetchmany(size=2) got %s rows" % len(l))
+
+ def test_like_ops(self):
+ users.insert().execute(
+ {'user_id':1, 'user_name':'apples'},
+ {'user_id':2, 'user_name':'oranges'},
+ {'user_id':3, 'user_name':'bananas'},
+ {'user_id':4, 'user_name':'legumes'},
+ {'user_id':5, 'user_name':'hi % there'},
+ )
+
+ for expr, result in (
+ (select([users.c.user_id]).where(users.c.user_name.startswith('apple')), [(1,)]),
+ (select([users.c.user_id]).where(users.c.user_name.contains('i % t')), [(5,)]),
+ (select([users.c.user_id]).where(users.c.user_name.endswith('anas')), [(3,)]),
+ ):
+ eq_(expr.execute().fetchall(), result)
+
+
+ @testing.emits_warning('.*now automatically escapes.*')
+ def test_percents_in_text(self):
+ for expr, result in (
+ (text("select 6 % 10"), 6),
+ (text("select 17 % 10"), 7),
+ (text("select '%'"), '%'),
+ (text("select '%%'"), '%%'),
+ (text("select '%%%'"), '%%%'),
+ (text("select 'hello % world'"), "hello % world")
+ ):
+ eq_(testing.db.scalar(expr), result)
+
+ def test_ilike(self):
+ users.insert().execute(
+ {'user_id':1, 'user_name':'one'},
+ {'user_id':2, 'user_name':'TwO'},
+ {'user_id':3, 'user_name':'ONE'},
+ {'user_id':4, 'user_name':'OnE'},
+ )
+
+ self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('one')).execute().fetchall(), [(1, ), (3, ), (4, )])
+
+ self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )])
+
+ if testing.against('postgres'):
+ self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )])
+ self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), [])
+
+
+ def test_compiled_execute(self):
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ s = select([users], users.c.user_id==bindparam('id')).compile()
+ c = testing.db.connect()
+ assert c.execute(s, id=7).fetchall()[0]['user_id'] == 7
+
+ def test_compiled_insert_execute(self):
+ users.insert().compile().execute(user_id = 7, user_name = 'jack')
+ s = select([users], users.c.user_id==bindparam('id')).compile()
+ c = testing.db.connect()
+ assert c.execute(s, id=7).fetchall()[0]['user_id'] == 7
+
+ def test_repeated_bindparams(self):
+ """Tests that a BindParam can be used more than once.
+
+ This should be run for DB-APIs with both positional and named
+ paramstyles.
+ """
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'fred')
+
+ u = bindparam('userid')
+ s = users.select(and_(users.c.user_name==u, users.c.user_name==u))
+ r = s.execute(userid='fred').fetchall()
+ assert len(r) == 1
+
+ def test_bindparam_shortname(self):
+ """test the 'shortname' field on BindParamClause."""
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'fred')
+ u = bindparam('userid', shortname='someshortname')
+ s = users.select(users.c.user_name==u)
+ r = s.execute(someshortname='fred').fetchall()
+ assert len(r) == 1
+
+ def test_bindparam_detection(self):
+ dialect = default.DefaultDialect(paramstyle='qmark')
+ prep = lambda q: str(sql.text(q).compile(dialect=dialect))
+
+ def a_eq(got, wanted):
+ if got != wanted:
+ print "Wanted %s" % wanted
+ print "Received %s" % got
+ self.assert_(got == wanted, got)
+
+ a_eq(prep('select foo'), 'select foo')
+ a_eq(prep("time='12:30:00'"), "time='12:30:00'")
+ a_eq(prep(u"time='12:30:00'"), u"time='12:30:00'")
+ a_eq(prep(":this:that"), ":this:that")
+ a_eq(prep(":this :that"), "? ?")
+ a_eq(prep("(:this),(:that :other)"), "(?),(? ?)")
+ a_eq(prep("(:this),(:that:other)"), "(?),(:that:other)")
+ a_eq(prep("(:this),(:that,:other)"), "(?),(?,?)")
+ a_eq(prep("(:that_:other)"), "(:that_:other)")
+ a_eq(prep("(:that_ :other)"), "(? ?)")
+ a_eq(prep("(:that_other)"), "(?)")
+ a_eq(prep("(:that$other)"), "(?)")
+ a_eq(prep("(:that$:other)"), "(:that$:other)")
+ a_eq(prep(".:that$ :other."), ".? ?.")
+
+ a_eq(prep(r'select \foo'), r'select \foo')
+ a_eq(prep(r"time='12\:30:00'"), r"time='12\:30:00'")
+ a_eq(prep(":this \:that"), "? :that")
+ a_eq(prep(r"(\:that$other)"), "(:that$other)")
+ a_eq(prep(r".\:that$ :other."), ".:that$ ?.")
+
+ def test_delete(self):
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'fred')
+ print repr(users.select().execute().fetchall())
+
+ users.delete(users.c.user_name == 'fred').execute()
+
+ print repr(users.select().execute().fetchall())
+
+
+
+ @testing.exclude('mysql', '<', (5, 0, 37), 'database bug')
+ def test_scalar_select(self):
+ """test that scalar subqueries with labels get their type propagated to the result set."""
+ # mysql and/or mysqldb has a bug here, type isn't propagated for scalar
+ # subquery.
+ datetable = Table('datetable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('today', DateTime))
+ datetable.create()
+ try:
+ datetable.insert().execute(id=1, today=datetime.datetime(2006, 5, 12, 12, 0, 0))
+ s = select([datetable.alias('x').c.today]).as_scalar()
+ s2 = select([datetable.c.id, s.label('somelabel')])
+ #print s2.c.somelabel.type
+ assert isinstance(s2.execute().fetchone()['somelabel'], datetime.datetime)
+ finally:
+ datetable.drop()
+
+ def test_order_by(self):
+ """Exercises ORDER BY clause generation.
+
+ Tests simple, compound, aliased and DESC clauses.
+ """
+
+ users.insert().execute(user_id=1, user_name='c')
+ users.insert().execute(user_id=2, user_name='b')
+ users.insert().execute(user_id=3, user_name='a')
+
+ def a_eq(executable, wanted):
+ got = list(executable.execute())
+ self.assertEquals(got, wanted)
+
+ for labels in False, True:
+ a_eq(users.select(order_by=[users.c.user_id],
+ use_labels=labels),
+ [(1, 'c'), (2, 'b'), (3, 'a')])
+
+ a_eq(users.select(order_by=[users.c.user_name, users.c.user_id],
+ use_labels=labels),
+ [(3, 'a'), (2, 'b'), (1, 'c')])
+
+ a_eq(select([users.c.user_id.label('foo')],
+ use_labels=labels,
+ order_by=[users.c.user_id]),
+ [(1,), (2,), (3,)])
+
+ a_eq(select([users.c.user_id.label('foo'), users.c.user_name],
+ use_labels=labels,
+ order_by=[users.c.user_name, users.c.user_id]),
+ [(3, 'a'), (2, 'b'), (1, 'c')])
+
+ a_eq(users.select(distinct=True,
+ use_labels=labels,
+ order_by=[users.c.user_id]),
+ [(1, 'c'), (2, 'b'), (3, 'a')])
+
+ a_eq(select([users.c.user_id.label('foo')],
+ distinct=True,
+ use_labels=labels,
+ order_by=[users.c.user_id]),
+ [(1,), (2,), (3,)])
+
+ a_eq(select([users.c.user_id.label('a'),
+ users.c.user_id.label('b'),
+ users.c.user_name],
+ use_labels=labels,
+ order_by=[users.c.user_id]),
+ [(1, 1, 'c'), (2, 2, 'b'), (3, 3, 'a')])
+
+ a_eq(users.select(distinct=True,
+ use_labels=labels,
+ order_by=[desc(users.c.user_id)]),
+ [(3, 'a'), (2, 'b'), (1, 'c')])
+
+ a_eq(select([users.c.user_id.label('foo')],
+ distinct=True,
+ use_labels=labels,
+ order_by=[users.c.user_id.desc()]),
+ [(3,), (2,), (1,)])
+
+ def test_column_accessor(self):
+ users.insert().execute(user_id=1, user_name='john')
+ users.insert().execute(user_id=2, user_name='jack')
+ addresses.insert().execute(address_id=1, user_id=2, address='foo@...')
+
+ r = users.select(users.c.user_id==2).execute().fetchone()
+ self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
+ self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
+
+ r = text("select * from query_users where user_id=2", bind=testing.db).execute().fetchone()
+ self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
+ self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
+
+ # test slices
+ r = text("select * from query_addresses", bind=testing.db).execute().fetchone()
+ self.assert_(r[0:1] == (1,))
+ self.assert_(r[1:] == (2, 'foo@...'))
+ self.assert_(r[:-1] == (1, 2))
+
+ # test a little sqlite weirdness - with the UNION, cols come back as "query_users.user_id" in cursor.description
+ r = text("select query_users.user_id, query_users.user_name from query_users "
+ "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().fetchone()
+ self.assert_(r['user_id']) == 1
+ self.assert_(r['user_name']) == "john"
+
+ # test using literal tablename.colname
+ r = text('select query_users.user_id AS "query_users.user_id", query_users.user_name AS "query_users.user_name" from query_users', bind=testing.db).execute().fetchone()
+ self.assert_(r['query_users.user_id']) == 1
+ self.assert_(r['query_users.user_name']) == "john"
+
+ def test_row_as_args(self):
+ users.insert().execute(user_id=1, user_name='john')
+ r = users.select(users.c.user_id==1).execute().fetchone()
+ users.delete().execute()
+ users.insert().execute(r)
+ assert users.select().execute().fetchall() == [(1, 'john')]
+
+ def test_result_as_args(self):
+ users.insert().execute([dict(user_id=1, user_name='john'), dict(user_id=2, user_name='ed')])
+ r = users.select().execute()
+ users2.insert().execute(list(r))
+ assert users2.select().execute().fetchall() == [(1, 'john'), (2, 'ed')]
+
+ users2.delete().execute()
+ r = users.select().execute()
+ users2.insert().execute(*list(r))
+ assert users2.select().execute().fetchall() == [(1, 'john'), (2, 'ed')]
+
+ def test_ambiguous_column(self):
+ users.insert().execute(user_id=1, user_name='john')
+ r = users.outerjoin(addresses).select().execute().fetchone()
+ try:
+ print r['user_id']
+ assert False
+ except exc.InvalidRequestError, e:
+ assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \
+ str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement."
+
+ @testing.requires.subqueries
+ def test_column_label_targeting(self):
+ users.insert().execute(user_id=7, user_name='ed')
+
+ for s in (
+ users.select().alias('foo'),
+ users.select().alias(users.name),
+ ):
+ row = s.select(use_labels=True).execute().fetchone()
+ assert row[s.c.user_id] == 7
+ assert row[s.c.user_name] == 'ed'
+
+ def test_keys(self):
+ users.insert().execute(user_id=1, user_name='foo')
+ r = users.select().execute().fetchone()
+ self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
+
+ def test_items(self):
+ users.insert().execute(user_id=1, user_name='foo')
+ r = users.select().execute().fetchone()
+ self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
+
+ def test_len(self):
+ users.insert().execute(user_id=1, user_name='foo')
+ r = users.select().execute().fetchone()
+ self.assertEqual(len(r), 2)
+ r.close()
+ r = testing.db.execute('select user_name, user_id from query_users').fetchone()
+ self.assertEqual(len(r), 2)
+ r.close()
+ r = testing.db.execute('select user_name from query_users').fetchone()
+ self.assertEqual(len(r), 1)
+ r.close()
+
+ def test_cant_execute_join(self):
+ try:
+ users.join(addresses).execute()
+ except exc.ArgumentError, e:
+ assert str(e).startswith('Not an executable clause: ')
+
+
+
+ def test_column_order_with_simple_query(self):
+ # should return values in column definition order
+ users.insert().execute(user_id=1, user_name='foo')
+ r = users.select(users.c.user_id==1).execute().fetchone()
+ self.assertEqual(r[0], 1)
+ self.assertEqual(r[1], 'foo')
+ self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
+ self.assertEqual(r.values(), [1, 'foo'])
+
+ def test_column_order_with_text_query(self):
+ # should return values in query order
+ users.insert().execute(user_id=1, user_name='foo')
+ r = testing.db.execute('select user_name, user_id from query_users').fetchone()
+ self.assertEqual(r[0], 'foo')
+ self.assertEqual(r[1], 1)
+ self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id'])
+ self.assertEqual(r.values(), ['foo', 1])
+
+ @testing.crashes('oracle', 'FIXME: unknown, varify not fails_on()')
+ @testing.crashes('firebird', 'An identifier must begin with a letter')
+ @testing.crashes('maxdb', 'FIXME: unknown, verify not fails_on()')
+ def test_column_accessor_shadow(self):
+ meta = MetaData(testing.db)
+ shadowed = Table('test_shadowed', meta,
+ Column('shadow_id', INT, primary_key = True),
+ Column('shadow_name', VARCHAR(20)),
+ Column('parent', VARCHAR(20)),
+ Column('row', VARCHAR(40)),
+ Column('__parent', VARCHAR(20)),
+ Column('__row', VARCHAR(20)),
+ )
+ shadowed.create(checkfirst=True)
+ try:
+ shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row')
+ r = shadowed.select(shadowed.c.shadow_id==1).execute().fetchone()
+ self.assert_(r.shadow_id == r['shadow_id'] == r[shadowed.c.shadow_id] == 1)
+ self.assert_(r.shadow_name == r['shadow_name'] == r[shadowed.c.shadow_name] == 'The Shadow')
+ self.assert_(r.parent == r['parent'] == r[shadowed.c.parent] == 'The Light')
+ self.assert_(r.row == r['row'] == r[shadowed.c.row] == 'Without light there is no shadow')
+ self.assert_(r['__parent'] == 'Hidden parent')
+ self.assert_(r['__row'] == 'Hidden row')
+ try:
+ print r.__parent, r.__row
+ self.fail('Should not allow access to private attributes')
+ except AttributeError:
+ pass # expected
+ r.close()
+ finally:
+ shadowed.drop(checkfirst=True)
+
+ def test_in_filtering(self):
+ """test the behavior of the in_() function."""
+
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'fred')
+ users.insert().execute(user_id = 9, user_name = None)
+
+ s = users.select(users.c.user_name.in_([]))
+ r = s.execute().fetchall()
+ # No username is in empty set
+ assert len(r) == 0
+
+ s = users.select(not_(users.c.user_name.in_([])))
+ r = s.execute().fetchall()
+ # All usernames with a value are outside an empty set
+ assert len(r) == 2
+
+ s = users.select(users.c.user_name.in_(['jack','fred']))
+ r = s.execute().fetchall()
+ assert len(r) == 2
+
+ s = users.select(not_(users.c.user_name.in_(['jack','fred'])))
+ r = s.execute().fetchall()
+ # Null values are not outside any set
+ assert len(r) == 0
+
+ u = bindparam('search_key')
+
+ s = users.select(u.in_([]))
+ r = s.execute(search_key='john').fetchall()
+ assert len(r) == 0
+ r = s.execute(search_key=None).fetchall()
+ assert len(r) == 0
+
+ s = users.select(not_(u.in_([])))
+ r = s.execute(search_key='john').fetchall()
+ assert len(r) == 3
+ r = s.execute(search_key=None).fetchall()
+ assert len(r) == 0
+
+ @testing.fails_on('firebird', 'FIXME: unknown')
+ @testing.fails_on('maxdb', 'FIXME: unknown')
+ @testing.fails_on('oracle', 'FIXME: unknown')
+ @testing.fails_on('mssql', 'FIXME: unknown')
+ def test_in_filtering_advanced(self):
+ """test the behavior of the in_() function when comparing against an empty collection."""
+
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'fred')
+ users.insert().execute(user_id = 9, user_name = None)
+
+ s = users.select(users.c.user_name.in_([]) == True)
+ r = s.execute().fetchall()
+ assert len(r) == 0
+ s = users.select(users.c.user_name.in_([]) == False)
+ r = s.execute().fetchall()
+ assert len(r) == 2
+ s = users.select(users.c.user_name.in_([]) == None)
+ r = s.execute().fetchall()
+ assert len(r) == 1
+
+class PercentSchemaNamesTest(TestBase):
+ """tests using percent signs, spaces in table and column names.
+
+ Doesn't pass for mysql, postgres, but this is really a
+ SQLAlchemy bug - we should be escaping out %% signs for this
+ operation the same way we do for text() and column labels.
+
+ """
+ @classmethod
+ @testing.crashes('mysql', 'mysqldb calls name % (params)')
+ @testing.crashes('postgres', 'postgres calls name % (params)')
+ def setup_class(cls):
+ global percent_table, metadata
+ metadata = MetaData(testing.db)
+ percent_table = Table('percent%table', metadata,
+ Column("percent%", Integer),
+ Column("%(oneofthese)s", Integer),
+ Column("spaces % more spaces", Integer),
+ )
+ metadata.create_all()
+
+ @classmethod
+ @testing.crashes('mysql', 'mysqldb calls name % (params)')
+ @testing.crashes('postgres', 'postgres calls name % (params)')
+ def teardown_class(cls):
+ metadata.drop_all()
+
+ @testing.crashes('mysql', 'mysqldb calls name % (params)')
+ @testing.crashes('postgres', 'postgres calls name % (params)')
+ def test_roundtrip(self):
+ percent_table.insert().execute(
+ {'percent%':5, '%(oneofthese)s':7, 'spaces % more spaces':12},
+ )
+ percent_table.insert().execute(
+ {'percent%':7, '%(oneofthese)s':8, 'spaces % more spaces':11},
+ {'percent%':9, '%(oneofthese)s':9, 'spaces % more spaces':10},
+ {'percent%':11, '%(oneofthese)s':10, 'spaces % more spaces':9},
+ )
+
+ for table in (percent_table, percent_table.alias()):
+ eq_(
+ table.select().order_by(table.c['%(oneofthese)s']).execute().fetchall(),
+ [
+ (5, 7, 12),
+ (7, 8, 11),
+ (9, 9, 10),
+ (11, 10, 9)
+ ]
+ )
+
+ eq_(
+ table.select().
+ where(table.c['spaces % more spaces'].in_([9, 10])).
+ order_by(table.c['%(oneofthese)s']).execute().fetchall(),
+ [
+ (9, 9, 10),
+ (11, 10, 9)
+ ]
+ )
+
+ result = table.select().order_by(table.c['%(oneofthese)s']).execute()
+ row = result.fetchone()
+ eq_(row[table.c['percent%']], 5)
+ eq_(row[table.c['%(oneofthese)s']], 7)
+ eq_(row[table.c['spaces % more spaces']], 12)
+ row = result.fetchone()
+ eq_(row['percent%'], 7)
+ eq_(row['%(oneofthese)s'], 8)
+ eq_(row['spaces % more spaces'], 11)
+ result.close()
+
+ percent_table.update().values({percent_table.c['%(oneofthese)s']:9, percent_table.c['spaces % more spaces']:15}).execute()
+
+ eq_(
+ percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute().fetchall(),
+ [
+ (5, 9, 15),
+ (7, 9, 15),
+ (9, 9, 15),
+ (11, 9, 15)
+ ]
+ )
+
+
+
+class LimitTest(TestBase):
+
+ @classmethod
+ def setup_class(cls):
+ global users, addresses, metadata
+ metadata = MetaData(testing.db)
+ users = Table('query_users', metadata,
+ Column('user_id', INT, primary_key = True),
+ Column('user_name', VARCHAR(20)),
+ )
+ addresses = Table('query_addresses', metadata,
+ Column('address_id', Integer, primary_key=True),
+ Column('user_id', Integer, ForeignKey('query_users.user_id')),
+ Column('address', String(30)))
+ metadata.create_all()
+
+ users.insert().execute(user_id=1, user_name='john')
+ addresses.insert().execute(address_id=1, user_id=1, address='addr1')
+ users.insert().execute(user_id=2, user_name='jack')
+ addresses.insert().execute(address_id=2, user_id=2, address='addr1')
+ users.insert().execute(user_id=3, user_name='ed')
+ addresses.insert().execute(address_id=3, user_id=3, address='addr2')
+ users.insert().execute(user_id=4, user_name='wendy')
+ addresses.insert().execute(address_id=4, user_id=4, address='addr3')
+ users.insert().execute(user_id=5, user_name='laura')
+ addresses.insert().execute(address_id=5, user_id=5, address='addr4')
+ users.insert().execute(user_id=6, user_name='ralph')
+ addresses.insert().execute(address_id=6, user_id=6, address='addr5')
+ users.insert().execute(user_id=7, user_name='fido')
+ addresses.insert().execute(address_id=7, user_id=7, address='addr5')
+
+ @classmethod
+ def teardown_class(cls):
+ metadata.drop_all()
+
+ def test_select_limit(self):
+ r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall()
+ self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r))
+
+ @testing.fails_on('maxdb', 'FIXME: unknown')
+ def test_select_limit_offset(self):
+ """Test the interaction between limit and offset"""
+
+ r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall()
+ self.assert_(r==[(3, 'ed'), (4, 'wendy'), (5, 'laura')])
+ r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall()
+ self.assert_(r==[(6, 'ralph'), (7, 'fido')])
+
+ def test_select_distinct_limit(self):
+ """Test the interaction between limit and distinct"""
+
+ r = sorted([x[0] for x in select([addresses.c.address]).distinct().limit(3).order_by(addresses.c.address).execute().fetchall()])
+ self.assert_(len(r) == 3, repr(r))
+ self.assert_(r[0] != r[1] and r[1] != r[2], repr(r))
+
+ @testing.fails_on('mssql', 'FIXME: unknown')
+ def test_select_distinct_offset(self):
+ """Test the interaction between distinct and offset"""
+
+ r = sorted([x[0] for x in select([addresses.c.address]).distinct().offset(1).order_by(addresses.c.address).execute().fetchall()])
+ self.assert_(len(r) == 4, repr(r))
+ self.assert_(r[0] != r[1] and r[1] != r[2] and r[2] != [3], repr(r))
+
+ def test_select_distinct_limit_offset(self):
+ """Test the interaction between limit and limit/offset"""
+
+ r = select([addresses.c.address]).order_by(addresses.c.address).distinct().offset(2).limit(3).execute().fetchall()
+ self.assert_(len(r) == 3, repr(r))
+ self.assert_(r[0] != r[1] and r[1] != r[2], repr(r))
+
+class CompoundTest(TestBase):
+ """test compound statements like UNION, INTERSECT, particularly their ability to nest on
+ different databases."""
+ @classmethod
+ def setup_class(cls):
+ global metadata, t1, t2, t3
+ metadata = MetaData(testing.db)
+ t1 = Table('t1', metadata,
+ Column('col1', Integer, Sequence('t1pkseq'), primary_key=True),
+ Column('col2', String(30)),
+ Column('col3', String(40)),
+ Column('col4', String(30))
+ )
+ t2 = Table('t2', metadata,
+ Column('col1', Integer, Sequence('t2pkseq'), primary_key=True),
+ Column('col2', String(30)),
+ Column('col3', String(40)),
+ Column('col4', String(30)))
+ t3 = Table('t3', metadata,
+ Column('col1', Integer, Sequence('t3pkseq'), primary_key=True),
+ Column('col2', String(30)),
+ Column('col3', String(40)),
+ Column('col4', String(30)))
+ metadata.create_all()
+
+ t1.insert().execute([
+ dict(col2="t1col2r1", col3="aaa", col4="aaa"),
+ dict(col2="t1col2r2", col3="bbb", col4="bbb"),
+ dict(col2="t1col2r3", col3="ccc", col4="ccc"),
+ ])
+ t2.insert().execute([
+ dict(col2="t2col2r1", col3="aaa", col4="bbb"),
+ dict(col2="t2col2r2", col3="bbb", col4="ccc"),
+ dict(col2="t2col2r3", col3="ccc", col4="aaa"),
+ ])
+ t3.insert().execute([
+ dict(col2="t3col2r1", col3="aaa", col4="ccc"),
+ dict(col2="t3col2r2", col3="bbb", col4="aaa"),
+ dict(col2="t3col2r3", col3="ccc", col4="bbb"),
+ ])
+
+ @classmethod
+ def teardown_class(cls):
+ metadata.drop_all()
+
+ def _fetchall_sorted(self, executed):
+ return sorted([tuple(row) for row in executed.fetchall()])
+
+ @testing.requires.subqueries
+ def test_union(self):
+ (s1, s2) = (
+ select([t1.c.col3.label('col3'), t1.c.col4.label('col4')],
+ t1.c.col2.in_(["t1col2r1", "t1col2r2"])),
+ select([t2.c.col3.label('col3'), t2.c.col4.label('col4')],
+ t2.c.col2.in_(["t2col2r2", "t2col2r3"]))
+ )
+ u = union(s1, s2)
+
+ wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
+ ('ccc', 'aaa')]
+ found1 = self._fetchall_sorted(u.execute())
+ self.assertEquals(found1, wanted)
+
+ found2 = self._fetchall_sorted(u.alias('bar').select().execute())
+ self.assertEquals(found2, wanted)
+
+ def test_union_ordered(self):
+ (s1, s2) = (
+ select([t1.c.col3.label('col3'), t1.c.col4.label('col4')],
+ t1.c.col2.in_(["t1col2r1", "t1col2r2"])),
+ select([t2.c.col3.label('col3'), t2.c.col4.label('col4')],
+ t2.c.col2.in_(["t2col2r2", "t2col2r3"]))
+ )
+ u = union(s1, s2, order_by=['col3', 'col4'])
+
+ wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
+ ('ccc', 'aaa')]
+ self.assertEquals(u.execute().fetchall(), wanted)
+
+ @testing.fails_on('maxdb', 'FIXME: unknown')
+ @testing.requires.subqueries
+ def test_union_ordered_alias(self):
+ (s1, s2) = (
+ select([t1.c.col3.label('col3'), t1.c.col4.label('col4')],
+ t1.c.col2.in_(["t1col2r1", "t1col2r2"])),
+ select([t2.c.col3.label('col3'), t2.c.col4.label('col4')],
+ t2.c.col2.in_(["t2col2r2", "t2col2r3"]))
+ )
+ u = union(s1, s2, order_by=['col3', 'col4'])
+
+ wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
+ ('ccc', 'aaa')]
+ self.assertEquals(u.alias('bar').select().execute().fetchall(), wanted)
+
+ @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
+ @testing.fails_on('mysql', 'FIXME: unknown')
+ @testing.fails_on('sqlite', 'FIXME: unknown')
+ def test_union_all(self):
+ e = union_all(
+ select([t1.c.col3]),
+ union(
+ select([t1.c.col3]),
+ select([t1.c.col3]),
+ )
+ )
+
+ wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)]
+ found1 = self._fetchall_sorted(e.execute())
+ self.assertEquals(found1, wanted)
+
+ found2 = self._fetchall_sorted(e.alias('foo').select().execute())
+ self.assertEquals(found2, wanted)
+
+ @testing.crashes('firebird', 'Does not support intersect')
+ @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
+ @testing.fails_on('mysql', 'FIXME: unknown')
+ def test_intersect(self):
+ i = intersect(
+ select([t2.c.col3, t2.c.col4]),
+ select([t2.c.col3, t2.c.col4], t2.c.col4==t3.c.col3)
+ )
+
+ wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+
+ found1 = self._fetchall_sorted(i.execute())
+ self.assertEquals(found1, wanted)
+
+ found2 = self._fetchall_sorted(i.alias('bar').select().execute())
+ self.assertEquals(found2, wanted)
+
+ @testing.crashes('firebird', 'Does not support except')
+ @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
+ @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
+ @testing.fails_on('mysql', 'FIXME: unknown')
+ def test_except_style1(self):
+ e = except_(union(
+ select([t1.c.col3, t1.c.col4]),
+ select([t2.c.col3, t2.c.col4]),
+ select([t3.c.col3, t3.c.col4]),
+ ), select([t2.c.col3, t2.c.col4]))
+
+ wanted = [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'),
+ ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
+
+ found = self._fetchall_sorted(e.alias('bar').select().execute())
+ self.assertEquals(found, wanted)
+
+ @testing.crashes('firebird', 'Does not support except')
+ @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
+ @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
+ @testing.fails_on('mysql', 'FIXME: unknown')
+ def test_except_style2(self):
+ e = except_(union(
+ select([t1.c.col3, t1.c.col4]),
+ select([t2.c.col3, t2.c.col4]),
+ select([t3.c.col3, t3.c.col4]),
+ ).alias('foo').select(), select([t2.c.col3, t2.c.col4]))
+
+ wanted = [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'),
+ ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
+
+ found1 = self._fetchall_sorted(e.execute())
+ self.assertEquals(found1, wanted)
+
+ found2 = self._fetchall_sorted(e.alias('bar').select().execute())
+ self.assertEquals(found2, wanted)
+
+ @testing.crashes('firebird', 'Does not support except')
+ @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
+ @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
+ @testing.fails_on('mysql', 'FIXME: unknown')
+ @testing.fails_on('sqlite', 'FIXME: unknown')
+ def test_except_style3(self):
+ # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc
+ e = except_(
+ select([t1.c.col3]), # aaa, bbb, ccc
+ except_(
+ select([t2.c.col3]), # aaa, bbb, ccc
+ select([t3.c.col3], t3.c.col3 == 'ccc'), #ccc
+ )
+ )
+ self.assertEquals(e.execute().fetchall(), [('ccc',)])
+ self.assertEquals(e.alias('foo').select().execute().fetchall(),
+ [('ccc',)])
+
+ @testing.crashes('firebird', 'Does not support intersect')
+ @testing.fails_on('mysql', 'FIXME: unknown')
+ def test_composite(self):
+ u = intersect(
+ select([t2.c.col3, t2.c.col4]),
+ union(
+ select([t1.c.col3, t1.c.col4]),
+ select([t2.c.col3, t2.c.col4]),
+ select([t3.c.col3, t3.c.col4]),
+ ).alias('foo').select()
+ )
+ wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+ found = self._fetchall_sorted(u.execute())
+
+ self.assertEquals(found, wanted)
+
+ @testing.crashes('firebird', 'Does not support intersect')
+ @testing.fails_on('mysql', 'FIXME: unknown')
+ def test_composite_alias(self):
+ ua = intersect(
+ select([t2.c.col3, t2.c.col4]),
+ union(
+ select([t1.c.col3, t1.c.col4]),
+ select([t2.c.col3, t2.c.col4]),
+ select([t3.c.col3, t3.c.col4]),
+ ).alias('foo').select()
+ ).alias('bar')
+
+ wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+ found = self._fetchall_sorted(ua.select().execute())
+ self.assertEquals(found, wanted)
+
+
+class JoinTest(TestBase):
+ """Tests join execution.
+
+ The compiled SQL emitted by the dialect might be ANSI joins or
+ theta joins ('old oracle style', with (+) for OUTER). This test
+ tries to exercise join syntax and uncover any inconsistencies in
+ `JOIN rhs ON lhs.col=rhs.col` vs `rhs.col=lhs.col`. At least one
+ database seems to be sensitive to this.
+ """
+
+ @classmethod
+ def setup_class(cls):
+ global metadata
+ global t1, t2, t3
+
+ metadata = MetaData(testing.db)
+ t1 = Table('t1', metadata,
+ Column('t1_id', Integer, primary_key=True),
+ Column('name', String(32)))
+ t2 = Table('t2', metadata,
+ Column('t2_id', Integer, primary_key=True),
+ Column('t1_id', Integer, ForeignKey('t1.t1_id')),
+ Column('name', String(32)))
+ t3 = Table('t3', metadata,
+ Column('t3_id', Integer, primary_key=True),
+ Column('t2_id', Integer, ForeignKey('t2.t2_id')),
+ Column('name', String(32)))
+ metadata.drop_all()
+ metadata.create_all()
+
+ # t1.10 -> t2.20 -> t3.30
+ # t1.11 -> t2.21
+ # t1.12
+ t1.insert().execute({'t1_id': 10, 'name': 't1 #10'},
+ {'t1_id': 11, 'name': 't1 #11'},
+ {'t1_id': 12, 'name': 't1 #12'})
+ t2.insert().execute({'t2_id': 20, 't1_id': 10, 'name': 't2 #20'},
+ {'t2_id': 21, 't1_id': 11, 'name': 't2 #21'})
+ t3.insert().execute({'t3_id': 30, 't2_id': 20, 'name': 't3 #30'})
+
+ @classmethod
+ def teardown_class(cls):
+ metadata.drop_all()
+
+ def assertRows(self, statement, expected):
+ """Execute a statement and assert that rows returned equal expected."""
+
+ found = sorted([tuple(row)
+ for row in statement.execute().fetchall()])
+
+ self.assertEquals(found, sorted(expected))
+
+ def test_join_x1(self):
+ """Joins t1->t2."""
+
+ for criteria in (t1.c.t1_id==t2.c.t1_id, t2.c.t1_id==t1.c.t1_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id],
+ from_obj=[t1.join(t2, criteria)])
+ self.assertRows(expr, [(10, 20), (11, 21)])
+
+ def test_join_x2(self):
+ """Joins t1->t2->t3."""
+
+ for criteria in (t1.c.t1_id==t2.c.t1_id, t2.c.t1_id==t1.c.t1_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id],
+ from_obj=[t1.join(t2, criteria)])
+ self.assertRows(expr, [(10, 20), (11, 21)])
+
+ def test_outerjoin_x1(self):
+ """Outer joins t1->t2."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id],
+ from_obj=[t1.join(t2).join(t3, criteria)])
+ self.assertRows(expr, [(10, 20)])
+
+ def test_outerjoin_x2(self):
+ """Outer joins t1->t2,t3."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ from_obj=[t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id). \
+ outerjoin(t3, criteria)])
+ self.assertRows(expr, [(10, 20, 30), (11, 21, None), (12, None, None)])
+
+ def test_outerjoin_where_x2_t1(self):
+ """Outer joins t1->t2,t3, where on t1."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ t1.c.name == 't1 #10',
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ t1.c.t1_id < 12,
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
+
+ def test_outerjoin_where_x2_t2(self):
+ """Outer joins t1->t2,t3, where on t2."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ t2.c.name == 't2 #20',
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ t2.c.t2_id < 29,
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
+
+ def test_outerjoin_where_x2_t1t2(self):
+ """Outer joins t1->t2,t3, where on t1 and t2."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'),
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t1.c.t1_id < 19, 29 > t2.c.t2_id),
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
+
+ def test_outerjoin_where_x2_t3(self):
+ """Outer joins t1->t2,t3, where on t3."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ t3.c.name == 't3 #30',
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ t3.c.t3_id < 39,
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ def test_outerjoin_where_x2_t1t3(self):
+ """Outer joins t1->t2,t3, where on t1 and t3."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t1.c.name == 't1 #10', t3.c.name == 't3 #30'),
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t1.c.t1_id < 19, t3.c.t3_id < 39),
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ def test_outerjoin_where_x2_t1t2(self):
+ """Outer joins t1->t2,t3, where on t1 and t2."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'),
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t1.c.t1_id < 12, t2.c.t2_id < 39),
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
+
+ def test_outerjoin_where_x2_t1t2t3(self):
+ """Outer joins t1->t2,t3, where on t1, t2 and t3."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t1.c.name == 't1 #10',
+ t2.c.name == 't2 #20',
+ t3.c.name == 't3 #30'),
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t1.c.t1_id < 19,
+ t2.c.t2_id < 29,
+ t3.c.t3_id < 39),
+ from_obj=[(t1.outerjoin(t2, t1.c.t1_id==t2.c.t1_id).
+ outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ def test_mixed(self):
+ """Joins t1->t2, outer t2->t3."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
+ print expr
+ self.assertRows(expr, [(10, 20, 30), (11, 21, None)])
+
+ def test_mixed_where(self):
+ """Joins t1->t2, outer t2->t3, plus a where on each table in turn."""
+
+ for criteria in (t2.c.t2_id==t3.c.t2_id, t3.c.t2_id==t2.c.t2_id):
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ t1.c.name == 't1 #10',
+ from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ t2.c.name == 't2 #20',
+ from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ t3.c.name == 't3 #30',
+ from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t1.c.name == 't1 #10', t2.c.name == 't2 #20'),
+ from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t2.c.name == 't2 #20', t3.c.name == 't3 #30'),
+ from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+ expr = select(
+ [t1.c.t1_id, t2.c.t2_id, t3.c.t3_id],
+ and_(t1.c.name == 't1 #10',
+ t2.c.name == 't2 #20',
+ t3.c.name == 't3 #30'),
+ from_obj=[(t1.join(t2).outerjoin(t3, criteria))])
+ self.assertRows(expr, [(10, 20, 30)])
+
+
+class OperatorTest(TestBase):
+ @classmethod
+ def setup_class(cls):
+ global metadata, flds
+ metadata = MetaData(testing.db)
+ flds = Table('flds', metadata,
+ Column('idcol', Integer, Sequence('t1pkseq'), primary_key=True),
+ Column('intcol', Integer),
+ Column('strcol', String(50)),
+ )
+ metadata.create_all()
+
+ flds.insert().execute([
+ dict(intcol=5, strcol='foo'),
+ dict(intcol=13, strcol='bar')
+ ])
+
+ @classmethod
+ def teardown_class(cls):
+ metadata.drop_all()
+
+ @testing.fails_on('maxdb', 'FIXME: unknown')
+ def test_modulo(self):
+ self.assertEquals(
+ select([flds.c.intcol % 3],
+ order_by=flds.c.idcol).execute().fetchall(),
+ [(2,),(1,)]
+ )
</ins></span></pre></div>
<a id="sqlalchemybranchesnoseteststesttestenvpy"></a>
<div class="delfile"><h4>Deleted: sqlalchemy/branches/nosetests/test/testenv.py (6026 => 6027)</h4>
<pre class="diff"><span>
<span class="info">--- sqlalchemy/branches/nosetests/test/testenv.py 2009-06-07 22:08:09 UTC (rev 6026)
+++ sqlalchemy/branches/nosetests/test/testenv.py 2009-06-07 22:09:07 UTC (rev 6027)
</span><span class="lines">@@ -1,36 +0,0 @@
</span><del>-"""First import for all test cases, sets sys.path and loads configuration."""
-
-import sys, os, logging, warnings
-
-if sys.version_info < (2, 4):
- warnings.filterwarnings('ignore', category=FutureWarning)
-
-
-from testlib.testing import main
-import testlib.config
-
-
-_setup = False
-
-def configure_for_tests():
- """import testenv; testenv.configure_for_tests()"""
-
- global _setup
- if not _setup:
- sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
- logging.basicConfig()
-
- testlib.config.configure()
- _setup = True
-
-def simple_setup():
- """import testenv; testenv.simple_setup()"""
-
- global _setup
- if not _setup:
- sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
- logging.basicConfig()
-
- testlib.config.configure_defaults()
- _setup = True
-
</del></span></pre>
</div>
</div>
</body>
</html>
|