[autofix.ci] apply automated fixes

This commit is contained in:
autofix-ci[bot] 2023-02-27 07:28:30 +00:00 committed by Maximilian Hils
parent 46bfb35488
commit 51670861e6
121 changed files with 496 additions and 649 deletions

View File

@ -3,6 +3,7 @@
## Unreleased: mitmproxy next
* mitmproxy now requires Python 3.10 or above.
([#5954](https://github.com/mitmproxy/mitmproxy/pull/5954), @mhils)
* Fix a bug where the direction indicator in the message stream view would be in the wrong direction.
([#5921](https://github.com/mitmproxy/mitmproxy/issues/5921), @konradh)
* Fix a bug where peername would be None in tls_passthrough script, which would make it not working.

View File

@ -4,7 +4,6 @@ import subprocess
import threading
import time
from typing import NamedTuple
from typing import Optional
import libtmux
@ -74,7 +73,7 @@ class CliDirector:
self.tmux_session.kill_session()
def press_key(
self, keys: str, count=1, pause: Optional[float] = None, target=None
self, keys: str, count=1, pause: float | None = None, target=None
) -> None:
if pause is None:
pause = self.pause_between_keys
@ -97,7 +96,7 @@ class CliDirector:
real_pause += 2 * pause
self.pause(real_pause)
def type(self, keys: str, pause: Optional[float] = None, target=None) -> None:
def type(self, keys: str, pause: float | None = None, target=None) -> None:
if pause is None:
pause = self.pause_between_keys
if target is None:
@ -128,7 +127,7 @@ class CliDirector:
def message(
self,
msg: str,
duration: Optional[int] = None,
duration: int | None = None,
add_instruction: bool = True,
instruction_html: str = "",
) -> None:
@ -161,7 +160,7 @@ class CliDirector:
self.tmux_pane.cmd("display-popup", "-C")
def instruction(
self, instruction: str, duration: float = 3, time_from: Optional[float] = None
self, instruction: str, duration: float = 3, time_from: float | None = None
) -> None:
if time_from is None:
time_from = self.current_time

View File

@ -5,8 +5,6 @@ This example shows how one can add a custom contentview to mitmproxy,
which is used to pretty-print HTTP bodies for example.
The content view API is explained in the mitmproxy.contentviews module.
"""
from typing import Optional
from mitmproxy import contentviews
from mitmproxy import flow
from mitmproxy import http
@ -19,9 +17,9 @@ class ViewSwapCase(contentviews.View):
self,
data: bytes,
*,
content_type: Optional[str] = None,
flow: Optional[flow.Flow] = None,
http_message: Optional[http.Message] = None,
content_type: str | None = None,
flow: flow.Flow | None = None,
http_message: http.Message | None = None,
**unknown_metadata,
) -> contentviews.TViewResult:
return "case-swapped text", contentviews.format_text(data.swapcase())
@ -30,9 +28,9 @@ class ViewSwapCase(contentviews.View):
self,
data: bytes,
*,
content_type: Optional[str] = None,
flow: Optional[flow.Flow] = None,
http_message: Optional[http.Message] = None,
content_type: str | None = None,
flow: flow.Flow | None = None,
http_message: http.Message | None = None,
**unknown_metadata,
) -> float:
if content_type == "text/plain":

View File

@ -8,10 +8,9 @@ Modifying streamed responses is tricky and brittle:
where one chunk ends with [...]foo" and the next starts with "bar[...].
"""
from collections.abc import Iterable
from typing import Union
def modify(data: bytes) -> Union[bytes, Iterable[bytes]]:
def modify(data: bytes) -> bytes | Iterable[bytes]:
"""
This function will be called for each chunk of request/response body data that arrives at the proxy,
and once at the end of the message with an empty bytes argument (b"").

View File

@ -1,7 +1,5 @@
import json
from dataclasses import dataclass
from typing import Optional
from typing import Union
from mitmproxy import ctx
from mitmproxy.addonmanager import Loader
@ -55,8 +53,8 @@ In the following example, we override the HTTP host header:
@dataclass
class Mapping:
server: Union[str, None]
host: Union[str, None]
server: str | None
host: str | None
class HttpsDomainFronting:
@ -70,7 +68,7 @@ class HttpsDomainFronting:
self.strict_mappings = {}
self.star_mappings = {}
def _resolve_addresses(self, host: str) -> Optional[Mapping]:
def _resolve_addresses(self, host: str) -> Mapping | None:
mapping = self.strict_mappings.get(host)
if mapping is not None:
return mapping

View File

@ -15,7 +15,6 @@ Note:
"""
import json
from typing import Union
from mitmproxy import http
@ -29,7 +28,7 @@ FILTER_COOKIES = {
# -- Helper functions --
def load_json_cookies() -> list[dict[str, Union[str, None]]]:
def load_json_cookies() -> list[dict[str, str | None]]:
"""
Load a particular json file containing a list of cookies.
"""
@ -40,7 +39,7 @@ def load_json_cookies() -> list[dict[str, Union[str, None]]]:
# NOTE: or just hardcode the cookies as [{"name": "", "value": ""}]
def stringify_cookies(cookies: list[dict[str, Union[str, None]]]) -> str:
def stringify_cookies(cookies: list[dict[str, str | None]]) -> str:
"""
Creates a cookie string from a list of cookie dicts.
"""
@ -54,7 +53,7 @@ def stringify_cookies(cookies: list[dict[str, Union[str, None]]]) -> str:
)
def parse_cookies(cookie_string: str) -> list[dict[str, Union[str, None]]]:
def parse_cookies(cookie_string: str) -> list[dict[str, str | None]]:
"""
Parses a cookie string into a list of cookie dicts.
"""

View File

@ -115,7 +115,7 @@ class NTLMUpstreamAuth:
def patched_receive_handshake_data(
self, data
) -> layer.CommandGenerator[tuple[bool, Optional[str]]]:
) -> layer.CommandGenerator[tuple[bool, str | None]]:
self.buf += data
response_head = self.buf.maybe_extract_lines()
if response_head:

View File

@ -1,4 +1,4 @@
from typing import Callable
from collections.abc import Callable
from typing import TextIO
from unittest import mock
from unittest.mock import MagicMock

View File

@ -1,12 +1,11 @@
import itertools
import json
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import MutableMapping
from typing import Any
from typing import Callable
from typing import cast
from typing import TextIO
from typing import Union
from mitmproxy import flowfilter
from mitmproxy.http import HTTPFlow
@ -78,7 +77,7 @@ class URLDict(MutableMapping):
return cls._load(json_obj, value_loader)
def _dump(self, value_dumper: Callable = f_id) -> dict:
dumped: dict[Union[flowfilter.TFilter, str], Any] = {}
dumped: dict[flowfilter.TFilter | str, Any] = {}
for fltr, value in self.store.items():
if hasattr(fltr, "pattern"):
# cast necessary for mypy

View File

@ -3,8 +3,6 @@ import datetime
import json
import logging
from pathlib import Path
from typing import Optional
from typing import Union
from mitmproxy import flowfilter
from mitmproxy.http import HTTPFlow
@ -118,7 +116,7 @@ class UrlIndexAddon:
The injection can be done using the URLInjection Add-on.
"""
index_filter: Optional[Union[str, flowfilter.TFilter]]
index_filter: str | flowfilter.TFilter | None
writer: UrlIndexWriter
OPT_FILEPATH = "URLINDEX_FILEPATH"
@ -127,9 +125,9 @@ class UrlIndexAddon:
def __init__(
self,
file_path: Union[str, Path],
file_path: str | Path,
append: bool = True,
index_filter: Union[str, flowfilter.TFilter] = filter_404,
index_filter: str | flowfilter.TFilter = filter_404,
index_format: str = "json",
):
"""Initializes the urlindex add-on.

View File

@ -2,7 +2,6 @@ import logging
import pathlib
import time
from datetime import datetime
from typing import Union
import mitmproxy.connections
import mitmproxy.http
@ -36,8 +35,8 @@ class WatchdogAddon:
raise RuntimeError("Watchtdog output path must be a directory.")
elif not self.flow_dir.exists():
self.flow_dir.mkdir(parents=True)
self.last_trigger: Union[None, float] = None
self.timeout: Union[None, float] = timeout
self.last_trigger: None | float = None
self.timeout: None | float = timeout
def serverconnect(self, conn: mitmproxy.connections.ServerConnection):
if self.timeout is not None:

View File

@ -40,7 +40,6 @@ import socket
from html.parser import HTMLParser
from typing import NamedTuple
from typing import Optional
from typing import Union
from urllib.parse import urlparse
import requests
@ -93,7 +92,7 @@ def get_cookies(flow: http.HTTPFlow) -> Cookies:
def find_unclaimed_URLs(body, requestUrl):
"""Look for unclaimed URLs in script tags and log them if found"""
def getValue(attrs: list[tuple[str, str]], attrName: str) -> Optional[str]:
def getValue(attrs: list[tuple[str, str]], attrName: str) -> str | None:
for name, value in attrs:
if attrName == name:
return value
@ -188,7 +187,7 @@ def test_query_injection(original_body: str, request_URL: str, cookies: Cookies)
return xss_info, sqli_info
def log_XSS_data(xss_info: Optional[XSSData]) -> None:
def log_XSS_data(xss_info: XSSData | None) -> None:
"""Log information about the given XSS to mitmproxy"""
# If it is None, then there is no info to log
if not xss_info:
@ -200,7 +199,7 @@ def log_XSS_data(xss_info: Optional[XSSData]) -> None:
logging.error("Line: %s" % xss_info.line)
def log_SQLi_data(sqli_info: Optional[SQLiData]) -> None:
def log_SQLi_data(sqli_info: SQLiData | None) -> None:
"""Log information about the given SQLi to mitmproxy"""
if not sqli_info:
return
@ -214,7 +213,7 @@ def log_SQLi_data(sqli_info: Optional[SQLiData]) -> None:
def get_SQLi_data(
new_body: str, original_body: str, request_URL: str, injection_point: str
) -> Optional[SQLiData]:
) -> SQLiData | None:
"""Return a SQLiDict if there is a SQLi otherwise return None
String String URL String -> (SQLiDict or None)"""
# Regexes taken from Damn Small SQLi Scanner: https://github.com/stamparm/DSSS/blob/master/dsss.py#L17
@ -337,8 +336,8 @@ def paths_to_text(html: str, string: str) -> list[str]:
def get_XSS_data(
body: Union[str, bytes], request_URL: str, injection_point: str
) -> Optional[XSSData]:
body: str | bytes, request_URL: str, injection_point: str
) -> XSSData | None:
"""Return a XSSDict if there is a XSS otherwise return None"""
def in_script(text, index, body) -> bool:

View File

@ -9,7 +9,6 @@ from collections.abc import Callable
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from typing import Optional
from mitmproxy import exceptions
from mitmproxy import flow
@ -71,7 +70,7 @@ class Loader:
typespec: type,
default: Any,
help: str,
choices: Optional[Sequence[str]] = None,
choices: Sequence[str] | None = None,
) -> None:
"""
Add an option to mitmproxy.

View File

@ -2,7 +2,6 @@ import asyncio
import logging
import traceback
import urllib.parse
from typing import Optional
import asgiref.compatibility
import asgiref.wsgi
@ -22,7 +21,7 @@ class ASGIApp:
- It currently only implements the HTTP protocol (Lifespan and WebSocket are unimplemented).
"""
def __init__(self, asgi_app, host: str, port: Optional[int]):
def __init__(self, asgi_app, host: str, port: int | None):
asgi_app = asgiref.compatibility.guarantee_single_callable(asgi_app)
self.asgi_app, self.host, self.port = asgi_app, host, port
@ -45,7 +44,7 @@ class ASGIApp:
class WSGIApp(ASGIApp):
def __init__(self, wsgi_app, host: str, port: Optional[int]):
def __init__(self, wsgi_app, host: str, port: int | None):
asgi_app = asgiref.wsgi.WsgiToAsgi(wsgi_app)
super().__init__(asgi_app, host, port)

View File

