[SQL-CVS] r807 - in trunk/SQLObject/sqlobject: . tests
SQLObject is a Python ORM.
Brought to you by:
ianbicking,
phd
From: <sub...@co...> - 2005-05-26 22:22:47
|
Author: ianb Date: 2005-05-26 22:22:34 +0000 (Thu, 26 May 2005) New Revision: 807 Modified: trunk/SQLObject/sqlobject/col.py trunk/SQLObject/sqlobject/main.py trunk/SQLObject/sqlobject/tests/test_basic.py Log: Fixed cascade='null', which was cascading deletes, instead of setting columns to null Modified: trunk/SQLObject/sqlobject/col.py =================================================================== --- trunk/SQLObject/sqlobject/col.py 2005-05-26 13:59:05 UTC (rev 806) +++ trunk/SQLObject/sqlobject/col.py 2005-05-26 22:22:34 UTC (rev 807) @@ -122,6 +122,10 @@ # None: no constraint is generated # True: a CASCADE constraint is generated # False: a RESTRICT constraint is generated + # 'null': a SET NULL trigger is generated + if isinstance(cascade, str): + assert cascade == 'null', ( + "The only string value allowed for cascade is 'null' (you gave: %r)" % cascade) self.cascade = cascade if type(constraints) not in (type([]), type(())): Modified: trunk/SQLObject/sqlobject/main.py =================================================================== --- trunk/SQLObject/sqlobject/main.py 2005-05-26 13:59:05 UTC (rev 806) +++ trunk/SQLObject/sqlobject/main.py 2005-05-26 22:22:34 UTC (rev 807) @@ -1205,19 +1205,22 @@ connection=conn) createTable = classmethod(createTable) - def createTableSQL(cls, createJoinTables=True, connection=None): + def createTableSQL(cls, createJoinTables=True, connection=None, + createIndexes=True): conn = connection or cls._connection sql = conn.createTableSQL(cls) if createJoinTables: sql += '\n' + cls.createJoinTablesSQL(connection=conn) + if createIndexes: + sql += '\n' + cls.createIndexesSQL(connection=conn) return sql createTableSQL = classmethod(createTableSQL) def createJoinTables(cls, ifNotExists=False, connection=None): conn = connection or cls._connection for join in cls._getJoinsToCreate(): - if ifNotExists and \ - conn.tableExists(join.intermediateTable): + if (ifNotExists and + conn.tableExists(join.intermediateTable)): continue conn._SO_createJoinTable(join) createJoinTables = classmethod(createJoinTables) @@ -1238,6 +1241,16 @@ conn._SO_createIndex(cls, index) createIndexes = classmethod(createIndexes) + def createIndexesSQL(cls, connection=None): + conn = connection or cls._connection + sql = [] + for index in cls.sqlmeta._indexList: + if not index: + continue + sql.append(conn.createIndexSQL(cls, index)) + return '\n'.join(sql) + createIndexesSQL = classmethod(createIndexesSQL) + def _getJoinsToCreate(cls): joins = [] for join in cls.sqlmeta._joinList: @@ -1282,23 +1295,37 @@ for k in depends: cols = findDependantColumns(klass.__name__, k) query = [] - restrict = False + delete = setnull = restrict = False for col in cols: if col.cascade == False: # Found a restriction restrict = True query.append("%s = %s" % (col.dbName, self.id)) + if col.cascade == 'null': + setnull = col.name + elif col.cascade: + delete = True + assert delete or setnull or restrict, ( + "Class %s depends on %s accoriding to " + "findDependantColumns, but this seems inaccurate" + % (k, klass)) query = ' OR '.join(query) results = k.select(query, connection=self._connection) - if restrict and results.count(): - # Restrictions only apply if there are - # matching records on the related table - raise SQLObjectIntegrityError, ( - "Tried to delete %s::%s but " - "table %s has a restriction against it" % - (klass.__name__, self.id, k.__name__)) - for row in results: - row.destroySelf() + if restrict: + if results.count(): + # Restrictions only apply if there are + # matching records on the related table + raise SQLObjectIntegrityError, ( + "Tried to delete %s::%s but " + "table %s has a restriction against it" % + (klass.__name__, self.id, k.__name__)) + else: + for row in results: + if delete: + row.destroySelf() + else: + row.set(**{setnull: None}) + self.sqlmeta._obsolete = True self._connection._SO_delete(self) self._connection.cache.expire(self.id, self.__class__) Modified: trunk/SQLObject/sqlobject/tests/test_basic.py =================================================================== --- trunk/SQLObject/sqlobject/tests/test_basic.py 2005-05-26 13:59:05 UTC (rev 806) +++ trunk/SQLObject/sqlobject/tests/test_basic.py 2005-05-26 22:22:34 UTC (rev 807) @@ -120,7 +120,6 @@ assert TestSO3.selectBy(otherID=tc4.id)[0] == tc3 assert list(TestSO3.selectBy(otherID=tc4.id)[:10]) == [tc3] assert list(TestSO3.selectBy(other=tc4)[:10]) == [tc3] - assert 0 class TestSO5(SQLObject): name = StringCol(length=10, dbName='name_col') @@ -242,3 +241,27 @@ tc9a.destroySelf() assert TestSO8.select().count() == 0 assert TestSO9.select().count() == 0 + +class TestSO10(SQLObject): + name = StringCol() + +class TestSO11(SQLObject): + name = StringCol() + other = ForeignKey('TestSO10', default=None, cascade='null') + +def testForeignKeySetNull(): + setupClass([TestSO10, TestSO11]) + obj1 = TestSO10(name='foo') + obj2 = TestSO10(name='bar') + dep1 = TestSO11(name='xxx', other=obj1) + dep2 = TestSO11(name='yyy', other=obj1) + dep3 = TestSO11(name='zzz', other=obj2) + for name in 'xxx', 'yyy', 'zzz': + assert len(list(TestSO11.selectBy(name=name))) == 1 + obj1.destroySelf() + for name in 'xxx', 'yyy', 'zzz': + assert len(list(TestSO11.selectBy(name=name))) == 1 + assert dep1.other is None + assert dep2.other is None + assert dep3.other is obj2 + |