[Sqlalchemy-tickets] Issue #3781: implement version_id for bulk_save / bulk_update (zzzeek/sqlalche
Brought to you by:
zzzeek
From: Michael B. <iss...@bi...> - 2016-08-26 14:43:02
|
New issue 3781: implement version_id for bulk_save / bulk_update https://bitbucket.org/zzzeek/sqlalchemy/issues/3781/implement-version_id-for-bulk_save Michael Bayer: ``` #!diff diff --git a/test/orm/test_bulk.py b/test/orm/test_bulk.py index 7e1b052..6acd6d2 100644 --- a/test/orm/test_bulk.py +++ b/test/orm/test_bulk.py @@ -13,6 +13,56 @@ class BulkTest(testing.AssertsExecutionResults): run_define_tables = 'each' +class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('version_table', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('version_id', Integer, nullable=False), + Column('value', String(40), nullable=False)) + + @classmethod + def setup_classes(cls): + class Foo(cls.Comparable): + pass + + @classmethod + def setup_mappers(cls): + Foo, version_table = cls.classes.Foo, cls.tables.version_table + + mapper(Foo, version_table, version_id_col=version_table.c.version_id) + + def test_bulk_insert_via_save(self): + Foo = self.classes.Foo + + s = Session() + + s.bulk_save_objects([Foo(value='value')]) + + eq_( + s.query(Foo).all(), + [Foo(version_id=1, value='value')] + ) + + def test_bulk_update_via_save(self): + Foo = self.classes.Foo + + s = Session() + + s.add(Foo(value='value')) + s.commit() + + f1 = s.query(Foo).first() + f1.value = 'new value' + s.bulk_save_objects([f1]) + + eq_( + s.query(Foo).all(), + [Foo(version_id=2, value='new value')] + ) + + class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): @classmethod ``` patch: ``` #!diff diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 5d69f51..467f47f 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -82,11 +82,15 @@ def _bulk_update(mapper, mappings, session_transaction, cached_connections = _cached_connection_dict(base_mapper) + search_keys = mapper._primary_key_propkeys + if mapper._version_id_prop: + search_keys = set([mapper._version_id_prop.key]).union(search_keys) + def _changed_dict(mapper, state): return dict( (k, v) for k, v in state.dict.items() if k in state.committed_state or k - in mapper._primary_key_propkeys + in search_keys ) if isstates: ``` also, tricky, if you do a bulk update on an object that's also in the session, now the object you have locally is stale. because bulk doesn't do any object bookkeeping. also, tricky, need to close |