[Sqlalchemy-commits] [5324] sqlalchemy/branches/reflection/lib/sqlalchemy: Refactored Postgresql d
Brought to you by:
zzzeek
From: <co...@sq...> - 2008-11-22 19:29:25
|
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN" "http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd"> <html xmlns="http://www.w3.org/1999/xhtml"> <head><meta http-equiv="content-type" content="text/html; charset=utf-8" /><style type="text/css"><!-- #msg dl { border: 1px #006 solid; background: #369; padding: 6px; color: #fff; } #msg dt { float: left; width: 6em; font-weight: bold; } #msg dt:after { content:':';} #msg dl, #msg dt, #msg ul, #msg li, #header, #footer { font-family: verdana,arial,helvetica,sans-serif; font-size: 10pt; } #msg dl a { font-weight: bold} #msg dl a:link { color:#fc3; } #msg dl a:active { color:#ff0; } #msg dl a:visited { color:#cc6; } h3 { font-family: verdana,arial,helvetica,sans-serif; font-size: 10pt; font-weight: bold; } #msg pre { overflow: auto; background: #ffc; border: 1px #fc0 solid; padding: 6px; } #msg ul, pre { overflow: auto; } #header, #footer { color: #fff; background: #636; border: 1px #300 solid; padding: 6px; } #patch { width: 100%; } #patch h4 {font-family: verdana,arial,helvetica,sans-serif;font-size:10pt;padding:8px;background:#369;color:#fff;margin:0;} #patch .propset h4, #patch .binary h4 {margin:0;} #patch pre {padding:0;line-height:1.2em;margin:0;} #patch .diff {width:100%;background:#eee;padding: 0 0 10px 0;overflow:auto;} #patch .propset .diff, #patch .binary .diff {padding:10px 0;} #patch span {display:block;padding:0 10px;} #patch .modfile, #patch .addfile, #patch .delfile, #patch .propset, #patch .binary, #patch .copfile {border:1px solid #ccc;margin:10px 0;} #patch ins {background:#dfd;text-decoration:none;display:block;padding:0 10px;} #patch del {background:#fdd;text-decoration:none;display:block;padding:0 10px;} #patch .lines, .info {color:#888;background:#fff;} --></style> <title>[5324] sqlalchemy/branches/reflection/lib/sqlalchemy: Refactored Postgresql dialect to implement reflection api.</title> </head> <body> <div id="msg"> <dl> <dt>Revision</dt> <dd>5324</dd> <dt>Author</dt> <dd>randall</dd> <dt>Date</dt> <dd>2008-11-22 14:29:20 -0500 (Sat, 22 Nov 2008)</dd> </dl> <h3>Log Message</h3> <pre>Refactored Postgresql dialect to implement reflection api. It passes both dialect and reflection tests.</pre> <h3>Modified Paths</h3> <ul> <li><a href="#sqlalchemybranchesreflectionlibsqlalchemydatabasespostgrespy">sqlalchemy/branches/reflection/lib/sqlalchemy/databases/postgres.py</a></li> <li><a href="#sqlalchemybranchesreflectionlibsqlalchemyenginebasepy">sqlalchemy/branches/reflection/lib/sqlalchemy/engine/base.py</a></li> </ul> </div> <div id="patch"> <h3>Diff</h3> <a id="sqlalchemybranchesreflectionlibsqlalchemydatabasespostgrespy"></a> <div class="modfile"><h4>Modified: sqlalchemy/branches/reflection/lib/sqlalchemy/databases/postgres.py (5323 => 5324)</h4> <pre class="diff"><span> <span class="info">--- sqlalchemy/branches/reflection/lib/sqlalchemy/databases/postgres.py 2008-11-22 19:22:42 UTC (rev 5323) +++ sqlalchemy/branches/reflection/lib/sqlalchemy/databases/postgres.py 2008-11-22 19:29:20 UTC (rev 5324) </span><span class="lines">@@ -368,24 +368,47 @@ </span><span class="cx"> """ % locals() </span><span class="cx"> return [row[0].decode(self.encoding) for row in connection.execute(s)] </span><span class="cx"> </span><del>- def server_version_info(self, connection): - v = connection.execute("select version()").scalar() - m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v) - if not m: - raise AssertionError("Could not determine version from string '%s'" % v) - return tuple([int(x) for x in m.group(1, 2, 3)]) - - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - if table.schema is not None: </del><ins>+ def __make_schema_where_clause(self, schema): + if schema is not None: </ins><span class="cx"> schema_where_clause = "n.nspname = :schema" </span><del>- schemaname = table.schema </del><ins>+ schemaname = schema </ins><span class="cx"> if isinstance(schemaname, str): </span><span class="cx"> schemaname = schemaname.decode(self.encoding) </span><span class="cx"> else: </span><span class="cx"> schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" </span><span class="cx"> schemaname = None </span><ins>+ return schema_where_clause </ins><span class="cx"> </span><ins>+ def _get_table_oid(self, connection, table_name, schema=None): + schema_where_clause = self.__make_schema_where_clause(schema) + query = """ + SELECT DISTINCT a.attrelid as table_oid + FROM pg_catalog.pg_attribute a + WHERE a.attrelid = ( + SELECT c.oid + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n + ON n.oid = c.relnamespace + WHERE (%s) + AND c.relname = :table_name AND c.relkind in ('r','v') + ) AND a.attnum > 0 AND NOT a.attisdropped; + """ % schema_where_clause + if isinstance(table_name, str): + table_name = table_name.decode(self.encoding) + if isinstance(schema, str): + schema = schema.decode(self.encoding) + s = sql.text(query, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)]) + c = connection.execute(s, table_name=table_name, schema=schema) + rows = c.fetchall() + if not rows: + raise exc.NoSuchTableError(table_name) + return rows[0].table_oid + + def get_columns(self, connection, table_name, schema=None, info_cache=None): + preparer = self.identifier_preparer + schema_where_clause = self.__make_schema_where_clause(schema) + schemaname = schema + # This could probably be simplified since we have _get_table_oid. </ins><span class="cx"> SQL_COLS = """ </span><span class="cx"> SELECT a.attname, </span><span class="cx"> pg_catalog.format_type(a.atttypid, a.atttypmod), </span><span class="lines">@@ -405,31 +428,27 @@ </span><span class="cx"> """ % schema_where_clause </span><span class="cx"> </span><span class="cx"> s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}) </span><del>- tablename = table.name </del><ins>+ tablename = table_name </ins><span class="cx"> if isinstance(tablename, str): </span><span class="cx"> tablename = tablename.decode(self.encoding) </span><ins>+ if isinstance(schemaname, str): + schemaname = schemaname.decode(self.encoding) </ins><span class="cx"> c = connection.execute(s, table_name=tablename, schema=schemaname) </span><span class="cx"> rows = c.fetchall() </span><del>- </del><span class="cx"> if not rows: </span><span class="cx"> raise exc.NoSuchTableError(table.name) </span><del>- </del><span class="cx"> domains = self._load_domains(connection) </span><del>- </del><ins>+ # Format the results. + columns = [] </ins><span class="cx"> for name, format_type, default, notnull, attnum, table_oid in rows: </span><del>- if include_columns and name not in include_columns: - continue - </del><span class="cx"> ## strip (30) from character varying(30) </span><span class="cx"> attype = re.search('([^\([]+)', format_type).group(1) </span><span class="cx"> nullable = not notnull </span><span class="cx"> is_array = format_type.endswith('[]') </span><del>- </del><span class="cx"> try: </span><span class="cx"> charlen = re.search('\(([\d,]+)\)', format_type).group(1) </span><span class="cx"> except: </span><span class="cx"> charlen = False </span><del>- </del><span class="cx"> numericprec = False </span><span class="cx"> numericscale = False </span><span class="cx"> if attype == 'numeric': </span><span class="lines">@@ -444,20 +463,17 @@ </span><span class="cx"> if attype == 'integer': </span><span class="cx"> numericprec, numericscale = (32, 0) </span><span class="cx"> charlen = False </span><del>- </del><span class="cx"> args = [] </span><span class="cx"> for a in (charlen, numericprec, numericscale): </span><span class="cx"> if a is None: </span><span class="cx"> args.append(None) </span><span class="cx"> elif a is not False: </span><span class="cx"> args.append(int(a)) </span><del>- </del><span class="cx"> kwargs = {} </span><span class="cx"> if attype == 'timestamp with time zone': </span><span class="cx"> kwargs['timezone'] = True </span><span class="cx"> elif attype == 'timestamp without time zone': </span><span class="cx"> kwargs['timezone'] = False </span><del>- </del><span class="cx"> if attype in ischema_names: </span><span class="cx"> coltype = ischema_names[attype] </span><span class="cx"> else: </span><span class="lines">@@ -466,14 +482,12 @@ </span><span class="cx"> if domain['attype'] in ischema_names: </span><span class="cx"> # A table can't override whether the domain is nullable. </span><span class="cx"> nullable = domain['nullable'] </span><del>- </del><span class="cx"> if domain['default'] and not default: </span><span class="cx"> # It can, however, override the default value, but can't set it to null. </span><span class="cx"> default = domain['default'] </span><span class="cx"> coltype = ischema_names[domain['attype']] </span><span class="cx"> else: </span><span class="cx"> coltype = None </span><del>- </del><span class="cx"> if coltype: </span><span class="cx"> coltype = coltype(*args, **kwargs) </span><span class="cx"> if is_array: </span><span class="lines">@@ -482,75 +496,98 @@ </span><span class="cx"> util.warn("Did not recognize type '%s' of column '%s'" % </span><span class="cx"> (attype, name)) </span><span class="cx"> coltype = sqltypes.NULLTYPE </span><ins>+ other_args = [] + columns.append((name, coltype, nullable, default, other_args)) + if table_oid is not None and info_cache is not None: + info_cache['table_oid'] = table_oid + return columns </ins><span class="cx"> </span><del>- colargs = [] - if default is not None: - match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) - if match is not None: - # the default is related to a Sequence - sch = table.schema - if '.' not in match.group(2) and sch is not None: - # unconditionally quote the schema name. this could - # later be enhanced to obey quoting rules / "quote schema" - default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) - colargs.append(schema.DefaultClause(sql.text(default))) - table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) - - - # Primary keys </del><ins>+ def get_primary_keys(self, connection, table_name, schema=None, + info_cache=None): + if info_cache is None: + info_cache = {} </ins><span class="cx"> PK_SQL = """ </span><del>- SELECT attname FROM pg_attribute </del><ins>+ SELECT attname AS colname FROM pg_attribute </ins><span class="cx"> WHERE attrelid = ( </span><span class="cx"> SELECT indexrelid FROM pg_index i </span><del>- WHERE i.indrelid = :table </del><ins>+ WHERE i.indrelid = :table_oid </ins><span class="cx"> AND i.indisprimary = 't') </span><span class="cx"> ORDER BY attnum </span><span class="cx"> """ </span><span class="cx"> t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) </span><del>- c = connection.execute(t, table=table_oid) - for row in c.fetchall(): - pk = row[0] - col = table.c[pk] - table.primary_key.add(col) - if col.default is None: - col.autoincrement = False </del><ins>+ table_oid = info_cache.get('table_oid') + if table_oid is None: + table_oid = self._get_table_oid(connection, table_name, schema) + info_cache['table_oid'] = table_oid + c = connection.execute(t, table_oid=table_oid) + return [tuple(r) for r in c.fetchall()] </ins><span class="cx"> </span><del>- # Foreign keys </del><ins>+ def get_foreign_keys(self, connection, table_name, schema=None, + info_cache=None): + if info_cache is None: + info_cache = {} + preparer = self.identifier_preparer </ins><span class="cx"> FK_SQL = """ </span><span class="cx"> SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef </span><span class="cx"> FROM pg_catalog.pg_constraint r </span><span class="cx"> WHERE r.conrelid = :table AND r.contype = 'f' </span><span class="cx"> ORDER BY 1 </span><span class="cx"> """ </span><del>- </del><span class="cx"> t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode}) </span><ins>+ table_oid = info_cache.get('table_oid') + if table_oid is None: + table_oid = self._get_table_oid(connection, table_name, schema) + info_cache['table_oid'] = table_oid </ins><span class="cx"> c = connection.execute(t, table=table_oid) </span><ins>+ fkeys = [] </ins><span class="cx"> for conname, condef in c.fetchall(): </span><span class="cx"> m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() </span><span class="cx"> (constrained_columns, referred_schema, referred_table, referred_columns) = m </span><span class="cx"> constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] </span><del>- if referred_schema: - referred_schema = preparer._unquote_identifier(referred_schema) - elif table.schema is not None and table.schema == self.get_default_schema_name(connection): - # no schema (i.e. its the default schema), and the table we're - # reflecting has the default schema explicit, then use that. - # i.e. try to use the user's conventions - referred_schema = table.schema - referred_table = preparer._unquote_identifier(referred_table) - referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] </del><ins>+ referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', referred_columns)] + fkeys.append((conname, constrained_columns, referred_schema, + referred_table, referred_columns)) + return fkeys </ins><span class="cx"> </span><del>- refspec = [] - if referred_schema is not None: - schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, - autoload_with=connection) - for column in referred_columns: - refspec.append(".".join([referred_schema, referred_table, column])) - else: - schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) - for column in referred_columns: - refspec.append(".".join([referred_table, column])) </del><ins>+ def server_version_info(self, connection): + v = connection.execute("select version()").scalar() + m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v) + if not m: + raise AssertionError("Could not determine version from string '%s'" % v) + return tuple([int(x) for x in m.group(1, 2, 3)]) </ins><span class="cx"> </span><del>- table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname)) </del><ins>+ def reflecttable(self, connection, table, include_columns): + info_cache = {} + preparer = self.identifier_preparer + # Columns + columns = self.get_columns(connection, table.name, table.schema, + info_cache) + for (name, coltype, nullable, default, other_args) in columns: + if include_columns and name not in include_columns: + continue + colargs = [] + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + # the default is related to a Sequence + if '.' not in match.group(2) and table.schema is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / "quote schema" + default = match.group(1) + ('"%s"' % table.schema) + '.' + match.group(2) + match.group(3) + colargs.append(schema.DefaultClause(sql.text(default))) + table.append_column(schema.Column(name, coltype, nullable=nullable, + *colargs)) + # Primary Keys + for row in self.get_primary_keys(connection, table.name, table.schema, + info_cache): + pk = row[0] + col = table.c[pk] + table.primary_key.add(col) + if col.default is None: + col.autoincrement = False + # Foreign keys + self._reflect_foreign_keys(connection, table, info_cache) + return None </ins><span class="cx"> </span><span class="cx"> def _load_domains(self, connection): </span><span class="cx"> ## Load data types for domains: </span></span></pre></div> <a id="sqlalchemybranchesreflectionlibsqlalchemyenginebasepy"></a> <div class="modfile"><h4>Modified: sqlalchemy/branches/reflection/lib/sqlalchemy/engine/base.py (5323 => 5324)</h4> <pre class="diff"><span> <span class="info">--- sqlalchemy/branches/reflection/lib/sqlalchemy/engine/base.py 2008-11-22 19:22:42 UTC (rev 5323) +++ sqlalchemy/branches/reflection/lib/sqlalchemy/engine/base.py 2008-11-22 19:29:20 UTC (rev 5324) </span><span class="lines">@@ -144,6 +144,40 @@ </span><span class="cx"> names. </span><span class="cx"> """ </span><span class="cx"> </span><ins>+ def _reflect_foreign_keys(self, connection, table, info_cache=None): + """Reflect foreign keys onto `table`. + + If there's no special behaviour required in a subclass, it can call + this method for foreign key reflection. + + """ + # Foreign keys (attempting same approache as postgresql.py) + preparer = self.identifier_preparer + fkeys = self.get_foreign_keys(connection, table.name, table.schema, + info_cache) + for (constraint_name, constrained_columns, referred_schema, + referred_table, referred_columns) in fkeys: + if referred_schema: + referred_schema = preparer._unquote_identifier(referred_schema) + elif table.schema is not None and table.schema == self.get_default_schema_name(connection): + # no schema (i.e. its the default schema), and the table we're + # reflecting has the default schema explicit, then use that. + # i.e. try to use the user's conventions + referred_schema = table.schema + referred_table = preparer._unquote_identifier(referred_table) + referred_columns = [preparer._unquote_identifier(x) for x in referred_columns] + refspec = [] + if referred_schema is not None: + schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, + autoload_with=connection) + for column in referred_columns: + refspec.append(".".join([referred_schema, referred_table, column])) + else: + schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) + for column in referred_columns: + refspec.append(".".join([referred_table, column])) + table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, constraint_name)) + </ins><span class="cx"> def get_columns(self, connection, table_name, schema=None, info_cache=None): </span><span class="cx"> """Return information about columns in `table_name`. </span><span class="cx"> </span></span></pre> </div> </div> </body> </html> |