summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-05-21 19:03:32 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-05-21 19:03:32 +0000
commitd45657a2f5b880dc22dda2d1eb1687af5234a470 (patch)
tree5cebf0c4b0d9f12071176bbdc8a4de47cb31b151 /lib/sqlalchemy
parentb67548ad788fc0eb8782dfd5a1d2a016dc5c7f78 (diff)
parent4550983e0ce2f35b3585e53894c941c23693e71d (diff)
downloadsqlalchemy-d45657a2f5b880dc22dda2d1eb1687af5234a470.tar.gz
Merge "Performance fixes for new result set"
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/cextension/resultproxy.c137
-rw-r--r--lib/sqlalchemy/engine/base.py4
-rw-r--r--lib/sqlalchemy/engine/cursor.py86
-rw-r--r--lib/sqlalchemy/engine/default.py131
-rw-r--r--lib/sqlalchemy/engine/result.py103
-rw-r--r--lib/sqlalchemy/engine/row.py103
-rw-r--r--lib/sqlalchemy/orm/loading.py16
-rw-r--r--lib/sqlalchemy/orm/query.py2
-rw-r--r--lib/sqlalchemy/sql/compiler.py19
-rw-r--r--lib/sqlalchemy/sql/selectable.py29
-rw-r--r--lib/sqlalchemy/testing/assertions.py4
-rw-r--r--lib/sqlalchemy/util/__init__.py3
-rw-r--r--lib/sqlalchemy/util/_collections.py8
13 files changed, 415 insertions, 230 deletions
diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c
index d5a6ea0c8..244379116 100644
--- a/lib/sqlalchemy/cextension/resultproxy.c
+++ b/lib/sqlalchemy/cextension/resultproxy.c
@@ -45,12 +45,19 @@ typedef struct {
PyObject *parent;
PyObject *row;
PyObject *keymap;
+ long key_style;
} BaseRow;
static PyObject *sqlalchemy_engine_row = NULL;
static PyObject *sqlalchemy_engine_result = NULL;
+
+//static int KEY_INTEGER_ONLY = 0;
+//static int KEY_OBJECTS_ONLY = 1;
+static int KEY_OBJECTS_BUT_WARN = 2;
+//static int KEY_OBJECTS_NO_WARN = 3;
+
/****************
* BaseRow *
****************/
@@ -90,13 +97,13 @@ safe_rowproxy_reconstructor(PyObject *self, PyObject *args)
static int
BaseRow_init(BaseRow *self, PyObject *args, PyObject *kwds)
{
- PyObject *parent, *keymap, *row, *processors;
+ PyObject *parent, *keymap, *row, *processors, *key_style;
Py_ssize_t num_values, num_processors;
PyObject **valueptr, **funcptr, **resultptr;
PyObject *func, *result, *processed_value, *values_fastseq;
- if (!PyArg_UnpackTuple(args, "BaseRow", 4, 4,
- &parent, &processors, &keymap, &row))
+ if (!PyArg_UnpackTuple(args, "BaseRow", 5, 5,
+ &parent, &processors, &keymap, &key_style, &row))
return -1;
Py_INCREF(parent);
@@ -107,44 +114,61 @@ BaseRow_init(BaseRow *self, PyObject *args, PyObject *kwds)
return -1;
num_values = PySequence_Length(values_fastseq);
- num_processors = PySequence_Size(processors);
- if (num_values != num_processors) {
- PyErr_Format(PyExc_RuntimeError,
- "number of values in row (%d) differ from number of column "
- "processors (%d)",
- (int)num_values, (int)num_processors);
- return -1;
+
+
+ if (processors != Py_None) {
+ num_processors = PySequence_Size(processors);
+ if (num_values != num_processors) {
+ PyErr_Format(PyExc_RuntimeError,
+ "number of values in row (%d) differ from number of column "
+ "processors (%d)",
+ (int)num_values, (int)num_processors);
+ return -1;
+ }
+
+ } else {
+ num_processors = -1;
}
result = PyTuple_New(num_values);
if (result == NULL)
return -1;
- valueptr = PySequence_Fast_ITEMS(values_fastseq);
- funcptr = PySequence_Fast_ITEMS(processors);
- resultptr = PySequence_Fast_ITEMS(result);
- while (--num_values >= 0) {
- func = *funcptr;
- if (func != Py_None) {
- processed_value = PyObject_CallFunctionObjArgs(
- func, *valueptr, NULL);
- if (processed_value == NULL) {
- Py_DECREF(values_fastseq);
- Py_DECREF(result);
- return -1;
+ if (num_processors != -1) {
+ valueptr = PySequence_Fast_ITEMS(values_fastseq);
+ funcptr = PySequence_Fast_ITEMS(processors);
+ resultptr = PySequence_Fast_ITEMS(result);
+ while (--num_values >= 0) {
+ func = *funcptr;
+ if (func != Py_None) {
+ processed_value = PyObject_CallFunctionObjArgs(
+ func, *valueptr, NULL);
+ if (processed_value == NULL) {
+ Py_DECREF(values_fastseq);
+ Py_DECREF(result);
+ return -1;
+ }
+ *resultptr = processed_value;
+ } else {
+ Py_INCREF(*valueptr);
+ *resultptr = *valueptr;
}
- *resultptr = processed_value;
- } else {
+ valueptr++;
+ funcptr++;
+ resultptr++;
+ }
+ } else {
+ valueptr = PySequence_Fast_ITEMS(values_fastseq);
+ resultptr = PySequence_Fast_ITEMS(result);
+ while (--num_values >= 0) {
Py_INCREF(*valueptr);
*resultptr = *valueptr;
+ valueptr++;
+ resultptr++;
}
- valueptr++;
- funcptr++;
- resultptr++;
}
Py_DECREF(values_fastseq);
-
self->row = result;
if (!PyDict_CheckExact(keymap)) {
@@ -153,7 +177,7 @@ BaseRow_init(BaseRow *self, PyObject *args, PyObject *kwds)
}
Py_INCREF(keymap);
self->keymap = keymap;
-
+ self->key_style = PyLong_AsLong(key_style);
return 0;
}
@@ -202,7 +226,7 @@ BaseRow_reduce(PyObject *self)
static PyObject *
BaseRow_filter_on_values(BaseRow *self, PyObject *filters)
{
- PyObject *module, *row_class, *new_obj;
+ PyObject *module, *row_class, *new_obj, *key_style;
if (sqlalchemy_engine_row == NULL) {
module = PyImport_ImportModule("sqlalchemy.engine.row");
@@ -216,7 +240,12 @@ BaseRow_filter_on_values(BaseRow *self, PyObject *filters)
// at the same time
row_class = PyObject_GetAttrString(sqlalchemy_engine_row, "Row");
- new_obj = PyObject_CallFunction(row_class, "OOOO", self->parent, filters, self->keymap, self->row);
+ key_style = PyLong_FromLong(self->key_style);
+ Py_INCREF(key_style);
+
+ new_obj = PyObject_CallFunction(
+ row_class, "OOOOO", self->parent, filters, self->keymap,
+ key_style, self->row);
Py_DECREF(row_class);
if (new_obj == NULL) {
return NULL;
@@ -356,7 +385,7 @@ BaseRow_getitem_by_object(BaseRow *self, PyObject *key, int asmapping)
/* -1 can be either the actual value, or an error flag. */
return NULL;
- if (!asmapping) {
+ if (!asmapping && self->key_style == KEY_OBJECTS_BUT_WARN) {
PyObject *tmp;
tmp = PyObject_CallMethod(self->parent, "_warn_for_nonint", "O", key);
@@ -416,7 +445,12 @@ BaseRow_subscript(BaseRow *self, PyObject *key)
static PyObject *
BaseRow_subscript_mapping(BaseRow *self, PyObject *key)
{
- return BaseRow_subscript_impl(self, key, 1);
+ if (self->key_style == KEY_OBJECTS_BUT_WARN) {
+ return BaseRow_subscript_impl(self, key, 0);
+ }
+ else {
+ return BaseRow_subscript_impl(self, key, 1);
+ }
}
@@ -567,6 +601,39 @@ BaseRow_setkeymap(BaseRow *self, PyObject *value, void *closure)
return 0;
}
+static PyObject *
+BaseRow_getkeystyle(BaseRow *self, void *closure)
+{
+ PyObject *result;
+
+ result = PyLong_FromLong(self->key_style);
+ Py_INCREF(result);
+ return result;
+}
+
+
+static int
+BaseRow_setkeystyle(BaseRow *self, PyObject *value, void *closure)
+{
+ if (value == NULL) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "Cannot delete the 'key_style' attribute");
+ return -1;
+ }
+
+ if (!PyLong_CheckExact(value)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "The 'key_style' attribute value must be an integer");
+ return -1;
+ }
+
+ self->key_style = PyLong_AsLong(value);
+
+ return 0;
+}
+
static PyGetSetDef BaseRow_getseters[] = {
{"_parent",
(getter)BaseRow_getparent, (setter)BaseRow_setparent,
@@ -580,6 +647,10 @@ static PyGetSetDef BaseRow_getseters[] = {
(getter)BaseRow_getkeymap, (setter)BaseRow_setkeymap,
"Key to (obj, index) dict",
NULL},
+ {"_key_style",
+ (getter)BaseRow_getkeystyle, (setter)BaseRow_setkeystyle,
+ "Return the key style",
+ NULL},
{NULL}
};
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index f169655e0..bbfafe8f1 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1482,8 +1482,10 @@ class Connection(Connectable):
if (
not self._is_future
- and context.should_autocommit
+ # usually we're in a transaction so avoid relatively
+ # expensive / legacy should_autocommit call
and self._transaction is None
+ and context.should_autocommit
):
self._commit_impl(autocommit=True)
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index a886d2025..a393f8da7 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -671,6 +671,8 @@ class CursorResultMetaData(ResultMetaData):
class LegacyCursorResultMetaData(CursorResultMetaData):
+ __slots__ = ()
+
def _contains(self, value, row):
key = value
if key in self._keymap:
@@ -813,17 +815,15 @@ class NoCursorFetchStrategy(ResultFetchStrategy):
"""
- __slots__ = ("closed",)
+ __slots__ = ()
- def __init__(self, closed):
- self.closed = closed
- self.cursor_description = None
+ cursor_description = None
def soft_close(self, result):
pass
def hard_close(self, result):
- self.closed = True
+ pass
def fetchone(self, result):
return self._non_result(result, None)
@@ -849,8 +849,10 @@ class NoCursorDQLFetchStrategy(NoCursorFetchStrategy):
"""
+ __slots__ = ()
+
def _non_result(self, result, default, err=None):
- if self.closed:
+ if result.closed:
util.raise_(
exc.ResourceClosedError("This result object is closed."),
replace_context=err,
@@ -859,6 +861,9 @@ class NoCursorDQLFetchStrategy(NoCursorFetchStrategy):
return default
+_NO_CURSOR_DQL = NoCursorDQLFetchStrategy()
+
+
class NoCursorDMLFetchStrategy(NoCursorFetchStrategy):
"""Cursor strategy for a DML result that has no open cursor.
@@ -867,12 +872,17 @@ class NoCursorDMLFetchStrategy(NoCursorFetchStrategy):
"""
+ __slots__ = ()
+
def _non_result(self, result, default, err=None):
# we only expect to have a _NoResultMetaData() here right now.
assert not result._metadata.returns_rows
result._metadata._we_dont_return_rows(err)
+_NO_CURSOR_DML = NoCursorDMLFetchStrategy()
+
+
class CursorFetchStrategy(ResultFetchStrategy):
"""Call fetch methods from a DBAPI cursor.
@@ -893,15 +903,15 @@ class CursorFetchStrategy(ResultFetchStrategy):
description = dbapi_cursor.description
if description is None:
- return NoCursorDMLFetchStrategy(False)
+ return _NO_CURSOR_DML
else:
return cls(dbapi_cursor, description)
def soft_close(self, result):
- result.cursor_strategy = NoCursorDQLFetchStrategy(False)
+ result.cursor_strategy = _NO_CURSOR_DQL
def hard_close(self, result):
- result.cursor_strategy = NoCursorDQLFetchStrategy(True)
+ result.cursor_strategy = _NO_CURSOR_DQL
def handle_exception(self, result, err):
result.connection._handle_dbapi_exception(
@@ -1016,7 +1026,7 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy):
description = dbapi_cursor.description
if description is None:
- return NoCursorDMLFetchStrategy(False)
+ return _NO_CURSOR_DML
else:
max_row_buffer = result.context.execution_options.get(
"max_row_buffer", 1000
@@ -1184,7 +1194,7 @@ class _NoResultMetaData(ResultMetaData):
self._we_dont_return_rows()
-_no_result_metadata = _NoResultMetaData()
+_NO_RESULT_METADATA = _NoResultMetaData()
class BaseCursorResult(object):
@@ -1199,11 +1209,12 @@ class BaseCursorResult(object):
@classmethod
def _create_for_context(cls, context):
+
if context._is_future_result:
- obj = object.__new__(CursorResult)
+ obj = CursorResult(context)
else:
- obj = object.__new__(LegacyCursorResult)
- obj.__init__(context)
+ obj = LegacyCursorResult(context)
+
return obj
def __init__(self, context):
@@ -1214,35 +1225,33 @@ class BaseCursorResult(object):
self._echo = (
self.connection._echo and context.engine._should_log_debug()
)
- self._init_metadata()
- def _init_metadata(self):
- self.cursor_strategy = strat = self.context.get_result_cursor_strategy(
- self
- )
+ # this is a hook used by dialects to change the strategy,
+ # so for the moment we have to keep calling this every time
+ # :(
+ self.cursor_strategy = strat = context.get_result_cursor_strategy(self)
if strat.cursor_description is not None:
- if self.context.compiled:
- if self.context.compiled._cached_metadata:
- cached_md = self.context.compiled._cached_metadata
- self._metadata = cached_md._adapt_to_context(self.context)
+ self._init_metadata(context, strat.cursor_description)
+ else:
+ self._metadata = _NO_RESULT_METADATA
+
+ def _init_metadata(self, context, cursor_description):
+ if context.compiled:
+ if context.compiled._cached_metadata:
+ cached_md = context.compiled._cached_metadata
+ self._metadata = cached_md._adapt_to_context(context)
- else:
- self._metadata = (
- self.context.compiled._cached_metadata
- ) = self._cursor_metadata(self, strat.cursor_description)
else:
- self._metadata = self._cursor_metadata(
- self, strat.cursor_description
- )
- if self._echo:
- self.context.engine.logger.debug(
- "Col %r", tuple(x[0] for x in strat.cursor_description)
- )
+ self._metadata = (
+ context.compiled._cached_metadata
+ ) = self._cursor_metadata(self, cursor_description)
else:
- self._metadata = _no_result_metadata
- # leave cursor open so that execution context can continue
- # setting up things like rowcount
+ self._metadata = self._cursor_metadata(self, cursor_description)
+ if self._echo:
+ context.engine.logger.debug(
+ "Col %r", tuple(x[0] for x in cursor_description)
+ )
def _soft_close(self, hard=False):
"""Soft close this :class:`_engine.CursorResult`.
@@ -1638,9 +1647,6 @@ class CursorResult(BaseCursorResult, Result):
def _fetchmany_impl(self, size=None):
return self.cursor_strategy.fetchmany(self, size)
- def _soft_close(self, **kw):
- BaseCursorResult._soft_close(self, **kw)
-
def _raw_row_iterator(self):
return self._fetchiter_impl()
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index d9b4cdda6..094ab3d55 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -372,6 +372,8 @@ class DefaultDialect(interfaces.Dialect):
return None
def _check_unicode_returns(self, connection, additional_tests=None):
+ # this now runs in py2k only and will be removed in 2.0; disabled for
+ # Python 3 in all cases under #5315
if util.py2k and not self.supports_unicode_statements:
cast_to = util.binary_type
else:
@@ -752,15 +754,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.compiled = compiled = compiled_ddl
self.isddl = True
- self.execution_options = compiled.execution_options
- if connection._execution_options:
- self.execution_options = self.execution_options.union(
- connection._execution_options
- )
- if execution_options:
- self.execution_options = self.execution_options.union(
- execution_options
- )
+ self.execution_options = compiled.execution_options.merge_with(
+ connection._execution_options, execution_options
+ )
self._is_future_result = (
connection._is_future
@@ -815,15 +811,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
# we get here
assert compiled.can_execute
- self.execution_options = compiled.execution_options
- if connection._execution_options:
- self.execution_options = self.execution_options.union(
- connection._execution_options
- )
- if execution_options:
- self.execution_options = self.execution_options.union(
- execution_options
- )
+ self.execution_options = compiled.execution_options.merge_with(
+ connection._execution_options, execution_options
+ )
self._is_future_result = (
connection._is_future
@@ -921,42 +911,32 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
# Convert the dictionary of bind parameter values
# into a dict or list to be sent to the DBAPI's
# execute() or executemany() method.
- parameters = []
if compiled.positional:
- for compiled_params in self.compiled_parameters:
- param = [
- processors[key](compiled_params[key])
- if key in processors
- else compiled_params[key]
- for key in positiontup
- ]
- parameters.append(dialect.execute_sequence_format(param))
+ parameters = [
+ dialect.execute_sequence_format(
+ [
+ processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key]
+ for key in positiontup
+ ]
+ )
+ for compiled_params in self.compiled_parameters
+ ]
else:
encode = not dialect.supports_unicode_statements
- for compiled_params in self.compiled_parameters:
- if encode:
- param = dict(
- (
- dialect._encoder(key)[0],
- processors[key](compiled_params[key])
- if key in processors
- else compiled_params[key],
- )
- for key in compiled_params
- )
- else:
- param = dict(
- (
- key,
- processors[key](compiled_params[key])
- if key in processors
- else compiled_params[key],
- )
- for key in compiled_params
- )
-
- parameters.append(param)
+ parameters = [
+ {
+ dialect._encoder(key)[0]
+ if encode
+ else key: processors[key](value)
+ if key in processors
+ else value
+ for key, value in compiled_params.items()
+ }
+ for compiled_params in self.compiled_parameters
+ ]
self.parameters = dialect.execute_sequence_format(parameters)
@@ -980,14 +960,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.dialect = connection.dialect
self.is_text = True
- if connection._execution_options:
- self.execution_options = self.execution_options.union(
- connection._execution_options
- )
- if execution_options:
- self.execution_options = self.execution_options.union(
- execution_options
- )
+ self.execution_options = self.execution_options.merge_with(
+ connection._execution_options, execution_options
+ )
self._is_future_result = (
connection._is_future
@@ -1038,14 +1013,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self._dbapi_connection = dbapi_connection
self.dialect = connection.dialect
- if connection._execution_options:
- self.execution_options = self.execution_options.union(
- connection._execution_options
- )
- if execution_options:
- self.execution_options = self.execution_options.union(
- execution_options
- )
+ self.execution_options = self.execution_options.merge_with(
+ connection._execution_options, execution_options
+ )
self._is_future_result = (
connection._is_future
@@ -1173,7 +1143,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return use_server_side
def create_cursor(self):
- if self._use_server_side_cursor():
+ if (
+ # inlining initial preference checks for SS cursors
+ self.dialect.supports_server_side_cursors
+ and (
+ self.execution_options.get("stream_results", False)
+ or (
+ self.dialect.server_side_cursors
+ and self._use_server_side_cursor()
+ )
+ )
+ ):
self._is_server_side = True
return self.create_server_side_cursor()
else:
@@ -1227,6 +1207,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
pass
def get_result_cursor_strategy(self, result):
+ """Dialect-overriable hook to return the internal strategy that
+ fetches results.
+
+
+ Some dialects will in some cases return special objects here that
+ have pre-buffered rows from some source or another, such as turning
+ Oracle OUT parameters into rows to accommodate for "returning",
+ SQL Server fetching "returning" before it resets "identity insert",
+ etc.
+
+ """
if self._is_server_side:
strat_cls = _cursor.BufferedRowCursorFetchStrategy
else:
@@ -1312,7 +1303,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
# the first row will have been fetched and current assumptions
# are that the result has only one row, until executemany()
# support is added here.
- assert result.returns_rows
+ assert result._metadata.returns_rows
result._soft_close()
elif not self._is_explicit_returning:
result._soft_close()
@@ -1330,9 +1321,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
# test that it has a cursor metadata that is accurate.
# the rows have all been fetched however.
- assert result.returns_rows
+ assert result._metadata.returns_rows
- elif not result.returns_rows:
+ elif not result._metadata.returns_rows:
# no results, get rowcount
# (which requires open cursor on some drivers
# such as kintersbasdb, mxodbc)
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index fe0abf0bb..109ab41fe 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -86,7 +86,7 @@ class ResultMetaData(object):
index = self._index_for_key(key, raiseerr)
if index is not None:
- return operator.methodcaller("_get_by_key_impl_mapping", index)
+ return operator.itemgetter(index)
else:
return None
@@ -169,10 +169,7 @@ class SimpleResultMetaData(ResultMetaData):
self._keymap = {key: rec for keys, rec in recs_names for key in keys}
- if _processors is None:
- self._processors = [None] * len_keys
- else:
- self._processors = _processors
+ self._processors = _processors
def _for_freeze(self):
unique_filters = self._unique_filters
@@ -256,7 +253,9 @@ class SimpleResultMetaData(ResultMetaData):
def result_tuple(fields, extra=None):
parent = SimpleResultMetaData(fields, extra)
- return functools.partial(Row, parent, parent._processors, parent._keymap)
+ return functools.partial(
+ Row, parent, parent._processors, parent._keymap, Row._default_key_style
+ )
# a symbol that indicates to internal Result methods that
@@ -280,6 +279,8 @@ class Result(InPlaceGenerative):
_row_logging_fn = None
+ _source_supports_scalars = False
+ _generate_rows = True
_column_slice_filter = None
_post_creational_filter = None
_unique_filter_state = None
@@ -388,11 +389,14 @@ class Result(InPlaceGenerative):
uniques, strategy = self._unique_filter_state
if not strategy and self._metadata._unique_filters:
- filters = self._metadata._unique_filters
- if self._metadata._tuplefilter:
- filters = self._metadata._tuplefilter(filters)
+ if self._source_supports_scalars:
+ strategy = self._metadata._unique_filters[0]
+ else:
+ filters = self._metadata._unique_filters
+ if self._metadata._tuplefilter:
+ filters = self._metadata._tuplefilter(filters)
- strategy = operator.methodcaller("_filter_on_values", filters)
+ strategy = operator.methodcaller("_filter_on_values", filters)
return uniques, strategy
def columns(self, *col_expressions):
@@ -489,7 +493,8 @@ class Result(InPlaceGenerative):
"""
result = self._column_slices([index])
- result._post_creational_filter = operator.itemgetter(0)
+ if self._generate_rows:
+ result._post_creational_filter = operator.itemgetter(0)
result._no_scalar_onerow = True
return result
@@ -497,11 +502,20 @@ class Result(InPlaceGenerative):
def _column_slices(self, indexes):
self._metadata = self._metadata._reduce(indexes)
+ if self._source_supports_scalars and len(indexes) == 1:
+ self._generate_rows = False
+ else:
+ self._generate_rows = True
+
def _getter(self, key, raiseerr=True):
"""return a callable that will retrieve the given key from a
:class:`.Row`.
"""
+ if self._source_supports_scalars:
+ raise NotImplementedError(
+ "can't use this function in 'only scalars' mode"
+ )
return self._metadata._getter(key, raiseerr)
def _tuple_getter(self, keys):
@@ -509,6 +523,10 @@ class Result(InPlaceGenerative):
:class:`.Row`.
"""
+ if self._source_supports_scalars:
+ raise NotImplementedError(
+ "can't use this function in 'only scalars' mode"
+ )
return self._metadata._row_as_tuple_getter(keys)
@_generative
@@ -527,9 +545,14 @@ class Result(InPlaceGenerative):
"""
self._post_creational_filter = operator.attrgetter("_mapping")
self._no_scalar_onerow = False
+ self._generate_rows = True
def _row_getter(self):
+ if self._source_supports_scalars and not self._generate_rows:
+ return None
+
process_row = self._process_row
+ key_style = self._process_row._default_key_style
metadata = self._metadata
keymap = metadata._keymap
@@ -537,10 +560,11 @@ class Result(InPlaceGenerative):
tf = metadata._tuplefilter
if tf:
- processors = tf(processors)
+ if processors:
+ processors = tf(processors)
_make_row_orig = functools.partial(
- process_row, metadata, processors, keymap
+ process_row, metadata, processors, keymap, key_style
)
def make_row(row):
@@ -548,7 +572,7 @@ class Result(InPlaceGenerative):
else:
make_row = functools.partial(
- process_row, metadata, processors, keymap
+ process_row, metadata, processors, keymap, key_style
)
fns = ()
@@ -626,7 +650,7 @@ class Result(InPlaceGenerative):
def iterrows(self):
for row in self._fetchiter_impl():
- obj = make_row(row)
+ obj = make_row(row) if make_row else row
hashed = strategy(obj) if strategy else obj
if hashed in uniques:
continue
@@ -639,7 +663,7 @@ class Result(InPlaceGenerative):
def iterrows(self):
for row in self._fetchiter_impl():
- row = make_row(row)
+ row = make_row(row) if make_row else row
if post_creational_filter:
row = post_creational_filter(row)
yield row
@@ -658,6 +682,10 @@ class Result(InPlaceGenerative):
def allrows(self):
rows = self._fetchall_impl()
+ if make_row:
+ made_rows = [make_row(row) for row in rows]
+ else:
+ made_rows = rows
rows = [
made_row
for made_row, sig_row in [
@@ -665,7 +693,7 @@ class Result(InPlaceGenerative):
made_row,
strategy(made_row) if strategy else made_row,
)
- for made_row in [make_row(row) for row in rows]
+ for made_row in made_rows
]
if sig_row not in uniques and not uniques.add(sig_row)
]
@@ -678,11 +706,16 @@ class Result(InPlaceGenerative):
def allrows(self):
rows = self._fetchall_impl()
+
if post_creational_filter:
- rows = [
- post_creational_filter(make_row(row)) for row in rows
- ]
- else:
+ if make_row:
+ rows = [
+ post_creational_filter(make_row(row))
+ for row in rows
+ ]
+ else:
+ rows = [post_creational_filter(row) for row in rows]
+ elif make_row:
rows = [make_row(row) for row in rows]
return rows
@@ -708,7 +741,7 @@ class Result(InPlaceGenerative):
if row is None:
return _NO_ROW
else:
- obj = make_row(row)
+ obj = make_row(row) if make_row else row
hashed = strategy(obj) if strategy else obj
if hashed in uniques:
continue
@@ -725,7 +758,7 @@ class Result(InPlaceGenerative):
if row is None:
return _NO_ROW
else:
- row = make_row(row)
+ row = make_row(row) if make_row else row
if post_creational_filter:
row = post_creational_filter(row)
return row
@@ -1042,6 +1075,8 @@ class FrozenResult(object):
def __init__(self, result):
self.metadata = result._metadata._for_freeze()
self._post_creational_filter = result._post_creational_filter
+ self._source_supports_scalars = result._source_supports_scalars
+ self._generate_rows = result._generate_rows
result._post_creational_filter = None
self.data = result.fetchall()
@@ -1056,6 +1091,8 @@ class FrozenResult(object):
def __call__(self):
result = IteratorResult(self.metadata, iter(self.data))
result._post_creational_filter = self._post_creational_filter
+ result._source_supports_scalars = self._source_supports_scalars
+ result._generate_rows = self._generate_rows
return result
@@ -1112,16 +1149,28 @@ class ChunkedIteratorResult(IteratorResult):
"""
- def __init__(self, cursor_metadata, chunks):
+ def __init__(self, cursor_metadata, chunks, source_supports_scalars=False):
self._metadata = cursor_metadata
self.chunks = chunks
+ self._source_supports_scalars = source_supports_scalars
+
+ self.iterator = itertools.chain.from_iterable(
+ self.chunks(None, self._generate_rows)
+ )
- self.iterator = itertools.chain.from_iterable(self.chunks(None))
+ def _column_slices(self, indexes):
+ result = super(ChunkedIteratorResult, self)._column_slices(indexes)
+ self.iterator = itertools.chain.from_iterable(
+ self.chunks(self._yield_per, self._generate_rows)
+ )
+ return result
@_generative
def yield_per(self, num):
self._yield_per = num
- self.iterator = itertools.chain.from_iterable(self.chunks(num))
+ self.iterator = itertools.chain.from_iterable(
+ self.chunks(num, self._generate_rows)
+ )
class MergedResult(IteratorResult):
@@ -1149,6 +1198,8 @@ class MergedResult(IteratorResult):
self._post_creational_filter = results[0]._post_creational_filter
self._no_scalar_onerow = results[0]._no_scalar_onerow
self._yield_per = results[0]._yield_per
+ self._source_supports_scalars = results[0]._source_supports_scalars
+ self._generate_rows = results[0]._generate_rows
def close(self):
self._soft_close(hard=True)
diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py
index 6cd020110..d279776ce 100644
--- a/lib/sqlalchemy/engine/row.py
+++ b/lib/sqlalchemy/engine/row.py
@@ -14,7 +14,6 @@ from .. import util
from ..sql import util as sql_util
from ..util.compat import collections_abc
-
MD_INDEX = 0 # integer index in cursor.description
# This reconstructor is necessary so that pickles with the C extension or
@@ -40,6 +39,11 @@ except ImportError:
return obj
+KEY_INTEGER_ONLY = 0
+KEY_OBJECTS_ONLY = 1
+KEY_OBJECTS_BUT_WARN = 2
+KEY_OBJECTS_NO_WARN = 3
+
try:
from sqlalchemy.cresultproxy import BaseRow
@@ -48,21 +52,27 @@ except ImportError:
_baserow_usecext = False
class BaseRow(object):
- __slots__ = ("_parent", "_data", "_keymap")
+ __slots__ = ("_parent", "_data", "_keymap", "_key_style")
- def __init__(self, parent, processors, keymap, data):
+ def __init__(self, parent, processors, keymap, key_style, data):
"""Row objects are constructed by CursorResult objects."""
self._parent = parent
- self._data = tuple(
- [
- proc(value) if proc else value
- for proc, value in zip(processors, data)
- ]
- )
+ if processors:
+ self._data = tuple(
+ [
+ proc(value) if proc else value
+ for proc, value in zip(processors, data)
+ ]
+ )
+ else:
+ self._data = tuple(data)
+
self._keymap = keymap
+ self._key_style = key_style
+
def __reduce__(self):
return (
rowproxy_reconstructor,
@@ -70,7 +80,13 @@ except ImportError:
)
def _filter_on_values(self, filters):
- return Row(self._parent, filters, self._keymap, self._data)
+ return Row(
+ self._parent,
+ filters,
+ self._keymap,
+ self._key_style,
+ self._data,
+ )
def _values_impl(self):
return list(self)
@@ -105,10 +121,14 @@ except ImportError:
mdindex = rec[MD_INDEX]
if mdindex is None:
self._parent._raise_for_ambiguous_column_name(rec)
- elif not ismapping and mdindex != key and not isinstance(key, int):
- self._parent._warn_for_nonint(key)
- # TODO: warn for non-int here, RemovedIn20Warning when available
+ elif (
+ self._key_style == KEY_OBJECTS_BUT_WARN
+ and not ismapping
+ and mdindex != key
+ and not isinstance(key, int)
+ ):
+ self._parent._warn_for_nonint(key)
return self._data[mdindex]
@@ -164,6 +184,8 @@ class Row(BaseRow, collections_abc.Sequence):
__slots__ = ()
+ _default_key_style = KEY_INTEGER_ONLY
+
@property
def _mapping(self):
"""Return a :class:`.RowMapping` for this :class:`.Row`.
@@ -182,19 +204,29 @@ class Row(BaseRow, collections_abc.Sequence):
.. versionadded:: 1.4
"""
-
- return RowMapping(self)
+ return RowMapping(
+ self._parent,
+ None,
+ self._keymap,
+ RowMapping._default_key_style,
+ self._data,
+ )
def __contains__(self, key):
return key in self._data
def __getstate__(self):
- return {"_parent": self._parent, "_data": self._data}
+ return {
+ "_parent": self._parent,
+ "_data": self._data,
+ "_key_style": self._key_style,
+ }
def __setstate__(self, state):
self._parent = parent = state["_parent"]
self._data = state["_data"]
self._keymap = parent._keymap
+ self._key_style = state["_key_style"]
def _op(self, other, op):
return (
@@ -305,11 +337,20 @@ class LegacyRow(Row):
"""
+ __slots__ = ()
+
+ if util.SQLALCHEMY_WARN_20:
+ _default_key_style = KEY_OBJECTS_BUT_WARN
+ else:
+ _default_key_style = KEY_OBJECTS_NO_WARN
+
def __contains__(self, key):
return self._parent._contains(key, self)
- def __getitem__(self, key):
- return self._get_by_key_impl(key)
+ if not _baserow_usecext:
+
+ def __getitem__(self, key):
+ return self._get_by_key_impl(key)
@util.deprecated(
"1.4",
@@ -441,7 +482,7 @@ class ROMappingView(
return list(other) != list(self)
-class RowMapping(collections_abc.Mapping):
+class RowMapping(BaseRow, collections_abc.Mapping):
"""A ``Mapping`` that maps column names and objects to :class:`.Row` values.
The :class:`.RowMapping` is available from a :class:`.Row` via the
@@ -463,22 +504,26 @@ class RowMapping(collections_abc.Mapping):
"""
- __slots__ = ("row",)
+ __slots__ = ()
- def __init__(self, row):
- self.row = row
+ _default_key_style = KEY_OBJECTS_ONLY
- def __getitem__(self, key):
- return self.row._get_by_key_impl_mapping(key)
+ if not _baserow_usecext:
+
+ def __getitem__(self, key):
+ return self._get_by_key_impl(key)
+
+ def _values_impl(self):
+ return list(self._data)
def __iter__(self):
- return (k for k in self.row._parent.keys if k is not None)
+ return (k for k in self._parent.keys if k is not None)
def __len__(self):
- return len(self.row)
+ return len(self._data)
def __contains__(self, key):
- return self.row._parent._has_key(key)
+ return self._parent._has_key(key)
def __repr__(self):
return repr(dict(self))
@@ -496,11 +541,11 @@ class RowMapping(collections_abc.Mapping):
"""
- return self.row._parent.keys
+ return self._parent.keys
def values(self):
"""Return a view of values for the values represented in the
underlying :class:`.Row`.
"""
- return ROMappingView(self, self.row._values_impl())
+ return ROMappingView(self, self._values_impl())
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 10d937945..0394d999c 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -78,7 +78,7 @@ def instances(query, cursor, context):
],
)
- def chunks(size):
+ def chunks(size, as_tuples):
while True:
yield_per = size
@@ -91,7 +91,13 @@ def instances(query, cursor, context):
else:
fetch = cursor.fetchall()
- rows = [tuple([proc(row) for proc in process]) for row in fetch]
+ if not as_tuples:
+ proc = process[0]
+ rows = [proc(row) for row in fetch]
+ else:
+ rows = [
+ tuple([proc(row) for proc in process]) for row in fetch
+ ]
for path, post_load in context.post_load_paths.items():
post_load.invoke(context, path)
@@ -101,14 +107,15 @@ def instances(query, cursor, context):
if not yield_per:
break
- result = ChunkedIteratorResult(row_metadata, chunks)
+ result = ChunkedIteratorResult(
+ row_metadata, chunks, source_supports_scalars=single_entity
+ )
if query._yield_per:
result.yield_per(query._yield_per)
if single_entity:
result = result.scalars()
- # filtered = context.loaders_require_uniquing
filtered = query._has_mapper_entities
if filtered:
@@ -796,6 +803,7 @@ def _populate_full(
for key, set_callable in populators["expire"]:
if set_callable:
state.expired_attributes.add(key)
+
for key, populator in populators["new"]:
populator(state, dict_, row)
for key, populator in populators["delayed"]:
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index db1fbea2c..70b8a71e3 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -3411,7 +3411,7 @@ class Query(Generative):
querycontext, self._connection_from_session, close_with_result=True
)
- result = conn.execute(querycontext.statement, self._params)
+ result = conn._execute_20(querycontext.statement, self._params)
return loading.instances(querycontext.query, result, querycontext)
def _execute_crud(self, stmt, mapper):
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index d32e3fd7a..ccc1b53fe 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -840,16 +840,17 @@ class SQLCompiler(Compiled):
),
replace_context=err,
)
- resolved_extracted = dict(
- zip([b.key for b in orig_extracted], extracted_parameters)
- )
+
+ resolved_extracted = {
+ b.key: extracted
+ for b, extracted in zip(orig_extracted, extracted_parameters)
+ }
else:
resolved_extracted = None
if params:
pd = {}
- for bindparam in self.bind_names:
- name = self.bind_names[bindparam]
+ for bindparam, name in self.bind_names.items():
if bindparam.key in params:
pd[name] = params[bindparam.key]
elif name in params:
@@ -884,7 +885,7 @@ class SQLCompiler(Compiled):
return pd
else:
pd = {}
- for bindparam in self.bind_names:
+ for bindparam, name in self.bind_names.items():
if _check and bindparam.required:
if _group_number:
raise exc.InvalidRequestError(
@@ -908,11 +909,9 @@ class SQLCompiler(Compiled):
value_param = bindparam
if bindparam.callable:
- pd[
- self.bind_names[bindparam]
- ] = value_param.effective_value
+ pd[name] = value_param.effective_value
else:
- pd[self.bind_names[bindparam]] = value_param.value
+ pd[name] = value_param.value
return pd
@property
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index bcab46d84..cc82c509b 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -3527,12 +3527,14 @@ class Select(
@classmethod
def _create_select_from_fromclause(cls, target, entities, *arg, **kw):
if arg or kw:
- util.warn_deprecated_20(
- "Passing arguments to %s.select() is deprecated and "
- "will be removed in SQLAlchemy 2.0. Please use generative "
- "methods such as select().where(), etc."
- % (target.__class__.__name__,)
- )
+ if util.SQLALCHEMY_WARN_20:
+ util.warn_deprecated_20(
+ "Passing arguments to %s.select() is deprecated and "
+ "will be removed in SQLAlchemy 2.0. "
+ "Please use generative "
+ "methods such as select().where(), etc."
+ % (target.__class__.__name__,)
+ )
return Select(entities, *arg, **kw)
else:
return Select._create_select(*entities)
@@ -3744,13 +3746,14 @@ class Select(
:meth:`_expression.Select.apply_labels`
"""
- util.warn_deprecated_20(
- "The select() function in SQLAlchemy 2.0 will accept a "
- "series of columns / tables and other entities only, "
- "passed positionally. For forwards compatibility, use the "
- "sqlalchemy.future.select() construct.",
- stacklevel=4,
- )
+ if util.SQLALCHEMY_WARN_20:
+ util.warn_deprecated_20(
+ "The select() function in SQLAlchemy 2.0 will accept a "
+ "series of columns / tables and other entities only, "
+ "passed positionally. For forwards compatibility, use the "
+ "sqlalchemy.future.select() construct.",
+ stacklevel=4,
+ )
self._auto_correlate = correlate
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index 87e5ba0d2..ba4a2de72 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -168,7 +168,9 @@ def _expect_warnings(
else:
real_warn(msg, *arg, **kw)
- with mock.patch("warnings.warn", our_warn):
+ with mock.patch("warnings.warn", our_warn), mock.patch(
+ "sqlalchemy.util.SQLALCHEMY_WARN_20", True
+ ), mock.patch("sqlalchemy.engine.row.LegacyRow._default_key_style", 2):
yield
if assert_ and (not py2konly or not compat.py3k):
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 695985a91..6a0b065ee 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -148,5 +148,4 @@ from .langhelpers import warn_limited # noqa
from .langhelpers import wrap_callable # noqa
-# things that used to be not always available,
-# but are now as of current support Python versions
+SQLALCHEMY_WARN_20 = False
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index 10d80fc98..0990acb83 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -52,6 +52,14 @@ class immutabledict(ImmutableContainer, dict):
dict.update(new, d)
return new
+ def merge_with(self, *dicts):
+ new = dict.__new__(self.__class__)
+ dict.__init__(new, self)
+ for d in dicts:
+ if d:
+ dict.update(new, d)
+ return new
+
def __repr__(self):
return "immutabledict(%s)" % dict.__repr__(self)