? SQLObject/include/Validator.pyc ? SQLObject/include/__init__.pyc Index: SQLObject/Col.py =================================================================== RCS file: /cvsroot/sqlobject/SQLObject/SQLObject/Col.py,v retrieving revision 1.32 diff -u -r1.32 Col.py --- SQLObject/Col.py 12 Nov 2003 17:04:55 -0000 1.32 +++ SQLObject/Col.py 3 Dec 2003 20:54:54 -0000 @@ -33,7 +33,8 @@ sqlType=None, columnDef=None, validator=None, - immutable=False): + immutable=False, + cascade=None): # This isn't strictly true, since we *could* use backquotes or # " or something (database-specific) around column names, but @@ -49,6 +50,12 @@ self.immutable = immutable + # cascade can be one of: + # None: no constraint is generated + # True: a CASCADE constraint is generated + # False: a RESTRICT constraint is generated + self.cascade = cascade + if type(constraints) not in (type([]), type(())): constraints = [constraints] self.constraints = self.autoConstraints() + constraints @@ -342,6 +349,25 @@ if not kw['name'].upper().endswith('ID'): kw['name'] = style.instanceAttrToIDAttr(kw['name']) SOKeyCol.__init__(self, **kw) + + def postgresCreateSQL(self): + from SQLObject import findClass + sql = SOKeyCol.postgresCreateSQL(self) + if self.cascade is not None: + other = findClass(self.foreignKey) + tName = other._table + idName = other._idName + action = self.cascade and 'CASCADE' or 'RESTRICT' + constraint = ('CONSTRAINT %(tName)s_exists ' + 'FOREIGN KEY(%(colName)s) ' + 'REFERENCES %(tName)s(%(idName)s) ' + 'ON DELETE %(action)s' % + {'tName':tName, + 'colName':self.dbName, + 'idName':idName, + 'action':action}) + sql = ', '.join([sql, constraint]) + return sql class ForeignKey(KeyCol): Index: SQLObject/DBConnection.py =================================================================== RCS file: /cvsroot/sqlobject/SQLObject/SQLObject/DBConnection.py,v retrieving revision 1.55 diff -u -r1.55 DBConnection.py --- SQLObject/DBConnection.py 12 Nov 2003 17:06:34 -0000 1.55 +++ SQLObject/DBConnection.py 3 Dec 2003 20:54:54 -0000 @@ -278,7 +278,7 @@ def createColumn(self, soClass, col): assert 0, "Implement in subclasses" - def dropTable(self, tableName): + def dropTable(self, tableName, cascade=False): self.query("DROP TABLE %s" % tableName) def clearTable(self, tableName): @@ -605,6 +605,10 @@ def createIDColumn(self, soClass): return '%s SERIAL PRIMARY KEY' % soClass._idName + def dropTable(self, tableName, cascade=False): + self.query("DROP TABLE %s %s" % (tableName, + cascade and 'CASCADE' or '')) + def joinSQLType(self, join): return 'INT NOT NULL' @@ -859,7 +863,7 @@ (tableName, column.firebirdCreateSQL())) - def dropTable(self, tableName): + def dropTable(self, tableName, cascade=False): self.query("DROP TABLE %s" % tableName) self.query("DROP GENERATOR GEN_%s" % tableName) @@ -1091,7 +1095,7 @@ def createTable(self, soClass): self._meta["%s.id" % soClass._table] = "1" - def dropTable(self, tableName): + def dropTable(self, tableName, cascade=False): try: del self._meta["%s.id" % tableName] except KeyError: Index: SQLObject/SQLObject.py =================================================================== RCS file: /cvsroot/sqlobject/SQLObject/SQLObject/SQLObject.py,v retrieving revision 1.64 diff -u -r1.64 SQLObject.py --- SQLObject/SQLObject.py 12 Nov 2003 17:06:00 -0000 1.64 +++ SQLObject/SQLObject.py 3 Dec 2003 20:54:55 -0000 @@ -311,6 +311,19 @@ #assert classRegistry.get(registry, {}).has_key(name), "No class by the name %s found (I have %s)" % (repr(name), ', '.join(map(str, classRegistry.keys()))) return classRegistry[registry][name] +def findDependencies(name, registry=None): + depends = [] + for n, klass in classRegistry[registry].items(): + if findDependantColumns(name, klass): + depends.append(klass) + return depends + +def findDependantColumns(name, klass): + depends = [] + for col in klass._SO_columns: + if col.foreignKey == name and col.cascade: + depends.append(col) + return depends class CreateNewSQLObject: """ @@ -932,6 +945,10 @@ return obj _SO_fetchAlternateID = classmethod(_SO_fetchAlternateID) + def _SO_depends(cls): + return findDependencies(cls.__name__, cls._registry) + _SO_depends = classmethod(_SO_depends) + def select(cls, clause=None, clauseTables=None, orderBy=NoDefault, limit=None, lazyColumns=False, reversed=False, @@ -951,10 +968,10 @@ selectBy = classmethod(selectBy) # 3-03 @@: Should these have a connection argument? - def dropTable(cls, ifExists=False, dropJoinTables=True): + def dropTable(cls, ifExists=False, dropJoinTables=True, cascade=False): if ifExists and not cls._connection.tableExists(cls._table): return - cls._connection.dropTable(cls._table) + cls._connection.dropTable(cls._table, cascade) if dropJoinTables: cls.dropJoinTables(ifExists=ifExists) dropTable = classmethod(dropTable) @@ -1009,6 +1026,17 @@ def destroySelf(self): # Kills this object. Kills it dead! + depends = [] + klass = self.__class__ + depends = self._SO_depends() + for k in depends: + cols = findDependantColumns(klass.__name__, k) + query = [] + for col in cols: + query.append("%s = %s" % (col.dbName, self.id)) + query = ' OR '.join(query) + for row in k.select(query): + row.destroySelf() self._SO_obsolete = True self._connection._SO_delete(self) self._connection.cache.expire(self.id, self.__class__) Index: tests/SQLObjectTest.py =================================================================== RCS file: /cvsroot/sqlobject/SQLObject/tests/SQLObjectTest.py,v retrieving revision 1.18 diff -u -r1.18 SQLObjectTest.py --- tests/SQLObjectTest.py 1 Oct 2003 01:53:48 -0000 1.18 +++ tests/SQLObjectTest.py 3 Dec 2003 20:54:55 -0000 @@ -103,8 +103,8 @@ elif hasattr(c, 'drop'): __connection__.query(c.drop) elif hasattr(c, 'dropTable'): - c.dropTable(ifExists=True) - + c.dropTable(ifExists=True, cascade=True) + if hasattr(c, '%sCreate' % self.databaseName): if not __connection__.tableExists(c._table): sql = getattr(c, '%sCreate' % self.databaseName) Index: tests/test.py =================================================================== RCS file: /cvsroot/sqlobject/SQLObject/tests/test.py,v retrieving revision 1.34 diff -u -r1.34 test.py --- tests/test.py 4 Nov 2003 02:28:35 -0000 1.34 +++ tests/test.py 3 Dec 2003 20:54:55 -0000 @@ -119,6 +119,124 @@ tcc2 = TestSO3.new(name='c', other=tc4a.id) self.assertEqual(tcc2.other, tc4a) +class TestSO5(SQLObject): + name = StringCol(length=10, dbName='name_col') + other = ForeignKey('TestSO6', default=None, cascade=True) + another = ForeignKey('TestSO7', default=None, cascade=True) + +class TestSO6(SQLObject): + name = StringCol(length=10, dbName='name_col') + other = ForeignKey('TestSO7', default=None, cascade=True) + +class TestSO7(SQLObject): + name = StringCol(length=10, dbName='name_col') + +class TestCase567(SQLObjectTest): + + classes = [TestSO7, TestSO6, TestSO5] + + def testForeignKeyDestroySelfCascade(self): + tc5 = TestSO5.new(name='a') + tc6a = TestSO6.new(name='1') + tc5.other = tc6a + tc7a = TestSO7.new(name='2') + tc6a.other = tc7a + tc5.another = tc7a + self.assertEqual(tc5.other, tc6a) + self.assertEqual(tc5.otherID, tc6a.id) + self.assertEqual(tc6a.other, tc7a) + self.assertEqual(tc6a.otherID, tc7a.id) + self.assertEqual(tc5.other.other, tc7a) + self.assertEqual(tc5.other.otherID, tc7a.id) + self.assertEqual(tc5.another, tc7a) + self.assertEqual(tc5.anotherID, tc7a.id) + self.assertEqual(tc5.other.other, tc5.another) + self.assertEqual(TestSO5.select().count(), 1) + self.assertEqual(TestSO6.select().count(), 1) + self.assertEqual(TestSO7.select().count(), 1) + tc6b = TestSO6.new(name='3') + tc6c = TestSO6.new(name='4') + tc7b = TestSO7.new(name='5') + tc6b.other = tc7b + tc6c.other = tc7b + self.assertEqual(TestSO5.select().count(), 1) + self.assertEqual(TestSO6.select().count(), 3) + self.assertEqual(TestSO7.select().count(), 2) + tc6b.destroySelf() + self.assertEqual(TestSO5.select().count(), 1) + self.assertEqual(TestSO6.select().count(), 2) + self.assertEqual(TestSO7.select().count(), 2) + tc7b.destroySelf() + self.assertEqual(TestSO5.select().count(), 1) + self.assertEqual(TestSO6.select().count(), 1) + self.assertEqual(TestSO7.select().count(), 1) + tc7a.destroySelf() + self.assertEqual(TestSO5.select().count(), 0) + self.assertEqual(TestSO6.select().count(), 0) + self.assertEqual(TestSO7.select().count(), 0) + + def testForeignKeyDropTableCascade(self): + tc5a = TestSO5.new(name='a') + tc6a = TestSO6.new(name='1') + tc5a.other = tc6a + tc7a = TestSO7.new(name='2') + tc6a.other = tc7a + tc5a.another = tc7a + tc5b = TestSO5.new(name='b') + tc5c = TestSO5.new(name='c') + tc6b = TestSO6.new(name='3') + tc5c.other = tc6b + self.assertEqual(TestSO5.select().count(), 3) + self.assertEqual(TestSO6.select().count(), 2) + self.assertEqual(TestSO7.select().count(), 1) + TestSO7.dropTable(cascade=True) + self.assertEqual(TestSO5.select().count(), 3) + self.assertEqual(TestSO6.select().count(), 2) + tc6a.destroySelf() + self.assertEqual(TestSO5.select().count(), 2) + self.assertEqual(TestSO6.select().count(), 1) + tc6b.destroySelf() + self.assertEqual(TestSO5.select().count(), 1) + self.assertEqual(TestSO6.select().count(), 0) + self.assertEqual(iter(TestSO5.select()).next(), tc5b) + tc6c = TestSO6.new(name='3') + tc5b.other = tc6c + self.assertEqual(TestSO5.select().count(), 1) + self.assertEqual(TestSO6.select().count(), 1) + tc6c.destroySelf() + self.assertEqual(TestSO5.select().count(), 0) + self.assertEqual(TestSO6.select().count(), 0) + +class TestSO8(SQLObject): + name = StringCol(length=10, dbName='name_col') + other = ForeignKey('TestSO9', default=None, cascade=False) + +class TestSO9(SQLObject): + name = StringCol(length=10, dbName='name_col') + +class TestCase89(SQLObjectTest): + + classes = [TestSO9, TestSO8] + + def testForeignKeyDestroySelfRestrict(self): + tc8a = TestSO8.new(name='a') + tc9a = TestSO9.new(name='1') + tc8a.other = tc9a + tc8b = TestSO8.new(name='b') + tc9b = TestSO9.new(name='2') + self.assertEqual(tc8a.other, tc9a) + self.assertEqual(tc8a.otherID, tc9a.id) + self.assertEqual(TestSO8.select().count(), 2) + self.assertEqual(TestSO9.select().count(), 2) + self.assertRaises(Exception, tc9a.destroySelf) + tc9b.destroySelf() + self.assertEqual(TestSO8.select().count(), 2) + self.assertEqual(TestSO9.select().count(), 1) + tc8a.destroySelf() + tc8b.destroySelf() + tc9a.destroySelf() + self.assertEqual(TestSO8.select().count(), 0) + self.assertEqual(TestSO9.select().count(), 0) ######################################## ## Fancy sort