Author: phd
Date: 2007-05-15 10:03:11 -0600 (Tue, 15 May 2007)
New Revision: 2692
Modified:
SQLObject/trunk/sqlobject/joins.py
SQLObject/trunk/sqlobject/tests/test_SQLMultipleJoin.py
SQLObject/trunk/sqlobject/tests/test_SingleJoin.py
SQLObject/trunk/sqlobject/views.py
Log:
Replaced calls to style.dbColumnToPythonAttr() in joins.py by name/dbName
lookup in case the user named columns differently using dbName.
Modified: SQLObject/trunk/sqlobject/joins.py
===================================================================
--- SQLObject/trunk/sqlobject/joins.py 2007-05-15 15:58:31 UTC (rev 2691)
+++ SQLObject/trunk/sqlobject/joins.py 2007-05-15 16:03:11 UTC (rev 2692)
@@ -152,6 +152,12 @@
conn = None
return self._applyOrderBy([self.otherClass.get(id, conn) for (id,) in ids if id is not None], self.otherClass)
+ def _dbNameToPythonName(self):
+ for column in self.otherClass.sqlmeta.columns.values():
+ if column.dbName == self.joinColumn:
+ return column.name
+ return self.soClass.sqlmeta.style.dbColumnToPythonAttr(self.joinColumn)
+
class MultipleJoin(Join):
baseClass = SOMultipleJoin
@@ -162,7 +168,8 @@
conn = inst._connection
else:
conn = None
- results = self.otherClass.select(getattr(self.otherClass.q, self.soClass.sqlmeta.style.dbColumnToPythonAttr(self.joinColumn)) == inst.id, connection=conn)
+ pythonColumn = self._dbNameToPythonName()
+ results = self.otherClass.select(getattr(self.otherClass.q, pythonColumn) == inst.id, connection=conn)
return results.orderBy(self.orderBy)
class SQLMultipleJoin(Join):
@@ -252,7 +259,7 @@
def tablesUsedImmediate(self):
return [self.table, self.interTable]
-
+
def __sqlrepr__(self, db):
return '%s.%s = %s.%s' % (self.interTable, self.joinColumn, self.table, self.idName)
@@ -276,7 +283,7 @@
conn = None
results = self.otherClass.select(sqlbuilder.AND(
OtherTableToJoin(
- self.otherClass.sqlmeta.table, self.otherClass.sqlmeta.idName,
+ self.otherClass.sqlmeta.table, self.otherClass.sqlmeta.idName,
self.intermediateTable, self.otherColumn
),
JoinToTable(
@@ -305,7 +312,7 @@
conn = inst._connection
else:
conn = None
- pythonColumn = self.soClass.sqlmeta.style.dbColumnToPythonAttr(self.joinColumn)
+ pythonColumn = self._dbNameToPythonName()
results = self.otherClass.select(
getattr(self.otherClass.q, pythonColumn) == inst.id,
connection=conn
@@ -387,7 +394,7 @@
& (sqlbuilder.Field(self.intermediateTable, self.joinColumn)
== obj.id))
select = self.otherClass.select(query)
- return _ManyToManySelectWrapper(obj, self, select)
+ return _ManyToManySelectWrapper(obj, self, select)
def event_CreateTableSignal(self, soClass, connection, extra_sql,
post_funcs):
@@ -411,14 +418,14 @@
joinColumn = None
otherColumn = None
createJoinTable = True
-
+
class _ManyToManySelectWrapper(object):
def __init__(self, forObject, join, select):
self.forObject = forObject
self.join = join
self.select = select
-
+
def __getattr__(self, attr):
# @@: This passes through private variable access too... should it?
# Also magic methods, like __str__
@@ -456,7 +463,7 @@
obj = self.join.otherClass(**kw)
self.add(obj)
return obj
-
+
class SOOneToMany(object):
def __init__(self, soClass, name, join, joinColumn, **attrs):
@@ -502,7 +509,7 @@
self.forObject = forObject
self.join = join
self.select = select
-
+
def __getattr__(self, attr):
# @@: This passes through private variable access too... should it?
# Also magic methods, like __str__
@@ -523,4 +530,3 @@
def create(self, **kw):
kw[self.join.joinColumn] = self.forObject.id
return self.join.otherClass(**kw)
-
Modified: SQLObject/trunk/sqlobject/tests/test_SQLMultipleJoin.py
===================================================================
--- SQLObject/trunk/sqlobject/tests/test_SQLMultipleJoin.py 2007-05-15 15:58:31 UTC (rev 2691)
+++ SQLObject/trunk/sqlobject/tests/test_SQLMultipleJoin.py 2007-05-15 16:03:11 UTC (rev 2692)
@@ -3,12 +3,12 @@
class Race(SQLObject):
name = StringCol()
- fightersAsList = MultipleJoin('RFighter')
- fightersAsSResult = SQLMultipleJoin('RFighter')
+ fightersAsList = MultipleJoin('RFighter', joinColumn="rf_id")
+ fightersAsSResult = SQLMultipleJoin('RFighter', joinColumn="rf_id")
class RFighter(SQLObject):
name = StringCol()
- race = ForeignKey('Race')
+ race = ForeignKey('Race', dbName="rf_id")
power = IntCol()
def createAllTables():
Modified: SQLObject/trunk/sqlobject/tests/test_SingleJoin.py
===================================================================
--- SQLObject/trunk/sqlobject/tests/test_SingleJoin.py 2007-05-15 15:58:31 UTC (rev 2691)
+++ SQLObject/trunk/sqlobject/tests/test_SingleJoin.py 2007-05-15 16:03:11 UTC (rev 2692)
@@ -4,13 +4,13 @@
class PersonWithAlbum(SQLObject):
name = StringCol()
# albumNone returns the album or none
- albumNone = SingleJoin('PhotoAlbum', joinColumn='person_id')
+ albumNone = SingleJoin('PhotoAlbum', joinColumn='test_person_id')
# albumInstance returns the album or an default album instance
- albumInstance = SingleJoin('PhotoAlbum', makeDefault=True, joinColumn='person_id')
+ albumInstance = SingleJoin('PhotoAlbum', makeDefault=True, joinColumn='test_person_id')
class PhotoAlbum(SQLObject):
color = StringCol(default='red')
- person = ForeignKey('PersonWithAlbum')
+ person = ForeignKey('PersonWithAlbum', dbName='test_person_id')
def test_1():
setupClass([PersonWithAlbum, PhotoAlbum])
Modified: SQLObject/trunk/sqlobject/views.py
===================================================================
--- SQLObject/trunk/sqlobject/views.py 2007-05-15 15:58:31 UTC (rev 2691)
+++ SQLObject/trunk/sqlobject/views.py 2007-05-15 16:03:11 UTC (rev 2692)
@@ -28,7 +28,7 @@
class ViewSQLObjectTable(SQLObjectTable):
FieldClass = ViewSQLObjectField
UnicodeFieldClass = UnicodeViewSQLObjectField
-
+
def __getattr__(self, attr):
if attr == 'sqlmeta':
raise AttributeError
@@ -53,13 +53,13 @@
table as an optional alternate name for the class alias
See test_views.py for simple examples.
'''
-
+
def __classinit__(cls, new_attrs):
SQLObject.__classinit__(cls, new_attrs)
# like is_base
if cls.__name__ != 'ViewSQLObject':
dbName = hasattr(cls,'_connection') and (cls._connection and cls._connection.dbName) or None
-
+
if getattr(cls.sqlmeta, 'table', None):
cls.sqlmeta.alias = cls.sqlmeta.table
else:
@@ -81,7 +81,7 @@
aggregates[''].append(ascol)
else:
columns.append(ascol)
-
+
metajoin = getattr(cls.sqlmeta, 'join', NoDefault)
clause = getattr(cls.sqlmeta, 'clause', NoDefault)
select = Select(columns,
@@ -91,17 +91,17 @@
#distinctOn=cls.sqlmeta.idName,
join=metajoin,
clause=clause)
-
+
aggregates = aggregates.values()
#print cls.__name__, sqlrepr(aggregates, dbName)
-
+
if aggregates != [[None]]:
join = []
last_alias = "%s_base" % alias
last_id = "id"
last = Alias(select, last_alias)
columns = [ColumnAS(SQLConstant("%s.%s"%(last_alias,x.expr2)), x.expr2) for x in columns]
-
+
for i, agg in enumerate(aggregates):
restriction = agg[0]
if restriction is None:
@@ -121,22 +121,22 @@
agg_join = LEFTJOINOn(last,
new_alias,
"%s.%s = %s.%s" % (last_alias, last_id, agg_alias, agg_id))
-
+
join.append(agg_join)
for col in agg:
columns.append(ColumnAS(SQLConstant("%s.%s"%(agg_alias, col.expr2)),col.expr2))
-
+
last = new_alias
last_alias = agg_alias
last_id = agg_id
select = Select(columns,
join=join)
-
+
cls.sqlmeta.table = Alias(select, alias)
cls.q = ViewSQLObjectTable(cls)
for n,col in cls.sqlmeta.columns.iteritems():
- col.dbName = getattr(cls.q, n)
-
+ col.dbName = n
+
def isAggregate(expr):
if isinstance(expr, SQLCall):
return True
|