@ -2,14 +2,13 @@ import logging
import shutil
import subprocess
import tempfile
from typing import Optional
from mitmproxy import command
from mitmproxy import ctx
from mitmproxy.log import ALERT
def get_chrome_executable() -> Optional[str]:
def get_chrome_executable() -> str | None:
for browser in (
"/Applications/Google Chrome.app/Contents/MacOS/Google Chrome",
# https://stackoverflow.com/questions/40674914/google-chrome-path-in-windows-10
@ -29,7 +28,7 @@ def get_chrome_executable() -> Optional[str]:
return None
def get_chrome_flatpak() -> Optional[str]:
def get_chrome_flatpak() -> str | None:
if shutil.which("flatpak"):
for browser in (
"com.google.Chrome",
@ -50,7 +49,7 @@ def get_chrome_flatpak() -> Optional[str]:
return None
def get_browser_cmd() -> Optional[list[str]]:
def get_browser_cmd() -> list[str] | None:
if browser := get_chrome_executable():
return [browser]
elif browser := get_chrome_flatpak():

View File

@ -6,7 +6,6 @@ import time
import traceback
from collections.abc import Sequence
from typing import cast
from typing import Optional
import mitmproxy.types
from mitmproxy import command
@ -134,8 +133,8 @@ class ReplayHandler(server.ConnectionHandler):
class ClientPlayback:
playback_task: Optional[asyncio.Task] = None
inflight: Optional[http.HTTPFlow]
playback_task: asyncio.Task | None = None
inflight: http.HTTPFlow | None
queue: asyncio.Queue
options: Options
replay_tasks: set[asyncio.Task]
@ -176,7 +175,7 @@ class ClientPlayback:
self.queue.task_done()
self.inflight = None
def check(self, f: flow.Flow) -> Optional[str]:
def check(self, f: flow.Flow) -> str | None:
if f.live or f == self.inflight:
return "Can't replay live flow."
if f.intercepted:

View File

@ -1,7 +1,6 @@
import logging
import os
from collections.abc import Sequence
from typing import Union
import mitmproxy.types
from mitmproxy import command
@ -135,7 +134,7 @@ class Core:
"""
Quickly set a number of common values on flows.
"""
val: Union[int, str] = value
val: int | str = value
if attr == "status_code":
try:
val = int(val) # type: ignore

View File

@ -4,7 +4,6 @@ import logging
import os.path
from collections.abc import Sequence
from typing import Any
from typing import Union
import pyperclip
@ -28,7 +27,7 @@ def is_addr(v):
return isinstance(v, tuple) and len(v) > 1
def extract(cut: str, f: flow.Flow) -> Union[str, bytes]:
def extract(cut: str, f: flow.Flow) -> str | bytes:
path = cut.split(".")
current: Any = f
for i, spec in enumerate(path):
@ -86,7 +85,7 @@ class Cut:
or "false", "bytes" are preserved, and all other values are
converted to strings.
"""
ret: list[list[Union[str, bytes]]] = []
ret: list[list[str | bytes]] = []
for f in flows:
ret.append([extract(c, f) for c in cuts])
return ret # type: ignore
@ -148,7 +147,7 @@ class Cut:
format is UTF-8 encoded CSV. If there is exactly one row and one
column, the data is written to file as-is, with raw bytes preserved.
"""
v: Union[str, bytes]
v: str | bytes
fp = io.StringIO(newline="")
if len(cuts) == 1 and len(flows) == 1:
v = extract_str(cuts[0], flows[0])

View File

@ -1,9 +1,8 @@
import asyncio
import ipaddress
import socket
from collections.abc import Callable
from collections.abc import Iterable
from typing import Callable
from typing import Union
from mitmproxy import dns
from mitmproxy.proxy import mode_specs
@ -24,7 +23,7 @@ async def resolve_question_by_name(
question: dns.Question,
loop: asyncio.AbstractEventLoop,
family: socket.AddressFamily,
ip: Callable[[str], Union[ipaddress.IPv4Address, ipaddress.IPv6Address]],
ip: Callable[[str], ipaddress.IPv4Address | ipaddress.IPv6Address],
) -> Iterable[dns.ResourceRecord]:
try:
addrinfos = await loop.getaddrinfo(host=question.name, port=0, family=family)
@ -51,7 +50,7 @@ async def resolve_question_by_addr(
question: dns.Question,
loop: asyncio.AbstractEventLoop,
suffix: str,
sockaddr: Callable[[list[str]], Union[tuple[str, int], tuple[str, int, int, int]]],
sockaddr: Callable[[list[str]], tuple[str, int] | tuple[str, int, int, int]],
) -> Iterable[dns.ResourceRecord]:
try:
addr = sockaddr(question.name[: -len(suffix)].split(".")[::-1])

View File

@ -6,7 +6,6 @@ import shutil
import sys
from typing import IO
from typing import Optional
from typing import Union
from wsproto.frame_protocol import CloseReason
@ -45,8 +44,8 @@ CONTENTVIEW_STYLES: dict[str, dict[str, str | bool]] = {
class Dumper:
def __init__(self, outfile: Optional[IO[str]] = None):
self.filter: Optional[flowfilter.TFilter] = None
def __init__(self, outfile: IO[str] | None = None):
self.filter: flowfilter.TFilter | None = None
self.outfp: IO[str] = outfile or sys.stdout
self.out_has_vt_codes = vt_codes.ensure_supported(self.outfp)
@ -103,7 +102,7 @@ class Dumper:
vs = strutils.bytes_to_escaped_str(v)
self.echo(f"{ks}: {vs}", ident=4)
def _echo_trailers(self, trailers: Optional[http.Headers]):
def _echo_trailers(self, trailers: http.Headers | None):
if not trailers:
return
self.echo("--- HTTP Trailers", fg="magenta", ident=4)
@ -116,8 +115,8 @@ class Dumper:
def _echo_message(
self,
message: Union[http.Message, TCPMessage, UDPMessage, WebSocketMessage],
flow: Union[http.HTTPFlow, TCPFlow, UDPFlow],
message: http.Message | TCPMessage | UDPMessage | WebSocketMessage,
flow: http.HTTPFlow | TCPFlow | UDPFlow,
):
_, lines, error = contentviews.get_message_content_view(
ctx.options.dumper_default_contentview, message, flow
@ -341,7 +340,7 @@ class Dumper:
def udp_error(self, f):
self._proto_error(f)
def _proto_message(self, f: Union[TCPFlow, UDPFlow]) -> None:
def _proto_message(self, f: TCPFlow | UDPFlow) -> None:
if self.match(f):
message = f.messages[-1]
direction = "->" if message.from_client else "<-"

View File

@ -2,7 +2,6 @@ import asyncio
import collections
import logging
from collections.abc import Callable
from typing import Optional
from mitmproxy import command
from mitmproxy import log
@ -27,7 +26,7 @@ class EventStore:
self.sig_add.send(entry)
@property
def size(self) -> Optional[int]:
def size(self) -> int | None:
return self.data.maxlen
@command.command("eventstore.clear")

View File

@ -3,7 +3,6 @@ import shlex
from collections.abc import Callable
from collections.abc import Sequence
from typing import Any
from typing import Union
import pyperclip
@ -142,7 +141,7 @@ def raw(f: flow.Flow, separator=b"\r\n\r\n") -> bytes:
raise exceptions.CommandError("Can't export flow with no request or response.")
formats: dict[str, Callable[[flow.Flow], Union[str, bytes]]] = dict(
formats: dict[str, Callable[[flow.Flow], str | bytes]] = dict(
curl=curl_command,
httpie=httpie_command,
raw=raw,

View File

@ -7,7 +7,7 @@ from mitmproxy import flowfilter
class Intercept:
filt: Optional[flowfilter.TFilter] = None
filt: flowfilter.TFilter | None = None
def load(self, loader):
loader.add_option("intercept_active", bool, False, "Intercept toggle")

View File

@ -16,12 +16,11 @@ that sets nextlayer.layer works just as well.
"""
import re
import struct
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Sequence
from typing import Any
from typing import Callable
from typing import cast
from typing import Optional
from typing import Union
from mitmproxy import connection
@ -51,7 +50,7 @@ ServerSecurityLayerCls = Union[
def stack_match(
context: context.Context, layers: Sequence[Union[LayerCls, tuple[LayerCls, ...]]]
context: context.Context, layers: Sequence[LayerCls | tuple[LayerCls, ...]]
) -> bool:
if len(context.layers) != len(layers):
return False
@ -90,12 +89,12 @@ class NextLayer:
def ignore_connection(
self,
server_address: Optional[connection.Address],
server_address: connection.Address | None,
data_client: bytes,
*,
is_tls: Callable[[bytes], bool] = is_tls_record_magic,
client_hello: Callable[[bytes], Optional[ClientHello]] = parse_client_hello,
) -> Optional[bool]:
client_hello: Callable[[bytes], ClientHello | None] = parse_client_hello,
) -> bool | None:
"""
Returns:
True, if the connection should be ignored.
@ -174,7 +173,7 @@ class NextLayer:
for rex in hosts
)
def get_http_layer(self, context: context.Context) -> Optional[layers.HttpLayer]:
def get_http_layer(self, context: context.Context) -> layers.HttpLayer | None:
def s(*layers):
return stack_match(context, layers)
@ -195,7 +194,7 @@ class NextLayer:
def detect_udp_tls(
self, data_client: bytes
) -> Optional[tuple[ClientHello, ClientSecurityLayerCls, ServerSecurityLayerCls]]:
) -> tuple[ClientHello, ClientSecurityLayerCls, ServerSecurityLayerCls] | None:
if len(data_client) == 0:
return None
@ -257,7 +256,7 @@ class NextLayer:
def _next_layer(
self, context: context.Context, data_client: bytes, data_server: bytes
) -> Optional[layer.Layer]:
) -> layer.Layer | None:
assert context.layers
if context.client.transport_protocol == "tcp":

View File

@ -112,8 +112,8 @@ class Proxyserver(ServerManager):
servers: Servers
is_running: bool
_connect_addr: Optional[Address] = None
_update_task: Optional[asyncio.Task] = None
_connect_addr: Address | None = None
_update_task: asyncio.Task | None = None
def __init__(self):
self.connections = {}

View File

@ -41,10 +41,10 @@ def _mode(path: str) -> Literal["ab", "wb"]:
class Save:
def __init__(self) -> None:
self.stream: Optional[io.FilteredFlowWriter] = None
self.filt: Optional[flowfilter.TFilter] = None
self.stream: io.FilteredFlowWriter | None = None
self.filt: flowfilter.TFilter | None = None
self.active_flows: set[flow.Flow] = set()
self.current_path: Optional[str] = None
self.current_path: str | None = None
def load(self, loader):
loader.add_option(

View File

@ -7,7 +7,6 @@ import sys
import traceback
import types
from collections.abc import Sequence
from typing import Optional
import mitmproxy.types as mtypes
from mitmproxy import addonmanager
@ -22,7 +21,7 @@ from mitmproxy.utils import asyncio_utils
logger = logging.getLogger(__name__)
def load_script(path: str) -> Optional[types.ModuleType]:
def load_script(path: str) -> types.ModuleType | None:
fullname = "__mitmproxy_script__.{}".format(
os.path.splitext(os.path.basename(path))[0]
)

View File

@ -4,7 +4,6 @@ import urllib
from collections.abc import Hashable
from collections.abc import Sequence
from typing import Any
from typing import Optional
import mitmproxy.types
from mitmproxy import command
@ -195,7 +194,7 @@ class ServerPlayback:
key.append(headers)
return hashlib.sha256(repr(key).encode("utf8", "surrogateescape")).digest()
def next_flow(self, flow: http.HTTPFlow) -> Optional[http.HTTPFlow]:
def next_flow(self, flow: http.HTTPFlow) -> http.HTTPFlow | None:
"""
Returns the next flow object, or None if no matching flow was
found.

View File

@ -37,7 +37,7 @@ class StickyCookie:
self.jar: collections.defaultdict[
TOrigin, dict[str, str]
] = collections.defaultdict(dict)
self.flt: Optional[flowfilter.TFilter] = None
self.flt: flowfilter.TFilter | None = None
def load(self, loader):
loader.add_option(

View File

@ -4,7 +4,6 @@ import os
import ssl
from pathlib import Path
from typing import Any
from typing import Optional
from typing import TypedDict
from aioquic.h3.connection import H3_ALPN
@ -64,8 +63,8 @@ DEFAULT_HOSTFLAGS = (
class AppData(TypedDict):
client_alpn: Optional[bytes]
server_alpn: Optional[bytes]
client_alpn: bytes | None
server_alpn: bytes | None
http2: bool
@ -200,7 +199,7 @@ class TlsConfig:
if len(tls_start.context.layers) == 2 and isinstance(
tls_start.context.layers[0], modes.HttpProxy
):
client_alpn: Optional[bytes] = b"http/1.1"
client_alpn: bytes | None = b"http/1.1"
else:
client_alpn = client.alpn
@ -257,7 +256,7 @@ class TlsConfig:
# don't assign to client.cipher_list, doesn't need to be stored.
cipher_list = server.cipher_list or DEFAULT_CIPHERS
client_cert: Optional[str] = None
client_cert: str | None = None
if ctx.options.client_certs:
client_certs = os.path.expanduser(ctx.options.client_certs)
if os.path.isfile(client_certs):
@ -455,7 +454,7 @@ class TlsConfig:
our certificate should have and then fetches a matching cert from the certstore.
"""
altnames: list[str] = []
organization: Optional[str] = None
organization: str | None = None
# Use upstream certificate if available.
if ctx.options.upstream_cert and conn_context.server.certificate_list:

View File

@ -27,7 +27,7 @@ class UpstreamAuth:
- Reverse proxy regular requests (CONNECT is invalid in this mode)
"""
auth: Optional[bytes] = None
auth: bytes | None = None
def load(self, loader):
loader.add_option(

View File

@ -235,7 +235,7 @@ class View(collections.abc.Sequence):
return self._rev(v - 1) + 1
def index(
self, f: mitmproxy.flow.Flow, start: int = 0, stop: Optional[int] = None
self, f: mitmproxy.flow.Flow, start: int = 0, stop: int | None = None
) -> int:
return self._rev(self._view.index(f, start, stop))
@ -353,7 +353,7 @@ class View(collections.abc.Sequence):
raise exceptions.CommandError(str(e)) from e
self.set_filter(filt)
def set_filter(self, flt: Optional[flowfilter.TFilter]):
def set_filter(self, flt: flowfilter.TFilter | None):
self.filter = flt or flowfilter.match_all
self._refilter()
@ -524,7 +524,7 @@ class View(collections.abc.Sequence):
self.focus.flow = f
self.sig_view_add.send(flow=f)
def get_by_id(self, flow_id: str) -> Optional[mitmproxy.flow.Flow]:
def get_by_id(self, flow_id: str) -> mitmproxy.flow.Flow | None:
"""
Get flow with the given id from the store.
Returns None if the flow is not found.
@ -669,7 +669,7 @@ class Focus:
def __init__(self, v: View) -> None:
self.view = v
self._flow: Optional[mitmproxy.flow.Flow] = None
self._flow: mitmproxy.flow.Flow | None = None
self.sig_change = signals.SyncSignal(lambda: None)
if len(self.view):
self.flow = self.view[0]
@ -678,18 +678,18 @@ class Focus:
v.sig_view_refresh.connect(self._sig_view_refresh)
@property
def flow(self) -> Optional[mitmproxy.flow.Flow]:
def flow(self) -> mitmproxy.flow.Flow | None:
return self._flow
@flow.setter
def flow(self, f: Optional[mitmproxy.flow.Flow]):
def flow(self, f: mitmproxy.flow.Flow | None):
if f is not None and f not in self.view:
raise ValueError("Attempt to set focus to flow not in view")
self._flow = f
self.sig_change.send()
@property
def index(self) -> Optional[int]:
def index(self) -> int | None:
if self.flow:
return self.view.index(self.flow)
return None

View File

@ -131,14 +131,14 @@ class Cert(serializable.Serializable):
) # pragma: no cover
@property
def cn(self) -> Optional[str]:
def cn(self) -> str | None:
attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
if attrs:
return attrs[0].value
return None
@property
def organization(self) -> Optional[str]:
def organization(self) -> str | None:
attrs = self._cert.subject.get_attributes_for_oid(
x509.NameOID.ORGANIZATION_NAME
)
@ -231,9 +231,9 @@ def create_ca(
def dummy_cert(
privkey: rsa.RSAPrivateKey,
cacert: x509.Certificate,
commonname: Optional[str],
commonname: str | None,
sans: list[str],
organization: Optional[str] = None,
organization: str | None = None,
) -> Cert:
"""
Generates a dummy certificate.
@ -288,7 +288,7 @@ def dummy_cert(
class CertStoreEntry:
cert: Cert
privatekey: rsa.RSAPrivateKey
chain_file: Optional[Path]
chain_file: Path | None
chain_certs: list[Cert]
@ -312,7 +312,7 @@ class CertStore:
self,
default_privatekey: rsa.RSAPrivateKey,
default_ca: Cert,
default_chain_file: Optional[Path],
default_chain_file: Path | None,
dhparams: DHParams,
):
self.default_privatekey = default_privatekey
@ -365,10 +365,10 @@ class CertStore:
@classmethod
def from_store(
cls,
path: Union[Path, str],
path: Path | str,
basename: str,
key_size: int,
passphrase: Optional[bytes] = None,
passphrase: bytes | None = None,
) -> "CertStore":
path = Path(path)
ca_file = path / f"{basename}-ca.pem"
@ -379,7 +379,7 @@ class CertStore:
@classmethod
def from_files(
cls, ca_file: Path, dhparam_file: Path, passphrase: Optional[bytes] = None
cls, ca_file: Path, dhparam_file: Path, passphrase: bytes | None = None
) -> "CertStore":
raw = ca_file.read_bytes()
key = load_pem_private_key(raw, passphrase)
@ -387,7 +387,7 @@ class CertStore:
certs = re.split(rb"(?=-----BEGIN CERTIFICATE-----)", raw)
ca = Cert.from_pem(certs[1])
if len(certs) > 2:
chain_file: Optional[Path] = ca_file
chain_file: Path | None = ca_file
else:
chain_file = None
return cls(key, ca, chain_file, dh)
@ -463,7 +463,7 @@ class CertStore:
(path / f"{basename}-dhparam.pem").write_bytes(DEFAULT_DHPARAM)
def add_cert_file(
self, spec: str, path: Path, passphrase: Optional[bytes] = None
self, spec: str, path: Path, passphrase: bytes | None = None
) -> None:
raw = path.read_bytes()
cert = Cert.from_pem(raw)
@ -500,9 +500,9 @@ class CertStore:
def get_cert(
self,
commonname: Optional[str],
commonname: str | None,
sans: list[str],
organization: Optional[str] = None,
organization: str | None = None,
) -> CertStoreEntry:
"""
commonname: Common name for the generated certificate. Must be a
@ -543,7 +543,7 @@ class CertStore:
return entry
def load_pem_private_key(data: bytes, password: Optional[bytes]) -> rsa.RSAPrivateKey:
def load_pem_private_key(data: bytes, password: bytes | None) -> rsa.RSAPrivateKey:
"""
like cryptography's load_pem_private_key, but silently falls back to not using a password
if the private key is unencrypted.

View File

@ -12,7 +12,6 @@ from collections.abc import Iterable
from collections.abc import Sequence
from typing import Any
from typing import NamedTuple
from typing import Optional
import pyparsing
@ -66,7 +65,7 @@ class Command:
name: str
manager: "CommandManager"
signature: inspect.Signature
help: Optional[str]
help: str | None
def __init__(self, manager: "CommandManager", name: str, func: Callable) -> None:
self.name = name
@ -95,7 +94,7 @@ class Command:
)
@property
def return_type(self) -> Optional[type]:
def return_type(self) -> type | None:
return _empty_as_none(self.signature.return_annotation)
@property
@ -209,7 +208,7 @@ class CommandManager:
CommandParameter("", mitmproxy.types.Cmd),
CommandParameter("", mitmproxy.types.CmdArgs),
]
expected: Optional[CommandParameter] = None
expected: CommandParameter | None = None
for part in parts:
if part.isspace():
parsed.append(
@ -314,7 +313,7 @@ def parsearg(manager: CommandManager, spec: str, argtype: type) -> Any:
raise exceptions.CommandError(str(e)) from e
def command(name: Optional[str] = None):
def command(name: str | None = None):
def decorator(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):

View File

@ -9,7 +9,6 @@ from dataclasses import dataclass
from dataclasses import field
from enum import Flag
from typing import Literal
from typing import Optional
from mitmproxy import certs
from mitmproxy.coretypes import serializable
@ -35,10 +34,7 @@ TransportProtocol = Literal["tcp", "udp"]
# this version at least provides useful type checking messages.
Address = tuple[str, int]
if sys.version_info < (3, 10): # pragma: no cover
kw_only = {}
else:
kw_only = {"kw_only": True}
kw_only = {"kw_only": True}
# noinspection PyDataclass
@ -51,9 +47,9 @@ class Connection(serializable.SerializableDataclass, metaclass=ABCMeta):
This is intentional, all I/O should be handled by `mitmproxy.proxy.server` exclusively.
"""
peername: Optional[Address]
peername: Address | None
"""The remote's `(ip, port)` tuple for this connection."""
sockname: Optional[Address]
sockname: Address | None
"""Our local `(ip, port)` tuple for this connection."""
state: ConnectionState = field(
@ -68,7 +64,7 @@ class Connection(serializable.SerializableDataclass, metaclass=ABCMeta):
"""A unique UUID to identify the connection."""
transport_protocol: TransportProtocol = field(default="tcp")
"""The connection protocol in use."""
error: Optional[str] = None
error: str | None = None
"""
A string describing a general error with connections to this address.
@ -99,27 +95,27 @@ class Connection(serializable.SerializableDataclass, metaclass=ABCMeta):
> TLS version, with the exception of the end-entity certificate which
> MUST be first.
"""
alpn: Optional[bytes] = None
alpn: bytes | None = None
"""The application-layer protocol as negotiated using
[ALPN](https://en.wikipedia.org/wiki/Application-Layer_Protocol_Negotiation)."""
alpn_offers: Sequence[bytes] = ()
"""The ALPN offers as sent in the ClientHello."""
# we may want to add SSL_CIPHER_description here, but that's currently not exposed by cryptography
cipher: Optional[str] = None
cipher: str | None = None
"""The active cipher name as returned by OpenSSL's `SSL_CIPHER_get_name`."""
cipher_list: Sequence[str] = ()
"""Ciphers accepted by the proxy server on this connection."""
tls_version: Optional[str] = None
tls_version: str | None = None
"""The active TLS version."""
sni: Optional[str] = None
sni: str | None = None
"""
The [Server Name Indication (SNI)](https://en.wikipedia.org/wiki/Server_Name_Indication) sent in the ClientHello.
"""
timestamp_start: Optional[float] = None
timestamp_end: Optional[float] = None
timestamp_start: float | None = None
timestamp_end: float | None = None
"""*Timestamp:* Connection has been closed."""
timestamp_tls_setup: Optional[float] = None
timestamp_tls_setup: float | None = None
"""*Timestamp:* TLS handshake has been completed successfully."""
@property
@ -157,7 +153,7 @@ class Connection(serializable.SerializableDataclass, metaclass=ABCMeta):
return f"{type(self).__name__}({attrs!r})"
@property
def alpn_proto_negotiated(self) -> Optional[bytes]: # pragma: no cover
def alpn_proto_negotiated(self) -> bytes | None: # pragma: no cover
"""*Deprecated:* An outdated alias for Connection.alpn."""
warnings.warn(
"Connection.alpn_proto_negotiated is deprecated, use Connection.alpn instead.",
@ -177,7 +173,7 @@ class Client(Connection):
sockname: Address
"""The local address we received this connection on."""
mitmcert: Optional[certs.Cert] = None
mitmcert: certs.Cert | None = None
"""
The certificate used by mitmproxy to establish TLS with the client.
"""
@ -221,7 +217,7 @@ class Client(Connection):
self.peername = x
@property
def cipher_name(self) -> Optional[str]: # pragma: no cover
def cipher_name(self) -> str | None: # pragma: no cover
"""*Deprecated:* An outdated alias for Connection.cipher."""
warnings.warn(
"Client.cipher_name is deprecated, use Client.cipher instead.",
@ -231,7 +227,7 @@ class Client(Connection):
return self.cipher
@property
def clientcert(self) -> Optional[certs.Cert]: # pragma: no cover
def clientcert(self) -> certs.Cert | None: # pragma: no cover
"""*Deprecated:* An outdated alias for Connection.certificate_list[0]."""
warnings.warn(
"Client.clientcert is deprecated, use Client.certificate_list instead.",
@ -261,30 +257,30 @@ class Client(Connection):
class Server(Connection):
"""A connection between mitmproxy and an upstream server."""
address: Optional[Address] # type: ignore
address: Address | None # type: ignore
"""The server's `(host, port)` address tuple. The host can either be a domain or a plain IP address."""
if sys.version_info < (3, 10): # pragma: no cover
# no keyword-only arguments here.
address: Optional[Address] = None
address: Address | None = None
peername: Optional[Address] = None
peername: Address | None = None
"""
The server's resolved `(ip, port)` tuple. Will be set during connection establishment.
May be `None` in upstream proxy mode when the address is resolved by the upstream proxy only.
"""
sockname: Optional[Address] = None
sockname: Address | None = None
timestamp_start: Optional[float] = None
timestamp_start: float | None = None
"""
*Timestamp:* Connection establishment started.
For IP addresses, this corresponds to sending a TCP SYN; for domains, this corresponds to starting a DNS lookup.
"""
timestamp_tcp_setup: Optional[float] = None
timestamp_tcp_setup: float | None = None
"""*Timestamp:* TCP ACK received."""
via: Optional[server_spec.ServerSpec] = None
via: server_spec.ServerSpec | None = None
"""An optional proxy server specification via which the connection should be established."""
def __str__(self):
@ -315,7 +311,7 @@ class Server(Connection):
return super().__setattr__(name, value)
@property
def ip_address(self) -> Optional[Address]: # pragma: no cover
def ip_address(self) -> Address | None: # pragma: no cover
"""*Deprecated:* An outdated alias for `Server.peername`."""
warnings.warn(
"Server.ip_address is deprecated, use Server.peername instead.",
@ -325,7 +321,7 @@ class Server(Connection):
return self.peername
@property
def cert(self) -> Optional[certs.Cert]: # pragma: no cover
def cert(self) -> certs.Cert | None: # pragma: no cover
"""*Deprecated:* An outdated alias for `Connection.certificate_list[0]`."""
warnings.warn(
"Server.cert is deprecated, use Server.certificate_list instead.",

View File

@ -12,8 +12,6 @@ metadata depend on the protocol in use. Known attributes can be found in
`base.View`.
"""
import traceback
from typing import Optional
from typing import Union
from . import auto
from . import css
@ -61,7 +59,7 @@ on_remove = signals.SyncSignal(_update)
"""A contentview has been removed."""
def get(name: str) -> Optional[View]:
def get(name: str) -> View | None:
for i in views:
if i.name.lower() == name.lower():
return i
@ -99,7 +97,7 @@ def safe_to_print(lines, encoding="utf8"):
def get_message_content_view(
viewname: str,
message: Union[http.Message, TCPMessage, UDPMessage, WebSocketMessage],
message: http.Message | TCPMessage | UDPMessage | WebSocketMessage,
flow: flow.Flow,
):
"""
@ -110,7 +108,7 @@ def get_message_content_view(
viewmode = get("auto")
assert viewmode
content: Optional[bytes]
content: bytes | None
try:
content = message.content
except ValueError:
@ -162,11 +160,11 @@ def get_content_view(
viewmode: View,
data: bytes,
*,
content_type: Optional[str] = None,
flow: Optional[flow.Flow] = None,
http_message: Optional[http.Message] = None,
tcp_message: Optional[tcp.TCPMessage] = None,
udp_message: Optional[udp.UDPMessage] = None,
content_type: str | None = None,
flow: flow.Flow | None = None,
http_message: http.Message | None = None,
tcp_message: tcp.TCPMessage | None = None,
udp_message: udp.UDPMessage | None = None,
):
"""
Args:

View File

@ -5,7 +5,6 @@ from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Mapping
from typing import ClassVar
from typing import Optional
from typing import Union
from mitmproxy import flow
@ -26,9 +25,9 @@ class View(ABC):
self,
data: bytes,
*,
content_type: Optional[str] = None,
flow: Optional[flow.Flow] = None,
http_message: Optional[http.Message] = None,
content_type: str | None = None,
flow: flow.Flow | None = None,
http_message: http.Message | None = None,
**unknown_metadata,
) -> TViewResult:
"""
@ -52,9 +51,9 @@ class View(ABC):
self,
data: bytes,
*,
content_type: Optional[str] = None,
flow: Optional[flow.Flow] = None,
http_message: Optional[http.Message] = None,
content_type: str | None = None,
flow: flow.Flow | None = None,
http_message: http.Message | None = None,
**unknown_metadata,
) -> float:
"""

View File

@ -1,6 +1,5 @@
import re
import time
from typing import Optional
from mitmproxy.contentviews import base
from mitmproxy.utils import strutils
@ -58,7 +57,7 @@ class ViewCSS(base.View):
return "CSS", base.format_text(beautified)
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
return float(bool(data) and content_type == "text/css")

View File

@ -1,6 +1,5 @@
import json
from typing import Any
from typing import Optional
from mitmproxy.contentviews import base
from mitmproxy.contentviews.json import PARSE_ERROR
@ -48,7 +47,7 @@ class ViewGraphQL(base.View):
return "GraphQL", base.format_text(format_query_list(data))
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
if content_type != "application/json" or not data:
return 0

View File

@ -2,8 +2,6 @@ from collections import defaultdict
from collections.abc import Iterator
from dataclasses import dataclass
from dataclasses import field
from typing import Optional
from typing import Union
import pylsqpack
from aioquic.buffer import Buffer
@ -74,7 +72,7 @@ class StreamType:
@dataclass
class ConnectionState:
message_count: int = 0
frames: dict[int, list[Union[Frame, StreamType]]] = field(default_factory=dict)
frames: dict[int, list[Frame | StreamType]] = field(default_factory=dict)
client_buf: bytearray = field(default_factory=bytearray)
server_buf: bytearray = field(default_factory=bytearray)
@ -90,8 +88,8 @@ class ViewHttp3(base.View):
def __call__(
self,
data,
flow: Optional[flow.Flow] = None,
tcp_message: Optional[tcp.TCPMessage] = None,
flow: flow.Flow | None = None,
tcp_message: tcp.TCPMessage | None = None,
**metadata,
):
assert isinstance(flow, tcp.TCPFlow)
@ -146,7 +144,7 @@ class ViewHttp3(base.View):
return "HTTP/3", fmt_frames(frames)
def render_priority(
self, data: bytes, flow: Optional[flow.Flow] = None, **metadata
self, data: bytes, flow: flow.Flow | None = None, **metadata
) -> float:
return (
2
@ -155,7 +153,7 @@ class ViewHttp3(base.View):
)
def fmt_frames(frames: list[Union[Frame, StreamType]]) -> Iterator[base.TViewLine]:
def fmt_frames(frames: list[Frame | StreamType]) -> Iterator[base.TViewLine]:
for i, frame in enumerate(frames):
if i > 0:
yield [("text", "")]

View File

@ -1,5 +1,4 @@
import imghdr
from typing import Optional
from . import image_parser
from mitmproxy.contentviews import base
@ -36,7 +35,7 @@ class ViewImage(base.View):
return view_name, base.format_dict(multidict.MultiDict(image_metadata))
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
return float(
bool(

View File

@ -1,6 +1,5 @@
import io
import re
from typing import Optional
from mitmproxy.contentviews import base
from mitmproxy.utils import strutils
@ -55,6 +54,6 @@ class ViewJavaScript(base.View):
return "JavaScript", base.format_text(res)
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
return float(bool(data) and content_type in self.__content_types)

View File

@ -3,7 +3,6 @@ import re
from collections.abc import Iterator
from functools import lru_cache
from typing import Any
from typing import Optional
from mitmproxy.contentviews import base
@ -55,7 +54,7 @@ class ViewJSON(base.View):
return "JSON", format_json(data)
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
if not data:
return 0

View File

@ -1,5 +1,4 @@
import struct
from typing import Optional
from mitmproxy.contentviews import base
from mitmproxy.utils import strutils
@ -273,6 +272,6 @@ class ViewMQTT(base.View):
return "MQTT", base.format_text(text)
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
return 0

View File

@ -1,5 +1,4 @@
from typing import Any
from typing import Optional
import msgpack
@ -95,6 +94,6 @@ class ViewMsgPack(base.View):
return "MsgPack", format_msgpack(data)
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
return float(bool(data) and content_type in self.__content_types)

View File

@ -1,5 +1,3 @@
from typing import Optional
from . import base
from mitmproxy.coretypes import multidict
from mitmproxy.net.http import multipart
@ -13,7 +11,7 @@ class ViewMultipart(base.View):
yield [("highlight", "Form data:\n")]
yield from base.format_dict(multidict.MultiDict(v))
def __call__(self, data: bytes, content_type: Optional[str] = None, **metadata):
def __call__(self, data: bytes, content_type: str | None = None, **metadata):
if content_type is None:
return
v = multipart.decode_multipart(content_type, data)
@ -21,6 +19,6 @@ class ViewMultipart(base.View):
return "Multipart form", self._format(v)
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
return float(bool(data) and content_type == "multipart/form-data")

View File

@ -1,5 +1,4 @@
import io
from typing import Optional
from kaitaistruct import KaitaiStream
@ -98,6 +97,6 @@ class ViewProtobuf(base.View):
return "Protobuf", base.format_text(decoded)
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
return float(bool(data) and content_type in self.__content_types)

View File

@ -1,5 +1,3 @@
from typing import Optional
from . import base
from .. import http
@ -8,7 +6,7 @@ class ViewQuery(base.View):
name = "Query"
def __call__(
self, data: bytes, http_message: Optional[http.Message] = None, **metadata
self, data: bytes, http_message: http.Message | None = None, **metadata
):
query = getattr(http_message, "query", None)
if query:
@ -17,6 +15,6 @@ class ViewQuery(base.View):
return "Query", base.format_text("")
def render_priority(
self, data: bytes, *, http_message: Optional[http.Message] = None, **metadata
self, data: bytes, *, http_message: http.Message | None = None, **metadata
) -> float:
return 0.3 * float(bool(getattr(http_message, "query", False) and not data))

View File

@ -1,5 +1,3 @@
from typing import Optional
from . import base
from mitmproxy.net.http import url
@ -16,6 +14,6 @@ class ViewURLEncoded(base.View):
return "URLEncoded form", base.format_pairs(d)
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
return float(bool(data) and content_type == "application/x-www-form-urlencoded")

View File

@ -1,5 +1,3 @@
from typing import Optional
from . import base
from mitmproxy.contrib.wbxml import ASCommandResponse
@ -18,6 +16,6 @@ class ViewWBXML(base.View):
return None
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
return float(bool(data) and content_type in self.__content_types)

View File

@ -2,7 +2,6 @@ import io
import re
import textwrap
from collections.abc import Iterable
from typing import Optional
from mitmproxy.contentviews import base
from mitmproxy.utils import sliding_window
@ -140,7 +139,7 @@ def indent_text(data: str, prefix: str) -> str:
return textwrap.indent(dedented, prefix[:32])
def is_inline_text(a: Optional[Token], b: Optional[Token], c: Optional[Token]) -> bool:
def is_inline_text(a: Token | None, b: Token | None, c: Token | None) -> bool:
if isinstance(a, Tag) and isinstance(b, Text) and isinstance(c, Tag):
if a.is_opening and "\n" not in b.data and c.is_closing and a.tag == c.tag:
return True
@ -148,11 +147,11 @@ def is_inline_text(a: Optional[Token], b: Optional[Token], c: Optional[Token]) -
def is_inline(
prev2: Optional[Token],
prev1: Optional[Token],
t: Optional[Token],
next1: Optional[Token],
next2: Optional[Token],
prev2: Token | None,
prev1: Token | None,
t: Token | None,
next1: Token | None,
next2: Token | None,
) -> bool:
if isinstance(t, Text):
return is_inline_text(prev1, t, next1)
@ -267,7 +266,7 @@ class ViewXmlHtml(base.View):
return t, pretty
def render_priority(
self, data: bytes, *, content_type: Optional[str] = None, **metadata
self, data: bytes, *, content_type: str | None = None, **metadata
) -> float:
if not data:
return 0

View File

@ -1,7 +1,7 @@
# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild
import kaitaistruct
from kaitaistruct import KaitaiStruct, KaitaiStream, BytesIO
from kaitaistruct import KaitaiStream, KaitaiStruct
from enum import Enum

View File

@ -1,7 +1,7 @@
# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild
import kaitaistruct
from kaitaistruct import KaitaiStruct, KaitaiStream, BytesIO
from kaitaistruct import KaitaiStream, KaitaiStruct
from enum import Enum

View File

@ -1,7 +1,7 @@
# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild
import kaitaistruct
from kaitaistruct import KaitaiStruct, KaitaiStream, BytesIO
from kaitaistruct import KaitaiStruct
if getattr(kaitaistruct, 'API_VERSION', (0, 9)) < (0, 9):

View File

@ -1,7 +1,7 @@
# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild
import kaitaistruct
from kaitaistruct import KaitaiStruct, KaitaiStream, BytesIO
from kaitaistruct import KaitaiStruct
if getattr(kaitaistruct, 'API_VERSION', (0, 9)) < (0, 9):

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import Callable
from mitmproxy import dns
from mitmproxy import flow

View File

@ -8,7 +8,6 @@ from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import ClassVar
from typing import Optional
from mitmproxy import connection
from mitmproxy import exceptions
@ -66,7 +65,7 @@ class Flow(serializable.Serializable):
with a `timestamp_start` set to `None`.
"""
error: Optional[Error] = None
error: Error | None = None
"""A connection or protocol error affecting this flow."""
intercepted: bool
@ -89,7 +88,7 @@ class Flow(serializable.Serializable):
The default marker for the view will be used if the Unicode emoji name can not be interpreted.
"""
is_replay: Optional[str]
is_replay: str | None
"""
This attribute indicates if this flow has been replayed in either direction.
@ -123,10 +122,10 @@ class Flow(serializable.Serializable):
self.timestamp_created = time.time()
self.intercepted: bool = False
self._resume_event: Optional[asyncio.Event] = None
self._backup: Optional[Flow] = None
self._resume_event: asyncio.Event | None = None
self._backup: Flow | None = None
self.marked: str = ""
self.is_replay: Optional[str] = None
self.is_replay: str | None = None
self.metadata: dict[str, Any] = dict()
self.comment: str = ""

View File

@ -38,7 +38,6 @@ import sys
from collections.abc import Sequence
from typing import ClassVar
from typing import Protocol
from typing import Union
import pyparsing as pp
@ -662,7 +661,7 @@ def parse(s: str) -> TFilter:
raise ValueError(f"Invalid filter expression: {s!r}") from e
def match(flt: Union[str, TFilter], flow: flow.Flow) -> bool:
def match(flt: str | TFilter, flow: flow.Flow) -> bool:
"""
Matches a flow against a compiled filter expression.
Returns True if matched, False if not.

View File

@ -5,6 +5,7 @@ import re
import time
import urllib.parse
import warnings
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Mapping
@ -15,10 +16,7 @@ from email.utils import formatdate
from email.utils import mktime_tz
from email.utils import parsedate_tz
from typing import Any
from typing import Callable
from typing import cast
from typing import Optional
from typing import Union
from mitmproxy import flow
from mitmproxy.coretypes import multidict
@ -43,7 +41,7 @@ def _native(x: bytes) -> str:
return x.decode("utf-8", "surrogateescape")
def _always_bytes(x: Union[str, bytes]) -> bytes:
def _always_bytes(x: str | bytes) -> bytes:
return strutils.always_bytes(x, "utf-8", "surrogateescape")
@ -136,7 +134,7 @@ class Headers(multidict.MultiDict): # type: ignore
else:
return b""
def __delitem__(self, key: Union[str, bytes]) -> None:
def __delitem__(self, key: str | bytes) -> None:
key = _always_bytes(key)
super().__delitem__(key)
@ -144,7 +142,7 @@ class Headers(multidict.MultiDict): # type: ignore
for x in super().__iter__():
yield _native(x)
def get_all(self, name: Union[str, bytes]) -> list[str]:
def get_all(self, name: str | bytes) -> list[str]:
"""
Like `Headers.get`, but does not fold multiple headers into a single one.
This is useful for Set-Cookie and Cookie headers, which do not support folding.
@ -157,7 +155,7 @@ class Headers(multidict.MultiDict): # type: ignore
name = _always_bytes(name)
return [_native(x) for x in super().get_all(name)]
def set_all(self, name: Union[str, bytes], values: Iterable[Union[str, bytes]]):
def set_all(self, name: str | bytes, values: Iterable[str | bytes]):
"""
Explicitly set multiple headers for the given key.
See `Headers.get_all`.
@ -166,7 +164,7 @@ class Headers(multidict.MultiDict): # type: ignore
values = [_always_bytes(x) for x in values]
return super().set_all(name, values)
def insert(self, index: int, key: Union[str, bytes], value: Union[str, bytes]):
def insert(self, index: int, key: str | bytes, value: str | bytes):
key = _always_bytes(key)
value = _always_bytes(value)
super().insert(index, key, value)
@ -182,10 +180,10 @@ class Headers(multidict.MultiDict): # type: ignore
class MessageData(serializable.Serializable):
http_version: bytes
headers: Headers
content: Optional[bytes]
trailers: Optional[Headers]
content: bytes | None
trailers: Headers | None
timestamp_start: float
timestamp_end: Optional[float]
timestamp_end: float | None
# noinspection PyUnreachableCode
if __debug__:
@ -246,7 +244,7 @@ class Message(serializable.Serializable):
self.data.set_state(state)
data: MessageData
stream: Union[Callable[[bytes], Union[Iterable[bytes], bytes]], bool] = False
stream: Callable[[bytes], Iterable[bytes] | bytes] | bool = False
"""
This attribute controls if the message body should be streamed.
@ -269,7 +267,7 @@ class Message(serializable.Serializable):
return self.data.http_version.decode("utf-8", "surrogateescape")
@http_version.setter
def http_version(self, http_version: Union[str, bytes]) -> None:
def http_version(self, http_version: str | bytes) -> None:
self.data.http_version = strutils.always_bytes(
http_version, "utf-8", "surrogateescape"
)
@ -302,18 +300,18 @@ class Message(serializable.Serializable):
self.data.headers = h
@property
def trailers(self) -> Optional[Headers]:
def trailers(self) -> Headers | None:
"""
The [HTTP trailers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Trailer).
"""
return self.data.trailers
@trailers.setter
def trailers(self, h: Optional[Headers]) -> None:
def trailers(self, h: Headers | None) -> None:
self.data.trailers = h
@property
def raw_content(self) -> Optional[bytes]:
def raw_content(self) -> bytes | None:
"""
The raw (potentially compressed) HTTP message body.
@ -324,11 +322,11 @@ class Message(serializable.Serializable):
return self.data.content
@raw_content.setter
def raw_content(self, content: Optional[bytes]) -> None:
def raw_content(self, content: bytes | None) -> None:
self.data.content = content
@property
def content(self) -> Optional[bytes]:
def content(self) -> bytes | None:
"""
The uncompressed HTTP message body as bytes.
@ -339,11 +337,11 @@ class Message(serializable.Serializable):
return self.get_content()
@content.setter
def content(self, value: Optional[bytes]) -> None:
def content(self, value: bytes | None) -> None:
self.set_content(value)
@property
def text(self) -> Optional[str]:
def text(self) -> str | None:
"""
The uncompressed and decoded HTTP message body as text.
@ -354,10 +352,10 @@ class Message(serializable.Serializable):
return self.get_text()
@text.setter
def text(self, value: Optional[str]) -> None:
def text(self, value: str | None) -> None:
self.set_text(value)
def set_content(self, value: Optional[bytes]) -> None:
def set_content(self, value: bytes | None) -> None:
if value is None:
self.raw_content = None
return
@ -382,7 +380,7 @@ class Message(serializable.Serializable):
else:
self.headers["content-length"] = str(len(self.raw_content))
def get_content(self, strict: bool = True) -> Optional[bytes]:
def get_content(self, strict: bool = True) -> bytes | None:
"""
Similar to `Message.content`, but does not raise if `strict` is `False`.
Instead, the compressed message body is returned as-is.
@ -404,7 +402,7 @@ class Message(serializable.Serializable):
else:
return self.raw_content
def _get_content_type_charset(self) -> Optional[str]:
def _get_content_type_charset(self) -> str | None:
ct = parse_content_type(self.headers.get("content-type", ""))
if ct:
return ct[2].get("charset")
@ -438,7 +436,7 @@ class Message(serializable.Serializable):
return enc
def set_text(self, text: Optional[str]) -> None:
def set_text(self, text: str | None) -> None:
if text is None:
self.content = None
return
@ -458,7 +456,7 @@ class Message(serializable.Serializable):
enc = "utf8"
self.content = text.encode(enc, "surrogateescape")
def get_text(self, strict: bool = True) -> Optional[str]:
def get_text(self, strict: bool = True) -> str | None:
"""
Similar to `Message.text`, but does not raise if `strict` is `False`.
Instead, the message body is returned as surrogate-escaped UTF-8.
@ -486,14 +484,14 @@ class Message(serializable.Serializable):
self.data.timestamp_start = timestamp_start
@property
def timestamp_end(self) -> Optional[float]:
def timestamp_end(self) -> float | None:
"""
*Timestamp:* Last byte received.
"""
return self.data.timestamp_end
@timestamp_end.setter
def timestamp_end(self, timestamp_end: Optional[float]):
def timestamp_end(self, timestamp_end: float | None):
self.data.timestamp_end = timestamp_end
def decode(self, strict: bool = True) -> None:
@ -558,11 +556,11 @@ class Request(Message):
authority: bytes,
path: bytes,
http_version: bytes,
headers: Union[Headers, tuple[tuple[bytes, bytes], ...]],
content: Optional[bytes],
trailers: Union[Headers, tuple[tuple[bytes, bytes], ...], None],
headers: Headers | tuple[tuple[bytes, bytes], ...],
content: bytes | None,
trailers: Headers | tuple[tuple[bytes, bytes], ...] | None,
timestamp_start: float,
timestamp_end: Optional[float],
timestamp_end: float | None,
):
# auto-convert invalid types to retain compatibility with older code.
if isinstance(host, bytes):
@ -613,12 +611,10 @@ class Request(Message):
cls,
method: str,
url: str,
content: Union[bytes, str] = "",
headers: Union[
Headers,
dict[Union[str, bytes], Union[str, bytes]],
Iterable[tuple[bytes, bytes]],
] = (),
content: bytes | str = "",
headers: (
Headers | dict[str | bytes, str | bytes] | Iterable[tuple[bytes, bytes]]
) = (),
) -> "Request":
"""
Simplified API for creating request objects.
@ -693,7 +689,7 @@ class Request(Message):
return self.data.method.decode("utf-8", "surrogateescape").upper()
@method.setter
def method(self, val: Union[str, bytes]) -> None:
def method(self, val: str | bytes) -> None:
self.data.method = always_bytes(val, "utf-8", "surrogateescape")
@property
@ -704,7 +700,7 @@ class Request(Message):
return self.data.scheme.decode("utf-8", "surrogateescape")
@scheme.setter
def scheme(self, val: Union[str, bytes]) -> None:
def scheme(self, val: str | bytes) -> None:
self.data.scheme = always_bytes(val, "utf-8", "surrogateescape")
@property
@ -726,7 +722,7 @@ class Request(Message):
return self.data.authority.decode("utf8", "surrogateescape")
@authority.setter
def authority(self, val: Union[str, bytes]) -> None:
def authority(self, val: str | bytes) -> None:
if isinstance(val, str):
try:
val = val.encode("idna", "strict")
@ -748,12 +744,12 @@ class Request(Message):
return self.data.host
@host.setter
def host(self, val: Union[str, bytes]) -> None:
def host(self, val: str | bytes) -> None:
self.data.host = always_str(val, "idna", "strict")
self._update_host_and_authority()
@property
def host_header(self) -> Optional[str]:
def host_header(self) -> str | None:
"""
The request's host/authority header.
@ -768,7 +764,7 @@ class Request(Message):
return self.data.headers.get("Host", None)
@host_header.setter
def host_header(self, val: Union[None, str, bytes]) -> None:
def host_header(self, val: None | str | bytes) -> None:
if val is None:
if self.is_http2 or self.is_http3:
self.data.authority = b""
@ -814,7 +810,7 @@ class Request(Message):
return self.data.path.decode("utf-8", "surrogateescape")
@path.setter
def path(self, val: Union[str, bytes]) -> None:
def path(self, val: str | bytes) -> None:
self.data.path = always_bytes(val, "utf-8", "surrogateescape")
@property
@ -829,7 +825,7 @@ class Request(Message):
return url.unparse(self.scheme, self.host, self.port, self.path)
@url.setter
def url(self, val: Union[str, bytes]) -> None:
def url(self, val: str | bytes) -> None:
val = always_str(val, "utf-8", "surrogateescape")
self.scheme, self.host, self.port, self.path = url.parse(val)
@ -1050,11 +1046,11 @@ class Response(Message):
http_version: bytes,
status_code: int,
reason: bytes,
headers: Union[Headers, tuple[tuple[bytes, bytes], ...]],
content: Optional[bytes],
trailers: Union[None, Headers, tuple[tuple[bytes, bytes], ...]],
headers: Headers | tuple[tuple[bytes, bytes], ...],
content: bytes | None,
trailers: None | Headers | tuple[tuple[bytes, bytes], ...],
timestamp_start: float,
timestamp_end: Optional[float],
timestamp_end: float | None,
):
# auto-convert invalid types to retain compatibility with older code.
if isinstance(http_version, str):
@ -1093,10 +1089,10 @@ class Response(Message):
def make(
cls,
status_code: int = 200,
content: Union[bytes, str] = b"",
headers: Union[
Headers, Mapping[str, Union[str, bytes]], Iterable[tuple[bytes, bytes]]
] = (),
content: bytes | str = b"",
headers: (
Headers | Mapping[str, str | bytes] | Iterable[tuple[bytes, bytes]]
) = (),
) -> "Response":
"""
Simplified API for creating response objects.
@ -1165,7 +1161,7 @@ class Response(Message):
return self.data.reason.decode("ISO-8859-1")
@reason.setter
def reason(self, reason: Union[str, bytes]) -> None:
def reason(self, reason: str | bytes) -> None:
self.data.reason = strutils.always_bytes(reason, "ISO-8859-1")
def _get_cookies(self):
@ -1183,9 +1179,7 @@ class Response(Message):
@property
def cookies(
self,
) -> multidict.MultiDictView[
str, tuple[str, multidict.MultiDict[str, Optional[str]]]
]:
) -> multidict.MultiDictView[str, tuple[str, multidict.MultiDict[str, str | None]]]:
"""
The response cookies. A possibly empty `MultiDictView`, where the keys are cookie
name strings, and values are `(cookie value, attributes)` tuples. Within
@ -1245,9 +1239,9 @@ class HTTPFlow(flow.Flow):
request: Request
"""The client's HTTP request."""
response: Optional[Response] = None
response: Response | None = None
"""The server's HTTP response."""
error: Optional[flow.Error] = None
error: flow.Error | None = None
"""
A connection or protocol error affecting this flow.
@ -1256,7 +1250,7 @@ class HTTPFlow(flow.Flow):
from the server, but there was an error sending it back to the client.
"""
websocket: Optional[WebSocketData] = None
websocket: WebSocketData | None = None
"""
If this HTTP flow initiated a WebSocket connection, this attribute contains all associated WebSocket data.
"""

View File

@ -8,7 +8,6 @@ version number, this prevents issues with developer builds and snapshots.
import copy
import uuid
from typing import Any
from typing import Union
from mitmproxy import version
from mitmproxy.utils import strutils
@ -491,9 +490,7 @@ converters = {
}
def migrate_flow(
flow_data: dict[Union[bytes, str], Any]
) -> dict[Union[bytes, str], Any]:
def migrate_flow(flow_data: dict[bytes | str, Any]) -> dict[bytes | str, Any]:
while True:
flow_version = flow_data.get(b"version", flow_data.get("version"))

View File

@ -1,7 +1,6 @@
import asyncio
import logging
import traceback
from typing import Optional
from . import ctx as mitmproxy_ctx
from .proxy.mode_specs import ReverseMode
@ -26,7 +25,7 @@ class Master:
def __init__(
self,
opts: options.Options,
event_loop: Optional[asyncio.AbstractEventLoop] = None,
event_loop: asyncio.AbstractEventLoop | None = None,
):
self.options: options.Options = opts or options.Options()
self.commands = command.CommandManager(self)

View File

@ -7,7 +7,6 @@ import gzip
import zlib
from io import BytesIO
from typing import overload
from typing import Union
import brotli
import zstandard as zstd
@ -31,13 +30,13 @@ def decode(encoded: str, encoding: str, errors: str = "strict") -> str:
@overload
def decode(encoded: bytes, encoding: str, errors: str = "strict") -> Union[str, bytes]:
def decode(encoded: bytes, encoding: str, errors: str = "strict") -> str | bytes:
...
def decode(
encoded: Union[None, str, bytes], encoding: str, errors: str = "strict"
) -> Union[None, str, bytes]:
encoded: None | str | bytes, encoding: str, errors: str = "strict"
) -> None | str | bytes:
"""
Decode the given input object
@ -87,7 +86,7 @@ def encode(decoded: None, encoding: str, errors: str = "strict") -> None:
@overload
def encode(decoded: str, encoding: str, errors: str = "strict") -> Union[str, bytes]:
def encode(decoded: str, encoding: str, errors: str = "strict") -> str | bytes:
...
@ -97,8 +96,8 @@ def encode(decoded: bytes, encoding: str, errors: str = "strict") -> bytes:
def encode(
decoded: Union[None, str, bytes], encoding, errors="strict"
) -> Union[None, str, bytes]:
decoded: None | str | bytes, encoding, errors="strict"
) -> None | str | bytes:
"""
Encode the given input object

View File

@ -1,8 +1,7 @@
import collections
from typing import Optional
def parse_content_type(c: str) -> Optional[tuple[str, str, dict[str, str]]]:
def parse_content_type(c: str) -> tuple[str, str, dict[str, str]] | None:
"""
A simple parser for content-type values. Returns a (type, subtype,
parameters) tuple, where type and subtype are strings, and parameters

View File

@ -1,7 +1,6 @@
import re
import time
from collections.abc import Iterable
from typing import Optional
from mitmproxy.http import Headers
from mitmproxy.http import Request
@ -78,8 +77,8 @@ def validate_headers(headers: Headers) -> None:
def expected_http_body_size(
request: Request, response: Optional[Response] = None
) -> Optional[int]:
request: Request, response: Response | None = None
) -> int | None:
"""
Returns:
The expected body length:
@ -226,7 +225,7 @@ def _read_request_line(
) -> tuple[str, int, bytes, bytes, bytes, bytes, bytes]:
try:
method, target, http_version = line.split()
port: Optional[int]
port: int | None
if target == b"*" or target.startswith(b"/"):
scheme, authority, path = b"", b"", target

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import mimetypes
import re
import warnings
from typing import Optional
from urllib.parse import quote
from mitmproxy.net.http import headers
@ -47,7 +46,7 @@ def encode_multipart(content_type: str, parts: list[tuple[bytes, bytes]]) -> byt
def decode_multipart(
content_type: Optional[str], content: bytes
content_type: str | None, content: bytes
) -> list[tuple[bytes, bytes]]:
"""
Takes a multipart boundary encoded string and returns list of (key, value) tuples.

View File

@ -4,7 +4,6 @@ import re
import urllib.parse
from collections.abc import Sequence
from typing import AnyStr
from typing import Optional
from mitmproxy.net import check
from mitmproxy.net.check import is_valid_host
@ -147,7 +146,7 @@ def hostport(scheme: AnyStr, host: AnyStr, port: int) -> AnyStr:
return "%s:%d" % (host, port)
def default_port(scheme: AnyStr) -> Optional[int]:
def default_port(scheme: AnyStr) -> int | None:
return {
"http": 80,
b"http": 80,
@ -156,7 +155,7 @@ def default_port(scheme: AnyStr) -> Optional[int]:
}.get(scheme, None)
def parse_authority(authority: AnyStr, check: bool) -> tuple[str, Optional[int]]:
def parse_authority(authority: AnyStr, check: bool) -> tuple[str, int | None]:
"""Extract the host and port from host header/authority information
Raises:

View File

@ -1,13 +1,12 @@
import os
import threading
from collections.abc import Callable
from collections.abc import Iterable
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import BinaryIO
from typing import Callable
from typing import Optional
import certifi
from OpenSSL import SSL
@ -55,7 +54,7 @@ DEFAULT_OPTIONS = SSL.OP_CIPHER_SERVER_PREFERENCE | SSL.OP_NO_COMPRESSION
class MasterSecretLogger:
def __init__(self, filename: Path):
self.filename = filename.expanduser()
self.f: Optional[BinaryIO] = None
self.f: BinaryIO | None = None
self.lock = threading.Lock()
# required for functools.wraps, which pyOpenSSL uses.
@ -76,7 +75,7 @@ class MasterSecretLogger:
self.f.close()
def make_master_secret_logger(filename: Optional[str]) -> Optional[MasterSecretLogger]:
def make_master_secret_logger(filename: str | None) -> MasterSecretLogger | None:
if filename:
return MasterSecretLogger(Path(filename))
return None
@ -92,7 +91,7 @@ def _create_ssl_context(
method: Method,
min_version: Version,
max_version: Version,
cipher_list: Optional[Iterable[str]],
cipher_list: Iterable[str] | None,
) -> SSL.Context:
context = SSL.Context(method.value)
@ -127,11 +126,11 @@ def create_proxy_server_context(
method: Method,
min_version: Version,
max_version: Version,
cipher_list: Optional[tuple[str, ...]],
cipher_list: tuple[str, ...] | None,
verify: Verify,
ca_path: Optional[str],
ca_pemfile: Optional[str],
client_cert: Optional[str],
ca_path: str | None,
ca_pemfile: str | None,
client_cert: str | None,
) -> SSL.Context:
context: SSL.Context = _create_ssl_context(
method=method,
@ -167,9 +166,9 @@ def create_client_proxy_context(
method: Method,
min_version: Version,
max_version: Version,
cipher_list: Optional[tuple[str, ...]],
chain_file: Optional[Path],
alpn_select_callback: Optional[Callable[[SSL.Connection, list[bytes]], Any]],
cipher_list: tuple[str, ...] | None,
chain_file: Path | None,
alpn_select_callback: Callable[[SSL.Connection, list[bytes]], Any] | None,
request_client_cert: bool,
extra_chain_certs: tuple[certs.Cert, ...],
dhparams: certs.DHParams,

View File

@ -3,10 +3,9 @@ from __future__ import annotations
import asyncio
import logging
import socket
from collections.abc import Callable
from typing import Any
from typing import Callable
from typing import cast
from typing import Optional
from typing import Union
import mitmproxy_rs
@ -267,7 +266,7 @@ async def start_server(
async def open_connection(
host: str, port: int, *, local_addr: Optional[Address] = None
host: str, port: int, *, local_addr: Address | None = None
) -> tuple[DatagramReader, DatagramWriter]:
"""UDP variant of asyncio.open_connection."""

View File

@ -13,7 +13,6 @@ from dataclasses import dataclass
from typing import Any
from typing import Optional
from typing import TextIO
from typing import Union
import ruamel.yaml
@ -34,10 +33,10 @@ class _Option:
def __init__(
self,
name: str,
typespec: Union[type, object], # object for Optional[x], which is not a type.
typespec: type | object, # object for Optional[x], which is not a type.
default: Any,
help: str,
choices: Optional[Sequence[str]],
choices: Sequence[str] | None,
) -> None:
typecheck.check_option_type(name, default, typespec)
self.name = name
@ -123,10 +122,10 @@ class OptManager:
def add_option(
self,
name: str,
typespec: Union[type, object],
typespec: type | object,
default: Any,
help: str,
choices: Optional[Sequence[str]] = None,
choices: Sequence[str] | None = None,
) -> None:
self._options[name] = _Option(name, typespec, default, help, choices)
self.changed.send(updated={name})
@ -373,7 +372,7 @@ class OptManager:
f"Received multiple values for {o.name}: {values}"
)
optstr: Optional[str]
optstr: str | None
if values:
optstr = values[0]
else:

View File

@ -1,8 +1,7 @@
import re
import socket
import sys
from typing import Callable
from typing import Optional
from collections.abc import Callable
def init_transparent_mode() -> None:
@ -11,7 +10,7 @@ def init_transparent_mode() -> None:
"""
original_addr: Optional[Callable[[socket.socket], tuple[str, int]]]
original_addr: Callable[[socket.socket], tuple[str, int]] | None
"""
Get the original destination for the given socket.
This function will be None if transparent mode is not supported.

View File

@ -15,7 +15,6 @@ from typing import Any
from typing import cast
from typing import ClassVar
from typing import IO
from typing import Optional
import pydivert.consts
@ -294,7 +293,7 @@ class Redirect(threading.Thread):
def shutdown(self):
self.windivert.close()
def recv(self) -> Optional[pydivert.Packet]:
def recv(self) -> pydivert.Packet | None:
"""
Convenience function that receives a packet from the passed handler and handles error codes.
If the process has been shut down, None is returned.
@ -402,9 +401,9 @@ class TransparentProxy:
which mitmproxy sees, but this would remove the correct client info from mitmproxy.
"""
local: Optional[RedirectLocal] = None
local: RedirectLocal | None = None
# really weird linting error here.
forward: Optional[Redirect] = None
forward: Redirect | None = None
response: Redirect
icmp: Redirect
@ -418,7 +417,7 @@ class TransparentProxy:
local: bool = True,
forward: bool = True,
proxy_port: int = 8080,
filter: Optional[str] = "tcp.DstPort == 80 or tcp.DstPort == 443",
filter: str | None = "tcp.DstPort == 80 or tcp.DstPort == 443",
) -> None:
self.proxy_port = proxy_port
self.filter = (

View File

@ -9,7 +9,6 @@ from dataclasses import dataclass
from dataclasses import is_dataclass
from typing import Any
from typing import Generic
from typing import Optional
from typing import TypeVar
from mitmproxy import flow
@ -106,7 +105,7 @@ command_reply_subclasses: dict[commands.Command, type[CommandCompleted]] = {}
@dataclass(repr=False)
class OpenConnectionCompleted(CommandCompleted):
command: commands.OpenConnection
reply: Optional[str]
reply: str | None
"""error message"""

View File

@ -11,7 +11,6 @@ from logging import DEBUG
from typing import Any
from typing import ClassVar
from typing import NamedTuple
from typing import Optional
from typing import TypeVar
from mitmproxy.connection import Connection
@ -60,7 +59,7 @@ class Layer:
__last_debug_message: ClassVar[str] = ""
context: Context
_paused: Optional[Paused]
_paused: Paused | None
"""
If execution is currently paused, this attribute stores the paused coroutine
and the command for which we are expecting a reply.
@ -70,7 +69,7 @@ class Layer:
All events that have occurred since execution was paused.
These will be replayed to ._child_layer once we resume.
"""
debug: Optional[str] = None
debug: str | None = None
"""
Enable debug logging by assigning a prefix string for log messages.
Different amounts of whitespace for different layers work well.
@ -242,7 +241,7 @@ mevents = (
class NextLayer(Layer):
layer: Optional[Layer]
layer: Layer | None
"""The next layer. To be set by an addon."""
events: list[mevents.Event]

View File

@ -5,8 +5,6 @@ from dataclasses import dataclass
from functools import cached_property
from logging import DEBUG
from logging import WARNING
from typing import Optional
from typing import Union
import wsproto.handshake
@ -71,7 +69,7 @@ class HTTPMode(enum.Enum):
upstream = 3
def validate_request(mode: HTTPMode, request: http.Request) -> Optional[str]:
def validate_request(mode: HTTPMode, request: http.Request) -> str | None:
if request.scheme not in ("http", "https", ""):
return f"Invalid request scheme: {request.scheme}"
if mode is HTTPMode.transparent and request.method == "CONNECT":
@ -82,7 +80,7 @@ def validate_request(mode: HTTPMode, request: http.Request) -> Optional[str]:
return None
def is_h3_alpn(alpn: Optional[bytes]) -> bool:
def is_h3_alpn(alpn: bytes | None) -> bool:
return alpn == b"h3" or (alpn is not None and alpn.startswith(b"h3-"))
@ -95,7 +93,7 @@ class GetHttpConnection(HttpCommand):
blocking = True
address: tuple[str, int]
tls: bool
via: Optional[server_spec.ServerSpec]
via: server_spec.ServerSpec | None
transport_protocol: TransportProtocol = "tcp"
def __hash__(self):
@ -114,7 +112,7 @@ class GetHttpConnection(HttpCommand):
@dataclass
class GetHttpConnectionCompleted(events.CommandCompleted):
command: GetHttpConnection
reply: Union[tuple[None, str], tuple[Connection, None]]
reply: tuple[None, str] | tuple[Connection, None]
"""connection object, error message"""
@ -125,7 +123,7 @@ class RegisterHttpConnection(HttpCommand):
"""
connection: Connection
err: Optional[str]
err: str | None
@dataclass
@ -149,7 +147,7 @@ class HttpStream(layer.Layer):
response_body_buf: bytes
flow: http.HTTPFlow
stream_id: StreamId
child_layer: Optional[layer.Layer] = None
child_layer: layer.Layer | None = None
@cached_property
def mode(self) -> HTTPMode:
@ -301,7 +299,7 @@ class HttpStream(layer.Layer):
@expect(RequestData, RequestTrailers, RequestEndOfMessage)
def state_stream_request_body(
self, event: Union[RequestData, RequestEndOfMessage]
self, event: RequestData | RequestEndOfMessage
) -> layer.CommandGenerator[None]:
if isinstance(event, RequestData):
if callable(self.flow.request.stream):
@ -554,7 +552,7 @@ class HttpStream(layer.Layer):
# Step 1: Determine the expected body size. This can either come from a known content-length header,
# or from the amount of currently buffered bytes (e.g. for chunked encoding).
response = not request
expected_size: Optional[int]
expected_size: int | None
# the 'late' case: we already started consuming the body
if request and self.request_body_buf:
expected_size = len(self.request_body_buf)
@ -652,7 +650,7 @@ class HttpStream(layer.Layer):
return False
def handle_protocol_error(
self, event: Union[RequestProtocolError, ResponseProtocolError]
self, event: RequestProtocolError | ResponseProtocolError
) -> layer.CommandGenerator[None]:
is_client_error_but_we_already_talk_upstream = (
isinstance(event, RequestProtocolError)
@ -930,7 +928,7 @@ class HttpLayer(layer.Layer):
def event_to_child(
self,
child: Union[layer.Layer, HttpStream],
child: layer.Layer | HttpStream,
event: events.Event,
) -> layer.CommandGenerator[None]:
for command in child.handle_event(event):
@ -1068,7 +1066,7 @@ class HttpLayer(layer.Layer):
) -> layer.CommandGenerator[None]:
waiting = self.waiting_for_establishment.pop(command.connection)
reply: Union[tuple[None, str], tuple[Connection, None]]
reply: tuple[None, str] | tuple[Connection, None]
if command.err:
reply = (None, command.err)
else:
@ -1098,7 +1096,7 @@ class HttpClient(layer.Layer):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
err: Optional[str]
err: str | None
if self.context.server.connected:
err = None
else:

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional
from ._base import HttpEvent
from mitmproxy import http
@ -15,7 +14,7 @@ class RequestHeaders(HttpEvent):
us to set END_STREAM on headers already (and some servers - Akamai - implicitly expect that).
In either case, this event will nonetheless be followed by RequestEndOfMessage.
"""
replay_flow: Optional[HTTPFlow] = None
replay_flow: HTTPFlow | None = None
"""If set, the current request headers belong to a replayed flow, which should be reused."""

View File

@ -1,6 +1,5 @@
import abc
from typing import Callable
from typing import Optional
from collections.abc import Callable
from typing import Union
import h11
@ -39,19 +38,19 @@ TBodyReader = Union[ChunkedReader, Http10Reader, ContentLengthReader]
class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
stream_id: Optional[StreamId] = None
request: Optional[http.Request] = None
response: Optional[http.Response] = None
stream_id: StreamId | None = None
request: http.Request | None = None
response: http.Response | None = None
request_done: bool = False
response_done: bool = False
# this is a bit of a hack to make both mypy and PyCharm happy.
state: Union[Callable[[events.Event], layer.CommandGenerator[None]], Callable]
state: Callable[[events.Event], layer.CommandGenerator[None]] | Callable
body_reader: TBodyReader
buf: ReceiveBuffer
ReceiveProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]]
ReceiveData: type[Union[RequestData, ResponseData]]
ReceiveEndOfMessage: type[Union[RequestEndOfMessage, ResponseEndOfMessage]]
ReceiveProtocolError: type[RequestProtocolError | ResponseProtocolError]
ReceiveData: type[RequestData | ResponseData]
ReceiveEndOfMessage: type[RequestEndOfMessage | ResponseEndOfMessage]
def __init__(self, context: Context, conn: Connection):
super().__init__(context, conn)
@ -464,7 +463,7 @@ def should_make_pipe(request: http.Request, response: http.Response) -> bool:
return False
def make_body_reader(expected_size: Optional[int]) -> TBodyReader:
def make_body_reader(expected_size: int | None) -> TBodyReader:
if expected_size is None:
return ChunkedReader()
elif expected_size == -1:

View File

@ -5,8 +5,6 @@ from enum import Enum
from logging import DEBUG
from logging import ERROR
from typing import ClassVar
from typing import Optional
from typing import Union
import h2.config
import h2.connection
@ -74,10 +72,10 @@ class Http2Connection(HttpConnection):
streams: dict[int, StreamState]
"""keep track of all active stream ids to send protocol errors on teardown"""
ReceiveProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]]
ReceiveData: type[Union[RequestData, ResponseData]]
ReceiveTrailers: type[Union[RequestTrailers, ResponseTrailers]]
ReceiveEndOfMessage: type[Union[RequestEndOfMessage, ResponseEndOfMessage]]
ReceiveProtocolError: type[RequestProtocolError | ResponseProtocolError]
ReceiveData: type[RequestData | ResponseData]
ReceiveTrailers: type[RequestTrailers | ResponseTrailers]
ReceiveEndOfMessage: type[RequestEndOfMessage | ResponseEndOfMessage]
def __init__(self, context: Context, conn: Connection):
super().__init__(context, conn)
@ -454,7 +452,7 @@ class Http2Client(Http2Connection):
their_stream_id: dict[int, int]
stream_queue: collections.defaultdict[int, list[Event]]
"""Queue of streams that we haven't sent yet because we have reached MAX_CONCURRENT_STREAMS"""
provisional_max_concurrency: Optional[int] = 10
provisional_max_concurrency: int | None = 10
"""A provisional currency limit before we get the server's first settings frame."""
last_activity: float
"""Timestamp of when we've last seen network activity on this connection."""
@ -583,7 +581,7 @@ class Http2Client(Http2Connection):
# - 102 Processing is WebDAV only and also ignorable.
# - 103 Early Hints is not mission-critical.
headers = http.Headers(event.headers)
status: Union[str, int] = "<unknown status>"
status: str | int = "<unknown status>"
try:
status = int(headers[":status"])
reason = status_codes.RESPONSES.get(status, "")

View File

@ -1,6 +1,5 @@
import time
from abc import abstractmethod
from typing import Union
from aioquic.h3.connection import ErrorCode as H3ErrorCode
from aioquic.h3.connection import FrameUnexpected as H3FrameUnexpected
@ -47,10 +46,10 @@ from mitmproxy.proxy.utils import expect
class Http3Connection(HttpConnection):
h3_conn: LayeredH3Connection
ReceiveData: type[Union[RequestData, ResponseData]]
ReceiveEndOfMessage: type[Union[RequestEndOfMessage, ResponseEndOfMessage]]
ReceiveProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]]
ReceiveTrailers: type[Union[RequestTrailers, ResponseTrailers]]
ReceiveData: type[RequestData | ResponseData]
ReceiveEndOfMessage: type[RequestEndOfMessage | ResponseEndOfMessage]
ReceiveProtocolError: type[RequestProtocolError | ResponseProtocolError]
ReceiveTrailers: type[RequestTrailers | ResponseTrailers]
def __init__(self, context: context.Context, conn: connection.Connection):
super().__init__(context, conn)
@ -205,9 +204,7 @@ class Http3Connection(HttpConnection):
yield from ()
@abstractmethod
def parse_headers(
self, event: HeadersReceived
) -> Union[RequestHeaders, ResponseHeaders]:
def parse_headers(self, event: HeadersReceived) -> RequestHeaders | ResponseHeaders:
pass # pragma: no cover
@ -220,9 +217,7 @@ class Http3Server(Http3Connection):
def __init__(self, context: context.Context):
super().__init__(context, context.client)
def parse_headers(
self, event: HeadersReceived
) -> Union[RequestHeaders, ResponseHeaders]:
def parse_headers(self, event: HeadersReceived) -> RequestHeaders | ResponseHeaders:
# same as HTTP/2
(
host,
@ -281,9 +276,7 @@ class Http3Client(Http3Connection):
cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id]
yield cmd
def parse_headers(
self, event: HeadersReceived
) -> Union[RequestHeaders, ResponseHeaders]:
def parse_headers(self, event: HeadersReceived) -> RequestHeaders | ResponseHeaders:
# same as HTTP/2
status_code, headers = parse_h2_response_headers(event.headers)
response = http.Response(

View File

@ -1,6 +1,5 @@
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional
from aioquic.h3.connection import FrameUnexpected
from aioquic.h3.connection import H3Connection
@ -41,7 +40,7 @@ class TrailersReceived(H3Event):
stream_ended: bool
"Whether the STREAM frame had the FIN bit set."
push_id: Optional[int] = None
push_id: int | None = None
"The Push ID or `None` if this is not a push."
@ -57,7 +56,7 @@ class StreamReset(H3Event):
error_code: int
"""The error code indicating why the stream was reset."""
push_id: Optional[int] = None
push_id: int | None = None
"The Push ID or `None` if this is not a push."
@ -81,7 +80,7 @@ class MockQuic:
def close(
self,
error_code: int = QuicErrorCode.NO_ERROR,
frame_type: Optional[int] = None,
frame_type: int | None = None,
reason_phrase: str = "",
) -> None:
# we'll get closed if a protocol error occurs in `H3Connection.handle_event`
@ -134,7 +133,7 @@ class LayeredH3Connection(H3Connection):
def _handle_request_or_push_frame(
self,
frame_type: int,
frame_data: Optional[bytes],
frame_data: bytes | None,
stream: H3Stream,
stream_ended: bool,
) -> list[H3Event]:
@ -156,7 +155,7 @@ class LayeredH3Connection(H3Connection):
def close_connection(
self,
error_code: int = QuicErrorCode.NO_ERROR,
frame_type: Optional[int] = None,
frame_type: int | None = None,
reason_phrase: str = "",
) -> None:
"""Closes the underlying QUIC connection and ignores any incoming events."""
@ -177,7 +176,7 @@ class LayeredH3Connection(H3Connection):
return self._quic.get_next_available_stream_id(is_unidirectional)
def get_open_stream_ids(self, push_id: Optional[int]) -> Iterable[int]:
def get_open_stream_ids(self, push_id: int | None) -> Iterable[int]:
"""Iterates over all non-special open streams, optionally for a given push id."""
return (

View File

@ -1,6 +1,5 @@
import time
from logging import DEBUG
from typing import Optional
from h11._receivebuffer import ReceiveBuffer
@ -74,7 +73,7 @@ class HttpUpstreamProxy(tunnel.TunnelLayer):
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, Optional[str]]]:
) -> layer.CommandGenerator[tuple[bool, str | None]]:
if not self.send_connect:
return (yield from super().receive_handshake_data(data))
self.buf += data

View File

@ -3,9 +3,8 @@ from __future__ import annotations
import socket
import struct
from abc import ABCMeta
from collections.abc import Callable
from dataclasses import dataclass
from typing import Callable
from typing import Optional
from mitmproxy import connection
from mitmproxy.proxy import commands
@ -39,7 +38,7 @@ class DestinationKnown(layer.Layer, metaclass=ABCMeta):
child_layer: layer.Layer
def finish_start(self) -> layer.CommandGenerator[Optional[str]]:
def finish_start(self) -> layer.CommandGenerator[str | None]:
if (
self.context.options.connection_strategy == "eager"
and self.context.server.address
@ -134,7 +133,7 @@ class Socks5Proxy(DestinationKnown):
def socks_err(
self,
message: str,
reply_code: Optional[int] = None,
reply_code: int | None = None,
) -> layer.CommandGenerator[None]:
if reply_code is not None:
yield commands.SendData(

View File

@ -1,13 +1,13 @@
from __future__ import annotations
import time
from collections.abc import Callable
from dataclasses import dataclass
from dataclasses import field
from logging import DEBUG
from logging import ERROR
from logging import WARNING
from ssl import VerifyMode
from typing import Callable
from aioquic.buffer import Buffer as QuicBuffer
from aioquic.h3.connection import ErrorCode as H3ErrorCode

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional
from mitmproxy import flow
from mitmproxy import tcp
@ -64,7 +63,7 @@ class TCPLayer(layer.Layer):
Simple TCP layer that just relays messages right now.
"""
flow: Optional[tcp.TCPFlow]
flow: tcp.TCPFlow | None
def __init__(self, context: Context, ignore: bool = False):
super().__init__(context)

View File

@ -6,7 +6,6 @@ from logging import DEBUG
from logging import ERROR
from logging import INFO
from logging import WARNING
from typing import Optional
from OpenSSL import SSL
@ -63,7 +62,7 @@ def handshake_record_contents(data: bytes) -> Iterator[bytes]:
offset += record_size
def get_client_hello(data: bytes) -> Optional[bytes]:
def get_client_hello(data: bytes) -> bytes | None:
"""
Read all TLS records that contain the initial ClientHello.
Returns the raw handshake packet bytes, without TLS record headers.
@ -78,7 +77,7 @@ def get_client_hello(data: bytes) -> Optional[bytes]:
return None
def parse_client_hello(data: bytes) -> Optional[ClientHello]:
def parse_client_hello(data: bytes) -> ClientHello | None:
"""
Check if the supplied bytes contain a full ClientHello message,
and if so, parse it.
@ -136,7 +135,7 @@ def dtls_handshake_record_contents(data: bytes) -> Iterator[bytes]:
offset += record_size
def get_dtls_client_hello(data: bytes) -> Optional[bytes]:
def get_dtls_client_hello(data: bytes) -> bytes | None:
"""
Read all DTLS records that contain the initial ClientHello.
Returns the raw handshake packet bytes, without TLS record headers.
@ -154,7 +153,7 @@ def get_dtls_client_hello(data: bytes) -> Optional[bytes]:
return None
def dtls_parse_client_hello(data: bytes) -> Optional[ClientHello]:
def dtls_parse_client_hello(data: bytes) -> ClientHello | None:
"""
Check if the supplied bytes contain a full ClientHello message,
and if so, parse it.
@ -309,7 +308,7 @@ class TLSLayer(tunnel.TunnelLayer):
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, Optional[str]]]:
) -> layer.CommandGenerator[tuple[bool, str | None]]:
# bio_write errors for b"", so we need to check first if we actually received something.
if data:
self.tls.bio_write(data)
@ -466,9 +465,7 @@ class ServerTLSLayer(TLSLayer):
wait_for_clienthello: bool = False
def __init__(
self, context: context.Context, conn: Optional[connection.Server] = None
):
def __init__(self, context: context.Context, conn: connection.Server | None = None):
super().__init__(context, conn or context.server)
def start_handshake(self) -> layer.CommandGenerator[None]:
@ -558,7 +555,7 @@ class ClientTLSLayer(TLSLayer):
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, Optional[str]]]:
) -> layer.CommandGenerator[tuple[bool, str | None]]:
if self.client_hello_parsed:
return (yield from super().receive_handshake_data(data))
self.recv_buffer.extend(data)
@ -621,7 +618,7 @@ class ClientTLSLayer(TLSLayer):
self.recv_buffer.clear()
return ret
def start_server_tls(self) -> layer.CommandGenerator[Optional[str]]:
def start_server_tls(self) -> layer.CommandGenerator[str | None]:
"""
We often need information from the upstream connection to establish TLS with the client.
For example, we need to check if the client does ALPN or not.

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional
from mitmproxy import flow
from mitmproxy import udp
@ -63,7 +62,7 @@ class UDPLayer(layer.Layer):
Simple UDP layer that just relays messages right now.
"""
flow: Optional[udp.UDPFlow]
flow: udp.UDPFlow | None
def __init__(self, context: Context, ignore: bool = False):
super().__init__(context)

View File

@ -17,8 +17,6 @@ from collections.abc import Callable
from collections.abc import MutableMapping
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional
from typing import Union
import mitmproxy_rs
from OpenSSL import SSL
@ -91,13 +89,13 @@ class TimeoutWatchdog:
@dataclass
class ConnectionIO:
handler: Optional[asyncio.Task] = None
reader: Optional[
Union[asyncio.StreamReader, udp.DatagramReader, mitmproxy_rs.TcpStream]
] = None
writer: Optional[
Union[asyncio.StreamWriter, udp.DatagramWriter, mitmproxy_rs.TcpStream]
] = None
handler: asyncio.Task | None = None
reader: None | (
asyncio.StreamReader | udp.DatagramReader | mitmproxy_rs.TcpStream
) = None
writer: None | (
asyncio.StreamWriter | udp.DatagramWriter | mitmproxy_rs.TcpStream
) = None
class ConnectionHandler(metaclass=abc.ABCMeta):
@ -201,8 +199,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
return
async with self.max_conns[command.connection.address]:
reader: Union[asyncio.StreamReader, udp.DatagramReader]
writer: Union[asyncio.StreamWriter, udp.DatagramWriter]
reader: asyncio.StreamReader | udp.DatagramReader
writer: asyncio.StreamWriter | udp.DatagramWriter
try:
command.connection.timestamp_start = time.time()
if command.connection.transport_protocol == "tcp":
@ -434,8 +432,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
class LiveConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):
def __init__(
self,
reader: Union[asyncio.StreamReader, mitmproxy_rs.TcpStream],
writer: Union[asyncio.StreamWriter, mitmproxy_rs.TcpStream],
reader: asyncio.StreamReader | mitmproxy_rs.TcpStream,
writer: asyncio.StreamWriter | mitmproxy_rs.TcpStream,
options: moptions.Options,
mode: mode_specs.ProxyMode,
) -> None:

View File

@ -1,7 +1,6 @@
import time
from enum import auto
from enum import Enum
from typing import Optional
from typing import Union
from mitmproxy import connection
@ -31,7 +30,7 @@ class TunnelLayer(layer.Layer):
conn: connection.Connection
"""The 'inner' connection which provides data I/O"""
tunnel_state: TunnelState = TunnelState.INACTIVE
command_to_reply_to: Optional[commands.OpenConnection] = None
command_to_reply_to: commands.OpenConnection | None = None
_event_queue: list[events.Event]
"""
If the connection already exists when we receive the start event,
@ -98,7 +97,7 @@ class TunnelLayer(layer.Layer):
else:
yield from self.event_to_child(event)
def _handshake_finished(self, err: Optional[str]):
def _handshake_finished(self, err: str | None):
if err:
self.tunnel_state = TunnelState.CLOSED
else:
@ -159,7 +158,7 @@ class TunnelLayer(layer.Layer):
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, Optional[str]]]:
) -> layer.CommandGenerator[tuple[bool, str | None]]:
"""returns a (done, err) tuple"""
yield from ()
return True, None

View File

@ -1,6 +1,4 @@
import uuid
from typing import Optional
from typing import Union
from wsproto.frame_protocol import Opcode
@ -123,11 +121,11 @@ def twebsocketflow(
def tdnsflow(
*,
client_conn: Optional[connection.Client] = None,
server_conn: Optional[connection.Server] = None,
req: Optional[dns.Message] = None,
resp: Union[bool, dns.Message] = False,
err: Union[bool, flow.Error] = False,
client_conn: connection.Client | None = None,
server_conn: connection.Server | None = None,
req: dns.Message | None = None,
resp: bool | dns.Message = False,
err: bool | flow.Error = False,
live: bool = True,
) -> dns.DNSFlow:
"""Create a DNS flow for testing."""
@ -160,12 +158,12 @@ def tdnsflow(
def tflow(
*,
client_conn: Optional[connection.Client] = None,
server_conn: Optional[connection.Server] = None,
req: Optional[http.Request] = None,
resp: Union[bool, http.Response] = False,
err: Union[bool, flow.Error] = False,
ws: Union[bool, websocket.WebSocketData] = False,
client_conn: connection.Client | None = None,
server_conn: connection.Server | None = None,
req: http.Request | None = None,
resp: bool | http.Response = False,
err: bool | flow.Error = False,
ws: bool | websocket.WebSocketData = False,
live: bool = True,
) -> http.HTTPFlow:
"""Create a flow for testing."""

View File

@ -1,6 +1,5 @@
import io
from dataclasses import dataclass
from typing import Optional
from kaitaistruct import KaitaiStream
from OpenSSL import SSL
@ -68,7 +67,7 @@ class ClientHello:
return self._client_hello.cipher_suites.cipher_suites
@property
def sni(self) -> Optional[str]:
def sni(self) -> str | None:
"""
The [Server Name Indication](https://en.wikipedia.org/wiki/Server_Name_Indication),
which indicates which hostname the client wants to connect to.
@ -142,7 +141,7 @@ class TlsData:
"""The affected connection."""
context: context.Context
"""The context object for this connection."""
ssl_conn: Optional[SSL.Connection] = None
ssl_conn: SSL.Connection | None = None
"""
The associated pyOpenSSL `SSL.Connection` object.
This will be set by an addon in the `tls_start_*` event hooks.

View File

@ -1,7 +1,6 @@
import abc
from collections.abc import Sequence
from typing import NamedTuple
from typing import Optional
import urwid
from urwid.text_layout import calc_coords
@ -54,7 +53,7 @@ class CommandBuffer:
self.text = start
# Cursor is always within the range [0:len(buffer)].
self._cursor = len(self.text)
self.completion: Optional[CompletionState] = None
self.completion: CompletionState | None = None
@property
def cursor(self) -> int:

View File

@ -3,8 +3,6 @@ import math
import platform
from collections.abc import Iterable
from functools import lru_cache
from typing import Optional
from typing import Union
import urwid.util
from publicsuffix2 import get_sld
@ -48,7 +46,7 @@ KEY_MAX = 30
def format_keyvals(
entries: Iterable[tuple[str, Union[None, str, urwid.Widget]]],
entries: Iterable[tuple[str, None | str | urwid.Widget]],
key_format: str = "key",
value_format: str = "text",
indent: int = 0,
@ -360,7 +358,7 @@ def format_size(num_bytes: int) -> tuple[str, str]:
def format_left_indicators(*, focused: bool, intercepted: bool, timestamp: float):
indicators: list[Union[str, tuple[str, str]]] = []
indicators: list[str | tuple[str, str]] = []
if focused:
indicators.append(("focus", ">>"))
else:
@ -378,7 +376,7 @@ def format_right_indicators(
replay: bool,
marked: str,
):
indicators: list[Union[str, tuple[str, str]]] = []
indicators: list[str | tuple[str, str]] = []
if replay:
indicators.append(("replay", SYMBOL_REPLAY))
else:
@ -406,12 +404,12 @@ def format_http_flow_list(
request_timestamp: float,
request_is_push_promise: bool,
intercepted: bool,
response_code: Optional[int],
response_reason: Optional[str],
response_content_length: Optional[int],
response_content_type: Optional[str],
duration: Optional[float],
error_message: Optional[str],
response_code: int | None,
response_reason: str | None,
response_content_length: int | None,
response_content_type: str | None,
duration: float | None,
error_message: str | None,
) -> urwid.Widget:
req = []
@ -494,7 +492,7 @@ def format_http_flow_table(
render_mode: RenderMode,
focused: bool,
marked: str,
is_replay: Optional[str],
is_replay: str | None,
request_method: str,
request_scheme: str,
request_host: str,
@ -504,12 +502,12 @@ def format_http_flow_table(
request_timestamp: float,
request_is_push_promise: bool,
intercepted: bool,
response_code: Optional[int],
response_reason: Optional[str],
response_content_length: Optional[int],
response_content_type: Optional[str],
duration: Optional[float],
error_message: Optional[str],
response_code: int | None,
response_reason: str | None,
response_content_length: int | None,
response_content_type: str | None,
duration: float | None,
error_message: str | None,
) -> urwid.Widget:
items = [
format_left_indicators(
@ -617,8 +615,8 @@ def format_message_flow(
client_address,
server_address,
total_size: int,
duration: Optional[float],
error_message: Optional[str],
duration: float | None,
error_message: str | None,
):
conn = f"{human.format_address(client_address)} <-> {human.format_address(server_address)}"
@ -669,16 +667,16 @@ def format_dns_flow(
focused: bool,
intercepted: bool,
marked: str,
is_replay: Optional[str],
is_replay: str | None,
op_code: str,
request_timestamp: float,
domain: str,
type: str,
response_code: Optional[str],
response_code: str | None,
response_code_http_equiv: int,
answer: Optional[str],
answer: str | None,
error_message: str,
duration: Optional[float],
duration: float | None,
):
items = []
@ -749,8 +747,8 @@ def format_flow(
relevant for display and call the render with only that. This assures that rows
are updated if the flow is changed.
"""
duration: Optional[float]
error_message: Optional[str]
duration: float | None
error_message: str | None
if f.error:
error_message = f.error.msg
else:
@ -783,7 +781,7 @@ def format_flow(
elif isinstance(f, DNSFlow):
if f.response:
duration = f.response.timestamp - f.request.timestamp
response_code_str: Optional[str] = dns.response_codes.to_str(
response_code_str: str | None = dns.response_codes.to_str(
f.response.response_code
)
response_code_http_equiv = dns.response_codes.http_equiv_status_code(
@ -815,14 +813,14 @@ def format_flow(
)
elif isinstance(f, HTTPFlow):
intercepted = f.intercepted
response_content_length: Optional[int]
response_content_length: int | None
if f.response:
if f.response.raw_content is not None:
response_content_length = len(f.response.raw_content)
else:
response_content_length = None
response_code: Optional[int] = f.response.status_code
response_reason: Optional[str] = f.response.reason
response_code: int | None = f.response.status_code
response_reason: str | None = f.response.reason
response_content_type = f.response.headers.get("content-type")
if f.response.timestamp_end:
duration = max(

View File

@ -1,5 +1,3 @@
from typing import Optional
import urwid
import mitmproxy.flow
@ -27,8 +25,8 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
sc = flow.server_conn
cc = flow.client_conn
req: Optional[http.Request]
resp: Optional[http.Response]
req: http.Request | None
resp: http.Response | None
if isinstance(flow, http.HTTPFlow):
req = flow.request
resp = flow.response

View File

@ -1,5 +1,4 @@
from functools import lru_cache
from typing import Optional
import urwid
@ -70,7 +69,7 @@ class FlowListWalker(urwid.ListWalker):
self.master.view.focus.index = index
@lru_cache(maxsize=None)
def _get(self, pos: int) -> tuple[Optional[FlowItem], Optional[int]]:
def _get(self, pos: int) -> tuple[FlowItem | None, int | None]:
if not self.master.view.inbounds(pos):
return None, None
return FlowItem(self.master, self.master.view[pos]), pos

View File

@ -2,7 +2,6 @@ import logging
import math
import sys
from functools import lru_cache
from typing import Optional
import urwid
@ -417,7 +416,7 @@ class FlowDetails(tabs.Tabs):
return searchable.Searchable(txt)
def dns_message_text(
self, type: str, message: Optional[dns.Message]
self, type: str, message: dns.Message | None
) -> searchable.Searchable:
# Keep in sync with web/src/js/components/FlowView/DnsMessages.tsx
if message:

View File

@ -9,7 +9,6 @@ from collections.abc import Sequence
from typing import Any
from typing import AnyStr
from typing import ClassVar
from typing import Optional
import urwid
@ -65,21 +64,21 @@ class Column(metaclass=abc.ABCMeta):
def blank(self) -> Any:
pass
def keypress(self, key: str, editor: "GridEditor") -> Optional[str]:
def keypress(self, key: str, editor: "GridEditor") -> str | None:
return key
class GridRow(urwid.WidgetWrap):
def __init__(
self,
focused: Optional[int],
focused: int | None,
editing: bool,
editor: "GridEditor",
values: tuple[Iterable[bytes], Container[int]],
) -> None:
self.focused = focused
self.editor = editor
self.edit_col: Optional[Cell] = None
self.edit_col: Cell | None = None
errors = values[1]
self.fields: Sequence[Any] = []
@ -128,7 +127,7 @@ class GridWalker(urwid.ListWalker):
self.editor = editor
self.focus = 0
self.focus_col = 0
self.edit_row: Optional[GridRow] = None
self.edit_row: GridRow | None = None
def _modified(self):
self.editor.show_empty_msg()
@ -360,7 +359,7 @@ class BaseGridEditor(urwid.WidgetWrap):
"""
return data
def is_error(self, col: int, val: Any) -> Optional[str]:
def is_error(self, col: int, val: Any) -> str | None:
"""
Return None, or a string error message.
"""

View File

@ -1,5 +1,4 @@
from typing import Any
from typing import Union
import urwid
@ -191,11 +190,7 @@ class DataViewer(base.GridEditor, layoutwidget.LayoutWidget):
def __init__(
self,
master,
vals: Union[
list[list[Any]],
list[Any],
Any,
],
vals: (list[list[Any]] | list[Any] | Any),
) -> None:
if vals is not None:
# Whatever vals is, make it a list of rows containing lists of column values.

View File

@ -2,7 +2,6 @@ import logging
import os
from collections.abc import Sequence
from functools import cache
from typing import Optional
import ruamel.yaml.error
@ -136,13 +135,13 @@ class Keymap:
self.bindings = [b for b in self.bindings if b != binding]
self._on_change()
def get(self, context: str, key: str) -> Optional[Binding]:
def get(self, context: str, key: str) -> Binding | None:
if context in self.keys:
return self.keys[context].get(key, None)
return None
@cache
def binding_for_help(self, help: str) -> Optional[Binding]:
def binding_for_help(self, help: str) -> Binding | None:
for b in self.bindings:
if b.help == help:
return b
@ -156,7 +155,7 @@ class Keymap:
multi.sort(key=lambda x: x.sortkey())
return single + multi
def handle(self, context: str, key: str) -> Optional[str]:
def handle(self, context: str, key: str) -> str | None:
"""
Returns the key if it has not been handled, or None.
"""
@ -166,7 +165,7 @@ class Keymap:
return None
return key
def handle_only(self, context: str, key: str) -> Optional[str]:
def handle_only(self, context: str, key: str) -> str | None:
"""
Like handle, but ignores global bindings. Returns the key if it has
not been handled, or None.

View File

@ -7,7 +7,6 @@ from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Optional
class Palette:
@ -92,7 +91,7 @@ class Palette:
"commander_hint",
]
_fields.extend(["gradient_%02d" % i for i in range(100)])
high: Optional[Mapping[str, Sequence[str]]] = None
high: Mapping[str, Sequence[str]] | None = None
low: Mapping[str, Sequence[str]]
def palette(self, transparent: bool):

View File

@ -2,7 +2,6 @@
This module is reponsible for drawing the quick key help at the bottom of mitmproxy.
"""
from dataclasses import dataclass
from typing import Optional
from typing import Union
import urwid
@ -51,7 +50,7 @@ class QuickHelp:
def make(
widget: type[urwid.Widget],
focused_flow: Optional[flow.Flow],
focused_flow: flow.Flow | None,
is_root_widget: bool,
) -> QuickHelp:
top_label = ""

View File

@ -2,7 +2,6 @@ from __future__ import annotations
from collections.abc import Callable
from functools import lru_cache
from typing import Optional
import urwid
@ -99,9 +98,7 @@ class ActionBar(urwid.WidgetWrap):
self.bottom._w = urwid.Text("")
self.prompting = callback
def sig_prompt_command(
self, partial: str = "", cursor: Optional[int] = None
) -> None:
def sig_prompt_command(self, partial: str = "", cursor: int | None = None) -> None:
signals.focus.send(section="footer")
self.top._w = commander.CommandEdit(
self.master,

Some files were not shown because too many files have changed in this diff Show More