diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 31 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 65 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/reflection.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/util.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 86 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_sequence.py | 130 |
12 files changed, 291 insertions, 134 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index c35ab2880..5aaecf23a 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2640,6 +2640,19 @@ class MSDialect(default.DefaultDialect): return c.first() is not None @reflection.cache + @_db_plus_owner_listing + def get_sequence_names(self, connection, dbname, owner, schema, **kw): + sequences = ischema.sequences + + s = sql.select([sequences.c.sequence_name]) + if owner: + s = s.where(sequences.c.sequence_schema == owner) + + c = connection.execute(s) + + return [row[0] for row in c] + + @reflection.cache def get_schema_names(self, connection, **kw): s = sql.select( [ischema.schemata.c.schema_name], diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index d009d656e..b34422e65 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2517,6 +2517,8 @@ class MySQLDialect(default.DefaultDialect): rs.close() def has_sequence(self, connection, sequence_name, schema=None): + if not self.supports_sequences: + self._sequences_not_supported() if not schema: schema = self.default_schema_name # MariaDB implements sequences as a special type of table @@ -2524,13 +2526,40 @@ class MySQLDialect(default.DefaultDialect): cursor = connection.execute( sql.text( "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " - "WHERE TABLE_NAME=:name AND " + "WHERE TABLE_TYPE='SEQUENCE' and TABLE_NAME=:name AND " "TABLE_SCHEMA=:schema_name" ), dict(name=sequence_name, schema_name=schema), ) return cursor.first() is not None + def _sequences_not_supported(self): + raise NotImplementedError( + "Sequences are supported only by the " + "MariaDB series 10.3 or greater" + ) + + @reflection.cache + def get_sequence_names(self, connection, schema=None, **kw): + if not self.supports_sequences: + self._sequences_not_supported() + if not schema: + schema = self.default_schema_name + # MariaDB implements sequences as a special type of table + cursor = connection.execute( + sql.text( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE='SEQUENCE' and TABLE_SCHEMA=:schema_name" + ), + dict(schema_name=schema), + ) + return [ + row[0] + for row in self._compat_fetchall( + cursor, charset=self._connection_charset + ) + ] + def initialize(self, connection): self._connection_charset = self._detect_charset(connection) self._detect_sql_mode(connection) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 481ea7263..5e912a0c2 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1635,6 +1635,19 @@ class OracleDialect(default.DefaultDialect): return [self.normalize_name(row[0]) for row in cursor] @reflection.cache + def get_sequence_names(self, connection, schema=None, **kw): + if not schema: + schema = self.default_schema_name + cursor = connection.execute( + sql.text( + "SELECT sequence_name FROM all_sequences " + "WHERE sequence_owner = :schema_name" + ), + schema_name=self.denormalize_name(schema), + ) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache def get_table_options(self, connection, table_name, schema=None, **kw): options = {} diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 924cd6908..f3e775354 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2726,39 +2726,23 @@ class PGDialect(default.DefaultDialect): def has_sequence(self, connection, sequence_name, schema=None): if schema is None: - cursor = connection.execute( - sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=current_schema() " - "and relname=:name" - ).bindparams( - sql.bindparam( - "name", - util.text_type(sequence_name), - type_=sqltypes.Unicode, - ) - ) - ) - else: - cursor = connection.execute( - sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=:schema and relname=:name" - ).bindparams( - sql.bindparam( - "name", - util.text_type(sequence_name), - type_=sqltypes.Unicode, - ), - sql.bindparam( - "schema", - util.text_type(schema), - type_=sqltypes.Unicode, - ), - ) + schema = self.default_schema_name + cursor = connection.execute( + sql.text( + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=:schema and relname=:name" + ).bindparams( + sql.bindparam( + "name", + util.text_type(sequence_name), + type_=sqltypes.Unicode, + ), + sql.bindparam( + "schema", util.text_type(schema), type_=sqltypes.Unicode, + ), ) + ) return bool(cursor.first()) @@ -2915,6 +2899,23 @@ class PGDialect(default.DefaultDialect): return [name for name, in result] @reflection.cache + def get_sequence_names(self, connection, schema=None, **kw): + if not schema: + schema = self.default_schema_name + cursor = connection.execute( + sql.text( + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=:schema" + ).bindparams( + sql.bindparam( + "schema", util.text_type(schema), type_=sqltypes.Unicode, + ), + ) + ) + return [row[0] for row in cursor] + + @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): view_def = connection.scalar( sql.text( diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 49d9af966..59b9cd4ce 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -304,8 +304,17 @@ class Dialect(object): def get_view_names(self, connection, schema=None, **kw): """Return a list of all view names available in the database. - schema: - Optional, retrieve names from a non-default schema. + :param schema: schema name to query, if not the default schema. + """ + + raise NotImplementedError() + + def get_sequence_names(self, connection, schema=None, **kw): + """Return a list of all sequence names available in the database. + + :param schema: schema name to query, if not the default schema. + + .. versionadded:: 1.4 """ raise NotImplementedError() diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 344d5511d..fded37b2a 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -255,11 +255,6 @@ class Inspector(object): support named schemas, behavior is undefined if ``schema`` is not passed as ``None``. For special quoting, use :class:`.quoted_name`. - :param order_by: Optional, may be the string "foreign_key" to sort - the result on foreign key dependencies. Does not automatically - resolve cycles, and will raise :class:`.CircularDependencyError` - if cycles exist. - .. seealso:: :meth:`_reflection.Inspector.get_sorted_table_and_fkc_names` @@ -276,6 +271,10 @@ class Inspector(object): def has_table(self, table_name, schema=None): """Return True if the backend has a table of the given name. + + :param table_name: name of the table to check + :param schema: schema name to query, if not the default schema. + .. versionadded:: 1.4 """ @@ -283,6 +282,19 @@ class Inspector(object): with self._operation_context() as conn: return self.dialect.has_table(conn, table_name, schema) + def has_sequence(self, sequence_name, schema=None): + """Return True if the backend has a table of the given name. + + :param sequence_name: name of the table to check + :param schema: schema name to query, if not the default schema. + + .. versionadded:: 1.4 + + """ + # TODO: info_cache? + with self._operation_context() as conn: + return self.dialect.has_sequence(conn, sequence_name, schema) + def get_sorted_table_and_fkc_names(self, schema=None): """Return dependency-sorted table and foreign key constraint names in referred to within a particular schema. @@ -401,6 +413,19 @@ class Inspector(object): conn, schema, info_cache=self.info_cache ) + def get_sequence_names(self, schema=None): + """Return all sequence names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + + """ + + with self._operation_context() as conn: + return self.dialect.get_sequence_names( + conn, schema, info_cache=self.info_cache + ) + def get_view_definition(self, view_name, schema=None): """Return definition for `view_name`. diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 8fb04646f..fc0260ae2 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -8,6 +8,7 @@ from .. import exc from .. import util from ..util import collections_abc +from ..util import immutabledict def connection_memoize(key): @@ -85,9 +86,11 @@ _no_kw = util.immutabledict() def _distill_params_20(params): + # TODO: this has to be in C if params is None: return _no_tuple, _no_kw, [] - elif isinstance(params, collections_abc.MutableSequence): # list + elif isinstance(params, list): + # collections_abc.MutableSequence): # avoid abc.__instancecheck__ if params and not isinstance( params[0], (collections_abc.Mapping, tuple) ): @@ -99,7 +102,9 @@ def _distill_params_20(params): return tuple(params), _no_kw, params elif isinstance( params, - (collections_abc.Sequence, collections_abc.Mapping), # tuple or dict + (tuple, dict, immutabledict), + # avoid abc.__instancecheck__ + # (collections_abc.Sequence, collections_abc.Mapping), ): return _no_tuple, params, [params] else: diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 588b83571..f380229e1 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -2245,7 +2245,7 @@ class _BundleEntity(_QueryEntity): class _ColumnEntity(_QueryEntity): - __slots__ = () + __slots__ = ("_fetch_column", "_row_processor") @classmethod def _for_columns(cls, compile_state, columns, parent_bundle=None): @@ -2275,6 +2275,44 @@ class _ColumnEntity(_QueryEntity): def use_id_for_hash(self): return not self.column.type.hashable + def row_processor(self, context, result): + compile_state = context.compile_state + + # the resulting callable is entirely cacheable so just return + # it if we already made one + if self._row_processor is not None: + return self._row_processor + + # retrieve the column that would have been set up in + # setup_compile_state, to avoid doing redundant work + if self._fetch_column is not None: + column = self._fetch_column + else: + # fetch_column will be None when we are doing a from_statement + # and setup_compile_state may not have been called. + column = self.column + + # previously, the RawColumnEntity didn't look for from_obj_alias + # however I can't think of a case where we would be here and + # we'd want to ignore it if this is the from_statement use case. + # it's not really a use case to have raw columns + from_statement + if compile_state._from_obj_alias: + column = compile_state._from_obj_alias.columns[column] + + if column._annotations: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + if compile_state.compound_eager_adapter: + column = compile_state.compound_eager_adapter.columns[column] + + getter = result._getter(column) + + ret = getter, self._label_name, self._extra_entities + self._row_processor = ret + return ret + class _RawColumnEntity(_ColumnEntity): entity_zero = None @@ -2303,28 +2341,11 @@ class _RawColumnEntity(_ColumnEntity): self.column._from_objects[0] if self.column._from_objects else None ) self._extra_entities = (self.expr, self.column) + self._fetch_column = self._row_processor = None def corresponds_to(self, entity): return False - def row_processor(self, context, result): - if ("fetch_column", self) in context.attributes: - column = context.attributes[("fetch_column", self)] - else: - column = self.column - - if column._annotations: - # annotated columns perform more slowly in compiler and - # result due to the __eq__() method, so use deannotated - column = column._deannotate() - - compile_state = context.compile_state - if compile_state.compound_eager_adapter: - column = compile_state.compound_eager_adapter.columns[column] - - getter = result._getter(column) - return getter, self._label_name, self._extra_entities - def setup_compile_state(self, compile_state): current_adapter = compile_state._get_current_adapter() if current_adapter: @@ -2338,7 +2359,7 @@ class _RawColumnEntity(_ColumnEntity): column = column._deannotate() compile_state.primary_columns.append(column) - compile_state.attributes[("fetch_column", self)] = column + self._fetch_column = column class _ORMColumnEntity(_ColumnEntity): @@ -2386,6 +2407,7 @@ class _ORMColumnEntity(_ColumnEntity): compile_state._has_orm_entities = True self.column = column + self._fetch_column = self._row_processor = None self._extra_entities = (self.expr, self.column) @@ -2407,27 +2429,6 @@ class _ORMColumnEntity(_ColumnEntity): self.entity_zero ) and entity.common_parent(self.entity_zero) - def row_processor(self, context, result): - compile_state = context.compile_state - - if ("fetch_column", self) in context.attributes: - column = context.attributes[("fetch_column", self)] - else: - column = self.column - if compile_state._from_obj_alias: - column = compile_state._from_obj_alias.columns[column] - - if column._annotations: - # annotated columns perform more slowly in compiler and - # result due to the __eq__() method, so use deannotated - column = column._deannotate() - - if compile_state.compound_eager_adapter: - column = compile_state.compound_eager_adapter.columns[column] - - getter = result._getter(column) - return getter, self._label_name, self._extra_entities - def setup_compile_state(self, compile_state): current_adapter = compile_state._get_current_adapter() if current_adapter: @@ -2460,5 +2461,4 @@ class _ORMColumnEntity(_ColumnEntity): compile_state._fallback_from_clauses.append(ezero.selectable) compile_state.primary_columns.append(column) - - compile_state.attributes[("fetch_column", self)] = column + self._fetch_column = column diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index cdad55320..458217e22 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1637,8 +1637,7 @@ class Query( "params() takes zero or one positional argument, " "which is a dictionary." ) - params = dict(self.load_options._params) - params.update(kwargs) + params = self.load_options._params.union(kwargs) self.load_options += {"_params": params} @_generative @@ -1965,6 +1964,19 @@ class Query( join(Order.items).\ join(Item.keywords) + .. note:: as seen in the above example, **the order in which each + call to the join() method occurs is important**. Query would not, + for example, know how to join correctly if we were to specify + ``User``, then ``Item``, then ``Order``, in our chain of joins; in + such a case, depending on the arguments passed, it may raise an + error that it doesn't know how to join, or it may produce invalid + SQL in which case the database will raise an error. In correct + practice, the + :meth:`_query.Query.join` method is invoked in such a way that lines + up with how we would want the JOIN clauses in SQL to be + rendered, and each call should represent a clear link from what + precedes it. + **Joins to a Target Entity or Selectable** A second form of :meth:`_query.Query.join` allows any mapped entity or diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 5f2ce8f14..4c603b6dd 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -495,8 +495,14 @@ class Generative(HasMemoized): def _generate(self): skip = self._memoized_keys - s = self.__class__.__new__(self.__class__) - s.__dict__ = {k: v for k, v in self.__dict__.items() if k not in skip} + cls = self.__class__ + s = cls.__new__(cls) + if skip: + s.__dict__ = { + k: v for k, v in self.__dict__.items() if k not in skip + } + else: + s.__dict__ = self.__dict__.copy() return s diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 2cc34448d..1eac76598 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -136,6 +136,7 @@ class TablesTest(TestBase): metadata = None tables = None other = None + sequences = None @classmethod def setup_class(cls): @@ -154,6 +155,7 @@ class TablesTest(TestBase): cls.other = adict() cls.tables = adict() + cls.sequences = adict() cls.bind = cls.setup_bind() cls.metadata = sa.MetaData() @@ -173,6 +175,7 @@ class TablesTest(TestBase): if cls.run_create_tables == "once": cls.metadata.create_all(cls.bind) cls.tables.update(cls.metadata.tables) + cls.sequences.update(cls.metadata._sequences) def _setup_each_tables(self): if self.run_define_tables == "each": @@ -180,6 +183,7 @@ class TablesTest(TestBase): if self.run_create_tables == "each": self.metadata.create_all(self.bind) self.tables.update(self.metadata.tables) + self.sequences.update(self.metadata._sequences) elif self.run_create_tables == "each": self.metadata.create_all(self.bind) diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index dda447c0d..55e8e8406 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -1,12 +1,13 @@ from .. import config from .. import fixtures from ..assertions import eq_ +from ..assertions import is_true from ..config import requirements from ..schema import Column from ..schema import Table +from ... import inspect from ... import Integer from ... import MetaData -from ... import schema from ... import Sequence from ... import String from ... import testing @@ -88,69 +89,108 @@ class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): ) -class HasSequenceTest(fixtures.TestBase): +class HasSequenceTest(fixtures.TablesTest): + run_deletes = None + __requires__ = ("sequences",) __backend__ = True - def test_has_sequence(self, connection): - s1 = Sequence("user_id_seq") - connection.execute(schema.CreateSequence(s1)) - try: - eq_( - connection.dialect.has_sequence(connection, "user_id_seq"), - True, + @classmethod + def define_tables(cls, metadata): + Sequence("user_id_seq", metadata=metadata) + Sequence("other_seq", metadata=metadata) + if testing.requires.schemas.enabled: + Sequence( + "user_id_seq", schema=config.test_schema, metadata=metadata + ) + Sequence( + "schema_seq", schema=config.test_schema, metadata=metadata ) - finally: - connection.execute(schema.DropSequence(s1)) + Table( + "user_id_table", metadata, Column("id", Integer, primary_key=True), + ) + + def test_has_sequence(self, connection): + eq_( + inspect(connection).has_sequence("user_id_seq"), True, + ) + + def test_has_sequence_other_object(self, connection): + eq_( + inspect(connection).has_sequence("user_id_table"), False, + ) @testing.requires.schemas def test_has_sequence_schema(self, connection): - s1 = Sequence("user_id_seq", schema=config.test_schema) - connection.execute(schema.CreateSequence(s1)) - try: - eq_( - connection.dialect.has_sequence( - connection, "user_id_seq", schema=config.test_schema - ), - True, - ) - finally: - connection.execute(schema.DropSequence(s1)) + eq_( + inspect(connection).has_sequence( + "user_id_seq", schema=config.test_schema + ), + True, + ) def test_has_sequence_neg(self, connection): - eq_(connection.dialect.has_sequence(connection, "user_id_seq"), False) + eq_( + inspect(connection).has_sequence("some_sequence"), False, + ) @testing.requires.schemas def test_has_sequence_schemas_neg(self, connection): eq_( - connection.dialect.has_sequence( - connection, "user_id_seq", schema=config.test_schema + inspect(connection).has_sequence( + "some_sequence", schema=config.test_schema ), False, ) @testing.requires.schemas def test_has_sequence_default_not_in_remote(self, connection): - s1 = Sequence("user_id_seq") - connection.execute(schema.CreateSequence(s1)) - try: - eq_( - connection.dialect.has_sequence( - connection, "user_id_seq", schema=config.test_schema - ), - False, - ) - finally: - connection.execute(schema.DropSequence(s1)) + eq_( + inspect(connection).has_sequence( + "other_sequence", schema=config.test_schema + ), + False, + ) @testing.requires.schemas def test_has_sequence_remote_not_in_default(self, connection): - s1 = Sequence("user_id_seq", schema=config.test_schema) - connection.execute(schema.CreateSequence(s1)) - try: - eq_( - connection.dialect.has_sequence(connection, "user_id_seq"), - False, - ) - finally: - connection.execute(schema.DropSequence(s1)) + eq_( + inspect(connection).has_sequence("schema_seq"), False, + ) + + def test_get_sequence_names(self, connection): + exp = {"other_seq", "user_id_seq"} + + res = set(inspect(connection).get_sequence_names()) + is_true(res.intersection(exp) == exp) + is_true("schema_seq" not in res) + + @testing.requires.schemas + def test_get_sequence_names_no_sequence_schema(self, connection): + eq_( + inspect(connection).get_sequence_names( + schema=config.test_schema_2 + ), + [], + ) + + @testing.requires.schemas + def test_get_sequence_names_sequences_schema(self, connection): + eq_( + sorted( + inspect(connection).get_sequence_names( + schema=config.test_schema + ) + ), + ["schema_seq", "user_id_seq"], + ) + + +class HasSequenceTestEmpty(fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + def test_get_sequence_names_no_sequence(self, connection): + eq_( + inspect(connection).get_sequence_names(), [], + ) |
