diff --git a/__init__.py b/__init__.py index af5d428..1a423a3 100755 --- a/__init__.py +++ b/__init__.py @@ -2,6 +2,7 @@ non-XML syntax that supports inline expressions and an optional sandboxed environment. """ + from .bccache import BytecodeCache as BytecodeCache from .bccache import FileSystemBytecodeCache as FileSystemBytecodeCache from .bccache import MemcachedBytecodeCache as MemcachedBytecodeCache @@ -34,4 +35,4 @@ from .utils import pass_environment as pass_environment from .utils import pass_eval_context as pass_eval_context from .utils import select_autoescape as select_autoescape -__version__ = "3.1.3" +__version__ = "3.1.6" diff --git a/async_utils.py b/async_utils.py index 715d701..f0c1402 100644 --- a/async_utils.py +++ b/async_utils.py @@ -6,6 +6,9 @@ from functools import wraps from .utils import _PassArg from .utils import pass_eval_context +if t.TYPE_CHECKING: + import typing_extensions as te + V = t.TypeVar("V") @@ -47,7 +50,7 @@ def async_variant(normal_func): # type: ignore if need_eval_context: wrapper = pass_eval_context(wrapper) - wrapper.jinja_async_variant = True + wrapper.jinja_async_variant = True # type: ignore[attr-defined] return wrapper return decorator @@ -64,18 +67,30 @@ async def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V": if inspect.isawaitable(value): return await t.cast("t.Awaitable[V]", value) - return t.cast("V", value) + return value -async def auto_aiter( +class _IteratorToAsyncIterator(t.Generic[V]): + def __init__(self, iterator: "t.Iterator[V]"): + self._iterator = iterator + + def __aiter__(self) -> "te.Self": + return self + + async def __anext__(self) -> V: + try: + return next(self._iterator) + except StopIteration as e: + raise StopAsyncIteration(e.value) from e + + +def auto_aiter( iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", ) -> "t.AsyncIterator[V]": if hasattr(iterable, "__aiter__"): - async for item in t.cast("t.AsyncIterable[V]", iterable): - yield item + return iterable.__aiter__() else: - for item in iterable: - yield item + return _IteratorToAsyncIterator(iter(iterable)) async def auto_to_list( diff --git a/bccache.py b/bccache.py index d0ddf56..ada8b09 100755 --- a/bccache.py +++ b/bccache.py @@ -5,6 +5,7 @@ slows down your application too much. Situations where this is useful are often forking web applications that are initialized on the first request. """ + import errno import fnmatch import marshal @@ -20,14 +21,15 @@ from types import CodeType if t.TYPE_CHECKING: import typing_extensions as te + from .environment import Environment class _MemcachedClient(te.Protocol): - def get(self, key: str) -> bytes: - ... + def get(self, key: str) -> bytes: ... - def set(self, key: str, value: bytes, timeout: t.Optional[int] = None) -> None: - ... + def set( + self, key: str, value: bytes, timeout: t.Optional[int] = None + ) -> None: ... bc_version = 5 diff --git a/compiler.py b/compiler.py index ff95c80..a4ff6a1 100755 --- a/compiler.py +++ b/compiler.py @@ -1,4 +1,5 @@ """Compiles nodes from the parser into Python code.""" + import typing as t from contextlib import contextmanager from functools import update_wrapper @@ -24,6 +25,7 @@ from .visitor import NodeVisitor if t.TYPE_CHECKING: import typing_extensions as te + from .environment import Environment F = t.TypeVar("F", bound=t.Callable[..., t.Any]) @@ -53,15 +55,14 @@ def optimizeconst(f: F) -> F: return f(self, node, frame, **kwargs) - return update_wrapper(t.cast(F, new_func), f) + return update_wrapper(new_func, f) # type: ignore[return-value] def _make_binop(op: str) -> t.Callable[["CodeGenerator", nodes.BinExpr, "Frame"], None]: @optimizeconst def visitor(self: "CodeGenerator", node: nodes.BinExpr, frame: Frame) -> None: if ( - self.environment.sandboxed - and op in self.environment.intercepted_binops # type: ignore + self.environment.sandboxed and op in self.environment.intercepted_binops # type: ignore ): self.write(f"environment.call_binop(context, {op!r}, ") self.visit(node.left, frame) @@ -84,8 +85,7 @@ def _make_unop( @optimizeconst def visitor(self: "CodeGenerator", node: nodes.UnaryExpr, frame: Frame) -> None: if ( - self.environment.sandboxed - and op in self.environment.intercepted_unops # type: ignore + self.environment.sandboxed and op in self.environment.intercepted_unops # type: ignore ): self.write(f"environment.call_unop(context, {op!r}, ") self.visit(node.node, frame) @@ -133,7 +133,7 @@ def has_safe_repr(value: t.Any) -> bool: if type(value) in {tuple, list, set, frozenset}: return all(has_safe_repr(v) for v in value) - if type(value) is dict: + if type(value) is dict: # noqa E721 return all(has_safe_repr(k) and has_safe_repr(v) for k, v in value.items()) return False @@ -216,7 +216,7 @@ class Frame: # or compile time. self.soft_frame = False - def copy(self) -> "Frame": + def copy(self) -> "te.Self": """Create a copy of the current one.""" rv = object.__new__(self.__class__) rv.__dict__.update(self.__dict__) @@ -229,7 +229,7 @@ class Frame: return Frame(self.eval_ctx, level=self.symbols.level + 1) return Frame(self.eval_ctx, self) - def soft(self) -> "Frame": + def soft(self) -> "te.Self": """Return a soft frame. A soft frame may not be modified as standalone thing as it shares the resources with the frame it was created of, but it's not a rootlevel frame any longer. @@ -551,10 +551,13 @@ class CodeGenerator(NodeVisitor): for node in nodes: visitor.visit(node) - for id_map, names, dependency in (self.filters, visitor.filters, "filters"), ( - self.tests, - visitor.tests, - "tests", + for id_map, names, dependency in ( + (self.filters, visitor.filters, "filters"), + ( + self.tests, + visitor.tests, + "tests", + ), ): for name in sorted(names): if name not in id_map: @@ -808,7 +811,7 @@ class CodeGenerator(NodeVisitor): self.writeline("_block_vars.update({") else: self.writeline("context.vars.update({") - for idx, name in enumerate(vars): + for idx, name in enumerate(sorted(vars)): if idx: self.write(", ") ref = frame.symbols.ref(name) @@ -818,7 +821,7 @@ class CodeGenerator(NodeVisitor): if len(public_names) == 1: self.writeline(f"context.exported_vars.add({public_names[0]!r})") else: - names_str = ", ".join(map(repr, public_names)) + names_str = ", ".join(map(repr, sorted(public_names))) self.writeline(f"context.exported_vars.update(({names_str}))") # -- Statement Visitors @@ -829,7 +832,8 @@ class CodeGenerator(NodeVisitor): assert frame is None, "no root frame allowed" eval_ctx = EvalContext(self.environment, self.name) - from .runtime import exported, async_exported + from .runtime import async_exported + from .runtime import exported if self.environment.is_async: exported_names = sorted(exported + async_exported) @@ -898,12 +902,15 @@ class CodeGenerator(NodeVisitor): if not self.environment.is_async: self.writeline("yield from parent_template.root_render_func(context)") else: - self.writeline( - "async for event in parent_template.root_render_func(context):" - ) + self.writeline("agen = parent_template.root_render_func(context)") + self.writeline("try:") + self.indent() + self.writeline("async for event in agen:") self.indent() self.writeline("yield event") self.outdent() + self.outdent() + self.writeline("finally: await agen.aclose()") self.outdent(1 + (not self.has_known_extends)) # at this point we now have the blocks collected and can visit them too. @@ -973,14 +980,20 @@ class CodeGenerator(NodeVisitor): f"yield from context.blocks[{node.name!r}][0]({context})", node ) else: + self.writeline(f"gen = context.blocks[{node.name!r}][0]({context})") + self.writeline("try:") + self.indent() self.writeline( - f"{self.choose_async()}for event in" - f" context.blocks[{node.name!r}][0]({context}):", + f"{self.choose_async()}for event in gen:", node, ) self.indent() self.simple_write("event", frame) self.outdent() + self.outdent() + self.writeline( + f"finally: {self.choose_async('await gen.aclose()', 'gen.close()')}" + ) self.outdent(level) @@ -1053,26 +1066,33 @@ class CodeGenerator(NodeVisitor): self.writeline("else:") self.indent() - skip_event_yield = False + def loop_body() -> None: + self.indent() + self.simple_write("event", frame) + self.outdent() + if node.with_context: self.writeline( - f"{self.choose_async()}for event in template.root_render_func(" + f"gen = template.root_render_func(" "template.new_context(context.get_all(), True," - f" {self.dump_local_context(frame)})):" + f" {self.dump_local_context(frame)}))" + ) + self.writeline("try:") + self.indent() + self.writeline(f"{self.choose_async()}for event in gen:") + loop_body() + self.outdent() + self.writeline( + f"finally: {self.choose_async('await gen.aclose()', 'gen.close()')}" ) elif self.environment.is_async: self.writeline( "for event in (await template._get_default_module_async())" "._body_stream:" ) + loop_body() else: self.writeline("yield from template._get_default_module()._body_stream") - skip_event_yield = True - - if not skip_event_yield: - self.indent() - self.simple_write("event", frame) - self.outdent() if node.ignore_missing: self.outdent() @@ -1121,9 +1141,14 @@ class CodeGenerator(NodeVisitor): ) self.writeline(f"if {frame.symbols.ref(alias)} is missing:") self.indent() + # The position will contain the template name, and will be formatted + # into a string that will be compiled into an f-string. Curly braces + # in the name must be replaced with escapes so that they will not be + # executed as part of the f-string. + position = self.position(node).replace("{", "{{").replace("}", "}}") message = ( "the template {included_template.__name__!r}" - f" (imported on {self.position(node)})" + f" (imported on {position})" f" does not export the requested name {name!r}" ) self.writeline( @@ -1556,6 +1581,29 @@ class CodeGenerator(NodeVisitor): def visit_Assign(self, node: nodes.Assign, frame: Frame) -> None: self.push_assign_tracking() + + # ``a.b`` is allowed for assignment, and is parsed as an NSRef. However, + # it is only valid if it references a Namespace object. Emit a check for + # that for each ref here, before assignment code is emitted. This can't + # be done in visit_NSRef as the ref could be in the middle of a tuple. + seen_refs: t.Set[str] = set() + + for nsref in node.find_all(nodes.NSRef): + if nsref.name in seen_refs: + # Only emit the check for each reference once, in case the same + # ref is used multiple times in a tuple, `ns.a, ns.b = c, d`. + continue + + seen_refs.add(nsref.name) + ref = frame.symbols.ref(nsref.name) + self.writeline(f"if not isinstance({ref}, Namespace):") + self.indent() + self.writeline( + "raise TemplateRuntimeError" + '("cannot assign attribute on non-namespace object")' + ) + self.outdent() + self.newline(node) self.visit(node.target, frame) self.write(" = ") @@ -1612,17 +1660,11 @@ class CodeGenerator(NodeVisitor): self.write(ref) def visit_NSRef(self, node: nodes.NSRef, frame: Frame) -> None: - # NSRefs can only be used to store values; since they use the normal - # `foo.bar` notation they will be parsed as a normal attribute access - # when used anywhere but in a `set` context + # NSRef is a dotted assignment target a.b=c, but uses a[b]=c internally. + # visit_Assign emits code to validate that each ref is to a Namespace + # object only. That can't be emitted here as the ref could be in the + # middle of a tuple assignment. ref = frame.symbols.ref(node.name) - self.writeline(f"if not isinstance({ref}, Namespace):") - self.indent() - self.writeline( - "raise TemplateRuntimeError" - '("cannot assign attribute on non-namespace object")' - ) - self.outdent() self.writeline(f"{ref}[{node.attr!r}]") def visit_Const(self, node: nodes.Const, frame: Frame) -> None: diff --git a/debug.py b/debug.py index 7ed7e92..eeeeee7 100755 --- a/debug.py +++ b/debug.py @@ -152,7 +152,7 @@ def get_template_locals(real_locals: t.Mapping[str, t.Any]) -> t.Dict[str, t.Any available at that point in the template. """ # Start with the current template context. - ctx: "t.Optional[Context]" = real_locals.get("context") + ctx: t.Optional[Context] = real_locals.get("context") if ctx is not None: data: t.Dict[str, t.Any] = ctx.get_all().copy() diff --git a/environment.py b/environment.py index 185d332..0fc6e5b 100755 --- a/environment.py +++ b/environment.py @@ -1,6 +1,7 @@ """Classes for managing templates and their runtime and compile time options. """ + import os import typing import typing as t @@ -20,10 +21,10 @@ from .defaults import BLOCK_END_STRING from .defaults import BLOCK_START_STRING from .defaults import COMMENT_END_STRING from .defaults import COMMENT_START_STRING -from .defaults import DEFAULT_FILTERS +from .defaults import DEFAULT_FILTERS # type: ignore[attr-defined] from .defaults import DEFAULT_NAMESPACE from .defaults import DEFAULT_POLICIES -from .defaults import DEFAULT_TESTS +from .defaults import DEFAULT_TESTS # type: ignore[attr-defined] from .defaults import KEEP_TRAILING_NEWLINE from .defaults import LINE_COMMENT_PREFIX from .defaults import LINE_STATEMENT_PREFIX @@ -55,6 +56,7 @@ from .utils import missing if t.TYPE_CHECKING: import typing_extensions as te + from .bccache import BytecodeCache from .ext import Extension from .loaders import BaseLoader @@ -79,7 +81,7 @@ def get_spontaneous_environment(cls: t.Type[_env_bound], *args: t.Any) -> _env_b def create_cache( size: int, -) -> t.Optional[t.MutableMapping[t.Tuple[weakref.ref, str], "Template"]]: +) -> t.Optional[t.MutableMapping[t.Tuple["weakref.ref[t.Any]", str], "Template"]]: """Return the cache class for the given size.""" if size == 0: return None @@ -91,13 +93,13 @@ def create_cache( def copy_cache( - cache: t.Optional[t.MutableMapping], -) -> t.Optional[t.MutableMapping[t.Tuple[weakref.ref, str], "Template"]]: + cache: t.Optional[t.MutableMapping[t.Any, t.Any]], +) -> t.Optional[t.MutableMapping[t.Tuple["weakref.ref[t.Any]", str], "Template"]]: """Create an empty copy of the given cache.""" if cache is None: return None - if type(cache) is dict: + if type(cache) is dict: # noqa E721 return {} return LRUCache(cache.capacity) # type: ignore @@ -121,7 +123,7 @@ def load_extensions( return result -def _environment_config_check(environment: "Environment") -> "Environment": +def _environment_config_check(environment: _env_bound) -> _env_bound: """Perform a sanity check on the environment.""" assert issubclass( environment.undefined, Undefined @@ -404,8 +406,8 @@ class Environment: cache_size: int = missing, auto_reload: bool = missing, bytecode_cache: t.Optional["BytecodeCache"] = missing, - enable_async: bool = False, - ) -> "Environment": + enable_async: bool = missing, + ) -> "te.Self": """Create a new overlay environment that shares all the data with the current environment except for cache and the overridden attributes. Extensions cannot be removed for an overlayed environment. An overlayed @@ -417,8 +419,11 @@ class Environment: copied over so modifications on the original environment may not shine through. + .. versionchanged:: 3.1.5 + ``enable_async`` is applied correctly. + .. versionchanged:: 3.1.2 - Added the ``newline_sequence``,, ``keep_trailing_newline``, + Added the ``newline_sequence``, ``keep_trailing_newline``, and ``enable_async`` parameters to match ``__init__``. """ args = dict(locals()) @@ -670,7 +675,7 @@ class Environment: stream = ext.filter_stream(stream) # type: ignore if not isinstance(stream, TokenStream): - stream = TokenStream(stream, name, filename) # type: ignore + stream = TokenStream(stream, name, filename) return stream @@ -704,15 +709,14 @@ class Environment: return compile(source, filename, "exec") @typing.overload - def compile( # type: ignore + def compile( self, source: t.Union[str, nodes.Template], name: t.Optional[str] = None, filename: t.Optional[str] = None, raw: "te.Literal[False]" = False, defer_init: bool = False, - ) -> CodeType: - ... + ) -> CodeType: ... @typing.overload def compile( @@ -722,8 +726,7 @@ class Environment: filename: t.Optional[str] = None, raw: "te.Literal[True]" = ..., defer_init: bool = False, - ) -> str: - ... + ) -> str: ... @internalcode def compile( @@ -814,7 +817,7 @@ class Environment: def compile_templates( self, - target: t.Union[str, os.PathLike], + target: t.Union[str, "os.PathLike[str]"], extensions: t.Optional[t.Collection[str]] = None, filter_func: t.Optional[t.Callable[[str], bool]] = None, zip: t.Optional[str] = "deflated", @@ -858,7 +861,10 @@ class Environment: f.write(data.encode("utf8")) if zip is not None: - from zipfile import ZipFile, ZipInfo, ZIP_DEFLATED, ZIP_STORED + from zipfile import ZIP_DEFLATED + from zipfile import ZIP_STORED + from zipfile import ZipFile + from zipfile import ZipInfo zip_file = ZipFile( target, "w", dict(deflated=ZIP_DEFLATED, stored=ZIP_STORED)[zip] @@ -1245,7 +1251,7 @@ class Template: namespace: t.MutableMapping[str, t.Any], globals: t.MutableMapping[str, t.Any], ) -> "Template": - t: "Template" = object.__new__(cls) + t: Template = object.__new__(cls) t.environment = environment t.globals = globals t.name = namespace["name"] @@ -1279,19 +1285,7 @@ class Template: if self.environment.is_async: import asyncio - close = False - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - close = True - - try: - return loop.run_until_complete(self.render_async(*args, **kwargs)) - finally: - if close: - loop.close() + return asyncio.run(self.render_async(*args, **kwargs)) ctx = self.new_context(dict(*args, **kwargs)) @@ -1355,7 +1349,7 @@ class Template: async def generate_async( self, *args: t.Any, **kwargs: t.Any - ) -> t.AsyncIterator[str]: + ) -> t.AsyncGenerator[str, object]: """An async version of :meth:`generate`. Works very similarly but returns an async iterator instead. """ @@ -1367,8 +1361,14 @@ class Template: ctx = self.new_context(dict(*args, **kwargs)) try: - async for event in self.root_render_func(ctx): # type: ignore - yield event + agen = self.root_render_func(ctx) + try: + async for event in agen: # type: ignore + yield event + finally: + # we can't use async with aclosing(...) because that's only + # in 3.10+ + await agen.aclose() # type: ignore except Exception: yield self.environment.handle_exception() @@ -1417,7 +1417,9 @@ class Template: """ ctx = self.new_context(vars, shared, locals) return TemplateModule( - self, ctx, [x async for x in self.root_render_func(ctx)] # type: ignore + self, + ctx, + [x async for x in self.root_render_func(ctx)], # type: ignore ) @internalcode @@ -1588,7 +1590,7 @@ class TemplateStream: def dump( self, - fp: t.Union[str, t.IO], + fp: t.Union[str, t.IO[bytes]], encoding: t.Optional[str] = None, errors: t.Optional[str] = "strict", ) -> None: @@ -1606,22 +1608,25 @@ class TemplateStream: if encoding is None: encoding = "utf-8" - fp = open(fp, "wb") + real_fp: t.IO[bytes] = open(fp, "wb") close = True + else: + real_fp = fp + try: if encoding is not None: iterable = (x.encode(encoding, errors) for x in self) # type: ignore else: iterable = self # type: ignore - if hasattr(fp, "writelines"): - fp.writelines(iterable) + if hasattr(real_fp, "writelines"): + real_fp.writelines(iterable) else: for item in iterable: - fp.write(item) + real_fp.write(item) finally: if close: - fp.close() + real_fp.close() def disable_buffering(self) -> None: """Disable the output buffering.""" diff --git a/ext.py b/ext.py index fade1fa..c7af8d4 100755 --- a/ext.py +++ b/ext.py @@ -1,4 +1,5 @@ """Extension API for adding custom tags and behavior.""" + import pprint import re import typing as t @@ -18,23 +19,23 @@ from .utils import pass_context if t.TYPE_CHECKING: import typing_extensions as te + from .lexer import Token from .lexer import TokenStream from .parser import Parser class _TranslationsBasic(te.Protocol): - def gettext(self, message: str) -> str: - ... + def gettext(self, message: str) -> str: ... def ngettext(self, singular: str, plural: str, n: int) -> str: pass class _TranslationsContext(_TranslationsBasic): - def pgettext(self, context: str, message: str) -> str: - ... + def pgettext(self, context: str, message: str) -> str: ... - def npgettext(self, context: str, singular: str, plural: str, n: int) -> str: - ... + def npgettext( + self, context: str, singular: str, plural: str, n: int + ) -> str: ... _SupportedTranslations = t.Union[_TranslationsBasic, _TranslationsContext] @@ -88,7 +89,7 @@ class Extension: def __init__(self, environment: Environment) -> None: self.environment = environment - def bind(self, environment: Environment) -> "Extension": + def bind(self, environment: Environment) -> "te.Self": """Create a copy of this extension bound to another environment.""" rv = object.__new__(self.__class__) rv.__dict__.update(self.__dict__) @@ -218,7 +219,7 @@ def _make_new_pgettext(func: t.Callable[[str, str], str]) -> t.Callable[..., str def _make_new_npgettext( - func: t.Callable[[str, str, str, int], str] + func: t.Callable[[str, str, str, int], str], ) -> t.Callable[..., str]: @pass_context def npgettext( @@ -294,14 +295,14 @@ class InternationalizationExtension(Extension): pgettext = translations.pgettext else: - def pgettext(c: str, s: str) -> str: + def pgettext(c: str, s: str) -> str: # type: ignore[misc] return s if hasattr(translations, "npgettext"): npgettext = translations.npgettext else: - def npgettext(c: str, s: str, p: str, n: int) -> str: + def npgettext(c: str, s: str, p: str, n: int) -> str: # type: ignore[misc] return s if n == 1 else p self._install_callables( diff --git a/filters.py b/filters.py index 8b09247..2bcba4f 100755 --- a/filters.py +++ b/filters.py @@ -1,10 +1,12 @@ """Built-in template filters used with the ``|`` operator.""" + import math import random import re import typing import typing as t from collections import abc +from inspect import getattr_static from itertools import chain from itertools import groupby @@ -28,6 +30,7 @@ from .utils import urlize if t.TYPE_CHECKING: import typing_extensions as te + from .environment import Environment from .nodes import EvalContext from .runtime import Context @@ -122,7 +125,7 @@ def make_multi_attrgetter( def _prepare_attribute_parts( - attr: t.Optional[t.Union[str, int]] + attr: t.Optional[t.Union[str, int]], ) -> t.List[t.Union[str, int]]: if attr is None: return [] @@ -142,7 +145,7 @@ def do_forceescape(value: "t.Union[str, HasHTML]") -> Markup: def do_urlencode( - value: t.Union[str, t.Mapping[str, t.Any], t.Iterable[t.Tuple[str, t.Any]]] + value: t.Union[str, t.Mapping[str, t.Any], t.Iterable[t.Tuple[str, t.Any]]], ) -> str: """Quote data for use in a URL path or query using UTF-8. @@ -267,6 +270,7 @@ def do_xmlattr( sign, this fails with a ``ValueError``. Regardless of this, user input should never be used as keys to this filter, or must be separately validated first. + .. sourcecode:: html+jinja "t.Iterator[V]": + return sync_do_unique( + environment, await auto_to_list(value), case_sensitive, attribute + ) + + def _min_or_max( environment: "Environment", value: "t.Iterable[V]", @@ -563,7 +579,7 @@ def do_default( @pass_eval_context def sync_do_join( eval_ctx: "EvalContext", - value: t.Iterable, + value: t.Iterable[t.Any], d: str = "", attribute: t.Optional[t.Union[str, int]] = None, ) -> str: @@ -621,7 +637,7 @@ def sync_do_join( @async_variant(sync_do_join) # type: ignore async def do_join( eval_ctx: "EvalContext", - value: t.Union[t.AsyncIterable, t.Iterable], + value: t.Union[t.AsyncIterable[t.Any], t.Iterable[t.Any]], d: str = "", attribute: t.Optional[t.Union[str, int]] = None, ) -> str: @@ -984,7 +1000,7 @@ def do_int(value: t.Any, default: int = 0, base: int = 10) -> int: # this quirk is necessary so that "42.23"|int gives 42. try: return int(float(value)) - except (TypeError, ValueError): + except (TypeError, ValueError, OverflowError): return default @@ -1113,7 +1129,7 @@ def do_batch( {%- endfor %} """ - tmp: "t.List[V]" = [] + tmp: t.List[V] = [] for item in value: if len(tmp) == linecount: @@ -1171,7 +1187,7 @@ def do_round( class _GroupTuple(t.NamedTuple): grouper: t.Any - list: t.List + list: t.List[t.Any] # Use the regular tuple repr to hide this subclass if users print # out the value during debugging. @@ -1367,13 +1383,11 @@ def do_mark_unsafe(value: str) -> str: @typing.overload -def do_reverse(value: str) -> str: - ... +def do_reverse(value: str) -> str: ... @typing.overload -def do_reverse(value: "t.Iterable[V]") -> "t.Iterable[V]": - ... +def do_reverse(value: "t.Iterable[V]") -> "t.Iterable[V]": ... def do_reverse(value: t.Union[str, t.Iterable[V]]) -> t.Union[str, t.Iterable[V]]: @@ -1398,55 +1412,51 @@ def do_reverse(value: t.Union[str, t.Iterable[V]]) -> t.Union[str, t.Iterable[V] def do_attr( environment: "Environment", obj: t.Any, name: str ) -> t.Union[Undefined, t.Any]: - """Get an attribute of an object. ``foo|attr("bar")`` works like - ``foo.bar`` just that always an attribute is returned and items are not - looked up. + """Get an attribute of an object. ``foo|attr("bar")`` works like + ``foo.bar``, but returns undefined instead of falling back to ``foo["bar"]`` + if the attribute doesn't exist. See :ref:`Notes on subscriptions ` for more details. """ + # Environment.getattr will fall back to obj[name] if obj.name doesn't exist. + # But we want to call env.getattr to get behavior such as sandboxing. + # Determine if the attr exists first, so we know the fallback won't trigger. try: - name = str(name) - except UnicodeError: - pass - else: - try: - value = getattr(obj, name) - except AttributeError: - pass - else: - if environment.sandboxed: - environment = t.cast("SandboxedEnvironment", environment) + # This avoids executing properties/descriptors, but misses __getattr__ + # and __getattribute__ dynamic attrs. + getattr_static(obj, name) + except AttributeError: + # This finds dynamic attrs, and we know it's not a descriptor at this point. + if not hasattr(obj, name): + return environment.undefined(obj=obj, name=name) - if not environment.is_safe_attribute(obj, name, value): - return environment.unsafe_undefined(obj, name) - - return value - - return environment.undefined(obj=obj, name=name) - - -@typing.overload -def sync_do_map( - context: "Context", value: t.Iterable, name: str, *args: t.Any, **kwargs: t.Any -) -> t.Iterable: - ... + return environment.getattr(obj, name) @typing.overload def sync_do_map( context: "Context", - value: t.Iterable, + value: t.Iterable[t.Any], + name: str, + *args: t.Any, + **kwargs: t.Any, +) -> t.Iterable[t.Any]: ... + + +@typing.overload +def sync_do_map( + context: "Context", + value: t.Iterable[t.Any], *, attribute: str = ..., default: t.Optional[t.Any] = None, -) -> t.Iterable: - ... +) -> t.Iterable[t.Any]: ... @pass_context def sync_do_map( - context: "Context", value: t.Iterable, *args: t.Any, **kwargs: t.Any -) -> t.Iterable: + context: "Context", value: t.Iterable[t.Any], *args: t.Any, **kwargs: t.Any +) -> t.Iterable[t.Any]: """Applies a filter on a sequence of objects or looks up an attribute. This is useful when dealing with lists of objects but you are really only interested in a certain value of it. @@ -1496,32 +1506,30 @@ def sync_do_map( @typing.overload def do_map( context: "Context", - value: t.Union[t.AsyncIterable, t.Iterable], + value: t.Union[t.AsyncIterable[t.Any], t.Iterable[t.Any]], name: str, *args: t.Any, **kwargs: t.Any, -) -> t.Iterable: - ... +) -> t.Iterable[t.Any]: ... @typing.overload def do_map( context: "Context", - value: t.Union[t.AsyncIterable, t.Iterable], + value: t.Union[t.AsyncIterable[t.Any], t.Iterable[t.Any]], *, attribute: str = ..., default: t.Optional[t.Any] = None, -) -> t.Iterable: - ... +) -> t.Iterable[t.Any]: ... @async_variant(sync_do_map) # type: ignore async def do_map( context: "Context", - value: t.Union[t.AsyncIterable, t.Iterable], + value: t.Union[t.AsyncIterable[t.Any], t.Iterable[t.Any]], *args: t.Any, **kwargs: t.Any, -) -> t.AsyncIterable: +) -> t.AsyncIterable[t.Any]: if value: func = prepare_map(context, args, kwargs) @@ -1628,8 +1636,8 @@ def sync_do_selectattr( .. code-block:: python - (u for user in users if user.is_active) - (u for user in users if test_none(user.email)) + (user for user in users if user.is_active) + (user for user in users if test_none(user.email)) .. versionadded:: 2.7 """ @@ -1666,8 +1674,8 @@ def sync_do_rejectattr( .. code-block:: python - (u for user in users if not user.is_active) - (u for user in users if not test_none(user.email)) + (user for user in users if not user.is_active) + (user for user in users if not test_none(user.email)) .. versionadded:: 2.7 """ @@ -1714,7 +1722,7 @@ def do_tojson( def prepare_map( - context: "Context", args: t.Tuple, kwargs: t.Dict[str, t.Any] + context: "Context", args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any] ) -> t.Callable[[t.Any], t.Any]: if not args and "attribute" in kwargs: attribute = kwargs.pop("attribute") @@ -1743,7 +1751,7 @@ def prepare_map( def prepare_select_or_reject( context: "Context", - args: t.Tuple, + args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any], modfunc: t.Callable[[t.Any], t.Any], lookup_attr: bool, @@ -1767,7 +1775,7 @@ def prepare_select_or_reject( args = args[1 + off :] def func(item: t.Any) -> t.Any: - return context.environment.call_test(name, item, args, kwargs) + return context.environment.call_test(name, item, args, kwargs, context) except LookupError: func = bool # type: ignore @@ -1778,7 +1786,7 @@ def prepare_select_or_reject( def select_or_reject( context: "Context", value: "t.Iterable[V]", - args: t.Tuple, + args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any], modfunc: t.Callable[[t.Any], t.Any], lookup_attr: bool, @@ -1794,7 +1802,7 @@ def select_or_reject( async def async_select_or_reject( context: "Context", value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", - args: t.Tuple, + args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any], modfunc: t.Callable[[t.Any], t.Any], lookup_attr: bool, diff --git a/idtracking.py b/idtracking.py index 995ebaa..e6dd8cd 100755 --- a/idtracking.py +++ b/idtracking.py @@ -3,6 +3,9 @@ import typing as t from . import nodes from .visitor import NodeVisitor +if t.TYPE_CHECKING: + import typing_extensions as te + VAR_LOAD_PARAMETER = "param" VAR_LOAD_RESOLVE = "resolve" VAR_LOAD_ALIAS = "alias" @@ -83,7 +86,7 @@ class Symbols: ) return rv - def copy(self) -> "Symbols": + def copy(self) -> "te.Self": rv = object.__new__(self.__class__) rv.__dict__.update(self.__dict__) rv.refs = self.refs.copy() @@ -118,23 +121,20 @@ class Symbols: self._define_ref(name, load=(VAR_LOAD_RESOLVE, name)) def branch_update(self, branch_symbols: t.Sequence["Symbols"]) -> None: - stores: t.Dict[str, int] = {} + stores: t.Set[str] = set() + for branch in branch_symbols: - for target in branch.stores: - if target in self.stores: - continue - stores[target] = stores.get(target, 0) + 1 + stores.update(branch.stores) + + stores.difference_update(self.stores) for sym in branch_symbols: self.refs.update(sym.refs) self.loads.update(sym.loads) self.stores.update(sym.stores) - for name, branch_count in stores.items(): - if branch_count == len(branch_symbols): - continue - - target = self.find_ref(name) # type: ignore + for name in stores: + target = self.find_ref(name) assert target is not None, "should not happen" if self.parent is not None: @@ -146,7 +146,7 @@ class Symbols: def dump_stores(self) -> t.Dict[str, str]: rv: t.Dict[str, str] = {} - node: t.Optional["Symbols"] = self + node: t.Optional[Symbols] = self while node is not None: for name in sorted(node.stores): @@ -159,7 +159,7 @@ class Symbols: def dump_param_targets(self) -> t.Set[str]: rv = set() - node: t.Optional["Symbols"] = self + node: t.Optional[Symbols] = self while node is not None: for target, (instr, _) in self.loads.items(): diff --git a/lexer.py b/lexer.py index aff7e9f..9b1c969 100755 --- a/lexer.py +++ b/lexer.py @@ -3,6 +3,7 @@ is used to do some preprocessing. It filters out invalid operators like the bitshift operators we don't allow in templates. It separates template code and python code in expressions. """ + import re import typing as t from ast import literal_eval @@ -15,6 +16,7 @@ from .utils import LRUCache if t.TYPE_CHECKING: import typing_extensions as te + from .environment import Environment # cache for the lexers. Exists in order to be able to have multiple @@ -260,7 +262,7 @@ class Failure: self.message = message self.error_class = cls - def __call__(self, lineno: int, filename: str) -> "te.NoReturn": + def __call__(self, lineno: int, filename: t.Optional[str]) -> "te.NoReturn": raise self.error_class(self.message, lineno, filename) @@ -327,7 +329,7 @@ class TokenStream: filename: t.Optional[str], ): self._iter = iter(generator) - self._pushed: "te.Deque[Token]" = deque() + self._pushed: te.Deque[Token] = deque() self.name = name self.filename = filename self.closed = False @@ -447,7 +449,7 @@ def get_lexer(environment: "Environment") -> "Lexer": return lexer -class OptionalLStrip(tuple): +class OptionalLStrip(tuple): # type: ignore[type-arg] """A special tuple for marking a point in the state that can have lstrip applied. """ @@ -755,7 +757,7 @@ class Lexer: for idx, token in enumerate(tokens): # failure group - if token.__class__ is Failure: + if isinstance(token, Failure): raise token(lineno, filename) # bygroup is a bit more complex, in that case we # yield for the current token the first named @@ -776,7 +778,7 @@ class Lexer: data = groups[idx] if data or token not in ignore_if_empty: - yield lineno, token, data + yield lineno, token, data # type: ignore[misc] lineno += data.count("\n") + newlines_stripped newlines_stripped = 0 diff --git a/loaders.py b/loaders.py index 32f3a74..3913ee5 100755 --- a/loaders.py +++ b/loaders.py @@ -1,6 +1,7 @@ """API and implementations for loading templates from different data sources. """ + import importlib.util import os import posixpath @@ -177,7 +178,9 @@ class FileSystemLoader(BaseLoader): def __init__( self, - searchpath: t.Union[str, os.PathLike, t.Sequence[t.Union[str, os.PathLike]]], + searchpath: t.Union[ + str, "os.PathLike[str]", t.Sequence[t.Union[str, "os.PathLike[str]"]] + ], encoding: str = "utf-8", followlinks: bool = False, ) -> None: @@ -201,7 +204,12 @@ class FileSystemLoader(BaseLoader): if os.path.isfile(filename): break else: - raise TemplateNotFound(template) + plural = "path" if len(self.searchpath) == 1 else "paths" + paths_str = ", ".join(repr(p) for p in self.searchpath) + raise TemplateNotFound( + template, + f"{template!r} not found in search {plural}: {paths_str}", + ) with open(filename, encoding=self.encoding) as f: contents = f.read() @@ -235,6 +243,30 @@ class FileSystemLoader(BaseLoader): return sorted(found) +if sys.version_info >= (3, 13): + + def _get_zipimporter_files(z: t.Any) -> t.Dict[str, object]: + try: + get_files = z._get_files + except AttributeError as e: + raise TypeError( + "This zip import does not have the required" + " metadata to list templates." + ) from e + return get_files() +else: + + def _get_zipimporter_files(z: t.Any) -> t.Dict[str, object]: + try: + files = z._files + except AttributeError as e: + raise TypeError( + "This zip import does not have the required" + " metadata to list templates." + ) from e + return files # type: ignore[no-any-return] + + class PackageLoader(BaseLoader): """Load templates from a directory in a Python package. @@ -295,7 +327,6 @@ class PackageLoader(BaseLoader): assert loader is not None, "A loader was not found for the package." self._loader = loader self._archive = None - template_root = None if isinstance(loader, zipimport.zipimporter): self._archive = loader.archive @@ -312,18 +343,23 @@ class PackageLoader(BaseLoader): elif spec.origin is not None: roots.append(os.path.dirname(spec.origin)) + if not roots: + raise ValueError( + f"The {package_name!r} package was not installed in a" + " way that PackageLoader understands." + ) + for root in roots: root = os.path.join(root, package_path) if os.path.isdir(root): template_root = root break - - if template_root is None: - raise ValueError( - f"The {package_name!r} package was not installed in a" - " way that PackageLoader understands." - ) + else: + raise ValueError( + f"PackageLoader could not find a {package_path!r} directory" + f" in the {package_name!r} package." + ) self._template_root = template_root @@ -379,11 +415,7 @@ class PackageLoader(BaseLoader): for name in filenames ) else: - if not hasattr(self._loader, "_files"): - raise TypeError( - "This zip import does not have the required" - " metadata to list templates." - ) + files = _get_zipimporter_files(self._loader) # Package is a zip file. prefix = ( @@ -392,7 +424,7 @@ class PackageLoader(BaseLoader): ) offset = len(prefix) - for name in self._loader._files.keys(): + for name in files: # Find names under the templates directory that aren't directories. if name.startswith(prefix) and name[-1] != os.path.sep: results.append(name[offset:].replace(os.path.sep, "/")) @@ -407,7 +439,7 @@ class DictLoader(BaseLoader): >>> loader = DictLoader({'index.html': 'source here'}) - Because auto reloading is rarely useful this is disabled per default. + Because auto reloading is rarely useful this is disabled by default. """ def __init__(self, mapping: t.Mapping[str, str]) -> None: @@ -590,10 +622,7 @@ class ModuleLoader(BaseLoader): Example usage: - >>> loader = ChoiceLoader([ - ... ModuleLoader('/path/to/compiled/templates'), - ... FileSystemLoader('/path/to/templates') - ... ]) + >>> loader = ModuleLoader('/path/to/compiled/templates') Templates can be precompiled with :meth:`Environment.compile_templates`. """ @@ -601,7 +630,10 @@ class ModuleLoader(BaseLoader): has_source_access = False def __init__( - self, path: t.Union[str, os.PathLike, t.Sequence[t.Union[str, os.PathLike]]] + self, + path: t.Union[ + str, "os.PathLike[str]", t.Sequence[t.Union[str, "os.PathLike[str]"]] + ], ) -> None: package_name = f"_jinja2_module_templates_{id(self):x}" diff --git a/meta.py b/meta.py index 0057d6e..298499e 100755 --- a/meta.py +++ b/meta.py @@ -1,6 +1,7 @@ """Functions that expose information about templates that might be interesting for introspection. """ + import typing as t from . import nodes diff --git a/nodes.py b/nodes.py index b2f88d9..2f93b90 100755 --- a/nodes.py +++ b/nodes.py @@ -2,6 +2,7 @@ some node tree helper functions used by the parser and compiler in order to normalize nodes. """ + import inspect import operator import typing as t @@ -13,6 +14,7 @@ from .utils import _PassArg if t.TYPE_CHECKING: import typing_extensions as te + from .environment import Environment _NodeBound = t.TypeVar("_NodeBound", bound="Node") @@ -56,7 +58,7 @@ class NodeType(type): def __new__(mcs, name, bases, d): # type: ignore for attr in "fields", "attributes": - storage = [] + storage: t.List[t.Tuple[str, ...]] = [] storage.extend(getattr(bases[0] if bases else object, attr, ())) storage.extend(d.get(attr, ())) assert len(bases) <= 1, "multiple inheritance not allowed" diff --git a/optimizer.py b/optimizer.py index fe10107..32d1c71 100755 --- a/optimizer.py +++ b/optimizer.py @@ -7,6 +7,7 @@ want. For example, loop unrolling doesn't work because unrolled loops would have a different scope. The solution would be a second syntax tree that stored the scoping rules. """ + import typing as t from . import nodes diff --git a/parser.py b/parser.py index 3354bc9..f411775 100755 --- a/parser.py +++ b/parser.py @@ -1,4 +1,5 @@ """Parse tokens from the lexer into nodes for the compiler.""" + import typing import typing as t @@ -10,6 +11,7 @@ from .lexer import describe_token_expr if t.TYPE_CHECKING: import typing_extensions as te + from .environment import Environment _ImportInclude = t.TypeVar("_ImportInclude", nodes.Import, nodes.Include) @@ -62,7 +64,7 @@ class Parser: self.filename = filename self.closed = False self.extensions: t.Dict[ - str, t.Callable[["Parser"], t.Union[nodes.Node, t.List[nodes.Node]]] + str, t.Callable[[Parser], t.Union[nodes.Node, t.List[nodes.Node]]] ] = {} for extension in environment.iter_extensions(): for tag in extension.tags: @@ -457,8 +459,7 @@ class Parser: @typing.overload def parse_assign_target( self, with_tuple: bool = ..., name_only: "te.Literal[True]" = ... - ) -> nodes.Name: - ... + ) -> nodes.Name: ... @typing.overload def parse_assign_target( @@ -467,8 +468,7 @@ class Parser: name_only: bool = False, extra_end_rules: t.Optional[t.Tuple[str, ...]] = None, with_namespace: bool = False, - ) -> t.Union[nodes.NSRef, nodes.Name, nodes.Tuple]: - ... + ) -> t.Union[nodes.NSRef, nodes.Name, nodes.Tuple]: ... def parse_assign_target( self, @@ -487,21 +487,18 @@ class Parser: """ target: nodes.Expr - if with_namespace and self.stream.look().type == "dot": - token = self.stream.expect("name") - next(self.stream) # dot - attr = self.stream.expect("name") - target = nodes.NSRef(token.value, attr.value, lineno=token.lineno) - elif name_only: + if name_only: token = self.stream.expect("name") target = nodes.Name(token.value, "store", lineno=token.lineno) else: if with_tuple: target = self.parse_tuple( - simplified=True, extra_end_rules=extra_end_rules + simplified=True, + extra_end_rules=extra_end_rules, + with_namespace=with_namespace, ) else: - target = self.parse_primary() + target = self.parse_primary(with_namespace=with_namespace) target.set_ctx("store") @@ -643,17 +640,25 @@ class Parser: node = self.parse_filter_expr(node) return node - def parse_primary(self) -> nodes.Expr: + def parse_primary(self, with_namespace: bool = False) -> nodes.Expr: + """Parse a name or literal value. If ``with_namespace`` is enabled, also + parse namespace attr refs, for use in assignments.""" token = self.stream.current node: nodes.Expr if token.type == "name": + next(self.stream) if token.value in ("true", "false", "True", "False"): node = nodes.Const(token.value in ("true", "True"), lineno=token.lineno) elif token.value in ("none", "None"): node = nodes.Const(None, lineno=token.lineno) + elif with_namespace and self.stream.current.type == "dot": + # If namespace attributes are allowed at this point, and the next + # token is a dot, produce a namespace reference. + next(self.stream) + attr = self.stream.expect("name") + node = nodes.NSRef(token.value, attr.value, lineno=token.lineno) else: node = nodes.Name(token.value, "load", lineno=token.lineno) - next(self.stream) elif token.type == "string": next(self.stream) buf = [token.value] @@ -683,6 +688,7 @@ class Parser: with_condexpr: bool = True, extra_end_rules: t.Optional[t.Tuple[str, ...]] = None, explicit_parentheses: bool = False, + with_namespace: bool = False, ) -> t.Union[nodes.Tuple, nodes.Expr]: """Works like `parse_expression` but if multiple expressions are delimited by a comma a :class:`~jinja2.nodes.Tuple` node is created. @@ -690,8 +696,9 @@ class Parser: if no commas where found. The default parsing mode is a full tuple. If `simplified` is `True` - only names and literals are parsed. The `no_condexpr` parameter is - forwarded to :meth:`parse_expression`. + only names and literals are parsed; ``with_namespace`` allows namespace + attr refs as well. The `no_condexpr` parameter is forwarded to + :meth:`parse_expression`. Because tuples do not require delimiters and may end in a bogus comma an extra hint is needed that marks the end of a tuple. For example @@ -704,13 +711,14 @@ class Parser: """ lineno = self.stream.current.lineno if simplified: - parse = self.parse_primary - elif with_condexpr: - parse = self.parse_expression + + def parse() -> nodes.Expr: + return self.parse_primary(with_namespace=with_namespace) + else: def parse() -> nodes.Expr: - return self.parse_expression(with_condexpr=False) + return self.parse_expression(with_condexpr=with_condexpr) args: t.List[nodes.Expr] = [] is_tuple = False @@ -861,7 +869,14 @@ class Parser: return nodes.Slice(lineno=lineno, *args) # noqa: B026 - def parse_call_args(self) -> t.Tuple: + def parse_call_args( + self, + ) -> t.Tuple[ + t.List[nodes.Expr], + t.List[nodes.Keyword], + t.Optional[nodes.Expr], + t.Optional[nodes.Expr], + ]: token = self.stream.expect("lparen") args = [] kwargs = [] @@ -952,7 +967,7 @@ class Parser: next(self.stream) name += "." + self.stream.expect("name").value dyn_args = dyn_kwargs = None - kwargs = [] + kwargs: t.List[nodes.Keyword] = [] if self.stream.current.type == "lparen": args, kwargs, dyn_args, dyn_kwargs = self.parse_call_args() elif self.stream.current.type in { diff --git a/runtime.py b/runtime.py index 58a540b..09119e2 100755 --- a/runtime.py +++ b/runtime.py @@ -1,4 +1,5 @@ """The runtime functions and state used by compiled templates.""" + import functools import sys import typing as t @@ -28,7 +29,9 @@ F = t.TypeVar("F", bound=t.Callable[..., t.Any]) if t.TYPE_CHECKING: import logging + import typing_extensions as te + from .environment import Environment class LoopRenderFunc(te.Protocol): @@ -37,8 +40,7 @@ if t.TYPE_CHECKING: reciter: t.Iterable[V], loop_render_func: "LoopRenderFunc", depth: int = 0, - ) -> str: - ... + ) -> str: ... # these variables are exported to the template runtime @@ -170,7 +172,7 @@ class Context: ): self.parent = parent self.vars: t.Dict[str, t.Any] = {} - self.environment: "Environment" = environment + self.environment: Environment = environment self.eval_ctx = EvalContext(self.environment, name) self.exported_vars: t.Set[str] = set() self.name = name @@ -259,7 +261,10 @@ class Context: @internalcode def call( - __self, __obj: t.Callable, *args: t.Any, **kwargs: t.Any # noqa: B902 + __self, + __obj: t.Callable[..., t.Any], + *args: t.Any, + **kwargs: t.Any, # noqa: B902 ) -> t.Union[t.Any, "Undefined"]: """Call the callable with the arguments and keyword arguments provided but inject the active context or environment as first @@ -362,7 +367,7 @@ class BlockReference: @internalcode async def _async_call(self) -> str: - rv = concat( + rv = self._context.environment.concat( # type: ignore [x async for x in self._stack[self._depth](self._context)] # type: ignore ) @@ -376,7 +381,9 @@ class BlockReference: if self._context.environment.is_async: return self._async_call() # type: ignore - rv = concat(self._stack[self._depth](self._context)) + rv = self._context.environment.concat( # type: ignore + self._stack[self._depth](self._context) + ) if self._context.eval_ctx.autoescape: return Markup(rv) @@ -586,7 +593,7 @@ class AsyncLoopContext(LoopContext): @staticmethod def _to_iterator( # type: ignore - iterable: t.Union[t.Iterable[V], t.AsyncIterable[V]] + iterable: t.Union[t.Iterable[V], t.AsyncIterable[V]], ) -> t.AsyncIterator[V]: return auto_aiter(iterable) @@ -787,8 +794,8 @@ class Macro: class Undefined: - """The default undefined type. This undefined type can be printed and - iterated over, but every other access will raise an :exc:`UndefinedError`: + """The default undefined type. This can be printed, iterated, and treated as + a boolean. Any other operation will raise an :exc:`UndefinedError`. >>> foo = Undefined(name='foo') >>> str(foo) @@ -853,7 +860,11 @@ class Undefined: @internalcode def __getattr__(self, name: str) -> t.Any: - if name[:2] == "__": + # Raise AttributeError on requests for names that appear to be unimplemented + # dunder methods to keep Python's internal protocol probing behaviors working + # properly in cases where another exception type could cause unexpected or + # difficult-to-diagnose failures. + if name[:2] == "__" and name[-2:] == "__": raise AttributeError(name) return self._fail_with_undefined_error() @@ -977,10 +988,20 @@ class ChainableUndefined(Undefined): def __html__(self) -> str: return str(self) - def __getattr__(self, _: str) -> "ChainableUndefined": + def __getattr__(self, name: str) -> "ChainableUndefined": + # Raise AttributeError on requests for names that appear to be unimplemented + # dunder methods to avoid confusing Python with truthy non-method objects that + # do not implement the protocol being probed for. e.g., copy.copy(Undefined()) + # fails spectacularly if getattr(Undefined(), '__setstate__') returns an + # Undefined object instead of raising AttributeError to signal that it does not + # support that style of object initialization. + if name[:2] == "__" and name[-2:] == "__": + raise AttributeError(name) + return self - __getitem__ = __getattr__ # type: ignore + def __getitem__(self, _name: str) -> "ChainableUndefined": # type: ignore[override] + return self class DebugUndefined(Undefined): @@ -1039,13 +1060,3 @@ class StrictUndefined(Undefined): __iter__ = __str__ = __len__ = Undefined._fail_with_undefined_error __eq__ = __ne__ = __bool__ = __hash__ = Undefined._fail_with_undefined_error __contains__ = Undefined._fail_with_undefined_error - - -# Remove slots attributes, after the metaclass is applied they are -# unneeded and contain wrong data for subclasses. -del ( - Undefined.__slots__, - ChainableUndefined.__slots__, - DebugUndefined.__slots__, - StrictUndefined.__slots__, -) diff --git a/sandbox.py b/sandbox.py index 06d7414..9c9dae2 100755 --- a/sandbox.py +++ b/sandbox.py @@ -1,12 +1,14 @@ """A sandbox layer that ensures unsafe operations cannot be performed. Useful when the template itself comes from an untrusted source. """ + import operator import types import typing as t from _string import formatter_field_name_split # type: ignore from collections import abc from collections import deque +from functools import update_wrapper from string import Formatter from markupsafe import EscapeFormatter @@ -37,7 +39,7 @@ UNSAFE_COROUTINE_ATTRIBUTES = {"cr_frame", "cr_code"} #: unsafe attributes on async generators UNSAFE_ASYNC_GENERATOR_ATTRIBUTES = {"ag_code", "ag_frame"} -_mutable_spec: t.Tuple[t.Tuple[t.Type, t.FrozenSet[str]], ...] = ( +_mutable_spec: t.Tuple[t.Tuple[t.Type[t.Any], t.FrozenSet[str]], ...] = ( ( abc.MutableSet, frozenset( @@ -59,7 +61,9 @@ _mutable_spec: t.Tuple[t.Tuple[t.Type, t.FrozenSet[str]], ...] = ( ), ( abc.MutableSequence, - frozenset(["append", "reverse", "insert", "sort", "extend", "remove"]), + frozenset( + ["append", "clear", "pop", "reverse", "insert", "sort", "extend", "remove"] + ), ), ( deque, @@ -80,20 +84,6 @@ _mutable_spec: t.Tuple[t.Tuple[t.Type, t.FrozenSet[str]], ...] = ( ) -def inspect_format_method(callable: t.Callable) -> t.Optional[str]: - if not isinstance( - callable, (types.MethodType, types.BuiltinMethodType) - ) or callable.__name__ not in ("format", "format_map"): - return None - - obj = callable.__self__ - - if isinstance(obj, str): - return obj - - return None - - def safe_range(*args: int) -> range: """A range that can't generate ranges with a length of more than MAX_RANGE items. @@ -313,6 +303,9 @@ class SandboxedEnvironment(Environment): except AttributeError: pass else: + fmt = self.wrap_str_format(value) + if fmt is not None: + return fmt if self.is_safe_attribute(obj, argument, value): return value return self.unsafe_undefined(obj, argument) @@ -330,6 +323,9 @@ class SandboxedEnvironment(Environment): except (TypeError, LookupError): pass else: + fmt = self.wrap_str_format(value) + if fmt is not None: + return fmt if self.is_safe_attribute(obj, attribute, value): return value return self.unsafe_undefined(obj, attribute) @@ -345,34 +341,49 @@ class SandboxedEnvironment(Environment): exc=SecurityError, ) - def format_string( - self, - s: str, - args: t.Tuple[t.Any, ...], - kwargs: t.Dict[str, t.Any], - format_func: t.Optional[t.Callable] = None, - ) -> str: - """If a format call is detected, then this is routed through this - method so that our safety sandbox can be used for it. + def wrap_str_format(self, value: t.Any) -> t.Optional[t.Callable[..., str]]: + """If the given value is a ``str.format`` or ``str.format_map`` method, + return a new function than handles sandboxing. This is done at access + rather than in :meth:`call`, so that calls made without ``call`` are + also sandboxed. """ + if not isinstance( + value, (types.MethodType, types.BuiltinMethodType) + ) or value.__name__ not in ("format", "format_map"): + return None + + f_self: t.Any = value.__self__ + + if not isinstance(f_self, str): + return None + + str_type: t.Type[str] = type(f_self) + is_format_map = value.__name__ == "format_map" formatter: SandboxedFormatter - if isinstance(s, Markup): - formatter = SandboxedEscapeFormatter(self, escape=s.escape) + + if isinstance(f_self, Markup): + formatter = SandboxedEscapeFormatter(self, escape=f_self.escape) else: formatter = SandboxedFormatter(self) - if format_func is not None and format_func.__name__ == "format_map": - if len(args) != 1 or kwargs: - raise TypeError( - "format_map() takes exactly one argument" - f" {len(args) + (kwargs is not None)} given" - ) + vformat = formatter.vformat - kwargs = args[0] - args = () + def wrapper(*args: t.Any, **kwargs: t.Any) -> str: + if is_format_map: + if kwargs: + raise TypeError("format_map() takes no keyword arguments") - rv = formatter.vformat(s, args, kwargs) - return type(s)(rv) + if len(args) != 1: + raise TypeError( + f"format_map() takes exactly one argument ({len(args)} given)" + ) + + kwargs = args[0] + args = () + + return str_type(vformat(f_self, args, kwargs)) + + return update_wrapper(wrapper, value) def call( __self, # noqa: B902 @@ -382,9 +393,6 @@ class SandboxedEnvironment(Environment): **kwargs: t.Any, ) -> t.Any: """Call an object from sandboxed code.""" - fmt = inspect_format_method(__obj) - if fmt is not None: - return __self.format_string(fmt, args, kwargs, __obj) # the double prefixes are to avoid double keyword argument # errors when proxying the call. diff --git a/tests.py b/tests.py index a467cf0..1a59e37 100755 --- a/tests.py +++ b/tests.py @@ -1,4 +1,5 @@ """Built-in template tests used with the ``is`` operator.""" + import operator import typing as t from collections import abc @@ -169,7 +170,7 @@ def test_sequence(value: t.Any) -> bool: """ try: len(value) - value.__getitem__ + value.__getitem__ # noqa B018 except Exception: return False @@ -204,7 +205,7 @@ def test_escaped(value: t.Any) -> bool: return hasattr(value, "__html__") -def test_in(value: t.Any, seq: t.Container) -> bool: +def test_in(value: t.Any, seq: t.Container[t.Any]) -> bool: """Check if value is in seq. .. versionadded:: 2.10 diff --git a/utils.py b/utils.py index 18914a5..7c92262 100755 --- a/utils.py +++ b/utils.py @@ -18,8 +18,17 @@ if t.TYPE_CHECKING: F = t.TypeVar("F", bound=t.Callable[..., t.Any]) -# special singleton representing missing values for the runtime -missing: t.Any = type("MissingType", (), {"__repr__": lambda x: "missing"})() + +class _MissingType: + def __repr__(self) -> str: + return "missing" + + def __reduce__(self) -> str: + return "missing" + + +missing: t.Any = _MissingType() +"""Special singleton representing missing values for the runtime.""" internal_code: t.MutableSet[CodeType] = set() @@ -152,7 +161,7 @@ def import_string(import_name: str, silent: bool = False) -> t.Any: raise -def open_if_exists(filename: str, mode: str = "rb") -> t.Optional[t.IO]: +def open_if_exists(filename: str, mode: str = "rb") -> t.Optional[t.IO[t.Any]]: """Returns a file descriptor for the filename if that file exists, otherwise ``None``. """ @@ -324,6 +333,8 @@ def urlize( elif ( "@" in middle and not middle.startswith("www.") + # ignore values like `@a@b` + and not middle.startswith("@") and ":" not in middle and _email_re.match(middle) ): @@ -428,7 +439,7 @@ class LRUCache: def __init__(self, capacity: int) -> None: self.capacity = capacity self._mapping: t.Dict[t.Any, t.Any] = {} - self._queue: "te.Deque[t.Any]" = deque() + self._queue: te.Deque[t.Any] = deque() self._postinit() def _postinit(self) -> None: @@ -450,10 +461,10 @@ class LRUCache: self.__dict__.update(d) self._postinit() - def __getnewargs__(self) -> t.Tuple: + def __getnewargs__(self) -> t.Tuple[t.Any, ...]: return (self.capacity,) - def copy(self) -> "LRUCache": + def copy(self) -> "te.Self": """Return a shallow copy of the instance.""" rv = self.__class__(self.capacity) rv._mapping.update(self._mapping) diff --git a/visitor.py b/visitor.py index 17c6aab..7b8e180 100755 --- a/visitor.py +++ b/visitor.py @@ -1,6 +1,7 @@ """API for traversing the AST nodes. Implemented by the compiler and meta introspection. """ + import typing as t from .nodes import Node @@ -9,8 +10,7 @@ if t.TYPE_CHECKING: import typing_extensions as te class VisitCallable(te.Protocol): - def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any: - ... + def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any: ... class NodeVisitor: