[SQL-CVS] r3 - in trunk/SQLObject: sqlobject tests
SQLObject is a Python ORM.
Brought to you by:
ianbicking,
phd
From: <sub...@co...> - 2004-02-06 06:17:06
|
Author: ianb Date: Thu Feb 5 21:09:52 2004 New Revision: 3 Modified: trunk/SQLObject/sqlobject/dbconnection.py trunk/SQLObject/sqlobject/joins.py trunk/SQLObject/sqlobject/main.py trunk/SQLObject/tests/ (props changed) trunk/SQLObject/tests/test.py Log: Renamed __new__ to get, renamed new to __init__ (in other words, class instantiation creates a row, while the get() class method fetches an object) Modified: trunk/SQLObject/sqlobject/dbconnection.py ============================================================================== --- trunk/SQLObject/sqlobject/dbconnection.py (original) +++ trunk/SQLObject/sqlobject/dbconnection.py Thu Feb 5 21:09:52 2004 @@ -176,10 +176,10 @@ self.releaseConnection(conn) break if select.ops.get('lazyColumns', 0): - obj = select.sourceClass(result[0], connection=withConnection) + obj = select.sourceClass.get(result[0], connection=withConnection) yield obj else: - obj = select.sourceClass(result[0], selectResults=result[1:], connection=withConnection) + obj = select.sourceClass.get(result[0], selectResults=result[1:], connection=withConnection) yield obj def iterSelect(self, select): @@ -1358,7 +1358,7 @@ if not self._maxNext: raise StopIteration self._maxNext -= 1 - return self.select.sourceClass(int(idList[self.tableDict[self.select.sourceClass._table]])) + return self.select.sourceClass.get(int(idList[self.tableDict[self.select.sourceClass._table]])) raise StopIteration def field(self, table, field): Modified: trunk/SQLObject/sqlobject/joins.py ============================================================================== --- trunk/SQLObject/sqlobject/joins.py (original) +++ trunk/SQLObject/sqlobject/joins.py Thu Feb 5 21:09:52 2004 @@ -115,7 +115,7 @@ conn = inst._connection else: conn = None - return self._applyOrderBy([self.otherClass(id, conn) for (id,) in ids if id is not None], self.otherClass) + return self._applyOrderBy([self.otherClass.get(id, conn) for (id,) in ids if id is not None], self.otherClass) class MultipleJoin(Join): baseClass = SOMultipleJoin @@ -155,7 +155,7 @@ conn = inst._connection else: conn = None - return self._applyOrderBy([self.otherClass(id, conn) for (id,) in ids if id is not None], self.otherClass) + return self._applyOrderBy([self.otherClass.get(id, conn) for (id,) in ids if id is not None], self.otherClass) def remove(self, inst, other): inst._connection._SO_intermediateDelete( Modified: trunk/SQLObject/sqlobject/main.py ============================================================================== --- trunk/SQLObject/sqlobject/main.py (original) +++ trunk/SQLObject/sqlobject/main.py Thu Feb 5 21:09:52 2004 @@ -309,26 +309,10 @@ # when necessary: (bad clever? maybe) _expired = False - def __new__(cls, id, connection=None, selectResults=None): + def get(cls, id, connection=None, selectResults=None): assert id is not None, 'None is not a possible id for %s' % cls.__name - # When id is CreateNewSQLObject, that means we are trying to - # create a new object. This is a contract of sorts with the - # `new()` method. - if id is CreateNewSQLObject: - # Create an actual new object: - inst = object.__new__(cls) - inst._SO_creating = True - inst._SO_validatorState = SQLObjectState(inst) - # This is a dictionary of column-names to - # column-values for the new row: - inst._SO_createValues = {} - if connection is not None: - inst._connection = connection - assert selectResults is None - return inst - # Some databases annoyingly return longs for INT if isinstance(id, long): id = int(id) @@ -343,7 +327,7 @@ val = cache.get(id, cls) if val is None: try: - val = object.__new__(cls) + val = cls(_SO_fetch_no_create=1) val._SO_validatorState = SQLObjectState(val) val._init(id, connection, selectResults) cache.put(id, cls, val) @@ -351,6 +335,8 @@ cache.finishPut(cls) return val + get = classmethod(get) + def addColumn(cls, columnDef, changeSchema=False): column = columnDef.withClass(cls) name = column.name @@ -755,22 +741,25 @@ if id is None: return None elif self._SO_perConnection: - return joinClass(id, connection=self._connection) + return joinClass.get(id, connection=self._connection) else: - return joinClass(id) - - def new(cls, **kw): - # This is what creates a new row, plus the new Python - # object to go with it. + return joinClass.get(id) + def __init__(self, **kw): + # The get() classmethod/constructor uses a magic keyword + # argument when it wants an empty object, fetched from the + # database. So we have nothing more to do in that case: + if kw.has_key('_SO_fetch_no_create'): + return + # Pass the connection object along if we were given one. # Passing None for the ID tells __new__ we want to create # a new object. if kw.has_key('connection'): - inst = cls(CreateNewSQLObject, connection=kw['connection']) + self._connection = kw['connection'] + self._SO_perConnection = True del kw['connection'] - else: - inst = cls(CreateNewSQLObject) + self._SO_writeLock = threading.Lock() if kw.has_key('id'): id = kw['id'] @@ -778,9 +767,13 @@ else: id = None + self._SO_creating = True + self._SO_createValues = {} + self._SO_validatorState = SQLObjectState(self) + # First we do a little fix-up on the keywords we were # passed: - for column in inst._SO_columns: + for column in self._SO_columns: # If a foreign key is given, we get the ID of the object # and put that in instead @@ -805,27 +798,25 @@ forDB = {} others = {} for name, value in kw.items(): - if name in inst._SO_plainSetters: + if name in self._SO_plainSetters: forDB[name] = value else: others[name] = value # We take all the straight-to-DB values and use set() to # set them: - inst.set(**forDB) + self.set(**forDB) # The rest go through setattr(): for name, value in others.items(): try: - getattr(cls, name) + getattr(self.__class__, name) except AttributeError: - raise TypeError, "%s.new() got an unexpected keyword argument %s" % (cls.__name__, name) - setattr(inst, name, value) + raise TypeError, "%s.new() got an unexpected keyword argument %s" % (self.__class__.__name__, name) + setattr(self, name, value) # Then we finalize the process: - inst._SO_finishCreate(id) - return inst - new = classmethod(new) + self._SO_finishCreate(id) def _SO_finishCreate(self, id=None): # Here's where an INSERT is finalized. @@ -863,9 +854,9 @@ if not result: raise SQLObjectNotFound, "The %s by alternateID %s=%s does not exist" % (cls.__name__, dbIDName, repr(value)) if connection: - obj = cls(result[0], connection=connection) + obj = cls.get(result[0], connection=connection) else: - obj = cls(result[0]) + obj = cls.get(result[0]) if not obj._cacheValues: obj._SO_writeLock.acquire() try: Modified: trunk/SQLObject/tests/test.py ============================================================================== --- trunk/SQLObject/tests/test.py (original) +++ trunk/SQLObject/tests/test.py Thu Feb 5 21:09:52 2004 @@ -46,7 +46,7 @@ def inserts(self): for name, passwd in self.info: - self.MyClass.new(name=name, passwd=passwd) + self.MyClass(name=name, passwd=passwd) def testGet(self): bob = self.MyClass.selectBy(name='bob')[0] @@ -96,7 +96,7 @@ classes = [Student] def testBoolCol(self): - student = Student.new(is_smart = False) + student = Student(is_smart=False) self.assertEqual(student.is_smart, False) class TestCase34(SQLObjectTest): @@ -104,30 +104,30 @@ classes = [TestSO3, TestSO4] def testForeignKey(self): - tc3 = TestSO3.new(name='a') + tc3 = TestSO3(name='a') self.assertEqual(tc3.other, None) self.assertEqual(tc3.other2, None) self.assertEqual(tc3.otherID, None) self.assertEqual(tc3.other2ID, None) - tc4a = TestSO4.new(me='1') + tc4a = TestSO4(me='1') tc3.other = tc4a self.assertEqual(tc3.other, tc4a) self.assertEqual(tc3.otherID, tc4a.id) - tc4b = TestSO4.new(me='2') + tc4b = TestSO4(me='2') tc3.other = tc4b.id self.assertEqual(tc3.other, tc4b) self.assertEqual(tc3.otherID, tc4b.id) - tc4c = TestSO4.new(me='3') + tc4c = TestSO4(me='3') tc3.other2 = tc4c self.assertEqual(tc3.other2, tc4c) self.assertEqual(tc3.other2ID, tc4c.id) - tc4d = TestSO4.new(me='4') + tc4d = TestSO4(me='4') tc3.other2 = tc4d.id self.assertEqual(tc3.other2, tc4d) self.assertEqual(tc3.other2ID, tc4d.id) - tcc = TestSO3.new(name='b', other=tc4a) + tcc = TestSO3(name='b', other=tc4a) self.assertEqual(tcc.other, tc4a) - tcc2 = TestSO3.new(name='c', other=tc4a.id) + tcc2 = TestSO3(name='c', other=tc4a.id) self.assertEqual(tcc2.other, tc4a) class TestSO5(SQLObject): @@ -147,10 +147,10 @@ classes = [TestSO7, TestSO6, TestSO5] def testForeignKeyDestroySelfCascade(self): - tc5 = TestSO5.new(name='a') - tc6a = TestSO6.new(name='1') + tc5 = TestSO5(name='a') + tc6a = TestSO6(name='1') tc5.other = tc6a - tc7a = TestSO7.new(name='2') + tc7a = TestSO7(name='2') tc6a.other = tc7a tc5.another = tc7a self.assertEqual(tc5.other, tc6a) @@ -165,9 +165,9 @@ self.assertEqual(TestSO5.select().count(), 1) self.assertEqual(TestSO6.select().count(), 1) self.assertEqual(TestSO7.select().count(), 1) - tc6b = TestSO6.new(name='3') - tc6c = TestSO6.new(name='4') - tc7b = TestSO7.new(name='5') + tc6b = TestSO6(name='3') + tc6c = TestSO6(name='4') + tc7b = TestSO7(name='5') tc6b.other = tc7b tc6c.other = tc7b self.assertEqual(TestSO5.select().count(), 1) @@ -187,15 +187,15 @@ self.assertEqual(TestSO7.select().count(), 0) def testForeignKeyDropTableCascade(self): - tc5a = TestSO5.new(name='a') - tc6a = TestSO6.new(name='1') + tc5a = TestSO5(name='a') + tc6a = TestSO6(name='1') tc5a.other = tc6a - tc7a = TestSO7.new(name='2') + tc7a = TestSO7(name='2') tc6a.other = tc7a tc5a.another = tc7a - tc5b = TestSO5.new(name='b') - tc5c = TestSO5.new(name='c') - tc6b = TestSO6.new(name='3') + tc5b = TestSO5(name='b') + tc5c = TestSO5(name='c') + tc6b = TestSO6(name='3') tc5c.other = tc6b self.assertEqual(TestSO5.select().count(), 3) self.assertEqual(TestSO6.select().count(), 2) @@ -210,7 +210,7 @@ self.assertEqual(TestSO5.select().count(), 1) self.assertEqual(TestSO6.select().count(), 0) self.assertEqual(iter(TestSO5.select()).next(), tc5b) - tc6c = TestSO6.new(name='3') + tc6c = TestSO6(name='3') tc5b.other = tc6c self.assertEqual(TestSO5.select().count(), 1) self.assertEqual(TestSO6.select().count(), 1) @@ -230,11 +230,11 @@ classes = [TestSO9, TestSO8] def testForeignKeyDestroySelfRestrict(self): - tc8a = TestSO8.new(name='a') - tc9a = TestSO9.new(name='1') + tc8a = TestSO8(name='a') + tc9a = TestSO9(name='1') tc8a.other = tc9a - tc8b = TestSO8.new(name='b') - tc9b = TestSO9.new(name='2') + tc8b = TestSO8(name='b') + tc9b = TestSO9(name='2') self.assertEqual(tc8a.other, tc9a) self.assertEqual(tc8a.otherID, tc9a.id) self.assertEqual(TestSO8.select().count(), 2) @@ -270,7 +270,7 @@ for fname, lname in [('aj', 'baker'), ('joe', 'robbins'), ('tim', 'jackson'), ('joe', 'baker'), ('zoe', 'robbins')]: - Names.new(fname=fname, lname=lname) + Names(fname=fname, lname=lname) def testDefaultOrder(self): self.assertEqual([(n.fname, n.lname) for n in Names.select()], @@ -300,7 +300,7 @@ def inserts(self): for name in self.names: - IterTest.new(name=name) + IterTest(name=name) def test_00_normal(self): count = 0 @@ -379,15 +379,15 @@ classes = [TestSOTrans] def inserts(self): - TestSOTrans.new(name='bob') - TestSOTrans.new(name='tim') + TestSOTrans(name='bob') + TestSOTrans(name='tim') def testTransaction(self): if not self.supportTransactions: return trans = TestSOTrans._connection.transaction() try: TestSOTrans._connection.autoCommit = 'exception' - TestSOTrans.new(name='joe', connection=trans) + TestSOTrans(name='joe', connection=trans) trans.rollback() self.assertEqual([n.name for n in TestSOTrans.select(connection=trans)], ['bob', 'tim']) @@ -418,12 +418,12 @@ def inserts(self): for l in ['a', 'bcd', 'a', 'e']: - Enum1.new(l=l) + Enum1(l=l) def testBad(self): if self.supportRestrictedEnum: try: - v = Enum1.new(l='b') + v = Enum1(l='b') except Exception, e: pass else: @@ -447,7 +447,7 @@ def inserts(self): for i in range(100): - Counter.new(number=i) + Counter(number=i) def counterEqual(self, counters, value): self.assertEquals([c.number for c in counters], value) @@ -492,7 +492,7 @@ def inserts(self): for i in range(10): for j in range(10): - Counter2.new(n1=i, n2=j) + Counter2(n1=i, n2=j) def counterEqual(self, counters, value): self.assertEquals([(c.n1, c.n2) for c in counters], value) @@ -517,10 +517,10 @@ def inserts(self): for n in ['jane', 'tim', 'bob', 'jake']: - Person.new(name=n) + Person(name=n) for p in ['555-555-5555', '555-394-2930', '444-382-4854']: - Phone.new(phone=p) + Phone(phone=p) def testDefaultOrder(self): self.assertEqual(list(Person.select('all')), @@ -531,7 +531,7 @@ return nickname = StringCol('nickname', length=10) Person.addColumn(nickname, changeSchema=True) - n = Person.new(name='robert', nickname='bob') + n = Person(name='robert', nickname='bob') self.assertEqual([p.name for p in Person.select('all')], ['bob', 'jake', 'jane', 'robert', 'tim']) Person.delColumn(nickname, changeSchema=True) @@ -613,20 +613,19 @@ def testClassCreate(self): if not self.supportAuto: return - import sys class AutoTest(SQLObject): _fromDatabase = True _connection = connection() - john = AutoTest.new(firstName='john', - lastName='doe', - age=10, - created=DateTime.now(), - wannahavefun=False) - jane = AutoTest.new(firstName='jane', - lastName='doe', - happy='N', - created=DateTime.now(), - wannahavefun=True) + john = AutoTest(firstName='john', + lastName='doe', + age=10, + created=DateTime.now(), + wannahavefun=False) + jane = AutoTest(firstName='jane', + lastName='doe', + happy='N', + created=DateTime.now(), + wannahavefun=True) self.failIf(john.wannahavefun) self.failUnless(jane.wannahavefun) del classregistry.registry(AutoTest._registry).classes['AutoTest'] @@ -651,9 +650,9 @@ def inserts(self): for n in ['bob', 'tim', 'jane', 'joe', 'fred', 'barb']: - PersonJoiner.new(name=n) + PersonJoiner(name=n) for z in ['11111', '22222', '33333', '44444']: - AddressJoiner.new(zip=z) + AddressJoiner(zip=z) def testJoin(self): b = PersonJoiner.byName('bob') @@ -693,12 +692,12 @@ classes = [PersonJoiner2, AddressJoiner2] def inserts(self): - p1 = PersonJoiner2.new(name='bob') - p2 = PersonJoiner2.new(name='sally') + p1 = PersonJoiner2(name='bob') + p2 = PersonJoiner2(name='sally') for z in ['11111', '22222', '33333']: - a = AddressJoiner2.new(zip=z, personJoiner2=p1) + a = AddressJoiner2(zip=z, personJoiner2=p1) #p1.addAddressJoiner2(a) - AddressJoiner2.new(zip='00000', personJoiner2=p2) + AddressJoiner2(zip='00000', personJoiner2=p2) def test(self): bob = PersonJoiner2.byName('bob') @@ -711,7 +710,7 @@ z.zip = 'xxxxx' id = z.id del z - z = AddressJoiner2(id) + z = AddressJoiner2.get(id) self.assertEqual(z.zip, 'xxxxx') def testDefaultOrder(self): @@ -737,15 +736,15 @@ classes = [Super, Sub] def testSuper(self): - s1 = Super.new(name='one') - s2 = Super.new(name='two') - s3 = Super(s1.id) + s1 = Super(name='one') + s2 = Super(name='two') + s3 = Super.get(s1.id) self.assertEqual(s1, s3) def testSub(self): - s1 = Sub.new(name='one', name2='1') - s2 = Sub.new(name='two', name2='2') - s3 = Sub(s1.id) + s1 = Sub(name='one', name2='1') + s2 = Sub(name='two', name2='2') + s3 = Sub.get(s1.id) self.assertEqual(s1, s3) @@ -761,8 +760,8 @@ classes = [SyncTest] def inserts(self): - SyncTest.new(name='bob') - SyncTest.new(name='tim') + SyncTest(name='bob') + SyncTest(name='tim') def testExpire(self): conn = SyncTest._connection @@ -792,19 +791,19 @@ classes = [SOValidation] def testValidate(self): - t = SOValidation.new(name='hey') + t = SOValidation(name='hey') self.assertRaises(validators.InvalidField, setattr, t, 'name', '!!!') t.name = 'you' def testConfirmType(self): - t = SOValidation.new(name2='hey') + t = SOValidation(name2='hey') self.assertRaises(validators.InvalidField, setattr, t, 'name2', 1) t.name2 = 'you' def testWrapType(self): - t = SOValidation.new(name3=1) + t = SOValidation(name3=1) self.assertRaises(validators.InvalidField, setattr, t, 'name3', 'x') t.name3 = 1L @@ -863,11 +862,11 @@ classes = [SOStringID] def testStringID(self): - t = SOStringID.new(id='hey', val='whatever') + t = SOStringID(id='hey', val='whatever') t2 = SOStringID.byVal('whatever') self.assertEqual(t, t2) - t3 = SOStringID.new(id='you', val='nowhere') - t4 = SOStringID('you') + t3 = SOStringID(id='you', val='nowhere') + t4 = SOStringID.get('you') self.assertEqual(t3, t4) @@ -894,8 +893,8 @@ def test(self): - st1 = SOStyleTest1.new(a='something', st2=None) - st2 = SOStyleTest2.new(b='whatever') + st1 = SOStyleTest1(a='something', st2=None) + st2 = SOStyleTest2(b='whatever') st1.st2 = st2 self.assertEqual(st1._SO_columnDict['st2ID'].dbName, 'idst2') self.assertEqual(st1.st2, st2) |