[SQL-CVS] r726 - in trunk/SQLObject/sqlobject: . tests
SQLObject is a Python ORM.
Brought to you by:
ianbicking,
phd
From: <sub...@co...> - 2005-04-18 14:23:31
|
Author: phd Date: 2005-04-18 14:23:23 +0000 (Mon, 18 Apr 2005) New Revision: 726 Added: trunk/SQLObject/sqlobject/tests/test_joins_conditional.py Modified: trunk/SQLObject/sqlobject/dbconnection.py trunk/SQLObject/sqlobject/main.py trunk/SQLObject/sqlobject/sqlbuilder.py trunk/SQLObject/sqlobject/sresults.py Log: Implemented LEFT/RIGHT/STRAIGHT INNER/OUTER/CROSS JOINs, either NATURAL or conditional. Modified: trunk/SQLObject/sqlobject/dbconnection.py =================================================================== --- trunk/SQLObject/sqlobject/dbconnection.py 2005-04-15 21:53:30 UTC (rev 725) +++ trunk/SQLObject/sqlobject/dbconnection.py 2005-04-18 14:23:23 UTC (rev 726) @@ -344,8 +344,16 @@ """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...) to the select object. """ + ops = select.ops + join = ops.get('join') + if join: + tables = self._fixTablesForJoins(select) + else: + tables = select.tables q = "SELECT %s" % ", ".join([str(expression) for expression in expressions]) - q += " FROM %s WHERE" % ", ".join(select.tables) + q += " FROM %s " % ", ".join(tables) + if join: q += self._addJoins(select) + q += " WHERE" q = self._addWhereClause(select, q, limit=0, order=0) val = self.queryOne(q) if len(expressions) == 1: @@ -354,29 +362,63 @@ def queryForSelect(self, select): ops = select.ops + join = ops.get('join') cls = select.sourceClass + if join: + tables = self._fixTablesForJoins(select) + else: + tables = select.tables if ops.get('distinct', False): q = 'SELECT DISTINCT ' else: q = 'SELECT ' if ops.get('lazyColumns', 0): - q += "%s.%s FROM %s WHERE " % \ + q += "%s.%s FROM %s " % \ (cls.sqlmeta.table, cls.sqlmeta.idName, - ", ".join(select.tables)) + ", ".join(tables)) else: columns = ", ".join(["%s.%s" % (cls.sqlmeta.table, col.dbName) for col in cls.sqlmeta._columns]) if columns: - q += "%s.%s, %s FROM %s WHERE " % \ + q += "%s.%s, %s FROM %s " % \ (cls.sqlmeta.table, cls.sqlmeta.idName, columns, - ", ".join(select.tables)) + ", ".join(tables)) else: - q += "%s.%s FROM %s WHERE " % \ + q += "%s.%s FROM %s " % \ (cls.sqlmeta.table, cls.sqlmeta.idName, - ", ".join(select.tables)) + ", ".join(tables)) + if join: q += self._addJoins(select) + q += " WHERE" return self._addWhereClause(select, q) + def _fixTablesForJoins(self, select): + ops = select.ops + join = ops.get('join') + tables = select.tables + if type(join) is str: + return tables + else: + tables = tables[:] # maka a copy for modification + if isinstance(join, sqlbuilder.SQLJoin): + if join.table1 in tables: tables.remove(join.table1) + if join.table2 in tables: tables.remove(join.table2) + else: + for j in join: + if j.table1 in tables: tables.remove(j.table1) + if j.table2 in tables: tables.remove(j.table2) + return tables + + def _addJoins(self, select): + ops = select.ops + join = ops.get('join') + if type(join) is str: + return join + elif isinstance(join, sqlbuilder.SQLJoin): + return self.sqlrepr(join) + else: + return ", ".join([self.sqlrepr(j) for j in join]) + def _addWhereClause(self, select, startSelect, limit=1, order=1): q = select.clause Modified: trunk/SQLObject/sqlobject/main.py =================================================================== --- trunk/SQLObject/sqlobject/main.py 2005-04-15 21:53:30 UTC (rev 725) +++ trunk/SQLObject/sqlobject/main.py 2005-04-18 14:23:23 UTC (rev 726) @@ -1161,8 +1161,8 @@ def select(cls, clause=None, clauseTables=None, orderBy=NoDefault, limit=None, lazyColumns=False, reversed=False, - distinct=False, - connection=None): + distinct=False, connection=None, + join=None): return cls.SelectResultsClass(cls, clause, clauseTables=clauseTables, orderBy=orderBy, @@ -1170,7 +1170,8 @@ lazyColumns=lazyColumns, reversed=reversed, distinct=distinct, - connection=connection) + connection=connection, + join=join) select = classmethod(select) def selectBy(cls, connection=None, **kw): Modified: trunk/SQLObject/sqlobject/sqlbuilder.py =================================================================== --- trunk/SQLObject/sqlobject/sqlbuilder.py 2005-04-15 21:53:30 UTC (rev 725) +++ trunk/SQLObject/sqlobject/sqlbuilder.py 2005-04-18 14:23:23 UTC (rev 726) @@ -157,7 +157,7 @@ except AssertionError: return '<%s %s>' % ( self.__class__.__name__, hex(id(self))[2:]) - + def __str__(self): return repr(self) @@ -594,6 +594,197 @@ return "'%s%s%s'" % (self.prefix, s, self.postfix) ######################################## +## SQL JOINs +######################################## + +class SQLJoin(SQLExpression): + def __init__(self, table1, table2, op=','): + if table1 and type(table1) <> str: table1 = table1.sqlmeta.table + if type(table2) <> str: table2 = table2.sqlmeta.table + self.table1 = table1 + self.table2 = table2 + self.op = op + + def __sqlrepr__(self, db): + if self.table1: + return "%s%s %s" % (self.table1, self.op, self.table2) + else: + return "%s %s" % (self.op, self.table2) + +registerConverter(SQLJoin, SQLExprConverter) + +def JOIN(table1, table2): + return SQLJoin(table1, table2, " JOIN") + +def INNERJOIN(table1, table2): + return SQLJoin(table1, table2, " INNER JOIN") + +def CROSSJOIN(table1, table2): + return SQLJoin(table1, table2, " CROSS JOIN") + +def STRAIGHTJOIN(table1, table2): + return SQLJoin(table1, table2, " STRAIGHT JOIN") + +def LEFTJOIN(table1, table2): + return SQLJoin(table1, table2, " LEFT JOIN") + +def LEFTOUTERJOIN(table1, table2): + return SQLJoin(table1, table2, " LEFT OUTER JOIN") + +def NATURALJOIN(table1, table2): + return SQLJoin(table1, table2, " NATURAL JOIN") + +def NATURALLEFTJOIN(table1, table2): + return SQLJoin(table1, table2, " NATURAL LEFT JOIN") + +def NATURALLEFTOUTERJOIN(table1, table2): + return SQLJoin(table1, table2, " NATURAL LEFT OUTER JOIN") + +def RIGHTJOIN(table1, table2): + return SQLJoin(table1, table2, " RIGHT JOIN") + +def RIGHTOUTERJOIN(table1, table2): + return SQLJoin(table1, table2, " RIGHT OUTER JOIN") + +def NATURALRIGHTJOIN(table1, table2): + return SQLJoin(table1, table2, " NATURAL RIGHT JOIN") + +def NATURALRIGHTOUTERJOIN(table1, table2): + return SQLJoin(table1, table2, " NATURAL RIGHT OUTER JOIN") + +def FULLJOIN(table1, table2): + return SQLJoin(table1, table2, " FULL JOIN") + +def FULLOUTERJOIN(table1, table2): + return SQLJoin(table1, table2, " FULL OUTER JOIN") + +def NATURALFULLJOIN(table1, table2): + return SQLJoin(table1, table2, " NATURAL FULL JOIN") + +def NATURALFULLOUTERJOIN(table1, table2): + return SQLJoin(table1, table2, " NATURAL FULL OUTER JOIN") + +class SQLJoinConditional(SQLJoin): + """Conditional JOIN""" + def __init__(self, table1, table2, op, on_condition=None, using_columns=None): + """For condition you must give on_condition or using_columns but not both + + on_condition can be a string or SQLExpression, for example + Table1.q.col1 == Table2.q.col2 + using_columns can be a string or a list of columns, e.g. + (Table1.q.col1, Table2.q.col2) + """ + if not on_condition and not using_columns: + raise TypeError, "You must give ON condition or USING columns" + if on_condition and using_columns: + raise TypeError, "You must give ON condition or USING columns but not both" + SQLJoin.__init__(self, table1, table2, op) + self.on_condition = on_condition + self.using_columns = using_columns + + def __sqlrepr__(self, db): + if self.on_condition: + on_condition = self.on_condition + if hasattr(on_condition, "__sqlrepr__"): + on_condition = sqlrepr(on_condition, db) + join = "%s %s ON %s" % (self.op, self.table2, on_condition) + if self.table1: + join = "%s %s" % (self.table1, join) + return join + elif self.using_columns: + using_columns = [] + for col in self.using_columns: + if hasattr(col, "__sqlrepr__"): + col = sqlrepr(col, db) + using_columns.append(col) + using_columns = ", ".join() + join = "%s %s USING (%s)" % (self.op, self.table2, using_columns) + if self.table1: + join = "%s %s" % (self.table1, join) + return join + else: + RuntimeError, "Impossible error" + +registerConverter(SQLJoinConditional, SQLExprConverter) + +def INNERJOINConditional(table1, table2, on_condition=None, using_columns=None): + return SQLJoinConditional(table1, table2, "INNER JOIN", on_condition, using_columns) + +def LEFTJOINConditional(table1, table2, on_condition=None, using_columns=None): + return SQLJoinConditional(table1, table2, "LEFT JOIN", on_condition, using_columns) + +def LEFTOUTERJOINConditional(table1, table2, on_condition=None, using_columns=None): + return SQLJoinConditional(table1, table2, "LEFT OUTER JOIN", on_condition, using_columns) + +def RIGHTJOINConditional(table1, table2, on_condition=None, using_columns=None): + return SQLJoinConditional(table1, table2, "RIGHT JOIN", on_condition, using_columns) + +def RIGHTOUTERJOINConditional(table1, table2, on_condition=None, using_columns=None): + return SQLJoinConditional(table1, table2, "RIGHT OUTER JOIN", on_condition, using_columns) + +def FULLJOINConditional(table1, table2, on_condition=None, using_columns=None): + return SQLJoinConditional(table1, table2, "FULL JOIN", on_condition, using_columns) + +def FULLOUTERJOINConditional(table1, table2, on_condition=None, using_columns=None): + return SQLJoinConditional(table1, table2, "FULL OUTER JOIN", on_condition, using_columns) + +class SQLJoinOn(SQLJoinConditional): + """Conditional JOIN ON""" + def __init__(self, table1, table2, op, on_condition): + SQLJoinConditional.__init__(self, table1, table2, op, on_condition) + +registerConverter(SQLJoinOn, SQLExprConverter) + +class SQLJoinUsing(SQLJoinConditional): + """Conditional JOIN USING""" + def __init__(self, table1, table2, op, using_columns): + SQLJoinConditional.__init__(self, table1, table2, op, None, using_columns) + +registerConverter(SQLJoinUsing, SQLExprConverter) + +def INNERJOINOn(table1, table2, on_condition): + return SQLJoinOn(table1, table2, "INNER JOIN", on_condition) + +def LEFTJOINOn(table1, table2, on_condition): + return SQLJoinOn(table1, table2, "LEFT JOIN", on_condition) + +def LEFTOUTERJOINOn(table1, table2, on_condition): + return SQLJoinOn(table1, table2, "LEFT OUTER JOIN", on_condition) + +def RIGHTJOINOn(table1, table2, on_condition): + return SQLJoinOn(table1, table2, "RIGHT JOIN", on_condition) + +def RIGHTOUTERJOINOn(table1, table2, on_condition): + return SQLJoinOn(table1, table2, "RIGHT OUTER JOIN", on_condition) + +def FULLJOINOn(table1, table2, on_condition): + return SQLJoinOn(table1, table2, "FULL JOIN", on_condition) + +def FULLOUTERJOINOn(table1, table2, on_condition): + return SQLJoinOn(table1, table2, "FULL OUTER JOIN", on_condition) + +def INNERJOINUsing(table1, table2, using_columns): + return SQLJoinUsing(table1, table2, "INNER JOIN", using_columns) + +def LEFTJOINUsing(table1, table2, using_columns): + return SQLJoinUsing(table1, table2, "LEFT JOIN", using_columns) + +def LEFTOUTERJOINUsing(table1, table2, using_columns): + return SQLJoinUsing(table1, table2, "LEFT OUTER JOIN", using_columns) + +def RIGHTJOINUsing(table1, table2, using_columns): + return SQLJoinUsing(table1, table2, "RIGHT JOIN", using_columns) + +def RIGHTOUTERJOINUsing(table1, table2, using_columns): + return SQLJoinUsing(table1, table2, "RIGHT OUTER JOIN", using_columns) + +def FULLJOINUsing(table1, table2, using_columns): + return SQLJoinUsing(table1, table2, "FULL JOIN", using_columns) + +def FULLOUTERJOINUsing(table1, table2, using_columns): + return SQLJoinUsing(table1, table2, "FULL OUTER JOIN", using_columns) + +######################################## ## Global initializations ######################################## Modified: trunk/SQLObject/sqlobject/sresults.py =================================================================== --- trunk/SQLObject/sqlobject/sresults.py 2005-04-15 21:53:30 UTC (rev 725) +++ trunk/SQLObject/sqlobject/sresults.py 2005-04-18 14:23:23 UTC (rev 726) @@ -31,6 +31,13 @@ if ops.has_key('connection') and ops['connection'] is None: del ops['connection'] + def __repr__(self): + return "<%s at %x>" % (self.__class__.__name__, id(self)) + + def __str__(self): + conn = self.ops.get('connection', self.sourceClass._connection) + return conn.queryForSelect(self) + def _mungeOrderBy(self, orderBy): if isinstance(orderBy, str) and orderBy.startswith('-'): orderBy = orderBy[1:] Added: trunk/SQLObject/sqlobject/tests/test_joins_conditional.py =================================================================== --- trunk/SQLObject/sqlobject/tests/test_joins_conditional.py 2005-04-15 21:53:30 UTC (rev 725) +++ trunk/SQLObject/sqlobject/tests/test_joins_conditional.py 2005-04-18 14:23:23 UTC (rev 726) @@ -0,0 +1,48 @@ +from sqlobject import * +from sqlobject.sqlbuilder import * +from sqlobject.tests.dbtest import * + +######################################## +## Condiotional joins +######################################## + +class TestJoin1(SQLObject): + col1 = StringCol() + +class TestJoin2(SQLObject): + col2 = StringCol() + +def setup(): + setupClass(TestJoin1) + setupClass(TestJoin2) + +def test_1syntax(): + setup() + join = JOIN("table1", "table2") + assert str(join) == "table1 JOIN table2" + join = LEFTJOIN("table1", "table2") + assert str(join) == "table1 LEFT JOIN table2" + join = LEFTJOINOn("table1", "table2", "tabl1.col1 = table2.col2") + assert getConnection().sqlrepr(join) == "table1 LEFT JOIN table2 ON tabl1.col1 = table2.col2" + +def test_2select_syntax(): + setup() + select = TestJoin1.select( + join=LEFTJOINConditional(None, TestJoin2, + on_condition=(TestJoin1.q.col1 == TestJoin2.q.col2)) + ) + assert str(select) == \ + "SELECT test_join1.id, test_join1.col1 FROM test_join1 LEFT JOIN test_join2 ON (test_join1.col1 = test_join2.col2) WHERE 1 = 1" + +def test_3perform_join(): + setup() + TestJoin1(col1="test1") + TestJoin1(col1="test2") + TestJoin1(col1="test3") + TestJoin2(col2="test1") + TestJoin2(col2="test2") + + select = TestJoin1.select( + join=LEFTJOINOn(None, TestJoin2, TestJoin1.q.col1 == TestJoin2.q.col2) + ) + assert select.count() == 3 |