[SQL-CVS] r689 - in trunk/SQLObject/sqlobject: . tests
SQLObject is a Python ORM.
Brought to you by:
ianbicking,
phd
From: <sub...@co...> - 2005-03-27 19:37:40
|
Author: ianb Date: 2005-03-27 19:37:29 +0000 (Sun, 27 Mar 2005) New Revision: 689 Added: trunk/SQLObject/sqlobject/tests/test_conngetter.py Modified: trunk/SQLObject/sqlobject/classregistry.py trunk/SQLObject/sqlobject/dbconnection.py trunk/SQLObject/sqlobject/tests/dbtest.py Log: * Allow keyword arguments to connectionForURI (and pass through from dbtest.getConnection) * Add callback to registry that is called for all classes * Wrappers added to dbconnection instances that create a psuedo-class that is bound to a connection Modified: trunk/SQLObject/sqlobject/classregistry.py =================================================================== --- trunk/SQLObject/sqlobject/classregistry.py 2005-03-25 17:39:04 UTC (rev 688) +++ trunk/SQLObject/sqlobject/classregistry.py 2005-03-27 19:37:29 UTC (rev 689) @@ -46,6 +46,7 @@ self.name = name self.classes = {} self.callbacks = {} + self.genericCallbacks = [] def addClassCallback(self, className, callback, *args, **kw): """ @@ -59,6 +60,15 @@ else: self.callbacks.setdefault(className, []).append((callback, args, kw)) + def addCallback(self, callback, *args, **kw): + """ + This callback is called for all classes, not just specific + ones (like addClassCallback). + """ + self.genericCallbacks.append((callback, args, kw)) + for cls in self.classes.values(): + callback(cls, *args, **kw) + def addClass(self, cls): """ Everytime a class is created, we add it to the registry, so @@ -84,6 +94,8 @@ for callback, args, kw in self.callbacks[cls.__name__]: callback(cls, *args, **kw) del self.callbacks[cls.__name__] + for callback, args, kw in self.genericCallbacks: + callback(cls, *args, **kw) def getClass(self, className): return self.classes[className] Modified: trunk/SQLObject/sqlobject/dbconnection.py =================================================================== --- trunk/SQLObject/sqlobject/dbconnection.py 2005-03-25 17:39:04 UTC (rev 688) +++ trunk/SQLObject/sqlobject/dbconnection.py 2005-03-27 19:37:29 UTC (rev 689) @@ -8,14 +8,16 @@ import atexit import os import new +import types +import urllib +import weakref +import inspect import sqlbuilder from cache import CacheSet import col from joins import sorter from converters import sqlrepr -import urllib -import weakref -from classregistry import findClass +import classregistry warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used") @@ -30,7 +32,7 @@ def __init__(self, name=None, debug=False, debugOutput=False, cache=True, style=None, autoCommit=True, - debugThreading=False): + debugThreading=False, registry=None): self.name = name self.debug = debug self.debugOutput = debugOutput @@ -41,6 +43,9 @@ self._connectionNumbers = {} self._connectionCount = 1 self.autoCommit = autoCommit + self.registry = registry or None + classregistry.registry(self.registry).addCallback( + self.soClassAdded) registerConnectionInstance(self) atexit.register(_closeConnection, weakref.ref(self)) @@ -119,6 +124,75 @@ return user, password, host, port, path, args _parseURI = staticmethod(_parseURI) + def soClassAdded(self, soClass): + """ + This is called for each new class; we use this opportunity + to create an instance method that is bound to the class + and this connection. + """ + name = soClass.__name__ + assert not hasattr(self, name), ( + "Connection %r already has an attribute with the name " + "%r (and you just created the conflicting class %r)" + % (self, name, soClass)) + setattr(self, name, ConnWrapper(soClass, self)) + +class ConnWrapper(object): + + """ + This represents a SQLObject class that is bound to a specific + connection (instances have a connection instance variable, but + classes are global, so this is binds the connection variable + lazily when a class method is accessed) + """ + # @@: methods that take connection arguments should be explicitly + # marked up instead of the implicit use of a connection argument + # and inspect.getargspec() + + def __init__(self, soClass, connection): + self._soClass = soClass + self._connection = connection + + def __call__(self, *args, **kw): + kw['connection'] = self._connection + return self._soClass(*args, **kw) + + def __getattr__(self, attr): + meth = getattr(self._soClass, attr) + if not isinstance(meth, types.MethodType): + # We don't need to wrap non-methods + return meth + try: + takes_conn = meth.takes_connection + except AttributeError: + args, varargs, varkw, defaults = inspect.getargspec(meth) + assert not varkw and not varargs, ( + "I cannot tell whether I must wrap this method, " + "because it takes **kw: %r" + % meth) + takes_conn = 'connection' in args + meth.im_func.takes_connection = takes_conn + if not takes_conn: + return meth + return ConnMethodWrapper(meth, self._connection) + +class ConnMethodWrapper(object): + + def __init__(self, method, connection): + self._method = method + self._connection = connection + + def __getattr__(self, attr): + return getattr(self._method, attr) + + def __call__(self, *args, **kw): + kw['connection'] = self._connection + return self._method(*args, **kw) + + def __repr__(self): + return '<Wrapped %r with connection %r>' % ( + self._method, self._connection) + class DBAPI(DBConnection): """ @@ -627,7 +701,10 @@ try: func = attr.im_func except AttributeError: - return attr + if isinstance(attr, ConnWrapper): + return ConnWrapper(attr._soClass, self) + else: + return attr else: meth = new.instancemethod(func, self, self.__class__) return meth @@ -674,13 +751,18 @@ assert inst.name.find(':') == -1, "You cannot include ':' in your class names (%r)" % cls.name self.instanceNames[inst.name] = inst - def connectionForURI(self, uri): + def connectionForURI(self, uri, **args): + if args: + if '?' not in uri: + uri += '?' + uri += urllib.urlencode(args) if self.cachedURIs.has_key(uri): return self.cachedURIs[uri] if uri.find(':') != -1: scheme, rest = uri.split(':', 1) - assert self.schemeBuilders.has_key(scheme), \ - "No SQLObject driver exists for %s" % scheme + assert self.schemeBuilders.has_key(scheme), ( + "No SQLObject driver exists for %s (only %s)" + % (scheme, ', '.join(self.schemeBuilders.keys()))) conn = self.schemeBuilders[scheme]().connectionFromURI(uri) else: # We just have a name, not a URI Modified: trunk/SQLObject/sqlobject/tests/dbtest.py =================================================================== --- trunk/SQLObject/sqlobject/tests/dbtest.py 2005-03-25 17:39:04 UTC (rev 688) +++ trunk/SQLObject/sqlobject/tests/dbtest.py 2005-03-27 19:37:29 UTC (rev 689) @@ -69,12 +69,12 @@ installedDBTracker = sqlobject.connectionForURI( 'sqlite:///' + installedDBFilename) -def getConnection(): +def getConnection(**kw): name = os.environ.get('TESTDB') assert name, 'You must set $TESTDB to do database operations' if connectionShortcuts.has_key(name): name = connectionShortcuts[name] - return sqlobject.connectionForURI(name) + return sqlobject.connectionForURI(name, **kw) connection = getConnection() Added: trunk/SQLObject/sqlobject/tests/test_conngetter.py =================================================================== --- trunk/SQLObject/sqlobject/tests/test_conngetter.py 2005-03-25 17:39:04 UTC (rev 688) +++ trunk/SQLObject/sqlobject/tests/test_conngetter.py 2005-03-27 19:37:29 UTC (rev 689) @@ -0,0 +1,37 @@ +from sqlobject import * +from sqlobject.tests.dbtest import * + +class TestSimple(SQLObject): + + class sqlmeta: + registry = 'conngetter' + + name = StringCol(alternateID=True) + +class TestJoined(SQLObject): + + class sqlmeta: + registry = 'conngetter' + + this_name = StringCol(alternateID=True) + simple = ForeignKey('TestSimple') + +def test_autogetter(): + conn = getConnection(registry='conngetter') + TestJoined.dropTable(connection=conn, ifExists=True) + TestSimple.dropTable(connection=conn, ifExists=True) + TestSimple.createTable(connection=conn, ifNotExists=True) + TestJoined.createTable(connection=conn, ifNotExists=True) + assert conn.TestSimple.__sqlobject_class__ is TestSimple + obj = conn.TestSimple(name='test') + assert (TestSimple.get(obj.id, connection=conn) is obj) + assert obj._connection is conn + obj2 = TestSimple(name='test2', connection=conn) + assert (conn.TestSimple.byName('test2') is obj2) + joined = conn.TestJoined(this_name='join_test', simple=obj) + assert joined.simple is obj + assert joined.simple._connection is conn + assert joined._connection is conn + for item in conn.TestSimple.select(): + assert item._connection is conn + |