Index: sqlobject/main.py =================================================================== --- sqlobject/main.py (revision 126) +++ sqlobject/main.py (working copy) @@ -344,7 +344,7 @@ get = classmethod(get) - def addColumn(cls, columnDef, changeSchema=False): + def addColumn(cls, columnDef, changeSchema=False, connection=None): column = columnDef.withClass(cls) name = column.name assert name != 'id', "The 'id' column is implicit, and should not be defined as a column" @@ -449,15 +449,15 @@ setattr(cls, column.alternateMethodName, classmethod(func)) if changeSchema: - cls._connection.addColumn(cls._table, column) + (connection or cls._connection).addColumn(cls._table, column) if cls._SO_finishedClassCreation: makeProperties(cls) addColumn = classmethod(addColumn) - def addColumnsFromDatabase(cls): - for columnDef in cls._connection.columnsFromSchema(cls._table, cls): + def addColumnsFromDatabase(cls, connection=None): + for columnDef in (connection or cls._connection).columnsFromSchema(cls._table, cls): alreadyExists = False for c in cls._columns: if c.kw['name'] == columnDef.kw['name']: @@ -468,7 +468,7 @@ addColumnsFromDatabase = classmethod(addColumnsFromDatabase) - def delColumn(cls, column, changeSchema=False): + def delColumn(cls, column, changeSchema=False, connection=None): if isinstance(column, str): column = cls._SO_columnDict[column] if isinstance(column, col.Col): @@ -497,7 +497,7 @@ delattr(cls, column.alternateMethodName) if changeSchema: - cls._connection.delColumn(cls._table, column) + (connection or cls._connection).delColumn(cls._table, column) if cls._SO_finishedClassCreation: unmakeProperties(cls) @@ -933,29 +933,43 @@ def selectBy(cls, connection=None, **kw): return SelectResults(cls, - cls._connection._SO_columnClause(cls, kw), + (connection or cls._connection)._SO_columnClause(cls, kw), connection=connection) selectBy = classmethod(selectBy) - # 3-03 @@: Should these have a connection argument? - def dropTable(cls, ifExists=False, dropJoinTables=True, cascade=False): - if ifExists and not cls._connection.tableExists(cls._table): + def dropTable(cls, ifExists=False, dropJoinTables=True, cascade=False, connection=None): + if connection: + conn = connection + else: + conn = cls._connection + + if ifExists and not conn.tableExists(cls._table): return - cls._connection.dropTable(cls._table, cascade) + conn.dropTable(cls._table, cascade) if dropJoinTables: - cls.dropJoinTables(ifExists=ifExists) + cls.dropJoinTables(ifExists=ifExists, connection=connection) dropTable = classmethod(dropTable) - def createTable(cls, ifNotExists=False, createJoinTables=True): - if ifNotExists and cls._connection.tableExists(cls._table): + def createTable(cls, ifNotExists=False, createJoinTables=True, connection=None): + if connection: + conn = connection + else: + conn = cls._connection + + if ifNotExists and conn.tableExists(cls._table): return - cls._connection.createTable(cls) + conn.createTable(cls) if createJoinTables: - cls.createJoinTables(ifNotExists=ifNotExists) + cls.createJoinTables(ifNotExists=ifNotExists, connection = connection) createTable = classmethod(createTable) - def createJoinTables(cls, ifNotExists=False): + def createJoinTables(cls, ifNotExists=False, connection=None): + if connection: + conn = connection + else: + conn = cls._connection + for join in cls._SO_joinList: if not join: continue @@ -968,13 +982,18 @@ if join.soClass.__name__ > join.otherClass.__name__: continue if ifNotExists and \ - cls._connection.tableExists(join.intermediateTable): + conn.tableExists(join.intermediateTable): continue - cls._connection._SO_createJoinTable(join) + conn._SO_createJoinTable(join) createJoinTables = classmethod(createJoinTables) - def dropJoinTables(cls, ifExists=False): + def dropJoinTables(cls, ifExists=False, connection=None): + if connection: + conn = connection + else: + conn = cls._connection + for join in cls._SO_joinList: if not join: continue @@ -983,16 +1002,21 @@ if join.soClass.__name__ > join.otherClass.__name__: continue if ifExists and \ - not cls._connection.tableExists(join.intermediateTable): + not conn.tableExists(join.intermediateTable): continue - cls._connection._SO_dropJoinTable(join) + conn._SO_dropJoinTable(join) dropJoinTables = classmethod(dropJoinTables) - def clearTable(cls): + def clearTable(cls, connection=None): + if connection: + conn = connection + else: + conn = cls._connection + # 3-03 @@: Maybe this should check the cache... but it's # kind of crude anyway, so... - cls._connection.clearTable(cls._table) + conn.clearTable(cls._table) clearTable = classmethod(clearTable) def destroySelf(self): @@ -1036,8 +1060,8 @@ self.id, ' '.join(['%s=%s' % (name, repr(value)) for name, value in self._reprItems()])) - def sqlrepr(cls, value): - return cls._connection.sqlrepr(value) + def sqlrepr(cls, value, connection=None): + return (connection or cls._connection).sqlrepr(value) sqlrepr = classmethod(sqlrepr)