diff --git a/__init__.py b/__init__.py
index af5d428..1a423a3 100755
--- a/__init__.py
+++ b/__init__.py
@@ -2,6 +2,7 @@
non-XML syntax that supports inline expressions and an optional
sandboxed environment.
"""
+
from .bccache import BytecodeCache as BytecodeCache
from .bccache import FileSystemBytecodeCache as FileSystemBytecodeCache
from .bccache import MemcachedBytecodeCache as MemcachedBytecodeCache
@@ -34,4 +35,4 @@ from .utils import pass_environment as pass_environment
from .utils import pass_eval_context as pass_eval_context
from .utils import select_autoescape as select_autoescape
-__version__ = "3.1.3"
+__version__ = "3.1.6"
diff --git a/async_utils.py b/async_utils.py
index 715d701..f0c1402 100644
--- a/async_utils.py
+++ b/async_utils.py
@@ -6,6 +6,9 @@ from functools import wraps
from .utils import _PassArg
from .utils import pass_eval_context
+if t.TYPE_CHECKING:
+ import typing_extensions as te
+
V = t.TypeVar("V")
@@ -47,7 +50,7 @@ def async_variant(normal_func): # type: ignore
if need_eval_context:
wrapper = pass_eval_context(wrapper)
- wrapper.jinja_async_variant = True
+ wrapper.jinja_async_variant = True # type: ignore[attr-defined]
return wrapper
return decorator
@@ -64,18 +67,30 @@ async def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V":
if inspect.isawaitable(value):
return await t.cast("t.Awaitable[V]", value)
- return t.cast("V", value)
+ return value
-async def auto_aiter(
+class _IteratorToAsyncIterator(t.Generic[V]):
+ def __init__(self, iterator: "t.Iterator[V]"):
+ self._iterator = iterator
+
+ def __aiter__(self) -> "te.Self":
+ return self
+
+ async def __anext__(self) -> V:
+ try:
+ return next(self._iterator)
+ except StopIteration as e:
+ raise StopAsyncIteration(e.value) from e
+
+
+def auto_aiter(
iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
) -> "t.AsyncIterator[V]":
if hasattr(iterable, "__aiter__"):
- async for item in t.cast("t.AsyncIterable[V]", iterable):
- yield item
+ return iterable.__aiter__()
else:
- for item in iterable:
- yield item
+ return _IteratorToAsyncIterator(iter(iterable))
async def auto_to_list(
diff --git a/bccache.py b/bccache.py
index d0ddf56..ada8b09 100755
--- a/bccache.py
+++ b/bccache.py
@@ -5,6 +5,7 @@ slows down your application too much.
Situations where this is useful are often forking web applications that
are initialized on the first request.
"""
+
import errno
import fnmatch
import marshal
@@ -20,14 +21,15 @@ from types import CodeType
if t.TYPE_CHECKING:
import typing_extensions as te
+
from .environment import Environment
class _MemcachedClient(te.Protocol):
- def get(self, key: str) -> bytes:
- ...
+ def get(self, key: str) -> bytes: ...
- def set(self, key: str, value: bytes, timeout: t.Optional[int] = None) -> None:
- ...
+ def set(
+ self, key: str, value: bytes, timeout: t.Optional[int] = None
+ ) -> None: ...
bc_version = 5
diff --git a/compiler.py b/compiler.py
index ff95c80..a4ff6a1 100755
--- a/compiler.py
+++ b/compiler.py
@@ -1,4 +1,5 @@
"""Compiles nodes from the parser into Python code."""
+
import typing as t
from contextlib import contextmanager
from functools import update_wrapper
@@ -24,6 +25,7 @@ from .visitor import NodeVisitor
if t.TYPE_CHECKING:
import typing_extensions as te
+
from .environment import Environment
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
@@ -53,15 +55,14 @@ def optimizeconst(f: F) -> F:
return f(self, node, frame, **kwargs)
- return update_wrapper(t.cast(F, new_func), f)
+ return update_wrapper(new_func, f) # type: ignore[return-value]
def _make_binop(op: str) -> t.Callable[["CodeGenerator", nodes.BinExpr, "Frame"], None]:
@optimizeconst
def visitor(self: "CodeGenerator", node: nodes.BinExpr, frame: Frame) -> None:
if (
- self.environment.sandboxed
- and op in self.environment.intercepted_binops # type: ignore
+ self.environment.sandboxed and op in self.environment.intercepted_binops # type: ignore
):
self.write(f"environment.call_binop(context, {op!r}, ")
self.visit(node.left, frame)
@@ -84,8 +85,7 @@ def _make_unop(
@optimizeconst
def visitor(self: "CodeGenerator", node: nodes.UnaryExpr, frame: Frame) -> None:
if (
- self.environment.sandboxed
- and op in self.environment.intercepted_unops # type: ignore
+ self.environment.sandboxed and op in self.environment.intercepted_unops # type: ignore
):
self.write(f"environment.call_unop(context, {op!r}, ")
self.visit(node.node, frame)
@@ -133,7 +133,7 @@ def has_safe_repr(value: t.Any) -> bool:
if type(value) in {tuple, list, set, frozenset}:
return all(has_safe_repr(v) for v in value)
- if type(value) is dict:
+ if type(value) is dict: # noqa E721
return all(has_safe_repr(k) and has_safe_repr(v) for k, v in value.items())
return False
@@ -216,7 +216,7 @@ class Frame:
# or compile time.
self.soft_frame = False
- def copy(self) -> "Frame":
+ def copy(self) -> "te.Self":
"""Create a copy of the current one."""
rv = object.__new__(self.__class__)
rv.__dict__.update(self.__dict__)
@@ -229,7 +229,7 @@ class Frame:
return Frame(self.eval_ctx, level=self.symbols.level + 1)
return Frame(self.eval_ctx, self)
- def soft(self) -> "Frame":
+ def soft(self) -> "te.Self":
"""Return a soft frame. A soft frame may not be modified as
standalone thing as it shares the resources with the frame it
was created of, but it's not a rootlevel frame any longer.
@@ -551,10 +551,13 @@ class CodeGenerator(NodeVisitor):
for node in nodes:
visitor.visit(node)
- for id_map, names, dependency in (self.filters, visitor.filters, "filters"), (
- self.tests,
- visitor.tests,
- "tests",
+ for id_map, names, dependency in (
+ (self.filters, visitor.filters, "filters"),
+ (
+ self.tests,
+ visitor.tests,
+ "tests",
+ ),
):
for name in sorted(names):
if name not in id_map:
@@ -808,7 +811,7 @@ class CodeGenerator(NodeVisitor):
self.writeline("_block_vars.update({")
else:
self.writeline("context.vars.update({")
- for idx, name in enumerate(vars):
+ for idx, name in enumerate(sorted(vars)):
if idx:
self.write(", ")
ref = frame.symbols.ref(name)
@@ -818,7 +821,7 @@ class CodeGenerator(NodeVisitor):
if len(public_names) == 1:
self.writeline(f"context.exported_vars.add({public_names[0]!r})")
else:
- names_str = ", ".join(map(repr, public_names))
+ names_str = ", ".join(map(repr, sorted(public_names)))
self.writeline(f"context.exported_vars.update(({names_str}))")
# -- Statement Visitors
@@ -829,7 +832,8 @@ class CodeGenerator(NodeVisitor):
assert frame is None, "no root frame allowed"
eval_ctx = EvalContext(self.environment, self.name)
- from .runtime import exported, async_exported
+ from .runtime import async_exported
+ from .runtime import exported
if self.environment.is_async:
exported_names = sorted(exported + async_exported)
@@ -898,12 +902,15 @@ class CodeGenerator(NodeVisitor):
if not self.environment.is_async:
self.writeline("yield from parent_template.root_render_func(context)")
else:
- self.writeline(
- "async for event in parent_template.root_render_func(context):"
- )
+ self.writeline("agen = parent_template.root_render_func(context)")
+ self.writeline("try:")
+ self.indent()
+ self.writeline("async for event in agen:")
self.indent()
self.writeline("yield event")
self.outdent()
+ self.outdent()
+ self.writeline("finally: await agen.aclose()")
self.outdent(1 + (not self.has_known_extends))
# at this point we now have the blocks collected and can visit them too.
@@ -973,14 +980,20 @@ class CodeGenerator(NodeVisitor):
f"yield from context.blocks[{node.name!r}][0]({context})", node
)
else:
+ self.writeline(f"gen = context.blocks[{node.name!r}][0]({context})")
+ self.writeline("try:")
+ self.indent()
self.writeline(
- f"{self.choose_async()}for event in"
- f" context.blocks[{node.name!r}][0]({context}):",
+ f"{self.choose_async()}for event in gen:",
node,
)
self.indent()
self.simple_write("event", frame)
self.outdent()
+ self.outdent()
+ self.writeline(
+ f"finally: {self.choose_async('await gen.aclose()', 'gen.close()')}"
+ )
self.outdent(level)
@@ -1053,26 +1066,33 @@ class CodeGenerator(NodeVisitor):
self.writeline("else:")
self.indent()
- skip_event_yield = False
+ def loop_body() -> None:
+ self.indent()
+ self.simple_write("event", frame)
+ self.outdent()
+
if node.with_context:
self.writeline(
- f"{self.choose_async()}for event in template.root_render_func("
+ f"gen = template.root_render_func("
"template.new_context(context.get_all(), True,"
- f" {self.dump_local_context(frame)})):"
+ f" {self.dump_local_context(frame)}))"
+ )
+ self.writeline("try:")
+ self.indent()
+ self.writeline(f"{self.choose_async()}for event in gen:")
+ loop_body()
+ self.outdent()
+ self.writeline(
+ f"finally: {self.choose_async('await gen.aclose()', 'gen.close()')}"
)
elif self.environment.is_async:
self.writeline(
"for event in (await template._get_default_module_async())"
"._body_stream:"
)
+ loop_body()
else:
self.writeline("yield from template._get_default_module()._body_stream")
- skip_event_yield = True
-
- if not skip_event_yield:
- self.indent()
- self.simple_write("event", frame)
- self.outdent()
if node.ignore_missing:
self.outdent()
@@ -1121,9 +1141,14 @@ class CodeGenerator(NodeVisitor):
)
self.writeline(f"if {frame.symbols.ref(alias)} is missing:")
self.indent()
+ # The position will contain the template name, and will be formatted
+ # into a string that will be compiled into an f-string. Curly braces
+ # in the name must be replaced with escapes so that they will not be
+ # executed as part of the f-string.
+ position = self.position(node).replace("{", "{{").replace("}", "}}")
message = (
"the template {included_template.__name__!r}"
- f" (imported on {self.position(node)})"
+ f" (imported on {position})"
f" does not export the requested name {name!r}"
)
self.writeline(
@@ -1556,6 +1581,29 @@ class CodeGenerator(NodeVisitor):
def visit_Assign(self, node: nodes.Assign, frame: Frame) -> None:
self.push_assign_tracking()
+
+ # ``a.b`` is allowed for assignment, and is parsed as an NSRef. However,
+ # it is only valid if it references a Namespace object. Emit a check for
+ # that for each ref here, before assignment code is emitted. This can't
+ # be done in visit_NSRef as the ref could be in the middle of a tuple.
+ seen_refs: t.Set[str] = set()
+
+ for nsref in node.find_all(nodes.NSRef):
+ if nsref.name in seen_refs:
+ # Only emit the check for each reference once, in case the same
+ # ref is used multiple times in a tuple, `ns.a, ns.b = c, d`.
+ continue
+
+ seen_refs.add(nsref.name)
+ ref = frame.symbols.ref(nsref.name)
+ self.writeline(f"if not isinstance({ref}, Namespace):")
+ self.indent()
+ self.writeline(
+ "raise TemplateRuntimeError"
+ '("cannot assign attribute on non-namespace object")'
+ )
+ self.outdent()
+
self.newline(node)
self.visit(node.target, frame)
self.write(" = ")
@@ -1612,17 +1660,11 @@ class CodeGenerator(NodeVisitor):
self.write(ref)
def visit_NSRef(self, node: nodes.NSRef, frame: Frame) -> None:
- # NSRefs can only be used to store values; since they use the normal
- # `foo.bar` notation they will be parsed as a normal attribute access
- # when used anywhere but in a `set` context
+ # NSRef is a dotted assignment target a.b=c, but uses a[b]=c internally.
+ # visit_Assign emits code to validate that each ref is to a Namespace
+ # object only. That can't be emitted here as the ref could be in the
+ # middle of a tuple assignment.
ref = frame.symbols.ref(node.name)
- self.writeline(f"if not isinstance({ref}, Namespace):")
- self.indent()
- self.writeline(
- "raise TemplateRuntimeError"
- '("cannot assign attribute on non-namespace object")'
- )
- self.outdent()
self.writeline(f"{ref}[{node.attr!r}]")
def visit_Const(self, node: nodes.Const, frame: Frame) -> None:
diff --git a/debug.py b/debug.py
index 7ed7e92..eeeeee7 100755
--- a/debug.py
+++ b/debug.py
@@ -152,7 +152,7 @@ def get_template_locals(real_locals: t.Mapping[str, t.Any]) -> t.Dict[str, t.Any
available at that point in the template.
"""
# Start with the current template context.
- ctx: "t.Optional[Context]" = real_locals.get("context")
+ ctx: t.Optional[Context] = real_locals.get("context")
if ctx is not None:
data: t.Dict[str, t.Any] = ctx.get_all().copy()
diff --git a/environment.py b/environment.py
index 185d332..0fc6e5b 100755
--- a/environment.py
+++ b/environment.py
@@ -1,6 +1,7 @@
"""Classes for managing templates and their runtime and compile time
options.
"""
+
import os
import typing
import typing as t
@@ -20,10 +21,10 @@ from .defaults import BLOCK_END_STRING
from .defaults import BLOCK_START_STRING
from .defaults import COMMENT_END_STRING
from .defaults import COMMENT_START_STRING
-from .defaults import DEFAULT_FILTERS
+from .defaults import DEFAULT_FILTERS # type: ignore[attr-defined]
from .defaults import DEFAULT_NAMESPACE
from .defaults import DEFAULT_POLICIES
-from .defaults import DEFAULT_TESTS
+from .defaults import DEFAULT_TESTS # type: ignore[attr-defined]
from .defaults import KEEP_TRAILING_NEWLINE
from .defaults import LINE_COMMENT_PREFIX
from .defaults import LINE_STATEMENT_PREFIX
@@ -55,6 +56,7 @@ from .utils import missing
if t.TYPE_CHECKING:
import typing_extensions as te
+
from .bccache import BytecodeCache
from .ext import Extension
from .loaders import BaseLoader
@@ -79,7 +81,7 @@ def get_spontaneous_environment(cls: t.Type[_env_bound], *args: t.Any) -> _env_b
def create_cache(
size: int,
-) -> t.Optional[t.MutableMapping[t.Tuple[weakref.ref, str], "Template"]]:
+) -> t.Optional[t.MutableMapping[t.Tuple["weakref.ref[t.Any]", str], "Template"]]:
"""Return the cache class for the given size."""
if size == 0:
return None
@@ -91,13 +93,13 @@ def create_cache(
def copy_cache(
- cache: t.Optional[t.MutableMapping],
-) -> t.Optional[t.MutableMapping[t.Tuple[weakref.ref, str], "Template"]]:
+ cache: t.Optional[t.MutableMapping[t.Any, t.Any]],
+) -> t.Optional[t.MutableMapping[t.Tuple["weakref.ref[t.Any]", str], "Template"]]:
"""Create an empty copy of the given cache."""
if cache is None:
return None
- if type(cache) is dict:
+ if type(cache) is dict: # noqa E721
return {}
return LRUCache(cache.capacity) # type: ignore
@@ -121,7 +123,7 @@ def load_extensions(
return result
-def _environment_config_check(environment: "Environment") -> "Environment":
+def _environment_config_check(environment: _env_bound) -> _env_bound:
"""Perform a sanity check on the environment."""
assert issubclass(
environment.undefined, Undefined
@@ -404,8 +406,8 @@ class Environment:
cache_size: int = missing,
auto_reload: bool = missing,
bytecode_cache: t.Optional["BytecodeCache"] = missing,
- enable_async: bool = False,
- ) -> "Environment":
+ enable_async: bool = missing,
+ ) -> "te.Self":
"""Create a new overlay environment that shares all the data with the
current environment except for cache and the overridden attributes.
Extensions cannot be removed for an overlayed environment. An overlayed
@@ -417,8 +419,11 @@ class Environment:
copied over so modifications on the original environment may not shine
through.
+ .. versionchanged:: 3.1.5
+ ``enable_async`` is applied correctly.
+
.. versionchanged:: 3.1.2
- Added the ``newline_sequence``,, ``keep_trailing_newline``,
+ Added the ``newline_sequence``, ``keep_trailing_newline``,
and ``enable_async`` parameters to match ``__init__``.
"""
args = dict(locals())
@@ -670,7 +675,7 @@ class Environment:
stream = ext.filter_stream(stream) # type: ignore
if not isinstance(stream, TokenStream):
- stream = TokenStream(stream, name, filename) # type: ignore
+ stream = TokenStream(stream, name, filename)
return stream
@@ -704,15 +709,14 @@ class Environment:
return compile(source, filename, "exec")
@typing.overload
- def compile( # type: ignore
+ def compile(
self,
source: t.Union[str, nodes.Template],
name: t.Optional[str] = None,
filename: t.Optional[str] = None,
raw: "te.Literal[False]" = False,
defer_init: bool = False,
- ) -> CodeType:
- ...
+ ) -> CodeType: ...
@typing.overload
def compile(
@@ -722,8 +726,7 @@ class Environment:
filename: t.Optional[str] = None,
raw: "te.Literal[True]" = ...,
defer_init: bool = False,
- ) -> str:
- ...
+ ) -> str: ...
@internalcode
def compile(
@@ -814,7 +817,7 @@ class Environment:
def compile_templates(
self,
- target: t.Union[str, os.PathLike],
+ target: t.Union[str, "os.PathLike[str]"],
extensions: t.Optional[t.Collection[str]] = None,
filter_func: t.Optional[t.Callable[[str], bool]] = None,
zip: t.Optional[str] = "deflated",
@@ -858,7 +861,10 @@ class Environment:
f.write(data.encode("utf8"))
if zip is not None:
- from zipfile import ZipFile, ZipInfo, ZIP_DEFLATED, ZIP_STORED
+ from zipfile import ZIP_DEFLATED
+ from zipfile import ZIP_STORED
+ from zipfile import ZipFile
+ from zipfile import ZipInfo
zip_file = ZipFile(
target, "w", dict(deflated=ZIP_DEFLATED, stored=ZIP_STORED)[zip]
@@ -1245,7 +1251,7 @@ class Template:
namespace: t.MutableMapping[str, t.Any],
globals: t.MutableMapping[str, t.Any],
) -> "Template":
- t: "Template" = object.__new__(cls)
+ t: Template = object.__new__(cls)
t.environment = environment
t.globals = globals
t.name = namespace["name"]
@@ -1279,19 +1285,7 @@ class Template:
if self.environment.is_async:
import asyncio
- close = False
-
- try:
- loop = asyncio.get_running_loop()
- except RuntimeError:
- loop = asyncio.new_event_loop()
- close = True
-
- try:
- return loop.run_until_complete(self.render_async(*args, **kwargs))
- finally:
- if close:
- loop.close()
+ return asyncio.run(self.render_async(*args, **kwargs))
ctx = self.new_context(dict(*args, **kwargs))
@@ -1355,7 +1349,7 @@ class Template:
async def generate_async(
self, *args: t.Any, **kwargs: t.Any
- ) -> t.AsyncIterator[str]:
+ ) -> t.AsyncGenerator[str, object]:
"""An async version of :meth:`generate`. Works very similarly but
returns an async iterator instead.
"""
@@ -1367,8 +1361,14 @@ class Template:
ctx = self.new_context(dict(*args, **kwargs))
try:
- async for event in self.root_render_func(ctx): # type: ignore
- yield event
+ agen = self.root_render_func(ctx)
+ try:
+ async for event in agen: # type: ignore
+ yield event
+ finally:
+ # we can't use async with aclosing(...) because that's only
+ # in 3.10+
+ await agen.aclose() # type: ignore
except Exception:
yield self.environment.handle_exception()
@@ -1417,7 +1417,9 @@ class Template:
"""
ctx = self.new_context(vars, shared, locals)
return TemplateModule(
- self, ctx, [x async for x in self.root_render_func(ctx)] # type: ignore
+ self,
+ ctx,
+ [x async for x in self.root_render_func(ctx)], # type: ignore
)
@internalcode
@@ -1588,7 +1590,7 @@ class TemplateStream:
def dump(
self,
- fp: t.Union[str, t.IO],
+ fp: t.Union[str, t.IO[bytes]],
encoding: t.Optional[str] = None,
errors: t.Optional[str] = "strict",
) -> None:
@@ -1606,22 +1608,25 @@ class TemplateStream:
if encoding is None:
encoding = "utf-8"
- fp = open(fp, "wb")
+ real_fp: t.IO[bytes] = open(fp, "wb")
close = True
+ else:
+ real_fp = fp
+
try:
if encoding is not None:
iterable = (x.encode(encoding, errors) for x in self) # type: ignore
else:
iterable = self # type: ignore
- if hasattr(fp, "writelines"):
- fp.writelines(iterable)
+ if hasattr(real_fp, "writelines"):
+ real_fp.writelines(iterable)
else:
for item in iterable:
- fp.write(item)
+ real_fp.write(item)
finally:
if close:
- fp.close()
+ real_fp.close()
def disable_buffering(self) -> None:
"""Disable the output buffering."""
diff --git a/ext.py b/ext.py
index fade1fa..c7af8d4 100755
--- a/ext.py
+++ b/ext.py
@@ -1,4 +1,5 @@
"""Extension API for adding custom tags and behavior."""
+
import pprint
import re
import typing as t
@@ -18,23 +19,23 @@ from .utils import pass_context
if t.TYPE_CHECKING:
import typing_extensions as te
+
from .lexer import Token
from .lexer import TokenStream
from .parser import Parser
class _TranslationsBasic(te.Protocol):
- def gettext(self, message: str) -> str:
- ...
+ def gettext(self, message: str) -> str: ...
def ngettext(self, singular: str, plural: str, n: int) -> str:
pass
class _TranslationsContext(_TranslationsBasic):
- def pgettext(self, context: str, message: str) -> str:
- ...
+ def pgettext(self, context: str, message: str) -> str: ...
- def npgettext(self, context: str, singular: str, plural: str, n: int) -> str:
- ...
+ def npgettext(
+ self, context: str, singular: str, plural: str, n: int
+ ) -> str: ...
_SupportedTranslations = t.Union[_TranslationsBasic, _TranslationsContext]
@@ -88,7 +89,7 @@ class Extension:
def __init__(self, environment: Environment) -> None:
self.environment = environment
- def bind(self, environment: Environment) -> "Extension":
+ def bind(self, environment: Environment) -> "te.Self":
"""Create a copy of this extension bound to another environment."""
rv = object.__new__(self.__class__)
rv.__dict__.update(self.__dict__)
@@ -218,7 +219,7 @@ def _make_new_pgettext(func: t.Callable[[str, str], str]) -> t.Callable[..., str
def _make_new_npgettext(
- func: t.Callable[[str, str, str, int], str]
+ func: t.Callable[[str, str, str, int], str],
) -> t.Callable[..., str]:
@pass_context
def npgettext(
@@ -294,14 +295,14 @@ class InternationalizationExtension(Extension):
pgettext = translations.pgettext
else:
- def pgettext(c: str, s: str) -> str:
+ def pgettext(c: str, s: str) -> str: # type: ignore[misc]
return s
if hasattr(translations, "npgettext"):
npgettext = translations.npgettext
else:
- def npgettext(c: str, s: str, p: str, n: int) -> str:
+ def npgettext(c: str, s: str, p: str, n: int) -> str: # type: ignore[misc]
return s if n == 1 else p
self._install_callables(
diff --git a/filters.py b/filters.py
index 8b09247..2bcba4f 100755
--- a/filters.py
+++ b/filters.py
@@ -1,10 +1,12 @@
"""Built-in template filters used with the ``|`` operator."""
+
import math
import random
import re
import typing
import typing as t
from collections import abc
+from inspect import getattr_static
from itertools import chain
from itertools import groupby
@@ -28,6 +30,7 @@ from .utils import urlize
if t.TYPE_CHECKING:
import typing_extensions as te
+
from .environment import Environment
from .nodes import EvalContext
from .runtime import Context
@@ -122,7 +125,7 @@ def make_multi_attrgetter(
def _prepare_attribute_parts(
- attr: t.Optional[t.Union[str, int]]
+ attr: t.Optional[t.Union[str, int]],
) -> t.List[t.Union[str, int]]:
if attr is None:
return []
@@ -142,7 +145,7 @@ def do_forceescape(value: "t.Union[str, HasHTML]") -> Markup:
def do_urlencode(
- value: t.Union[str, t.Mapping[str, t.Any], t.Iterable[t.Tuple[str, t.Any]]]
+ value: t.Union[str, t.Mapping[str, t.Any], t.Iterable[t.Tuple[str, t.Any]]],
) -> str:
"""Quote data for use in a URL path or query using UTF-8.
@@ -267,6 +270,7 @@ def do_xmlattr(
sign, this fails with a ``ValueError``. Regardless of this, user input
should never be used as keys to this filter, or must be separately validated
first.
+
.. sourcecode:: html+jinja
"t.Iterator[V]":
+ return sync_do_unique(
+ environment, await auto_to_list(value), case_sensitive, attribute
+ )
+
+
def _min_or_max(
environment: "Environment",
value: "t.Iterable[V]",
@@ -563,7 +579,7 @@ def do_default(
@pass_eval_context
def sync_do_join(
eval_ctx: "EvalContext",
- value: t.Iterable,
+ value: t.Iterable[t.Any],
d: str = "",
attribute: t.Optional[t.Union[str, int]] = None,
) -> str:
@@ -621,7 +637,7 @@ def sync_do_join(
@async_variant(sync_do_join) # type: ignore
async def do_join(
eval_ctx: "EvalContext",
- value: t.Union[t.AsyncIterable, t.Iterable],
+ value: t.Union[t.AsyncIterable[t.Any], t.Iterable[t.Any]],
d: str = "",
attribute: t.Optional[t.Union[str, int]] = None,
) -> str:
@@ -984,7 +1000,7 @@ def do_int(value: t.Any, default: int = 0, base: int = 10) -> int:
# this quirk is necessary so that "42.23"|int gives 42.
try:
return int(float(value))
- except (TypeError, ValueError):
+ except (TypeError, ValueError, OverflowError):
return default
@@ -1113,7 +1129,7 @@ def do_batch(
{%- endfor %}
"""
- tmp: "t.List[V]" = []
+ tmp: t.List[V] = []
for item in value:
if len(tmp) == linecount:
@@ -1171,7 +1187,7 @@ def do_round(
class _GroupTuple(t.NamedTuple):
grouper: t.Any
- list: t.List
+ list: t.List[t.Any]
# Use the regular tuple repr to hide this subclass if users print
# out the value during debugging.
@@ -1367,13 +1383,11 @@ def do_mark_unsafe(value: str) -> str:
@typing.overload
-def do_reverse(value: str) -> str:
- ...
+def do_reverse(value: str) -> str: ...
@typing.overload
-def do_reverse(value: "t.Iterable[V]") -> "t.Iterable[V]":
- ...
+def do_reverse(value: "t.Iterable[V]") -> "t.Iterable[V]": ...
def do_reverse(value: t.Union[str, t.Iterable[V]]) -> t.Union[str, t.Iterable[V]]:
@@ -1398,55 +1412,51 @@ def do_reverse(value: t.Union[str, t.Iterable[V]]) -> t.Union[str, t.Iterable[V]
def do_attr(
environment: "Environment", obj: t.Any, name: str
) -> t.Union[Undefined, t.Any]:
- """Get an attribute of an object. ``foo|attr("bar")`` works like
- ``foo.bar`` just that always an attribute is returned and items are not
- looked up.
+ """Get an attribute of an object. ``foo|attr("bar")`` works like
+ ``foo.bar``, but returns undefined instead of falling back to ``foo["bar"]``
+ if the attribute doesn't exist.
See :ref:`Notes on subscriptions ` for more details.
"""
+ # Environment.getattr will fall back to obj[name] if obj.name doesn't exist.
+ # But we want to call env.getattr to get behavior such as sandboxing.
+ # Determine if the attr exists first, so we know the fallback won't trigger.
try:
- name = str(name)
- except UnicodeError:
- pass
- else:
- try:
- value = getattr(obj, name)
- except AttributeError:
- pass
- else:
- if environment.sandboxed:
- environment = t.cast("SandboxedEnvironment", environment)
+ # This avoids executing properties/descriptors, but misses __getattr__
+ # and __getattribute__ dynamic attrs.
+ getattr_static(obj, name)
+ except AttributeError:
+ # This finds dynamic attrs, and we know it's not a descriptor at this point.
+ if not hasattr(obj, name):
+ return environment.undefined(obj=obj, name=name)
- if not environment.is_safe_attribute(obj, name, value):
- return environment.unsafe_undefined(obj, name)
-
- return value
-
- return environment.undefined(obj=obj, name=name)
-
-
-@typing.overload
-def sync_do_map(
- context: "Context", value: t.Iterable, name: str, *args: t.Any, **kwargs: t.Any
-) -> t.Iterable:
- ...
+ return environment.getattr(obj, name)
@typing.overload
def sync_do_map(
context: "Context",
- value: t.Iterable,
+ value: t.Iterable[t.Any],
+ name: str,
+ *args: t.Any,
+ **kwargs: t.Any,
+) -> t.Iterable[t.Any]: ...
+
+
+@typing.overload
+def sync_do_map(
+ context: "Context",
+ value: t.Iterable[t.Any],
*,
attribute: str = ...,
default: t.Optional[t.Any] = None,
-) -> t.Iterable:
- ...
+) -> t.Iterable[t.Any]: ...
@pass_context
def sync_do_map(
- context: "Context", value: t.Iterable, *args: t.Any, **kwargs: t.Any
-) -> t.Iterable:
+ context: "Context", value: t.Iterable[t.Any], *args: t.Any, **kwargs: t.Any
+) -> t.Iterable[t.Any]:
"""Applies a filter on a sequence of objects or looks up an attribute.
This is useful when dealing with lists of objects but you are really
only interested in a certain value of it.
@@ -1496,32 +1506,30 @@ def sync_do_map(
@typing.overload
def do_map(
context: "Context",
- value: t.Union[t.AsyncIterable, t.Iterable],
+ value: t.Union[t.AsyncIterable[t.Any], t.Iterable[t.Any]],
name: str,
*args: t.Any,
**kwargs: t.Any,
-) -> t.Iterable:
- ...
+) -> t.Iterable[t.Any]: ...
@typing.overload
def do_map(
context: "Context",
- value: t.Union[t.AsyncIterable, t.Iterable],
+ value: t.Union[t.AsyncIterable[t.Any], t.Iterable[t.Any]],
*,
attribute: str = ...,
default: t.Optional[t.Any] = None,
-) -> t.Iterable:
- ...
+) -> t.Iterable[t.Any]: ...
@async_variant(sync_do_map) # type: ignore
async def do_map(
context: "Context",
- value: t.Union[t.AsyncIterable, t.Iterable],
+ value: t.Union[t.AsyncIterable[t.Any], t.Iterable[t.Any]],
*args: t.Any,
**kwargs: t.Any,
-) -> t.AsyncIterable:
+) -> t.AsyncIterable[t.Any]:
if value:
func = prepare_map(context, args, kwargs)
@@ -1628,8 +1636,8 @@ def sync_do_selectattr(
.. code-block:: python
- (u for user in users if user.is_active)
- (u for user in users if test_none(user.email))
+ (user for user in users if user.is_active)
+ (user for user in users if test_none(user.email))
.. versionadded:: 2.7
"""
@@ -1666,8 +1674,8 @@ def sync_do_rejectattr(
.. code-block:: python
- (u for user in users if not user.is_active)
- (u for user in users if not test_none(user.email))
+ (user for user in users if not user.is_active)
+ (user for user in users if not test_none(user.email))
.. versionadded:: 2.7
"""
@@ -1714,7 +1722,7 @@ def do_tojson(
def prepare_map(
- context: "Context", args: t.Tuple, kwargs: t.Dict[str, t.Any]
+ context: "Context", args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any]
) -> t.Callable[[t.Any], t.Any]:
if not args and "attribute" in kwargs:
attribute = kwargs.pop("attribute")
@@ -1743,7 +1751,7 @@ def prepare_map(
def prepare_select_or_reject(
context: "Context",
- args: t.Tuple,
+ args: t.Tuple[t.Any, ...],
kwargs: t.Dict[str, t.Any],
modfunc: t.Callable[[t.Any], t.Any],
lookup_attr: bool,
@@ -1767,7 +1775,7 @@ def prepare_select_or_reject(
args = args[1 + off :]
def func(item: t.Any) -> t.Any:
- return context.environment.call_test(name, item, args, kwargs)
+ return context.environment.call_test(name, item, args, kwargs, context)
except LookupError:
func = bool # type: ignore
@@ -1778,7 +1786,7 @@ def prepare_select_or_reject(
def select_or_reject(
context: "Context",
value: "t.Iterable[V]",
- args: t.Tuple,
+ args: t.Tuple[t.Any, ...],
kwargs: t.Dict[str, t.Any],
modfunc: t.Callable[[t.Any], t.Any],
lookup_attr: bool,
@@ -1794,7 +1802,7 @@ def select_or_reject(
async def async_select_or_reject(
context: "Context",
value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
- args: t.Tuple,
+ args: t.Tuple[t.Any, ...],
kwargs: t.Dict[str, t.Any],
modfunc: t.Callable[[t.Any], t.Any],
lookup_attr: bool,
diff --git a/idtracking.py b/idtracking.py
index 995ebaa..e6dd8cd 100755
--- a/idtracking.py
+++ b/idtracking.py
@@ -3,6 +3,9 @@ import typing as t
from . import nodes
from .visitor import NodeVisitor
+if t.TYPE_CHECKING:
+ import typing_extensions as te
+
VAR_LOAD_PARAMETER = "param"
VAR_LOAD_RESOLVE = "resolve"
VAR_LOAD_ALIAS = "alias"
@@ -83,7 +86,7 @@ class Symbols:
)
return rv
- def copy(self) -> "Symbols":
+ def copy(self) -> "te.Self":
rv = object.__new__(self.__class__)
rv.__dict__.update(self.__dict__)
rv.refs = self.refs.copy()
@@ -118,23 +121,20 @@ class Symbols:
self._define_ref(name, load=(VAR_LOAD_RESOLVE, name))
def branch_update(self, branch_symbols: t.Sequence["Symbols"]) -> None:
- stores: t.Dict[str, int] = {}
+ stores: t.Set[str] = set()
+
for branch in branch_symbols:
- for target in branch.stores:
- if target in self.stores:
- continue
- stores[target] = stores.get(target, 0) + 1
+ stores.update(branch.stores)
+
+ stores.difference_update(self.stores)
for sym in branch_symbols:
self.refs.update(sym.refs)
self.loads.update(sym.loads)
self.stores.update(sym.stores)
- for name, branch_count in stores.items():
- if branch_count == len(branch_symbols):
- continue
-
- target = self.find_ref(name) # type: ignore
+ for name in stores:
+ target = self.find_ref(name)
assert target is not None, "should not happen"
if self.parent is not None:
@@ -146,7 +146,7 @@ class Symbols:
def dump_stores(self) -> t.Dict[str, str]:
rv: t.Dict[str, str] = {}
- node: t.Optional["Symbols"] = self
+ node: t.Optional[Symbols] = self
while node is not None:
for name in sorted(node.stores):
@@ -159,7 +159,7 @@ class Symbols:
def dump_param_targets(self) -> t.Set[str]:
rv = set()
- node: t.Optional["Symbols"] = self
+ node: t.Optional[Symbols] = self
while node is not None:
for target, (instr, _) in self.loads.items():
diff --git a/lexer.py b/lexer.py
index aff7e9f..9b1c969 100755
--- a/lexer.py
+++ b/lexer.py
@@ -3,6 +3,7 @@ is used to do some preprocessing. It filters out invalid operators like
the bitshift operators we don't allow in templates. It separates
template code and python code in expressions.
"""
+
import re
import typing as t
from ast import literal_eval
@@ -15,6 +16,7 @@ from .utils import LRUCache
if t.TYPE_CHECKING:
import typing_extensions as te
+
from .environment import Environment
# cache for the lexers. Exists in order to be able to have multiple
@@ -260,7 +262,7 @@ class Failure:
self.message = message
self.error_class = cls
- def __call__(self, lineno: int, filename: str) -> "te.NoReturn":
+ def __call__(self, lineno: int, filename: t.Optional[str]) -> "te.NoReturn":
raise self.error_class(self.message, lineno, filename)
@@ -327,7 +329,7 @@ class TokenStream:
filename: t.Optional[str],
):
self._iter = iter(generator)
- self._pushed: "te.Deque[Token]" = deque()
+ self._pushed: te.Deque[Token] = deque()
self.name = name
self.filename = filename
self.closed = False
@@ -447,7 +449,7 @@ def get_lexer(environment: "Environment") -> "Lexer":
return lexer
-class OptionalLStrip(tuple):
+class OptionalLStrip(tuple): # type: ignore[type-arg]
"""A special tuple for marking a point in the state that can have
lstrip applied.
"""
@@ -755,7 +757,7 @@ class Lexer:
for idx, token in enumerate(tokens):
# failure group
- if token.__class__ is Failure:
+ if isinstance(token, Failure):
raise token(lineno, filename)
# bygroup is a bit more complex, in that case we
# yield for the current token the first named
@@ -776,7 +778,7 @@ class Lexer:
data = groups[idx]
if data or token not in ignore_if_empty:
- yield lineno, token, data
+ yield lineno, token, data # type: ignore[misc]
lineno += data.count("\n") + newlines_stripped
newlines_stripped = 0
diff --git a/loaders.py b/loaders.py
index 32f3a74..3913ee5 100755
--- a/loaders.py
+++ b/loaders.py
@@ -1,6 +1,7 @@
"""API and implementations for loading templates from different data
sources.
"""
+
import importlib.util
import os
import posixpath
@@ -177,7 +178,9 @@ class FileSystemLoader(BaseLoader):
def __init__(
self,
- searchpath: t.Union[str, os.PathLike, t.Sequence[t.Union[str, os.PathLike]]],
+ searchpath: t.Union[
+ str, "os.PathLike[str]", t.Sequence[t.Union[str, "os.PathLike[str]"]]
+ ],
encoding: str = "utf-8",
followlinks: bool = False,
) -> None:
@@ -201,7 +204,12 @@ class FileSystemLoader(BaseLoader):
if os.path.isfile(filename):
break
else:
- raise TemplateNotFound(template)
+ plural = "path" if len(self.searchpath) == 1 else "paths"
+ paths_str = ", ".join(repr(p) for p in self.searchpath)
+ raise TemplateNotFound(
+ template,
+ f"{template!r} not found in search {plural}: {paths_str}",
+ )
with open(filename, encoding=self.encoding) as f:
contents = f.read()
@@ -235,6 +243,30 @@ class FileSystemLoader(BaseLoader):
return sorted(found)
+if sys.version_info >= (3, 13):
+
+ def _get_zipimporter_files(z: t.Any) -> t.Dict[str, object]:
+ try:
+ get_files = z._get_files
+ except AttributeError as e:
+ raise TypeError(
+ "This zip import does not have the required"
+ " metadata to list templates."
+ ) from e
+ return get_files()
+else:
+
+ def _get_zipimporter_files(z: t.Any) -> t.Dict[str, object]:
+ try:
+ files = z._files
+ except AttributeError as e:
+ raise TypeError(
+ "This zip import does not have the required"
+ " metadata to list templates."
+ ) from e
+ return files # type: ignore[no-any-return]
+
+
class PackageLoader(BaseLoader):
"""Load templates from a directory in a Python package.
@@ -295,7 +327,6 @@ class PackageLoader(BaseLoader):
assert loader is not None, "A loader was not found for the package."
self._loader = loader
self._archive = None
- template_root = None
if isinstance(loader, zipimport.zipimporter):
self._archive = loader.archive
@@ -312,18 +343,23 @@ class PackageLoader(BaseLoader):
elif spec.origin is not None:
roots.append(os.path.dirname(spec.origin))
+ if not roots:
+ raise ValueError(
+ f"The {package_name!r} package was not installed in a"
+ " way that PackageLoader understands."
+ )
+
for root in roots:
root = os.path.join(root, package_path)
if os.path.isdir(root):
template_root = root
break
-
- if template_root is None:
- raise ValueError(
- f"The {package_name!r} package was not installed in a"
- " way that PackageLoader understands."
- )
+ else:
+ raise ValueError(
+ f"PackageLoader could not find a {package_path!r} directory"
+ f" in the {package_name!r} package."
+ )
self._template_root = template_root
@@ -379,11 +415,7 @@ class PackageLoader(BaseLoader):
for name in filenames
)
else:
- if not hasattr(self._loader, "_files"):
- raise TypeError(
- "This zip import does not have the required"
- " metadata to list templates."
- )
+ files = _get_zipimporter_files(self._loader)
# Package is a zip file.
prefix = (
@@ -392,7 +424,7 @@ class PackageLoader(BaseLoader):
)
offset = len(prefix)
- for name in self._loader._files.keys():
+ for name in files:
# Find names under the templates directory that aren't directories.
if name.startswith(prefix) and name[-1] != os.path.sep:
results.append(name[offset:].replace(os.path.sep, "/"))
@@ -407,7 +439,7 @@ class DictLoader(BaseLoader):
>>> loader = DictLoader({'index.html': 'source here'})
- Because auto reloading is rarely useful this is disabled per default.
+ Because auto reloading is rarely useful this is disabled by default.
"""
def __init__(self, mapping: t.Mapping[str, str]) -> None:
@@ -590,10 +622,7 @@ class ModuleLoader(BaseLoader):
Example usage:
- >>> loader = ChoiceLoader([
- ... ModuleLoader('/path/to/compiled/templates'),
- ... FileSystemLoader('/path/to/templates')
- ... ])
+ >>> loader = ModuleLoader('/path/to/compiled/templates')
Templates can be precompiled with :meth:`Environment.compile_templates`.
"""
@@ -601,7 +630,10 @@ class ModuleLoader(BaseLoader):
has_source_access = False
def __init__(
- self, path: t.Union[str, os.PathLike, t.Sequence[t.Union[str, os.PathLike]]]
+ self,
+ path: t.Union[
+ str, "os.PathLike[str]", t.Sequence[t.Union[str, "os.PathLike[str]"]]
+ ],
) -> None:
package_name = f"_jinja2_module_templates_{id(self):x}"
diff --git a/meta.py b/meta.py
index 0057d6e..298499e 100755
--- a/meta.py
+++ b/meta.py
@@ -1,6 +1,7 @@
"""Functions that expose information about templates that might be
interesting for introspection.
"""
+
import typing as t
from . import nodes
diff --git a/nodes.py b/nodes.py
index b2f88d9..2f93b90 100755
--- a/nodes.py
+++ b/nodes.py
@@ -2,6 +2,7 @@
some node tree helper functions used by the parser and compiler in order
to normalize nodes.
"""
+
import inspect
import operator
import typing as t
@@ -13,6 +14,7 @@ from .utils import _PassArg
if t.TYPE_CHECKING:
import typing_extensions as te
+
from .environment import Environment
_NodeBound = t.TypeVar("_NodeBound", bound="Node")
@@ -56,7 +58,7 @@ class NodeType(type):
def __new__(mcs, name, bases, d): # type: ignore
for attr in "fields", "attributes":
- storage = []
+ storage: t.List[t.Tuple[str, ...]] = []
storage.extend(getattr(bases[0] if bases else object, attr, ()))
storage.extend(d.get(attr, ()))
assert len(bases) <= 1, "multiple inheritance not allowed"
diff --git a/optimizer.py b/optimizer.py
index fe10107..32d1c71 100755
--- a/optimizer.py
+++ b/optimizer.py
@@ -7,6 +7,7 @@ want. For example, loop unrolling doesn't work because unrolled loops
would have a different scope. The solution would be a second syntax tree
that stored the scoping rules.
"""
+
import typing as t
from . import nodes
diff --git a/parser.py b/parser.py
index 3354bc9..f411775 100755
--- a/parser.py
+++ b/parser.py
@@ -1,4 +1,5 @@
"""Parse tokens from the lexer into nodes for the compiler."""
+
import typing
import typing as t
@@ -10,6 +11,7 @@ from .lexer import describe_token_expr
if t.TYPE_CHECKING:
import typing_extensions as te
+
from .environment import Environment
_ImportInclude = t.TypeVar("_ImportInclude", nodes.Import, nodes.Include)
@@ -62,7 +64,7 @@ class Parser:
self.filename = filename
self.closed = False
self.extensions: t.Dict[
- str, t.Callable[["Parser"], t.Union[nodes.Node, t.List[nodes.Node]]]
+ str, t.Callable[[Parser], t.Union[nodes.Node, t.List[nodes.Node]]]
] = {}
for extension in environment.iter_extensions():
for tag in extension.tags:
@@ -457,8 +459,7 @@ class Parser:
@typing.overload
def parse_assign_target(
self, with_tuple: bool = ..., name_only: "te.Literal[True]" = ...
- ) -> nodes.Name:
- ...
+ ) -> nodes.Name: ...
@typing.overload
def parse_assign_target(
@@ -467,8 +468,7 @@ class Parser:
name_only: bool = False,
extra_end_rules: t.Optional[t.Tuple[str, ...]] = None,
with_namespace: bool = False,
- ) -> t.Union[nodes.NSRef, nodes.Name, nodes.Tuple]:
- ...
+ ) -> t.Union[nodes.NSRef, nodes.Name, nodes.Tuple]: ...
def parse_assign_target(
self,
@@ -487,21 +487,18 @@ class Parser:
"""
target: nodes.Expr
- if with_namespace and self.stream.look().type == "dot":
- token = self.stream.expect("name")
- next(self.stream) # dot
- attr = self.stream.expect("name")
- target = nodes.NSRef(token.value, attr.value, lineno=token.lineno)
- elif name_only:
+ if name_only:
token = self.stream.expect("name")
target = nodes.Name(token.value, "store", lineno=token.lineno)
else:
if with_tuple:
target = self.parse_tuple(
- simplified=True, extra_end_rules=extra_end_rules
+ simplified=True,
+ extra_end_rules=extra_end_rules,
+ with_namespace=with_namespace,
)
else:
- target = self.parse_primary()
+ target = self.parse_primary(with_namespace=with_namespace)
target.set_ctx("store")
@@ -643,17 +640,25 @@ class Parser:
node = self.parse_filter_expr(node)
return node
- def parse_primary(self) -> nodes.Expr:
+ def parse_primary(self, with_namespace: bool = False) -> nodes.Expr:
+ """Parse a name or literal value. If ``with_namespace`` is enabled, also
+ parse namespace attr refs, for use in assignments."""
token = self.stream.current
node: nodes.Expr
if token.type == "name":
+ next(self.stream)
if token.value in ("true", "false", "True", "False"):
node = nodes.Const(token.value in ("true", "True"), lineno=token.lineno)
elif token.value in ("none", "None"):
node = nodes.Const(None, lineno=token.lineno)
+ elif with_namespace and self.stream.current.type == "dot":
+ # If namespace attributes are allowed at this point, and the next
+ # token is a dot, produce a namespace reference.
+ next(self.stream)
+ attr = self.stream.expect("name")
+ node = nodes.NSRef(token.value, attr.value, lineno=token.lineno)
else:
node = nodes.Name(token.value, "load", lineno=token.lineno)
- next(self.stream)
elif token.type == "string":
next(self.stream)
buf = [token.value]
@@ -683,6 +688,7 @@ class Parser:
with_condexpr: bool = True,
extra_end_rules: t.Optional[t.Tuple[str, ...]] = None,
explicit_parentheses: bool = False,
+ with_namespace: bool = False,
) -> t.Union[nodes.Tuple, nodes.Expr]:
"""Works like `parse_expression` but if multiple expressions are
delimited by a comma a :class:`~jinja2.nodes.Tuple` node is created.
@@ -690,8 +696,9 @@ class Parser:
if no commas where found.
The default parsing mode is a full tuple. If `simplified` is `True`
- only names and literals are parsed. The `no_condexpr` parameter is
- forwarded to :meth:`parse_expression`.
+ only names and literals are parsed; ``with_namespace`` allows namespace
+ attr refs as well. The `no_condexpr` parameter is forwarded to
+ :meth:`parse_expression`.
Because tuples do not require delimiters and may end in a bogus comma
an extra hint is needed that marks the end of a tuple. For example
@@ -704,13 +711,14 @@ class Parser:
"""
lineno = self.stream.current.lineno
if simplified:
- parse = self.parse_primary
- elif with_condexpr:
- parse = self.parse_expression
+
+ def parse() -> nodes.Expr:
+ return self.parse_primary(with_namespace=with_namespace)
+
else:
def parse() -> nodes.Expr:
- return self.parse_expression(with_condexpr=False)
+ return self.parse_expression(with_condexpr=with_condexpr)
args: t.List[nodes.Expr] = []
is_tuple = False
@@ -861,7 +869,14 @@ class Parser:
return nodes.Slice(lineno=lineno, *args) # noqa: B026
- def parse_call_args(self) -> t.Tuple:
+ def parse_call_args(
+ self,
+ ) -> t.Tuple[
+ t.List[nodes.Expr],
+ t.List[nodes.Keyword],
+ t.Optional[nodes.Expr],
+ t.Optional[nodes.Expr],
+ ]:
token = self.stream.expect("lparen")
args = []
kwargs = []
@@ -952,7 +967,7 @@ class Parser:
next(self.stream)
name += "." + self.stream.expect("name").value
dyn_args = dyn_kwargs = None
- kwargs = []
+ kwargs: t.List[nodes.Keyword] = []
if self.stream.current.type == "lparen":
args, kwargs, dyn_args, dyn_kwargs = self.parse_call_args()
elif self.stream.current.type in {
diff --git a/runtime.py b/runtime.py
index 58a540b..09119e2 100755
--- a/runtime.py
+++ b/runtime.py
@@ -1,4 +1,5 @@
"""The runtime functions and state used by compiled templates."""
+
import functools
import sys
import typing as t
@@ -28,7 +29,9 @@ F = t.TypeVar("F", bound=t.Callable[..., t.Any])
if t.TYPE_CHECKING:
import logging
+
import typing_extensions as te
+
from .environment import Environment
class LoopRenderFunc(te.Protocol):
@@ -37,8 +40,7 @@ if t.TYPE_CHECKING:
reciter: t.Iterable[V],
loop_render_func: "LoopRenderFunc",
depth: int = 0,
- ) -> str:
- ...
+ ) -> str: ...
# these variables are exported to the template runtime
@@ -170,7 +172,7 @@ class Context:
):
self.parent = parent
self.vars: t.Dict[str, t.Any] = {}
- self.environment: "Environment" = environment
+ self.environment: Environment = environment
self.eval_ctx = EvalContext(self.environment, name)
self.exported_vars: t.Set[str] = set()
self.name = name
@@ -259,7 +261,10 @@ class Context:
@internalcode
def call(
- __self, __obj: t.Callable, *args: t.Any, **kwargs: t.Any # noqa: B902
+ __self,
+ __obj: t.Callable[..., t.Any],
+ *args: t.Any,
+ **kwargs: t.Any, # noqa: B902
) -> t.Union[t.Any, "Undefined"]:
"""Call the callable with the arguments and keyword arguments
provided but inject the active context or environment as first
@@ -362,7 +367,7 @@ class BlockReference:
@internalcode
async def _async_call(self) -> str:
- rv = concat(
+ rv = self._context.environment.concat( # type: ignore
[x async for x in self._stack[self._depth](self._context)] # type: ignore
)
@@ -376,7 +381,9 @@ class BlockReference:
if self._context.environment.is_async:
return self._async_call() # type: ignore
- rv = concat(self._stack[self._depth](self._context))
+ rv = self._context.environment.concat( # type: ignore
+ self._stack[self._depth](self._context)
+ )
if self._context.eval_ctx.autoescape:
return Markup(rv)
@@ -586,7 +593,7 @@ class AsyncLoopContext(LoopContext):
@staticmethod
def _to_iterator( # type: ignore
- iterable: t.Union[t.Iterable[V], t.AsyncIterable[V]]
+ iterable: t.Union[t.Iterable[V], t.AsyncIterable[V]],
) -> t.AsyncIterator[V]:
return auto_aiter(iterable)
@@ -787,8 +794,8 @@ class Macro:
class Undefined:
- """The default undefined type. This undefined type can be printed and
- iterated over, but every other access will raise an :exc:`UndefinedError`:
+ """The default undefined type. This can be printed, iterated, and treated as
+ a boolean. Any other operation will raise an :exc:`UndefinedError`.
>>> foo = Undefined(name='foo')
>>> str(foo)
@@ -853,7 +860,11 @@ class Undefined:
@internalcode
def __getattr__(self, name: str) -> t.Any:
- if name[:2] == "__":
+ # Raise AttributeError on requests for names that appear to be unimplemented
+ # dunder methods to keep Python's internal protocol probing behaviors working
+ # properly in cases where another exception type could cause unexpected or
+ # difficult-to-diagnose failures.
+ if name[:2] == "__" and name[-2:] == "__":
raise AttributeError(name)
return self._fail_with_undefined_error()
@@ -977,10 +988,20 @@ class ChainableUndefined(Undefined):
def __html__(self) -> str:
return str(self)
- def __getattr__(self, _: str) -> "ChainableUndefined":
+ def __getattr__(self, name: str) -> "ChainableUndefined":
+ # Raise AttributeError on requests for names that appear to be unimplemented
+ # dunder methods to avoid confusing Python with truthy non-method objects that
+ # do not implement the protocol being probed for. e.g., copy.copy(Undefined())
+ # fails spectacularly if getattr(Undefined(), '__setstate__') returns an
+ # Undefined object instead of raising AttributeError to signal that it does not
+ # support that style of object initialization.
+ if name[:2] == "__" and name[-2:] == "__":
+ raise AttributeError(name)
+
return self
- __getitem__ = __getattr__ # type: ignore
+ def __getitem__(self, _name: str) -> "ChainableUndefined": # type: ignore[override]
+ return self
class DebugUndefined(Undefined):
@@ -1039,13 +1060,3 @@ class StrictUndefined(Undefined):
__iter__ = __str__ = __len__ = Undefined._fail_with_undefined_error
__eq__ = __ne__ = __bool__ = __hash__ = Undefined._fail_with_undefined_error
__contains__ = Undefined._fail_with_undefined_error
-
-
-# Remove slots attributes, after the metaclass is applied they are
-# unneeded and contain wrong data for subclasses.
-del (
- Undefined.__slots__,
- ChainableUndefined.__slots__,
- DebugUndefined.__slots__,
- StrictUndefined.__slots__,
-)
diff --git a/sandbox.py b/sandbox.py
index 06d7414..9c9dae2 100755
--- a/sandbox.py
+++ b/sandbox.py
@@ -1,12 +1,14 @@
"""A sandbox layer that ensures unsafe operations cannot be performed.
Useful when the template itself comes from an untrusted source.
"""
+
import operator
import types
import typing as t
from _string import formatter_field_name_split # type: ignore
from collections import abc
from collections import deque
+from functools import update_wrapper
from string import Formatter
from markupsafe import EscapeFormatter
@@ -37,7 +39,7 @@ UNSAFE_COROUTINE_ATTRIBUTES = {"cr_frame", "cr_code"}
#: unsafe attributes on async generators
UNSAFE_ASYNC_GENERATOR_ATTRIBUTES = {"ag_code", "ag_frame"}
-_mutable_spec: t.Tuple[t.Tuple[t.Type, t.FrozenSet[str]], ...] = (
+_mutable_spec: t.Tuple[t.Tuple[t.Type[t.Any], t.FrozenSet[str]], ...] = (
(
abc.MutableSet,
frozenset(
@@ -59,7 +61,9 @@ _mutable_spec: t.Tuple[t.Tuple[t.Type, t.FrozenSet[str]], ...] = (
),
(
abc.MutableSequence,
- frozenset(["append", "reverse", "insert", "sort", "extend", "remove"]),
+ frozenset(
+ ["append", "clear", "pop", "reverse", "insert", "sort", "extend", "remove"]
+ ),
),
(
deque,
@@ -80,20 +84,6 @@ _mutable_spec: t.Tuple[t.Tuple[t.Type, t.FrozenSet[str]], ...] = (
)
-def inspect_format_method(callable: t.Callable) -> t.Optional[str]:
- if not isinstance(
- callable, (types.MethodType, types.BuiltinMethodType)
- ) or callable.__name__ not in ("format", "format_map"):
- return None
-
- obj = callable.__self__
-
- if isinstance(obj, str):
- return obj
-
- return None
-
-
def safe_range(*args: int) -> range:
"""A range that can't generate ranges with a length of more than
MAX_RANGE items.
@@ -313,6 +303,9 @@ class SandboxedEnvironment(Environment):
except AttributeError:
pass
else:
+ fmt = self.wrap_str_format(value)
+ if fmt is not None:
+ return fmt
if self.is_safe_attribute(obj, argument, value):
return value
return self.unsafe_undefined(obj, argument)
@@ -330,6 +323,9 @@ class SandboxedEnvironment(Environment):
except (TypeError, LookupError):
pass
else:
+ fmt = self.wrap_str_format(value)
+ if fmt is not None:
+ return fmt
if self.is_safe_attribute(obj, attribute, value):
return value
return self.unsafe_undefined(obj, attribute)
@@ -345,34 +341,49 @@ class SandboxedEnvironment(Environment):
exc=SecurityError,
)
- def format_string(
- self,
- s: str,
- args: t.Tuple[t.Any, ...],
- kwargs: t.Dict[str, t.Any],
- format_func: t.Optional[t.Callable] = None,
- ) -> str:
- """If a format call is detected, then this is routed through this
- method so that our safety sandbox can be used for it.
+ def wrap_str_format(self, value: t.Any) -> t.Optional[t.Callable[..., str]]:
+ """If the given value is a ``str.format`` or ``str.format_map`` method,
+ return a new function than handles sandboxing. This is done at access
+ rather than in :meth:`call`, so that calls made without ``call`` are
+ also sandboxed.
"""
+ if not isinstance(
+ value, (types.MethodType, types.BuiltinMethodType)
+ ) or value.__name__ not in ("format", "format_map"):
+ return None
+
+ f_self: t.Any = value.__self__
+
+ if not isinstance(f_self, str):
+ return None
+
+ str_type: t.Type[str] = type(f_self)
+ is_format_map = value.__name__ == "format_map"
formatter: SandboxedFormatter
- if isinstance(s, Markup):
- formatter = SandboxedEscapeFormatter(self, escape=s.escape)
+
+ if isinstance(f_self, Markup):
+ formatter = SandboxedEscapeFormatter(self, escape=f_self.escape)
else:
formatter = SandboxedFormatter(self)
- if format_func is not None and format_func.__name__ == "format_map":
- if len(args) != 1 or kwargs:
- raise TypeError(
- "format_map() takes exactly one argument"
- f" {len(args) + (kwargs is not None)} given"
- )
+ vformat = formatter.vformat
- kwargs = args[0]
- args = ()
+ def wrapper(*args: t.Any, **kwargs: t.Any) -> str:
+ if is_format_map:
+ if kwargs:
+ raise TypeError("format_map() takes no keyword arguments")
- rv = formatter.vformat(s, args, kwargs)
- return type(s)(rv)
+ if len(args) != 1:
+ raise TypeError(
+ f"format_map() takes exactly one argument ({len(args)} given)"
+ )
+
+ kwargs = args[0]
+ args = ()
+
+ return str_type(vformat(f_self, args, kwargs))
+
+ return update_wrapper(wrapper, value)
def call(
__self, # noqa: B902
@@ -382,9 +393,6 @@ class SandboxedEnvironment(Environment):
**kwargs: t.Any,
) -> t.Any:
"""Call an object from sandboxed code."""
- fmt = inspect_format_method(__obj)
- if fmt is not None:
- return __self.format_string(fmt, args, kwargs, __obj)
# the double prefixes are to avoid double keyword argument
# errors when proxying the call.
diff --git a/tests.py b/tests.py
index a467cf0..1a59e37 100755
--- a/tests.py
+++ b/tests.py
@@ -1,4 +1,5 @@
"""Built-in template tests used with the ``is`` operator."""
+
import operator
import typing as t
from collections import abc
@@ -169,7 +170,7 @@ def test_sequence(value: t.Any) -> bool:
"""
try:
len(value)
- value.__getitem__
+ value.__getitem__ # noqa B018
except Exception:
return False
@@ -204,7 +205,7 @@ def test_escaped(value: t.Any) -> bool:
return hasattr(value, "__html__")
-def test_in(value: t.Any, seq: t.Container) -> bool:
+def test_in(value: t.Any, seq: t.Container[t.Any]) -> bool:
"""Check if value is in seq.
.. versionadded:: 2.10
diff --git a/utils.py b/utils.py
index 18914a5..7c92262 100755
--- a/utils.py
+++ b/utils.py
@@ -18,8 +18,17 @@ if t.TYPE_CHECKING:
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
-# special singleton representing missing values for the runtime
-missing: t.Any = type("MissingType", (), {"__repr__": lambda x: "missing"})()
+
+class _MissingType:
+ def __repr__(self) -> str:
+ return "missing"
+
+ def __reduce__(self) -> str:
+ return "missing"
+
+
+missing: t.Any = _MissingType()
+"""Special singleton representing missing values for the runtime."""
internal_code: t.MutableSet[CodeType] = set()
@@ -152,7 +161,7 @@ def import_string(import_name: str, silent: bool = False) -> t.Any:
raise
-def open_if_exists(filename: str, mode: str = "rb") -> t.Optional[t.IO]:
+def open_if_exists(filename: str, mode: str = "rb") -> t.Optional[t.IO[t.Any]]:
"""Returns a file descriptor for the filename if that file exists,
otherwise ``None``.
"""
@@ -324,6 +333,8 @@ def urlize(
elif (
"@" in middle
and not middle.startswith("www.")
+ # ignore values like `@a@b`
+ and not middle.startswith("@")
and ":" not in middle
and _email_re.match(middle)
):
@@ -428,7 +439,7 @@ class LRUCache:
def __init__(self, capacity: int) -> None:
self.capacity = capacity
self._mapping: t.Dict[t.Any, t.Any] = {}
- self._queue: "te.Deque[t.Any]" = deque()
+ self._queue: te.Deque[t.Any] = deque()
self._postinit()
def _postinit(self) -> None:
@@ -450,10 +461,10 @@ class LRUCache:
self.__dict__.update(d)
self._postinit()
- def __getnewargs__(self) -> t.Tuple:
+ def __getnewargs__(self) -> t.Tuple[t.Any, ...]:
return (self.capacity,)
- def copy(self) -> "LRUCache":
+ def copy(self) -> "te.Self":
"""Return a shallow copy of the instance."""
rv = self.__class__(self.capacity)
rv._mapping.update(self._mapping)
diff --git a/visitor.py b/visitor.py
index 17c6aab..7b8e180 100755
--- a/visitor.py
+++ b/visitor.py
@@ -1,6 +1,7 @@
"""API for traversing the AST nodes. Implemented by the compiler and
meta introspection.
"""
+
import typing as t
from .nodes import Node
@@ -9,8 +10,7 @@ if t.TYPE_CHECKING:
import typing_extensions as te
class VisitCallable(te.Protocol):
- def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
- ...
+ def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any: ...
class NodeVisitor: