[Sqlalchemy-commits] [1463] sqlalchemy/trunk/test: rick morrison's CASE statement + unit test
Brought to you by:
zzzeek
From: <co...@sq...> - 2006-05-15 23:47:19
|
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN" "http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd"> <html xmlns="http://www.w3.org/1999/xhtml"> <head><style type="text/css"><!-- #msg dl { border: 1px #006 solid; background: #369; padding: 6px; color: #fff; } #msg dt { float: left; width: 6em; font-weight: bold; } #msg dt:after { content:':';} #msg dl, #msg dt, #msg ul, #msg li { font-family: verdana,arial,helvetica,sans-serif; font-size: 10pt; } #msg dl a { font-weight: bold} #msg dl a:link { color:#fc3; } #msg dl a:active { color:#ff0; } #msg dl a:visited { color:#cc6; } h3 { font-family: verdana,arial,helvetica,sans-serif; font-size: 10pt; font-weight: bold; } #msg pre { overflow: auto; background: #ffc; border: 1px #fc0 solid; padding: 6px; } #msg ul, pre { overflow: auto; } #patch { width: 100%; } #patch h4 {font-family: verdana,arial,helvetica,sans-serif;font-size:10pt;padding:8px;background:#369;color:#fff;margin:0;} #patch .propset h4, #patch .binary h4 {margin:0;} #patch pre {padding:0;line-height:1.2em;margin:0;} #patch .diff {width:100%;background:#eee;padding: 0 0 10px 0;overflow:auto;} #patch .propset .diff, #patch .binary .diff {padding:10px 0;} #patch span {display:block;padding:0 10px;} #patch .modfile, #patch .addfile, #patch .delfile, #patch .propset, #patch .binary, #patch .copfile {border:1px solid #ccc;margin:10px 0;} #patch ins {background:#dfd;text-decoration:none;display:block;padding:0 10px;} #patch del {background:#fdd;text-decoration:none;display:block;padding:0 10px;} #patch .lines, .info {color:#888;background:#fff;} --></style> <title>[1463] sqlalchemy/trunk/test: rick morrison's CASE statement + unit test</title> </head> <body> <div id="msg"> <dl> <dt>Revision</dt> <dd>1463</dd> <dt>Author</dt> <dd>zzzeek</dd> <dt>Date</dt> <dd>2006-05-15 18:47:07 -0500 (Mon, 15 May 2006)</dd> </dl> <h3>Log Message</h3> <pre>rick morrison's CASE statement + unit test</pre> <h3>Modified Paths</h3> <ul> <li><a href="#sqlalchemytrunklibsqlalchemyansisqlpy">sqlalchemy/trunk/lib/sqlalchemy/ansisql.py</a></li> <li><a href="#sqlalchemytrunklibsqlalchemysqlpy">sqlalchemy/trunk/lib/sqlalchemy/sql.py</a></li> <li><a href="#sqlalchemytrunktestalltestspy">sqlalchemy/trunk/test/alltests.py</a></li> </ul> <h3>Added Paths</h3> <ul> <li><a href="#sqlalchemytrunktestcase_statementpy">sqlalchemy/trunk/test/case_statement.py</a></li> </ul> </div> <div id="patch"> <h3>Diff</h3> <a id="sqlalchemytrunklibsqlalchemyansisqlpy"></a> <div class="modfile"><h4>Modified: sqlalchemy/trunk/lib/sqlalchemy/ansisql.py (1462 => 1463)</h4> <pre class="diff"><span> <span class="info">--- sqlalchemy/trunk/lib/sqlalchemy/ansisql.py 2006-05-15 22:47:41 UTC (rev 1462) +++ sqlalchemy/trunk/lib/sqlalchemy/ansisql.py 2006-05-15 23:47:07 UTC (rev 1463) </span><span class="lines">@@ -224,7 +224,13 @@ </span><span class="cx"> </span><span class="cx"> def apply_function_parens(self, func): </span><span class="cx"> return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 </span><del>- </del><ins>+ + def visit_calculatedclause(self, list): + if list.parens: + self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")" + else: + self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ') + </ins><span class="cx"> def visit_function(self, func): </span><span class="cx"> if len(self.select_stack): </span><span class="cx"> self.typemap.setdefault(func.name, func.type) </span></span></pre></div> <a id="sqlalchemytrunklibsqlalchemysqlpy"></a> <div class="modfile"><h4>Modified: sqlalchemy/trunk/lib/sqlalchemy/sql.py (1462 => 1463)</h4> <pre class="diff"><span> <span class="info">--- sqlalchemy/trunk/lib/sqlalchemy/sql.py 2006-05-15 22:47:41 UTC (rev 1462) +++ sqlalchemy/trunk/lib/sqlalchemy/sql.py 2006-05-15 23:47:07 UTC (rev 1463) </span><span class="lines">@@ -1,4 +1,3 @@ </span><del>-# sql.py </del><span class="cx"> # Copyright (C) 2005,2006 Michael Bayer mi...@zz... </span><span class="cx"> # </span><span class="cx"> # This module is part of SQLAlchemy and is released under </span><span class="lines">@@ -13,7 +12,7 @@ </span><span class="cx"> import string, re, random </span><span class="cx"> types = __import__('types') </span><span class="cx"> </span><del>-__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] </del><ins>+__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] </ins><span class="cx"> </span><span class="cx"> def desc(column): </span><span class="cx"> """returns a descending ORDER BY clause element, e.g.: </span><span class="lines">@@ -132,6 +131,17 @@ </span><span class="cx"> """ returns BETWEEN predicate clause (clausetest BETWEEN clauseleft AND clauseright) """ </span><span class="cx"> return BooleanExpression(ctest, and_(cleft, cright), 'BETWEEN') </span><span class="cx"> between = between_ </span><ins>+ +def case(whens, value=None, else_=None): + """ SQL CASE statement -- whens are a sequence of pairs to be translated into "when / then" clauses; + optional [value] for simple case statements, and [else_] for case defaults """ + whenlist = [CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens] + if else_: + whenlist.append(CompoundClause(None, 'ELSE', else_)) + cc = CalculatedClause(None, 'CASE', value, *whenlist + ['END']) + for c in cc.clauses: + c.parens = False + return cc </ins><span class="cx"> </span><span class="cx"> def cast(clause, totype, **kwargs): </span><span class="cx"> """ returns CAST function CAST(clause AS totype) </span><span class="lines">@@ -295,6 +305,7 @@ </span><span class="cx"> def visit_join(self, join):pass </span><span class="cx"> def visit_null(self, null):pass </span><span class="cx"> def visit_clauselist(self, list):pass </span><ins>+ def visit_calculatedclause(self, calcclause):pass </ins><span class="cx"> def visit_function(self, func):pass </span><span class="cx"> def visit_label(self, label):pass </span><span class="cx"> def visit_typeclause(self, typeclause):pass </span><span class="lines">@@ -831,9 +842,42 @@ </span><span class="cx"> return self.operator == other.operator </span><span class="cx"> else: </span><span class="cx"> return False </span><ins>+ +class CalculatedClause(ClauseList, ColumnElement): + """ describes a calculated SQL expression that has a type, like CASE. extends ColumnElement to + provide column-level comparison operators. """ + def __init__(self, name, *clauses, **kwargs): + self.name = name + self.type = sqltypes.to_instance(kwargs.get('type', None)) + self._engine = kwargs.get('engine', None) + ClauseList.__init__(self, *clauses) + key = property(lambda self:self.name or "_calc_") + def _process_from_dict(self, data, asfrom): + super(CalculatedClause, self)._process_from_dict(data, asfrom) + # this helps a Select object get the engine from us + data.setdefault(self, self) + def copy_container(self): + clauses = [clause.copy_container() for clause in self.clauses] + return CalculatedClause(type=self.type, engine=self._engine, *clauses) + def accept_visitor(self, visitor): + for c in self.clauses: + c.accept_visitor(visitor) + visitor.visit_calculatedclause(self) + def _bind_param(self, obj): + return BindParamClause(self.name, obj, type=self.type) + def select(self): + return select([self]) + def scalar(self): + return select([self]).scalar() + def execute(self): + return select([self]).execute() + def _compare_type(self, obj): + return self.type + </ins><span class="cx"> </span><del>-class Function(ClauseList, ColumnElement): - """describes a SQL function. extends ClauseList to provide comparison operators.""" </del><ins>+class Function(CalculatedClause): + """describes a SQL function. extends CalculatedClause turn the "clauselist" into function + arguments, also adds a "packagenames" argument""" </ins><span class="cx"> def __init__(self, name, *clauses, **kwargs): </span><span class="cx"> self.name = name </span><span class="cx"> self.type = sqltypes.to_instance(kwargs.get('type', None)) </span><span class="lines">@@ -848,10 +892,6 @@ </span><span class="cx"> else: </span><span class="cx"> clause = BindParamClause(self.name, clause, shortname=self.name, type=None) </span><span class="cx"> self.clauses.append(clause) </span><del>- def _process_from_dict(self, data, asfrom): - super(Function, self)._process_from_dict(data, asfrom) - # this helps a Select object get the engine from us - data.setdefault(self, self) </del><span class="cx"> def copy_container(self): </span><span class="cx"> clauses = [clause.copy_container() for clause in self.clauses] </span><span class="cx"> return Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses) </span><span class="lines">@@ -859,17 +899,8 @@ </span><span class="cx"> for c in self.clauses: </span><span class="cx"> c.accept_visitor(visitor) </span><span class="cx"> visitor.visit_function(self) </span><del>- def _bind_param(self, obj): - return BindParamClause(self.name, obj, shortname=self.name, type=self.type) - def select(self): - return select([self]) - def scalar(self): - return select([self]).scalar() - def execute(self): - return select([self]).execute() - def _compare_type(self, obj): - return self.type </del><span class="cx"> </span><ins>+ </ins><span class="cx"> class FunctionGenerator(object): </span><span class="cx"> """generates Function objects based on getattr calls""" </span><span class="cx"> def __init__(self, engine=None): </span></span></pre></div> <a id="sqlalchemytrunktestalltestspy"></a> <div class="modfile"><h4>Modified: sqlalchemy/trunk/test/alltests.py (1462 => 1463)</h4> <pre class="diff"><span> <span class="info">--- sqlalchemy/trunk/test/alltests.py 2006-05-15 22:47:41 UTC (rev 1462) +++ sqlalchemy/trunk/test/alltests.py 2006-05-15 23:47:07 UTC (rev 1463) </span><span class="lines">@@ -24,6 +24,7 @@ </span><span class="cx"> # SQL syntax </span><span class="cx"> 'select', </span><span class="cx"> 'selectable', </span><ins>+ 'case_statement', </ins><span class="cx"> </span><span class="cx"> # assorted round-trip tests </span><span class="cx"> 'query', </span></span></pre></div> <a id="sqlalchemytrunktestcase_statementpy"></a> <div class="addfile"><h4>Added: sqlalchemy/trunk/test/case_statement.py (1462 => 1463)</h4> <pre class="diff"><span> <span class="info">--- sqlalchemy/trunk/test/case_statement.py 2006-05-15 22:47:41 UTC (rev 1462) +++ sqlalchemy/trunk/test/case_statement.py 2006-05-15 23:47:07 UTC (rev 1463) </span><span class="lines">@@ -0,0 +1,59 @@ </span><ins>+import sys +import testbase +from sqlalchemy import * + + +class CaseTest(testbase.PersistTest): + + def setUpAll(self): + global info_table + info_table = Table('infos', testbase.db, + Column('pk', Integer, primary_key=True), + Column('info', String)) + + info_table.create() + + info_table.insert().execute( + {'pk':1, 'info':'pk_1_data'}, + {'pk':2, 'info':'pk_2_data'}, + {'pk':3, 'info':'pk_3_data'}, + {'pk':4, 'info':'pk_4_data'}, + {'pk':5, 'info':'pk_5_data'}) + def tearDownAll(self): + info_table.drop() + + def testcase(self): + inner = select([case([[info_table.c.pk < 3, literal('lessthan3', type=String)], + [info_table.c.pk >= 3, literal('gt3', type=String)]]).label('x'), + info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') + + inner_result = inner.execute().fetchall() + + # Outputs: + # lessthan3 1 pk_1_data + # lessthan3 2 pk_2_data + # gt3 3 pk_3_data + # gt3 4 pk_4_data + # gt3 5 pk_5_data + assert inner_result == [ + ('lessthan3', 1, 'pk_1_data'), + ('lessthan3', 2, 'pk_2_data'), + ('gt3', 3, 'pk_3_data'), + ('gt3', 4, 'pk_4_data'), + ('gt3', 5, 'pk_5_data'), + ] + + outer = select([inner]) + + outer_result = outer.execute().fetchall() + + assert outer_result == [ + ('lessthan3', 1, 'pk_1_data'), + ('lessthan3', 2, 'pk_2_data'), + ('gt3', 3, 'pk_3_data'), + ('gt3', 4, 'pk_4_data'), + ('gt3', 5, 'pk_5_data'), + ] + +if __name__ == "__main__": + testbase.main() </ins></span></pre> </div> </div> </body> </html> |