Thread: [SQL-CVS] r2349 - in SQLObject/branches/sqlbuilder-views/sqlobject: . tests
SQLObject is a Python ORM.
Brought to you by:
ianbicking,
phd
From: <sub...@co...> - 2007-02-23 17:21:19
|
Author: luke Date: 2007-02-23 10:21:12 -0700 (Fri, 23 Feb 2007) New Revision: 2349 Modified: SQLObject/branches/sqlbuilder-views/sqlobject/conftest.py SQLObject/branches/sqlbuilder-views/sqlobject/converters.py SQLObject/branches/sqlbuilder-views/sqlobject/dbconnection.py SQLObject/branches/sqlbuilder-views/sqlobject/sqlbuilder.py SQLObject/branches/sqlbuilder-views/sqlobject/tests/test_views.py SQLObject/branches/sqlbuilder-views/sqlobject/views.py Log: More compact View aggregate handling, sqlrepr/tablesUsedDict caching (experimental) Modified: SQLObject/branches/sqlbuilder-views/sqlobject/conftest.py =================================================================== --- SQLObject/branches/sqlbuilder-views/sqlobject/conftest.py 2007-02-23 15:09:12 UTC (rev 2348) +++ SQLObject/branches/sqlbuilder-views/sqlobject/conftest.py 2007-02-23 17:21:12 UTC (rev 2349) @@ -16,7 +16,7 @@ except ImportError: # Python 2.2 pass else: - pkg_resources.require('SQLObject') + pass #pkg_resources.require('SQLObject') connectionShortcuts = { 'mysql': 'mysql://test@localhost/test', Modified: SQLObject/branches/sqlbuilder-views/sqlobject/converters.py =================================================================== --- SQLObject/branches/sqlbuilder-views/sqlobject/converters.py 2007-02-23 15:09:12 UTC (rev 2348) +++ SQLObject/branches/sqlbuilder-views/sqlobject/converters.py 2007-02-23 17:21:12 UTC (rev 2349) @@ -227,6 +227,7 @@ registerConverter(Decimal, DecimalConverter) def sqlrepr(obj, db=None): + import sqlbuilder try: reprFunc = obj.__sqlrepr__ except AttributeError: @@ -236,4 +237,17 @@ (type(obj), repr(obj)) return converter(obj, db) else: - return reprFunc(db) +# return reprFunc(db) + cache = getattr(obj, '_sqlreprCache', {}) + if not isinstance(cache, dict): + #Alias etc + cache = {} + ret = cache.get(db, None) + if ret is None: + ret = reprFunc(db) + try: + cache[db] = ret + obj._sqlreprCache = cache + except TypeError: + pass + return ret Modified: SQLObject/branches/sqlbuilder-views/sqlobject/dbconnection.py =================================================================== --- SQLObject/branches/sqlbuilder-views/sqlobject/dbconnection.py 2007-02-23 15:09:12 UTC (rev 2348) +++ SQLObject/branches/sqlbuilder-views/sqlobject/dbconnection.py 2007-02-23 17:21:12 UTC (rev 2349) @@ -495,7 +495,7 @@ def _SO_selectOneAlt(self, so, columnNames, condition): if columnNames: - columns = (isinstance(x, (str, unicode)) and sqlbuilder.SQLConstant(x) or x for x in columnNames) + columns = [isinstance(x, (str, unicode)) and sqlbuilder.SQLConstant(x) or x for x in columnNames] else: columns = None return self.queryOne(self.sqlrepr(sqlbuilder.Select(columns, Modified: SQLObject/branches/sqlbuilder-views/sqlobject/sqlbuilder.py =================================================================== --- SQLObject/branches/sqlbuilder-views/sqlobject/sqlbuilder.py 2007-02-23 15:09:12 UTC (rev 2348) +++ SQLObject/branches/sqlbuilder-views/sqlobject/sqlbuilder.py 2007-02-23 17:21:12 UTC (rev 2349) @@ -188,13 +188,16 @@ def tablesUsed(self, db): return self.tablesUsedDict(db).keys() def tablesUsedDict(self, db): - tables = {} - for table in self.tablesUsedImmediate(): - if hasattr(table, '__sqlrepr__'): - table = sqlrepr(table, db) - tables[table] = 1 - for component in self.components(): - tables.update(tablesUsedDict(component, db)) + tables = getattr(db, 'tableCache', {}).get(id(self), None) + if tables is None: + tables = {} + for table in self.tablesUsedImmediate(): + if hasattr(table, '__sqlrepr__'): + table = sqlrepr(table, db) + tables[table] = 1 + for component in self.components(): + tables.update(tablesUsedDict(component, db)) + getattr(db, 'tableCache', {})[id(self)] = tables return tables def tablesUsedImmediate(self): return [] @@ -434,7 +437,7 @@ def __init__(self, table, alias=None): if hasattr(table, "sqlmeta"): tableName = SQLConstant(table.sqlmeta.table) - elif isinstance(table, Select): + elif isinstance(table, (Select,Union)): assert alias is not None, "Alias name cannot be constructed from Select instances, please provide 'alias' kw." tableName = Subquery('', table) table = None @@ -473,6 +476,24 @@ return [self.q] +class Union(SQLExpression): + def __init__(self, *tables): + tabs = [] + for t in tables: + if not isinstance(t, SQLExpression) and hasattr(t, 'sqlmeta'): + t = t.sqlmeta.table + if isinstance(t, Alias): + t = t.q + if isinstance(t, Table): + t = t.tableName + if not isinstance(t, SQLExpression): + t = SQLConstant(t.sqlmeta.table) + tabs.append(t) + self.tables = tabs + + def __sqlrepr__(self, db): + return " UNION ".join([str(sqlrepr(t, db)) for t in self.tables]) + ######################################## ## SQL Statements ######################################## @@ -1175,6 +1196,12 @@ val = sqlrepr(val, db) return val + def tablesUsedImmediate(self): + return getattr(self._resolve(), 'tablesUsedImmediate', lambda: [])() + + def components(self): + return getattr(self._resolve(), 'components', lambda: [])() + def _resolve(self): return getattr(self.proxy, self.attr) Modified: SQLObject/branches/sqlbuilder-views/sqlobject/tests/test_views.py =================================================================== --- SQLObject/branches/sqlbuilder-views/sqlobject/tests/test_views.py 2007-02-23 15:09:12 UTC (rev 2348) +++ SQLObject/branches/sqlbuilder-views/sqlobject/tests/test_views.py 2007-02-23 17:21:12 UTC (rev 2349) @@ -42,6 +42,8 @@ number = StringCol(dbName=ViewPhone.q.number) timesCalled = IntCol(dbName=func.COUNT(PhoneCall.q.toID)) + timesCalledLong = IntCol(dbName=func.COUNT(PhoneCall.q.toID)) + timesCalledLong.aggregateClause = PhoneCall.q.minutes>10 minutesCalled = IntCol(dbName=func.SUM(PhoneCall.q.minutes)) class ViewPhoneMore2(ViewPhoneMore): @@ -49,6 +51,9 @@ table = 'vpm' +class ViewPhoneInnerAggregate(ViewPhone): + twiceMinutes = IntCol(dbName=func.SUM(PhoneCall.q.minutes)*2) + def setup_module(mod): setupClass([mod.PhoneNumber,mod.PhoneCall]) mod.ViewPhoneCall._connection = mod.PhoneNumber._connection @@ -59,7 +64,8 @@ 'number') calls = inserts(mod.PhoneCall, [(phones[0], phones[1], 5), (phones[0], phones[1], 20), - (phones[1], phones[0], 10)], + (phones[1], phones[0], 10), + (phones[1], phones[0], 25)], 'phoneNumber to minutes') mod.phones = phones mod.calls = calls @@ -100,6 +106,7 @@ checkAttr(ViewPhoneMore, phones[0].id, 'number', phones[0].number) checkAttr(ViewPhoneMore, phones[0].id, 'minutesCalled', phones[0].incoming.sum(PhoneCall.q.minutes)) checkAttr(ViewPhoneMore, phones[0].id, 'timesCalled', phones[0].incoming.count()) + checkAttr(ViewPhoneMore, phones[0].id, 'timesCalledLong', phones[0].incoming.filter(PhoneCall.q.minutes>10).count()) def testJoinView(): p = ViewPhone.get(phones[0].id) @@ -107,6 +114,9 @@ assert p.vCalls.count() == 2 assert p.vCalls[0] == ViewPhoneCall.get(calls[0].id) +def testInnerAggregate(): + checkAttr(ViewPhoneInnerAggregate, phones[0].id, 'twiceMinutes', phones[0].calls.sum(PhoneCall.q.minutes)*2) + def testSelect(): s = ViewPhone.select() assert s.count() == len(phones) Modified: SQLObject/branches/sqlbuilder-views/sqlobject/views.py =================================================================== --- SQLObject/branches/sqlbuilder-views/sqlobject/views.py 2007-02-23 15:09:12 UTC (rev 2348) +++ SQLObject/branches/sqlbuilder-views/sqlobject/views.py 2007-02-23 17:21:12 UTC (rev 2349) @@ -78,18 +78,27 @@ SQLObject.__classinit__(cls, new_attrs) # like is_base if cls.__name__ != 'ViewSQLObject': + dbName = hasattr(cls,'_connection') and cls._connection.dbName or None + if getattr(cls.sqlmeta, 'table', None): cls.sqlmeta.alias = cls.sqlmeta.table else: cls.sqlmeta.alias = cls.sqlmeta.style.pythonClassToDBTable(cls.__name__) alias = cls.sqlmeta.alias columns = [ColumnAS(cls.sqlmeta.idName, 'id')] - aggregates = [] + # {sqlrepr-key: [restriction, *aggregate-column]} + aggregates = {'':[None]} for n,col in cls.sqlmeta.columns.iteritems(): - if isinstance(col.dbName, SQLCall): - aggregates.append(ColumnAS(col.dbName, n)) + ascol = ColumnAS(col.dbName, n) + if isAggregate(col.dbName): + restriction = getattr(col, 'aggregateClause',None) + if restriction: + restrictkey = sqlrepr(restriction, dbName) + aggregates[restrictkey] = aggregates.get(restrictkey, [restriction]) + [ascol] + else: + aggregates[''].append(ascol) else: - columns.append(ColumnAS(col.dbName, n)) + columns.append(ascol) metajoin = getattr(cls.sqlmeta, 'join', NoDefault) clause = getattr(cls.sqlmeta, 'clause', NoDefault) @@ -99,7 +108,9 @@ join=metajoin, clause=clause) - if aggregates: + aggregates = aggregates.values() + + if len(aggregates) > 1: join = [] last_alias = "%s_base" % alias last_id = "id" @@ -107,22 +118,28 @@ columns = [SQLConstant("%s.%s"%(last_alias,x.expr2)) for x in columns] for i, agg in enumerate(aggregates): + restriction = agg[0] + if restriction is None: + restriction = clause + else: + restriction = AND(clause, restriction) + agg = agg[1:] agg_alias = "%s_%s" % (alias, i) agg_id = '%s_id'%agg_alias if not last.q.alias.endswith('base'): last = None - new_alias = Alias( - Select([ColumnAS(cls.sqlmeta.idName, agg_id), agg], - groupBy=cls.sqlmeta.idName, - join=metajoin, - clause=clause), + new_alias = Alias(Select([ColumnAS(cls.sqlmeta.idName, agg_id)]+agg, + groupBy=cls.sqlmeta.idName, + join=metajoin, + clause=restriction), agg_alias) agg_join = LEFTJOINOn(last, new_alias, "%s.%s = %s.%s" % (last_alias, last_id, agg_alias, agg_id)) join.append(agg_join) - columns.append(SQLConstant("%s.%s"%(agg_alias, agg.expr2))) + for col in agg: + columns.append(SQLConstant("%s.%s"%(agg_alias, col.expr2))) last = new_alias last_alias = agg_alias @@ -134,6 +151,12 @@ cls.q = ViewSQLObjectTable(cls) for n,col in cls.sqlmeta.columns.iteritems(): col.dbName = getattr(cls.q, n) + +def isAggregate(expr): + if isinstance(expr, SQLCall): + return True + if isinstance(expr, SQLOp): + return isAggregate(expr.expr1) or isAggregate(expr.expr2) + return False - ###### |