diff options
author | David Lord <davidism@gmail.com> | 2021-04-10 16:12:25 -0700 |
---|---|---|
committer | David Lord <davidism@gmail.com> | 2021-04-10 16:12:25 -0700 |
commit | 1932ce3dc44e16394888cf04666d9b7a1795da76 (patch) | |
tree | 240f231e7ca27ae18fa35855ea3d4a4a7eddcca1 /src/jinja2 | |
parent | a9b06f4bd271de7f56347f602dd90b41c3db8327 (diff) | |
download | jinja2-inline-async.tar.gz |
async support doesn't require patchinginline-async
Diffstat (limited to 'src/jinja2')
-rw-r--r-- | src/jinja2/async_utils.py | 76 | ||||
-rw-r--r-- | src/jinja2/asyncfilters.py | 261 | ||||
-rw-r--r-- | src/jinja2/asyncsupport.py | 249 | ||||
-rw-r--r-- | src/jinja2/compiler.py | 15 | ||||
-rw-r--r-- | src/jinja2/environment.py | 133 | ||||
-rw-r--r-- | src/jinja2/filters.py | 190 | ||||
-rw-r--r-- | src/jinja2/nativetypes.py | 7 | ||||
-rw-r--r-- | src/jinja2/runtime.py | 104 | ||||
-rw-r--r-- | src/jinja2/utils.py | 11 |
9 files changed, 453 insertions, 593 deletions
diff --git a/src/jinja2/async_utils.py b/src/jinja2/async_utils.py new file mode 100644 index 0000000..cb011b2 --- /dev/null +++ b/src/jinja2/async_utils.py @@ -0,0 +1,76 @@ +import inspect +import typing as t +from functools import wraps + +from .utils import _PassArg +from .utils import pass_eval_context + +if t.TYPE_CHECKING: + V = t.TypeVar("V") + + +def async_variant(normal_func): + def decorator(async_func): + pass_arg = _PassArg.from_obj(normal_func) + need_eval_context = pass_arg is None + + if pass_arg is _PassArg.environment: + + def is_async(args): + return args[0].is_async + + else: + + def is_async(args): + return args[0].environment.is_async + + @wraps(normal_func) + def wrapper(*args, **kwargs): + b = is_async(args) + + if need_eval_context: + args = args[1:] + + if b: + return async_func(*args, **kwargs) + + return normal_func(*args, **kwargs) + + if need_eval_context: + wrapper = pass_eval_context(wrapper) + + wrapper.jinja_async_variant = True + return wrapper + + return decorator + + +async def auto_await(value): + if inspect.isawaitable(value): + return await value + + return value + + +async def auto_aiter(iterable): + if hasattr(iterable, "__aiter__"): + async for item in iterable: + yield item + else: + for item in iterable: + yield item + + +async def auto_to_list( + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", +) -> "t.List[V]": + seq = [] + + if hasattr(value, "__aiter__"): + async for item in t.cast(t.AsyncIterable, value): + seq.append(item) + else: + for item in t.cast(t.Iterable, value): + seq.append(item) + + return seq diff --git a/src/jinja2/asyncfilters.py b/src/jinja2/asyncfilters.py deleted file mode 100644 index 00cae01..0000000 --- a/src/jinja2/asyncfilters.py +++ /dev/null @@ -1,261 +0,0 @@ -import typing -import typing as t -import warnings -from functools import wraps -from itertools import groupby - -from . import filters -from .asyncsupport import auto_aiter -from .asyncsupport import auto_await -from .utils import _PassArg -from .utils import pass_eval_context - -if t.TYPE_CHECKING: - from .environment import Environment - from .nodes import EvalContext - from .runtime import Context - from .runtime import Undefined - - V = t.TypeVar("V") - - -async def auto_to_seq( - value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", -) -> "t.List[V]": - seq = [] - - if hasattr(value, "__aiter__"): - async for item in t.cast(t.AsyncIterable, value): - seq.append(item) - else: - for item in t.cast(t.Iterable, value): - seq.append(item) - - return seq - - -async def async_select_or_reject( - context: "Context", - value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", - args: t.Tuple, - kwargs: t.Dict[str, t.Any], - modfunc: t.Callable[[t.Any], t.Any], - lookup_attr: bool, -) -> "t.AsyncIterator[V]": - if value: - func = filters.prepare_select_or_reject( - context, args, kwargs, modfunc, lookup_attr - ) - - async for item in auto_aiter(value): - if func(item): - yield item - - -def dual_filter(normal_func, async_func): - pass_arg = _PassArg.from_obj(normal_func) - wrapper_has_eval_context = False - - if pass_arg is _PassArg.environment: - wrapper_has_eval_context = False - - def is_async(args): - return args[0].is_async - - else: - wrapper_has_eval_context = pass_arg is None - - def is_async(args): - return args[0].environment.is_async - - @wraps(normal_func) - def wrapper(*args, **kwargs): - b = is_async(args) - - if wrapper_has_eval_context: - args = args[1:] - - if b: - return async_func(*args, **kwargs) - - return normal_func(*args, **kwargs) - - if wrapper_has_eval_context: - wrapper = pass_eval_context(wrapper) - - wrapper.jinja_async_variant = True - return wrapper - - -def async_variant(original): - def decorator(f): - return dual_filter(original, f) - - return decorator - - -def asyncfiltervariant(original): - warnings.warn( - "'asyncfiltervariant' is renamed to 'async_variant', the old" - " name will be removed in Jinja 3.1.", - DeprecationWarning, - stacklevel=2, - ) - return async_variant(original) - - -@async_variant(filters.do_first) -async def do_first( - environment: "Environment", seq: "t.Union[t.AsyncIterable[V], t.Iterable[V]]" -) -> "t.Union[V, Undefined]": - try: - return t.cast("V", await auto_aiter(seq).__anext__()) - except StopAsyncIteration: - return environment.undefined("No first item, sequence was empty.") - - -@async_variant(filters.do_groupby) -async def do_groupby( - environment: "Environment", - value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", - attribute: t.Union[str, int], - default: t.Optional[t.Any] = None, -) -> "t.List[t.Tuple[t.Any, t.List[V]]]": - expr = filters.make_attrgetter(environment, attribute, default=default) - return [ - filters._GroupTuple(key, await auto_to_seq(values)) - for key, values in groupby(sorted(await auto_to_seq(value), key=expr), expr) - ] - - -@async_variant(filters.do_join) -async def do_join( - eval_ctx: "EvalContext", - value: t.Union[t.AsyncIterable, t.Iterable], - d: str = "", - attribute: t.Optional[t.Union[str, int]] = None, -) -> str: - return filters.do_join(eval_ctx, await auto_to_seq(value), d, attribute) - - -@async_variant(filters.do_list) -async def do_list(value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]") -> "t.List[V]": - return await auto_to_seq(value) - - -@async_variant(filters.do_reject) -async def do_reject( - context: "Context", - value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", - *args: t.Any, - **kwargs: t.Any, -) -> "t.AsyncIterator[V]": - return async_select_or_reject(context, value, args, kwargs, lambda x: not x, False) - - -@async_variant(filters.do_rejectattr) -async def do_rejectattr( - context: "Context", - value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", - *args: t.Any, - **kwargs: t.Any, -) -> "t.AsyncIterator[V]": - return async_select_or_reject(context, value, args, kwargs, lambda x: not x, True) - - -@async_variant(filters.do_select) -async def do_select( - context: "Context", - value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", - *args: t.Any, - **kwargs: t.Any, -) -> "t.AsyncIterator[V]": - return async_select_or_reject(context, value, args, kwargs, lambda x: x, False) - - -@async_variant(filters.do_selectattr) -async def do_selectattr( - context: "Context", - value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", - *args: t.Any, - **kwargs: t.Any, -) -> "t.AsyncIterator[V]": - return async_select_or_reject(context, value, args, kwargs, lambda x: x, True) - - -@typing.overload -def do_map( - context: "Context", - value: t.Union[t.AsyncIterable, t.Iterable], - name: str, - *args: t.Any, - **kwargs: t.Any, -) -> t.Iterable: - ... - - -@typing.overload -def do_map( - context: "Context", - value: t.Union[t.AsyncIterable, t.Iterable], - *, - attribute: str = ..., - default: t.Optional[t.Any] = None, -) -> t.Iterable: - ... - - -@async_variant(filters.do_map) -async def do_map(context, value, *args, **kwargs): - if value: - func = filters.prepare_map(context, args, kwargs) - - async for item in auto_aiter(value): - yield await auto_await(func(item)) - - -@async_variant(filters.do_sum) -async def do_sum( - environment: "Environment", - iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", - attribute: t.Optional[t.Union[str, int]] = None, - start: "V" = 0, # type: ignore -) -> "V": - rv = start - - if attribute is not None: - func = filters.make_attrgetter(environment, attribute) - else: - - def func(x): - return x - - async for item in auto_aiter(iterable): - rv += func(item) - - return rv - - -@async_variant(filters.do_slice) -async def do_slice( - value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", - slices: int, - fill_with: t.Optional[t.Any] = None, -) -> "t.Iterator[t.List[V]]": - return filters.do_slice(await auto_to_seq(value), slices, fill_with) - - -ASYNC_FILTERS = { - "first": do_first, - "groupby": do_groupby, - "join": do_join, - "list": do_list, - # we intentionally do not support do_last because it may not be safe in async - "reject": do_reject, - "rejectattr": do_rejectattr, - "map": do_map, - "select": do_select, - "selectattr": do_selectattr, - "sum": do_sum, - "slice": do_slice, -} diff --git a/src/jinja2/asyncsupport.py b/src/jinja2/asyncsupport.py deleted file mode 100644 index e46a85a..0000000 --- a/src/jinja2/asyncsupport.py +++ /dev/null @@ -1,249 +0,0 @@ -"""The code for async support. Importing this patches Jinja.""" -import asyncio -import inspect -from functools import update_wrapper - -from markupsafe import Markup - -from .environment import TemplateModule -from .runtime import LoopContext -from .utils import concat -from .utils import internalcode -from .utils import missing - - -async def concat_async(async_gen): - rv = [] - - async def collect(): - async for event in async_gen: - rv.append(event) - - await collect() - return concat(rv) - - -async def generate_async(self, *args, **kwargs): - vars = dict(*args, **kwargs) - try: - async for event in self.root_render_func(self.new_context(vars)): - yield event - except Exception: - yield self.environment.handle_exception() - - -def wrap_generate_func(original_generate): - def _convert_generator(self, loop, args, kwargs): - async_gen = self.generate_async(*args, **kwargs) - try: - while 1: - yield loop.run_until_complete(async_gen.__anext__()) - except StopAsyncIteration: - pass - - def generate(self, *args, **kwargs): - if not self.environment.is_async: - return original_generate(self, *args, **kwargs) - return _convert_generator(self, asyncio.get_event_loop(), args, kwargs) - - return update_wrapper(generate, original_generate) - - -async def render_async(self, *args, **kwargs): - if not self.environment.is_async: - raise RuntimeError("The environment was not created with async mode enabled.") - - vars = dict(*args, **kwargs) - ctx = self.new_context(vars) - - try: - return await concat_async(self.root_render_func(ctx)) - except Exception: - return self.environment.handle_exception() - - -def wrap_render_func(original_render): - def render(self, *args, **kwargs): - if not self.environment.is_async: - return original_render(self, *args, **kwargs) - loop = asyncio.get_event_loop() - return loop.run_until_complete(self.render_async(*args, **kwargs)) - - return update_wrapper(render, original_render) - - -def wrap_block_reference_call(original_call): - @internalcode - async def async_call(self): - rv = await concat_async(self._stack[self._depth](self._context)) - if self._context.eval_ctx.autoescape: - rv = Markup(rv) - return rv - - @internalcode - def __call__(self): - if not self._context.environment.is_async: - return original_call(self) - return async_call(self) - - return update_wrapper(__call__, original_call) - - -def wrap_macro_invoke(original_invoke): - @internalcode - async def async_invoke(self, arguments, autoescape): - rv = await self._func(*arguments) - if autoescape: - rv = Markup(rv) - return rv - - @internalcode - def _invoke(self, arguments, autoescape): - if not self._environment.is_async: - return original_invoke(self, arguments, autoescape) - return async_invoke(self, arguments, autoescape) - - return update_wrapper(_invoke, original_invoke) - - -@internalcode -async def get_default_module_async(self): - if self._module is not None: - return self._module - self._module = rv = await self.make_module_async() - return rv - - -def wrap_default_module(original_default_module): - @internalcode - def _get_default_module(self, ctx=None): - if self.environment.is_async: - raise RuntimeError("Template module attribute is unavailable in async mode") - return original_default_module(self, ctx) - - return _get_default_module - - -async def make_module_async(self, vars=None, shared=False, locals=None): - context = self.new_context(vars, shared, locals) - body_stream = [] - async for item in self.root_render_func(context): - body_stream.append(item) - return TemplateModule(self, context, body_stream) - - -def patch_template(): - from . import Template - - Template.generate = wrap_generate_func(Template.generate) - Template.generate_async = update_wrapper(generate_async, Template.generate_async) - Template.render_async = update_wrapper(render_async, Template.render_async) - Template.render = wrap_render_func(Template.render) - Template._get_default_module = wrap_default_module(Template._get_default_module) - Template._get_default_module_async = get_default_module_async - Template.make_module_async = update_wrapper( - make_module_async, Template.make_module_async - ) - - -def patch_runtime(): - from .runtime import BlockReference, Macro - - BlockReference.__call__ = wrap_block_reference_call(BlockReference.__call__) - Macro._invoke = wrap_macro_invoke(Macro._invoke) - - -def patch_filters(): - from .filters import FILTERS - from .asyncfilters import ASYNC_FILTERS - - FILTERS.update(ASYNC_FILTERS) - - -def patch_all(): - patch_template() - patch_runtime() - patch_filters() - - -async def auto_await(value): - if inspect.isawaitable(value): - return await value - return value - - -async def auto_aiter(iterable): - if hasattr(iterable, "__aiter__"): - async for item in iterable: - yield item - return - for item in iterable: - yield item - - -class AsyncLoopContext(LoopContext): - _to_iterator = staticmethod(auto_aiter) - - @property - async def length(self): - if self._length is not None: - return self._length - - try: - self._length = len(self._iterable) - except TypeError: - iterable = [x async for x in self._iterator] - self._iterator = self._to_iterator(iterable) - self._length = len(iterable) + self.index + (self._after is not missing) - - return self._length - - @property - async def revindex0(self): - return await self.length - self.index - - @property - async def revindex(self): - return await self.length - self.index0 - - async def _peek_next(self): - if self._after is not missing: - return self._after - - try: - self._after = await self._iterator.__anext__() - except StopAsyncIteration: - self._after = missing - - return self._after - - @property - async def last(self): - return await self._peek_next() is missing - - @property - async def nextitem(self): - rv = await self._peek_next() - - if rv is missing: - return self._undefined("there is no next item") - - return rv - - def __aiter__(self): - return self - - async def __anext__(self): - if self._after is not missing: - rv = self._after - self._after = missing - else: - rv = await self._iterator.__anext__() - - self.index0 += 1 - self._before = self._current - self._current = rv - return rv, self - - -patch_all() diff --git a/src/jinja2/compiler.py b/src/jinja2/compiler.py index 1d73f7d..b15fb67 100644 --- a/src/jinja2/compiler.py +++ b/src/jinja2/compiler.py @@ -727,16 +727,15 @@ class CodeGenerator(NodeVisitor): assert frame is None, "no root frame allowed" eval_ctx = EvalContext(self.environment, self.name) - from .runtime import exported - - self.writeline("from __future__ import generator_stop") # Python < 3.7 - self.writeline("from jinja2.runtime import " + ", ".join(exported)) + from .runtime import exported, async_exported if self.environment.is_async: - self.writeline( - "from jinja2.asyncsupport import auto_await, " - "auto_aiter, AsyncLoopContext" - ) + exported_names = sorted(exported + async_exported) + else: + exported_names = sorted(exported) + + self.writeline("from __future__ import generator_stop") # Python < 3.7 + self.writeline("from jinja2.runtime import " + ", ".join(exported_names)) # if we want a deferred initialization we cannot move the # environment into a local name diff --git a/src/jinja2/environment.py b/src/jinja2/environment.py index 2a64a0a..ae68738 100644 --- a/src/jinja2/environment.py +++ b/src/jinja2/environment.py @@ -45,7 +45,6 @@ from .runtime import Undefined from .utils import _PassArg from .utils import concat from .utils import consume -from .utils import have_async_gen from .utils import import_string from .utils import internalcode from .utils import LRUCache @@ -342,12 +341,7 @@ class Environment: # load extensions self.extensions = load_extensions(self, extensions) - self.enable_async = enable_async - self.is_async = self.enable_async and have_async_gen - if self.is_async: - # runs patch_all() to enable async support - from . import asyncsupport # noqa: F401 - + self.is_async = enable_async _environment_sanity_check(self) def add_extension(self, extension): @@ -1119,13 +1113,20 @@ class Template: This will return the rendered template as a string. """ - vars = dict(*args, **kwargs) + if self.environment.is_async: + import asyncio + + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.render_async(*args, **kwargs)) + + ctx = self.new_context(dict(*args, **kwargs)) + try: - return concat(self.root_render_func(self.new_context(vars))) + return concat(self.root_render_func(ctx)) except Exception: self.environment.handle_exception() - def render_async(self, *args, **kwargs): + async def render_async(self, *args, **kwargs): """This works similar to :meth:`render` but returns a coroutine that when awaited returns the entire rendered template string. This requires the async feature to be enabled. @@ -1134,10 +1135,17 @@ class Template: await template.render_async(knights='that say nih; asynchronously') """ - # see asyncsupport for the actual implementation - raise NotImplementedError( - "This feature is not available for this version of Python" - ) + if not self.environment.is_async: + raise RuntimeError( + "The environment was not created with async mode enabled." + ) + + ctx = self.new_context(dict(*args, **kwargs)) + + try: + return concat([n async for n in self.root_render_func(ctx)]) + except Exception: + return self.environment.handle_exception() def stream(self, *args, **kwargs): """Works exactly like :meth:`generate` but returns a @@ -1153,20 +1161,41 @@ class Template: It accepts the same arguments as :meth:`render`. """ - vars = dict(*args, **kwargs) + if self.environment.is_async: + import asyncio + + loop = asyncio.get_event_loop() + async_gen = self.generate_async(*args, **kwargs) + + try: + while True: + yield loop.run_until_complete(async_gen.__anext__()) + except StopAsyncIteration: + return + + ctx = self.new_context(dict(*args, **kwargs)) + try: - yield from self.root_render_func(self.new_context(vars)) + yield from self.root_render_func(ctx) except Exception: yield self.environment.handle_exception() - def generate_async(self, *args, **kwargs): + async def generate_async(self, *args, **kwargs): """An async version of :meth:`generate`. Works very similarly but returns an async iterator instead. """ - # see asyncsupport for the actual implementation - raise NotImplementedError( - "This feature is not available for this version of Python" - ) + if not self.environment.is_async: + raise RuntimeError( + "The environment was not created with async mode enabled." + ) + + ctx = self.new_context(dict(*args, **kwargs)) + + try: + async for event in self.root_render_func(ctx): + yield event + except Exception: + yield self.environment.handle_exception() def new_context(self, vars=None, shared=False, locals=None): """Create a new :class:`Context` for this template. The vars @@ -1187,42 +1216,56 @@ class Template: a dict which is then used as context. The arguments are the same as for the :meth:`new_context` method. """ - return TemplateModule(self, self.new_context(vars, shared, locals)) + ctx = self.new_context(vars, shared, locals) + return TemplateModule(self, ctx) - def make_module_async(self, vars=None, shared=False, locals=None): + async def make_module_async(self, vars=None, shared=False, locals=None): """As template module creation can invoke template code for asynchronous executions this method must be used instead of the normal :meth:`make_module` one. Likewise the module attribute becomes unavailable in async mode. """ - # see asyncsupport for the actual implementation - raise NotImplementedError( - "This feature is not available for this version of Python" - ) + ctx = self.new_context(vars, shared, locals) + return TemplateModule(self, ctx, [x async for x in self.root_render_func(ctx)]) @internalcode def _get_default_module(self, ctx=None): """If a context is passed in, this means that the template was - imported. Imported templates have access to the current template's - globals by default, but they can only be accessed via the context - during runtime. - - If there are new globals, we need to create a new - module because the cached module is already rendered and will not have - access to globals from the current context. This new module is not - cached as :attr:`_module` because the template can be imported elsewhere, - and it should have access to only the current template's globals. + imported. Imported templates have access to the current + template's globals by default, but they can only be accessed via + the context during runtime. + + If there are new globals, we need to create a new module because + the cached module is already rendered and will not have access + to globals from the current context. This new module is not + cached because the template can be imported elsewhere, and it + should have access to only the current template's globals. """ + if self.environment.is_async: + raise RuntimeError("Module is not available in async mode.") + if ctx is not None: - globals = { - key: ctx.parent[key] for key in ctx.globals_keys - self.globals.keys() - } - if globals: - return self.make_module(globals) - if self._module is not None: - return self._module - self._module = rv = self.make_module() - return rv + keys = ctx.globals_keys - self.globals.keys() + + if keys: + return self.make_module({k: ctx.parent[k] for k in keys}) + + if self._module is None: + self._module = self.make_module() + + return self._module + + async def _get_default_module_async(self, ctx=None): + if ctx is not None: + keys = ctx.globals_keys - self.globals.keys() + + if keys: + return await self.make_module_async({k: ctx.parent[k] for k in keys}) + + if self._module is None: + self._module = await self.make_module_async() + + return self._module @property def module(self): diff --git a/src/jinja2/filters.py b/src/jinja2/filters.py index 82f2ff2..8aa11c2 100644 --- a/src/jinja2/filters.py +++ b/src/jinja2/filters.py @@ -13,6 +13,10 @@ from markupsafe import escape from markupsafe import Markup from markupsafe import soft_str +from .async_utils import async_variant +from .async_utils import auto_aiter +from .async_utils import auto_await +from .async_utils import auto_to_list from .exceptions import FilterArgumentError from .runtime import Undefined from .utils import htmlsafe_json_dumps @@ -550,7 +554,7 @@ def do_default( @pass_eval_context -def do_join( +def sync_do_join( eval_ctx: "EvalContext", value: t.Iterable, d: str = "", @@ -607,13 +611,23 @@ def do_join( return soft_str(d).join(map(soft_str, value)) +@async_variant(sync_do_join) +async def do_join( + eval_ctx: "EvalContext", + value: t.Union[t.AsyncIterable, t.Iterable], + d: str = "", + attribute: t.Optional[t.Union[str, int]] = None, +) -> str: + return sync_do_join(eval_ctx, await auto_to_list(value), d, attribute) + + def do_center(value: str, width: int = 80) -> str: """Centers the value in a field of a given width.""" return soft_str(value).center(width) @pass_environment -def do_first( +def sync_do_first( environment: "Environment", seq: "t.Iterable[V]" ) -> "t.Union[V, Undefined]": """Return the first item of a sequence.""" @@ -623,6 +637,16 @@ def do_first( return environment.undefined("No first item, sequence was empty.") +@async_variant(sync_do_first) +async def do_first( + environment: "Environment", seq: "t.Union[t.AsyncIterable[V], t.Iterable[V]]" +) -> "t.Union[V, Undefined]": + try: + return t.cast("V", await auto_aiter(seq).__anext__()) + except StopAsyncIteration: + return environment.undefined("No first item, sequence was empty.") + + @pass_environment def do_last( environment: "Environment", seq: "t.Reversible[V]" @@ -642,6 +666,9 @@ def do_last( return environment.undefined("No last item, sequence was empty.") +# No async do_last, it may not be safe in async mode. + + @pass_context def do_random(context: "Context", seq: "t.Sequence[V]") -> "t.Union[V, Undefined]": """Return a random item from the sequence.""" @@ -1006,7 +1033,7 @@ def do_striptags(value: "t.Union[str, HasHTML]") -> str: return Markup(str(value)).striptags() -def do_slice( +def sync_do_slice( value: "t.Collection[V]", slices: int, fill_with: "t.Optional[V]" = None ) -> "t.Iterator[t.List[V]]": """Slice an iterator and return a list of lists containing @@ -1049,6 +1076,15 @@ def do_slice( yield tmp +@async_variant(sync_do_slice) +async def do_slice( + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + slices: int, + fill_with: t.Optional[t.Any] = None, +) -> "t.Iterator[t.List[V]]": + return sync_do_slice(await auto_to_list(value), slices, fill_with) + + def do_batch( value: "t.Iterable[V]", linecount: int, fill_with: "t.Optional[V]" = None ) -> "t.Iterator[t.List[V]]": @@ -1140,7 +1176,7 @@ class _GroupTuple(t.NamedTuple): @pass_environment -def do_groupby( +def sync_do_groupby( environment: "Environment", value: "t.Iterable[V]", attribute: t.Union[str, int], @@ -1198,8 +1234,22 @@ def do_groupby( ] +@async_variant(sync_do_groupby) +async def do_groupby( + environment: "Environment", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + attribute: t.Union[str, int], + default: t.Optional[t.Any] = None, +) -> "t.List[t.Tuple[t.Any, t.List[V]]]": + expr = make_attrgetter(environment, attribute, default=default) + return [ + _GroupTuple(key, await auto_to_list(values)) + for key, values in groupby(sorted(await auto_to_list(value), key=expr), expr) + ] + + @pass_environment -def do_sum( +def sync_do_sum( environment: "Environment", iterable: "t.Iterable[V]", attribute: t.Optional[t.Union[str, int]] = None, @@ -1225,13 +1275,40 @@ def do_sum( return sum(iterable, start) -def do_list(value: "t.Iterable[V]") -> "t.List[V]": +@async_variant(sync_do_sum) +async def do_sum( + environment: "Environment", + iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + attribute: t.Optional[t.Union[str, int]] = None, + start: "V" = 0, # type: ignore +) -> "V": + rv = start + + if attribute is not None: + func = make_attrgetter(environment, attribute) + else: + + def func(x): + return x + + async for item in auto_aiter(iterable): + rv += func(item) + + return rv + + +def sync_do_list(value: "t.Iterable[V]") -> "t.List[V]": """Convert the value into a list. If it was a string the returned list will be a list of characters. """ return list(value) +@async_variant(sync_do_list) +async def do_list(value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]") -> "t.List[V]": + return await auto_to_list(value) + + def do_mark_safe(value: str) -> Markup: """Mark the value as safe which means that in an environment with automatic escaping enabled this variable will not be escaped. @@ -1304,14 +1381,14 @@ def do_attr( @typing.overload -def do_map( +def sync_do_map( context: "Context", value: t.Iterable, name: str, *args: t.Any, **kwargs: t.Any ) -> t.Iterable: ... @typing.overload -def do_map( +def sync_do_map( context: "Context", value: t.Iterable, *, @@ -1322,7 +1399,7 @@ def do_map( @pass_context -def do_map(context, value, *args, **kwargs): +def sync_do_map(context, value, *args, **kwargs): """Applies a filter on a sequence of objects or looks up an attribute. This is useful when dealing with lists of objects but you are really only interested in a certain value of it. @@ -1369,8 +1446,39 @@ def do_map(context, value, *args, **kwargs): yield func(item) +@typing.overload +def do_map( + context: "Context", + value: t.Union[t.AsyncIterable, t.Iterable], + name: str, + *args: t.Any, + **kwargs: t.Any, +) -> t.Iterable: + ... + + +@typing.overload +def do_map( + context: "Context", + value: t.Union[t.AsyncIterable, t.Iterable], + *, + attribute: str = ..., + default: t.Optional[t.Any] = None, +) -> t.Iterable: + ... + + +@async_variant(sync_do_map) +async def do_map(context, value, *args, **kwargs): + if value: + func = prepare_map(context, args, kwargs) + + async for item in auto_aiter(value): + yield await auto_await(func(item)) + + @pass_context -def do_select( +def sync_do_select( context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any ) -> "t.Iterator[V]": """Filters a sequence of objects by applying a test to each object, @@ -1400,8 +1508,18 @@ def do_select( return select_or_reject(context, value, args, kwargs, lambda x: x, False) +@async_variant(sync_do_select) +async def do_select( + context: "Context", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + *args: t.Any, + **kwargs: t.Any, +) -> "t.AsyncIterator[V]": + return async_select_or_reject(context, value, args, kwargs, lambda x: x, False) + + @pass_context -def do_reject( +def sync_do_reject( context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any ) -> "t.Iterator[V]": """Filters a sequence of objects by applying a test to each object, @@ -1426,8 +1544,18 @@ def do_reject( return select_or_reject(context, value, args, kwargs, lambda x: not x, False) +@async_variant(sync_do_reject) +async def do_reject( + context: "Context", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + *args: t.Any, + **kwargs: t.Any, +) -> "t.AsyncIterator[V]": + return async_select_or_reject(context, value, args, kwargs, lambda x: not x, False) + + @pass_context -def do_selectattr( +def sync_do_selectattr( context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any ) -> "t.Iterator[V]": """Filters a sequence of objects by applying a test to the specified @@ -1456,8 +1584,18 @@ def do_selectattr( return select_or_reject(context, value, args, kwargs, lambda x: x, True) +@async_variant(sync_do_selectattr) +async def do_selectattr( + context: "Context", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + *args: t.Any, + **kwargs: t.Any, +) -> "t.AsyncIterator[V]": + return async_select_or_reject(context, value, args, kwargs, lambda x: x, True) + + @pass_context -def do_rejectattr( +def sync_do_rejectattr( context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any ) -> "t.Iterator[V]": """Filters a sequence of objects by applying a test to the specified @@ -1484,6 +1622,16 @@ def do_rejectattr( return select_or_reject(context, value, args, kwargs, lambda x: not x, True) +@async_variant(sync_do_rejectattr) +async def do_rejectattr( + context: "Context", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + *args: t.Any, + **kwargs: t.Any, +) -> "t.AsyncIterator[V]": + return async_select_or_reject(context, value, args, kwargs, lambda x: not x, True) + + @pass_eval_context def do_tojson( eval_ctx: "EvalContext", value: t.Any, indent: t.Optional[int] = None @@ -1591,6 +1739,22 @@ def select_or_reject( yield item +async def async_select_or_reject( + context: "Context", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + args: t.Tuple, + kwargs: t.Dict[str, t.Any], + modfunc: t.Callable[[t.Any], t.Any], + lookup_attr: bool, +) -> "t.AsyncIterator[V]": + if value: + func = prepare_select_or_reject(context, args, kwargs, modfunc, lookup_attr) + + async for item in auto_aiter(value): + if func(item): + yield item + + FILTERS = { "abs": abs, "attr": do_attr, diff --git a/src/jinja2/nativetypes.py b/src/jinja2/nativetypes.py index 8867a31..6cca518 100644 --- a/src/jinja2/nativetypes.py +++ b/src/jinja2/nativetypes.py @@ -86,10 +86,10 @@ class NativeTemplate(Template): with :func:`ast.literal_eval`, the parsed value is returned. Otherwise, the string is returned. """ - vars = dict(*args, **kwargs) + ctx = self.new_context(dict(*args, **kwargs)) try: - return native_concat(self.root_render_func(self.new_context(vars))) + return native_concat(self.root_render_func(ctx)) except Exception: return self.environment.handle_exception() @@ -99,8 +99,7 @@ class NativeTemplate(Template): "The environment was not created with async mode enabled." ) - vars = dict(*args, **kwargs) - ctx = self.new_context(vars) + ctx = self.new_context(dict(*args, **kwargs)) try: return native_concat([n async for n in self.root_render_func(ctx)]) diff --git a/src/jinja2/runtime.py b/src/jinja2/runtime.py index 3d55819..0ce4930 100644 --- a/src/jinja2/runtime.py +++ b/src/jinja2/runtime.py @@ -9,6 +9,8 @@ from markupsafe import escape # noqa: F401 from markupsafe import Markup from markupsafe import soft_str +from .async_utils import auto_aiter +from .async_utils import auto_await # noqa: F401 from .exceptions import TemplateNotFound # noqa: F401 from .exceptions import TemplateRuntimeError # noqa: F401 from .exceptions import UndefinedError @@ -42,6 +44,11 @@ exported = [ "Undefined", "internalcode", ] +async_exported = [ + "AsyncLoopContext", + "auto_aiter", + "auto_await", +] def identity(x): @@ -369,10 +376,24 @@ class BlockReference: return BlockReference(self.name, self._context, self._stack, self._depth + 1) @internalcode + async def _async_call(self): + rv = concat([x async for x in self._stack[self._depth](self._context)]) + + if self._context.eval_ctx.autoescape: + return Markup(rv) + + return rv + + @internalcode def __call__(self): + if self._context.environment.is_async: + return self._async_call() + rv = concat(self._stack[self._depth](self._context)) + if self._context.eval_ctx.autoescape: - rv = Markup(rv) + return Markup(rv) + return rv @@ -567,6 +588,73 @@ class LoopContext: return f"<{self.__class__.__name__} {self.index}/{self.length}>" +class AsyncLoopContext(LoopContext): + @staticmethod + def _to_iterator(iterable): + return auto_aiter(iterable) + + @property + async def length(self): + if self._length is not None: + return self._length + + try: + self._length = len(self._iterable) + except TypeError: + iterable = [x async for x in self._iterator] + self._iterator = self._to_iterator(iterable) + self._length = len(iterable) + self.index + (self._after is not missing) + + return self._length + + @property + async def revindex0(self): + return await self.length - self.index + + @property + async def revindex(self): + return await self.length - self.index0 + + async def _peek_next(self): + if self._after is not missing: + return self._after + + try: + self._after = await self._iterator.__anext__() + except StopAsyncIteration: + self._after = missing + + return self._after + + @property + async def last(self): + return await self._peek_next() is missing + + @property + async def nextitem(self): + rv = await self._peek_next() + + if rv is missing: + return self._undefined("there is no next item") + + return rv + + def __aiter__(self): + return self + + async def __anext__(self): + if self._after is not missing: + rv = self._after + self._after = missing + else: + rv = await self._iterator.__anext__() + + self.index0 += 1 + self._before = self._current + self._current = rv + return rv, self + + class Macro: """Wraps a macro function.""" @@ -672,11 +760,23 @@ class Macro: return self._invoke(arguments, autoescape) + async def _async_invoke(self, arguments, autoescape): + rv = await self._func(*arguments) + + if autoescape: + return Markup(rv) + + return rv + def _invoke(self, arguments, autoescape): - """This method is being swapped out by the async implementation.""" + if self._environment.is_async: + return self._async_invoke(arguments, autoescape) + rv = self._func(*arguments) + if autoescape: rv = Markup(rv) + return rv def __repr__(self): diff --git a/src/jinja2/utils.py b/src/jinja2/utils.py index 80769a7..c49dbb5 100644 --- a/src/jinja2/utils.py +++ b/src/jinja2/utils.py @@ -20,13 +20,10 @@ if t.TYPE_CHECKING: # special singleton representing missing values for the runtime missing = type("MissingType", (), {"__repr__": lambda x: "missing"})() -# internal code internal_code: t.MutableSet[CodeType] = set() concat = "".join -_slash_escape = "\\/" not in json.dumps("/") - def pass_context(f: "F") -> "F": """Pass the :class:`~jinja2.runtime.Context` as the first argument @@ -832,14 +829,6 @@ class Namespace: return f"<Namespace {self.__attrs!r}>" -# does this python version support async for in and async generators? -try: - exec("async def _():\n async for _ in ():\n yield _") - have_async_gen = True -except SyntaxError: - have_async_gen = False - - class Markup(markupsafe.Markup): def __init__(self, *args, **kwargs): warnings.warn( |