diff --git a/docs/scripts/api-events.py b/docs/scripts/api-events.py index 94f65f0de..8f3a72e72 100644 --- a/docs/scripts/api-events.py +++ b/docs/scripts/api-events.py @@ -10,20 +10,21 @@ from mitmproxy import log from mitmproxy.proxy import layer from mitmproxy.proxy import server_hooks from mitmproxy.proxy.layers import dns -from mitmproxy.proxy.layers.http import _hooks as http from mitmproxy.proxy.layers import modes from mitmproxy.proxy.layers import quic from mitmproxy.proxy.layers import tcp from mitmproxy.proxy.layers import tls from mitmproxy.proxy.layers import udp from mitmproxy.proxy.layers import websocket +from mitmproxy.proxy.layers.http import _hooks as http known = set() def category(name: str, desc: str, hooks: list[type[hooks.Hook]]) -> None: all_params = [ - list(inspect.signature(hook.__init__, eval_str=True).parameters.values())[1:] for hook in hooks + list(inspect.signature(hook.__init__, eval_str=True).parameters.values())[1:] + for hook in hooks ] # slightly overengineered, but this was fun to write. ¯\_(ツ)_/¯ diff --git a/docs/scripts/clirecording/clidirector.py b/docs/scripts/clirecording/clidirector.py index db286b2b2..e861e9556 100644 --- a/docs/scripts/clirecording/clidirector.py +++ b/docs/scripts/clirecording/clidirector.py @@ -1,11 +1,12 @@ import json -from typing import NamedTuple, Optional - -import libtmux import random import subprocess import threading import time +from typing import NamedTuple +from typing import Optional + +import libtmux class InstructionSpec(NamedTuple): diff --git a/docs/scripts/clirecording/record.py b/docs/scripts/clirecording/record.py index 54ba1be2a..6e91674e8 100644 --- a/docs/scripts/clirecording/record.py +++ b/docs/scripts/clirecording/record.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 - -from clidirector import CliDirector import screenplays +from clidirector import CliDirector if __name__ == "__main__": diff --git a/docs/scripts/clirecording/screenplays.py b/docs/scripts/clirecording/screenplays.py index ea871e7a7..5f916dac1 100644 --- a/docs/scripts/clirecording/screenplays.py +++ b/docs/scripts/clirecording/screenplays.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 - from clidirector import CliDirector diff --git a/docs/scripts/examples.py b/docs/scripts/examples.py index 4dd742d50..953cd1fcc 100755 --- a/docs/scripts/examples.py +++ b/docs/scripts/examples.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 - import re from pathlib import Path diff --git a/docs/scripts/filters.py b/docs/scripts/filters.py index 32634196a..c002a4ed2 100755 --- a/docs/scripts/filters.py +++ b/docs/scripts/filters.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 - from mitmproxy import flowfilter diff --git a/docs/scripts/options.py b/docs/scripts/options.py index 3747d3fb7..6ee4c34af 100755 --- a/docs/scripts/options.py +++ b/docs/scripts/options.py @@ -1,8 +1,11 @@ #!/usr/bin/env python3 import asyncio -from mitmproxy import options, optmanager -from mitmproxy.tools import dump, console, web +from mitmproxy import options +from mitmproxy import optmanager +from mitmproxy.tools import console +from mitmproxy.tools import dump +from mitmproxy.tools import web masters = { "mitmproxy": console.master.ConsoleMaster, diff --git a/examples/addons/contentview-custom-grpc.py b/examples/addons/contentview-custom-grpc.py index c84da91c8..37ba99dd7 100644 --- a/examples/addons/contentview-custom-grpc.py +++ b/examples/addons/contentview-custom-grpc.py @@ -4,7 +4,9 @@ protobuf messages based on a user defined rule set. """ from mitmproxy import contentviews -from mitmproxy.contentviews.grpc import ViewGrpcProtobuf, ViewConfig, ProtoParser +from mitmproxy.contentviews.grpc import ProtoParser +from mitmproxy.contentviews.grpc import ViewConfig +from mitmproxy.contentviews.grpc import ViewGrpcProtobuf config: ViewConfig = ViewConfig() config.parser_rules = [ @@ -68,13 +70,13 @@ config.parser_rules = [ tag_prefixes=["1.5.1", "1.5.3", "1.5.4", "1.5.5", "1.5.6"], name="latitude", intended_decoding=ProtoParser.DecodedTypes.double, - ), # noqa: E501 + ), ProtoParser.ParserFieldDefinition( tag=".2", tag_prefixes=["1.5.1", "1.5.3", "1.5.4", "1.5.5", "1.5.6"], name="longitude", intended_decoding=ProtoParser.DecodedTypes.double, - ), # noqa: E501 + ), ProtoParser.ParserFieldDefinition(tag="7", name="app"), ], ), diff --git a/examples/addons/contentview.py b/examples/addons/contentview.py index a485c25a8..b96a81ca5 100644 --- a/examples/addons/contentview.py +++ b/examples/addons/contentview.py @@ -7,7 +7,8 @@ The content view API is explained in the mitmproxy.contentviews module. """ from typing import Optional -from mitmproxy import contentviews, flow +from mitmproxy import contentviews +from mitmproxy import flow from mitmproxy import http diff --git a/examples/addons/filter-flows.py b/examples/addons/filter-flows.py index b69406c50..fda546fde 100644 --- a/examples/addons/filter-flows.py +++ b/examples/addons/filter-flows.py @@ -2,10 +2,11 @@ Use mitmproxy's filter pattern in scripts. """ from __future__ import annotations + import logging -from mitmproxy import http from mitmproxy import flowfilter +from mitmproxy import http class Filter: diff --git a/examples/addons/http-stream-modify.py b/examples/addons/http-stream-modify.py index a200fe563..76cb33e51 100644 --- a/examples/addons/http-stream-modify.py +++ b/examples/addons/http-stream-modify.py @@ -7,7 +7,8 @@ Modifying streamed responses is tricky and brittle: - If you want to replace all occurrences of "foobar", make sure to catch the cases where one chunk ends with [...]foo" and the next starts with "bar[...]. """ -from typing import Iterable, Union +from collections.abc import Iterable +from typing import Union def modify(data: bytes) -> Union[bytes, Iterable[bytes]]: diff --git a/examples/addons/http-trailers.py b/examples/addons/http-trailers.py index 26a51f23b..4a7f56d61 100644 --- a/examples/addons/http-trailers.py +++ b/examples/addons/http-trailers.py @@ -6,7 +6,6 @@ the body is fully transmitted. Such trailers need to be announced in the initial headers by name, so the receiving endpoint can wait and read them after the body. """ - from mitmproxy import http from mitmproxy.http import Headers diff --git a/examples/addons/io-read-saved-flows.py b/examples/addons/io-read-saved-flows.py index f6a177be4..32da842d3 100644 --- a/examples/addons/io-read-saved-flows.py +++ b/examples/addons/io-read-saved-flows.py @@ -2,11 +2,13 @@ """ Read a mitmproxy dump file. """ -from mitmproxy import io, http -from mitmproxy.exceptions import FlowReadException import pprint import sys +from mitmproxy import http +from mitmproxy import io +from mitmproxy.exceptions import FlowReadException + with open(sys.argv[1], "rb") as logfile: freader = io.FlowReader(logfile) pp = pprint.PrettyPrinter(indent=4) diff --git a/examples/addons/io-write-flow-file.py b/examples/addons/io-write-flow-file.py index ecc0528e7..a348749de 100644 --- a/examples/addons/io-write-flow-file.py +++ b/examples/addons/io-write-flow-file.py @@ -11,7 +11,8 @@ import random import sys from typing import BinaryIO -from mitmproxy import io, http +from mitmproxy import http +from mitmproxy import io class Writer: diff --git a/examples/addons/log-events.py b/examples/addons/log-events.py index f5a1c91b2..900be1b26 100644 --- a/examples/addons/log-events.py +++ b/examples/addons/log-events.py @@ -8,4 +8,7 @@ def load(l): logging.info("This is some informative text.") logging.warning("This is a warning.") logging.error("This is an error.") - logging.log(ALERT, "This is an alert. It has the same urgency as info, but will also pop up in the status bar.") + logging.log( + ALERT, + "This is an alert. It has the same urgency as info, but will also pop up in the status bar.", + ) diff --git a/examples/addons/nonblocking.py b/examples/addons/nonblocking.py index ae59db80a..fb3d3fee2 100644 --- a/examples/addons/nonblocking.py +++ b/examples/addons/nonblocking.py @@ -3,7 +3,6 @@ Make events hooks non-blocking using async or @concurrent """ import asyncio import logging - import time from mitmproxy.script import concurrent diff --git a/examples/addons/shutdown.py b/examples/addons/shutdown.py index 6a6d5069a..13629eeff 100644 --- a/examples/addons/shutdown.py +++ b/examples/addons/shutdown.py @@ -10,7 +10,8 @@ Usage: """ import logging -from mitmproxy import ctx, http +from mitmproxy import ctx +from mitmproxy import http def request(flow: http.HTTPFlow) -> None: diff --git a/examples/addons/tcp-simple.py b/examples/addons/tcp-simple.py index 242e97140..ed90ba48f 100644 --- a/examples/addons/tcp-simple.py +++ b/examples/addons/tcp-simple.py @@ -12,8 +12,8 @@ Example Invocation: """ import logging -from mitmproxy.utils import strutils from mitmproxy import tcp +from mitmproxy.utils import strutils def tcp_message(flow: tcp.TCPFlow): diff --git a/examples/addons/websocket-inject-message.py b/examples/addons/websocket-inject-message.py index a0b73d24c..5916edc1a 100644 --- a/examples/addons/websocket-inject-message.py +++ b/examples/addons/websocket-inject-message.py @@ -5,7 +5,8 @@ This example shows how to inject a WebSocket message into a running connection. """ import asyncio -from mitmproxy import ctx, http +from mitmproxy import ctx +from mitmproxy import http # Simple example: Inject a message as a response to an event diff --git a/examples/addons/wsgi-flask-app.py b/examples/addons/wsgi-flask-app.py index 4f117f05a..7feab9d58 100644 --- a/examples/addons/wsgi-flask-app.py +++ b/examples/addons/wsgi-flask-app.py @@ -6,6 +6,7 @@ instance, we're using the Flask framework (http://flask.pocoo.org/) to expose a single simplest-possible page. """ from flask import Flask + from mitmproxy.addons import asgiapp app = Flask("proxapp") @@ -24,5 +25,4 @@ addons = [ # mitmproxy will connect to said domain and use its certificate but won't send any data. # By using `--set upstream_cert=false` and `--set connection_strategy_lazy` the local certificate is used instead. # asgiapp.WSGIApp(app, "example.com", 443), - ] diff --git a/examples/contrib/all_markers.py b/examples/contrib/all_markers.py index 4e9043f33..153818d03 100644 --- a/examples/contrib/all_markers.py +++ b/examples/contrib/all_markers.py @@ -1,10 +1,13 @@ -from mitmproxy import ctx, command +from mitmproxy import command +from mitmproxy import ctx from mitmproxy.utils import emoji -@command.command('all.markers') +@command.command("all.markers") def all_markers(): - 'Create a new flow showing all marker values' + "Create a new flow showing all marker values" for marker in emoji.emoji: - ctx.master.commands.call('view.flows.create', 'get', f'https://example.com/{marker}') - ctx.master.commands.call('flow.mark', [ctx.master.view.focus.flow], marker) + ctx.master.commands.call( + "view.flows.create", "get", f"https://example.com/{marker}" + ) + ctx.master.commands.call("flow.mark", [ctx.master.view.focus.flow], marker) diff --git a/examples/contrib/block_dns_over_https.py b/examples/contrib/block_dns_over_https.py index 4fe71c44c..2933ce6ce 100644 --- a/examples/contrib/block_dns_over_https.py +++ b/examples/contrib/block_dns_over_https.py @@ -9,79 +9,250 @@ import logging # known DoH providers' hostnames and IP addresses to block default_blocklist: dict = { "hostnames": [ - "dns.adguard.com", "dns-family.adguard.com", "dns.google", "cloudflare-dns.com", - "mozilla.cloudflare-dns.com", "security.cloudflare-dns.com", "family.cloudflare-dns.com", - "dns.quad9.net", "dns9.quad9.net", "dns10.quad9.net", "dns11.quad9.net", "doh.opendns.com", - "doh.familyshield.opendns.com", "doh.cleanbrowsing.org", "doh.xfinity.com", "dohdot.coxlab.net", - "odvr.nic.cz", "doh.dnslify.com", "dns.nextdns.io", "dns.dnsoverhttps.net", "doh.crypto.sx", - "doh.powerdns.org", "doh-fi.blahdns.com", "doh-jp.blahdns.com", "doh-de.blahdns.com", - "doh.ffmuc.net", "dns.dns-over-https.com", "doh.securedns.eu", "dns.rubyfish.cn", - "dns.containerpi.com", "dns.containerpi.com", "dns.containerpi.com", "doh-2.seby.io", - "doh.seby.io", "commons.host", "doh.dnswarden.com", "doh.dnswarden.com", "doh.dnswarden.com", - "dns-nyc.aaflalo.me", "dns.aaflalo.me", "doh.applied-privacy.net", "doh.captnemo.in", - "doh.tiar.app", "doh.tiarap.org", "doh.dns.sb", "rdns.faelix.net", "doh.li", "doh.armadillodns.net", - "jp.tiar.app", "jp.tiarap.org", "doh.42l.fr", "dns.hostux.net", "dns.hostux.net", "dns.aa.net.uk", - "adblock.mydns.network", "ibksturm.synology.me", "jcdns.fun", "ibuki.cgnat.net", "dns.twnic.tw", - "example.doh.blockerdns.com", "dns.digitale-gesellschaft.ch", "doh.libredns.gr", - "doh.centraleu.pi-dns.com", "doh.northeu.pi-dns.com", "doh.westus.pi-dns.com", - "doh.eastus.pi-dns.com", "dns.flatuslifir.is", "private.canadianshield.cira.ca", - "protected.canadianshield.cira.ca", "family.canadianshield.cira.ca", "dns.google.com", - "dns.google.com" + "dns.adguard.com", + "dns-family.adguard.com", + "dns.google", + "cloudflare-dns.com", + "mozilla.cloudflare-dns.com", + "security.cloudflare-dns.com", + "family.cloudflare-dns.com", + "dns.quad9.net", + "dns9.quad9.net", + "dns10.quad9.net", + "dns11.quad9.net", + "doh.opendns.com", + "doh.familyshield.opendns.com", + "doh.cleanbrowsing.org", + "doh.xfinity.com", + "dohdot.coxlab.net", + "odvr.nic.cz", + "doh.dnslify.com", + "dns.nextdns.io", + "dns.dnsoverhttps.net", + "doh.crypto.sx", + "doh.powerdns.org", + "doh-fi.blahdns.com", + "doh-jp.blahdns.com", + "doh-de.blahdns.com", + "doh.ffmuc.net", + "dns.dns-over-https.com", + "doh.securedns.eu", + "dns.rubyfish.cn", + "dns.containerpi.com", + "dns.containerpi.com", + "dns.containerpi.com", + "doh-2.seby.io", + "doh.seby.io", + "commons.host", + "doh.dnswarden.com", + "doh.dnswarden.com", + "doh.dnswarden.com", + "dns-nyc.aaflalo.me", + "dns.aaflalo.me", + "doh.applied-privacy.net", + "doh.captnemo.in", + "doh.tiar.app", + "doh.tiarap.org", + "doh.dns.sb", + "rdns.faelix.net", + "doh.li", + "doh.armadillodns.net", + "jp.tiar.app", + "jp.tiarap.org", + "doh.42l.fr", + "dns.hostux.net", + "dns.hostux.net", + "dns.aa.net.uk", + "adblock.mydns.network", + "ibksturm.synology.me", + "jcdns.fun", + "ibuki.cgnat.net", + "dns.twnic.tw", + "example.doh.blockerdns.com", + "dns.digitale-gesellschaft.ch", + "doh.libredns.gr", + "doh.centraleu.pi-dns.com", + "doh.northeu.pi-dns.com", + "doh.westus.pi-dns.com", + "doh.eastus.pi-dns.com", + "dns.flatuslifir.is", + "private.canadianshield.cira.ca", + "protected.canadianshield.cira.ca", + "family.canadianshield.cira.ca", + "dns.google.com", + "dns.google.com", ], "ips": [ - "104.16.248.249", "104.16.248.249", "104.16.249.249", "104.16.249.249", "104.18.2.55", - "104.18.26.128", "104.18.27.128", "104.18.3.55", "104.18.44.204", "104.18.44.204", - "104.18.45.204", "104.18.45.204", "104.182.57.196", "104.236.178.232", "104.24.122.53", - "104.24.123.53", "104.28.0.106", "104.28.1.106", "104.31.90.138", "104.31.91.138", - "115.159.131.230", "116.202.176.26", "116.203.115.192", "136.144.215.158", "139.59.48.222", - "139.99.222.72", "146.112.41.2", "146.112.41.3", "146.185.167.43", "149.112.112.10", - "149.112.112.11", "149.112.112.112", "149.112.112.9", "149.112.121.10", "149.112.121.20", - "149.112.121.30", "149.112.122.10", "149.112.122.20", "149.112.122.30", "159.69.198.101", - "168.235.81.167", "172.104.93.80", "172.65.3.223", "174.138.29.175", "174.68.248.77", - "176.103.130.130", "176.103.130.131", "176.103.130.132", "176.103.130.134", "176.56.236.175", - "178.62.214.105", "185.134.196.54", "185.134.197.54", "185.213.26.187", "185.216.27.142", - "185.228.168.10", "185.228.168.168", "185.235.81.1", "185.26.126.37", "185.26.126.37", - "185.43.135.1", "185.95.218.42", "185.95.218.43", "195.30.94.28", "2001:148f:fffe::1", - "2001:19f0:7001:3259:5400:2ff:fe71:bc9", "2001:19f0:7001:5554:5400:2ff:fe57:3077", - "2001:19f0:7001:5554:5400:2ff:fe57:3077", "2001:19f0:7001:5554:5400:2ff:fe57:3077", - "2001:4860:4860::8844", "2001:4860:4860::8888", - "2001:4b98:dc2:43:216:3eff:fe86:1d28", "2001:558:fe21:6b:96:113:151:149", - "2001:608:a01::3", "2001:678:888:69:c45d:2738:c3f2:1878", "2001:8b0::2022", "2001:8b0::2023", - "2001:c50:ffff:1:101:101:101:101", "210.17.9.228", "217.169.20.22", "217.169.20.23", - "2400:6180:0:d0::5f73:4001", "2400:8902::f03c:91ff:feda:c514", "2604:180:f3::42", - "2604:a880:1:20::51:f001", "2606:4700::6810:f8f9", "2606:4700::6810:f9f9", "2606:4700::6812:1a80", - "2606:4700::6812:1b80", "2606:4700::6812:237", "2606:4700::6812:337", "2606:4700:3033::6812:2ccc", - "2606:4700:3033::6812:2dcc", "2606:4700:3033::6818:7b35", "2606:4700:3034::681c:16a", - "2606:4700:3035::6818:7a35", "2606:4700:3035::681f:5a8a", "2606:4700:3036::681c:6a", - "2606:4700:3036::681f:5b8a", "2606:4700:60:0:a71e:6467:cef8:2a56", "2620:10a:80bb::10", - "2620:10a:80bb::20", "2620:10a:80bb::30" "2620:10a:80bc::10", "2620:10a:80bc::20", - "2620:10a:80bc::30", "2620:119:fc::2", "2620:119:fc::3", "2620:fe::10", "2620:fe::11", - "2620:fe::9", "2620:fe::fe:10", "2620:fe::fe:11", "2620:fe::fe:9", "2620:fe::fe", - "2a00:5a60::ad1:ff", "2a00:5a60::ad2:ff", "2a00:5a60::bad1:ff", "2a00:5a60::bad2:ff", - "2a00:d880:5:bf0::7c93", "2a01:4f8:1c0c:8233::1", "2a01:4f8:1c1c:6b4b::1", "2a01:4f8:c2c:52bf::1", - "2a01:4f9:c010:43ce::1", "2a01:4f9:c01f:4::abcd", "2a01:7c8:d002:1ef:5054:ff:fe40:3703", - "2a01:9e00::54", "2a01:9e00::55", "2a01:9e01::54", "2a01:9e01::55", - "2a02:1205:34d5:5070:b26e:bfff:fe1d:e19b", "2a03:4000:38:53c::2", - "2a03:b0c0:0:1010::e9a:3001", "2a04:bdc7:100:70::abcd", "2a05:fc84::42", "2a05:fc84::43", - "2a07:a8c0::", "2a0d:4d00:81::1", "2a0d:5600:33:3::abcd", "35.198.2.76", "35.231.247.227", - "45.32.55.94", "45.67.219.208", "45.76.113.31", "45.77.180.10", "45.90.28.0", - "46.101.66.244", "46.227.200.54", "46.227.200.55", "46.239.223.80", "8.8.4.4", - "8.8.8.8", "83.77.85.7", "88.198.91.187", "9.9.9.10", "9.9.9.11", "9.9.9.9", - "94.130.106.88", "95.216.181.228", "95.216.212.177", "96.113.151.148", - ] + "104.16.248.249", + "104.16.248.249", + "104.16.249.249", + "104.16.249.249", + "104.18.2.55", + "104.18.26.128", + "104.18.27.128", + "104.18.3.55", + "104.18.44.204", + "104.18.44.204", + "104.18.45.204", + "104.18.45.204", + "104.182.57.196", + "104.236.178.232", + "104.24.122.53", + "104.24.123.53", + "104.28.0.106", + "104.28.1.106", + "104.31.90.138", + "104.31.91.138", + "115.159.131.230", + "116.202.176.26", + "116.203.115.192", + "136.144.215.158", + "139.59.48.222", + "139.99.222.72", + "146.112.41.2", + "146.112.41.3", + "146.185.167.43", + "149.112.112.10", + "149.112.112.11", + "149.112.112.112", + "149.112.112.9", + "149.112.121.10", + "149.112.121.20", + "149.112.121.30", + "149.112.122.10", + "149.112.122.20", + "149.112.122.30", + "159.69.198.101", + "168.235.81.167", + "172.104.93.80", + "172.65.3.223", + "174.138.29.175", + "174.68.248.77", + "176.103.130.130", + "176.103.130.131", + "176.103.130.132", + "176.103.130.134", + "176.56.236.175", + "178.62.214.105", + "185.134.196.54", + "185.134.197.54", + "185.213.26.187", + "185.216.27.142", + "185.228.168.10", + "185.228.168.168", + "185.235.81.1", + "185.26.126.37", + "185.26.126.37", + "185.43.135.1", + "185.95.218.42", + "185.95.218.43", + "195.30.94.28", + "2001:148f:fffe::1", + "2001:19f0:7001:3259:5400:2ff:fe71:bc9", + "2001:19f0:7001:5554:5400:2ff:fe57:3077", + "2001:19f0:7001:5554:5400:2ff:fe57:3077", + "2001:19f0:7001:5554:5400:2ff:fe57:3077", + "2001:4860:4860::8844", + "2001:4860:4860::8888", + "2001:4b98:dc2:43:216:3eff:fe86:1d28", + "2001:558:fe21:6b:96:113:151:149", + "2001:608:a01::3", + "2001:678:888:69:c45d:2738:c3f2:1878", + "2001:8b0::2022", + "2001:8b0::2023", + "2001:c50:ffff:1:101:101:101:101", + "210.17.9.228", + "217.169.20.22", + "217.169.20.23", + "2400:6180:0:d0::5f73:4001", + "2400:8902::f03c:91ff:feda:c514", + "2604:180:f3::42", + "2604:a880:1:20::51:f001", + "2606:4700::6810:f8f9", + "2606:4700::6810:f9f9", + "2606:4700::6812:1a80", + "2606:4700::6812:1b80", + "2606:4700::6812:237", + "2606:4700::6812:337", + "2606:4700:3033::6812:2ccc", + "2606:4700:3033::6812:2dcc", + "2606:4700:3033::6818:7b35", + "2606:4700:3034::681c:16a", + "2606:4700:3035::6818:7a35", + "2606:4700:3035::681f:5a8a", + "2606:4700:3036::681c:6a", + "2606:4700:3036::681f:5b8a", + "2606:4700:60:0:a71e:6467:cef8:2a56", + "2620:10a:80bb::10", + "2620:10a:80bb::20", + "2620:10a:80bb::30" "2620:10a:80bc::10", + "2620:10a:80bc::20", + "2620:10a:80bc::30", + "2620:119:fc::2", + "2620:119:fc::3", + "2620:fe::10", + "2620:fe::11", + "2620:fe::9", + "2620:fe::fe:10", + "2620:fe::fe:11", + "2620:fe::fe:9", + "2620:fe::fe", + "2a00:5a60::ad1:ff", + "2a00:5a60::ad2:ff", + "2a00:5a60::bad1:ff", + "2a00:5a60::bad2:ff", + "2a00:d880:5:bf0::7c93", + "2a01:4f8:1c0c:8233::1", + "2a01:4f8:1c1c:6b4b::1", + "2a01:4f8:c2c:52bf::1", + "2a01:4f9:c010:43ce::1", + "2a01:4f9:c01f:4::abcd", + "2a01:7c8:d002:1ef:5054:ff:fe40:3703", + "2a01:9e00::54", + "2a01:9e00::55", + "2a01:9e01::54", + "2a01:9e01::55", + "2a02:1205:34d5:5070:b26e:bfff:fe1d:e19b", + "2a03:4000:38:53c::2", + "2a03:b0c0:0:1010::e9a:3001", + "2a04:bdc7:100:70::abcd", + "2a05:fc84::42", + "2a05:fc84::43", + "2a07:a8c0::", + "2a0d:4d00:81::1", + "2a0d:5600:33:3::abcd", + "35.198.2.76", + "35.231.247.227", + "45.32.55.94", + "45.67.219.208", + "45.76.113.31", + "45.77.180.10", + "45.90.28.0", + "46.101.66.244", + "46.227.200.54", + "46.227.200.55", + "46.239.223.80", + "8.8.4.4", + "8.8.8.8", + "83.77.85.7", + "88.198.91.187", + "9.9.9.10", + "9.9.9.11", + "9.9.9.9", + "94.130.106.88", + "95.216.181.228", + "95.216.212.177", + "96.113.151.148", + ], } # additional hostnames to block -additional_doh_names: list[str] = [ - 'dns.google.com' -] +additional_doh_names: list[str] = ["dns.google.com"] # additional IPs to block -additional_doh_ips: list[str] = [ +additional_doh_ips: list[str] = [] -] - -doh_hostnames, doh_ips = default_blocklist['hostnames'], default_blocklist['ips'] +doh_hostnames, doh_ips = default_blocklist["hostnames"], default_blocklist["ips"] # convert to sets for faster lookups doh_hostnames = set(doh_hostnames) @@ -95,9 +266,9 @@ def _has_dns_message_content_type(flow): :param flow: mitmproxy flow :return: True if 'Content-Type' header is DNS-looking, False otherwise """ - doh_content_types = ['application/dns-message'] - if 'Content-Type' in flow.request.headers: - if flow.request.headers['Content-Type'] in doh_content_types: + doh_content_types = ["application/dns-message"] + if "Content-Type" in flow.request.headers: + if flow.request.headers["Content-Type"] in doh_content_types: return True return False @@ -109,7 +280,7 @@ def _request_has_dns_query_string(flow): :param flow: mitmproxy flow :return: True is 'dns' is a parameter in the query string, False otherwise """ - return 'dns' in flow.request.query + return "dns" in flow.request.query def _request_is_dns_json(flow): @@ -127,12 +298,12 @@ def _request_is_dns_json(flow): """ # Header 'Accept: application/dns-json' is required in Cloudflare's DoH JSON API # or they return a 400 HTTP response code - if 'Accept' in flow.request.headers: - if flow.request.headers['Accept'] == 'application/dns-json': + if "Accept" in flow.request.headers: + if flow.request.headers["Accept"] == "application/dns-json": return True # Google's DoH JSON API is https://dns.google/resolve - path = flow.request.path.split('?')[0] - if flow.request.host == 'dns.google' and path == '/resolve': + path = flow.request.path.split("?")[0] + if flow.request.host == "dns.google" and path == "/resolve": return True return False @@ -146,9 +317,9 @@ def _request_has_doh_looking_path(flow): :return: True if path looks like it's DoH, otherwise False """ doh_paths = [ - '/dns-query', # used in example in RFC 8484 (see https://tools.ietf.org/html/rfc8484#section-4.1.1) + "/dns-query", # used in example in RFC 8484 (see https://tools.ietf.org/html/rfc8484#section-4.1.1) ] - path = flow.request.path.split('?')[0] + path = flow.request.path.split("?")[0] return path in doh_paths @@ -171,7 +342,7 @@ doh_request_detection_checks = [ _request_has_dns_query_string, _request_is_dns_json, _requested_hostname_is_in_doh_blocklist, - _request_has_doh_looking_path + _request_has_doh_looking_path, ] @@ -179,6 +350,9 @@ def request(flow): for check in doh_request_detection_checks: is_doh = check(flow) if is_doh: - logging.warning("[DoH Detection] DNS over HTTPS request detected via method \"%s\"" % check.__name__) + logging.warning( + '[DoH Detection] DNS over HTTPS request detected via method "%s"' + % check.__name__ + ) flow.kill() break diff --git a/examples/contrib/change_upstream_proxy.py b/examples/contrib/change_upstream_proxy.py index 7f6d56bc4..4a824131c 100644 --- a/examples/contrib/change_upstream_proxy.py +++ b/examples/contrib/change_upstream_proxy.py @@ -1,4 +1,3 @@ - from mitmproxy import http from mitmproxy.connection import Server from mitmproxy.net.server_spec import ServerSpec diff --git a/examples/contrib/check_ssl_pinning.py b/examples/contrib/check_ssl_pinning.py index 8bc0b24aa..87a96181f 100644 --- a/examples/contrib/check_ssl_pinning.py +++ b/examples/contrib/check_ssl_pinning.py @@ -1,14 +1,17 @@ +import ipaddress +import time + +import OpenSSL + import mitmproxy from mitmproxy import ctx from mitmproxy.certs import Cert -import ipaddress -import OpenSSL -import time # Certificate for client connection is generated in dummy_cert() in certs.py. Monkeypatching # the function to generate test cases for SSL Pinning. + def monkey_dummy_cert(privkey, cacert, commonname, sans): ss = [] for i in sans: @@ -42,7 +45,7 @@ def monkey_dummy_cert(privkey, cacert, commonname, sans): if ctx.options.certwrongCN: # append an extra char to make certs common name different than original one. # APpending a char in the end of the domain name. - new_cn = commonname + b'm' + new_cn = commonname + b"m" cert.get_subject().CN = new_cn else: @@ -52,7 +55,8 @@ def monkey_dummy_cert(privkey, cacert, commonname, sans): if ss: cert.set_version(2) cert.add_extensions( - [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)]) + [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)] + ) cert.set_pubkey(cacert.get_pubkey()) cert.sign(privkey, "sha256") return Cert(cert) @@ -61,23 +65,29 @@ def monkey_dummy_cert(privkey, cacert, commonname, sans): class CheckSSLPinning: def load(self, loader): loader.add_option( - "certbeginon", bool, False, + "certbeginon", + bool, + False, """ Sets SSL Certificate's 'Begins On' time in future. - """ + """, ) loader.add_option( - "certexpire", bool, False, + "certexpire", + bool, + False, """ Sets SSL Certificate's 'Expires On' time in the past. - """ + """, ) loader.add_option( - "certwrongCN", bool, False, + "certwrongCN", + bool, + False, """ Sets SSL Certificate's CommonName(CN) different from the domain name. - """ + """, ) def clientconnect(self, layer): diff --git a/examples/contrib/custom_next_layer.py b/examples/contrib/custom_next_layer.py index 917272dcb..31e0887fc 100644 --- a/examples/contrib/custom_next_layer.py +++ b/examples/contrib/custom_next_layer.py @@ -11,7 +11,8 @@ Example usage: import logging from mitmproxy import ctx -from mitmproxy.proxy import layer, layers +from mitmproxy.proxy import layer +from mitmproxy.proxy import layers def running(): diff --git a/examples/contrib/domain_fronting.py b/examples/contrib/domain_fronting.py index fd73d2985..0a477d0b5 100644 --- a/examples/contrib/domain_fronting.py +++ b/examples/contrib/domain_fronting.py @@ -1,6 +1,8 @@ -from typing import Optional, Union import json from dataclasses import dataclass +from typing import Optional +from typing import Union + from mitmproxy import ctx from mitmproxy.addonmanager import Loader from mitmproxy.http import HTTPFlow @@ -79,7 +81,7 @@ class HttpsDomainFronting: index = host.find(".", index) if index == -1: break - super_domain = host[(index + 1):] + super_domain = host[(index + 1) :] mapping = self.star_mappings.get(super_domain) if mapping is not None: return mapping diff --git a/examples/contrib/har_dump.py b/examples/contrib/har_dump.py index e1337af46..0a7d6faae 100644 --- a/examples/contrib/har_dump.py +++ b/examples/contrib/har_dump.py @@ -7,16 +7,14 @@ mitmdump -s ./har_dump.py --set hardump=./dump.har filename endwith '.zhar' will be compressed: mitmdump -s ./har_dump.py --set hardump=./dump.zhar """ - import base64 import json import logging import os +import zlib from datetime import datetime from datetime import timezone -import zlib - import mitmproxy from mitmproxy import connection from mitmproxy import ctx @@ -33,27 +31,28 @@ SERVERS_SEEN: set[connection.Server] = set() def load(l): l.add_option( - "hardump", str, "", "HAR dump path.", + "hardump", + str, + "", + "HAR dump path.", ) def configure(updated): - HAR.update({ - "log": { - "version": "1.2", - "creator": { - "name": "mitmproxy har_dump", - "version": "0.1", - "comment": "mitmproxy version %s" % version.MITMPROXY - }, - "pages": [ - { - "pageTimings": {} - } - ], - "entries": [] + HAR.update( + { + "log": { + "version": "1.2", + "creator": { + "name": "mitmproxy har_dump", + "version": "0.1", + "comment": "mitmproxy version %s" % version.MITMPROXY, + }, + "pages": [{"pageTimings": {}}], + "entries": [], + } } - }) + ) # The `pages` attribute is needed for Firefox Dev Tools to load the HAR file. # An empty value works fine. @@ -65,12 +64,15 @@ def flow_entry(flow: mitmproxy.http.HTTPFlow) -> dict: connect_time = -1 if flow.server_conn and flow.server_conn not in SERVERS_SEEN: - connect_time = (flow.server_conn.timestamp_tcp_setup - - flow.server_conn.timestamp_start) + connect_time = ( + flow.server_conn.timestamp_tcp_setup - flow.server_conn.timestamp_start + ) if flow.server_conn.timestamp_tls_setup is not None: - ssl_time = (flow.server_conn.timestamp_tls_setup - - flow.server_conn.timestamp_tcp_setup) + ssl_time = ( + flow.server_conn.timestamp_tls_setup + - flow.server_conn.timestamp_tcp_setup + ) SERVERS_SEEN.add(flow.server_conn) @@ -81,28 +83,31 @@ def flow_entry(flow: mitmproxy.http.HTTPFlow) -> dict: # spent waiting between request.timestamp_end and response.timestamp_start # thus it correlates to HAR wait instead. timings_raw = { - 'send': flow.request.timestamp_end - flow.request.timestamp_start, - 'receive': flow.response.timestamp_end - flow.response.timestamp_start, - 'wait': flow.response.timestamp_start - flow.request.timestamp_end, - 'connect': connect_time, - 'ssl': ssl_time, + "send": flow.request.timestamp_end - flow.request.timestamp_start, + "receive": flow.response.timestamp_end - flow.response.timestamp_start, + "wait": flow.response.timestamp_start - flow.request.timestamp_end, + "connect": connect_time, + "ssl": ssl_time, } # HAR timings are integers in ms, so we re-encode the raw timings to that format. - timings = { - k: int(1000 * v) if v != -1 else -1 - for k, v in timings_raw.items() - } + timings = {k: int(1000 * v) if v != -1 else -1 for k, v in timings_raw.items()} # full_time is the sum of all timings. # Timings set to -1 will be ignored as per spec. full_time = sum(v for v in timings.values() if v > -1) - started_date_time = datetime.fromtimestamp(flow.request.timestamp_start, timezone.utc).isoformat() + started_date_time = datetime.fromtimestamp( + flow.request.timestamp_start, timezone.utc + ).isoformat() # Response body size and encoding - response_body_size = len(flow.response.raw_content) if flow.response.raw_content else 0 - response_body_decoded_size = len(flow.response.content) if flow.response.content else 0 + response_body_size = ( + len(flow.response.raw_content) if flow.response.raw_content else 0 + ) + response_body_decoded_size = ( + len(flow.response.content) if flow.response.content else 0 + ) response_body_compression = response_body_decoded_size - response_body_size entry = { @@ -127,9 +132,9 @@ def flow_entry(flow: mitmproxy.http.HTTPFlow) -> dict: "content": { "size": response_body_size, "compression": response_body_compression, - "mimeType": flow.response.headers.get('Content-Type', '') + "mimeType": flow.response.headers.get("Content-Type", ""), }, - "redirectURL": flow.response.headers.get('Location', ''), + "redirectURL": flow.response.headers.get("Location", ""), "headersSize": len(str(flow.response.headers)), "bodySize": response_body_size, }, @@ -139,7 +144,9 @@ def flow_entry(flow: mitmproxy.http.HTTPFlow) -> dict: # Store binary data as base64 if strutils.is_mostly_bin(flow.response.content): - entry["response"]["content"]["text"] = base64.b64encode(flow.response.content).decode() + entry["response"]["content"]["text"] = base64.b64encode( + flow.response.content + ).decode() entry["response"]["content"]["encoding"] = "base64" else: entry["response"]["content"]["text"] = flow.response.get_text(strict=False) @@ -152,7 +159,7 @@ def flow_entry(flow: mitmproxy.http.HTTPFlow) -> dict: entry["request"]["postData"] = { "mimeType": flow.request.headers.get("Content-Type", ""), "text": flow.request.get_text(strict=False), - "params": params + "params": params, } if flow.server_conn.connected: @@ -165,7 +172,7 @@ def flow_entry(flow: mitmproxy.http.HTTPFlow) -> dict: def response(flow: mitmproxy.http.HTTPFlow): """ - Called when a server response has been received. + Called when a server response has been received. """ if flow.websocket is None: flow_entry(flow) @@ -182,29 +189,29 @@ def websocket_end(flow: mitmproxy.http.HTTPFlow): else: data = base64.b64encode(message.content).decode() websocket_message = { - 'type': 'send' if message.from_client else 'receive', - 'time': message.timestamp, - 'opcode': message.type.value, - 'data': data + "type": "send" if message.from_client else "receive", + "time": message.timestamp, + "opcode": message.type.value, + "data": data, } websocket_messages.append(websocket_message) - entry['_resourceType'] = 'websocket' - entry['_webSocketMessages'] = websocket_messages + entry["_resourceType"] = "websocket" + entry["_webSocketMessages"] = websocket_messages def done(): """ - Called once on script shutdown, after any other events. + Called once on script shutdown, after any other events. """ if ctx.options.hardump: json_dump: str = json.dumps(HAR, indent=2) - if ctx.options.hardump == '-': + if ctx.options.hardump == "-": print(json_dump) else: raw: bytes = json_dump.encode() - if ctx.options.hardump.endswith('.zhar'): + if ctx.options.hardump.endswith(".zhar"): raw = zlib.compress(raw, 9) with open(os.path.expanduser(ctx.options.hardump), "wb") as f: @@ -234,7 +241,9 @@ def format_cookies(cookie_list): # Expiration time needs to be formatted expire_ts = cookies.get_expiration_ts(attrs) if expire_ts is not None: - cookie_har["expires"] = datetime.fromtimestamp(expire_ts, timezone.utc).isoformat() + cookie_har["expires"] = datetime.fromtimestamp( + expire_ts, timezone.utc + ).isoformat() rv.append(cookie_har) @@ -251,6 +260,6 @@ def format_response_cookies(fields): def name_value(obj): """ - Convert (key, value) pairs to HAR format. + Convert (key, value) pairs to HAR format. """ return [{"name": k, "value": v} for k, v in obj.items()] diff --git a/examples/contrib/http_manipulate_cookies.py b/examples/contrib/http_manipulate_cookies.py index b91018c6e..aaad41227 100644 --- a/examples/contrib/http_manipulate_cookies.py +++ b/examples/contrib/http_manipulate_cookies.py @@ -15,9 +15,10 @@ Note: """ import json -from mitmproxy import http from typing import Union +from mitmproxy import http + PATH_TO_COOKIES = "./cookies.json" # insert your path to the cookie file here FILTER_COOKIES = { @@ -43,7 +44,14 @@ def stringify_cookies(cookies: list[dict[str, Union[str, None]]]) -> str: """ Creates a cookie string from a list of cookie dicts. """ - return "; ".join([f"{c['name']}={c['value']}" if c.get("value", None) is not None else f"{c['name']}" for c in cookies]) + return "; ".join( + [ + f"{c['name']}={c['value']}" + if c.get("value", None) is not None + else f"{c['name']}" + for c in cookies + ] + ) def parse_cookies(cookie_string: str) -> list[dict[str, Union[str, None]]]: @@ -52,7 +60,9 @@ def parse_cookies(cookie_string: str) -> list[dict[str, Union[str, None]]]: """ return [ {"name": g[0], "value": g[1]} if len(g) == 2 else {"name": g[0], "value": None} - for g in [k.split("=", 1) for k in [c.strip() for c in cookie_string.split(";")] if k] + for g in [ + k.split("=", 1) for k in [c.strip() for c in cookie_string.split(";")] if k + ] ] diff --git a/examples/contrib/httpdump.py b/examples/contrib/httpdump.py index e8c166557..532ed1e5d 100644 --- a/examples/contrib/httpdump.py +++ b/examples/contrib/httpdump.py @@ -14,8 +14,9 @@ import mimetypes import os from pathlib import Path -from mitmproxy import ctx, http +from mitmproxy import ctx from mitmproxy import flowfilter +from mitmproxy import http class HTTPDump: @@ -32,7 +33,7 @@ class HTTPDump: name="open_browser", typespec=bool, default=True, - help="open integrated browser at start" + help="open integrated browser at start", ) def running(self): diff --git a/examples/contrib/jsondump.py b/examples/contrib/jsondump.py index 7617902f3..cfde9b75c 100644 --- a/examples/contrib/jsondump.py +++ b/examples/contrib/jsondump.py @@ -34,7 +34,8 @@ import base64 import json import logging from queue import Queue -from threading import Lock, Thread +from threading import Lock +from threading import Thread import requests @@ -66,76 +67,77 @@ class JSONDumper: self.outfile.close() fields = { - 'timestamp': ( - ('error', 'timestamp'), - - ('request', 'timestamp_start'), - ('request', 'timestamp_end'), - - ('response', 'timestamp_start'), - ('response', 'timestamp_end'), - - ('client_conn', 'timestamp_start'), - ('client_conn', 'timestamp_end'), - ('client_conn', 'timestamp_tls_setup'), - - ('server_conn', 'timestamp_start'), - ('server_conn', 'timestamp_end'), - ('server_conn', 'timestamp_tls_setup'), - ('server_conn', 'timestamp_tcp_setup'), + "timestamp": ( + ("error", "timestamp"), + ("request", "timestamp_start"), + ("request", "timestamp_end"), + ("response", "timestamp_start"), + ("response", "timestamp_end"), + ("client_conn", "timestamp_start"), + ("client_conn", "timestamp_end"), + ("client_conn", "timestamp_tls_setup"), + ("server_conn", "timestamp_start"), + ("server_conn", "timestamp_end"), + ("server_conn", "timestamp_tls_setup"), + ("server_conn", "timestamp_tcp_setup"), ), - 'ip': ( - ('server_conn', 'source_address'), - ('server_conn', 'ip_address'), - ('server_conn', 'address'), - ('client_conn', 'address'), + "ip": ( + ("server_conn", "source_address"), + ("server_conn", "ip_address"), + ("server_conn", "address"), + ("client_conn", "address"), ), - 'ws_messages': ( - ('messages',), + "ws_messages": (("messages",),), + "headers": ( + ("request", "headers"), + ("response", "headers"), ), - 'headers': ( - ('request', 'headers'), - ('response', 'headers'), - ), - 'content': ( - ('request', 'content'), - ('response', 'content'), + "content": ( + ("request", "content"), + ("response", "content"), ), } def _init_transformations(self): self.transformations = [ { - 'fields': self.fields['headers'], - 'func': dict, + "fields": self.fields["headers"], + "func": dict, }, { - 'fields': self.fields['timestamp'], - 'func': lambda t: int(t * 1000), + "fields": self.fields["timestamp"], + "func": lambda t: int(t * 1000), }, { - 'fields': self.fields['ip'], - 'func': lambda addr: { - 'host': addr[0].replace('::ffff:', ''), - 'port': addr[1], + "fields": self.fields["ip"], + "func": lambda addr: { + "host": addr[0].replace("::ffff:", ""), + "port": addr[1], }, }, { - 'fields': self.fields['ws_messages'], - 'func': lambda ms: [{ - 'type': m[0], - 'from_client': m[1], - 'content': base64.b64encode(bytes(m[2], 'utf-8')) if self.encode else m[2], - 'timestamp': int(m[3] * 1000), - } for m in ms], - } + "fields": self.fields["ws_messages"], + "func": lambda ms: [ + { + "type": m[0], + "from_client": m[1], + "content": base64.b64encode(bytes(m[2], "utf-8")) + if self.encode + else m[2], + "timestamp": int(m[3] * 1000), + } + for m in ms + ], + }, ] if self.encode: - self.transformations.append({ - 'fields': self.fields['content'], - 'func': base64.b64encode, - }) + self.transformations.append( + { + "fields": self.fields["content"], + "func": base64.b64encode, + } + ) @staticmethod def transform_field(obj, path, func): @@ -156,8 +158,10 @@ class JSONDumper: Recursively convert all list/dict elements of type `bytes` into strings. """ if isinstance(obj, dict): - return {cls.convert_to_strings(key): cls.convert_to_strings(value) - for key, value in obj.items()} + return { + cls.convert_to_strings(key): cls.convert_to_strings(value) + for key, value in obj.items() + } elif isinstance(obj, list) or isinstance(obj, tuple): return [cls.convert_to_strings(element) for element in obj] elif isinstance(obj, bytes): @@ -175,8 +179,8 @@ class JSONDumper: Transform and dump (write / send) a data frame. """ for tfm in self.transformations: - for field in tfm['fields']: - self.transform_field(frame, field, tfm['func']) + for field in tfm["fields"]: + self.transform_field(frame, field, tfm["func"]) frame = self.convert_to_strings(frame) if self.outfile: @@ -191,14 +195,21 @@ class JSONDumper: """ Extra options to be specified in `~/.mitmproxy/config.yaml`. """ - loader.add_option('dump_encodecontent', bool, False, - 'Encode content as base64.') - loader.add_option('dump_destination', str, 'jsondump.out', - 'Output destination: path to a file or URL.') - loader.add_option('dump_username', str, '', - 'Basic auth username for URL destinations.') - loader.add_option('dump_password', str, '', - 'Basic auth password for URL destinations.') + loader.add_option( + "dump_encodecontent", bool, False, "Encode content as base64." + ) + loader.add_option( + "dump_destination", + str, + "jsondump.out", + "Output destination: path to a file or URL.", + ) + loader.add_option( + "dump_username", str, "", "Basic auth username for URL destinations." + ) + loader.add_option( + "dump_password", str, "", "Basic auth password for URL destinations." + ) def configure(self, _): """ @@ -207,18 +218,18 @@ class JSONDumper: """ self.encode = ctx.options.dump_encodecontent - if ctx.options.dump_destination.startswith('http'): + if ctx.options.dump_destination.startswith("http"): self.outfile = None self.url = ctx.options.dump_destination - logging.info('Sending all data frames to %s' % self.url) + logging.info("Sending all data frames to %s" % self.url) if ctx.options.dump_username and ctx.options.dump_password: self.auth = (ctx.options.dump_username, ctx.options.dump_password) - logging.info('HTTP Basic auth enabled.') + logging.info("HTTP Basic auth enabled.") else: - self.outfile = open(ctx.options.dump_destination, 'a') + self.outfile = open(ctx.options.dump_destination, "a") self.url = None self.lock = Lock() - logging.info('Writing all data frames to %s' % ctx.options.dump_destination) + logging.info("Writing all data frames to %s" % ctx.options.dump_destination) self._init_transformations() diff --git a/examples/contrib/link_expander.py b/examples/contrib/link_expander.py index 0edf7c986..7e7e6b5d8 100644 --- a/examples/contrib/link_expander.py +++ b/examples/contrib/link_expander.py @@ -2,27 +2,33 @@ # relative links () and expands them to absolute links # In practice this can be used to front an indexing spider that may not have the capability to expand relative page links. # Usage: mitmdump -s link_expander.py or mitmproxy -s link_expander.py - import re from urllib.parse import urljoin def response(flow): - if "Content-Type" in flow.response.headers and flow.response.headers["Content-Type"].find("text/html") != -1: + if ( + "Content-Type" in flow.response.headers + and flow.response.headers["Content-Type"].find("text/html") != -1 + ): pageUrl = flow.request.url pageText = flow.response.text - pattern = (r"]*?\s+)?href=(?P[\"'])" - r"(?P(?!https?:\/\/|ftps?:\/\/|\/\/|#|javascript:|mailto:).*?)(?P=delimiter)") + pattern = ( + r"]*?\s+)?href=(?P[\"'])" + r"(?P(?!https?:\/\/|ftps?:\/\/|\/\/|#|javascript:|mailto:).*?)(?P=delimiter)" + ) rel_matcher = re.compile(pattern, flags=re.IGNORECASE) rel_matches = rel_matcher.finditer(pageText) map_dict = {} for match_num, match in enumerate(rel_matches): (delimiter, rel_link) = match.group("delimiter", "link") abs_link = urljoin(pageUrl, rel_link) - map_dict["{0}{1}{0}".format(delimiter, rel_link)] = "{0}{1}{0}".format(delimiter, abs_link) + map_dict["{0}{1}{0}".format(delimiter, rel_link)] = "{0}{1}{0}".format( + delimiter, abs_link + ) for map in map_dict.items(): pageText = pageText.replace(*map) # Uncomment the following to print the expansion mapping # print("{0} -> {1}".format(*map)) - flow.response.text = pageText \ No newline at end of file + flow.response.text = pageText diff --git a/examples/contrib/mitmproxywrapper.py b/examples/contrib/mitmproxywrapper.py index 075cc3c04..a484ff386 100644 --- a/examples/contrib/mitmproxywrapper.py +++ b/examples/contrib/mitmproxywrapper.py @@ -6,12 +6,11 @@ # # mitmproxywrapper.py -h # - -import subprocess -import re import argparse import contextlib import os +import re +import subprocess import sys @@ -21,59 +20,50 @@ class Wrapper: self.extra_arguments = extra_arguments def run_networksetup_command(self, *arguments): - return subprocess.check_output( - ['sudo', 'networksetup'] + list(arguments)) + return subprocess.check_output(["sudo", "networksetup"] + list(arguments)) def proxy_state_for_service(self, service): - state = self.run_networksetup_command( - '-getwebproxy', - service).splitlines() - return dict([re.findall(r'([^:]+): (.*)', line)[0] for line in state]) + state = self.run_networksetup_command("-getwebproxy", service).splitlines() + return dict([re.findall(r"([^:]+): (.*)", line)[0] for line in state]) def enable_proxy_for_service(self, service): - print(f'Enabling proxy on {service}...') - for subcommand in ['-setwebproxy', '-setsecurewebproxy']: + print(f"Enabling proxy on {service}...") + for subcommand in ["-setwebproxy", "-setsecurewebproxy"]: self.run_networksetup_command( - subcommand, service, '127.0.0.1', str( - self.port)) + subcommand, service, "127.0.0.1", str(self.port) + ) def disable_proxy_for_service(self, service): - print(f'Disabling proxy on {service}...') - for subcommand in ['-setwebproxystate', '-setsecurewebproxystate']: - self.run_networksetup_command(subcommand, service, 'Off') + print(f"Disabling proxy on {service}...") + for subcommand in ["-setwebproxystate", "-setsecurewebproxystate"]: + self.run_networksetup_command(subcommand, service, "Off") def interface_name_to_service_name_map(self): - order = self.run_networksetup_command('-listnetworkserviceorder') + order = self.run_networksetup_command("-listnetworkserviceorder") mapping = re.findall( - r'\(\d+\)\s(.*)$\n\(.*Device: (.+)\)$', - order, - re.MULTILINE) + r"\(\d+\)\s(.*)$\n\(.*Device: (.+)\)$", order, re.MULTILINE + ) return {b: a for (a, b) in mapping} def run_command_with_input(self, command, input): - popen = subprocess.Popen( - command, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE) + popen = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) (stdout, stderr) = popen.communicate(input) return stdout def primary_interace_name(self): - scutil_script = 'get State:/Network/Global/IPv4\nd.show\n' - stdout = self.run_command_with_input('/usr/sbin/scutil', scutil_script) - interface, = re.findall(r'PrimaryInterface\s*:\s*(.+)', stdout) + scutil_script = "get State:/Network/Global/IPv4\nd.show\n" + stdout = self.run_command_with_input("/usr/sbin/scutil", scutil_script) + (interface,) = re.findall(r"PrimaryInterface\s*:\s*(.+)", stdout) return interface def primary_service_name(self): - return self.interface_name_to_service_name_map()[ - self.primary_interace_name()] + return self.interface_name_to_service_name_map()[self.primary_interace_name()] def proxy_enabled_for_service(self, service): - return self.proxy_state_for_service(service)['Enabled'] == 'Yes' + return self.proxy_state_for_service(service)["Enabled"] == "Yes" def toggle_proxy(self): - new_state = not self.proxy_enabled_for_service( - self.primary_service_name()) + new_state = not self.proxy_enabled_for_service(self.primary_service_name()) for service_name in self.connected_service_names(): if self.proxy_enabled_for_service(service_name) and not new_state: self.disable_proxy_for_service(service_name) @@ -81,31 +71,29 @@ class Wrapper: self.enable_proxy_for_service(service_name) def connected_service_names(self): - scutil_script = 'list\n' - stdout = self.run_command_with_input('/usr/sbin/scutil', scutil_script) - service_ids = re.findall(r'State:/Network/Service/(.+)/IPv4', stdout) + scutil_script = "list\n" + stdout = self.run_command_with_input("/usr/sbin/scutil", scutil_script) + service_ids = re.findall(r"State:/Network/Service/(.+)/IPv4", stdout) service_names = [] for service_id in service_ids: scutil_script = f"show Setup:/Network/Service/{service_id}\n" - stdout = self.run_command_with_input( - '/usr/sbin/scutil', - scutil_script) - service_name, = re.findall(r'UserDefinedName\s*:\s*(.+)', stdout) + stdout = self.run_command_with_input("/usr/sbin/scutil", scutil_script) + (service_name,) = re.findall(r"UserDefinedName\s*:\s*(.+)", stdout) service_names.append(service_name) return service_names def wrap_mitmproxy(self): with self.wrap_proxy(): - cmd = ['mitmproxy', '-p', str(self.port)] + cmd = ["mitmproxy", "-p", str(self.port)] if self.extra_arguments: cmd.extend(self.extra_arguments) subprocess.check_call(cmd) def wrap_honeyproxy(self): with self.wrap_proxy(): - popen = subprocess.Popen('honeyproxy.sh') + popen = subprocess.Popen("honeyproxy.sh") try: popen.wait() except KeyboardInterrupt: @@ -127,26 +115,29 @@ class Wrapper: @classmethod def ensure_superuser(cls): if os.getuid() != 0: - print('Relaunching with sudo...') - os.execv('/usr/bin/sudo', ['/usr/bin/sudo'] + sys.argv) + print("Relaunching with sudo...") + os.execv("/usr/bin/sudo", ["/usr/bin/sudo"] + sys.argv) @classmethod def main(cls): parser = argparse.ArgumentParser( - description='Helper tool for OS X proxy configuration and mitmproxy.', - epilog='Any additional arguments will be passed on unchanged to mitmproxy.') + description="Helper tool for OS X proxy configuration and mitmproxy.", + epilog="Any additional arguments will be passed on unchanged to mitmproxy.", + ) parser.add_argument( - '-t', - '--toggle', - action='store_true', - help='just toggle the proxy configuration') + "-t", + "--toggle", + action="store_true", + help="just toggle the proxy configuration", + ) # parser.add_argument('--honeyproxy', action='store_true', help='run honeyproxy instead of mitmproxy') parser.add_argument( - '-p', - '--port', + "-p", + "--port", type=int, - help='override the default port of 8080', - default=8080) + help="override the default port of 8080", + default=8080, + ) args, extra_arguments = parser.parse_known_args() wrapper = cls(port=args.port, extra_arguments=extra_arguments) @@ -159,6 +150,6 @@ class Wrapper: wrapper.wrap_mitmproxy() -if __name__ == '__main__': +if __name__ == "__main__": Wrapper.ensure_superuser() Wrapper.main() diff --git a/examples/contrib/modify_body_inject_iframe.py b/examples/contrib/modify_body_inject_iframe.py index 595bd9f28..1736efd34 100644 --- a/examples/contrib/modify_body_inject_iframe.py +++ b/examples/contrib/modify_body_inject_iframe.py @@ -1,24 +1,21 @@ # (this script works best with --anticache) from bs4 import BeautifulSoup -from mitmproxy import ctx, http + +from mitmproxy import ctx +from mitmproxy import http class Injector: def load(self, loader): - loader.add_option( - "iframe", str, "", "IFrame to inject" - ) + loader.add_option("iframe", str, "", "IFrame to inject") def response(self, flow: http.HTTPFlow) -> None: if ctx.options.iframe: html = BeautifulSoup(flow.response.content, "html.parser") if html.body: iframe = html.new_tag( - "iframe", - src=ctx.options.iframe, - frameborder=0, - height=0, - width=0) + "iframe", src=ctx.options.iframe, frameborder=0, height=0, width=0 + ) html.body.insert(0, iframe) flow.response.content = str(html).encode("utf8") diff --git a/examples/contrib/ntlm_upstream_proxy.py b/examples/contrib/ntlm_upstream_proxy.py index 656d48b3a..f11a0b77a 100644 --- a/examples/contrib/ntlm_upstream_proxy.py +++ b/examples/contrib/ntlm_upstream_proxy.py @@ -1,28 +1,34 @@ import base64 +import binascii import logging import socket -from typing import Any, Optional +from typing import Any +from typing import Optional -import binascii -from ntlm_auth import gss_channel_bindings, ntlm +from ntlm_auth import gss_channel_bindings +from ntlm_auth import ntlm -from mitmproxy import addonmanager, http +from mitmproxy import addonmanager from mitmproxy import ctx +from mitmproxy import http from mitmproxy.net.http import http1 -from mitmproxy.proxy import commands, layer +from mitmproxy.proxy import commands +from mitmproxy.proxy import layer from mitmproxy.proxy.context import Context -from mitmproxy.proxy.layers.http import HttpConnectUpstreamHook, HttpLayer, HttpStream +from mitmproxy.proxy.layers.http import HttpConnectUpstreamHook +from mitmproxy.proxy.layers.http import HttpLayer +from mitmproxy.proxy.layers.http import HttpStream from mitmproxy.proxy.layers.http._upstream_proxy import HttpUpstreamProxy class NTLMUpstreamAuth: """ - This addon handles authentication to systems upstream from us for the - upstream proxy and reverse proxy mode. There are 3 cases: - - Upstream proxy CONNECT requests should have authentication added, and - subsequent already connected requests should not. - - Upstream proxy regular requests - - Reverse proxy regular requests (CONNECT is invalid in this mode) + This addon handles authentication to systems upstream from us for the + upstream proxy and reverse proxy mode. There are 3 cases: + - Upstream proxy CONNECT requests should have authentication added, and + subsequent already connected requests should not. + - Upstream proxy regular requests + - Reverse proxy regular requests (CONNECT is invalid in this mode) """ def load(self, loader: addonmanager.Loader) -> None: @@ -34,7 +40,7 @@ class NTLMUpstreamAuth: help=""" Add HTTP NTLM authentication to upstream proxy requests. Format: username:password. - """ + """, ) loader.add_option( name="upstream_ntlm_domain", @@ -42,7 +48,7 @@ class NTLMUpstreamAuth: default=None, help=""" Add HTTP NTLM domain for authentication to upstream proxy requests. - """ + """, ) loader.add_option( name="upstream_proxy_address", @@ -50,7 +56,7 @@ class NTLMUpstreamAuth: default=None, help=""" upstream poxy address. - """ + """, ) loader.add_option( name="upstream_ntlm_compatibility", @@ -59,7 +65,7 @@ class NTLMUpstreamAuth: help=""" Add HTTP NTLM compatibility for authentication to upstream proxy requests. Valid values are 0-5 (Default: 3) - """ + """, ) logging.debug("AddOn: NTLM Upstream Authentication - Loaded") @@ -69,9 +75,13 @@ class NTLMUpstreamAuth: for l in context.layers: if isinstance(l, HttpLayer): for _, stream in l.streams.items(): - return stream.flow if isinstance(stream, HttpStream) else None + return ( + stream.flow if isinstance(stream, HttpStream) else None + ) - def build_connect_flow(context: Context, connect_header: tuple) -> http.HTTPFlow: + def build_connect_flow( + context: Context, connect_header: tuple + ) -> http.HTTPFlow: flow = extract_flow_from_context(context) if not flow: logging.error("failed to build connect flow") @@ -85,23 +95,27 @@ class NTLMUpstreamAuth: assert self.conn.address self.ntlm_context = CustomNTLMContext(ctx) proxy_authorization = self.ntlm_context.get_ntlm_start_negotiate_message() - self.flow = build_connect_flow(self.context, ("Proxy-Authorization", proxy_authorization)) + self.flow = build_connect_flow( + self.context, ("Proxy-Authorization", proxy_authorization) + ) yield HttpConnectUpstreamHook(self.flow) raw = http1.assemble_request(self.flow.request) yield commands.SendData(self.tunnel_connection, raw) def extract_proxy_authenticate_msg(response_head: list) -> str: for header in response_head: - if b'Proxy-Authenticate' in header: - challenge_message = str(bytes(header).decode('utf-8')) + if b"Proxy-Authenticate" in header: + challenge_message = str(bytes(header).decode("utf-8")) try: - token = challenge_message.split(': ')[1] + token = challenge_message.split(": ")[1] except IndexError: logging.error("Failed to extract challenge_message") raise return token - def patched_receive_handshake_data(self, data) -> layer.CommandGenerator[tuple[bool, Optional[str]]]: + def patched_receive_handshake_data( + self, data + ) -> layer.CommandGenerator[tuple[bool, Optional[str]]]: self.buf += data response_head = self.buf.maybe_extract_lines() if response_head: @@ -119,8 +133,14 @@ class NTLMUpstreamAuth: else: if not challenge_message: return True, None - proxy_authorization = self.ntlm_context.get_ntlm_challenge_response_message(challenge_message) - self.flow = build_connect_flow(self.context, ("Proxy-Authorization", proxy_authorization)) + proxy_authorization = ( + self.ntlm_context.get_ntlm_challenge_response_message( + challenge_message + ) + ) + self.flow = build_connect_flow( + self.context, ("Proxy-Authorization", proxy_authorization) + ) raw = http1.assemble_request(self.flow.request) yield commands.SendData(self.tunnel_connection, raw) return False, None @@ -131,19 +151,19 @@ class NTLMUpstreamAuth: HttpUpstreamProxy.receive_handshake_data = patched_receive_handshake_data def done(self): - logging.info('close ntlm session') + logging.info("close ntlm session") -addons = [ - NTLMUpstreamAuth() -] +addons = [NTLMUpstreamAuth()] class CustomNTLMContext: - def __init__(self, - ctx, - preferred_type: str = 'NTLM', - cbt_data: gss_channel_bindings.GssChannelBindingsStruct = None): + def __init__( + self, + ctx, + preferred_type: str = "NTLM", + cbt_data: gss_channel_bindings.GssChannelBindingsStruct = None, + ): # TODO:// take care the cbt_data auth: str = ctx.options.upstream_ntlm_auth domain: str = str(ctx.options.upstream_ntlm_domain).upper() @@ -158,29 +178,39 @@ class CustomNTLMContext: domain=domain, workstation=workstation, ntlm_compatibility=ntlm_compatibility, - cbt_data=cbt_data) + cbt_data=cbt_data, + ) def get_ntlm_start_negotiate_message(self) -> str: negotiate_message = self.ntlm_context.step() negotiate_message_base_64_in_bytes = base64.b64encode(negotiate_message) - negotiate_message_base_64_ascii = negotiate_message_base_64_in_bytes.decode("ascii") - negotiate_message_base_64_final = f'{self.preferred_type} {negotiate_message_base_64_ascii}' + negotiate_message_base_64_ascii = negotiate_message_base_64_in_bytes.decode( + "ascii" + ) + negotiate_message_base_64_final = ( + f"{self.preferred_type} {negotiate_message_base_64_ascii}" + ) logging.debug( - f'{self.preferred_type} Authentication, negotiate message: {negotiate_message_base_64_final}' + f"{self.preferred_type} Authentication, negotiate message: {negotiate_message_base_64_final}" ) return negotiate_message_base_64_final def get_ntlm_challenge_response_message(self, challenge_message: str) -> Any: challenge_message = challenge_message.replace(self.preferred_type + " ", "", 1) try: - challenge_message_ascii_bytes = base64.b64decode(challenge_message, validate=True) + challenge_message_ascii_bytes = base64.b64decode( + challenge_message, validate=True + ) except binascii.Error as err: - logging.debug(f'{self.preferred_type} Authentication fail with error {err.__str__()}') + logging.debug( + f"{self.preferred_type} Authentication fail with error {err.__str__()}" + ) return False authenticate_message = self.ntlm_context.step(challenge_message_ascii_bytes) - negotiate_message_base_64 = '{} {}'.format(self.preferred_type, - base64.b64encode(authenticate_message).decode('ascii')) + negotiate_message_base_64 = "{} {}".format( + self.preferred_type, base64.b64encode(authenticate_message).decode("ascii") + ) logging.debug( - f'{self.preferred_type} Authentication, response to challenge message: {negotiate_message_base_64}' + f"{self.preferred_type} Authentication, response to challenge message: {negotiate_message_base_64}" ) return negotiate_message_base_64 diff --git a/examples/contrib/remote-debug.py b/examples/contrib/remote-debug.py index 767d828cd..323b88d1b 100644 --- a/examples/contrib/remote-debug.py +++ b/examples/contrib/remote-debug.py @@ -18,4 +18,7 @@ Usage: def load(l): import pydevd_pycharm - pydevd_pycharm.settrace("localhost", port=5678, stdoutToServer=True, stderrToServer=True, suspend=False) + + pydevd_pycharm.settrace( + "localhost", port=5678, stdoutToServer=True, stderrToServer=True, suspend=False + ) diff --git a/examples/contrib/save_streamed_data.py b/examples/contrib/save_streamed_data.py index 283a6b52b..640770596 100644 --- a/examples/contrib/save_streamed_data.py +++ b/examples/contrib/save_streamed_data.py @@ -58,12 +58,13 @@ class StreamSaver: return data if not self.fh: - self.path = datetime.fromtimestamp(self.flow.request.timestamp_start).strftime( - ctx.options.save_streamed_data) - self.path = self.path.replace('%+T', str(self.flow.request.timestamp_start)) - self.path = self.path.replace('%+I', str(self.flow.client_conn.id)) - self.path = self.path.replace('%+D', self.direction) - self.path = self.path.replace('%+C', self.flow.client_conn.address[0]) + self.path = datetime.fromtimestamp( + self.flow.request.timestamp_start + ).strftime(ctx.options.save_streamed_data) + self.path = self.path.replace("%+T", str(self.flow.request.timestamp_start)) + self.path = self.path.replace("%+I", str(self.flow.client_conn.id)) + self.path = self.path.replace("%+D", self.direction) + self.path = self.path.replace("%+C", self.flow.client_conn.address[0]) self.path = os.path.expanduser(self.path) parent = Path(self.path).parent @@ -89,25 +90,27 @@ class StreamSaver: def load(loader): loader.add_option( - "save_streamed_data", Optional[str], None, + "save_streamed_data", + Optional[str], + None, "Format string for saving streamed data to files. If set each streamed request or response is written " "to a file with a name derived from the string. In addition to formating supported by python " "strftime() (using the request start time) the code '%+T' is replaced with the time stamp of the request, " "'%+D' by 'req' or 'rsp' depending on the direction of the data, '%+C' by the client IP addresses and " - "'%+I' by the client connection ID." + "'%+I' by the client connection ID.", ) def requestheaders(flow): if ctx.options.save_streamed_data and flow.request.stream: - flow.request.stream = StreamSaver(flow, 'req') + flow.request.stream = StreamSaver(flow, "req") def responseheaders(flow): if isinstance(flow.request.stream, StreamSaver): flow.request.stream.done() if ctx.options.save_streamed_data and flow.response.stream: - flow.response.stream = StreamSaver(flow, 'rsp') + flow.response.stream = StreamSaver(flow, "rsp") def response(flow): diff --git a/examples/contrib/search.py b/examples/contrib/search.py index e9c935ac6..73d775d2b 100644 --- a/examples/contrib/search.py +++ b/examples/contrib/search.py @@ -3,20 +3,19 @@ import re from collections.abc import Sequence from json import dumps -from mitmproxy import command, flow +from mitmproxy import command +from mitmproxy import flow -MARKER = ':mag:' -RESULTS_STR = 'Search Results: ' +MARKER = ":mag:" +RESULTS_STR = "Search Results: " class Search: def __init__(self): self.exp = None - @command.command('search') - def _search(self, - flows: Sequence[flow.Flow], - regex: str) -> None: + @command.command("search") + def _search(self, flows: Sequence[flow.Flow], regex: str) -> None: """ Defines a command named "search" that matches the given regular expression against most parts @@ -49,11 +48,11 @@ class Search: for _flow in flows: # Erase previous results while preserving other comments: comments = list() - for c in _flow.comment.split('\n'): + for c in _flow.comment.split("\n"): if c.startswith(RESULTS_STR): break comments.append(c) - _flow.comment = '\n'.join(comments) + _flow.comment = "\n".join(comments) if _flow.marked == MARKER: _flow.marked = False @@ -62,7 +61,7 @@ class Search: if results: comments.append(RESULTS_STR) comments.append(dumps(results, indent=2)) - _flow.comment = '\n'.join(comments) + _flow.comment = "\n".join(comments) _flow.marked = MARKER def header_results(self, message): @@ -71,22 +70,16 @@ class Search: def flow_results(self, _flow): results = dict() - results.update( - {'flow_comment': self.exp.findall(_flow.comment)}) + results.update({"flow_comment": self.exp.findall(_flow.comment)}) if _flow.request is not None: - results.update( - {'request_path': self.exp.findall(_flow.request.path)}) - results.update( - {'request_headers': self.header_results(_flow.request)}) + results.update({"request_path": self.exp.findall(_flow.request.path)}) + results.update({"request_headers": self.header_results(_flow.request)}) if _flow.request.text: - results.update( - {'request_body': self.exp.findall(_flow.request.text)}) + results.update({"request_body": self.exp.findall(_flow.request.text)}) if _flow.response is not None: - results.update( - {'response_headers': self.header_results(_flow.response)}) + results.update({"response_headers": self.header_results(_flow.response)}) if _flow.response.text: - results.update( - {'response_body': self.exp.findall(_flow.response.text)}) + results.update({"response_body": self.exp.findall(_flow.response.text)}) return results diff --git a/examples/contrib/sslstrip.py b/examples/contrib/sslstrip.py index 05aa5f3e5..6b88c3956 100644 --- a/examples/contrib/sslstrip.py +++ b/examples/contrib/sslstrip.py @@ -12,15 +12,15 @@ secure_hosts: set[str] = set() def request(flow: http.HTTPFlow) -> None: - flow.request.headers.pop('If-Modified-Since', None) - flow.request.headers.pop('Cache-Control', None) + flow.request.headers.pop("If-Modified-Since", None) + flow.request.headers.pop("Cache-Control", None) # do not force https redirection - flow.request.headers.pop('Upgrade-Insecure-Requests', None) + flow.request.headers.pop("Upgrade-Insecure-Requests", None) # proxy connections to SSL-enabled hosts if flow.request.pretty_host in secure_hosts: - flow.request.scheme = 'https' + flow.request.scheme = "https" flow.request.port = 443 # We need to update the request destination to whatever is specified in the host header: @@ -31,32 +31,36 @@ def request(flow: http.HTTPFlow) -> None: def response(flow: http.HTTPFlow) -> None: assert flow.response - flow.response.headers.pop('Strict-Transport-Security', None) - flow.response.headers.pop('Public-Key-Pins', None) + flow.response.headers.pop("Strict-Transport-Security", None) + flow.response.headers.pop("Public-Key-Pins", None) # strip links in response body - flow.response.content = flow.response.content.replace(b'https://', b'http://') + flow.response.content = flow.response.content.replace(b"https://", b"http://") # strip meta tag upgrade-insecure-requests in response body - csp_meta_tag_pattern = br'' - flow.response.content = re.sub(csp_meta_tag_pattern, b'', flow.response.content, flags=re.IGNORECASE) + csp_meta_tag_pattern = rb'' + flow.response.content = re.sub( + csp_meta_tag_pattern, b"", flow.response.content, flags=re.IGNORECASE + ) # strip links in 'Location' header - if flow.response.headers.get('Location', '').startswith('https://'): - location = flow.response.headers['Location'] + if flow.response.headers.get("Location", "").startswith("https://"): + location = flow.response.headers["Location"] hostname = urllib.parse.urlparse(location).hostname if hostname: secure_hosts.add(hostname) - flow.response.headers['Location'] = location.replace('https://', 'http://', 1) + flow.response.headers["Location"] = location.replace("https://", "http://", 1) # strip upgrade-insecure-requests in Content-Security-Policy header - csp_header = flow.response.headers.get('Content-Security-Policy', '') - if re.search('upgrade-insecure-requests', csp_header, flags=re.IGNORECASE): - csp = flow.response.headers['Content-Security-Policy'] - new_header = re.sub(r'upgrade-insecure-requests[;\s]*', '', csp, flags=re.IGNORECASE) - flow.response.headers['Content-Security-Policy'] = new_header + csp_header = flow.response.headers.get("Content-Security-Policy", "") + if re.search("upgrade-insecure-requests", csp_header, flags=re.IGNORECASE): + csp = flow.response.headers["Content-Security-Policy"] + new_header = re.sub( + r"upgrade-insecure-requests[;\s]*", "", csp, flags=re.IGNORECASE + ) + flow.response.headers["Content-Security-Policy"] = new_header # strip secure flag from 'Set-Cookie' headers - cookies = flow.response.headers.get_all('Set-Cookie') - cookies = [re.sub(r';\s*secure\s*', '', s) for s in cookies] - flow.response.headers.set_all('Set-Cookie', cookies) + cookies = flow.response.headers.get_all("Set-Cookie") + cookies = [re.sub(r";\s*secure\s*", "", s) for s in cookies] + flow.response.headers.set_all("Set-Cookie", cookies) diff --git a/examples/contrib/suppress_error_responses.py b/examples/contrib/suppress_error_responses.py index e087a78da..5cb319ef6 100644 --- a/examples/contrib/suppress_error_responses.py +++ b/examples/contrib/suppress_error_responses.py @@ -10,7 +10,7 @@ from mitmproxy.exceptions import HttpSyntaxException def error(self, flow: http.HTTPFlow): """Kills the flow if it has an error different to HTTPSyntaxException. - Sometimes, web scanners generate malformed HTTP syntax on purpose and we do not want to kill these requests. + Sometimes, web scanners generate malformed HTTP syntax on purpose and we do not want to kill these requests. """ if flow.error is not None and not isinstance(flow.error, HttpSyntaxException): flow.kill() diff --git a/examples/contrib/test_har_dump.py b/examples/contrib/test_har_dump.py index 88c27a9b4..77b8db6c7 100644 --- a/examples/contrib/test_har_dump.py +++ b/examples/contrib/test_har_dump.py @@ -1,13 +1,13 @@ import json +from mitmproxy.net.http import cookies +from mitmproxy.test import taddons from mitmproxy.test import tflow from mitmproxy.test import tutils -from mitmproxy.test import taddons -from mitmproxy.net.http import cookies class TestHARDump: - def flow(self, resp_content=b'message'): + def flow(self, resp_content=b"message"): times = dict( timestamp_start=746203272, timestamp_end=746203272, @@ -15,8 +15,8 @@ class TestHARDump: # Create a dummy flow for testing return tflow.tflow( - req=tutils.treq(method=b'GET', **times), - resp=tutils.tresp(content=resp_content, **times) + req=tutils.treq(method=b"GET", **times), + resp=tutils.tresp(content=resp_content, **times), ) def test_simple(self, tmpdir, tdata): @@ -26,7 +26,7 @@ class TestHARDump: a = tctx.script(tdata.path("../examples/contrib/har_dump.py")) # check script is read without errors assert tctx.master.logs == [] - assert a.name_value # last function in har_dump.py + assert a.name_value # last function in har_dump.py path = str(tmpdir.join("somefile")) tctx.configure(a, hardump=path) @@ -46,7 +46,9 @@ class TestHARDump: a.done() with open(path) as inp: har = json.load(inp) - assert har["log"]["entries"][0]["response"]["content"]["encoding"] == "base64" + assert ( + har["log"]["entries"][0]["response"]["content"]["encoding"] == "base64" + ) def test_format_cookies(self, tdata): with taddons.context() as tctx: @@ -55,17 +57,21 @@ class TestHARDump: CA = cookies.CookieAttrs f = a.format_cookies([("n", "v", CA([("k", "v")]))])[0] - assert f['name'] == "n" - assert f['value'] == "v" - assert not f['httpOnly'] - assert not f['secure'] + assert f["name"] == "n" + assert f["value"] == "v" + assert not f["httpOnly"] + assert not f["secure"] - f = a.format_cookies([("n", "v", CA([("httponly", None), ("secure", None)]))])[0] - assert f['httpOnly'] - assert f['secure'] + f = a.format_cookies( + [("n", "v", CA([("httponly", None), ("secure", None)]))] + )[0] + assert f["httpOnly"] + assert f["secure"] - f = a.format_cookies([("n", "v", CA([("expires", "Mon, 24-Aug-2037 00:00:00 GMT")]))])[0] - assert f['expires'] + f = a.format_cookies( + [("n", "v", CA([("expires", "Mon, 24-Aug-2037 00:00:00 GMT")]))] + )[0] + assert f["expires"] def test_binary(self, tmpdir, tdata): with taddons.context() as tctx: diff --git a/examples/contrib/test_jsondump.py b/examples/contrib/test_jsondump.py index 106a0ecbf..abb85c290 100644 --- a/examples/contrib/test_jsondump.py +++ b/examples/contrib/test_jsondump.py @@ -1,21 +1,21 @@ -import json import base64 - -from mitmproxy.test import tflow -from mitmproxy.test import tutils -from mitmproxy.test import taddons +import json import requests_mock +from mitmproxy.test import taddons +from mitmproxy.test import tflow +from mitmproxy.test import tutils + example_dir = tutils.test_data.push("../examples") class TestJSONDump: def echo_response(self, request, context): - self.request = {'json': request.json(), 'headers': request.headers} - return '' + self.request = {"json": request.json(), "headers": request.headers} + return "" - def flow(self, resp_content=b'message'): + def flow(self, resp_content=b"message"): times = dict( timestamp_start=746203272, timestamp_end=746203272, @@ -23,8 +23,8 @@ class TestJSONDump: # Create a dummy flow for testing return tflow.tflow( - req=tutils.treq(method=b'GET', **times), - resp=tutils.tresp(content=resp_content, **times) + req=tutils.treq(method=b"GET", **times), + resp=tutils.tresp(content=resp_content, **times), ) def test_simple(self, tmpdir): @@ -36,7 +36,7 @@ class TestJSONDump: tctx.invoke(a, "done") with open(path) as inp: entry = json.loads(inp.readline()) - assert entry['response']['content'] == 'message' + assert entry["response"]["content"] == "message" def test_contentencode(self, tmpdir): with taddons.context() as tctx: @@ -45,24 +45,28 @@ class TestJSONDump: content = b"foo" + b"\xFF" * 10 tctx.configure(a, dump_destination=path, dump_encodecontent=True) - tctx.invoke( - a, "response", self.flow(resp_content=content) - ) + tctx.invoke(a, "response", self.flow(resp_content=content)) tctx.invoke(a, "done") with open(path) as inp: entry = json.loads(inp.readline()) - assert entry['response']['content'] == base64.b64encode(content).decode('utf-8') + assert entry["response"]["content"] == base64.b64encode(content).decode( + "utf-8" + ) def test_http(self, tmpdir): with requests_mock.Mocker() as mock: - mock.post('http://my-server', text=self.echo_response) + mock.post("http://my-server", text=self.echo_response) with taddons.context() as tctx: a = tctx.script(example_dir.path("complex/jsondump.py")) - tctx.configure(a, dump_destination='http://my-server', - dump_username='user', dump_password='pass') + tctx.configure( + a, + dump_destination="http://my-server", + dump_username="user", + dump_password="pass", + ) tctx.invoke(a, "response", self.flow()) tctx.invoke(a, "done") - assert self.request['json']['response']['content'] == 'message' - assert self.request['headers']['Authorization'] == 'Basic dXNlcjpwYXNz' + assert self.request["json"]["response"]["content"] == "message" + assert self.request["headers"]["Authorization"] == "Basic dXNlcjpwYXNz" diff --git a/examples/contrib/test_xss_scanner.py b/examples/contrib/test_xss_scanner.py index f27752825..2f89bab48 100644 --- a/examples/contrib/test_xss_scanner.py +++ b/examples/contrib/test_xss_scanner.py @@ -1,229 +1,331 @@ import pytest import requests + from examples.complex import xss_scanner as xss -from mitmproxy.test import tflow, tutils +from mitmproxy.test import tflow +from mitmproxy.test import tutils -class TestXSSScanner(): +class TestXSSScanner: def test_get_XSS_info(self): # First type of exploit: # Exploitable: - xss_info = xss.get_XSS_data(b"" % - xss.FULL_PAYLOAD, - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData('https://example.com', - "End of URL", - '" % xss.FULL_PAYLOAD, + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + "" % - xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22"), - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - '" + % xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22"), + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + "" % - xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22").replace(b"/", b"%2F"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b"" + % xss.FULL_PAYLOAD.replace(b"'", b"%27") + .replace(b'"', b"%22") + .replace(b"/", b"%2F"), + "https://example.com", + "End of URL", + ) assert xss_info is None # Second type of exploit: # Exploitable: - xss_info = xss.get_XSS_data(b"" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"\"", b"%22"), - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - "';alert(0);g='", - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") - .replace(b"\"", b"%22").decode('utf-8')) + xss_info = xss.get_XSS_data( + b"" + % xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .replace(b'"', b"%22"), + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + "';alert(0);g='", + xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .replace(b'"', b"%22") + .decode("utf-8"), + ) assert xss_info == expected_xss_info # Non-Exploitable: - xss_info = xss.get_XSS_data(b"" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b"\"", b"%22").replace(b"'", b"%22"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b"" + % xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b'"', b"%22") + .replace(b"'", b"%22"), + "https://example.com", + "End of URL", + ) assert xss_info is None # Third type of exploit: # Exploitable: - xss_info = xss.get_XSS_data(b"" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"'", b"%27"), - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - '";alert(0);g="', - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") - .replace(b"'", b"%27").decode('utf-8')) + xss_info = xss.get_XSS_data( + b'' + % xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .replace(b"'", b"%27"), + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + '";alert(0);g="', + xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .replace(b"'", b"%27") + .decode("utf-8"), + ) assert xss_info == expected_xss_info # Non-Exploitable: - xss_info = xss.get_XSS_data(b"" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b"'", b"%27").replace(b"\"", b"%22"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b'' + % xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b"'", b"%27") + .replace(b'"', b"%22"), + "https://example.com", + "End of URL", + ) assert xss_info is None # Fourth type of exploit: Test # Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD, - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - "'>", - xss.FULL_PAYLOAD.decode('utf-8')) + xss_info = xss.get_XSS_data( + b"Test" % xss.FULL_PAYLOAD, + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + "'>", + xss.FULL_PAYLOAD.decode("utf-8"), + ) assert xss_info == expected_xss_info # Non-Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"'", b"%27"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b"Test" + % xss.FULL_PAYLOAD.replace(b"'", b"%27"), + "https://example.com", + "End of URL", + ) assert xss_info is None # Fifth type of exploit: Test # Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"'", b"%27"), - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - "\">", - xss.FULL_PAYLOAD.replace(b"'", b"%27").decode('utf-8')) + xss_info = xss.get_XSS_data( + b'Test' + % xss.FULL_PAYLOAD.replace(b"'", b"%27"), + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + '">', + xss.FULL_PAYLOAD.replace(b"'", b"%27").decode("utf-8"), + ) assert xss_info == expected_xss_info # Non-Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b"\"", b"%22"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b'Test' + % xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22"), + "https://example.com", + "End of URL", + ) assert xss_info is None # Sixth type of exploit: Test # Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD, - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - ">", - xss.FULL_PAYLOAD.decode('utf-8')) + xss_info = xss.get_XSS_data( + b"Test" % xss.FULL_PAYLOAD, + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + ">", + xss.FULL_PAYLOAD.decode("utf-8"), + ) assert xss_info == expected_xss_info # Non-Exploitable - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") - .replace(b"=", b"%3D"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b"Test" + % xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .replace(b"=", b"%3D"), + "https://example.com", + "End of URL", + ) assert xss_info is None # Seventh type of exploit: PAYLOAD # Exploitable: - xss_info = xss.get_XSS_data(b"%s" % - xss.FULL_PAYLOAD, - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - "", - xss.FULL_PAYLOAD.decode('utf-8')) + xss_info = xss.get_XSS_data( + b"%s" % xss.FULL_PAYLOAD, + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + "", + xss.FULL_PAYLOAD.decode("utf-8"), + ) assert xss_info == expected_xss_info # Non-Exploitable - xss_info = xss.get_XSS_data(b"%s" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"/", b"%2F"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b"%s" + % xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .replace(b"/", b"%2F"), + "https://example.com", + "End of URL", + ) assert xss_info is None # Eighth type of exploit: Test # Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - "Javascript:alert(0)", - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) + xss_info = xss.get_XSS_data( + b"Test" + % xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + "Javascript:alert(0)", + xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .decode("utf-8"), + ) assert xss_info == expected_xss_info # Non-Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") - .replace(b"=", b"%3D"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b"Test" + % xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .replace(b"=", b"%3D"), + "https://example.com", + "End of URL", + ) assert xss_info is None # Ninth type of exploit: Test # Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - '" onmouseover="alert(0)" t="', - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) + xss_info = xss.get_XSS_data( + b'Test' + % xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + '" onmouseover="alert(0)" t="', + xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .decode("utf-8"), + ) assert xss_info == expected_xss_info # Non-Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") - .replace(b'"', b"%22"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b'Test' + % xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .replace(b'"', b"%22"), + "https://example.com", + "End of URL", + ) assert xss_info is None # Tenth type of exploit: Test # Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - "' onmouseover='alert(0)' t='", - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) + xss_info = xss.get_XSS_data( + b"Test" + % xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + "' onmouseover='alert(0)' t='", + xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .decode("utf-8"), + ) assert xss_info == expected_xss_info # Non-Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") - .replace(b"'", b"%22"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b"Test" + % xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .replace(b"'", b"%22"), + "https://example.com", + "End of URL", + ) assert xss_info is None # Eleventh type of exploit: Test # Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), - "https://example.com", - "End of URL") - expected_xss_info = xss.XSSData("https://example.com", - "End of URL", - " onmouseover=alert(0) t=", - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) + xss_info = xss.get_XSS_data( + b"Test" + % xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), + "https://example.com", + "End of URL", + ) + expected_xss_info = xss.XSSData( + "https://example.com", + "End of URL", + " onmouseover=alert(0) t=", + xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .decode("utf-8"), + ) assert xss_info == expected_xss_info # Non-Exploitable: - xss_info = xss.get_XSS_data(b"Test" % - xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") - .replace(b"=", b"%3D"), - "https://example.com", - "End of URL") + xss_info = xss.get_XSS_data( + b"Test" + % xss.FULL_PAYLOAD.replace(b"<", b"%3C") + .replace(b">", b"%3E") + .replace(b"=", b"%3D"), + "https://example.com", + "End of URL", + ) assert xss_info is None def test_get_SQLi_data(self): - sqli_data = xss.get_SQLi_data("SQL syntax MySQL", - "", - "https://example.com", - "End of URL") - expected_sqli_data = xss.SQLiData("https://example.com", - "End of URL", - "SQL syntax.*MySQL", - "MySQL") + sqli_data = xss.get_SQLi_data( + "SQL syntax MySQL", + "", + "https://example.com", + "End of URL", + ) + expected_sqli_data = xss.SQLiData( + "https://example.com", "End of URL", "SQL syntax.*MySQL", "MySQL" + ) assert sqli_data == expected_sqli_data - sqli_data = xss.get_SQLi_data("SQL syntax MySQL", - "SQL syntax MySQL", - "https://example.com", - "End of URL") + sqli_data = xss.get_SQLi_data( + "SQL syntax MySQL", + "SQL syntax MySQL", + "https://example.com", + "End of URL", + ) assert sqli_data is None def test_inside_quote(self): @@ -233,9 +335,12 @@ class TestXSSScanner(): assert not xss.inside_quote("'", b"longStringNotInIt", 1, b"short") def test_paths_to_text(self): - text = xss.paths_to_text("""

STRING

+ text = xss.paths_to_text( + """

STRING

- """, "STRING") + """, + "STRING", + ) expected_text = ["/html/head/h1", "/html/script"] assert text == expected_text assert xss.paths_to_text("""""", "STRING") == [] @@ -244,114 +349,156 @@ class TestXSSScanner(): class MockResponse: def __init__(self, html, headers=None, cookies=None): self.text = html + return MockResponse("%s" % xss.FULL_PAYLOAD) def mocked_requests_invuln(*args, headers=None, cookies=None): class MockResponse: def __init__(self, html, headers=None, cookies=None): self.text = html + return MockResponse("") def test_test_end_of_url_injection(self, get_request_vuln): - xss_info = xss.test_end_of_URL_injection("", "https://example.com/index.html", {})[0] - expected_xss_info = xss.XSSData('https://example.com/index.html/1029zxcs\'d"aoso[sb]po(pc)se;sl/bsl\\eq=3847asd', - 'End of URL', - '', - '1029zxcs\\\'d"aoso[sb]po(pc)se;sl/bsl\\\\eq=3847asd') - sqli_info = xss.test_end_of_URL_injection("", "https://example.com/", {})[1] + xss_info = xss.test_end_of_URL_injection( + "", "https://example.com/index.html", {} + )[0] + expected_xss_info = xss.XSSData( + "https://example.com/index.html/1029zxcs'd\"aoso[sb]po(pc)se;sl/bsl\\eq=3847asd", + "End of URL", + "", + "1029zxcs\\'d\"aoso[sb]po(pc)se;sl/bsl\\\\eq=3847asd", + ) + sqli_info = xss.test_end_of_URL_injection( + "", "https://example.com/", {} + )[1] assert xss_info == expected_xss_info assert sqli_info is None def test_test_referer_injection(self, get_request_vuln): - xss_info = xss.test_referer_injection("", "https://example.com/", {})[0] - expected_xss_info = xss.XSSData('https://example.com/', - 'Referer', - '', - '1029zxcs\\\'d"aoso[sb]po(pc)se;sl/bsl\\\\eq=3847asd') - sqli_info = xss.test_referer_injection("", "https://example.com/", {})[1] + xss_info = xss.test_referer_injection( + "", "https://example.com/", {} + )[0] + expected_xss_info = xss.XSSData( + "https://example.com/", + "Referer", + "", + "1029zxcs\\'d\"aoso[sb]po(pc)se;sl/bsl\\\\eq=3847asd", + ) + sqli_info = xss.test_referer_injection( + "", "https://example.com/", {} + )[1] assert xss_info == expected_xss_info assert sqli_info is None def test_test_user_agent_injection(self, get_request_vuln): - xss_info = xss.test_user_agent_injection("", "https://example.com/", {})[0] - expected_xss_info = xss.XSSData('https://example.com/', - 'User Agent', - '', - '1029zxcs\\\'d"aoso[sb]po(pc)se;sl/bsl\\\\eq=3847asd') - sqli_info = xss.test_user_agent_injection("", "https://example.com/", {})[1] + xss_info = xss.test_user_agent_injection( + "", "https://example.com/", {} + )[0] + expected_xss_info = xss.XSSData( + "https://example.com/", + "User Agent", + "", + "1029zxcs\\'d\"aoso[sb]po(pc)se;sl/bsl\\\\eq=3847asd", + ) + sqli_info = xss.test_user_agent_injection( + "", "https://example.com/", {} + )[1] assert xss_info == expected_xss_info assert sqli_info is None def test_test_query_injection(self, get_request_vuln): - xss_info = xss.test_query_injection("", "https://example.com/vuln.php?cmd=ls", {})[0] - expected_xss_info = xss.XSSData('https://example.com/vuln.php?cmd=1029zxcs\'d"aoso[sb]po(pc)se;sl/bsl\\eq=3847asd', - 'Query', - '', - '1029zxcs\\\'d"aoso[sb]po(pc)se;sl/bsl\\\\eq=3847asd') - sqli_info = xss.test_query_injection("", "https://example.com/vuln.php?cmd=ls", {})[1] + xss_info = xss.test_query_injection( + "", "https://example.com/vuln.php?cmd=ls", {} + )[0] + expected_xss_info = xss.XSSData( + "https://example.com/vuln.php?cmd=1029zxcs'd\"aoso[sb]po(pc)se;sl/bsl\\eq=3847asd", + "Query", + "", + "1029zxcs\\'d\"aoso[sb]po(pc)se;sl/bsl\\\\eq=3847asd", + ) + sqli_info = xss.test_query_injection( + "", "https://example.com/vuln.php?cmd=ls", {} + )[1] assert xss_info == expected_xss_info assert sqli_info is None - @pytest.fixture(scope='function') + @pytest.fixture(scope="function") def get_request_vuln(self, monkeypatch): - monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln) + monkeypatch.setattr(requests, "get", self.mocked_requests_vuln) - @pytest.fixture(scope='function') + @pytest.fixture(scope="function") def get_request_invuln(self, monkeypatch): - monkeypatch.setattr(requests, 'get', self.mocked_requests_invuln) + monkeypatch.setattr(requests, "get", self.mocked_requests_invuln) - @pytest.fixture(scope='function') + @pytest.fixture(scope="function") def mock_gethostbyname(self, monkeypatch): def gethostbyname(domain): claimed_domains = ["google.com"] if domain not in claimed_domains: from socket import gaierror + raise gaierror("[Errno -2] Name or service not known") else: - return '216.58.221.46' + return "216.58.221.46" monkeypatch.setattr("socket.gethostbyname", gethostbyname) def test_find_unclaimed_URLs(self, logger, mock_gethostbyname): - xss.find_unclaimed_URLs("", - "https://example.com") + xss.find_unclaimed_URLs( + '', + "https://example.com", + ) assert logger.args == [] - xss.find_unclaimed_URLs("", - "https://example.com") - assert logger.args[0] == 'XSS found in https://example.com due to unclaimed URL "http://unclaimedDomainName.com".' - xss.find_unclaimed_URLs("", - "https://example.com") - assert logger.args[1] == 'XSS found in https://example.com due to unclaimed URL "http://unclaimedDomainName.com".' - xss.find_unclaimed_URLs("", - "https://example.com") - assert logger.args[2] == 'XSS found in https://example.com due to unclaimed URL "http://unclaimedDomainName.com".' + xss.find_unclaimed_URLs( + '', + "https://example.com", + ) + assert ( + logger.args[0] + == 'XSS found in https://example.com due to unclaimed URL "http://unclaimedDomainName.com".' + ) + xss.find_unclaimed_URLs( + '', + "https://example.com", + ) + assert ( + logger.args[1] + == 'XSS found in https://example.com due to unclaimed URL "http://unclaimedDomainName.com".' + ) + xss.find_unclaimed_URLs( + '', + "https://example.com", + ) + assert ( + logger.args[2] + == 'XSS found in https://example.com due to unclaimed URL "http://unclaimedDomainName.com".' + ) def test_log_XSS_data(self, logger): xss.log_XSS_data(None) assert logger.args == [] # self, url: str, injection_point: str, exploit: str, line: str - xss.log_XSS_data(xss.XSSData('https://example.com', - 'Location', - 'String', - 'Line of HTML')) - assert logger.args[0] == '===== XSS Found ====' - assert logger.args[1] == 'XSS URL: https://example.com' - assert logger.args[2] == 'Injection Point: Location' - assert logger.args[3] == 'Suggested Exploit: String' - assert logger.args[4] == 'Line: Line of HTML' + xss.log_XSS_data( + xss.XSSData("https://example.com", "Location", "String", "Line of HTML") + ) + assert logger.args[0] == "===== XSS Found ====" + assert logger.args[1] == "XSS URL: https://example.com" + assert logger.args[2] == "Injection Point: Location" + assert logger.args[3] == "Suggested Exploit: String" + assert logger.args[4] == "Line: Line of HTML" def test_log_SQLi_data(self, logger): xss.log_SQLi_data(None) assert logger.args == [] - xss.log_SQLi_data(xss.SQLiData('https://example.com', - 'Location', - 'Oracle.*Driver', - 'Oracle')) - assert logger.args[0] == '===== SQLi Found =====' - assert logger.args[1] == 'SQLi URL: https://example.com' - assert logger.args[2] == 'Injection Point: Location' - assert logger.args[3] == 'Regex used: Oracle.*Driver' + xss.log_SQLi_data( + xss.SQLiData("https://example.com", "Location", "Oracle.*Driver", "Oracle") + ) + assert logger.args[0] == "===== SQLi Found =====" + assert logger.args[1] == "SQLi URL: https://example.com" + assert logger.args[2] == "Injection Point: Location" + assert logger.args[3] == "Regex used: Oracle.*Driver" def test_get_cookies(self): mocked_req = tutils.treq() @@ -363,7 +510,7 @@ class TestXSSScanner(): def test_response(self, get_request_invuln, logger): mocked_flow = tflow.tflow( req=tutils.treq(path=b"index.html?q=1"), - resp=tutils.tresp(content=b'') + resp=tutils.tresp(content=b""), ) xss.response(mocked_flow) assert logger.args == [] diff --git a/examples/contrib/tls_passthrough.py b/examples/contrib/tls_passthrough.py index 16d90ddad..ab50d4191 100644 --- a/examples/contrib/tls_passthrough.py +++ b/examples/contrib/tls_passthrough.py @@ -17,10 +17,13 @@ Example: import collections import logging import random -from abc import ABC, abstractmethod +from abc import ABC +from abc import abstractmethod from enum import Enum -from mitmproxy import connection, ctx, tls +from mitmproxy import connection +from mitmproxy import ctx +from mitmproxy import tls from mitmproxy.utils import human @@ -54,6 +57,7 @@ class ConservativeStrategy(TlsStrategy): Conservative Interception Strategy - only intercept if there haven't been any failed attempts in the history. """ + def should_intercept(self, server_address: connection.Address) -> bool: return InterceptionResult.FAILURE not in self.history[server_address] @@ -62,6 +66,7 @@ class ProbabilisticStrategy(TlsStrategy): """ Fixed probability that we intercept a given connection. """ + def __init__(self, p: float): self.p = p super().__init__() @@ -75,7 +80,9 @@ class MaybeTls: def load(self, l): l.add_option( - "tls_strategy", int, 0, + "tls_strategy", + int, + 0, "TLS passthrough strategy. If set to 0, connections will be passed through after the first unsuccessful " "handshake. If set to 0 < p <= 100, connections with be passed through with probability p.", ) @@ -97,7 +104,9 @@ class MaybeTls: def tls_established_client(self, data: tls.TlsData): server_address = data.context.server.peername - logging.info(f"TLS handshake successful: {human.format_address(server_address)}") + logging.info( + f"TLS handshake successful: {human.format_address(server_address)}" + ) self.strategy.record_success(server_address) def tls_failed_client(self, data: tls.TlsData): diff --git a/examples/contrib/webscanner_helper/mapping.py b/examples/contrib/webscanner_helper/mapping.py index 333809f53..52509730d 100644 --- a/examples/contrib/webscanner_helper/mapping.py +++ b/examples/contrib/webscanner_helper/mapping.py @@ -3,8 +3,8 @@ import logging from bs4 import BeautifulSoup -from mitmproxy.http import HTTPFlow from examples.contrib.webscanner_helper.urldict import URLDict +from mitmproxy.http import HTTPFlow NO_CONTENT = object() @@ -14,7 +14,7 @@ class MappingAddonConfig: class MappingAddon: - """ The mapping add-on can be used in combination with web application scanners to reduce their false positives. + """The mapping add-on can be used in combination with web application scanners to reduce their false positives. Many web application scanners produce false positives caused by dynamically changing content of web applications such as the current time or current measurements. When testing for injection vulnerabilities, web application @@ -45,7 +45,7 @@ class MappingAddon: """Whether to store all new content in the configuration file.""" def __init__(self, filename: str, persistent: bool = False) -> None: - """ Initializes the mapping add-on + """Initializes the mapping add-on Args: filename: str that provides the name of the file in which the urls and css selectors to mapped content is @@ -71,12 +71,16 @@ class MappingAddon: def load(self, loader): loader.add_option( - self.OPT_MAPPING_FILE, str, "", - "File where replacement configuration is stored." + self.OPT_MAPPING_FILE, + str, + "", + "File where replacement configuration is stored.", ) loader.add_option( - self.OPT_MAP_PERSISTENT, bool, False, - "Whether to store all new content in the configuration file." + self.OPT_MAP_PERSISTENT, + bool, + False, + "Whether to store all new content in the configuration file.", ) def configure(self, updated): @@ -88,23 +92,33 @@ class MappingAddon: if self.OPT_MAP_PERSISTENT in updated: self.persistent = updated[self.OPT_MAP_PERSISTENT] - def replace(self, soup: BeautifulSoup, css_sel: str, replace: BeautifulSoup) -> None: + def replace( + self, soup: BeautifulSoup, css_sel: str, replace: BeautifulSoup + ) -> None: """Replaces the content of soup that matches the css selector with the given replace content.""" for content in soup.select(css_sel): - self.logger.debug(f"replace \"{content}\" with \"{replace}\"") + self.logger.debug(f'replace "{content}" with "{replace}"') content.replace_with(copy.copy(replace)) - def apply_template(self, soup: BeautifulSoup, template: dict[str, BeautifulSoup]) -> None: + def apply_template( + self, soup: BeautifulSoup, template: dict[str, BeautifulSoup] + ) -> None: """Applies the given mapping template to the given soup.""" for css_sel, replace in template.items(): mapped = soup.select(css_sel) if not mapped: - self.logger.warning(f"Could not find \"{css_sel}\", can not freeze anything.") + self.logger.warning( + f'Could not find "{css_sel}", can not freeze anything.' + ) else: - self.replace(soup, css_sel, BeautifulSoup(replace, features=MappingAddonConfig.HTML_PARSER)) + self.replace( + soup, + css_sel, + BeautifulSoup(replace, features=MappingAddonConfig.HTML_PARSER), + ) def response(self, flow: HTTPFlow) -> None: - """If a response is received, check if we should replace some content. """ + """If a response is received, check if we should replace some content.""" try: templates = self.mapping_templates[flow] res = flow.response @@ -118,7 +132,9 @@ class MappingAddon: self.apply_template(content, template) res.content = content.encode(encoding) else: - self.logger.warning(f"Unsupported content type '{content_type}' or content encoding '{encoding}'") + self.logger.warning( + f"Unsupported content type '{content_type}' or content encoding '{encoding}'" + ) except KeyError: pass diff --git a/examples/contrib/webscanner_helper/proxyauth_selenium.py b/examples/contrib/webscanner_helper/proxyauth_selenium.py index 6ac1d94de..579fcc3d2 100644 --- a/examples/contrib/webscanner_helper/proxyauth_selenium.py +++ b/examples/contrib/webscanner_helper/proxyauth_selenium.py @@ -3,13 +3,15 @@ import logging import random import string import time -from typing import Any, cast +from typing import Any +from typing import cast + +from selenium import webdriver import mitmproxy.http from mitmproxy import flowfilter from mitmproxy import master from mitmproxy.script import concurrent -from selenium import webdriver logger = logging.getLogger(__name__) @@ -18,14 +20,14 @@ cookie_key_name = { "expires": "Expires", "domain": "Domain", "is_http_only": "HttpOnly", - "is_secure": "Secure" + "is_secure": "Secure", } def randomString(string_length=10): - """Generate a random string of fixed length """ + """Generate a random string of fixed length""" letters = string.ascii_lowercase - return ''.join(random.choice(letters) for i in range(string_length)) + return "".join(random.choice(letters) for i in range(string_length)) class AuthorizationOracle(abc.ABC): @@ -41,7 +43,7 @@ class AuthorizationOracle(abc.ABC): class SeleniumAddon: - """ This Addon can be used in combination with web application scanners in order to help them to authenticate + """This Addon can be used in combination with web application scanners in order to help them to authenticate against a web application. Since the authentication is highly dependant on the web application, this add-on includes the abstract method @@ -50,8 +52,7 @@ class SeleniumAddon: application. In addition, an authentication oracle which inherits from AuthorizationOracle should be created. """ - def __init__(self, fltr: str, domain: str, - auth_oracle: AuthorizationOracle): + def __init__(self, fltr: str, domain: str, auth_oracle: AuthorizationOracle): self.filter = flowfilter.parse(fltr) self.auth_oracle = auth_oracle self.domain = domain @@ -62,9 +63,8 @@ class SeleniumAddon: options.headless = True profile = webdriver.FirefoxProfile() - profile.set_preference('network.proxy.type', 0) - self.browser = webdriver.Firefox(firefox_profile=profile, - options=options) + profile.set_preference("network.proxy.type", 0) + self.browser = webdriver.Firefox(firefox_profile=profile, options=options) self.cookies: list[dict[str, str]] = [] def _login(self, flow): @@ -76,7 +76,9 @@ class SeleniumAddon: def request(self, flow: mitmproxy.http.HTTPFlow): if flow.request.is_replay: logger.warning("Caught replayed request: " + str(flow)) - if (not self.filter or self.filter(flow)) and self.auth_oracle.is_unauthorized_request(flow): + if ( + not self.filter or self.filter(flow) + ) and self.auth_oracle.is_unauthorized_request(flow): logger.debug("unauthorized request detected, perform login") self._login(flow) @@ -88,7 +90,7 @@ class SeleniumAddon: if self.auth_oracle.is_unauthorized_response(flow): self._login(flow) new_flow = flow.copy() - if master and hasattr(master, 'commands'): + if master and hasattr(master, "commands"): # cast necessary for mypy cast(Any, master).commands.call("replay.client", [new_flow]) count = 0 @@ -99,7 +101,9 @@ class SeleniumAddon: if new_flow.response: flow.response = new_flow.response else: - logger.warning("Could not call 'replay.client' command since master was not initialized yet.") + logger.warning( + "Could not call 'replay.client' command since master was not initialized yet." + ) if self.set_cookies and flow.response: logger.debug("set set-cookie header for response") @@ -124,7 +128,8 @@ class SeleniumAddon: def _set_request_cookies(self, flow: mitmproxy.http.HTTPFlow): if self.cookies: cookies = "; ".join( - map(lambda c: f"{c['name']}={c['value']}", self.cookies)) + map(lambda c: f"{c['name']}={c['value']}", self.cookies) + ) flow.request.headers["cookie"] = cookies @abc.abstractmethod diff --git a/examples/contrib/webscanner_helper/test_mapping.py b/examples/contrib/webscanner_helper/test_mapping.py index 340522837..c88b11983 100644 --- a/examples/contrib/webscanner_helper/test_mapping.py +++ b/examples/contrib/webscanner_helper/test_mapping.py @@ -1,15 +1,15 @@ -from typing import TextIO, Callable +from typing import Callable +from typing import TextIO from unittest import mock from unittest.mock import MagicMock +from examples.contrib.webscanner_helper.mapping import MappingAddon +from examples.contrib.webscanner_helper.mapping import MappingAddonConfig from mitmproxy.test import tflow from mitmproxy.test import tutils -from examples.contrib.webscanner_helper.mapping import MappingAddon, MappingAddonConfig - class TestConfig: - def test_config(self): assert MappingAddonConfig.HTML_PARSER == "html.parser" @@ -20,7 +20,6 @@ mapping_content = f'{{"{url}": {{"body": "{new_content}"}}}}' class TestMappingAddon: - def test_init(self, tmpdir): tmpfile = tmpdir.join("tmpfile") with open(tmpfile, "w") as tfile: @@ -36,8 +35,8 @@ class TestMappingAddon: loader = MagicMock() mapping.load(loader) - assert 'mapping_file' in str(loader.add_option.call_args_list) - assert 'map_persistent' in str(loader.add_option.call_args_list) + assert "mapping_file" in str(loader.add_option.call_args_list) + assert "map_persistent" in str(loader.add_option.call_args_list) def test_configure(self, tmpdir): tmpfile = tmpdir.join("tmpfile") @@ -45,7 +44,10 @@ class TestMappingAddon: tfile.write(mapping_content) mapping = MappingAddon(tmpfile) new_filename = "My new filename" - updated = {str(mapping.OPT_MAPPING_FILE): new_filename, str(mapping.OPT_MAP_PERSISTENT): True} + updated = { + str(mapping.OPT_MAPPING_FILE): new_filename, + str(mapping.OPT_MAP_PERSISTENT): True, + } open_mock = mock.mock_open(read_data="{}") with mock.patch("builtins.open", open_mock): @@ -161,5 +163,8 @@ class TestMappingAddon: with open(tmpfile, "w") as tfile: tfile.write("{}") mapping = MappingAddon(tmpfile, persistent=True) - with mock.patch('examples.complex.webscanner_helper.urldict.URLDict.dump', selfself.mock_dump): + with mock.patch( + "examples.complex.webscanner_helper.urldict.URLDict.dump", + selfself.mock_dump, + ): mapping.done() diff --git a/examples/contrib/webscanner_helper/test_proxyauth_selenium.py b/examples/contrib/webscanner_helper/test_proxyauth_selenium.py index 58e035068..e755c776c 100644 --- a/examples/contrib/webscanner_helper/test_proxyauth_selenium.py +++ b/examples/contrib/webscanner_helper/test_proxyauth_selenium.py @@ -3,16 +3,16 @@ from unittest.mock import MagicMock import pytest +from examples.contrib.webscanner_helper.proxyauth_selenium import AuthorizationOracle +from examples.contrib.webscanner_helper.proxyauth_selenium import logger +from examples.contrib.webscanner_helper.proxyauth_selenium import randomString +from examples.contrib.webscanner_helper.proxyauth_selenium import SeleniumAddon +from mitmproxy.http import HTTPFlow from mitmproxy.test import tflow from mitmproxy.test import tutils -from mitmproxy.http import HTTPFlow - -from examples.contrib.webscanner_helper.proxyauth_selenium import logger, randomString, AuthorizationOracle, \ - SeleniumAddon class TestRandomString: - def test_random_string(self): res = randomString() assert isinstance(res, str) @@ -36,8 +36,11 @@ oracle = AuthenticationOracleTest() @pytest.fixture(scope="module", autouse=True) def selenium_addon(request): - addon = SeleniumAddon(fltr=r"~u http://example\.com/login\.php", domain=r"~d http://example\.com", - auth_oracle=oracle) + addon = SeleniumAddon( + fltr=r"~u http://example\.com/login\.php", + domain=r"~d http://example\.com", + auth_oracle=oracle, + ) browser = MagicMock() addon.browser = browser yield addon @@ -49,11 +52,10 @@ def selenium_addon(request): class TestSeleniumAddon: - def test_request_replay(self, selenium_addon): f = tflow.tflow(resp=tutils.tresp()) f.request.is_replay = True - with mock.patch.object(logger, 'warning') as mock_warning: + with mock.patch.object(logger, "warning") as mock_warning: selenium_addon.request(f) mock_warning.assert_called() @@ -62,7 +64,7 @@ class TestSeleniumAddon: f.request.url = "http://example.com/login.php" selenium_addon.set_cookies = False assert not selenium_addon.set_cookies - with mock.patch.object(logger, 'debug') as mock_debug: + with mock.patch.object(logger, "debug") as mock_debug: selenium_addon.request(f) mock_debug.assert_called() assert selenium_addon.set_cookies @@ -79,9 +81,11 @@ class TestSeleniumAddon: f.request.url = "http://example.com/login.php" selenium_addon.set_cookies = False assert not selenium_addon.set_cookies - with mock.patch.object(logger, 'debug') as mock_debug: - with mock.patch('examples.complex.webscanner_helper.proxyauth_selenium.SeleniumAddon.login', - return_value=[{"name": "cookie", "value": "test"}]) as mock_login: + with mock.patch.object(logger, "debug") as mock_debug: + with mock.patch( + "examples.complex.webscanner_helper.proxyauth_selenium.SeleniumAddon.login", + return_value=[{"name": "cookie", "value": "test"}], + ) as mock_login: selenium_addon.request(f) mock_debug.assert_called() assert selenium_addon.set_cookies @@ -95,7 +99,7 @@ class TestSeleniumAddon: selenium_addon.set_cookies = False assert not selenium_addon.set_cookies - with mock.patch.object(logger, 'debug') as mock_debug: + with mock.patch.object(logger, "debug") as mock_debug: selenium_addon.request(f) mock_debug.assert_called() selenium_addon.filter = fltr @@ -105,8 +109,10 @@ class TestSeleniumAddon: f = tflow.tflow(resp=tutils.tresp()) f.request.url = "http://example.com/login.php" selenium_addon.set_cookies = False - with mock.patch('examples.complex.webscanner_helper.proxyauth_selenium.SeleniumAddon.login', - return_value=[]) as mock_login: + with mock.patch( + "examples.complex.webscanner_helper.proxyauth_selenium.SeleniumAddon.login", + return_value=[], + ) as mock_login: selenium_addon.response(f) mock_login.assert_called() @@ -114,7 +120,9 @@ class TestSeleniumAddon: f = tflow.tflow(resp=tutils.tresp()) f.request.url = "http://example.com/login.php" selenium_addon.set_cookies = False - with mock.patch('examples.complex.webscanner_helper.proxyauth_selenium.SeleniumAddon.login', - return_value=[{"name": "cookie", "value": "test"}]) as mock_login: + with mock.patch( + "examples.complex.webscanner_helper.proxyauth_selenium.SeleniumAddon.login", + return_value=[{"name": "cookie", "value": "test"}], + ) as mock_login: selenium_addon.response(f) mock_login.assert_called() diff --git a/examples/contrib/webscanner_helper/test_urldict.py b/examples/contrib/webscanner_helper/test_urldict.py index 102c9ee35..066566237 100644 --- a/examples/contrib/webscanner_helper/test_urldict.py +++ b/examples/contrib/webscanner_helper/test_urldict.py @@ -1,5 +1,6 @@ -from mitmproxy.test import tflow, tutils from examples.contrib.webscanner_helper.urldict import URLDict +from mitmproxy.test import tflow +from mitmproxy.test import tutils url = "http://10.10.10.10" new_content_body = "New Body" @@ -11,11 +12,10 @@ input_file_content_error = f'{{"{url_error}": {content}}}' class TestUrlDict: - def test_urldict_empty(self): urldict = URLDict() dump = urldict.dumps() - assert dump == '{}' + assert dump == "{}" def test_urldict_loads(self): urldict = URLDict.loads(input_file_content) diff --git a/examples/contrib/webscanner_helper/test_urlindex.py b/examples/contrib/webscanner_helper/test_urlindex.py index 058a36068..d3dd5f480 100644 --- a/examples/contrib/webscanner_helper/test_urlindex.py +++ b/examples/contrib/webscanner_helper/test_urlindex.py @@ -4,17 +4,18 @@ from pathlib import Path from unittest import mock from unittest.mock import patch +from examples.contrib.webscanner_helper.urlindex import filter_404 +from examples.contrib.webscanner_helper.urlindex import JSONUrlIndexWriter +from examples.contrib.webscanner_helper.urlindex import SetEncoder +from examples.contrib.webscanner_helper.urlindex import TextUrlIndexWriter +from examples.contrib.webscanner_helper.urlindex import UrlIndexAddon +from examples.contrib.webscanner_helper.urlindex import UrlIndexWriter +from examples.contrib.webscanner_helper.urlindex import WRITER from mitmproxy.test import tflow from mitmproxy.test import tutils -from examples.contrib.webscanner_helper.urlindex import UrlIndexWriter, SetEncoder, JSONUrlIndexWriter, \ - TextUrlIndexWriter, WRITER, \ - filter_404, \ - UrlIndexAddon - class TestBaseClass: - @patch.multiple(UrlIndexWriter, __abstractmethods__=set()) def test_base_class(self, tmpdir): tmpfile = tmpdir.join("tmpfile") @@ -25,14 +26,13 @@ class TestBaseClass: class TestSetEncoder: - def test_set_encoder_set(self): test_set = {"foo", "bar", "42"} result = SetEncoder.default(SetEncoder(), test_set) assert isinstance(result, list) - assert 'foo' in result - assert 'bar' in result - assert '42' in result + assert "foo" in result + assert "bar" in result + assert "42" in result def test_set_encoder_str(self): test_str = "test" @@ -45,18 +45,18 @@ class TestSetEncoder: class TestJSONUrlIndexWriter: - def test_load(self, tmpdir): tmpfile = tmpdir.join("tmpfile") with open(tmpfile, "w") as tfile: tfile.write( - "{\"http://example.com:80\": {\"/\": {\"GET\": [301]}}, \"http://www.example.com:80\": {\"/\": {\"GET\": [302]}}}") + '{"http://example.com:80": {"/": {"GET": [301]}}, "http://www.example.com:80": {"/": {"GET": [302]}}}' + ) writer = JSONUrlIndexWriter(filename=tmpfile) writer.load() - assert 'http://example.com:80' in writer.host_urls - assert '/' in writer.host_urls['http://example.com:80'] - assert 'GET' in writer.host_urls['http://example.com:80']['/'] - assert 301 in writer.host_urls['http://example.com:80']['/']['GET'] + assert "http://example.com:80" in writer.host_urls + assert "/" in writer.host_urls["http://example.com:80"] + assert "GET" in writer.host_urls["http://example.com:80"]["/"] + assert 301 in writer.host_urls["http://example.com:80"]["/"]["GET"] def test_load_empty(self, tmpdir): tmpfile = tmpdir.join("tmpfile") @@ -102,7 +102,8 @@ class TestTestUrlIndexWriter: tmpfile = tmpdir.join("tmpfile") with open(tmpfile, "w") as tfile: tfile.write( - "2020-04-22T05:41:08.679231 STATUS: 200 METHOD: GET URL:http://example.com") + "2020-04-22T05:41:08.679231 STATUS: 200 METHOD: GET URL:http://example.com" + ) writer = TextUrlIndexWriter(filename=tmpfile) writer.load() assert True @@ -173,7 +174,6 @@ class TestFilter: class TestUrlIndexAddon: - def test_init(self, tmpdir): tmpfile = tmpdir.join("tmpfile") UrlIndexAddon(tmpfile) @@ -202,7 +202,9 @@ class TestUrlIndexAddon: tfile.write("") url_index = UrlIndexAddon(tmpfile, append=False) f = tflow.tflow(resp=tutils.tresp()) - with mock.patch('examples.complex.webscanner_helper.urlindex.JSONUrlIndexWriter.add_url'): + with mock.patch( + "examples.complex.webscanner_helper.urlindex.JSONUrlIndexWriter.add_url" + ): url_index.response(f) assert not Path(tmpfile).exists() @@ -210,7 +212,9 @@ class TestUrlIndexAddon: tmpfile = tmpdir.join("tmpfile") url_index = UrlIndexAddon(tmpfile) f = tflow.tflow(resp=tutils.tresp()) - with mock.patch('examples.complex.webscanner_helper.urlindex.JSONUrlIndexWriter.add_url') as mock_add_url: + with mock.patch( + "examples.complex.webscanner_helper.urlindex.JSONUrlIndexWriter.add_url" + ) as mock_add_url: url_index.response(f) mock_add_url.assert_called() @@ -229,6 +233,8 @@ class TestUrlIndexAddon: def test_done(self, tmpdir): tmpfile = tmpdir.join("tmpfile") url_index = UrlIndexAddon(tmpfile) - with mock.patch('examples.complex.webscanner_helper.urlindex.JSONUrlIndexWriter.save') as mock_save: + with mock.patch( + "examples.complex.webscanner_helper.urlindex.JSONUrlIndexWriter.save" + ) as mock_save: url_index.done() mock_save.assert_called() diff --git a/examples/contrib/webscanner_helper/test_urlinjection.py b/examples/contrib/webscanner_helper/test_urlinjection.py index b1c412d21..b6a841721 100644 --- a/examples/contrib/webscanner_helper/test_urlinjection.py +++ b/examples/contrib/webscanner_helper/test_urlinjection.py @@ -1,20 +1,22 @@ import json from unittest import mock +from examples.contrib.webscanner_helper.urlinjection import HTMLInjection +from examples.contrib.webscanner_helper.urlinjection import InjectionGenerator +from examples.contrib.webscanner_helper.urlinjection import logger +from examples.contrib.webscanner_helper.urlinjection import RobotsInjection +from examples.contrib.webscanner_helper.urlinjection import SitemapInjection +from examples.contrib.webscanner_helper.urlinjection import UrlInjectionAddon from mitmproxy import flowfilter from mitmproxy.test import tflow from mitmproxy.test import tutils -from examples.contrib.webscanner_helper.urlinjection import InjectionGenerator, HTMLInjection, RobotsInjection, \ - SitemapInjection, \ - UrlInjectionAddon, logger - index = json.loads( - "{\"http://example.com:80\": {\"/\": {\"GET\": [301]}}, \"http://www.example.com:80\": {\"/test\": {\"POST\": [302]}}}") + '{"http://example.com:80": {"/": {"GET": [301]}}, "http://www.example.com:80": {"/test": {"POST": [302]}}}' +) class TestInjectionGenerator: - def test_inject(self): f = tflow.tflow(resp=tutils.tresp()) injection_generator = InjectionGenerator() @@ -23,12 +25,11 @@ class TestInjectionGenerator: class TestHTMLInjection: - def test_inject_not404(self): html_injection = HTMLInjection() f = tflow.tflow(resp=tutils.tresp()) - with mock.patch.object(logger, 'warning') as mock_warning: + with mock.patch.object(logger, "warning") as mock_warning: html_injection.inject(index, f) assert mock_warning.called @@ -57,12 +58,11 @@ class TestHTMLInjection: class TestRobotsInjection: - def test_inject_not404(self): robots_injection = RobotsInjection() f = tflow.tflow(resp=tutils.tresp()) - with mock.patch.object(logger, 'warning') as mock_warning: + with mock.patch.object(logger, "warning") as mock_warning: robots_injection.inject(index, f) assert mock_warning.called @@ -76,12 +76,11 @@ class TestRobotsInjection: class TestSitemapInjection: - def test_inject_not404(self): sitemap_injection = SitemapInjection() f = tflow.tflow(resp=tutils.tresp()) - with mock.patch.object(logger, 'warning') as mock_warning: + with mock.patch.object(logger, "warning") as mock_warning: sitemap_injection.inject(index, f) assert mock_warning.called @@ -89,19 +88,22 @@ class TestSitemapInjection: sitemap_injection = SitemapInjection() f = tflow.tflow(resp=tutils.tresp()) f.response.status_code = 404 - assert "http://example.com:80/" not in str(f.response.content) + assert "http://example.com:80/" not in str( + f.response.content + ) sitemap_injection.inject(index, f) assert "http://example.com:80/" in str(f.response.content) class TestUrlInjectionAddon: - def test_init(self, tmpdir): tmpfile = tmpdir.join("tmpfile") with open(tmpfile, "w") as tfile: json.dump(index, tfile) flt = f"~u .*/site.html$" - url_injection = UrlInjectionAddon(f"~u .*/site.html$", tmpfile, HTMLInjection(insert=True)) + url_injection = UrlInjectionAddon( + f"~u .*/site.html$", tmpfile, HTMLInjection(insert=True) + ) assert "http://example.com:80" in url_injection.url_store fltr = flowfilter.parse(flt) f = tflow.tflow(resp=tutils.tresp()) diff --git a/examples/contrib/webscanner_helper/test_watchdog.py b/examples/contrib/webscanner_helper/test_watchdog.py index f6a34a61b..d5382072b 100644 --- a/examples/contrib/webscanner_helper/test_watchdog.py +++ b/examples/contrib/webscanner_helper/test_watchdog.py @@ -1,18 +1,17 @@ +import multiprocessing import time from pathlib import Path from unittest import mock +from examples.contrib.webscanner_helper.watchdog import logger +from examples.contrib.webscanner_helper.watchdog import WatchdogAddon from mitmproxy.connections import ServerConnection from mitmproxy.exceptions import HttpSyntaxException from mitmproxy.test import tflow from mitmproxy.test import tutils -import multiprocessing - -from examples.contrib.webscanner_helper.watchdog import WatchdogAddon, logger class TestWatchdog: - def test_init_file(self, tmpdir): tmpfile = tmpdir.join("tmpfile") with open(tmpfile, "w") as tfile: @@ -35,14 +34,18 @@ class TestWatchdog: def test_serverconnect(self, tmpdir): event = multiprocessing.Event() w = WatchdogAddon(event, Path(tmpdir), timeout=10) - with mock.patch('mitmproxy.connections.ServerConnection.settimeout') as mock_set_timeout: + with mock.patch( + "mitmproxy.connections.ServerConnection.settimeout" + ) as mock_set_timeout: w.serverconnect(ServerConnection("127.0.0.1")) mock_set_timeout.assert_called() def test_serverconnect_None(self, tmpdir): event = multiprocessing.Event() w = WatchdogAddon(event, Path(tmpdir)) - with mock.patch('mitmproxy.connections.ServerConnection.settimeout') as mock_set_timeout: + with mock.patch( + "mitmproxy.connections.ServerConnection.settimeout" + ) as mock_set_timeout: w.serverconnect(ServerConnection("127.0.0.1")) assert not mock_set_timeout.called @@ -52,7 +55,7 @@ class TestWatchdog: f = tflow.tflow(resp=tutils.tresp()) f.error = "Test Error" - with mock.patch.object(logger, 'error') as mock_error: + with mock.patch.object(logger, "error") as mock_error: open_mock = mock.mock_open() with mock.patch("pathlib.Path.open", open_mock, create=True): w.error(f) @@ -66,7 +69,7 @@ class TestWatchdog: f.error = HttpSyntaxException() assert isinstance(f.error, HttpSyntaxException) - with mock.patch.object(logger, 'error') as mock_error: + with mock.patch.object(logger, "error") as mock_error: open_mock = mock.mock_open() with mock.patch("pathlib.Path.open", open_mock, create=True): w.error(f) @@ -79,6 +82,6 @@ class TestWatchdog: assert w.not_in_timeout(None, None) assert w.not_in_timeout(time.time, None) - with mock.patch('time.time', return_value=5): + with mock.patch("time.time", return_value=5): assert not w.not_in_timeout(3, 20) assert w.not_in_timeout(3, 1) diff --git a/examples/contrib/webscanner_helper/urldict.py b/examples/contrib/webscanner_helper/urldict.py index 7e990f1af..a5b02af21 100644 --- a/examples/contrib/webscanner_helper/urldict.py +++ b/examples/contrib/webscanner_helper/urldict.py @@ -1,7 +1,12 @@ import itertools import json +from collections.abc import Generator from collections.abc import MutableMapping -from typing import Any, Callable, Generator, TextIO, Union, cast +from typing import Any +from typing import Callable +from typing import cast +from typing import TextIO +from typing import Union from mitmproxy import flowfilter from mitmproxy.http import HTTPFlow @@ -76,7 +81,7 @@ class URLDict(MutableMapping): def _dump(self, value_dumper: Callable = f_id) -> dict: dumped: dict[Union[flowfilter.TFilter, str], Any] = {} for fltr, value in self.store.items(): - if hasattr(fltr, 'pattern'): + if hasattr(fltr, "pattern"): # cast necessary for mypy dumped[cast(Any, fltr).pattern] = value_dumper(value) else: diff --git a/examples/contrib/webscanner_helper/urlindex.py b/examples/contrib/webscanner_helper/urlindex.py index 650e47c01..09f9ef2e8 100644 --- a/examples/contrib/webscanner_helper/urlindex.py +++ b/examples/contrib/webscanner_helper/urlindex.py @@ -3,7 +3,8 @@ import datetime import json import logging from pathlib import Path -from typing import Optional, Union +from typing import Optional +from typing import Union from mitmproxy import flowfilter from mitmproxy.http import HTTPFlow @@ -67,7 +68,9 @@ class JSONUrlIndexWriter(UrlIndexWriter): res = flow.response if req is not None and res is not None: - urls = self.host_urls.setdefault(f"{req.scheme}://{req.host}:{req.port}", dict()) + urls = self.host_urls.setdefault( + f"{req.scheme}://{req.host}:{req.port}", dict() + ) methods = urls.setdefault(req.path, {}) codes = methods.setdefault(req.method, set()) codes.add(res.status_code) @@ -88,8 +91,10 @@ class TextUrlIndexWriter(UrlIndexWriter): req = flow.request if res is not None and req is not None: with self.filepath.open("a+") as f: - f.write(f"{datetime.datetime.utcnow().isoformat()} STATUS: {res.status_code} METHOD: " - f"{req.method} URL:{req.url}\n") + f.write( + f"{datetime.datetime.utcnow().isoformat()} STATUS: {res.status_code} METHOD: " + f"{req.method} URL:{req.url}\n" + ) def save(self): pass @@ -120,9 +125,14 @@ class UrlIndexAddon: OPT_APPEND = "URLINDEX_APPEND" OPT_INDEX_FILTER = "URLINDEX_FILTER" - def __init__(self, file_path: Union[str, Path], append: bool = True, - index_filter: Union[str, flowfilter.TFilter] = filter_404, index_format: str = "json"): - """ Initializes the urlindex add-on. + def __init__( + self, + file_path: Union[str, Path], + append: bool = True, + index_filter: Union[str, flowfilter.TFilter] = filter_404, + index_format: str = "json", + ): + """Initializes the urlindex add-on. Args: file_path: Path to file to which the URL index will be written. Can either be given as str or Path. @@ -153,7 +163,7 @@ class UrlIndexAddon: def response(self, flow: HTTPFlow): """Checks if the response should be included in the URL based on the index_filter and adds it to the URL index - if appropriate. + if appropriate. """ if isinstance(self.index_filter, str) or self.index_filter is None: raise ValueError("Invalid filter expression.") diff --git a/examples/contrib/webscanner_helper/urlinjection.py b/examples/contrib/webscanner_helper/urlinjection.py index 6c4f98291..8cd96313d 100644 --- a/examples/contrib/webscanner_helper/urlinjection.py +++ b/examples/contrib/webscanner_helper/urlinjection.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__name__) class InjectionGenerator: """Abstract class for an generator of the injection content in order to inject the URL index.""" + ENCODING = "UTF8" @abc.abstractmethod @@ -32,11 +33,11 @@ class HTMLInjection(InjectionGenerator): @classmethod def _form_html(cls, url): - return f"
" + return f'
' @classmethod def _link_html(cls, url): - return f"link to {url}" + return f'link to {url}' @classmethod def index_html(cls, index): @@ -54,9 +55,9 @@ class HTMLInjection(InjectionGenerator): @classmethod def landing_page(cls, index): return ( - "" - + cls.index_html(index) - + "" + '' + + cls.index_html(index) + + "" ) def inject(self, index, flow: HTTPFlow): @@ -64,19 +65,21 @@ class HTMLInjection(InjectionGenerator): if flow.response.status_code != 404 and not self.insert: logger.warning( f"URL '{flow.request.url}' didn't return 404 status, " - f"index page would overwrite valid page.") + f"index page would overwrite valid page." + ) elif self.insert: - content = (flow.response - .content - .decode(self.ENCODING, "backslashreplace")) + content = flow.response.content.decode( + self.ENCODING, "backslashreplace" + ) if "" in content: - content = content.replace("", self.index_html(index) + "") + content = content.replace( + "", self.index_html(index) + "" + ) else: content += self.index_html(index) flow.response.content = content.encode(self.ENCODING) else: - flow.response.content = (self.landing_page(index) - .encode(self.ENCODING)) + flow.response.content = self.landing_page(index).encode(self.ENCODING) class RobotsInjection(InjectionGenerator): @@ -98,11 +101,12 @@ class RobotsInjection(InjectionGenerator): if flow.response.status_code != 404: logger.warning( f"URL '{flow.request.url}' didn't return 404 status, " - f"index page would overwrite valid page.") + f"index page would overwrite valid page." + ) else: - flow.response.content = self.robots_txt(index, - self.directive).encode( - self.ENCODING) + flow.response.content = self.robots_txt(index, self.directive).encode( + self.ENCODING + ) class SitemapInjection(InjectionGenerator): @@ -111,7 +115,8 @@ class SitemapInjection(InjectionGenerator): @classmethod def sitemap(cls, index): lines = [ - ""] + '' + ] for scheme_netloc, paths in index.items(): for path, methods in paths.items(): url = scheme_netloc + path @@ -124,13 +129,14 @@ class SitemapInjection(InjectionGenerator): if flow.response.status_code != 404: logger.warning( f"URL '{flow.request.url}' didn't return 404 status, " - f"index page would overwrite valid page.") + f"index page would overwrite valid page." + ) else: flow.response.content = self.sitemap(index).encode(self.ENCODING) class UrlInjectionAddon: - """ The UrlInjection add-on can be used in combination with web application scanners to improve their crawling + """The UrlInjection add-on can be used in combination with web application scanners to improve their crawling performance. The given URls will be injected into the web application. With this, web application scanners can find pages to @@ -143,8 +149,9 @@ class UrlInjectionAddon: The URL index needed for the injection can be generated by the UrlIndex Add-on. """ - def __init__(self, flt: str, url_index_file: str, - injection_gen: InjectionGenerator): + def __init__( + self, flt: str, url_index_file: str, injection_gen: InjectionGenerator + ): """Initializes the UrlIndex add-on. Args: @@ -168,5 +175,7 @@ class UrlInjectionAddon: self.injection_gen.inject(self.url_store, flow) flow.response.status_code = 200 flow.response.headers["content-type"] = "text/html" - logger.debug(f"Set status code to 200 and set content to logged " - f"urls. Method: {self.injection_gen}") + logger.debug( + f"Set status code to 200 and set content to logged " + f"urls. Method: {self.injection_gen}" + ) diff --git a/examples/contrib/webscanner_helper/watchdog.py b/examples/contrib/webscanner_helper/watchdog.py index 48f58d9c0..361f72a43 100644 --- a/examples/contrib/webscanner_helper/watchdog.py +++ b/examples/contrib/webscanner_helper/watchdog.py @@ -1,19 +1,20 @@ +import logging import pathlib import time -import logging from datetime import datetime from typing import Union import mitmproxy.connections import mitmproxy.http -from mitmproxy.addons.export import curl_command, raw +from mitmproxy.addons.export import curl_command +from mitmproxy.addons.export import raw from mitmproxy.exceptions import HttpSyntaxException logger = logging.getLogger(__name__) -class WatchdogAddon(): - """ The Watchdog Add-on can be used in combination with web application scanners in oder to check if the device +class WatchdogAddon: + """The Watchdog Add-on can be used in combination with web application scanners in oder to check if the device under test responds correctls to the scanner's responses. The Watchdog Add-on checks if the device under test responds correctly to the scanner's responses. @@ -45,10 +46,14 @@ class WatchdogAddon(): @classmethod def not_in_timeout(cls, last_triggered, timeout): """Checks if current error lies not in timeout after last trigger (potential reset of connection).""" - return last_triggered is None or timeout is None or (time.time() - last_triggered > timeout) + return ( + last_triggered is None + or timeout is None + or (time.time() - last_triggered > timeout) + ) def error(self, flow): - """ Checks if the watchdog will be triggered. + """Checks if the watchdog will be triggered. Only triggers watchdog for timeouts after last reset and if flow.error is set (shows that error is a server error). Ignores HttpSyntaxException Errors since this can be triggered on purpose by web application scanner. @@ -56,8 +61,11 @@ class WatchdogAddon(): Args: flow: mitmproxy.http.flow """ - if (self.not_in_timeout(self.last_trigger, self.timeout) - and flow.error is not None and not isinstance(flow.error, HttpSyntaxException)): + if ( + self.not_in_timeout(self.last_trigger, self.timeout) + and flow.error is not None + and not isinstance(flow.error, HttpSyntaxException) + ): self.last_trigger = time.time() logger.error(f"Watchdog triggered! Cause: {flow}") @@ -65,7 +73,11 @@ class WatchdogAddon(): # save the request which might have caused the problem if flow.request: - with (self.flow_dir / f"{datetime.utcnow().isoformat()}.curl").open("w") as f: + with (self.flow_dir / f"{datetime.utcnow().isoformat()}.curl").open( + "w" + ) as f: f.write(curl_command(flow)) - with (self.flow_dir / f"{datetime.utcnow().isoformat()}.raw").open("wb") as f: + with (self.flow_dir / f"{datetime.utcnow().isoformat()}.raw").open( + "wb" + ) as f: f.write(raw(flow)) diff --git a/examples/contrib/xss_scanner.py b/examples/contrib/xss_scanner.py index 287982fb3..c942281c8 100644 --- a/examples/contrib/xss_scanner.py +++ b/examples/contrib/xss_scanner.py @@ -38,7 +38,9 @@ import logging import re import socket from html.parser import HTMLParser -from typing import NamedTuple, Optional, Union +from typing import NamedTuple +from typing import Optional +from typing import Union from urllib.parse import urlparse import requests @@ -82,14 +84,14 @@ Cookies = dict[str, str] def get_cookies(flow: http.HTTPFlow) -> Cookies: - """ Return a dict going from cookie names to cookie values - - Note that it includes both the cookies sent in the original request and - the cookies sent by the server """ + """Return a dict going from cookie names to cookie values + - Note that it includes both the cookies sent in the original request and + the cookies sent by the server""" return {name: value for name, value in flow.request.cookies.fields} def find_unclaimed_URLs(body, requestUrl): - """ Look for unclaimed URLs in script tags and log them if found""" + """Look for unclaimed URLs in script tags and log them if found""" def getValue(attrs: list[tuple[str, str]], attrName: str) -> Optional[str]: for name, value in attrs: @@ -101,9 +103,15 @@ def find_unclaimed_URLs(body, requestUrl): script_URLs: list[str] = [] def handle_starttag(self, tag, attrs): - if (tag == "script" or tag == "iframe") and "src" in [name for name, value in attrs]: + if (tag == "script" or tag == "iframe") and "src" in [ + name for name, value in attrs + ]: self.script_URLs.append(getValue(attrs, "src")) - if tag == "link" and getValue(attrs, "rel") == "stylesheet" and "href" in [name for name, value in attrs]: + if ( + tag == "link" + and getValue(attrs, "rel") == "stylesheet" + and "href" in [name for name, value in attrs] + ): self.script_URLs.append(getValue(attrs, "href")) parser = ScriptURLExtractor() @@ -114,17 +122,21 @@ def find_unclaimed_URLs(body, requestUrl): try: socket.gethostbyname(domain) except socket.gaierror: - logging.error(f"XSS found in {requestUrl} due to unclaimed URL \"{url}\".") + logging.error(f'XSS found in {requestUrl} due to unclaimed URL "{url}".') -def test_end_of_URL_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData: - """ Test the given URL for XSS via injection onto the end of the URL and - log the XSS if found """ +def test_end_of_URL_injection( + original_body: str, request_URL: str, cookies: Cookies +) -> VulnData: + """Test the given URL for XSS via injection onto the end of the URL and + log the XSS if found""" parsed_URL = urlparse(request_URL) path = parsed_URL.path if path != "" and path[-1] != "/": # ensure the path ends in a / path += "/" - path += FULL_PAYLOAD.decode('utf-8') # the path must be a string while the payload is bytes + path += FULL_PAYLOAD.decode( + "utf-8" + ) # the path must be a string while the payload is bytes url = parsed_URL._replace(path=path).geturl() body = requests.get(url, cookies=cookies).text.lower() xss_info = get_XSS_data(body, url, "End of URL") @@ -132,31 +144,42 @@ def test_end_of_URL_injection(original_body: str, request_URL: str, cookies: Coo return xss_info, sqli_info -def test_referer_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData: - """ Test the given URL for XSS via injection into the referer and - log the XSS if found """ - body = requests.get(request_URL, headers={'referer': FULL_PAYLOAD}, cookies=cookies).text.lower() +def test_referer_injection( + original_body: str, request_URL: str, cookies: Cookies +) -> VulnData: + """Test the given URL for XSS via injection into the referer and + log the XSS if found""" + body = requests.get( + request_URL, headers={"referer": FULL_PAYLOAD}, cookies=cookies + ).text.lower() xss_info = get_XSS_data(body, request_URL, "Referer") sqli_info = get_SQLi_data(body, original_body, request_URL, "Referer") return xss_info, sqli_info -def test_user_agent_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData: - """ Test the given URL for XSS via injection into the user agent and - log the XSS if found """ - body = requests.get(request_URL, headers={'User-Agent': FULL_PAYLOAD}, cookies=cookies).text.lower() +def test_user_agent_injection( + original_body: str, request_URL: str, cookies: Cookies +) -> VulnData: + """Test the given URL for XSS via injection into the user agent and + log the XSS if found""" + body = requests.get( + request_URL, headers={"User-Agent": FULL_PAYLOAD}, cookies=cookies + ).text.lower() xss_info = get_XSS_data(body, request_URL, "User Agent") sqli_info = get_SQLi_data(body, original_body, request_URL, "User Agent") return xss_info, sqli_info def test_query_injection(original_body: str, request_URL: str, cookies: Cookies): - """ Test the given URL for XSS via injection into URL queries and - log the XSS if found """ + """Test the given URL for XSS via injection into URL queries and + log the XSS if found""" parsed_URL = urlparse(request_URL) query_string = parsed_URL.query # queries is a list of parameters where each parameter is set to the payload - queries = [query.split("=")[0] + "=" + FULL_PAYLOAD.decode('utf-8') for query in query_string.split("&")] + queries = [ + query.split("=")[0] + "=" + FULL_PAYLOAD.decode("utf-8") + for query in query_string.split("&") + ] new_query_string = "&".join(queries) new_URL = parsed_URL._replace(query=new_query_string).geturl() body = requests.get(new_URL, cookies=cookies).text.lower() @@ -166,7 +189,7 @@ def test_query_injection(original_body: str, request_URL: str, cookies: Cookies) def log_XSS_data(xss_info: Optional[XSSData]) -> None: - """ Log information about the given XSS to mitmproxy """ + """Log information about the given XSS to mitmproxy""" # If it is None, then there is no info to log if not xss_info: return @@ -178,7 +201,7 @@ def log_XSS_data(xss_info: Optional[XSSData]) -> None: def log_SQLi_data(sqli_info: Optional[SQLiData]) -> None: - """ Log information about the given SQLi to mitmproxy """ + """Log information about the given SQLi to mitmproxy""" if not sqli_info: return logging.error("===== SQLi Found =====") @@ -189,51 +212,88 @@ def log_SQLi_data(sqli_info: Optional[SQLiData]) -> None: return -def get_SQLi_data(new_body: str, original_body: str, request_URL: str, injection_point: str) -> Optional[SQLiData]: - """ Return a SQLiDict if there is a SQLi otherwise return None - String String URL String -> (SQLiDict or None) """ +def get_SQLi_data( + new_body: str, original_body: str, request_URL: str, injection_point: str +) -> Optional[SQLiData]: + """Return a SQLiDict if there is a SQLi otherwise return None + String String URL String -> (SQLiDict or None)""" # Regexes taken from Damn Small SQLi Scanner: https://github.com/stamparm/DSSS/blob/master/dsss.py#L17 DBMS_ERRORS = { - "MySQL": (r"SQL syntax.*MySQL", r"Warning.*mysql_.*", r"valid MySQL result", r"MySqlClient\."), - "PostgreSQL": (r"PostgreSQL.*ERROR", r"Warning.*\Wpg_.*", r"valid PostgreSQL result", r"Npgsql\."), - "Microsoft SQL Server": (r"Driver.* SQL[\-\_\ ]*Server", r"OLE DB.* SQL Server", r"(\W|\A)SQL Server.*Driver", - r"Warning.*mssql_.*", r"(\W|\A)SQL Server.*[0-9a-fA-F]{8}", - r"(?s)Exception.*\WSystem\.Data\.SqlClient\.", r"(?s)Exception.*\WRoadhouse\.Cms\."), - "Microsoft Access": (r"Microsoft Access Driver", r"JET Database Engine", r"Access Database Engine"), - "Oracle": (r"\bORA-[0-9][0-9][0-9][0-9]", r"Oracle error", r"Oracle.*Driver", r"Warning.*\Woci_.*", r"Warning.*\Wora_.*"), + "MySQL": ( + r"SQL syntax.*MySQL", + r"Warning.*mysql_.*", + r"valid MySQL result", + r"MySqlClient\.", + ), + "PostgreSQL": ( + r"PostgreSQL.*ERROR", + r"Warning.*\Wpg_.*", + r"valid PostgreSQL result", + r"Npgsql\.", + ), + "Microsoft SQL Server": ( + r"Driver.* SQL[\-\_\ ]*Server", + r"OLE DB.* SQL Server", + r"(\W|\A)SQL Server.*Driver", + r"Warning.*mssql_.*", + r"(\W|\A)SQL Server.*[0-9a-fA-F]{8}", + r"(?s)Exception.*\WSystem\.Data\.SqlClient\.", + r"(?s)Exception.*\WRoadhouse\.Cms\.", + ), + "Microsoft Access": ( + r"Microsoft Access Driver", + r"JET Database Engine", + r"Access Database Engine", + ), + "Oracle": ( + r"\bORA-[0-9][0-9][0-9][0-9]", + r"Oracle error", + r"Oracle.*Driver", + r"Warning.*\Woci_.*", + r"Warning.*\Wora_.*", + ), "IBM DB2": (r"CLI Driver.*DB2", r"DB2 SQL error", r"\bdb2_\w+\("), - "SQLite": (r"SQLite/JDBCDriver", r"SQLite.Exception", r"System.Data.SQLite.SQLiteException", r"Warning.*sqlite_.*", - r"Warning.*SQLite3::", r"\[SQLITE_ERROR\]"), - "Sybase": (r"(?i)Warning.*sybase.*", r"Sybase message", r"Sybase.*Server message.*"), + "SQLite": ( + r"SQLite/JDBCDriver", + r"SQLite.Exception", + r"System.Data.SQLite.SQLiteException", + r"Warning.*sqlite_.*", + r"Warning.*SQLite3::", + r"\[SQLITE_ERROR\]", + ), + "Sybase": ( + r"(?i)Warning.*sybase.*", + r"Sybase message", + r"Sybase.*Server message.*", + ), } for dbms, regexes in DBMS_ERRORS.items(): for regex in regexes: # type: ignore - if re.search(regex, new_body, re.IGNORECASE) and not re.search(regex, original_body, re.IGNORECASE): - return SQLiData(request_URL, - injection_point, - regex, - dbms) + if re.search(regex, new_body, re.IGNORECASE) and not re.search( + regex, original_body, re.IGNORECASE + ): + return SQLiData(request_URL, injection_point, regex, dbms) return None # A qc is either ' or " -def inside_quote(qc: str, substring_bytes: bytes, text_index: int, body_bytes: bytes) -> bool: - """ Whether the Numberth occurrence of the first string in the second - string is inside quotes as defined by the supplied QuoteChar """ - substring = substring_bytes.decode('utf-8') - body = body_bytes.decode('utf-8') +def inside_quote( + qc: str, substring_bytes: bytes, text_index: int, body_bytes: bytes +) -> bool: + """Whether the Numberth occurrence of the first string in the second + string is inside quotes as defined by the supplied QuoteChar""" + substring = substring_bytes.decode("utf-8") + body = body_bytes.decode("utf-8") num_substrings_found = 0 in_quote = False for index, char in enumerate(body): # Whether the next chunk of len(substring) chars is the substring - next_part_is_substring = ( - (not (index + len(substring) > len(body))) and - (body[index:index + len(substring)] == substring) + next_part_is_substring = (not (index + len(substring) > len(body))) and ( + body[index : index + len(substring)] == substring ) # Whether this char is escaped with a \ - is_not_escaped = ( - (index - 1 < 0 or index - 1 > len(body)) or - (body[index - 1] != "\\") + is_not_escaped = (index - 1 < 0 or index - 1 > len(body)) or ( + body[index - 1] != "\\" ) if char == qc and is_not_escaped: in_quote = not in_quote @@ -245,25 +305,27 @@ def inside_quote(qc: str, substring_bytes: bytes, text_index: int, body_bytes: b def paths_to_text(html: str, string: str) -> list[str]: - """ Return list of Paths to a given str in the given HTML tree - - Note that it does a BFS """ + """Return list of Paths to a given str in the given HTML tree + - Note that it does a BFS""" def remove_last_occurence_of_sub_string(string: str, substr: str) -> str: - """ Delete the last occurrence of substr from str + """Delete the last occurrence of substr from str String String -> String """ index = string.rfind(substr) - return string[:index] + string[index + len(substr):] + return string[:index] + string[index + len(substr) :] class PathHTMLParser(HTMLParser): currentPath = "" paths: list[str] = [] def handle_starttag(self, tag, attrs): - self.currentPath += ("/" + tag) + self.currentPath += "/" + tag def handle_endtag(self, tag): - self.currentPath = remove_last_occurence_of_sub_string(self.currentPath, "/" + tag) + self.currentPath = remove_last_occurence_of_sub_string( + self.currentPath, "/" + tag + ) def handle_data(self, data): if string in data: @@ -274,13 +336,15 @@ def paths_to_text(html: str, string: str) -> list[str]: return parser.paths -def get_XSS_data(body: Union[str, bytes], request_URL: str, injection_point: str) -> Optional[XSSData]: - """ Return a XSSDict if there is a XSS otherwise return None """ +def get_XSS_data( + body: Union[str, bytes], request_URL: str, injection_point: str +) -> Optional[XSSData]: + """Return a XSSDict if there is a XSS otherwise return None""" def in_script(text, index, body) -> bool: - """ Whether the Numberth occurrence of the first string in the second - string is inside a script tag """ - paths = paths_to_text(body.decode('utf-8'), text.decode("utf-8")) + """Whether the Numberth occurrence of the first string in the second + string is inside a script tag""" + paths = paths_to_text(body.decode("utf-8"), text.decode("utf-8")) try: path = paths[index] return "script" in path @@ -288,12 +352,12 @@ def get_XSS_data(body: Union[str, bytes], request_URL: str, injection_point: str return False def in_HTML(text: bytes, index: int, body: bytes) -> bool: - """ Whether the Numberth occurrence of the first string in the second - string is inside the HTML but not inside a script tag or part of - a HTML attribute""" + """Whether the Numberth occurrence of the first string in the second + string is inside the HTML but not inside a script tag or part of + a HTML attribute""" # if there is a < then lxml will interpret that as a tag, so only search for the stuff before it text = text.split(b"<")[0] - paths = paths_to_text(body.decode('utf-8'), text.decode("utf-8")) + paths = paths_to_text(body.decode("utf-8"), text.decode("utf-8")) try: path = paths[index] return "script" not in path @@ -301,14 +365,14 @@ def get_XSS_data(body: Union[str, bytes], request_URL: str, injection_point: str return False def inject_javascript_handler(html: str) -> bool: - """ Whether you can inject a Javascript:alert(0) as a link """ + """Whether you can inject a Javascript:alert(0) as a link""" class injectJSHandlerHTMLParser(HTMLParser): injectJSHandler = False def handle_starttag(self, tag, attrs): for name, value in attrs: - if name == "href" and value.startswith(FRONT_WALL.decode('utf-8')): + if name == "href" and value.startswith(FRONT_WALL.decode("utf-8")): self.injectJSHandler = True parser = injectJSHandlerHTMLParser() @@ -317,7 +381,7 @@ def get_XSS_data(body: Union[str, bytes], request_URL: str, injection_point: str # Only convert the body to bytes if needed if isinstance(body, str): - body = bytes(body, 'utf-8') + body = bytes(body, "utf-8") # Regex for between 24 and 72 (aka 24*3) characters encapsulated by the walls regex = re.compile(b"""%s.{24,72}?%s""" % (FRONT_WALL, BACK_WALL)) matches = regex.findall(body) @@ -336,64 +400,121 @@ def get_XSS_data(body: Union[str, bytes], request_URL: str, injection_point: str inject_slash = b"sl/bsl" in match # forward slashes inject_semi = b"se;sl" in match # semicolons inject_equals = b"eq=" in match # equals sign - if in_script_val and inject_slash and inject_open_angle and inject_close_angle: # e.g. - return XSSData(request_URL, - injection_point, - ' - return XSSData(request_URL, - injection_point, - "';alert(0);g='", - match.decode('utf-8')) - elif in_script_val and in_double_quotes and inject_double_quotes and inject_semi: # e.g. - return XSSData(request_URL, - injection_point, - '";alert(0);g="', - match.decode('utf-8')) - elif in_tag and in_single_quotes and inject_single_quotes and inject_open_angle and inject_close_angle and inject_slash: + if ( + in_script_val and inject_slash and inject_open_angle and inject_close_angle + ): # e.g. + return XSSData( + request_URL, + injection_point, + " + return XSSData( + request_URL, injection_point, "';alert(0);g='", match.decode("utf-8") + ) + elif ( + in_script_val and in_double_quotes and inject_double_quotes and inject_semi + ): # e.g. + return XSSData( + request_URL, injection_point, '";alert(0);g="', match.decode("utf-8") + ) + elif ( + in_tag + and in_single_quotes + and inject_single_quotes + and inject_open_angle + and inject_close_angle + and inject_slash + ): # e.g. Test - return XSSData(request_URL, - injection_point, - "'>", - match.decode('utf-8')) - elif in_tag and in_double_quotes and inject_double_quotes and inject_open_angle and inject_close_angle and inject_slash: + return XSSData( + request_URL, + injection_point, + "'>", + match.decode("utf-8"), + ) + elif ( + in_tag + and in_double_quotes + and inject_double_quotes + and inject_open_angle + and inject_close_angle + and inject_slash + ): # e.g. Test - return XSSData(request_URL, - injection_point, - '">', - match.decode('utf-8')) - elif in_tag and not in_double_quotes and not in_single_quotes and inject_open_angle and inject_close_angle and inject_slash: + return XSSData( + request_URL, + injection_point, + '">', + match.decode("utf-8"), + ) + elif ( + in_tag + and not in_double_quotes + and not in_single_quotes + and inject_open_angle + and inject_close_angle + and inject_slash + ): # e.g. Test - return XSSData(request_URL, - injection_point, - '>', - match.decode('utf-8')) - elif inject_javascript_handler(body.decode('utf-8')): # e.g. Test - return XSSData(request_URL, - injection_point, - 'Javascript:alert(0)', - match.decode('utf-8')) - elif in_tag and in_double_quotes and inject_double_quotes and inject_equals: # e.g. Test - return XSSData(request_URL, - injection_point, - '" onmouseover="alert(0)" t="', - match.decode('utf-8')) - elif in_tag and in_single_quotes and inject_single_quotes and inject_equals: # e.g. Test - return XSSData(request_URL, - injection_point, - "' onmouseover='alert(0)' t='", - match.decode('utf-8')) - elif in_tag and not in_single_quotes and not in_double_quotes and inject_equals: # e.g. Test - return XSSData(request_URL, - injection_point, - " onmouseover=alert(0) t=", - match.decode('utf-8')) - elif in_HTML_val and not in_script_val and inject_open_angle and inject_close_angle and inject_slash: # e.g. PAYLOAD - return XSSData(request_URL, - injection_point, - '', - match.decode('utf-8')) + return XSSData( + request_URL, + injection_point, + ">", + match.decode("utf-8"), + ) + elif inject_javascript_handler( + body.decode("utf-8") + ): # e.g. Test + return XSSData( + request_URL, + injection_point, + "Javascript:alert(0)", + match.decode("utf-8"), + ) + elif ( + in_tag and in_double_quotes and inject_double_quotes and inject_equals + ): # e.g. Test + return XSSData( + request_URL, + injection_point, + '" onmouseover="alert(0)" t="', + match.decode("utf-8"), + ) + elif ( + in_tag and in_single_quotes and inject_single_quotes and inject_equals + ): # e.g. Test + return XSSData( + request_URL, + injection_point, + "' onmouseover='alert(0)' t='", + match.decode("utf-8"), + ) + elif ( + in_tag and not in_single_quotes and not in_double_quotes and inject_equals + ): # e.g. Test + return XSSData( + request_URL, + injection_point, + " onmouseover=alert(0) t=", + match.decode("utf-8"), + ) + elif ( + in_HTML_val + and not in_script_val + and inject_open_angle + and inject_close_angle + and inject_slash + ): # e.g. PAYLOAD + return XSSData( + request_URL, + injection_point, + "", + match.decode("utf-8"), + ) else: return None return None diff --git a/mitmproxy/addonmanager.py b/mitmproxy/addonmanager.py index 265c81113..5b52750a5 100644 --- a/mitmproxy/addonmanager.py +++ b/mitmproxy/addonmanager.py @@ -2,13 +2,14 @@ import contextlib import inspect import logging import pprint +import sys import traceback import types -from collections.abc import Callable, Sequence +from collections.abc import Callable +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Optional - -import sys +from typing import Any +from typing import Optional from mitmproxy import exceptions from mitmproxy import flow diff --git a/mitmproxy/addons/__init__.py b/mitmproxy/addons/__init__.py index 36ea9bfe5..0d801f577 100644 --- a/mitmproxy/addons/__init__.py +++ b/mitmproxy/addons/__init__.py @@ -11,19 +11,19 @@ from mitmproxy.addons import cut from mitmproxy.addons import disable_h2c from mitmproxy.addons import dns_resolver from mitmproxy.addons import export -from mitmproxy.addons import next_layer -from mitmproxy.addons import onboarding -from mitmproxy.addons import proxyserver -from mitmproxy.addons import proxyauth -from mitmproxy.addons import script -from mitmproxy.addons import serverplayback -from mitmproxy.addons import mapremote from mitmproxy.addons import maplocal +from mitmproxy.addons import mapremote from mitmproxy.addons import modifybody from mitmproxy.addons import modifyheaders +from mitmproxy.addons import next_layer +from mitmproxy.addons import onboarding +from mitmproxy.addons import proxyauth +from mitmproxy.addons import proxyserver +from mitmproxy.addons import save +from mitmproxy.addons import script +from mitmproxy.addons import serverplayback from mitmproxy.addons import stickyauth from mitmproxy.addons import stickycookie -from mitmproxy.addons import save from mitmproxy.addons import tlsconfig from mitmproxy.addons import upstream_auth diff --git a/mitmproxy/addons/asgiapp.py b/mitmproxy/addons/asgiapp.py index f85a88ebc..5425275a9 100644 --- a/mitmproxy/addons/asgiapp.py +++ b/mitmproxy/addons/asgiapp.py @@ -7,7 +7,8 @@ from typing import Optional import asgiref.compatibility import asgiref.wsgi -from mitmproxy import ctx, http +from mitmproxy import ctx +from mitmproxy import http logger = logging.getLogger(__name__) diff --git a/mitmproxy/addons/blocklist.py b/mitmproxy/addons/blocklist.py index 4d5b7bedf..6c99bca06 100644 --- a/mitmproxy/addons/blocklist.py +++ b/mitmproxy/addons/blocklist.py @@ -1,7 +1,11 @@ from collections.abc import Sequence from typing import NamedTuple -from mitmproxy import ctx, exceptions, flowfilter, http, version +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import flowfilter +from mitmproxy import http +from mitmproxy import version from mitmproxy.net.http.status_codes import NO_RESPONSE from mitmproxy.net.http.status_codes import RESPONSES diff --git a/mitmproxy/addons/browser.py b/mitmproxy/addons/browser.py index 7354f7c1e..ab2fcc560 100644 --- a/mitmproxy/addons/browser.py +++ b/mitmproxy/addons/browser.py @@ -74,7 +74,9 @@ class Browser: cmd = get_browser_cmd() if not cmd: - logging.log(ALERT, "Your platform is not supported yet - please submit a patch.") + logging.log( + ALERT, "Your platform is not supported yet - please submit a patch." + ) return tdir = tempfile.TemporaryDirectory() @@ -85,7 +87,8 @@ class Browser: *cmd, "--user-data-dir=%s" % str(tdir.name), "--proxy-server={}:{}".format( - ctx.options.listen_host or "127.0.0.1", ctx.options.listen_port or "8080" + ctx.options.listen_host or "127.0.0.1", + ctx.options.listen_port or "8080", ), "--disable-fre", "--no-default-browser-check", diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py index d435324e8..7abb83622 100644 --- a/mitmproxy/addons/clientplayback.py +++ b/mitmproxy/addons/clientplayback.py @@ -1,10 +1,10 @@ import asyncio import logging +import time import traceback from collections.abc import Sequence -from typing import Optional, cast - -import time +from typing import cast +from typing import Optional import mitmproxy.types from mitmproxy import command @@ -13,11 +13,15 @@ from mitmproxy import exceptions from mitmproxy import flow from mitmproxy import http from mitmproxy import io -from mitmproxy.connection import ConnectionState, Server +from mitmproxy.connection import ConnectionState +from mitmproxy.connection import Server from mitmproxy.hooks import UpdateHook from mitmproxy.log import ALERT from mitmproxy.options import Options -from mitmproxy.proxy import commands, events, layers, server +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layers +from mitmproxy.proxy import server from mitmproxy.proxy.context import Context from mitmproxy.proxy.layer import CommandGenerator from mitmproxy.proxy.layers.http import HTTPMode @@ -161,9 +165,7 @@ class ClientPlayback: else: await h.replay() except Exception: - logger.error( - f"Client replay has crashed!\n{traceback.format_exc()}" - ) + logger.error(f"Client replay has crashed!\n{traceback.format_exc()}") self.queue.task_done() self.inflight = None diff --git a/mitmproxy/addons/command_history.py b/mitmproxy/addons/command_history.py index d75d3cade..507b60e50 100644 --- a/mitmproxy/addons/command_history.py +++ b/mitmproxy/addons/command_history.py @@ -41,7 +41,7 @@ class CommandHistory: def done(self): if ctx.options.command_history and len(self.history) >= self.VACUUM_SIZE: # vacuum history so that it doesn't grow indefinitely. - history_str = "\n".join(self.history[-self.VACUUM_SIZE // 2:]) + "\n" + history_str = "\n".join(self.history[-self.VACUUM_SIZE // 2 :]) + "\n" try: self.history_file.write_text(history_str) except Exception as e: diff --git a/mitmproxy/addons/comment.py b/mitmproxy/addons/comment.py index 3e9b549c7..ecb303b0c 100644 --- a/mitmproxy/addons/comment.py +++ b/mitmproxy/addons/comment.py @@ -1,6 +1,8 @@ from collections.abc import Sequence -from mitmproxy import command, flow, ctx +from mitmproxy import command +from mitmproxy import ctx +from mitmproxy import flow from mitmproxy.hooks import UpdateHook diff --git a/mitmproxy/addons/core.py b/mitmproxy/addons/core.py index 87eda62cd..5ab1ba0a6 100644 --- a/mitmproxy/addons/core.py +++ b/mitmproxy/addons/core.py @@ -3,15 +3,16 @@ import os from collections.abc import Sequence from typing import Union -from mitmproxy.log import ALERT -from mitmproxy.utils import emoji -from mitmproxy import ctx, hooks -from mitmproxy import exceptions -from mitmproxy import command -from mitmproxy import flow -from mitmproxy import optmanager -from mitmproxy.net.http import status_codes import mitmproxy.types +from mitmproxy import command +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import flow +from mitmproxy import hooks +from mitmproxy import optmanager +from mitmproxy.log import ALERT +from mitmproxy.net.http import status_codes +from mitmproxy.utils import emoji logger = logging.getLogger(__name__) diff --git a/mitmproxy/addons/cut.py b/mitmproxy/addons/cut.py index 52a21df52..34a7b8fe0 100644 --- a/mitmproxy/addons/cut.py +++ b/mitmproxy/addons/cut.py @@ -1,18 +1,18 @@ -import io import csv +import io import logging import os.path from collections.abc import Sequence -from typing import Any, Union - -from mitmproxy import command -from mitmproxy import exceptions -from mitmproxy import flow -from mitmproxy import certs -import mitmproxy.types +from typing import Any +from typing import Union import pyperclip +import mitmproxy.types +from mitmproxy import certs +from mitmproxy import command +from mitmproxy import exceptions +from mitmproxy import flow from mitmproxy.log import ALERT logger = logging.getLogger(__name__) @@ -132,7 +132,7 @@ class Cut: writer.writerow(vals) logger.log( ALERT, - "Saved %s cuts over %d flows as CSV." % (len(cuts), len(flows)) + "Saved %s cuts over %d flows as CSV." % (len(cuts), len(flows)), ) except OSError as e: logger.error(str(e)) diff --git a/mitmproxy/addons/dns_resolver.py b/mitmproxy/addons/dns_resolver.py index fcfb153f2..714d37e16 100644 --- a/mitmproxy/addons/dns_resolver.py +++ b/mitmproxy/addons/dns_resolver.py @@ -1,7 +1,10 @@ import asyncio import ipaddress import socket -from typing import Callable, Iterable, Union +from collections.abc import Iterable +from typing import Callable +from typing import Union + from mitmproxy import dns from mitmproxy.proxy import mode_specs diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py index 8b290d171..42ecd2600 100644 --- a/mitmproxy/addons/dumper.py +++ b/mitmproxy/addons/dumper.py @@ -1,10 +1,12 @@ from __future__ import annotations -import logging import itertools +import logging import shutil import sys -from typing import IO, Optional, Union +from typing import IO +from typing import Optional +from typing import Union from wsproto.frame_protocol import CloseReason @@ -17,12 +19,15 @@ from mitmproxy import flowfilter from mitmproxy import http from mitmproxy.contrib import click as miniclick from mitmproxy.net.dns import response_codes -from mitmproxy.tcp import TCPFlow, TCPMessage -from mitmproxy.udp import UDPFlow, UDPMessage +from mitmproxy.tcp import TCPFlow +from mitmproxy.tcp import TCPMessage +from mitmproxy.udp import UDPFlow +from mitmproxy.udp import UDPMessage from mitmproxy.utils import human from mitmproxy.utils import strutils from mitmproxy.utils import vt_codes -from mitmproxy.websocket import WebSocketData, WebSocketMessage +from mitmproxy.websocket import WebSocketData +from mitmproxy.websocket import WebSocketMessage def indent(n: int, text: str) -> str: @@ -387,9 +392,12 @@ class Dumper: self.style(str(x), fg="bright_blue") for x in f.response.answers ) else: - answers = self.style(response_codes.to_str( - f.response.response_code, - ), fg="red") + answers = self.style( + response_codes.to_str( + f.response.response_code, + ), + fg="red", + ) self.echo(f"{arrows} {answers}") def dns_error(self, f: dns.DNSFlow): diff --git a/mitmproxy/addons/errorcheck.py b/mitmproxy/addons/errorcheck.py index 1015c9efb..b634623cf 100644 --- a/mitmproxy/addons/errorcheck.py +++ b/mitmproxy/addons/errorcheck.py @@ -1,6 +1,5 @@ import asyncio import logging - import sys from mitmproxy import log diff --git a/mitmproxy/addons/eventstore.py b/mitmproxy/addons/eventstore.py index 6f597b9c3..e719e6979 100644 --- a/mitmproxy/addons/eventstore.py +++ b/mitmproxy/addons/eventstore.py @@ -4,7 +4,8 @@ import logging from collections.abc import Callable from typing import Optional -from mitmproxy import command, log +from mitmproxy import command +from mitmproxy import log from mitmproxy.log import LogEntry from mitmproxy.utils import signals diff --git a/mitmproxy/addons/export.py b/mitmproxy/addons/export.py index 1e2267bff..1339111c5 100644 --- a/mitmproxy/addons/export.py +++ b/mitmproxy/addons/export.py @@ -1,15 +1,18 @@ import logging import shlex -from collections.abc import Callable, Sequence -from typing import Any, Union +from collections.abc import Callable +from collections.abc import Sequence +from typing import Any +from typing import Union import pyperclip import mitmproxy.types from mitmproxy import command -from mitmproxy import ctx, http +from mitmproxy import ctx from mitmproxy import exceptions from mitmproxy import flow +from mitmproxy import http from mitmproxy.net.http.http1 import assemble from mitmproxy.utils import strutils diff --git a/mitmproxy/addons/intercept.py b/mitmproxy/addons/intercept.py index f91876269..24a9173f9 100644 --- a/mitmproxy/addons/intercept.py +++ b/mitmproxy/addons/intercept.py @@ -1,8 +1,9 @@ from typing import Optional -from mitmproxy import flow, flowfilter -from mitmproxy import exceptions from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import flow +from mitmproxy import flowfilter class Intercept: diff --git a/mitmproxy/addons/keepserving.py b/mitmproxy/addons/keepserving.py index 5a149fdec..199cf2548 100644 --- a/mitmproxy/addons/keepserving.py +++ b/mitmproxy/addons/keepserving.py @@ -1,4 +1,5 @@ import asyncio + from mitmproxy import ctx diff --git a/mitmproxy/addons/maplocal.py b/mitmproxy/addons/maplocal.py index 54c522d93..5b1abd0b5 100644 --- a/mitmproxy/addons/maplocal.py +++ b/mitmproxy/addons/maplocal.py @@ -8,7 +8,11 @@ from typing import NamedTuple from werkzeug.security import safe_join -from mitmproxy import ctx, exceptions, flowfilter, http, version +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import flowfilter +from mitmproxy import http +from mitmproxy import version from mitmproxy.utils.spec import parse_spec diff --git a/mitmproxy/addons/mapremote.py b/mitmproxy/addons/mapremote.py index 245323a03..31a759ada 100644 --- a/mitmproxy/addons/mapremote.py +++ b/mitmproxy/addons/mapremote.py @@ -2,7 +2,10 @@ import re from collections.abc import Sequence from typing import NamedTuple -from mitmproxy import ctx, exceptions, flowfilter, http +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import flowfilter +from mitmproxy import http from mitmproxy.utils.spec import parse_spec diff --git a/mitmproxy/addons/modifybody.py b/mitmproxy/addons/modifybody.py index b82059afe..4148c4799 100644 --- a/mitmproxy/addons/modifybody.py +++ b/mitmproxy/addons/modifybody.py @@ -2,8 +2,10 @@ import logging import re from collections.abc import Sequence -from mitmproxy import ctx, exceptions -from mitmproxy.addons.modifyheaders import parse_modify_spec, ModifySpec +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy.addons.modifyheaders import ModifySpec +from mitmproxy.addons.modifyheaders import parse_modify_spec class ModifyBody: diff --git a/mitmproxy/addons/modifyheaders.py b/mitmproxy/addons/modifyheaders.py index 995005f30..a7e45b0dd 100644 --- a/mitmproxy/addons/modifyheaders.py +++ b/mitmproxy/addons/modifyheaders.py @@ -4,7 +4,10 @@ from collections.abc import Sequence from pathlib import Path from typing import NamedTuple -from mitmproxy import ctx, exceptions, flowfilter, http +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import flowfilter +from mitmproxy import http from mitmproxy.http import Headers from mitmproxy.utils import strutils from mitmproxy.utils.spec import parse_spec diff --git a/mitmproxy/addons/next_layer.py b/mitmproxy/addons/next_layer.py index bf4be48df..3495ed919 100644 --- a/mitmproxy/addons/next_layer.py +++ b/mitmproxy/addons/next_layer.py @@ -15,22 +15,39 @@ In that case it's not necessary to modify mitmproxy's source, adding a custom ad that sets nextlayer.layer works just as well. """ import re -from collections.abc import Sequence import struct -from typing import Any, Callable, Iterable, Optional, Union, cast +from collections.abc import Iterable +from collections.abc import Sequence +from typing import Any +from typing import Callable +from typing import cast +from typing import Optional +from typing import Union -from mitmproxy import ctx, dns, exceptions, connection +from mitmproxy import connection +from mitmproxy import ctx +from mitmproxy import dns +from mitmproxy import exceptions from mitmproxy.net.tls import is_tls_record_magic -from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy import context, layer, layers, mode_specs +from mitmproxy.proxy import context +from mitmproxy.proxy import layer +from mitmproxy.proxy import layers +from mitmproxy.proxy import mode_specs from mitmproxy.proxy.layers import modes +from mitmproxy.proxy.layers.http import HTTPMode from mitmproxy.proxy.layers.quic import quic_parse_client_hello -from mitmproxy.proxy.layers.tls import HTTP_ALPNS, dtls_parse_client_hello, parse_client_hello +from mitmproxy.proxy.layers.tls import dtls_parse_client_hello +from mitmproxy.proxy.layers.tls import HTTP_ALPNS +from mitmproxy.proxy.layers.tls import parse_client_hello from mitmproxy.tls import ClientHello LayerCls = type[layer.Layer] -ClientSecurityLayerCls = Union[type[layers.ClientTLSLayer], type[layers.ClientQuicLayer]] -ServerSecurityLayerCls = Union[type[layers.ServerTLSLayer], type[layers.ServerQuicLayer]] +ClientSecurityLayerCls = Union[ + type[layers.ClientTLSLayer], type[layers.ClientQuicLayer] +] +ServerSecurityLayerCls = Union[ + type[layers.ServerTLSLayer], type[layers.ServerQuicLayer] +] def stack_match( @@ -77,7 +94,7 @@ class NextLayer: data_client: bytes, *, is_tls: Callable[[bytes], bool] = is_tls_record_magic, - client_hello: Callable[[bytes], Optional[ClientHello]] = parse_client_hello + client_hello: Callable[[bytes], Optional[ClientHello]] = parse_client_hello, ) -> Optional[bool]: """ Returns: @@ -148,7 +165,9 @@ class NextLayer: ret.child_layer = client_layer_cls(context) return ret - def is_destination_in_hosts(self, context: context.Context, hosts: Iterable[re.Pattern]) -> bool: + def is_destination_in_hosts( + self, context: context.Context, hosts: Iterable[re.Pattern] + ) -> bool: return any( (context.server.address and rex.search(context.server.address[0])) or (context.client.sni and rex.search(context.client.sni)) @@ -168,15 +187,15 @@ class NextLayer: ): return layers.HttpLayer(context, HTTPMode.regular) # ... or an upstream proxy. - if ( - s(modes.HttpUpstreamProxy) - or - s(modes.HttpUpstreamProxy, (layers.ClientTLSLayer, layers.ClientQuicLayer)) + if s(modes.HttpUpstreamProxy) or s( + modes.HttpUpstreamProxy, (layers.ClientTLSLayer, layers.ClientQuicLayer) ): return layers.HttpLayer(context, HTTPMode.upstream) return None - def detect_udp_tls(self, data_client: bytes) -> Optional[tuple[ClientHello, ClientSecurityLayerCls, ServerSecurityLayerCls]]: + def detect_udp_tls( + self, data_client: bytes + ) -> Optional[tuple[ClientHello, ClientSecurityLayerCls, ServerSecurityLayerCls]]: if len(data_client) == 0: return None @@ -198,23 +217,23 @@ class NextLayer: # that's all we currently have to offer return None - def raw_udp_layer(self, context: context.Context, ignore: bool = False) -> layer.Layer: + def raw_udp_layer( + self, context: context.Context, ignore: bool = False + ) -> layer.Layer: def s(*layers): return stack_match(context, layers) # for regular and upstream HTTP3, if we already created a client QUIC layer # we need a server and raw QUIC layer as well - if ( - s(modes.HttpProxy, layers.ClientQuicLayer) - or - s(modes.HttpUpstreamProxy, layers.ClientQuicLayer) + if s(modes.HttpProxy, layers.ClientQuicLayer) or s( + modes.HttpUpstreamProxy, layers.ClientQuicLayer ): server_layer = layers.ServerQuicLayer(context) server_layer.child_layer = layers.RawQuicLayer(context, ignore=ignore) return server_layer # for reverse HTTP3 and QUIC, we need a client and raw QUIC layer - elif (s(modes.ReverseProxy, layers.ServerQuicLayer)): + elif s(modes.ReverseProxy, layers.ServerQuicLayer): client_layer = layers.ClientQuicLayer(context) client_layer.child_layer = layers.RawQuicLayer(context, ignore=ignore) return client_layer @@ -243,11 +262,7 @@ class NextLayer: if context.client.transport_protocol == "tcp": is_quic_stream = isinstance(context.layers[-1], layers.QuicStreamLayer) - if ( - len(data_client) < 3 - and not data_server - and not is_quic_stream - ): + if len(data_client) < 3 and not data_server and not is_quic_stream: return None # not enough data yet to make a decision # 1. check for --ignore/--allow @@ -292,7 +307,7 @@ class NextLayer: context.server.address, data_client, is_tls=lambda _: tls is not None, - client_hello=lambda _: None if tls is None else tls[0] + client_hello=lambda _: None if tls is None else tls[0], ): return self.raw_udp_layer(context, ignore=True) @@ -310,7 +325,7 @@ class NextLayer: return self.raw_udp_layer(context) # 5. Check for reverse modes - if (isinstance(context.layers[0], modes.ReverseProxy)): + if isinstance(context.layers[0], modes.ReverseProxy): scheme = cast(mode_specs.ReverseMode, context.client.proxy_mode).scheme if scheme in ("udp", "dtls"): return layers.UDPLayer(context) diff --git a/mitmproxy/addons/onboarding.py b/mitmproxy/addons/onboarding.py index 78ff11033..02cf4bd20 100644 --- a/mitmproxy/addons/onboarding.py +++ b/mitmproxy/addons/onboarding.py @@ -1,6 +1,6 @@ +from mitmproxy import ctx from mitmproxy.addons import asgiapp from mitmproxy.addons.onboardingapp import app -from mitmproxy import ctx APP_HOST = "mitm.it" diff --git a/mitmproxy/addons/onboardingapp/__init__.py b/mitmproxy/addons/onboardingapp/__init__.py index f8fafd20c..a4aed1985 100644 --- a/mitmproxy/addons/onboardingapp/__init__.py +++ b/mitmproxy/addons/onboardingapp/__init__.py @@ -1,8 +1,10 @@ import os -from flask import Flask, render_template +from flask import Flask +from flask import render_template -from mitmproxy.options import CONF_BASENAME, CONF_DIR +from mitmproxy.options import CONF_BASENAME +from mitmproxy.options import CONF_DIR from mitmproxy.utils.magisk import write_magisk_module app = Flask(__name__) diff --git a/mitmproxy/addons/proxyauth.py b/mitmproxy/addons/proxyauth.py index 653fa48ae..5ff57feee 100644 --- a/mitmproxy/addons/proxyauth.py +++ b/mitmproxy/addons/proxyauth.py @@ -2,14 +2,16 @@ from __future__ import annotations import binascii import weakref -from abc import ABC, abstractmethod -from typing import MutableMapping +from abc import ABC +from abc import abstractmethod +from collections.abc import MutableMapping from typing import Optional import ldap3 import passlib.apache -from mitmproxy import connection, ctx +from mitmproxy import connection +from mitmproxy import ctx from mitmproxy import exceptions from mitmproxy import http from mitmproxy.net.http import status_codes @@ -141,7 +143,9 @@ def is_http_proxy(f: http.HTTPFlow) -> bool: - True, if authentication is done as if mitmproxy is a proxy - False, if authentication is done as if mitmproxy is an HTTP server """ - return isinstance(f.client_conn.proxy_mode, (mode_specs.RegularMode, mode_specs.UpstreamMode)) + return isinstance( + f.client_conn.proxy_mode, (mode_specs.RegularMode, mode_specs.UpstreamMode) + ) def mkauth(username: str, password: str, scheme: str = "basic") -> str: diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index c28980bf0..cb17bcfbf 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -7,29 +7,34 @@ import asyncio import collections import ipaddress import logging +from collections.abc import Iterable +from collections.abc import Iterator from contextlib import contextmanager -from typing import Iterable, Iterator, Optional +from typing import Optional from wsproto.frame_protocol import Opcode -from mitmproxy import ( - command, - ctx, - exceptions, - http, - platform, - tcp, - udp, - websocket, -) +from mitmproxy import command +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import http +from mitmproxy import platform +from mitmproxy import tcp +from mitmproxy import udp +from mitmproxy import websocket from mitmproxy.connection import Address from mitmproxy.flow import Flow -from mitmproxy.proxy import events, mode_specs, server_hooks +from mitmproxy.proxy import events +from mitmproxy.proxy import mode_specs +from mitmproxy.proxy import server_hooks from mitmproxy.proxy.layers.tcp import TcpMessageInjected from mitmproxy.proxy.layers.udp import UdpMessageInjected from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected -from mitmproxy.proxy.mode_servers import ProxyConnectionHandler, ServerInstance, ServerManager -from mitmproxy.utils import human, signals +from mitmproxy.proxy.mode_servers import ProxyConnectionHandler +from mitmproxy.proxy.mode_servers import ServerInstance +from mitmproxy.proxy.mode_servers import ServerManager +from mitmproxy.utils import human +from mitmproxy.utils import signals logger = logging.getLogger(__name__) @@ -64,7 +69,8 @@ class Servers: # Shutdown modes that have been removed from the list. stop_tasks = [ - s.stop() for spec, s in self._instances.items() + s.stop() + for spec, s in self._instances.items() if spec not in new_instances ] @@ -101,6 +107,7 @@ class Proxyserver(ServerManager): """ This addon runs the actual proxy server. """ + connections: dict[tuple, ProxyConnectionHandler] servers: Servers @@ -116,7 +123,9 @@ class Proxyserver(ServerManager): return f"Proxyserver({len(self.connections)} active conns)" @contextmanager - def register_connection(self, connection_id: tuple, handler: ProxyConnectionHandler): + def register_connection( + self, connection_id: tuple, handler: ProxyConnectionHandler + ): self.connections[connection_id] = handler try: yield @@ -217,7 +226,10 @@ class Proxyserver(ServerManager): if "connect_addr" in updated: try: if ctx.options.connect_addr: - self._connect_addr = str(ipaddress.ip_address(ctx.options.connect_addr)), 0 + self._connect_addr = ( + str(ipaddress.ip_address(ctx.options.connect_addr)), + 0, + ) else: self._connect_addr = None except ValueError: @@ -229,25 +241,27 @@ class Proxyserver(ServerManager): modes: list[mode_specs.ProxyMode] = [] for mode in ctx.options.mode: try: - modes.append( - mode_specs.ProxyMode.parse(mode) - ) + modes.append(mode_specs.ProxyMode.parse(mode)) except ValueError as e: - raise exceptions.OptionsError(f"Invalid proxy mode specification: {mode} ({e})") + raise exceptions.OptionsError( + f"Invalid proxy mode specification: {mode} ({e})" + ) # ...and don't listen on the same address. listen_addrs = [ ( m.listen_host(ctx.options.listen_host), m.listen_port(ctx.options.listen_port), - m.transport_protocol + m.transport_protocol, ) for m in modes ] if len(set(listen_addrs)) != len(listen_addrs): (host, port, _) = collections.Counter(listen_addrs).most_common(1)[0][0] dup_addr = human.format_address((host or "0.0.0.0", port)) - raise exceptions.OptionsError(f"Cannot spawn multiple servers on the same address: {dup_addr}") + raise exceptions.OptionsError( + f"Cannot spawn multiple servers on the same address: {dup_addr}" + ) if ctx.options.mode and not ctx.master.addons.get("nextlayer"): logger.warning("Warning: Running proxyserver without nextlayer addon!") @@ -255,20 +269,20 @@ class Proxyserver(ServerManager): if platform.original_addr: platform.init_transparent_mode() else: - raise exceptions.OptionsError("Transparent mode not supported on this platform.") + raise exceptions.OptionsError( + "Transparent mode not supported on this platform." + ) if self.is_running: asyncio.create_task(self.servers.update(modes)) async def setup_servers(self) -> bool: - return await self.servers.update([mode_specs.ProxyMode.parse(m) for m in ctx.options.mode]) + return await self.servers.update( + [mode_specs.ProxyMode.parse(m) for m in ctx.options.mode] + ) def listen_addrs(self) -> list[Address]: - return [ - addr - for server in self.servers - for addr in server.listen_addrs - ] + return [addr for server in self.servers for addr in server.listen_addrs] def inject_event(self, event: events.MessageInjected): connection_id = ( @@ -330,12 +344,7 @@ class Proxyserver(ServerManager): for listen_host, listen_port, *_ in server.listen_addrs: self_connect = ( connect_port == listen_port - and connect_host in ( - "localhost", - "127.0.0.1", - "::1", - listen_host - ) + and connect_host in ("localhost", "127.0.0.1", "::1", listen_host) and server.mode.transport_protocol == data.server.transport_protocol ) if self_connect: diff --git a/mitmproxy/addons/readfile.py b/mitmproxy/addons/readfile.py index f54708659..e9f0a5cbf 100644 --- a/mitmproxy/addons/readfile.py +++ b/mitmproxy/addons/readfile.py @@ -2,13 +2,14 @@ import asyncio import logging import os.path import sys -from typing import BinaryIO, Optional +from typing import BinaryIO +from typing import Optional +from mitmproxy import command from mitmproxy import ctx from mitmproxy import exceptions from mitmproxy import flowfilter from mitmproxy import io -from mitmproxy import command class ReadFile: diff --git a/mitmproxy/addons/save.py b/mitmproxy/addons/save.py index 2a3db39b7..f2b0b2133 100644 --- a/mitmproxy/addons/save.py +++ b/mitmproxy/addons/save.py @@ -5,10 +5,11 @@ from collections.abc import Sequence from datetime import datetime from functools import lru_cache from pathlib import Path -from typing import Literal, Optional +from typing import Literal +from typing import Optional import mitmproxy.types -from mitmproxy import command, tcp, udp +from mitmproxy import command from mitmproxy import ctx from mitmproxy import dns from mitmproxy import exceptions @@ -16,6 +17,8 @@ from mitmproxy import flow from mitmproxy import flowfilter from mitmproxy import http from mitmproxy import io +from mitmproxy import tcp +from mitmproxy import udp from mitmproxy.log import ALERT diff --git a/mitmproxy/addons/script.py b/mitmproxy/addons/script.py index 12586d327..3e5f0beff 100644 --- a/mitmproxy/addons/script.py +++ b/mitmproxy/addons/script.py @@ -1,21 +1,22 @@ import asyncio +import importlib.machinery +import importlib.util import logging import os -import importlib.util -import importlib.machinery import sys -import types import traceback +import types from collections.abc import Sequence from typing import Optional -from mitmproxy import addonmanager, hooks +import mitmproxy.types as mtypes +from mitmproxy import addonmanager +from mitmproxy import command +from mitmproxy import ctx +from mitmproxy import eventsequence from mitmproxy import exceptions from mitmproxy import flow -from mitmproxy import command -from mitmproxy import eventsequence -from mitmproxy import ctx -import mitmproxy.types as mtypes +from mitmproxy import hooks from mitmproxy.utils import asyncio_utils logger = logging.getLogger(__name__) diff --git a/mitmproxy/addons/serverplayback.py b/mitmproxy/addons/serverplayback.py index 1973547c2..7419f8ac5 100644 --- a/mitmproxy/addons/serverplayback.py +++ b/mitmproxy/addons/serverplayback.py @@ -1,14 +1,18 @@ import hashlib import logging import urllib -from collections.abc import Hashable, Sequence -from typing import Any, Optional +from collections.abc import Hashable +from collections.abc import Sequence +from typing import Any +from typing import Optional import mitmproxy.types -from mitmproxy import command, hooks -from mitmproxy import ctx, http +from mitmproxy import command +from mitmproxy import ctx from mitmproxy import exceptions from mitmproxy import flow +from mitmproxy import hooks +from mitmproxy import http from mitmproxy import io diff --git a/mitmproxy/addons/stickyauth.py b/mitmproxy/addons/stickyauth.py index 15d98c33d..bd3b4e49d 100644 --- a/mitmproxy/addons/stickyauth.py +++ b/mitmproxy/addons/stickyauth.py @@ -1,8 +1,8 @@ from typing import Optional +from mitmproxy import ctx from mitmproxy import exceptions from mitmproxy import flowfilter -from mitmproxy import ctx class StickyAuth: diff --git a/mitmproxy/addons/stickycookie.py b/mitmproxy/addons/stickycookie.py index ef33f5bc4..becaafa44 100644 --- a/mitmproxy/addons/stickycookie.py +++ b/mitmproxy/addons/stickycookie.py @@ -2,7 +2,10 @@ import collections from http import cookiejar from typing import Optional -from mitmproxy import http, flowfilter, ctx, exceptions +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import flowfilter +from mitmproxy import http from mitmproxy.net.http import cookies TOrigin = tuple[str, int, str] @@ -31,7 +34,9 @@ def domain_match(a: str, b: str) -> bool: class StickyCookie: def __init__(self) -> None: - self.jar: collections.defaultdict[TOrigin, dict[str, str]] = collections.defaultdict(dict) + self.jar: collections.defaultdict[ + TOrigin, dict[str, str] + ] = collections.defaultdict(dict) self.flt: Optional[flowfilter.TFilter] = None def load(self, loader): diff --git a/mitmproxy/addons/termlog.py b/mitmproxy/addons/termlog.py index aa15ae39b..9da47e141 100644 --- a/mitmproxy/addons/termlog.py +++ b/mitmproxy/addons/termlog.py @@ -1,19 +1,17 @@ from __future__ import annotations + import asyncio import logging +import sys from typing import IO -import sys - -from mitmproxy import ctx, log +from mitmproxy import ctx +from mitmproxy import log from mitmproxy.utils import vt_codes class TermLog: - def __init__( - self, - out: IO[str] | None = None - ): + def __init__(self, out: IO[str] | None = None): self.logger = TermLogHandler(out) self.logger.install() @@ -41,10 +39,7 @@ class TermLog: class TermLogHandler(log.MitmLogHandler): - def __init__( - self, - out: IO[str] | None = None - ): + def __init__(self, out: IO[str] | None = None): super().__init__() self.file: IO[str] = out or sys.stdout self.has_vt_codes = vt_codes.ensure_supported(self.file) diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index 40bbee735..0d7faf52d 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -1,19 +1,28 @@ import ipaddress import logging import os -from pathlib import Path import ssl -from typing import Any, Optional, TypedDict +from pathlib import Path +from typing import Any +from typing import Optional +from typing import TypedDict from aioquic.h3.connection import H3_ALPN from aioquic.tls import CipherSuite -from OpenSSL import SSL, crypto -from mitmproxy import certs, ctx, exceptions, connection, tls +from OpenSSL import crypto +from OpenSSL import SSL + +from mitmproxy import certs +from mitmproxy import connection +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import tls from mitmproxy.net import tls as net_tls from mitmproxy.options import CONF_BASENAME from mitmproxy.proxy import context from mitmproxy.proxy.layers import modes -from mitmproxy.proxy.layers import tls as proxy_tls, quic +from mitmproxy.proxy.layers import quic +from mitmproxy.proxy.layers import tls as proxy_tls # We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default. # https://ssl-config.mozilla.org/#config=old @@ -166,7 +175,9 @@ class TlsConfig: extra_chain_certs = [] ssl_ctx = net_tls.create_client_proxy_context( - method=net_tls.Method.DTLS_SERVER_METHOD if tls_start.is_dtls else net_tls.Method.TLS_SERVER_METHOD, + method=net_tls.Method.DTLS_SERVER_METHOD + if tls_start.is_dtls + else net_tls.Method.TLS_SERVER_METHOD, min_version=net_tls.Version[ctx.options.tls_version_client_min], max_version=net_tls.Version[ctx.options.tls_version_client_max], cipher_list=tuple(cipher_list), @@ -179,7 +190,9 @@ class TlsConfig: tls_start.ssl_conn = SSL.Connection(ssl_ctx) tls_start.ssl_conn.use_certificate(entry.cert.to_pyopenssl()) - tls_start.ssl_conn.use_privatekey(crypto.PKey.from_cryptography_key(entry.privatekey)) + tls_start.ssl_conn.use_privatekey( + crypto.PKey.from_cryptography_key(entry.privatekey) + ) # Force HTTP/1 for secure web proxies, we currently don't support CONNECT over HTTP/2. # There is a proof-of-concept branch at https://github.com/mhils/mitmproxy/tree/http2-proxy, @@ -256,7 +269,9 @@ class TlsConfig: client_cert = p ssl_ctx = net_tls.create_proxy_server_context( - method=net_tls.Method.DTLS_CLIENT_METHOD if tls_start.is_dtls else net_tls.Method.TLS_CLIENT_METHOD, + method=net_tls.Method.DTLS_CLIENT_METHOD + if tls_start.is_dtls + else net_tls.Method.TLS_CLIENT_METHOD, min_version=net_tls.Version[ctx.options.tls_version_server_min], max_version=net_tls.Version[ctx.options.tls_version_server_max], cipher_list=tuple(cipher_list), @@ -328,9 +343,8 @@ class TlsConfig: # if we don't have upstream ALPN, we allow all offered by the client tls_start.settings.alpn_protocols = [ alpn.decode("ascii") - for alpn in [ - alpn for alpn in (client.alpn, server.alpn) if alpn - ] or client.alpn_offers + for alpn in [alpn for alpn in (client.alpn, server.alpn) if alpn] + or client.alpn_offers ] # set the certificates diff --git a/mitmproxy/addons/upstream_auth.py b/mitmproxy/addons/upstream_auth.py index da9b39534..63b0d32bb 100644 --- a/mitmproxy/addons/upstream_auth.py +++ b/mitmproxy/addons/upstream_auth.py @@ -1,9 +1,9 @@ -import re import base64 +import re from typing import Optional -from mitmproxy import exceptions from mitmproxy import ctx +from mitmproxy import exceptions from mitmproxy import http from mitmproxy.proxy import mode_specs from mitmproxy.utils import strutils @@ -53,7 +53,10 @@ class UpstreamAuth: def requestheaders(self, f: http.HTTPFlow): if self.auth: - if isinstance(f.client_conn.proxy_mode, mode_specs.UpstreamMode) and f.request.scheme == "http": + if ( + isinstance(f.client_conn.proxy_mode, mode_specs.UpstreamMode) + and f.request.scheme == "http" + ): f.request.headers["Proxy-Authorization"] = self.auth elif isinstance(f.client_conn.proxy_mode, mode_specs.ReverseMode): f.request.headers["Authorization"] = self.auth diff --git a/mitmproxy/addons/view.py b/mitmproxy/addons/view.py index 965d9fc48..06049715f 100644 --- a/mitmproxy/addons/view.py +++ b/mitmproxy/addons/view.py @@ -11,25 +11,29 @@ The View: import collections import logging import re -from collections.abc import Iterator, MutableMapping, Sequence -from typing import Any, Optional +from collections.abc import Iterator +from collections.abc import MutableMapping +from collections.abc import Sequence +from typing import Any +from typing import Optional import sortedcontainers import mitmproxy.flow from mitmproxy import command +from mitmproxy import connection from mitmproxy import ctx from mitmproxy import dns from mitmproxy import exceptions -from mitmproxy import hooks -from mitmproxy import connection from mitmproxy import flowfilter +from mitmproxy import hooks from mitmproxy import http from mitmproxy import io from mitmproxy import tcp from mitmproxy import udp from mitmproxy.log import ALERT -from mitmproxy.utils import human, signals +from mitmproxy.utils import human +from mitmproxy.utils import signals # The underlying sorted list implementation expects the sort key to be stable @@ -144,7 +148,9 @@ def _sig_view_remove(flow: mitmproxy.flow.Flow, index: int) -> None: class View(collections.abc.Sequence): def __init__(self) -> None: super().__init__() - self._store: collections.OrderedDict[str, mitmproxy.flow.Flow] = collections.OrderedDict() + self._store: collections.OrderedDict[ + str, mitmproxy.flow.Flow + ] = collections.OrderedDict() self.filter = flowfilter.match_all # Should we show only marked flows? self.show_marked = False @@ -475,7 +481,11 @@ class View(collections.abc.Sequence): except ValueError as e: raise exceptions.CommandError("Invalid URL: %s" % e) - c = connection.Client(peername=("", 0), sockname=("", 0), timestamp_start=req.timestamp_start - 0.0001) + c = connection.Client( + peername=("", 0), + sockname=("", 0), + timestamp_start=req.timestamp_start - 0.0001, + ) s = connection.Server(address=(req.host, req.port)) f = http.HTTPFlow(c, s) diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index 70d847075..9b0a82489 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -6,15 +6,21 @@ import re import sys from dataclasses import dataclass from pathlib import Path -from typing import NewType, Optional, Union - -from cryptography import x509 -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec -from cryptography.hazmat.primitives.serialization import pkcs12 -from cryptography.x509 import NameOID, ExtendedKeyUsageOID +from typing import NewType +from typing import Optional +from typing import Union import OpenSSL +from cryptography import x509 +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import dsa +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.serialization import pkcs12 +from cryptography.x509 import ExtendedKeyUsageOID +from cryptography.x509 import NameOID + from mitmproxy.coretypes import serializable # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 @@ -315,7 +321,10 @@ class CertStore: self.default_chain_certs = ( [ Cert.from_pem(chunk) - for chunk in re.split(rb"(?=-----BEGIN( [A-Z]+)+-----)", self.default_chain_file.read_bytes()) + for chunk in re.split( + rb"(?=-----BEGIN( [A-Z]+)+-----)", + self.default_chain_file.read_bytes(), + ) if chunk.startswith(b"-----BEGIN CERTIFICATE-----") ] if self.default_chain_file diff --git a/mitmproxy/command.py b/mitmproxy/command.py index 950fa44ef..f62b93c22 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -4,16 +4,21 @@ import functools import inspect import logging - -import pyparsing import sys import textwrap import types -from collections.abc import Sequence, Callable, Iterable -from typing import Any, NamedTuple, Optional +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Sequence +from typing import Any +from typing import NamedTuple +from typing import Optional + +import pyparsing import mitmproxy.types -from mitmproxy import exceptions, command_lexer +from mitmproxy import command_lexer +from mitmproxy import exceptions from mitmproxy.command_lexer import unquote @@ -195,7 +200,9 @@ class CommandManager: Parse a possibly partial command. Return a sequence of ParseResults and a sequence of remainder type help items. """ - parts: pyparsing.ParseResults = command_lexer.expr.parseString(cmdstr, parseAll=True) + parts: pyparsing.ParseResults = command_lexer.expr.parseString( + cmdstr, parseAll=True + ) parsed: list[ParseResult] = [] next_params: list[CommandParameter] = [ diff --git a/mitmproxy/connection.py b/mitmproxy/connection.py index 20782057a..e0e748ff0 100644 --- a/mitmproxy/connection.py +++ b/mitmproxy/connection.py @@ -1,18 +1,20 @@ import dataclasses import sys import time -from dataclasses import dataclass, field import uuid import warnings from abc import ABCMeta from collections.abc import Sequence +from dataclasses import dataclass +from dataclasses import field from enum import Flag -from typing import Literal, Optional +from typing import Literal +from typing import Optional from mitmproxy import certs from mitmproxy.coretypes import serializable -from mitmproxy.proxy import mode_specs from mitmproxy.net import server_spec +from mitmproxy.proxy import mode_specs from mitmproxy.utils import human @@ -48,12 +50,15 @@ class Connection(serializable.SerializableDataclass, metaclass=ABCMeta): The connection object only exposes metadata about the connection, but not the underlying socket object. This is intentional, all I/O should be handled by `mitmproxy.proxy.server` exclusively. """ + peername: Optional[Address] """The remote's `(ip, port)` tuple for this connection.""" sockname: Optional[Address] """Our local `(ip, port)` tuple for this connection.""" - state: ConnectionState = field(default=ConnectionState.CLOSED, metadata={"serialize": False}) + state: ConnectionState = field( + default=ConnectionState.CLOSED, metadata={"serialize": False} + ) """The current connection state.""" # all connections have a unique id. While @@ -177,7 +182,9 @@ class Client(Connection): The certificate used by mitmproxy to establish TLS with the client. """ - proxy_mode: mode_specs.ProxyMode = field(default=mode_specs.ProxyMode.parse("regular")) + proxy_mode: mode_specs.ProxyMode = field( + default=mode_specs.ProxyMode.parse("regular") + ) """The proxy server type this client has been connecting to.""" timestamp_start: float = field(default_factory=time.time) diff --git a/mitmproxy/contentviews/__init__.py b/mitmproxy/contentviews/__init__.py index 7c62b1f6f..2949c8fa3 100644 --- a/mitmproxy/contentviews/__init__.py +++ b/mitmproxy/contentviews/__init__.py @@ -12,37 +12,41 @@ metadata depend on the protocol in use. Known attributes can be found in `base.View`. """ import traceback -from typing import Union from typing import Optional +from typing import Union -from mitmproxy import flow, tcp, udp -from mitmproxy import http -from mitmproxy.utils import signals, strutils -from . import ( - auto, - raw, - hex, - json, - xml_html, - wbxml, - javascript, - css, - urlencoded, - multipart, - image, - query, - protobuf, - msgpack, - graphql, - grpc, - mqtt, - http3, -) - -from .base import View, KEY_MAX, format_text, format_dict, TViewResult +from . import auto +from . import css +from . import graphql +from . import grpc +from . import hex +from . import http3 +from . import image +from . import javascript +from . import json +from . import mqtt +from . import msgpack +from . import multipart +from . import protobuf +from . import query +from . import raw +from . import urlencoded +from . import wbxml +from . import xml_html from ..tcp import TCPMessage from ..udp import UDPMessage from ..websocket import WebSocketMessage +from .base import format_dict +from .base import format_text +from .base import KEY_MAX +from .base import TViewResult +from .base import View +from mitmproxy import flow +from mitmproxy import http +from mitmproxy import tcp +from mitmproxy import udp +from mitmproxy.utils import signals +from mitmproxy.utils import strutils views: list[View] = [] diff --git a/mitmproxy/contentviews/auto.py b/mitmproxy/contentviews/auto.py index d86dcf810..e8acfd94e 100644 --- a/mitmproxy/contentviews/auto.py +++ b/mitmproxy/contentviews/auto.py @@ -1,5 +1,5 @@ -from mitmproxy import contentviews from . import base +from mitmproxy import contentviews class ViewAuto(base.View): diff --git a/mitmproxy/contentviews/base.py b/mitmproxy/contentviews/base.py index d8baa9f25..9788eb688 100644 --- a/mitmproxy/contentviews/base.py +++ b/mitmproxy/contentviews/base.py @@ -1,7 +1,12 @@ # Default view cutoff *in lines* -from abc import ABC, abstractmethod -from collections.abc import Iterable, Iterator, Mapping -from typing import ClassVar, Optional, Union +from abc import ABC +from abc import abstractmethod +from collections.abc import Iterable +from collections.abc import Iterator +from collections.abc import Mapping +from typing import ClassVar +from typing import Optional +from typing import Union from mitmproxy import flow from mitmproxy import http diff --git a/mitmproxy/contentviews/graphql.py b/mitmproxy/contentviews/graphql.py index c179828e8..198082e80 100644 --- a/mitmproxy/contentviews/graphql.py +++ b/mitmproxy/contentviews/graphql.py @@ -1,8 +1,10 @@ import json -from typing import Any, Optional +from typing import Any +from typing import Optional from mitmproxy.contentviews import base -from mitmproxy.contentviews.json import parse_json, PARSE_ERROR +from mitmproxy.contentviews.json import PARSE_ERROR +from mitmproxy.contentviews.json import parse_json def format_graphql(data): diff --git a/mitmproxy/contentviews/grpc.py b/mitmproxy/contentviews/grpc.py index 332310af2..faa60079e 100644 --- a/mitmproxy/contentviews/grpc.py +++ b/mitmproxy/contentviews/grpc.py @@ -2,11 +2,17 @@ from __future__ import annotations import logging import struct -from dataclasses import dataclass, field +from collections.abc import Generator +from collections.abc import Iterable +from collections.abc import Iterator +from dataclasses import dataclass +from dataclasses import field from enum import Enum -from typing import Generator, Iterable, Iterator -from mitmproxy import contentviews, flow, flowfilter, http +from mitmproxy import contentviews +from mitmproxy import flow +from mitmproxy import flowfilter +from mitmproxy import http from mitmproxy.contentviews import base from mitmproxy.net.encoding import decode @@ -259,7 +265,9 @@ class ProtoParser: packed_field: ProtoParser.Field, ) -> list[ProtoParser.Field]: if not isinstance(packed_field.wire_value, bytes): - raise ValueError(f"can not unpack field with data other than bytes: {type(packed_field.wire_value)}") + raise ValueError( + f"can not unpack field with data other than bytes: {type(packed_field.wire_value)}" + ) wire_data: bytes = packed_field.wire_value tag: int = packed_field.tag options: ProtoParser.ParserOptions = packed_field.options @@ -953,7 +961,9 @@ def format_grpc( @dataclass class ViewConfig: - parser_options: ProtoParser.ParserOptions = field(default_factory=ProtoParser.ParserOptions) + parser_options: ProtoParser.ParserOptions = field( + default_factory=ProtoParser.ParserOptions + ) parser_rules: list[ProtoParser.ParserRule] = field(default_factory=list) diff --git a/mitmproxy/contentviews/hex.py b/mitmproxy/contentviews/hex.py index 5b53202c6..c5079929c 100644 --- a/mitmproxy/contentviews/hex.py +++ b/mitmproxy/contentviews/hex.py @@ -1,5 +1,5 @@ -from mitmproxy.utils import strutils from . import base +from mitmproxy.utils import strutils class ViewHex(base.View): diff --git a/mitmproxy/contentviews/http3.py b/mitmproxy/contentviews/http3.py index df014b99d..423772034 100644 --- a/mitmproxy/contentviews/http3.py +++ b/mitmproxy/contentviews/http3.py @@ -1,22 +1,27 @@ from collections import defaultdict from collections.abc import Iterator -from dataclasses import dataclass, field -from typing import Optional, Union +from dataclasses import dataclass +from dataclasses import field +from typing import Optional +from typing import Union -from aioquic.h3.connection import Setting, parse_settings - -from mitmproxy import flow, tcp -from . import base -from .hex import ViewHex -from ..proxy.layers.http import is_h3_alpn - -from aioquic.buffer import Buffer, BufferReadError import pylsqpack +from aioquic.buffer import Buffer +from aioquic.buffer import BufferReadError +from aioquic.h3.connection import parse_settings +from aioquic.h3.connection import Setting + +from . import base +from ..proxy.layers.http import is_h3_alpn +from .hex import ViewHex +from mitmproxy import flow +from mitmproxy import tcp @dataclass(frozen=True) class Frame: """Representation of an HTTP/3 frame.""" + type: int data: bytes @@ -27,10 +32,7 @@ class Frame: elif self.type == 1: try: hdrs = pylsqpack.Decoder(4096, 16).feed_header(0, self.data)[1] - return [ - [("header", "HEADERS Frame")], - *base.format_pairs(hdrs) - ] + return [[("header", "HEADERS Frame")], *base.format_pairs(hdrs)] except Exception as e: frame_name = f"HEADERS Frame (error: {e})" elif self.type == 4: @@ -46,10 +48,7 @@ class Frame: except ValueError: key = f"0x{k:x}" settings.append((key, f"0x{v:x}")) - return [ - [("header", "SETTINGS Frame")], - *base.format_pairs(settings) - ] + return [[("header", "SETTINGS Frame")], *base.format_pairs(settings)] return [ [("header", frame_name)], *ViewHex._format(self.data), @@ -59,6 +58,7 @@ class Frame: @dataclass(frozen=True) class StreamType: """Representation of an HTTP/3 stream types.""" + type: int def pretty(self): @@ -68,9 +68,7 @@ class StreamType: 0x02: "QPACK Encoder Stream", 0x03: "QPACK Decoder Stream", }.get(self.type, f"0x{self.type:x} Stream") - return [ - [("header", stream_type)] - ] + return [[("header", stream_type)]] @dataclass @@ -85,21 +83,23 @@ class ViewHttp3(base.View): name = "HTTP/3 Frames" def __init__(self) -> None: - self.connections: defaultdict[tcp.TCPFlow, ConnectionState] = defaultdict(ConnectionState) + self.connections: defaultdict[tcp.TCPFlow, ConnectionState] = defaultdict( + ConnectionState + ) def __call__( self, data, flow: Optional[flow.Flow] = None, tcp_message: Optional[tcp.TCPMessage] = None, - **metadata + **metadata, ): assert isinstance(flow, tcp.TCPFlow) assert tcp_message state = self.connections[flow] - for message in flow.messages[state.message_count:]: + for message in flow.messages[state.message_count :]: if message.from_client: buf = state.client_buf else: @@ -111,9 +111,7 @@ class ViewHttp3(base.View): stream_type = h3_buf.pull_uint_var() consumed = h3_buf.tell() del buf[:consumed] - state.frames[0] = [ - StreamType(stream_type) - ] + state.frames[0] = [StreamType(stream_type)] while True: h3_buf = Buffer(data=bytes(buf[:16])) @@ -128,29 +126,33 @@ class ViewHttp3(base.View): if len(buf) < consumed + frame_size: break - frame_data = bytes(buf[consumed:consumed + frame_size]) + frame_data = bytes(buf[consumed : consumed + frame_size]) frame = Frame(frame_type, frame_data) state.frames.setdefault(state.message_count, []).append(frame) - del buf[:consumed + frame_size] + del buf[: consumed + frame_size] state.message_count += 1 frames = state.frames.get(flow.messages.index(tcp_message), []) if not frames: - return "HTTP/3", [] # base.format_text(f"(no complete frames here, {state=})") + return ( + "HTTP/3", + [], + ) # base.format_text(f"(no complete frames here, {state=})") else: return "HTTP/3", fmt_frames(frames) def render_priority( - self, - data: bytes, - flow: Optional[flow.Flow] = None, - **metadata + self, data: bytes, flow: Optional[flow.Flow] = None, **metadata ) -> float: - return 2 * float(bool(flow and is_h3_alpn(flow.client_conn.alpn))) * float(isinstance(flow, tcp.TCPFlow)) + return ( + 2 + * float(bool(flow and is_h3_alpn(flow.client_conn.alpn))) + * float(isinstance(flow, tcp.TCPFlow)) + ) def fmt_frames(frames: list[Union[Frame, StreamType]]) -> Iterator[base.TViewLine]: diff --git a/mitmproxy/contentviews/image/view.py b/mitmproxy/contentviews/image/view.py index a414a1a7f..5d621133f 100644 --- a/mitmproxy/contentviews/image/view.py +++ b/mitmproxy/contentviews/image/view.py @@ -1,9 +1,9 @@ import imghdr from typing import Optional +from . import image_parser from mitmproxy.contentviews import base from mitmproxy.coretypes import multidict -from . import image_parser def test_ico(h, f): diff --git a/mitmproxy/contentviews/javascript.py b/mitmproxy/contentviews/javascript.py index de0466838..33ecac2ae 100644 --- a/mitmproxy/contentviews/javascript.py +++ b/mitmproxy/contentviews/javascript.py @@ -2,8 +2,8 @@ import io import re from typing import Optional -from mitmproxy.utils import strutils from mitmproxy.contentviews import base +from mitmproxy.utils import strutils DELIMITERS = "{};\n" SPECIAL_AREAS = ( diff --git a/mitmproxy/contentviews/json.py b/mitmproxy/contentviews/json.py index d8952e80b..23ec86a0d 100644 --- a/mitmproxy/contentviews/json.py +++ b/mitmproxy/contentviews/json.py @@ -1,8 +1,9 @@ -import re import json +import re from collections.abc import Iterator from functools import lru_cache -from typing import Any, Optional +from typing import Any +from typing import Optional from mitmproxy.contentviews import base @@ -28,7 +29,11 @@ def format_json(data: Any) -> Iterator[base.TViewLine]: yield current_line current_line = [] if re.match(r'\s*"', chunk): - if len(current_line) == 1 and current_line[0][0] == "text" and current_line[0][1].isspace(): + if ( + len(current_line) == 1 + and current_line[0][0] == "text" + and current_line[0][1].isspace() + ): current_line.append(("Token_Name_Tag", chunk)) else: current_line.append(("Token_Literal_String", chunk)) diff --git a/mitmproxy/contentviews/mqtt.py b/mitmproxy/contentviews/mqtt.py index 1b870341c..1c3b92a37 100644 --- a/mitmproxy/contentviews/mqtt.py +++ b/mitmproxy/contentviews/mqtt.py @@ -1,10 +1,9 @@ +import struct from typing import Optional from mitmproxy.contentviews import base from mitmproxy.utils import strutils -import struct - # from https://github.com/nikitastupin/mitmproxy-mqtt-script @@ -211,9 +210,13 @@ Password: {strutils.bytes_to_escaped_str(self.payload.get('Password', b'None'))} self.payload["WillTopic"] = f.decode("utf-8") elif self.connect_flags["Will"] and "WillMessage" not in self.payload: self.payload["WillMessage"] = f - elif self.connect_flags["UserName"] and "UserName" not in self.payload: # pragma: no cover + elif ( + self.connect_flags["UserName"] and "UserName" not in self.payload + ): # pragma: no cover self.payload["UserName"] = f.decode("utf-8") - elif self.connect_flags["Password"] and "Password" not in self.payload: # pragma: no cover + elif ( + self.connect_flags["Password"] and "Password" not in self.payload + ): # pragma: no cover self.payload["Password"] = f else: raise AssertionError(f"Unknown field in CONNECT payload: {f}") diff --git a/mitmproxy/contentviews/msgpack.py b/mitmproxy/contentviews/msgpack.py index 7e845bd11..92aeb8b39 100644 --- a/mitmproxy/contentviews/msgpack.py +++ b/mitmproxy/contentviews/msgpack.py @@ -1,8 +1,8 @@ -from typing import Any, Optional +from typing import Any +from typing import Optional import msgpack - from mitmproxy.contentviews import base PARSE_ERROR = object() @@ -15,14 +15,16 @@ def parse_msgpack(s: bytes) -> Any: return PARSE_ERROR -def format_msgpack(data: Any, output = None, indent_count: int = 0) -> list[base.TViewLine]: +def format_msgpack( + data: Any, output=None, indent_count: int = 0 +) -> list[base.TViewLine]: if output is None: output = [[]] indent = ("text", " " * indent_count) if type(data) is str: - token = [("Token_Literal_String", f"\"{data}\"")] + token = [("Token_Literal_String", f'"{data}"')] output[-1] += token # Need to return if single value, but return is discarded in dict/list loop @@ -43,7 +45,14 @@ def format_msgpack(data: Any, output = None, indent_count: int = 0) -> list[base elif type(data) is dict: output[-1] += [("text", "{")] for key in data: - output.append([indent, ("text", " "), ("Token_Name_Tag", f'"{key}"'), ("text", ": ")]) + output.append( + [ + indent, + ("text", " "), + ("Token_Name_Tag", f'"{key}"'), + ("text", ": "), + ] + ) format_msgpack(data[key], output, indent_count + 1) if key != list(data)[-1]: diff --git a/mitmproxy/contentviews/multipart.py b/mitmproxy/contentviews/multipart.py index 9485824ca..a8cef5f66 100644 --- a/mitmproxy/contentviews/multipart.py +++ b/mitmproxy/contentviews/multipart.py @@ -1,8 +1,8 @@ from typing import Optional +from . import base from mitmproxy.coretypes import multidict from mitmproxy.net.http import multipart -from . import base class ViewMultipart(base.View): diff --git a/mitmproxy/contentviews/protobuf.py b/mitmproxy/contentviews/protobuf.py index 50d349eb5..0d836be10 100644 --- a/mitmproxy/contentviews/protobuf.py +++ b/mitmproxy/contentviews/protobuf.py @@ -2,6 +2,7 @@ import io from typing import Optional from kaitaistruct import KaitaiStream + from . import base from mitmproxy.contrib.kaitaistruct import google_protobuf @@ -26,7 +27,9 @@ def _parse_proto(raw: bytes) -> list[google_protobuf.GoogleProtobuf.Pair]: """Parse a bytestring into protobuf pairs and make sure that all pairs have a valid wire type.""" buf = google_protobuf.GoogleProtobuf(KaitaiStream(io.BytesIO(raw))) for pair in buf.pairs: - if not isinstance(pair.wire_type, google_protobuf.GoogleProtobuf.Pair.WireTypes): + if not isinstance( + pair.wire_type, google_protobuf.GoogleProtobuf.Pair.WireTypes + ): raise ValueError("Not a protobuf.") return buf.pairs diff --git a/mitmproxy/contentviews/raw.py b/mitmproxy/contentviews/raw.py index a0b0884ec..c19872534 100644 --- a/mitmproxy/contentviews/raw.py +++ b/mitmproxy/contentviews/raw.py @@ -1,5 +1,5 @@ -from mitmproxy.utils import strutils from . import base +from mitmproxy.utils import strutils class ViewRaw(base.View): diff --git a/mitmproxy/contentviews/urlencoded.py b/mitmproxy/contentviews/urlencoded.py index 2988d8527..27e4e83d2 100644 --- a/mitmproxy/contentviews/urlencoded.py +++ b/mitmproxy/contentviews/urlencoded.py @@ -1,7 +1,7 @@ from typing import Optional -from mitmproxy.net.http import url from . import base +from mitmproxy.net.http import url class ViewURLEncoded(base.View): diff --git a/mitmproxy/contentviews/wbxml.py b/mitmproxy/contentviews/wbxml.py index 4cd7fda89..3faaa86d2 100644 --- a/mitmproxy/contentviews/wbxml.py +++ b/mitmproxy/contentviews/wbxml.py @@ -1,7 +1,7 @@ from typing import Optional -from mitmproxy.contrib.wbxml import ASCommandResponse from . import base +from mitmproxy.contrib.wbxml import ASCommandResponse class ViewWBXML(base.View): diff --git a/mitmproxy/contentviews/xml_html.py b/mitmproxy/contentviews/xml_html.py index b8e8f05f6..747b5b9c3 100644 --- a/mitmproxy/contentviews/xml_html.py +++ b/mitmproxy/contentviews/xml_html.py @@ -1,10 +1,12 @@ import io import re import textwrap -from typing import Iterable, Optional +from collections.abc import Iterable +from typing import Optional from mitmproxy.contentviews import base -from mitmproxy.utils import sliding_window, strutils +from mitmproxy.utils import sliding_window +from mitmproxy.utils import strutils """ A custom XML/HTML prettifier. Compared to other prettifiers, its main features are: @@ -46,7 +48,7 @@ class Token: self.data = data def __repr__(self): - return "{}({})".format(type(self).__name__, self.data) + return f"{type(self).__name__}({self.data})" class Text(Token): diff --git a/mitmproxy/coretypes/multidict.py b/mitmproxy/coretypes/multidict.py index 6710346e6..8a90c7327 100644 --- a/mitmproxy/coretypes/multidict.py +++ b/mitmproxy/coretypes/multidict.py @@ -1,6 +1,8 @@ from abc import ABCMeta from abc import abstractmethod -from collections.abc import Iterator, MutableMapping, Sequence +from collections.abc import Iterator +from collections.abc import MutableMapping +from collections.abc import Sequence from typing import TypeVar from mitmproxy.coretypes import serializable diff --git a/mitmproxy/coretypes/serializable.py b/mitmproxy/coretypes/serializable.py index 8bfd959d1..f91d202c8 100644 --- a/mitmproxy/coretypes/serializable.py +++ b/mitmproxy/coretypes/serializable.py @@ -10,8 +10,10 @@ from typing import TypeVar try: from types import UnionType, NoneType except ImportError: # pragma: no cover + class UnionType: # type: ignore pass + NoneType = type(None) # type: ignore T = TypeVar("T", bound="Serializable") @@ -59,7 +61,6 @@ U = TypeVar("U", bound="SerializableDataclass") class SerializableDataclass(Serializable): - @classmethod @cache def __fields(cls) -> tuple[dataclasses.Field, ...]: @@ -111,7 +112,9 @@ class SerializableDataclass(Serializable): raise if state: - raise ValueError(f"Unexpected fields in {type(self).__name__}.set_state: {state}") + raise ValueError( + f"Unexpected fields in {type(self).__name__}.set_state: {state}" + ) V = TypeVar("V") @@ -121,11 +124,15 @@ def _process(attr_val: typing.Any, attr_type: type[V], attr_name: str, make: boo origin = typing.get_origin(attr_type) if origin is typing.Literal: if attr_val not in typing.get_args(attr_type): - raise ValueError(f"Invalid value for {attr_name}: {attr_val!r} does not match any literal value.") + raise ValueError( + f"Invalid value for {attr_name}: {attr_val!r} does not match any literal value." + ) return attr_val if origin in (UnionType, typing.Union): attr_type, nt = typing.get_args(attr_type) - assert nt is NoneType, f"{attr_name}: only `x | None` union types are supported`" # noqa + assert ( + nt is NoneType + ), f"{attr_name}: only `x | None` union types are supported`" if attr_val is None: return None # type: ignore else: @@ -146,24 +153,32 @@ def _process(attr_val: typing.Any, attr_type: type[V], attr_name: str, make: boo # We don't have a good way to represent tuple[str,int] | tuple[str,int,int,int], so we do a dirty hack here. if attr_name in ("peername", "sockname"): return tuple( - _process(x, T, attr_name, make) for x, T in zip(attr_val, [str, int, int, int]) + _process(x, T, attr_name, make) + for x, T in zip(attr_val, [str, int, int, int]) ) # type: ignore Ts = typing.get_args(attr_type) if len(Ts) != len(attr_val): - raise ValueError(f"Invalid data for {attr_name}. Expected {Ts}, got {attr_val}.") + raise ValueError( + f"Invalid data for {attr_name}. Expected {Ts}, got {attr_val}." + ) return tuple(_process(x, T, attr_name, make) for T, x in zip(Ts, attr_val)) # type: ignore elif origin is dict: k_cls, v_cls = typing.get_args(attr_type) return { - _process(k, k_cls, attr_name, make): _process(v, v_cls, attr_name, make) for k, v in attr_val.items() + _process(k, k_cls, attr_name, make): _process(v, v_cls, attr_name, make) + for k, v in attr_val.items() } # type: ignore elif attr_type in (int, float): if not isinstance(attr_val, (int, float)): - raise ValueError(f"Invalid value for {attr_name}. Expected {attr_type}, got {attr_val} ({type(attr_val)}).") + raise ValueError( + f"Invalid value for {attr_name}. Expected {attr_type}, got {attr_val} ({type(attr_val)})." + ) return attr_type(attr_val) # type: ignore elif attr_type in (str, bytes, bool): if not isinstance(attr_val, attr_type): - raise ValueError(f"Invalid value for {attr_name}. Expected {attr_type}, got {attr_val} ({type(attr_val)}).") + raise ValueError( + f"Invalid value for {attr_name}. Expected {attr_type}, got {attr_val} ({type(attr_val)})." + ) return attr_type(attr_val) # type: ignore elif isinstance(attr_type, type) and issubclass(attr_type, enum.Enum): if make: diff --git a/mitmproxy/dns.py b/mitmproxy/dns.py index 8ccd48c05..8d0e2879d 100644 --- a/mitmproxy/dns.py +++ b/mitmproxy/dns.py @@ -1,15 +1,21 @@ from __future__ import annotations -from dataclasses import dataclass + import itertools import random import struct -from ipaddress import IPv4Address, IPv6Address import time +from dataclasses import dataclass +from ipaddress import IPv4Address +from ipaddress import IPv6Address from typing import ClassVar from mitmproxy import flow from mitmproxy.coretypes import serializable -from mitmproxy.net.dns import classes, domain_names, op_codes, response_codes, types +from mitmproxy.net.dns import classes +from mitmproxy.net.dns import domain_names +from mitmproxy.net.dns import op_codes +from mitmproxy.net.dns import response_codes +from mitmproxy.net.dns import types # DNS parameters taken from https://www.iana.org/assignments/dns-parameters/dns-parameters.xml diff --git a/mitmproxy/eventsequence.py b/mitmproxy/eventsequence.py index b00feaa34..57cf241d0 100644 --- a/mitmproxy/eventsequence.py +++ b/mitmproxy/eventsequence.py @@ -1,4 +1,6 @@ -from typing import Any, Callable, Iterator +from collections.abc import Iterator +from typing import Any +from typing import Callable from mitmproxy import dns from mitmproxy import flow diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 69514f298..889e3ae44 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -1,10 +1,14 @@ from __future__ import annotations + import asyncio import copy import time import uuid -from dataclasses import dataclass, field -from typing import Any, ClassVar, Optional +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import ClassVar +from typing import Optional from mitmproxy import connection from mitmproxy import exceptions @@ -128,7 +132,9 @@ class Flow(serializable.Serializable): __types: dict[str, type[Flow]] = {} - type: ClassVar[str] # automatically derived from the class name in __init_subclass__ + type: ClassVar[ + str + ] # automatically derived from the class name in __init_subclass__ """The flow type, for example `http`, `tcp`, or `dns`.""" def __init_subclass__(cls, **kwargs): diff --git a/mitmproxy/flowfilter.py b/mitmproxy/flowfilter.py index aaec9e1f5..a596e288a 100644 --- a/mitmproxy/flowfilter.py +++ b/mitmproxy/flowfilter.py @@ -32,15 +32,21 @@ ~c CODE Response code. rex Equivalent to ~u rex """ - import functools import re import sys from collections.abc import Sequence -from typing import ClassVar, Protocol, Union +from typing import ClassVar +from typing import Protocol +from typing import Union + import pyparsing as pp -from mitmproxy import dns, flow, http, tcp, udp +from mitmproxy import dns +from mitmproxy import flow +from mitmproxy import http +from mitmproxy import tcp +from mitmproxy import udp def only(*types): @@ -288,10 +294,16 @@ class FBod(_Rex): @only(http.HTTPFlow, tcp.TCPFlow, udp.UDPFlow, dns.DNSFlow) def __call__(self, f): if isinstance(f, http.HTTPFlow): - if f.request and (content := f.request.get_content(strict=False)) is not None: + if ( + f.request + and (content := f.request.get_content(strict=False)) is not None + ): if self.re.search(content): return True - if f.response and (content := f.response.get_content(strict=False)) is not None: + if ( + f.response + and (content := f.response.get_content(strict=False)) is not None + ): if self.re.search(content): return True if f.websocket: @@ -318,7 +330,10 @@ class FBodRequest(_Rex): @only(http.HTTPFlow, tcp.TCPFlow, udp.UDPFlow, dns.DNSFlow) def __call__(self, f): if isinstance(f, http.HTTPFlow): - if f.request and (content := f.request.get_content(strict=False)) is not None: + if ( + f.request + and (content := f.request.get_content(strict=False)) is not None + ): if self.re.search(content): return True if f.websocket: @@ -342,7 +357,10 @@ class FBodResponse(_Rex): @only(http.HTTPFlow, tcp.TCPFlow, udp.UDPFlow, dns.DNSFlow) def __call__(self, f): if isinstance(f, http.HTTPFlow): - if f.response and (content := f.response.get_content(strict=False)) is not None: + if ( + f.response + and (content := f.response.get_content(strict=False)) is not None + ): if self.re.search(content): return True if f.websocket: diff --git a/mitmproxy/hooks.py b/mitmproxy/hooks.py index 2c6c8574e..4acb92669 100644 --- a/mitmproxy/hooks.py +++ b/mitmproxy/hooks.py @@ -1,8 +1,12 @@ import re import warnings from collections.abc import Sequence -from dataclasses import dataclass, is_dataclass, fields -from typing import Any, ClassVar, TYPE_CHECKING +from dataclasses import dataclass +from dataclasses import fields +from dataclasses import is_dataclass +from typing import Any +from typing import ClassVar +from typing import TYPE_CHECKING import mitmproxy.flow diff --git a/mitmproxy/http.py b/mitmproxy/http.py index 640948044..f9f32beb4 100644 --- a/mitmproxy/http.py +++ b/mitmproxy/http.py @@ -1,26 +1,25 @@ import binascii +import json import os import re import time import urllib.parse -import json import warnings +from collections.abc import Iterable +from collections.abc import Iterator +from collections.abc import Mapping from dataclasses import dataclass from dataclasses import fields from email.utils import formatdate from email.utils import mktime_tz from email.utils import parsedate_tz +from typing import Any from typing import Callable -from typing import Iterable -from typing import Iterator -from typing import Mapping +from typing import cast from typing import Optional from typing import Union -from typing import cast -from typing import Any from mitmproxy import flow -from mitmproxy.websocket import WebSocketData from mitmproxy.coretypes import multidict from mitmproxy.coretypes import serializable from mitmproxy.net import encoding @@ -35,6 +34,7 @@ from mitmproxy.utils import strutils from mitmproxy.utils import typecheck from mitmproxy.utils.strutils import always_bytes from mitmproxy.utils.strutils import always_str +from mitmproxy.websocket import WebSocketData # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. @@ -1262,7 +1262,9 @@ class HTTPFlow(flow.Flow): def set_state(self, state: serializable.State) -> None: self.request = Request.from_state(state.pop("request")) self.response = Response.from_state(r) if (r := state.pop("response")) else None - self.websocket = WebSocketData.from_state(w) if (w := state.pop("websocket")) else None + self.websocket = ( + WebSocketData.from_state(w) if (w := state.pop("websocket")) else None + ) super().set_state(state) def __repr__(self): diff --git a/mitmproxy/io/__init__.py b/mitmproxy/io/__init__.py index 8d068d569..541f743ab 100644 --- a/mitmproxy/io/__init__.py +++ b/mitmproxy/io/__init__.py @@ -1,4 +1,7 @@ -from .io import FlowWriter, FlowReader, FilteredFlowWriter, read_flows_from_paths +from .io import FilteredFlowWriter +from .io import FlowReader +from .io import FlowWriter +from .io import read_flows_from_paths __all__ = ["FlowWriter", "FlowReader", "FilteredFlowWriter", "read_flows_from_paths"] diff --git a/mitmproxy/io/compat.py b/mitmproxy/io/compat.py index 466c14e03..9a923c67a 100644 --- a/mitmproxy/io/compat.py +++ b/mitmproxy/io/compat.py @@ -7,7 +7,8 @@ version number, this prevents issues with developer builds and snapshots. """ import copy import uuid -from typing import Any, Union +from typing import Any +from typing import Union from mitmproxy import version from mitmproxy.utils import strutils @@ -406,7 +407,9 @@ def convert_18_19(data): for name in ["peername", "sockname", "address"]: if data[conn].get(name) and isinstance(data[conn][name][0], bytes): - data[conn][name][0] = data[conn][name][0].decode(errors="backslashreplace") + data[conn][name][0] = data[conn][name][0].decode( + errors="backslashreplace" + ) if data["server_conn"]["sni"] is True: data["server_conn"]["sni"] = data["server_conn"]["address"][0] diff --git a/mitmproxy/io/io.py b/mitmproxy/io/io.py index cd2b095bb..f5957bf72 100644 --- a/mitmproxy/io/io.py +++ b/mitmproxy/io/io.py @@ -1,5 +1,9 @@ import os -from typing import Any, BinaryIO, Iterable, Union, cast +from collections.abc import Iterable +from typing import Any +from typing import BinaryIO +from typing import cast +from typing import Union from mitmproxy import exceptions from mitmproxy import flow diff --git a/mitmproxy/io/tnetstring.py b/mitmproxy/io/tnetstring.py index e08a729eb..b11580e9e 100644 --- a/mitmproxy/io/tnetstring.py +++ b/mitmproxy/io/tnetstring.py @@ -39,9 +39,9 @@ all other strings are returned as plain bytes. :License: MIT """ - import collections -from typing import BinaryIO, Union +from typing import BinaryIO +from typing import Union TSerializable = Union[None, str, bool, int, float, bytes, list, tuple, dict] diff --git a/mitmproxy/log.py b/mitmproxy/log.py index 05b609c7d..d56d5f29d 100644 --- a/mitmproxy/log.py +++ b/mitmproxy/log.py @@ -1,10 +1,12 @@ from __future__ import annotations + import logging import os import warnings from dataclasses import dataclass -from mitmproxy import hooks, master +from mitmproxy import hooks +from mitmproxy import master from mitmproxy.contrib import click as miniclick from mitmproxy.utils import human @@ -42,7 +44,7 @@ class MitmFormatter(logging.Formatter): self.without_client = f"{time} %s" default_time_format = "%H:%M:%S" - default_msec_format = '%s.%03d' + default_msec_format = "%s.%03d" def format(self, record: logging.LogRecord) -> str: time = self.formatTime(record) @@ -67,19 +69,18 @@ class MitmLogHandler(logging.Handler): def filter(self, record: logging.LogRecord) -> bool: # We can't remove stale handlers here because that would modify .handlers during iteration! - return ( - super().filter(record) - and - ( - not self._initiated_in_test - or self._initiated_in_test == os.environ.get("PYTEST_CURRENT_TEST") - ) + return super().filter(record) and ( + not self._initiated_in_test + or self._initiated_in_test == os.environ.get("PYTEST_CURRENT_TEST") ) def install(self) -> None: if self._initiated_in_test: for h in list(logging.getLogger().handlers): - if isinstance(h, MitmLogHandler) and h._initiated_in_test != self._initiated_in_test: + if ( + isinstance(h, MitmLogHandler) + and h._initiated_in_test != self._initiated_in_test + ): h.uninstall() logging.getLogger().addHandler(self) @@ -90,6 +91,7 @@ class MitmLogHandler(logging.Handler): # everything below is deprecated! + class LogEntry: def __init__(self, msg, level): # it's important that we serialize to string here already so that we don't pick up changes @@ -194,6 +196,7 @@ LOGGING_LEVELS_TO_LOGENTRY = { class LegacyLogEvents(MitmLogHandler): """Emit deprecated `add_log` events from stdlib logging.""" + def __init__( self, master: master.Master, diff --git a/mitmproxy/master.py b/mitmproxy/master.py index 0947e6eb0..a8e6bfc34 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -3,14 +3,15 @@ import logging import traceback from typing import Optional -from mitmproxy import addonmanager, hooks +from . import ctx as mitmproxy_ctx +from .proxy.mode_specs import ReverseMode +from mitmproxy import addonmanager from mitmproxy import command from mitmproxy import eventsequence +from mitmproxy import hooks from mitmproxy import http from mitmproxy import log from mitmproxy import options -from . import ctx as mitmproxy_ctx -from .proxy.mode_specs import ReverseMode logger = logging.getLogger(__name__) @@ -22,7 +23,11 @@ class Master: event_loop: asyncio.AbstractEventLoop - def __init__(self, opts: options.Options, event_loop: Optional[asyncio.AbstractEventLoop] = None): + def __init__( + self, + opts: options.Options, + event_loop: Optional[asyncio.AbstractEventLoop] = None, + ): self.options: options.Options = opts or options.Options() self.commands = command.CommandManager(self) self.addons = addonmanager.AddonManager(self) @@ -102,7 +107,11 @@ class Master: Loads a flow """ - if isinstance(f, http.HTTPFlow) and len(self.options.mode) == 1 and self.options.mode[0].startswith("reverse:"): + if ( + isinstance(f, http.HTTPFlow) + and len(self.options.mode) == 1 + and self.options.mode[0].startswith("reverse:") + ): # When we load flows in reverse proxy mode, we adjust the target host to # the reverse proxy destination for all flows we load. This makes it very # easy to replay saved flows against a different host. diff --git a/mitmproxy/net/check.py b/mitmproxy/net/check.py index 9a0bdec49..476170032 100644 --- a/mitmproxy/net/check.py +++ b/mitmproxy/net/check.py @@ -1,11 +1,11 @@ import ipaddress import re +from typing import AnyStr # Allow underscore in host name # Note: This could be a DNS label, a hostname, a FQDN, or an IP -from typing import AnyStr -_label_valid = re.compile(br"[A-Z\d\-_]{1,63}$", re.IGNORECASE) +_label_valid = re.compile(rb"[A-Z\d\-_]{1,63}$", re.IGNORECASE) def is_valid_host(host: AnyStr) -> bool: diff --git a/mitmproxy/net/encoding.py b/mitmproxy/net/encoding.py index 32e61b62c..29553651f 100644 --- a/mitmproxy/net/encoding.py +++ b/mitmproxy/net/encoding.py @@ -1,13 +1,13 @@ """ Utility functions for decoding response bodies. """ - import codecs import collections import gzip import zlib from io import BytesIO -from typing import Union, overload +from typing import overload +from typing import Union import brotli import zstandard as zstd @@ -184,7 +184,7 @@ def decode_zstd(content: bytes) -> bytes: except zstd.ZstdError: # If the zstd stream is streamed without a size header, # try decoding with a 10MiB output buffer - return zstd_ctx.decompress(content, max_output_size=10 * 2 ** 20) + return zstd_ctx.decompress(content, max_output_size=10 * 2**20) def encode_zstd(content: bytes) -> bytes: diff --git a/mitmproxy/net/http/cookies.py b/mitmproxy/net/http/cookies.py index 4b2ddd941..3e961ae83 100644 --- a/mitmproxy/net/http/cookies.py +++ b/mitmproxy/net/http/cookies.py @@ -1,7 +1,7 @@ import email.utils import re import time -from typing import Iterable +from collections.abc import Iterable from mitmproxy.coretypes import multidict diff --git a/mitmproxy/net/http/headers.py b/mitmproxy/net/http/headers.py index e3c00994a..6204040aa 100644 --- a/mitmproxy/net/http/headers.py +++ b/mitmproxy/net/http/headers.py @@ -33,4 +33,4 @@ def assemble_content_type(type, subtype, parameters): if not parameters: return f"{type}/{subtype}" params = "; ".join(f"{k}={v}" for k, v in parameters.items()) - return "{}/{}; {}".format(type, subtype, params) + return f"{type}/{subtype}; {params}" diff --git a/mitmproxy/net/http/http1/__init__.py b/mitmproxy/net/http/http1/__init__.py index 3049e02fb..b9b6e071e 100644 --- a/mitmproxy/net/http/http1/__init__.py +++ b/mitmproxy/net/http/http1/__init__.py @@ -1,17 +1,13 @@ -from .read import ( - read_request_head, - read_response_head, - connection_close, - expected_http_body_size, - validate_headers, -) -from .assemble import ( - assemble_request, - assemble_request_head, - assemble_response, - assemble_response_head, - assemble_body, -) +from .assemble import assemble_body +from .assemble import assemble_request +from .assemble import assemble_request_head +from .assemble import assemble_response +from .assemble import assemble_response_head +from .read import connection_close +from .read import expected_http_body_size +from .read import read_request_head +from .read import read_response_head +from .read import validate_headers __all__ = [ diff --git a/mitmproxy/net/http/http1/read.py b/mitmproxy/net/http/http1/read.py index 1da4583e1..2986c489d 100644 --- a/mitmproxy/net/http/http1/read.py +++ b/mitmproxy/net/http/http1/read.py @@ -1,8 +1,11 @@ import re import time -from typing import Iterable, Optional +from collections.abc import Iterable +from typing import Optional -from mitmproxy.http import Request, Headers, Response +from mitmproxy.http import Headers +from mitmproxy.http import Request +from mitmproxy.http import Response from mitmproxy.net.http import url @@ -214,7 +217,7 @@ def expected_http_body_size( def raise_if_http_version_unknown(http_version: bytes) -> None: - if not re.match(br"^HTTP/\d\.\d$", http_version): + if not re.match(rb"^HTTP/\d\.\d$", http_version): raise ValueError(f"Unknown HTTP version: {http_version!r}") diff --git a/mitmproxy/net/http/multipart.py b/mitmproxy/net/http/multipart.py index 007999587..4685d80e0 100644 --- a/mitmproxy/net/http/multipart.py +++ b/mitmproxy/net/http/multipart.py @@ -56,7 +56,7 @@ def decode(content_type: Optional[str], content: bytes) -> list[tuple[bytes, byt except (KeyError, UnicodeError): return [] - rx = re.compile(br'\bname="([^"]+)"') + rx = re.compile(rb'\bname="([^"]+)"') r = [] if content is not None: for i in content.split(b"--" + boundary): diff --git a/mitmproxy/net/http/url.py b/mitmproxy/net/http/url.py index 274f229fb..abc038abf 100644 --- a/mitmproxy/net/http/url.py +++ b/mitmproxy/net/http/url.py @@ -1,16 +1,19 @@ from __future__ import annotations + import re import urllib.parse from collections.abc import Sequence -from typing import AnyStr, Optional +from typing import AnyStr +from typing import Optional from mitmproxy.net import check +from mitmproxy.net.check import is_valid_host +from mitmproxy.net.check import is_valid_port +from mitmproxy.utils.strutils import always_str # This regex extracts & splits the host header into host and port. # Handles the edge case of IPv6 addresses containing colons. # https://bugzilla.mozilla.org/show_bug.cgi?id=45891 -from mitmproxy.net.check import is_valid_host, is_valid_port -from mitmproxy.utils.strutils import always_str _authority_re = re.compile(r"^(?P[^:]+|\[.+\])(?::(?P\d+))?$") diff --git a/mitmproxy/net/http/user_agents.py b/mitmproxy/net/http/user_agents.py index 58aa21eab..6a83f8fb7 100644 --- a/mitmproxy/net/http/user_agents.py +++ b/mitmproxy/net/http/user_agents.py @@ -2,9 +2,7 @@ A small collection of useful user-agent header strings. These should be kept reasonably current to reflect common usage. """ - # pylint: line-too-long - # A collection of (name, shortcut, string) tuples. UASTRINGS = [ diff --git a/mitmproxy/net/local_ip.py b/mitmproxy/net/local_ip.py index 27468c05c..bc3087263 100644 --- a/mitmproxy/net/local_ip.py +++ b/mitmproxy/net/local_ip.py @@ -1,4 +1,5 @@ from __future__ import annotations + import socket diff --git a/mitmproxy/net/server_spec.py b/mitmproxy/net/server_spec.py index 945565f19..f0d5edd60 100644 --- a/mitmproxy/net/server_spec.py +++ b/mitmproxy/net/server_spec.py @@ -9,7 +9,7 @@ from mitmproxy.net import check ServerSpec = tuple[ Literal["http", "https", "http3", "tls", "dtls", "tcp", "udp", "dns", "quic"], - tuple[str, int] + tuple[str, int], ] server_spec_re = re.compile( @@ -45,7 +45,17 @@ def parse(server_spec: str, default_scheme: str) -> ServerSpec: scheme = m.group("scheme") else: scheme = default_scheme - if scheme not in ("http", "https", "http3", "tls", "dtls", "tcp", "udp", "dns", "quic"): + if scheme not in ( + "http", + "https", + "http3", + "tls", + "dtls", + "tcp", + "udp", + "dns", + "quic", + ): raise ValueError(f"Invalid server scheme: {scheme}") host = m.group("host") diff --git a/mitmproxy/net/tls.py b/mitmproxy/net/tls.py index d87dc1a67..59fd229e3 100644 --- a/mitmproxy/net/tls.py +++ b/mitmproxy/net/tls.py @@ -1,15 +1,18 @@ import os import threading +from collections.abc import Iterable from enum import Enum from functools import lru_cache from pathlib import Path -from typing import Any, BinaryIO, Callable, Iterable, Optional +from typing import Any +from typing import BinaryIO +from typing import Callable +from typing import Optional import certifi - +from OpenSSL import SSL from OpenSSL.crypto import X509 -from OpenSSL import SSL from mitmproxy import certs @@ -18,8 +21,8 @@ class Method(Enum): TLS_SERVER_METHOD = SSL.TLS_SERVER_METHOD TLS_CLIENT_METHOD = SSL.TLS_CLIENT_METHOD # Type-pyopenssl does not know about these DTLS constants. - DTLS_SERVER_METHOD = SSL.DTLS_SERVER_METHOD # type: ignore - DTLS_CLIENT_METHOD = SSL.DTLS_CLIENT_METHOD # type: ignore + DTLS_SERVER_METHOD = SSL.DTLS_SERVER_METHOD # type: ignore + DTLS_CLIENT_METHOD = SSL.DTLS_CLIENT_METHOD # type: ignore try: diff --git a/mitmproxy/net/udp.py b/mitmproxy/net/udp.py index 00565f6b4..d51aee123 100644 --- a/mitmproxy/net/udp.py +++ b/mitmproxy/net/udp.py @@ -3,7 +3,11 @@ from __future__ import annotations import asyncio import logging import socket -from typing import Any, Callable, Optional, Union, cast +from typing import Any +from typing import Callable +from typing import cast +from typing import Optional +from typing import Union from mitmproxy.connection import Address from mitmproxy.net import udp_wireguard @@ -183,7 +187,9 @@ class DatagramWriter: self._closed = None @property - def _protocol(self) -> DrainableDatagramProtocol | udp_wireguard.WireGuardDatagramTransport: + def _protocol( + self, + ) -> DrainableDatagramProtocol | udp_wireguard.WireGuardDatagramTransport: return self._transport.get_protocol() # type: ignore def write(self, data: bytes) -> None: diff --git a/mitmproxy/options.py b/mitmproxy/options.py index 2ca6ccb58..f3cf11bed 100644 --- a/mitmproxy/options.py +++ b/mitmproxy/options.py @@ -90,11 +90,19 @@ class Options(optmanager.OptManager): """, ) self.add_option("allow_hosts", Sequence[str], [], "Opposite of --ignore-hosts.") - self.add_option("listen_host", str, "", - "Address to bind proxy server(s) to (may be overridden for individual modes, see `mode`).") - self.add_option("listen_port", Optional[int], None, - "Port to bind proxy server(s) to (may be overridden for individual modes, see `mode`). " - "By default, the port is mode-specific. The default regular HTTP proxy spawns on port 8080.") + self.add_option( + "listen_host", + str, + "", + "Address to bind proxy server(s) to (may be overridden for individual modes, see `mode`).", + ) + self.add_option( + "listen_port", + Optional[int], + None, + "Port to bind proxy server(s) to (may be overridden for individual modes, see `mode`). " + "By default, the port is mode-specific. The default regular HTTP proxy spawns on port 8080.", + ) self.add_option( "mode", Sequence[str], diff --git a/mitmproxy/optmanager.py b/mitmproxy/optmanager.py index ae3bb54ac..32d448323 100644 --- a/mitmproxy/optmanager.py +++ b/mitmproxy/optmanager.py @@ -1,18 +1,25 @@ from __future__ import annotations + import contextlib import copy -import weakref -from collections.abc import Callable, Iterable, Sequence -from dataclasses import dataclass import os import pprint import textwrap -from typing import Any, Optional, TextIO, Union +import weakref +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any +from typing import Optional +from typing import TextIO +from typing import Union import ruamel.yaml from mitmproxy import exceptions -from mitmproxy.utils import signals, typecheck +from mitmproxy.utils import signals +from mitmproxy.utils import typecheck """ The base implementation for Options. @@ -150,9 +157,7 @@ class OptManager: if i not in self._options: raise exceptions.OptionsError("No such option: %s" % i) - self._subscriptions.append( - (signals.make_weak_ref(func), set(opts)) - ) + self._subscriptions.append((signals.make_weak_ref(func), set(opts))) def _notify_subscribers(self, updated) -> None: cleanup = False @@ -526,7 +531,7 @@ def parse(text): snip = v.problem_mark.get_snippet() raise exceptions.OptionsError( "Config error at line %s:\n%s\n%s" - % (v.problem_mark.line + 1, snip, getattr(v, 'problem', '')) + % (v.problem_mark.line + 1, snip, getattr(v, "problem", "")) ) else: raise exceptions.OptionsError("Could not parse options.") diff --git a/mitmproxy/platform/__init__.py b/mitmproxy/platform/__init__.py index e6fdcd7c8..0b0c492ad 100644 --- a/mitmproxy/platform/__init__.py +++ b/mitmproxy/platform/__init__.py @@ -1,7 +1,8 @@ import re import socket import sys -from typing import Callable, Optional +from typing import Callable +from typing import Optional def init_transparent_mode() -> None: diff --git a/mitmproxy/platform/windows.py b/mitmproxy/platform/windows.py index 1ff887661..1e065544b 100644 --- a/mitmproxy/platform/windows.py +++ b/mitmproxy/platform/windows.py @@ -1,8 +1,7 @@ from __future__ import annotations -import collections + import collections.abc import contextlib -import ctypes import ctypes.wintypes import json import os @@ -12,12 +11,16 @@ import socketserver import threading import time from collections.abc import Callable -from typing import Any, ClassVar, IO, Optional, cast +from typing import Any +from typing import cast +from typing import ClassVar +from typing import IO +from typing import Optional -import pydivert import pydivert.consts -from mitmproxy.net.local_ip import get_local_ip, get_local_ip6 +from mitmproxy.net.local_ip import get_local_ip +from mitmproxy.net.local_ip import get_local_ip6 REDIRECT_API_HOST = "127.0.0.1" REDIRECT_API_PORT = 8085 @@ -98,7 +101,9 @@ class APIRequestHandler(socketserver.StreamRequestHandler): if c is None: return try: - server = proxifier.client_server_map[cast(tuple[str, int], tuple(c))] + server = proxifier.client_server_map[ + cast(tuple[str, int], tuple(c)) + ] except KeyError: server = None write(server, self.wfile) @@ -397,7 +402,7 @@ class TransparentProxy: local: Optional[RedirectLocal] = None # really weird linting error here. - forward: Optional[Redirect] = None # noqa + forward: Optional[Redirect] = None response: Redirect icmp: Redirect diff --git a/mitmproxy/proxy/commands.py b/mitmproxy/proxy/commands.py index 04b471e02..e6749dccd 100644 --- a/mitmproxy/proxy/commands.py +++ b/mitmproxy/proxy/commands.py @@ -8,10 +8,12 @@ The counterpart to commands are events. """ import logging import warnings -from typing import Union, TYPE_CHECKING +from typing import TYPE_CHECKING +from typing import Union import mitmproxy.hooks -from mitmproxy.connection import Connection, Server +from mitmproxy.connection import Connection +from mitmproxy.connection import Server if TYPE_CHECKING: import mitmproxy.proxy.layer @@ -133,6 +135,7 @@ class Log(Command): This could also be implemented with some more playbook magic in the future, but for now we keep the current approach as the fully sans-io one. """ + message: str level: int @@ -144,7 +147,8 @@ class Log(Command): if isinstance(level, str): # pragma: no cover warnings.warn( "commands.Log() now expects an integer log level, not a string.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) level = getattr(logging, level.upper()) self.message = message diff --git a/mitmproxy/proxy/context.py b/mitmproxy/proxy/context.py index 1be73b39b..29987418f 100644 --- a/mitmproxy/proxy/context.py +++ b/mitmproxy/proxy/context.py @@ -38,8 +38,7 @@ class Context: self.client = client self.options = options self.server = connection.Server( - address=None, - transport_protocol=client.transport_protocol + address=None, transport_protocol=client.transport_protocol ) self.layers = [] diff --git a/mitmproxy/proxy/events.py b/mitmproxy/proxy/events.py index fad8029ae..e741fbcfb 100644 --- a/mitmproxy/proxy/events.py +++ b/mitmproxy/proxy/events.py @@ -5,12 +5,16 @@ The counterpart to events are commands. """ import typing import warnings -from dataclasses import dataclass, is_dataclass -from typing import Any, Generic, Optional, TypeVar +from dataclasses import dataclass +from dataclasses import is_dataclass +from typing import Any +from typing import Generic +from typing import Optional +from typing import TypeVar from mitmproxy import flow -from mitmproxy.proxy import commands from mitmproxy.connection import Connection +from mitmproxy.proxy import commands class Event: diff --git a/mitmproxy/proxy/layer.py b/mitmproxy/proxy/layer.py index d486e9b8f..79aed95c9 100644 --- a/mitmproxy/proxy/layer.py +++ b/mitmproxy/proxy/layer.py @@ -5,13 +5,20 @@ import collections import textwrap from abc import abstractmethod from collections.abc import Callable +from collections.abc import Generator from dataclasses import dataclass from logging import DEBUG -from typing import Any, ClassVar, Generator, NamedTuple, Optional, TypeVar +from typing import Any +from typing import ClassVar +from typing import NamedTuple +from typing import Optional +from typing import TypeVar from mitmproxy.connection import Connection -from mitmproxy.proxy import commands, events -from mitmproxy.proxy.commands import Command, StartHook +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy.commands import Command +from mitmproxy.proxy.commands import StartHook from mitmproxy.proxy.context import Context T = TypeVar("T") diff --git a/mitmproxy/proxy/layers/__init__.py b/mitmproxy/proxy/layers/__init__.py index 349c32cfc..e21ba60e0 100644 --- a/mitmproxy/proxy/layers/__init__.py +++ b/mitmproxy/proxy/layers/__init__.py @@ -1,10 +1,14 @@ from . import modes from .dns import DNSLayer from .http import HttpLayer -from .quic import QuicStreamLayer, RawQuicLayer, ClientQuicLayer, ServerQuicLayer +from .quic import ClientQuicLayer +from .quic import QuicStreamLayer +from .quic import RawQuicLayer +from .quic import ServerQuicLayer from .tcp import TCPLayer +from .tls import ClientTLSLayer +from .tls import ServerTLSLayer from .udp import UDPLayer -from .tls import ClientTLSLayer, ServerTLSLayer from .websocket import WebsocketLayer __all__ = [ diff --git a/mitmproxy/proxy/layers/dns.py b/mitmproxy/proxy/layers/dns.py index 0b85ad05a..e2e5c701b 100644 --- a/mitmproxy/proxy/layers/dns.py +++ b/mitmproxy/proxy/layers/dns.py @@ -1,8 +1,11 @@ -from dataclasses import dataclass import struct +from dataclasses import dataclass -from mitmproxy import dns, flow as mflow -from mitmproxy.proxy import commands, events, layer +from mitmproxy import dns +from mitmproxy import flow as mflow +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer from mitmproxy.proxy.context import Context from mitmproxy.proxy.utils import expect @@ -45,13 +48,17 @@ class DNSLayer(layer.Layer): super().__init__(context) self.flows = {} - def handle_request(self, flow: dns.DNSFlow, msg: dns.Message) -> layer.CommandGenerator[None]: + def handle_request( + self, flow: dns.DNSFlow, msg: dns.Message + ) -> layer.CommandGenerator[None]: flow.request = msg # if already set, continue and query upstream again yield DnsRequestHook(flow) if flow.response: yield from self.handle_response(flow, flow.response) elif not self.context.server.address: - yield from self.handle_error(flow, "No hook has set a response and there is no upstream server.") + yield from self.handle_error( + flow, "No hook has set a response and there is no upstream server." + ) else: if not self.context.server.connected: err = yield commands.OpenConnection(self.context.server) @@ -61,7 +68,9 @@ class DNSLayer(layer.Layer): return yield commands.SendData(self.context.server, flow.request.packed) - def handle_response(self, flow: dns.DNSFlow, msg: dns.Message) -> layer.CommandGenerator[None]: + def handle_response( + self, flow: dns.DNSFlow, msg: dns.Message + ) -> layer.CommandGenerator[None]: flow.response = msg yield DnsResponseHook(flow) if flow.response: @@ -92,7 +101,9 @@ class DNSLayer(layer.Layer): try: flow = self.flows[msg.id] except KeyError: - flow = dns.DNSFlow(self.context.client, self.context.server, live=True) + flow = dns.DNSFlow( + self.context.client, self.context.server, live=True + ) self.flows[msg.id] = flow if from_client: yield from self.handle_request(flow, msg) diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index 09a31024b..9d7cba4ce 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -1,53 +1,68 @@ import collections import enum -from logging import DEBUG, WARNING - import time from dataclasses import dataclass from functools import cached_property -from typing import Optional, Union +from logging import DEBUG +from logging import WARNING +from typing import Optional +from typing import Union import wsproto.handshake -from mitmproxy import flow, http -from mitmproxy.connection import Connection, Server, TransportProtocol + +from ...context import Context +from ...mode_specs import ReverseMode +from ...mode_specs import UpstreamMode +from ..quic import QuicStreamEvent +from ._base import HttpCommand +from ._base import HttpConnection +from ._base import ReceiveHttp +from ._base import StreamId +from ._events import HttpEvent +from ._events import RequestData +from ._events import RequestEndOfMessage +from ._events import RequestHeaders +from ._events import RequestProtocolError +from ._events import RequestTrailers +from ._events import ResponseData +from ._events import ResponseEndOfMessage +from ._events import ResponseHeaders +from ._events import ResponseProtocolError +from ._events import ResponseTrailers +from ._hooks import HttpConnectHook +from ._hooks import HttpErrorHook +from ._hooks import HttpRequestHeadersHook +from ._hooks import HttpRequestHook +from ._hooks import HttpResponseHeadersHook +from ._hooks import HttpResponseHook +from ._http1 import Http1Client +from ._http1 import Http1Connection +from ._http1 import Http1Server +from ._http2 import Http2Client +from ._http2 import Http2Server +from ._http3 import Http3Client +from ._http3 import Http3Server +from mitmproxy import flow +from mitmproxy import http +from mitmproxy.connection import Connection +from mitmproxy.connection import Server +from mitmproxy.connection import TransportProtocol from mitmproxy.net import server_spec -from mitmproxy.net.http import status_codes, url +from mitmproxy.net.http import status_codes +from mitmproxy.net.http import url from mitmproxy.net.http.http1 import expected_http_body_size -from mitmproxy.proxy import commands, events, layer, tunnel -from mitmproxy.proxy.layers import quic, tcp, tls, websocket +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer +from mitmproxy.proxy import tunnel +from mitmproxy.proxy.layers import quic +from mitmproxy.proxy.layers import tcp +from mitmproxy.proxy.layers import tls +from mitmproxy.proxy.layers import websocket from mitmproxy.proxy.layers.http import _upstream_proxy from mitmproxy.proxy.utils import expect from mitmproxy.utils import human from mitmproxy.websocket import WebSocketData -from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId -from ._events import ( - HttpEvent, - RequestData, - RequestEndOfMessage, - RequestHeaders, - RequestProtocolError, - RequestTrailers, - ResponseData, - ResponseEndOfMessage, - ResponseHeaders, - ResponseProtocolError, - ResponseTrailers, -) -from ._hooks import ( # noqa - HttpConnectHook, - HttpConnectUpstreamHook, - HttpErrorHook, - HttpRequestHeadersHook, - HttpRequestHook, - HttpResponseHeadersHook, - HttpResponseHook, -) -from ._http1 import Http1Client, Http1Connection, Http1Server -from ._http2 import Http2Client, Http2Server -from ._http3 import Http3Client, Http3Server -from ..quic import QuicStreamEvent -from ...context import Context -from ...mode_specs import ReverseMode, UpstreamMode class HTTPMode(enum.Enum): @@ -228,7 +243,9 @@ class HttpStream(layer.Layer): "https" if self.context.client.tls else "http" ) - if self.mode is HTTPMode.regular and not (self.flow.request.is_http2 or self.flow.request.is_http3): + if self.mode is HTTPMode.regular and not ( + self.flow.request.is_http2 or self.flow.request.is_http3 + ): # Set the request target to origin-form for HTTP/1, some servers don't support absolute-form requests. # see https://github.com/mitmproxy/mitmproxy/issues/1759 self.flow.request.authority = "" @@ -707,7 +724,7 @@ class HttpStream(layer.Layer): 502, f"Cannot connect to {human.format_address(self.context.server.address)}: {err} " f"If you plan to redirect requests away from this server, " - f"consider setting `connection_strategy` to `lazy` to suppress early connections." + f"consider setting `connection_strategy` to `lazy` to suppress early connections.", ) self.child_layer = layer.NextLayer(self.context) yield from self.handle_connect_finish() @@ -1012,7 +1029,9 @@ class HttpLayer(layer.Layer): if not can_use_context_connection: - context.server = Server(address=event.address, transport_protocol=event.transport_protocol) + context.server = Server( + address=event.address, transport_protocol=event.transport_protocol + ) if event.via: context.server.via = event.via @@ -1034,7 +1053,9 @@ class HttpLayer(layer.Layer): elif context.server.transport_protocol == "udp": stack /= quic.ServerQuicLayer(context) else: - raise AssertionError(context.server.transport_protocol) # pragma: no cover + raise AssertionError( + context.server.transport_protocol + ) # pragma: no cover stack /= HttpClient(context) diff --git a/mitmproxy/proxy/layers/http/_base.py b/mitmproxy/proxy/layers/http/_base.py index b5f66d46b..198fa77fb 100644 --- a/mitmproxy/proxy/layers/http/_base.py +++ b/mitmproxy/proxy/layers/http/_base.py @@ -4,7 +4,9 @@ from dataclasses import dataclass from mitmproxy import http from mitmproxy.connection import Connection -from mitmproxy.proxy import commands, events, layer +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer from mitmproxy.proxy.context import Context StreamId = int diff --git a/mitmproxy/proxy/layers/http/_events.py b/mitmproxy/proxy/layers/http/_events.py index f67217b03..ecdbcd7a2 100644 --- a/mitmproxy/proxy/layers/http/_events.py +++ b/mitmproxy/proxy/layers/http/_events.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from typing import Optional +from ._base import HttpEvent from mitmproxy import http from mitmproxy.http import HTTPFlow -from ._base import HttpEvent @dataclass diff --git a/mitmproxy/proxy/layers/http/_http1.py b/mitmproxy/proxy/layers/http/_http1.py index c7c79cac8..0affd6d68 100644 --- a/mitmproxy/proxy/layers/http/_http1.py +++ b/mitmproxy/proxy/layers/http/_http1.py @@ -1,30 +1,39 @@ import abc -from typing import Callable, Optional, Union +from typing import Callable +from typing import Optional +from typing import Union import h11 -from h11._readers import ChunkedReader, ContentLengthReader, Http10Reader +from h11._readers import ChunkedReader +from h11._readers import ContentLengthReader +from h11._readers import Http10Reader from h11._receivebuffer import ReceiveBuffer -from mitmproxy import http, version -from mitmproxy.connection import Connection, ConnectionState -from mitmproxy.net.http import http1, status_codes -from mitmproxy.proxy import commands, events, layer -from mitmproxy.proxy.layers.http._base import ReceiveHttp, StreamId +from ...context import Context +from ._base import format_error +from ._base import HttpConnection +from ._events import HttpEvent +from ._events import RequestData +from ._events import RequestEndOfMessage +from ._events import RequestHeaders +from ._events import RequestProtocolError +from ._events import ResponseData +from ._events import ResponseEndOfMessage +from ._events import ResponseHeaders +from ._events import ResponseProtocolError +from mitmproxy import http +from mitmproxy import version +from mitmproxy.connection import Connection +from mitmproxy.connection import ConnectionState +from mitmproxy.net.http import http1 +from mitmproxy.net.http import status_codes +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer +from mitmproxy.proxy.layers.http._base import ReceiveHttp +from mitmproxy.proxy.layers.http._base import StreamId from mitmproxy.proxy.utils import expect from mitmproxy.utils import human -from ._base import HttpConnection, format_error -from ._events import ( - HttpEvent, - RequestData, - RequestEndOfMessage, - RequestHeaders, - RequestProtocolError, - ResponseData, - ResponseEndOfMessage, - ResponseHeaders, - ResponseProtocolError, -) -from ...context import Context TBodyReader = Union[ChunkedReader, Http10Reader, ContentLengthReader] @@ -189,7 +198,10 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta): # If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request. # This simplifies our connection management quite a bit as we can rely on # the proxyserver's max-connection-per-server throttling. - or ((self.request.is_http2 or self.request.is_http3) and isinstance(self, Http1Client)) + or ( + (self.request.is_http2 or self.request.is_http3) + and isinstance(self, Http1Client) + ) ) if connection_done: yield commands.CloseConnection(self.conn) @@ -245,7 +257,11 @@ class Http1Server(Http1Connection): elif isinstance(event, ResponseEndOfMessage): assert self.request assert self.response - if self.request.method.upper() != "HEAD" and "chunked" in self.response.headers.get("transfer-encoding", "").lower(): + if ( + self.request.method.upper() != "HEAD" + and "chunked" + in self.response.headers.get("transfer-encoding", "").lower() + ): yield commands.SendData(self.conn, b"0\r\n\r\n") yield from self.mark_done(response=True) elif isinstance(event, ResponseProtocolError): diff --git a/mitmproxy/proxy/layers/http/_http2.py b/mitmproxy/proxy/layers/http/_http2.py index f881612cc..7c7af9dac 100644 --- a/mitmproxy/proxy/layers/http/_http2.py +++ b/mitmproxy/proxy/layers/http/_http2.py @@ -1,10 +1,12 @@ import collections -from logging import DEBUG, ERROR - import time from collections.abc import Sequence from enum import Enum -from typing import ClassVar, Optional, Union +from logging import DEBUG +from logging import ERROR +from typing import ClassVar +from typing import Optional +from typing import Union import h2.config import h2.connection @@ -15,29 +17,40 @@ import h2.settings import h2.stream import h2.utilities -from mitmproxy import http, version -from mitmproxy.connection import Connection -from mitmproxy.net.http import status_codes, url -from mitmproxy.utils import human -from . import ( - RequestData, - RequestEndOfMessage, - RequestHeaders, - RequestProtocolError, - ResponseData, - ResponseEndOfMessage, - ResponseHeaders, - RequestTrailers, - ResponseTrailers, - ResponseProtocolError, -) -from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error -from ._http_h2 import BufferedH2Connection, H2ConnectionLogger -from ...commands import CloseConnection, Log, SendData, RequestWakeup +from . import RequestData +from . import RequestEndOfMessage +from . import RequestHeaders +from . import RequestProtocolError +from . import RequestTrailers +from . import ResponseData +from . import ResponseEndOfMessage +from . import ResponseHeaders +from . import ResponseProtocolError +from . import ResponseTrailers +from ...commands import CloseConnection +from ...commands import Log +from ...commands import RequestWakeup +from ...commands import SendData from ...context import Context -from ...events import ConnectionClosed, DataReceived, Event, Start, Wakeup +from ...events import ConnectionClosed +from ...events import DataReceived +from ...events import Event +from ...events import Start +from ...events import Wakeup from ...layer import CommandGenerator from ...utils import expect +from ._base import format_error +from ._base import HttpConnection +from ._base import HttpEvent +from ._base import ReceiveHttp +from ._http_h2 import BufferedH2Connection +from ._http_h2 import H2ConnectionLogger +from mitmproxy import http +from mitmproxy import version +from mitmproxy.connection import Connection +from mitmproxy.net.http import status_codes +from mitmproxy.net.http import url +from mitmproxy.utils import human class StreamState(Enum): @@ -70,8 +83,7 @@ class Http2Connection(HttpConnection): super().__init__(context, conn) if self.debug: self.h2_conf.logger = H2ConnectionLogger( - self.context.client.peername, - self.__class__.__name__ + self.context.client.peername, self.__class__.__name__ ) self.h2_conf.validate_inbound_headers = ( self.context.options.validate_inbound_headers @@ -374,7 +386,9 @@ class Http2Server(Http2Connection): if self.is_open_for_us(event.stream_id): self.h2_conn.send_headers( event.stream_id, - headers=(yield from format_h2_response_headers(self.context, event)), + headers=( + yield from format_h2_response_headers(self.context, event) + ), end_stream=event.end_stream, ) yield SendData(self.conn, self.h2_conn.data_to_send()) diff --git a/mitmproxy/proxy/layers/http/_http3.py b/mitmproxy/proxy/layers/http/_http3.py index 8a696a24f..ccafe36f4 100644 --- a/mitmproxy/proxy/layers/http/_http3.py +++ b/mitmproxy/proxy/layers/http/_http3.py @@ -1,50 +1,48 @@ -from abc import abstractmethod import time -from typing import Dict, Union +from abc import abstractmethod +from typing import Union -from aioquic.h3.connection import ( - ErrorCode as H3ErrorCode, - FrameUnexpected as H3FrameUnexpected, -) -from aioquic.h3.events import DataReceived, HeadersReceived, PushPromiseReceived +from aioquic.h3.connection import ErrorCode as H3ErrorCode +from aioquic.h3.connection import FrameUnexpected as H3FrameUnexpected +from aioquic.h3.events import DataReceived +from aioquic.h3.events import HeadersReceived +from aioquic.h3.events import PushPromiseReceived -from mitmproxy import connection, http, version +from . import RequestData +from . import RequestEndOfMessage +from . import RequestHeaders +from . import RequestProtocolError +from . import RequestTrailers +from . import ResponseData +from . import ResponseEndOfMessage +from . import ResponseHeaders +from . import ResponseProtocolError +from . import ResponseTrailers +from ._base import format_error +from ._base import HttpConnection +from ._base import HttpEvent +from ._base import ReceiveHttp +from ._http2 import format_h2_request_headers +from ._http2 import format_h2_response_headers +from ._http2 import parse_h2_request_headers +from ._http2 import parse_h2_response_headers +from ._http_h3 import LayeredH3Connection +from ._http_h3 import StreamReset +from ._http_h3 import TrailersReceived +from mitmproxy import connection +from mitmproxy import http +from mitmproxy import version from mitmproxy.net.http import status_codes -from mitmproxy.proxy import commands, context, events, layer -from mitmproxy.proxy.layers.quic import ( - QuicConnectionClosed, - QuicStreamEvent, - StopQuicStream, - error_code_to_str, -) +from mitmproxy.proxy import commands +from mitmproxy.proxy import context +from mitmproxy.proxy import events +from mitmproxy.proxy import layer +from mitmproxy.proxy.layers.quic import error_code_to_str +from mitmproxy.proxy.layers.quic import QuicConnectionClosed +from mitmproxy.proxy.layers.quic import QuicStreamEvent +from mitmproxy.proxy.layers.quic import StopQuicStream from mitmproxy.proxy.utils import expect -from . import ( - RequestData, - RequestEndOfMessage, - RequestHeaders, - RequestProtocolError, - RequestTrailers, - ResponseData, - ResponseEndOfMessage, - ResponseHeaders, - ResponseProtocolError, - ResponseTrailers, -) -from ._base import ( - HttpConnection, - HttpEvent, - ReceiveHttp, - format_error, -) -from ._http2 import ( - format_h2_request_headers, - format_h2_response_headers, - parse_h2_request_headers, - parse_h2_response_headers, -) -from ._http_h3 import LayeredH3Connection, StreamReset, TrailersReceived - class Http3Connection(HttpConnection): h3_conn: LayeredH3Connection @@ -56,7 +54,9 @@ class Http3Connection(HttpConnection): def __init__(self, context: context.Context, conn: connection.Connection): super().__init__(context, conn) - self.h3_conn = LayeredH3Connection(self.conn, is_client=self.conn is self.context.server) + self.h3_conn = LayeredH3Connection( + self.conn, is_client=self.conn is self.context.server + ) self._stream_protocol_errors: dict[int, int] = {} def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: @@ -74,9 +74,13 @@ class Http3Connection(HttpConnection): if isinstance(event, RequestHeaders) else format_h2_response_headers(self.context, event) ) - self.h3_conn.send_headers(event.stream_id, headers, end_stream=event.end_stream) + self.h3_conn.send_headers( + event.stream_id, headers, end_stream=event.end_stream + ) elif isinstance(event, (RequestTrailers, ResponseTrailers)): - self.h3_conn.send_trailers(event.stream_id, [*event.trailers.fields]) + self.h3_conn.send_trailers( + event.stream_id, [*event.trailers.fields] + ) elif isinstance(event, (RequestEndOfMessage, ResponseEndOfMessage)): self.h3_conn.end_stream(event.stream_id) elif isinstance(event, (RequestProtocolError, ResponseProtocolError)): @@ -145,9 +149,13 @@ class Http3Connection(HttpConnection): elif isinstance(h3_event, DataReceived): if h3_event.push_id is None: if h3_event.data: - yield ReceiveHttp(self.ReceiveData(h3_event.stream_id, h3_event.data)) + yield ReceiveHttp( + self.ReceiveData(h3_event.stream_id, h3_event.data) + ) if h3_event.stream_ended: - yield ReceiveHttp(self.ReceiveEndOfMessage(h3_event.stream_id)) + yield ReceiveHttp( + self.ReceiveEndOfMessage(h3_event.stream_id) + ) elif isinstance(h3_event, HeadersReceived): if h3_event.push_id is None: try: @@ -160,12 +168,20 @@ class Http3Connection(HttpConnection): else: yield ReceiveHttp(receive_event) if h3_event.stream_ended: - yield ReceiveHttp(self.ReceiveEndOfMessage(h3_event.stream_id)) + yield ReceiveHttp( + self.ReceiveEndOfMessage(h3_event.stream_id) + ) elif isinstance(h3_event, TrailersReceived): if h3_event.push_id is None: - yield ReceiveHttp(self.ReceiveTrailers(h3_event.stream_id, http.Headers(h3_event.trailers))) + yield ReceiveHttp( + self.ReceiveTrailers( + h3_event.stream_id, http.Headers(h3_event.trailers) + ) + ) if h3_event.stream_ended: - yield ReceiveHttp(self.ReceiveEndOfMessage(h3_event.stream_id)) + yield ReceiveHttp( + self.ReceiveEndOfMessage(h3_event.stream_id) + ) elif isinstance(h3_event, PushPromiseReceived): # pragma: no cover # we don't support push pass @@ -204,7 +220,9 @@ class Http3Server(Http3Connection): def __init__(self, context: context.Context): super().__init__(context, context.client) - def parse_headers(self, event: HeadersReceived) -> Union[RequestHeaders, ResponseHeaders]: + def parse_headers( + self, event: HeadersReceived + ) -> Union[RequestHeaders, ResponseHeaders]: # same as HTTP/2 ( host, @@ -238,8 +256,8 @@ class Http3Client(Http3Connection): ReceiveProtocolError = ResponseProtocolError ReceiveTrailers = ResponseTrailers - our_stream_id: Dict[int, int] - their_stream_id: Dict[int, int] + our_stream_id: dict[int, int] + their_stream_id: dict[int, int] def __init__(self, context: context.Context): super().__init__(context, context.server) @@ -263,7 +281,9 @@ class Http3Client(Http3Connection): cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id] yield cmd - def parse_headers(self, event: HeadersReceived) -> Union[RequestHeaders, ResponseHeaders]: + def parse_headers( + self, event: HeadersReceived + ) -> Union[RequestHeaders, ResponseHeaders]: # same as HTTP/2 status_code, headers = parse_h2_response_headers(event.headers) response = http.Response( diff --git a/mitmproxy/proxy/layers/http/_http_h2.py b/mitmproxy/proxy/layers/http/_http_h2.py index 8533b7afb..f5b08d64c 100644 --- a/mitmproxy/proxy/layers/http/_http_h2.py +++ b/mitmproxy/proxy/layers/http/_http_h2.py @@ -1,6 +1,6 @@ import collections import logging -from typing import Dict, List, NamedTuple, Tuple +from typing import NamedTuple import h2.config import h2.connection @@ -21,9 +21,7 @@ class H2ConnectionLogger(h2.config.DummyLogger): def debug(self, fmtstr, *args): logger.debug( - f"{self.conn_type} {fmtstr}", - *args, - extra={"client": self.peername} + f"{self.conn_type} {fmtstr}", *args, extra={"client": self.peername} ) def trace(self, fmtstr, *args): @@ -31,7 +29,7 @@ class H2ConnectionLogger(h2.config.DummyLogger): logging.DEBUG - 1, f"{self.conn_type} {fmtstr}", *args, - extra={"client": self.peername} + extra={"client": self.peername}, ) @@ -48,7 +46,7 @@ class BufferedH2Connection(h2.connection.H2Connection): """ stream_buffers: collections.defaultdict[int, collections.deque[SendH2Data]] - stream_trailers: Dict[int, List[Tuple[bytes, bytes]]] + stream_trailers: dict[int, list[tuple[bytes, bytes]]] def __init__(self, config: h2.config.H2Configuration): super().__init__(config) @@ -93,7 +91,7 @@ class BufferedH2Connection(h2.connection.H2Connection): # We can't send right now, so we buffer. self.stream_buffers[stream_id].append(SendH2Data(data, end_stream)) - def send_trailers(self, stream_id: int, trailers: List[Tuple[bytes, bytes]]): + def send_trailers(self, stream_id: int, trailers: list[tuple[bytes, bytes]]): if self.stream_buffers.get(stream_id, None): # Though trailers are not subject to flow control, we need to queue them and send strictly after data frames self.stream_trailers[stream_id] = trailers @@ -173,7 +171,9 @@ class BufferedH2Connection(h2.connection.H2Connection): if not self.stream_buffers[stream_id]: del self.stream_buffers[stream_id] if stream_id in self.stream_trailers: - self.send_headers(stream_id, self.stream_trailers.pop(stream_id), end_stream=True) + self.send_headers( + stream_id, self.stream_trailers.pop(stream_id), end_stream=True + ) sent_any_data = True return sent_any_data diff --git a/mitmproxy/proxy/layers/http/_http_h3.py b/mitmproxy/proxy/layers/http/_http_h3.py index 849f3b159..52b36e57e 100644 --- a/mitmproxy/proxy/layers/http/_http_h3.py +++ b/mitmproxy/proxy/layers/http/_http_h3.py @@ -1,31 +1,29 @@ +from collections.abc import Iterable from dataclasses import dataclass -from typing import Iterable, Optional +from typing import Optional -from aioquic.h3.connection import ( - FrameUnexpected, - H3Connection, - H3Event, - H3Stream, - Headers, - HeadersState, - StreamType, -) +from aioquic.h3.connection import FrameUnexpected +from aioquic.h3.connection import H3Connection +from aioquic.h3.connection import H3Event +from aioquic.h3.connection import H3Stream +from aioquic.h3.connection import Headers +from aioquic.h3.connection import HeadersState +from aioquic.h3.connection import StreamType from aioquic.h3.events import HeadersReceived from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.events import StreamDataReceived from aioquic.quic.packet import QuicErrorCode from mitmproxy import connection -from mitmproxy.proxy import commands, layer -from mitmproxy.proxy.layers.quic import ( - CloseQuicConnection, - QuicConnectionClosed, - QuicStreamDataReceived, - QuicStreamEvent, - QuicStreamReset, - ResetQuicStream, - SendQuicStreamData, -) +from mitmproxy.proxy import commands +from mitmproxy.proxy import layer +from mitmproxy.proxy.layers.quic import CloseQuicConnection +from mitmproxy.proxy.layers.quic import QuicConnectionClosed +from mitmproxy.proxy.layers.quic import QuicStreamDataReceived +from mitmproxy.proxy.layers.quic import QuicStreamEvent +from mitmproxy.proxy.layers.quic import QuicStreamReset +from mitmproxy.proxy.layers.quic import ResetQuicStream +from mitmproxy.proxy.layers.quic import SendQuicStreamData @dataclass @@ -104,8 +102,12 @@ class MockQuic: def reset_stream(self, stream_id: int, error_code: int) -> None: self.pending_commands.append(ResetQuicStream(self.conn, stream_id, error_code)) - def send_stream_data(self, stream_id: int, data: bytes, end_stream: bool = False) -> None: - self.pending_commands.append(SendQuicStreamData(self.conn, stream_id, data, end_stream)) + def send_stream_data( + self, stream_id: int, data: bytes, end_stream: bool = False + ) -> None: + self.pending_commands.append( + SendQuicStreamData(self.conn, stream_id, data, end_stream) + ) class LayeredH3Connection(H3Connection): @@ -114,7 +116,12 @@ class LayeredH3Connection(H3Connection): Also ensures that headers, data and trailers are sent in that order. """ - def __init__(self, conn: connection.Connection, is_client: bool, enable_webtransport: bool = False) -> None: + def __init__( + self, + conn: connection.Connection, + is_client: bool, + enable_webtransport: bool = False, + ) -> None: self._mock = MockQuic(conn, is_client) super().__init__(self._mock, enable_webtransport) # type: ignore @@ -132,13 +139,18 @@ class LayeredH3Connection(H3Connection): stream_ended: bool, ) -> list[H3Event]: # turn HeadersReceived into TrailersReceived for trailers - events = super()._handle_request_or_push_frame(frame_type, frame_data, stream, stream_ended) + events = super()._handle_request_or_push_frame( + frame_type, frame_data, stream, stream_ended + ) for index, event in enumerate(events): if ( isinstance(event, HeadersReceived) - and self._stream[event.stream_id].headers_recv_state == HeadersState.AFTER_TRAILERS + and self._stream[event.stream_id].headers_recv_state + == HeadersState.AFTER_TRAILERS ): - events[index] = TrailersReceived(event.headers, event.stream_id, event.stream_ended, event.push_id) + events[index] = TrailersReceived( + event.headers, event.stream_id, event.stream_ended, event.push_id + ) return events def close_connection( @@ -173,11 +185,7 @@ class LayeredH3Connection(H3Connection): for stream in self._stream.values() if ( stream.push_id == push_id - and stream.stream_type == ( - None - if push_id is None else - StreamType.PUSH - ) + and stream.stream_type == (None if push_id is None else StreamType.PUSH) and not ( stream.headers_recv_state == HeadersState.AFTER_TRAILERS and stream.headers_send_state == HeadersState.AFTER_TRAILERS @@ -206,10 +214,15 @@ class LayeredH3Connection(H3Connection): if self._get_or_create_stream(event.stream_id).ended: # aioquic will not send us any data events once a stream has ended. # Instead, it will close the connection. We simulate this here for H3 tests. - self.close_connection(error_code=QuicErrorCode.PROTOCOL_VIOLATION, reason_phrase="stream already ended") + self.close_connection( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + reason_phrase="stream already ended", + ) return [] else: - return self.handle_event(StreamDataReceived(event.data, event.end_stream, event.stream_id)) + return self.handle_event( + StreamDataReceived(event.data, event.end_stream, event.stream_id) + ) # should never happen else: # pragma: no cover @@ -241,7 +254,9 @@ class LayeredH3Connection(H3Connection): # supporting datagrams would require additional information from the underlying QUIC connection raise NotImplementedError() # pragma: no cover - def send_headers(self, stream_id: int, headers: Headers, end_stream: bool = False) -> None: + def send_headers( + self, stream_id: int, headers: Headers, end_stream: bool = False + ) -> None: """Sends headers over the given stream.""" # ensure we haven't sent something before diff --git a/mitmproxy/proxy/layers/http/_upstream_proxy.py b/mitmproxy/proxy/layers/http/_upstream_proxy.py index 4029d5347..9034fe145 100644 --- a/mitmproxy/proxy/layers/http/_upstream_proxy.py +++ b/mitmproxy/proxy/layers/http/_upstream_proxy.py @@ -1,15 +1,18 @@ -from logging import DEBUG - import time +from logging import DEBUG from typing import Optional from h11._receivebuffer import ReceiveBuffer -from mitmproxy import http, connection +from mitmproxy import connection +from mitmproxy import http from mitmproxy.net.http import http1 -from mitmproxy.proxy import commands, context, layer, tunnel -from mitmproxy.proxy.layers.http._hooks import HttpConnectUpstreamHook +from mitmproxy.proxy import commands +from mitmproxy.proxy import context +from mitmproxy.proxy import layer +from mitmproxy.proxy import tunnel from mitmproxy.proxy.layers import tls +from mitmproxy.proxy.layers.http._hooks import HttpConnectUpstreamHook from mitmproxy.utils import human @@ -48,7 +51,9 @@ class HttpUpstreamProxy(tunnel.TunnelLayer): return (yield from super().start_handshake()) assert self.conn.address flow = http.HTTPFlow(self.context.client, self.tunnel_connection) - authority = self.conn.address[0].encode("idna") + f":{self.conn.address[1]}".encode() + authority = ( + self.conn.address[0].encode("idna") + f":{self.conn.address[1]}".encode() + ) flow.request = http.Request( host=self.conn.address[0], port=self.conn.address[1], @@ -76,9 +81,7 @@ class HttpUpstreamProxy(tunnel.TunnelLayer): response_head = self.buf.maybe_extract_lines() if response_head: try: - response = http1.read_response_head([ - bytes(x) for x in response_head - ]) + response = http1.read_response_head([bytes(x) for x in response_head]) except ValueError as e: proxyaddr = human.format_address(self.tunnel_connection.address) yield commands.Log(f"{proxyaddr}: {e}") diff --git a/mitmproxy/proxy/layers/modes.py b/mitmproxy/proxy/layers/modes.py index cb243d754..bfc274967 100644 --- a/mitmproxy/proxy/layers/modes.py +++ b/mitmproxy/proxy/layers/modes.py @@ -1,14 +1,19 @@ from __future__ import annotations + import socket import struct from abc import ABCMeta from dataclasses import dataclass -from typing import Callable, Optional +from typing import Callable +from typing import Optional from mitmproxy import connection -from mitmproxy.proxy import commands, events, layer +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer from mitmproxy.proxy.commands import StartHook -from mitmproxy.proxy.layers import quic, tls +from mitmproxy.proxy.layers import quic +from mitmproxy.proxy.layers import tls from mitmproxy.proxy.mode_specs import ReverseMode from mitmproxy.proxy.utils import expect diff --git a/mitmproxy/proxy/layers/quic.py b/mitmproxy/proxy/layers/quic.py index 657df3766..808cb584b 100644 --- a/mitmproxy/proxy/layers/quic.py +++ b/mitmproxy/proxy/layers/quic.py @@ -1,44 +1,54 @@ from __future__ import annotations -from dataclasses import dataclass, field -from logging import DEBUG, ERROR, WARNING -from ssl import VerifyMode + import time +from dataclasses import dataclass +from dataclasses import field +from logging import DEBUG +from logging import ERROR +from logging import WARNING +from ssl import VerifyMode from typing import Callable from aioquic.buffer import Buffer as QuicBuffer from aioquic.h3.connection import ErrorCode as H3ErrorCode from aioquic.quic import events as quic_events from aioquic.quic.configuration import QuicConfiguration -from aioquic.quic.connection import ( - QuicConnection, - QuicConnectionError, - QuicConnectionState, - QuicErrorCode, - stream_is_client_initiated, - stream_is_unidirectional, -) -from aioquic.tls import CipherSuite, HandshakeType -from aioquic.quic.packet import ( - PACKET_TYPE_INITIAL, - QuicProtocolVersion, - encode_quic_version_negotiation, - pull_quic_header, -) +from aioquic.quic.connection import QuicConnection +from aioquic.quic.connection import QuicConnectionError +from aioquic.quic.connection import QuicConnectionState +from aioquic.quic.connection import QuicErrorCode +from aioquic.quic.connection import stream_is_client_initiated +from aioquic.quic.connection import stream_is_unidirectional +from aioquic.quic.packet import encode_quic_version_negotiation +from aioquic.quic.packet import PACKET_TYPE_INITIAL +from aioquic.quic.packet import pull_quic_header +from aioquic.quic.packet import QuicProtocolVersion +from aioquic.tls import CipherSuite +from aioquic.tls import HandshakeType from cryptography import x509 -from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa -from mitmproxy import certs, connection, ctx +from cryptography.hazmat.primitives.asymmetric import dsa +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric import rsa + +from mitmproxy import certs +from mitmproxy import connection +from mitmproxy import ctx from mitmproxy.net import tls -from mitmproxy.proxy import commands, context, events, layer, tunnel +from mitmproxy.proxy import commands +from mitmproxy.proxy import context +from mitmproxy.proxy import events +from mitmproxy.proxy import layer +from mitmproxy.proxy import tunnel from mitmproxy.proxy.layers.tcp import TCPLayer -from mitmproxy.proxy.layers.tls import ( - TlsClienthelloHook, - TlsEstablishedClientHook, - TlsEstablishedServerHook, - TlsFailedClientHook, - TlsFailedServerHook, -) +from mitmproxy.proxy.layers.tls import TlsClienthelloHook +from mitmproxy.proxy.layers.tls import TlsEstablishedClientHook +from mitmproxy.proxy.layers.tls import TlsEstablishedServerHook +from mitmproxy.proxy.layers.tls import TlsFailedClientHook +from mitmproxy.proxy.layers.tls import TlsFailedServerHook from mitmproxy.proxy.layers.udp import UDPLayer -from mitmproxy.tls import ClientHello, ClientHelloData, TlsData +from mitmproxy.tls import ClientHello +from mitmproxy.tls import ClientHelloData +from mitmproxy.tls import TlsData @dataclass @@ -53,7 +63,9 @@ class QuicTlsSettings: """The certificate to use for the connection.""" certificate_chain: list[x509.Certificate] = field(default_factory=list) """A list of additional certificates to send to the peer.""" - certificate_private_key: dsa.DSAPrivateKey | ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey | None = None + certificate_private_key: dsa.DSAPrivateKey | ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey | None = ( + None + ) """The certificate's private key.""" cipher_suites: list[CipherSuite] | None = None """An optional list of allowed/advertised cipher suites.""" @@ -147,7 +159,13 @@ class SendQuicStreamData(QuicStreamCommand): end_stream: bool """Whether the FIN bit should be set in the STREAM frame.""" - def __init__(self, connection: connection.Connection, stream_id: int, data: bytes, end_stream: bool = False) -> None: + def __init__( + self, + connection: connection.Connection, + stream_id: int, + data: bytes, + end_stream: bool = False, + ) -> None: super().__init__(connection, stream_id) self.data = data self.end_stream = end_stream @@ -159,7 +177,9 @@ class ResetQuicStream(QuicStreamCommand): error_code: int """An error code indicating why the stream is being reset.""" - def __init__(self, connection: connection.Connection, stream_id: int, error_code: int) -> None: + def __init__( + self, connection: connection.Connection, stream_id: int, error_code: int + ) -> None: super().__init__(connection, stream_id) self.error_code = error_code @@ -170,7 +190,9 @@ class StopQuicStream(QuicStreamCommand): error_code: int """An error code indicating why the stream is being stopped.""" - def __init__(self, connection: connection.Connection, stream_id: int, error_code: int) -> None: + def __init__( + self, connection: connection.Connection, stream_id: int, error_code: int + ) -> None: super().__init__(connection, stream_id) self.error_code = error_code @@ -203,6 +225,7 @@ class CloseQuicConnection(commands.CloseConnection): class QuicConnectionClosed(events.ConnectionClosed): """QUIC connection has been closed.""" + error_code: int "The error code which was specified when closing the connection." @@ -351,7 +374,12 @@ def quic_parse_client_hello(data: bytes) -> ClientHello: class QuicStreamNextLayer(layer.NextLayer): """`NextLayer` variant that callbacks `QuicStreamLayer` after layer decision.""" - def __init__(self, context: context.Context, stream: QuicStreamLayer, ask_on_start: bool = False) -> None: + def __init__( + self, + context: context.Context, + stream: QuicStreamLayer, + ask_on_start: bool = False, + ) -> None: super().__init__(context, ask_on_start) self._stream = stream self._layer: layer.Layer | None = None @@ -390,8 +418,8 @@ class QuicStreamLayer(layer.Layer): if stream_is_unidirectional(stream_id): self.client.state = ( connection.ConnectionState.CAN_READ - if stream_is_client_initiated(stream_id) else - connection.ConnectionState.CAN_WRITE + if stream_is_client_initiated(stream_id) + else connection.ConnectionState.CAN_WRITE ) self._client_stream_id = stream_id @@ -406,8 +434,8 @@ class QuicStreamLayer(layer.Layer): super().__init__(context) self.child_layer = ( TCPLayer(context, ignore=True) - if ignore else - QuicStreamNextLayer(context, self) + if ignore + else QuicStreamNextLayer(context, self) ) self.refresh_metadata() @@ -425,11 +453,11 @@ class QuicStreamLayer(layer.Layer): self.server.state = ( ( connection.ConnectionState.CAN_WRITE - if stream_is_client_initiated(server_stream_id) else - connection.ConnectionState.CAN_READ + if stream_is_client_initiated(server_stream_id) + else connection.ConnectionState.CAN_READ ) - if stream_is_unidirectional(server_stream_id) else - connection.ConnectionState.OPEN + if stream_is_unidirectional(server_stream_id) + else connection.ConnectionState.OPEN ) self.refresh_metadata() @@ -444,8 +472,14 @@ class QuicStreamLayer(layer.Layer): else: break # pragma: no cover if isinstance(child_layer, (UDPLayer, TCPLayer)) and child_layer.flow: - child_layer.flow.metadata["quic_is_unidirectional"] = stream_is_unidirectional(self._client_stream_id) - child_layer.flow.metadata["quic_initiator"] = "client" if stream_is_client_initiated(self._client_stream_id) else "server" + child_layer.flow.metadata[ + "quic_is_unidirectional" + ] = stream_is_unidirectional(self._client_stream_id) + child_layer.flow.metadata["quic_initiator"] = ( + "client" + if stream_is_client_initiated(self._client_stream_id) + else "server" + ) child_layer.flow.metadata["quic_stream_id_client"] = self._client_stream_id child_layer.flow.metadata["quic_stream_id_server"] = self._server_stream_id @@ -483,8 +517,8 @@ class RawQuicLayer(layer.Layer): self.ignore = ignore self.datagram_layer = ( UDPLayer(self.context.fork(), ignore=True) - if ignore else - layer.NextLayer(self.context.fork()) + if ignore + else layer.NextLayer(self.context.fork()) ) self.client_stream_ids = {} self.server_stream_ids = {} @@ -508,29 +542,34 @@ class RawQuicLayer(layer.Layer): # properly forward completion events based on their command elif isinstance(event, events.CommandCompleted): - yield from self.event_to_child(self.command_sources.pop(event.command), event) + yield from self.event_to_child( + self.command_sources.pop(event.command), event + ) # route injected messages based on their connections (prefer client, fallback to server) elif isinstance(event, events.MessageInjected): if event.flow.client_conn in self.connections: - yield from self.event_to_child(self.connections[event.flow.client_conn], event) + yield from self.event_to_child( + self.connections[event.flow.client_conn], event + ) elif event.flow.server_conn in self.connections: - yield from self.event_to_child(self.connections[event.flow.server_conn], event) + yield from self.event_to_child( + self.connections[event.flow.server_conn], event + ) else: raise AssertionError(f"Flow not associated: {event.flow!r}") # handle stream events targeting this context - elif ( - isinstance(event, QuicStreamEvent) - and ( - event.connection is self.context.client - or event.connection is self.context.server - ) + elif isinstance(event, QuicStreamEvent) and ( + event.connection is self.context.client + or event.connection is self.context.server ): from_client = event.connection is self.context.client # fetch or create the layer - stream_ids = self.client_stream_ids if from_client else self.server_stream_ids + stream_ids = ( + self.client_stream_ids if from_client else self.server_stream_ids + ) if event.stream_id in stream_ids: stream_layer = stream_ids[event.stream_id] else: @@ -549,7 +588,9 @@ class RawQuicLayer(layer.Layer): server_stream_id = event.stream_id # create, register and start the layer - stream_layer = QuicStreamLayer(self.context.fork(), self.ignore, client_stream_id) + stream_layer = QuicStreamLayer( + self.context.fork(), self.ignore, client_stream_id + ) self.client_stream_ids[client_stream_id] = stream_layer if server_stream_id is not None: stream_layer.open_server_stream(server_stream_id) @@ -562,7 +603,9 @@ class RawQuicLayer(layer.Layer): conn = stream_layer.client if from_client else stream_layer.server if isinstance(event, QuicStreamDataReceived): if event.data: - yield from self.event_to_child(stream_layer, events.DataReceived(conn, event.data)) + yield from self.event_to_child( + stream_layer, events.DataReceived(conn, event.data) + ) if event.end_stream: yield from self.close_stream_layer(stream_layer, from_client) elif isinstance(event, QuicStreamReset): @@ -574,26 +617,27 @@ class RawQuicLayer(layer.Layer): and command.end_stream and not command.data ): - yield ResetQuicStream(command.connection, command.stream_id, event.error_code) + yield ResetQuicStream( + command.connection, command.stream_id, event.error_code + ) else: yield command else: raise AssertionError(f"Unexpected stream event: {event!r}") # handle close events that target this context - elif ( - isinstance(event, QuicConnectionClosed) - and ( - event.connection is self.context.client - or event.connection is self.context.server - ) + elif isinstance(event, QuicConnectionClosed) and ( + event.connection is self.context.client + or event.connection is self.context.server ): from_client = event.connection is self.context.client other_conn = self.context.server if from_client else self.context.client # be done if both connections are closed if other_conn.connected: - yield CloseQuicConnection(other_conn, event.error_code, event.frame_type, event.reason_phrase) + yield CloseQuicConnection( + other_conn, event.error_code, event.frame_type, event.reason_phrase + ) else: self._handle_event = self.done # type: ignore @@ -607,16 +651,14 @@ class RawQuicLayer(layer.Layer): # forward to either the client or server connection of stream layers and swallow empty stream end for conn, child_layer in self.connections.items(): - if ( - isinstance(child_layer, QuicStreamLayer) - and ((conn is child_layer.client) if from_client else (conn is child_layer.server)) + if isinstance(child_layer, QuicStreamLayer) and ( + (conn is child_layer.client) + if from_client + else (conn is child_layer.server) ): conn.state &= ~connection.ConnectionState.CAN_WRITE for command in self.close_stream_layer(child_layer, from_client): - if ( - not isinstance(command, SendQuicStreamData) - or command.data - ): + if not isinstance(command, SendQuicStreamData) or command.data: yield command # all other connection events are routed to their corresponding layer @@ -626,7 +668,9 @@ class RawQuicLayer(layer.Layer): else: raise AssertionError(f"Unexpected event: {event!r}") - def close_stream_layer(self, stream_layer: QuicStreamLayer, client: bool) -> layer.CommandGenerator[None]: + def close_stream_layer( + self, stream_layer: QuicStreamLayer, client: bool + ) -> layer.CommandGenerator[None]: """Closes the incoming part of a connection.""" conn = stream_layer.client if client else stream_layer.server @@ -636,7 +680,9 @@ class RawQuicLayer(layer.Layer): conn.timestamp_end = time.time() yield from self.event_to_child(stream_layer, events.ConnectionClosed(conn)) - def event_to_child(self, child_layer: layer.Layer, event: events.Event) -> layer.CommandGenerator[None]: + def event_to_child( + self, child_layer: layer.Layer, event: events.Event + ) -> layer.CommandGenerator[None]: """Forwards events to child layers and translates commands.""" for command in child_layer.handle_event(event): @@ -664,16 +710,24 @@ class RawQuicLayer(layer.Layer): elif isinstance(command, commands.CloseConnection): assert stream_id is not None if command.connection.state & connection.ConnectionState.CAN_WRITE: - command.connection.state &= ~connection.ConnectionState.CAN_WRITE - yield SendQuicStreamData(quic_conn, stream_id, b"", end_stream=True) + command.connection.state &= ( + ~connection.ConnectionState.CAN_WRITE + ) + yield SendQuicStreamData( + quic_conn, stream_id, b"", end_stream=True + ) # XXX: Use `command.connection.state & connection.ConnectionState.CAN_READ` instead? - only_close_our_half = isinstance(command, commands.CloseTcpConnection) and command.half_close + only_close_our_half = ( + isinstance(command, commands.CloseTcpConnection) + and command.half_close + ) if not only_close_our_half: - if ( - stream_is_client_initiated(stream_id) == to_client - or not stream_is_unidirectional(stream_id) - ): - yield StopQuicStream(quic_conn, stream_id, QuicErrorCode.NO_ERROR) + if stream_is_client_initiated( + stream_id + ) == to_client or not stream_is_unidirectional(stream_id): + yield StopQuicStream( + quic_conn, stream_id, QuicErrorCode.NO_ERROR + ) yield from self.close_stream_layer(child_layer, to_client) # open server connections by reserving the next stream ID @@ -684,14 +738,18 @@ class RawQuicLayer(layer.Layer): assert client_stream_id is not None stream_id = self.get_next_available_stream_id( is_client=True, - is_unidirectional=stream_is_unidirectional(client_stream_id) + is_unidirectional=stream_is_unidirectional(client_stream_id), ) child_layer.open_server_stream(stream_id) self.server_stream_ids[stream_id] = child_layer - yield from self.event_to_child(child_layer, events.OpenConnectionCompleted(command, None)) + yield from self.event_to_child( + child_layer, events.OpenConnectionCompleted(command, None) + ) else: - raise AssertionError(f"Unexpected stream connection command: {command!r}") + raise AssertionError( + f"Unexpected stream connection command: {command!r}" + ) # remember blocking and wakeup commands else: @@ -701,7 +759,9 @@ class RawQuicLayer(layer.Layer): self.connections[command.connection] = child_layer yield command - def get_next_available_stream_id(self, is_client: bool, is_unidirectional: bool = False) -> int: + def get_next_available_stream_id( + self, is_client: bool, is_unidirectional: bool = False + ) -> int: index = (int(is_unidirectional) << 1) | int(not is_client) stream_id = self.next_stream_id[index] self.next_stream_id[index] = stream_id + 4 @@ -715,7 +775,12 @@ class QuicLayer(tunnel.TunnelLayer): quic: QuicConnection | None = None tls: QuicTlsSettings | None = None - def __init__(self, context: context.Context, conn: connection.Connection, time: Callable[[], float] | None) -> None: + def __init__( + self, + context: context.Context, + conn: connection.Connection, + time: Callable[[], float] | None, + ) -> None: super().__init__(context, tunnel_connection=conn, conn=conn) self.child_layer = layer.NextLayer(self.context, ask_on_start=True) self._time = time or ctx.master.event_loop.time @@ -723,10 +788,7 @@ class QuicLayer(tunnel.TunnelLayer): conn.tls = True def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: - if ( - isinstance(event, events.Wakeup) - and event.command in self._wakeup_commands - ): + if isinstance(event, events.Wakeup) and event.command in self._wakeup_commands: # TunnelLayer has no understanding of wakeups, so we turn this into an empty DataReceived event # which TunnelLayer recognizes as belonging to our connection. assert self.quic @@ -746,15 +808,16 @@ class QuicLayer(tunnel.TunnelLayer): if self.quic: yield from self.tls_interact() - def _handle_command(self, command: commands.Command) -> layer.CommandGenerator[None]: + def _handle_command( + self, command: commands.Command + ) -> layer.CommandGenerator[None]: """Turns stream commands into aioquic connection invocations.""" - if ( - isinstance(command, QuicStreamCommand) - and command.connection is self.conn - ): + if isinstance(command, QuicStreamCommand) and command.connection is self.conn: assert self.quic if isinstance(command, SendQuicStreamData): - self.quic.send_stream_data(command.stream_id, command.data, command.end_stream) + self.quic.send_stream_data( + command.stream_id, command.data, command.end_stream + ) elif isinstance(command, ResetQuicStream): self.quic.reset_stream(command.stream_id, command.error_code) elif isinstance(command, StopQuicStream): @@ -766,7 +829,9 @@ class QuicLayer(tunnel.TunnelLayer): else: yield from super()._handle_command(command) - def start_tls(self, original_destination_connection_id: bytes | None) -> layer.CommandGenerator[None]: + def start_tls( + self, original_destination_connection_id: bytes | None + ) -> layer.CommandGenerator[None]: """Initiates the aioquic connection.""" # must only be called if QUIC is uninitialized @@ -780,7 +845,9 @@ class QuicLayer(tunnel.TunnelLayer): else: yield QuicStartServerHook(tls_data) if not tls_data.settings: - yield commands.Log(f"No QUIC context was provided, failing connection.", ERROR) + yield commands.Log( + f"No QUIC context was provided, failing connection.", ERROR + ) yield commands.CloseConnection(self.conn) return @@ -812,15 +879,16 @@ class QuicLayer(tunnel.TunnelLayer): # request a new wakeup if all pending requests trigger at a later time timer = self.quic.get_timer() - if ( - timer is not None - and not any(existing <= timer for existing in self._wakeup_commands.values()) + if timer is not None and not any( + existing <= timer for existing in self._wakeup_commands.values() ): command = commands.RequestWakeup(timer - self._time()) self._wakeup_commands[command] = timer yield command - def receive_handshake_data(self, data: bytes) -> layer.CommandGenerator[tuple[bool, str | None]]: + def receive_handshake_data( + self, data: bytes + ) -> layer.CommandGenerator[tuple[bool, str | None]]: assert self.quic # forward incoming data to aioquic @@ -854,18 +922,25 @@ class QuicLayer(tunnel.TunnelLayer): f"{self.debug}[quic] tls established: {self.conn}", DEBUG ) if self.conn is self.context.client: - yield TlsEstablishedClientHook(QuicTlsData(self.conn, self.context, settings=self.tls)) + yield TlsEstablishedClientHook( + QuicTlsData(self.conn, self.context, settings=self.tls) + ) else: - yield TlsEstablishedServerHook(QuicTlsData(self.conn, self.context, settings=self.tls)) + yield TlsEstablishedServerHook( + QuicTlsData(self.conn, self.context, settings=self.tls) + ) yield from self.tls_interact() return True, None - elif isinstance(event, ( - quic_events.ConnectionIdIssued, - quic_events.ConnectionIdRetired, - quic_events.PingAcknowledged, - quic_events.ProtocolNegotiated, - )): + elif isinstance( + event, + ( + quic_events.ConnectionIdIssued, + quic_events.ConnectionIdRetired, + quic_events.PingAcknowledged, + quic_events.ProtocolNegotiated, + ), + ): pass else: raise AssertionError(f"Unexpected event: {event!r}") @@ -877,9 +952,13 @@ class QuicLayer(tunnel.TunnelLayer): def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]: self.conn.error = err if self.conn is self.context.client: - yield TlsFailedClientHook(QuicTlsData(self.conn, self.context, settings=self.tls)) + yield TlsFailedClientHook( + QuicTlsData(self.conn, self.context, settings=self.tls) + ) else: - yield TlsFailedServerHook(QuicTlsData(self.conn, self.context, settings=self.tls)) + yield TlsFailedServerHook( + QuicTlsData(self.conn, self.context, settings=self.tls) + ) yield from super().on_handshake_error(err) def receive_data(self, data: bytes) -> layer.CommandGenerator[None]: @@ -895,7 +974,8 @@ class QuicLayer(tunnel.TunnelLayer): if self.debug: reason = event.reason_phrase or error_code_to_str(event.error_code) yield commands.Log( - f"{self.debug}[quic] close_notify {self.conn} (reason={reason})", DEBUG + f"{self.debug}[quic] close_notify {self.conn} (reason={reason})", + DEBUG, ) # We don't rely on `ConnectionTerminated` to dispatch `QuicConnectionClosed`, because # after aioquic receives a termination frame, it still waits for the next `handle_timer` @@ -905,17 +985,28 @@ class QuicLayer(tunnel.TunnelLayer): yield commands.CloseConnection(self.tunnel_connection) return # we don't handle any further events, nor do/can we transmit data, so exit elif isinstance(event, quic_events.DatagramFrameReceived): - yield from self.event_to_child(events.DataReceived(self.conn, event.data)) + yield from self.event_to_child( + events.DataReceived(self.conn, event.data) + ) elif isinstance(event, quic_events.StreamDataReceived): - yield from self.event_to_child(QuicStreamDataReceived(self.conn, event.stream_id, event.data, event.end_stream)) + yield from self.event_to_child( + QuicStreamDataReceived( + self.conn, event.stream_id, event.data, event.end_stream + ) + ) elif isinstance(event, quic_events.StreamReset): - yield from self.event_to_child(QuicStreamReset(self.conn, event.stream_id, event.error_code)) - elif isinstance(event, ( - quic_events.ConnectionIdIssued, - quic_events.ConnectionIdRetired, - quic_events.PingAcknowledged, - quic_events.ProtocolNegotiated, - )): + yield from self.event_to_child( + QuicStreamReset(self.conn, event.stream_id, event.error_code) + ) + elif isinstance( + event, + ( + quic_events.ConnectionIdIssued, + quic_events.ConnectionIdRetired, + quic_events.PingAcknowledged, + quic_events.ProtocolNegotiated, + ), + ): pass else: raise AssertionError(f"Unexpected event: {event!r}") @@ -931,7 +1022,12 @@ class QuicLayer(tunnel.TunnelLayer): QuicErrorCode.NO_ERROR, None, "Connection closed." ) yield from self.event_to_child( - QuicConnectionClosed(self.conn, close_event.error_code, close_event.frame_type, close_event.reason_phrase) + QuicConnectionClosed( + self.conn, + close_event.error_code, + close_event.frame_type, + close_event.reason_phrase, + ) ) def send_data(self, data: bytes) -> layer.CommandGenerator[None]: @@ -941,11 +1037,15 @@ class QuicLayer(tunnel.TunnelLayer): self.quic.send_datagram_frame(data) yield from self.tls_interact() - def send_close(self, command: commands.CloseConnection) -> layer.CommandGenerator[None]: + def send_close( + self, command: commands.CloseConnection + ) -> layer.CommandGenerator[None]: # properly close the QUIC connection if self.quic: if isinstance(command, CloseQuicConnection): - self.quic.close(command.error_code, command.frame_type, command.reason_phrase) + self.quic.close( + command.error_code, command.frame_type, command.reason_phrase + ) else: self.quic.close() yield from self.tls_interact() @@ -959,13 +1059,17 @@ class ServerQuicLayer(QuicLayer): wait_for_clienthello: bool = False - def __init__(self, context: context.Context, conn: connection.Server | None = None, time: Callable[[], float] | None = None): + def __init__( + self, + context: context.Context, + conn: connection.Server | None = None, + time: Callable[[], float] | None = None, + ): super().__init__(context, conn or context.server, time) def start_handshake(self) -> layer.CommandGenerator[None]: - wait_for_clienthello = ( - not self.command_to_reply_to - and isinstance(self.child_layer, ClientQuicLayer) + wait_for_clienthello = not self.command_to_reply_to and isinstance( + self.child_layer, ClientQuicLayer ) if wait_for_clienthello: self.wait_for_clienthello = True @@ -999,7 +1103,9 @@ class ClientQuicLayer(QuicLayer): server_tls_available: bool """Indicates whether the parent layer is a ServerQuicLayer.""" - def __init__(self, context: context.Context, time: Callable[[], float] | None = None) -> None: + def __init__( + self, context: context.Context, time: Callable[[], float] | None = None + ) -> None: # same as ClientTLSLayer, we might be nested in some other transport if context.client.tls: context.client.alpn = None @@ -1018,7 +1124,9 @@ class ClientQuicLayer(QuicLayer): def start_handshake(self) -> layer.CommandGenerator[None]: yield from () - def receive_handshake_data(self, data: bytes) -> layer.CommandGenerator[tuple[bool, str | None]]: + def receive_handshake_data( + self, data: bytes + ) -> layer.CommandGenerator[tuple[bool, str | None]]: # if we already had a valid client hello, don't process further packets if self.tls: return (yield from super().receive_handshake_data(data)) @@ -1049,7 +1157,10 @@ class ClientQuicLayer(QuicLayer): # ensure it's (likely) a client handshake packet if len(data) < 1200 or header.packet_type != PACKET_TYPE_INITIAL: - return False, f"Invalid handshake received, roaming not supported. ({data.hex()})" + return ( + False, + f"Invalid handshake received, roaming not supported. ({data.hex()})", + ) # extract the client hello try: @@ -1068,7 +1179,8 @@ class ClientQuicLayer(QuicLayer): # replace the QUIC layer with an UDP layer if requested if tls_clienthello.ignore_connection: self.conn = self.tunnel_connection = connection.Client( - peername=("ignore-conn", 0), sockname=("ignore-conn", 0), + peername=("ignore-conn", 0), + sockname=("ignore-conn", 0), transport_protocol="udp", state=connection.ConnectionState.OPEN, ) @@ -1083,7 +1195,9 @@ class ClientQuicLayer(QuicLayer): parent_layer.handle_event = replacement_layer.handle_event # type: ignore parent_layer._handle_event = replacement_layer._handle_event # type: ignore yield from parent_layer.handle_event(events.Start()) - yield from parent_layer.handle_event(events.DataReceived(self.context.client, data)) + yield from parent_layer.handle_event( + events.DataReceived(self.context.client, data) + ) return True, None # start the server QUIC connection if demanded and available @@ -1122,4 +1236,6 @@ class ClientQuicLayer(QuicLayer): def errored(self, event: events.Event) -> layer.CommandGenerator[None]: if self.debug is not None: - yield commands.Log(f"{self.debug}[quic] Swallowing {event} as handshake failed.", DEBUG) + yield commands.Log( + f"{self.debug}[quic] Swallowing {event} as handshake failed.", DEBUG + ) diff --git a/mitmproxy/proxy/layers/tcp.py b/mitmproxy/proxy/layers/tcp.py index 2d1ff7305..0272d4ed5 100644 --- a/mitmproxy/proxy/layers/tcp.py +++ b/mitmproxy/proxy/layers/tcp.py @@ -1,10 +1,14 @@ from dataclasses import dataclass from typing import Optional -from mitmproxy import flow, tcp -from mitmproxy.proxy import commands, events, layer +from mitmproxy import flow +from mitmproxy import tcp +from mitmproxy.connection import Connection +from mitmproxy.connection import ConnectionState +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer from mitmproxy.proxy.commands import StartHook -from mitmproxy.connection import ConnectionState, Connection from mitmproxy.proxy.context import Context from mitmproxy.proxy.events import MessageInjected from mitmproxy.proxy.utils import expect diff --git a/mitmproxy/proxy/layers/tls.py b/mitmproxy/proxy/layers/tls.py index 1e00b8be8..bb74b8b49 100644 --- a/mitmproxy/proxy/layers/tls.py +++ b/mitmproxy/proxy/layers/tls.py @@ -1,18 +1,28 @@ import struct -from logging import DEBUG, ERROR, INFO, WARNING - import time +from collections.abc import Iterator from dataclasses import dataclass -from typing import Iterator, Optional +from logging import DEBUG +from logging import ERROR +from logging import INFO +from logging import WARNING +from typing import Optional from OpenSSL import SSL -from mitmproxy import certs, connection -from mitmproxy.proxy import commands, events, layer, tunnel +from mitmproxy import certs +from mitmproxy import connection +from mitmproxy.proxy import commands from mitmproxy.proxy import context +from mitmproxy.proxy import events +from mitmproxy.proxy import layer +from mitmproxy.proxy import tunnel from mitmproxy.proxy.commands import StartHook -from mitmproxy.proxy.layers import tcp, udp -from mitmproxy.tls import ClientHello, ClientHelloData, TlsData +from mitmproxy.proxy.layers import tcp +from mitmproxy.proxy.layers import udp +from mitmproxy.tls import ClientHello +from mitmproxy.tls import ClientHelloData +from mitmproxy.tls import TlsData from mitmproxy.utils import human @@ -96,7 +106,7 @@ def is_dtls_handshake_record(d: bytes) -> bool: True, if the passed bytes start with the DTLS record magic bytes False, otherwise. """ - return len(d) >= 3 and d[0] == 0x16 and d[1] == 0xfe and d[2] == 0xfd + return len(d) >= 3 and d[0] == 0x16 and d[1] == 0xFE and d[2] == 0xFD def dtls_handshake_record_contents(data: bytes) -> Iterator[bytes]: @@ -136,7 +146,9 @@ def get_dtls_client_hello(data: bytes) -> Optional[bytes]: client_hello += d if len(client_hello) >= 13: # comment about slicing: we skip the epoch and sequence number - client_hello_size = struct.unpack("!I", b"\x00" + client_hello[9:12])[0] + 12 + client_hello_size = ( + struct.unpack("!I", b"\x00" + client_hello[9:12])[0] + 12 + ) if len(client_hello) >= client_hello_size: return client_hello[:client_hello_size] return None @@ -257,7 +269,9 @@ class TLSLayer(tunnel.TunnelLayer): conn.tls = True def __repr__(self): - return super().__repr__().replace(")", f" {self.conn.sni!r} {self.conn.alpn!r})") + return ( + super().__repr__().replace(")", f" {self.conn.sni!r} {self.conn.alpn!r})") + ) @property def is_dtls(self): @@ -265,7 +279,7 @@ class TLSLayer(tunnel.TunnelLayer): @property def proto_name(self): - return 'DTLS' if self.is_dtls else 'TLS' + return "DTLS" if self.is_dtls else "TLS" def start_tls(self) -> layer.CommandGenerator[None]: assert not self.tls @@ -421,9 +435,7 @@ class TLSLayer(tunnel.TunnelLayer): if close: self.conn.state &= ~connection.ConnectionState.CAN_READ if self.debug: - yield commands.Log( - f"{self.debug}[tls] close_notify {self.conn}", DEBUG - ) + yield commands.Log(f"{self.debug}[tls] close_notify {self.conn}", DEBUG) yield from self.event_to_child(events.ConnectionClosed(self.conn)) def receive_close(self) -> layer.CommandGenerator[None]: @@ -440,7 +452,9 @@ class TLSLayer(tunnel.TunnelLayer): pass yield from self.tls_interact() - def send_close(self, command: commands.CloseConnection) -> layer.CommandGenerator[None]: + def send_close( + self, command: commands.CloseConnection + ) -> layer.CommandGenerator[None]: # We should probably shutdown the TLS connection properly here. yield from super().send_close(command) @@ -659,7 +673,9 @@ class ClientTLSLayer(TLSLayer): def errored(self, event: events.Event) -> layer.CommandGenerator[None]: if self.debug is not None: - yield commands.Log(f"{self.debug}[tls] Swallowing {event} as handshake failed.", DEBUG) + yield commands.Log( + f"{self.debug}[tls] Swallowing {event} as handshake failed.", DEBUG + ) class MockTLSLayer(TLSLayer): diff --git a/mitmproxy/proxy/layers/udp.py b/mitmproxy/proxy/layers/udp.py index 3026a71ce..e80fc7b9d 100644 --- a/mitmproxy/proxy/layers/udp.py +++ b/mitmproxy/proxy/layers/udp.py @@ -1,10 +1,13 @@ from dataclasses import dataclass from typing import Optional -from mitmproxy import flow, udp -from mitmproxy.proxy import commands, events, layer -from mitmproxy.proxy.commands import StartHook +from mitmproxy import flow +from mitmproxy import udp from mitmproxy.connection import Connection +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer +from mitmproxy.proxy.commands import StartHook from mitmproxy.proxy.context import Context from mitmproxy.proxy.events import MessageInjected from mitmproxy.proxy.utils import expect diff --git a/mitmproxy/proxy/layers/websocket.py b/mitmproxy/proxy/layers/websocket.py index a2c57fee9..24c291b76 100644 --- a/mitmproxy/proxy/layers/websocket.py +++ b/mitmproxy/proxy/layers/websocket.py @@ -1,19 +1,23 @@ import time +from collections.abc import Iterator from dataclasses import dataclass -from typing import Iterator -import wsproto import wsproto.extensions import wsproto.frame_protocol import wsproto.utilities -from mitmproxy import connection, http, websocket -from mitmproxy.proxy import commands, events, layer +from wsproto import ConnectionState +from wsproto.frame_protocol import Opcode + +from mitmproxy import connection +from mitmproxy import http +from mitmproxy import websocket +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer from mitmproxy.proxy.commands import StartHook from mitmproxy.proxy.context import Context from mitmproxy.proxy.events import MessageInjected from mitmproxy.proxy.utils import expect -from wsproto import ConnectionState -from wsproto.frame_protocol import Opcode @dataclass diff --git a/mitmproxy/proxy/mode_servers.py b/mitmproxy/proxy/mode_servers.py index 3c7587f0a..b40615670 100644 --- a/mitmproxy/proxy/mode_servers.py +++ b/mitmproxy/proxy/mode_servers.py @@ -12,25 +12,36 @@ Example: from __future__ import annotations import asyncio +import errno import json import logging import socket import textwrap import typing -from abc import ABCMeta, abstractmethod +from abc import ABCMeta +from abc import abstractmethod from contextlib import contextmanager from pathlib import Path -from typing import ClassVar, Generic, TypeVar, cast, get_args +from typing import cast +from typing import ClassVar +from typing import Generic +from typing import get_args +from typing import TypeVar -import errno import mitmproxy_wireguard as wg -from mitmproxy import ctx, flow, platform +from mitmproxy import ctx +from mitmproxy import flow +from mitmproxy import platform from mitmproxy.connection import Address from mitmproxy.master import Master -from mitmproxy.net import local_ip, udp +from mitmproxy.net import local_ip +from mitmproxy.net import udp from mitmproxy.net.udp_wireguard import WireGuardDatagramTransport -from mitmproxy.proxy import commands, layers, mode_specs, server +from mitmproxy.proxy import commands +from mitmproxy.proxy import layers +from mitmproxy.proxy import mode_specs +from mitmproxy.proxy import server from mitmproxy.proxy.context import Context from mitmproxy.proxy.layer import Layer from mitmproxy.utils import human @@ -55,14 +66,16 @@ class ProxyConnectionHandler(server.LiveConnectionHandler): await data.wait_for_resume() # pragma: no cover -M = TypeVar('M', bound=mode_specs.ProxyMode) +M = TypeVar("M", bound=mode_specs.ProxyMode) class ServerManager(typing.Protocol): connections: dict[tuple, ProxyConnectionHandler] @contextmanager - def register_connection(self, connection_id: tuple, handler: ProxyConnectionHandler): + def register_connection( + self, connection_id: tuple, handler: ProxyConnectionHandler + ): ... # pragma: no cover @@ -89,7 +102,7 @@ class ServerInstance(Generic[M], metaclass=ABCMeta): @classmethod def make( - cls: typing.Type[Self], + cls: type[Self], mode: mode_specs.ProxyMode | str, manager: ServerManager, ) -> Self: @@ -196,7 +209,9 @@ class ServerInstance(Generic[M], metaclass=ABCMeta): reader = cast(udp.DatagramReader, handler.transports[handler.client].reader) reader.feed_data(data, remote_addr) - async def handle_udp_connection(self, connection_id: tuple, handler: ProxyConnectionHandler) -> None: + async def handle_udp_connection( + self, connection_id: tuple, handler: ProxyConnectionHandler + ) -> None: with self.manager.register_connection(connection_id, handler): await handler.handle_client() @@ -220,7 +235,9 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta): self.last_exception = e message = f"{self.mode.description} failed to listen on {host or '*'}:{port} with {e}" if e.errno == errno.EADDRINUSE and self.mode.custom_listen_port is None: - assert self.mode.custom_listen_host is None # since [@ [listen_addr:]listen_port] + assert ( + self.mode.custom_listen_host is None + ) # since [@ [listen_addr:]listen_port] message += f"\nTry specifying a different port by using `--mode {self.mode.full_spec}@{port + 1}`." raise OSError(e.errno, message, e.filename) from e except Exception as e: @@ -261,9 +278,13 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta): s.bind(("", 0)) fixed_port = s.getsockname()[1] s.close() - return await asyncio.start_server(self.handle_tcp_connection, host, fixed_port) + return await asyncio.start_server( + self.handle_tcp_connection, host, fixed_port + ) except Exception as e: - logger.debug(f"Failed to listen on a single port ({e!r}), falling back to default behavior.") + logger.debug( + f"Failed to listen on a single port ({e!r}), falling back to default behavior." + ) return await asyncio.start_server(self.handle_tcp_connection, host, port) elif self.mode.transport_protocol == "udp": # create_datagram_endpoint only creates one socket, so the workaround above doesn't apply @@ -309,17 +330,24 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]): try: if not conf_path.exists(): conf_path.parent.mkdir(parents=True, exist_ok=True) - conf_path.write_text(json.dumps({ - "server_key": wg.genkey(), - "client_key": wg.genkey(), - }, indent=4)) + conf_path.write_text( + json.dumps( + { + "server_key": wg.genkey(), + "client_key": wg.genkey(), + }, + indent=4, + ) + ) try: c = json.loads(conf_path.read_text()) self.server_key = c["server_key"] self.client_key = c["client_key"] except Exception as e: - raise ValueError(f"Invalid configuration file ({conf_path}): {e}") from e + raise ValueError( + f"Invalid configuration file ({conf_path}): {e}" + ) from e # error early on invalid keys p = wg.pubkey(self.client_key) _ = wg.pubkey(self.server_key) @@ -355,7 +383,8 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]): return None host = local_ip.get_local_ip() or local_ip.get_local_ip6() port = self.mode.listen_port(ctx.options.listen_port) - return textwrap.dedent(f""" + return textwrap.dedent( + f""" [Interface] PrivateKey = {self.client_key} Address = 10.0.0.1/32 @@ -365,13 +394,11 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]): PublicKey = {wg.pubkey(self.server_key)} AllowedIPs = 0.0.0.0/0 Endpoint = {host}:{port} - """).strip() + """ + ).strip() def to_json(self) -> dict: - return { - "wireguard_conf": self.client_conf(), - **super().to_json() - } + return {"wireguard_conf": self.client_conf(), **super().to_json()} async def stop(self) -> None: assert self._server is not None @@ -390,15 +417,12 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]): async def wg_handle_tcp_connection(self, stream: wg.TcpStream) -> None: await self.handle_tcp_connection(stream, stream) - def wg_handle_udp_datagram(self, data: bytes, remote_addr: Address, local_addr: Address) -> None: + def wg_handle_udp_datagram( + self, data: bytes, remote_addr: Address, local_addr: Address + ) -> None: assert self._server is not None transport = WireGuardDatagramTransport(self._server, local_addr, remote_addr) - self.handle_udp_datagram( - transport, - data, - remote_addr, - local_addr - ) + self.handle_udp_datagram(transport, data, remote_addr, local_addr) class RegularInstance(AsyncioServerInstance[mode_specs.RegularMode]): diff --git a/mitmproxy/proxy/mode_specs.py b/mitmproxy/proxy/mode_specs.py index 0092c8077..b13762ced 100644 --- a/mitmproxy/proxy/mode_specs.py +++ b/mitmproxy/proxy/mode_specs.py @@ -19,14 +19,16 @@ Examples: RegularMode.parse("socks5") # ValueError """ - from __future__ import annotations import dataclasses -from abc import ABCMeta, abstractmethod +from abc import ABCMeta +from abc import abstractmethod from dataclasses import dataclass from functools import cache -from typing import ClassVar, Literal, Type, TypeVar +from typing import ClassVar +from typing import Literal +from typing import TypeVar from mitmproxy.coretypes.serializable import Serializable from mitmproxy.net import server_spec @@ -41,6 +43,7 @@ class ProxyMode(Serializable, metaclass=ABCMeta): Parsed representation of a proxy mode spec. Subclassed for each specific mode, which then does its own data validation. """ + full_spec: str """The full proxy mode spec as entered by the user.""" data: str @@ -50,9 +53,11 @@ class ProxyMode(Serializable, metaclass=ABCMeta): custom_listen_port: int | None """A custom listen port, if specified in the spec.""" - type_name: ClassVar[str] # automatically derived from the class name in __init_subclass__ + type_name: ClassVar[ + str + ] # automatically derived from the class name in __init_subclass__ """The unique name for this proxy mode, e.g. "regular" or "reverse".""" - __types: ClassVar[dict[str, Type[ProxyMode]]] = {} + __types: ClassVar[dict[str, type[ProxyMode]]] = {} def __init_subclass__(cls, **kwargs): cls.type_name = cls.__name__.removesuffix("Mode").lower() @@ -85,7 +90,7 @@ class ProxyMode(Serializable, metaclass=ABCMeta): @classmethod @cache - def parse(cls: Type[Self], spec: str) -> Self: + def parse(cls: type[Self], spec: str) -> Self: """ Parse a proxy mode specification and return the corresponding `ProxyMode` instance. """ @@ -121,10 +126,7 @@ class ProxyMode(Serializable, metaclass=ABCMeta): raise ValueError(f"{mode!r} is not a spec for a {cls.type_name} mode") return mode_cls( - full_spec=spec, - data=data, - custom_listen_host=host, - custom_listen_port=port + full_spec=spec, data=data, custom_listen_host=host, custom_listen_port=port ) def listen_host(self, default: str | None = None) -> str: @@ -165,8 +167,8 @@ class ProxyMode(Serializable, metaclass=ABCMeta): raise dataclasses.FrozenInstanceError("Proxy modes are immutable.") -TCP: Literal['tcp', 'udp'] = "tcp" -UDP: Literal['tcp', 'udp'] = "udp" +TCP: Literal["tcp", "udp"] = "tcp" +UDP: Literal["tcp", "udp"] = "udp" def _check_empty(data): @@ -176,6 +178,7 @@ def _check_empty(data): class RegularMode(ProxyMode): """A regular HTTP(S) proxy that is interfaced with `HTTP CONNECT` calls (or absolute-form HTTP requests).""" + description = "HTTP(S) proxy" transport_protocol = TCP @@ -185,6 +188,7 @@ class RegularMode(ProxyMode): class TransparentMode(ProxyMode): """A transparent proxy, see https://docs.mitmproxy.org/dev/howto-transparent/""" + description = "transparent proxy" transport_protocol = TCP @@ -194,6 +198,7 @@ class TransparentMode(ProxyMode): class UpstreamMode(ProxyMode): """A regular HTTP(S) proxy, but all connections are forwarded to a second upstream HTTP(S) proxy.""" + description = "HTTP(S) proxy (upstream mode)" transport_protocol = TCP scheme: Literal["http", "https"] @@ -209,9 +214,12 @@ class UpstreamMode(ProxyMode): class ReverseMode(ProxyMode): """A reverse proxy. This acts like a normal server, but redirects all requests to a fixed target.""" + description = "reverse proxy" transport_protocol = TCP - scheme: Literal["http", "https", "http3", "tls", "dtls", "tcp", "udp", "dns", "quic"] + scheme: Literal[ + "http", "https", "http3", "tls", "dtls", "tcp", "udp", "dns", "quic" + ] address: tuple[str, int] # noinspection PyDataclass @@ -230,6 +238,7 @@ class ReverseMode(ProxyMode): class Socks5Mode(ProxyMode): """A SOCKSv5 proxy.""" + description = "SOCKS v5 proxy" default_port = 1080 transport_protocol = TCP @@ -240,6 +249,7 @@ class Socks5Mode(ProxyMode): class DnsMode(ProxyMode): """A DNS server.""" + description = "DNS server" default_port = 53 transport_protocol = UDP @@ -253,6 +263,7 @@ class Http3Mode(ProxyMode): A regular HTTP3 proxy that is interfaced with absolute-form HTTP requests. (This class will be merged into `RegularMode` once the UDP implementation is deemed stable enough.) """ + description = "HTTP3 proxy" transport_protocol = UDP @@ -262,6 +273,7 @@ class Http3Mode(ProxyMode): class WireGuardMode(ProxyMode): """Proxy Server based on WireGuard""" + description = "WireGuard server" default_port = 51820 transport_protocol = UDP diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 24959848c..ac717dfae 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -10,23 +10,35 @@ import abc import asyncio import collections import logging - import time import traceback -from collections.abc import Awaitable, Callable, MutableMapping +from collections.abc import Awaitable +from collections.abc import Callable +from collections.abc import MutableMapping from contextlib import contextmanager from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional +from typing import Union import mitmproxy_wireguard as wg from OpenSSL import SSL -from mitmproxy import http, options as moptions, tls +from mitmproxy import http +from mitmproxy import options as moptions +from mitmproxy import tls +from mitmproxy.connection import Address +from mitmproxy.connection import Client +from mitmproxy.connection import Connection +from mitmproxy.connection import ConnectionState +from mitmproxy.net import udp +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer +from mitmproxy.proxy import layers +from mitmproxy.proxy import mode_specs +from mitmproxy.proxy import server_hooks from mitmproxy.proxy.context import Context from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy import commands, events, layer, layers, mode_specs, server_hooks -from mitmproxy.connection import Address, Client, Connection, ConnectionState -from mitmproxy.net import udp from mitmproxy.utils import asyncio_utils from mitmproxy.utils import human from mitmproxy.utils.data import pkg_data @@ -80,8 +92,12 @@ class TimeoutWatchdog: @dataclass class ConnectionIO: handler: Optional[asyncio.Task] = None - reader: Optional[Union[asyncio.StreamReader, udp.DatagramReader, wg.TcpStream]] = None - writer: Optional[Union[asyncio.StreamWriter, udp.DatagramWriter, wg.TcpStream]] = None + reader: Optional[ + Union[asyncio.StreamReader, udp.DatagramReader, wg.TcpStream] + ] = None + writer: Optional[ + Union[asyncio.StreamWriter, udp.DatagramWriter, wg.TcpStream] + ] = None class ConnectionHandler(metaclass=abc.ABCMeta): @@ -135,7 +151,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta): self.server_event(events.Start()) await asyncio.wait([handler]) if not handler.cancelled() and (e := handler.exception()): - self.log(f"mitmproxy has crashed!\n{traceback.format_exception(e)}", logging.ERROR) + self.log( + f"mitmproxy has crashed!\n{traceback.format_exception(e)}", + logging.ERROR, + ) watch.cancel() while self.wakeup_timer: @@ -331,11 +350,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta): pass def log(self, message: str, level: int = logging.INFO) -> None: - logger.log( - level, - message, - extra={"client": self.client.peername} - ) + logger.log(level, message, extra={"client": self.client.peername}) def server_event(self, event: events.Event) -> None: self.timeout_watchdog.register_activity() diff --git a/mitmproxy/proxy/server_hooks.py b/mitmproxy/proxy/server_hooks.py index 22e1e5418..a00c6ca19 100644 --- a/mitmproxy/proxy/server_hooks.py +++ b/mitmproxy/proxy/server_hooks.py @@ -1,7 +1,7 @@ from dataclasses import dataclass -from mitmproxy import connection from . import commands +from mitmproxy import connection @dataclass diff --git a/mitmproxy/proxy/tunnel.py b/mitmproxy/proxy/tunnel.py index 435693de6..5aa42cded 100644 --- a/mitmproxy/proxy/tunnel.py +++ b/mitmproxy/proxy/tunnel.py @@ -1,9 +1,14 @@ import time -from enum import Enum, auto -from typing import Optional, Union +from enum import auto +from enum import Enum +from typing import Optional +from typing import Union from mitmproxy import connection -from mitmproxy.proxy import commands, context, events, layer +from mitmproxy.proxy import commands +from mitmproxy.proxy import context +from mitmproxy.proxy import events +from mitmproxy.proxy import layer from mitmproxy.proxy.layer import Layer @@ -108,7 +113,9 @@ class TunnelLayer(layer.Layer): yield from self.event_to_child(evt) self._event_queue.clear() - def _handle_command(self, command: commands.Command) -> layer.CommandGenerator[None]: + def _handle_command( + self, command: commands.Command + ) -> layer.CommandGenerator[None]: if ( isinstance(command, commands.ConnectionCommand) and command.connection == self.conn @@ -170,7 +177,9 @@ class TunnelLayer(layer.Layer): def send_data(self, data: bytes) -> layer.CommandGenerator[None]: yield commands.SendData(self.tunnel_connection, data) - def send_close(self, command: commands.CloseConnection) -> layer.CommandGenerator[None]: + def send_close( + self, command: commands.CloseConnection + ) -> layer.CommandGenerator[None]: yield command diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index 9d9546568..587a65e78 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -2,9 +2,9 @@ This module provides a @concurrent decorator primitive to offload computations from mitmproxy's main master thread. """ - import asyncio import inspect + from mitmproxy import hooks diff --git a/mitmproxy/tcp.py b/mitmproxy/tcp.py index 2ad13336c..eec4c26fc 100644 --- a/mitmproxy/tcp.py +++ b/mitmproxy/tcp.py @@ -1,6 +1,7 @@ import time -from mitmproxy import connection, flow +from mitmproxy import connection +from mitmproxy import flow from mitmproxy.coretypes import serializable diff --git a/mitmproxy/test/taddons.py b/mitmproxy/test/taddons.py index 24dbaebda..82ee2de30 100644 --- a/mitmproxy/test/taddons.py +++ b/mitmproxy/test/taddons.py @@ -2,10 +2,11 @@ import asyncio import mitmproxy.master import mitmproxy.options -from mitmproxy import hooks from mitmproxy import command from mitmproxy import eventsequence -from mitmproxy.addons import script, core +from mitmproxy import hooks +from mitmproxy.addons import core +from mitmproxy.addons import script class context: diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index bcb2d211b..b632abc67 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -1,5 +1,8 @@ import uuid -from typing import Optional, Union +from typing import Optional +from typing import Union + +from wsproto.frame_protocol import Opcode from mitmproxy import connection from mitmproxy import dns @@ -10,9 +13,10 @@ from mitmproxy import udp from mitmproxy import websocket from mitmproxy.connection import ConnectionState from mitmproxy.proxy.mode_specs import ProxyMode -from mitmproxy.test.tutils import tdnsreq, tdnsresp -from mitmproxy.test.tutils import treq, tresp -from wsproto.frame_protocol import Opcode +from mitmproxy.test.tutils import tdnsreq +from mitmproxy.test.tutils import tdnsresp +from mitmproxy.test.tutils import treq +from mitmproxy.test.tutils import tresp def ttcpflow( diff --git a/mitmproxy/tls.py b/mitmproxy/tls.py index 98fed77a4..a8d93ffda 100644 --- a/mitmproxy/tls.py +++ b/mitmproxy/tls.py @@ -3,10 +3,11 @@ from dataclasses import dataclass from typing import Optional from kaitaistruct import KaitaiStream - from OpenSSL import SSL + from mitmproxy import connection -from mitmproxy.contrib.kaitaistruct import tls_client_hello, dtls_client_hello +from mitmproxy.contrib.kaitaistruct import dtls_client_hello +from mitmproxy.contrib.kaitaistruct import tls_client_hello from mitmproxy.net import check from mitmproxy.proxy import context @@ -18,7 +19,7 @@ class ClientHello: _raw_bytes: bytes - def __init__(self, raw_client_hello: bytes, dtls: bool=False): + def __init__(self, raw_client_hello: bytes, dtls: bool = False): """Create a TLS ClientHello object from raw bytes.""" self._raw_bytes = raw_client_hello if dtls: diff --git a/mitmproxy/tools/console/commander/commander.py b/mitmproxy/tools/console/commander/commander.py index 53acbf4bd..9c2fcb3a8 100644 --- a/mitmproxy/tools/console/commander/commander.py +++ b/mitmproxy/tools/console/commander/commander.py @@ -1,6 +1,7 @@ import abc from collections.abc import Sequence -from typing import NamedTuple, Optional +from typing import NamedTuple +from typing import Optional import urwid from urwid.text_layout import calc_coords diff --git a/mitmproxy/tools/console/commandexecutor.py b/mitmproxy/tools/console/commandexecutor.py index d683a1e4b..fea007fe9 100644 --- a/mitmproxy/tools/console/commandexecutor.py +++ b/mitmproxy/tools/console/commandexecutor.py @@ -3,7 +3,6 @@ from collections.abc import Sequence from mitmproxy import exceptions from mitmproxy import flow - from mitmproxy.tools.console import overlay from mitmproxy.tools.console import signals diff --git a/mitmproxy/tools/console/commands.py b/mitmproxy/tools/console/commands.py index 78564ef17..ee1049254 100644 --- a/mitmproxy/tools/console/commands.py +++ b/mitmproxy/tools/console/commands.py @@ -1,6 +1,7 @@ -import urwid import textwrap +import urwid + from mitmproxy import command from mitmproxy.tools.console import layoutwidget from mitmproxy.tools.console import signals diff --git a/mitmproxy/tools/console/common.py b/mitmproxy/tools/console/common.py index 65c12db15..58f2be69a 100644 --- a/mitmproxy/tools/console/common.py +++ b/mitmproxy/tools/console/common.py @@ -1,22 +1,23 @@ import enum -import platform import math +import platform from collections.abc import Iterable from functools import lru_cache -from typing import Optional, Union +from typing import Optional +from typing import Union -from publicsuffix2 import get_sld, get_tld - -import urwid import urwid.util +from publicsuffix2 import get_sld +from publicsuffix2 import get_tld +from mitmproxy import dns from mitmproxy import flow +from mitmproxy.dns import DNSFlow from mitmproxy.http import HTTPFlow -from mitmproxy.utils import human, emoji from mitmproxy.tcp import TCPFlow from mitmproxy.udp import UDPFlow -from mitmproxy import dns -from mitmproxy.dns import DNSFlow +from mitmproxy.utils import emoji +from mitmproxy.utils import human # Detect Windows Subsystem for Linux and Windows IS_WINDOWS_OR_WSL = ( diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index f88603d0a..6810bc42a 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -3,7 +3,8 @@ import logging from collections.abc import Sequence import mitmproxy.types -from mitmproxy import command, command_lexer +from mitmproxy import command +from mitmproxy import command_lexer from mitmproxy import contentviews from mitmproxy import ctx from mitmproxy import dns diff --git a/mitmproxy/tools/console/eventlog.py b/mitmproxy/tools/console/eventlog.py index da53539e1..ab6a03c65 100644 --- a/mitmproxy/tools/console/eventlog.py +++ b/mitmproxy/tools/console/eventlog.py @@ -1,8 +1,9 @@ import collections import urwid -from mitmproxy.tools.console import layoutwidget + from mitmproxy import log +from mitmproxy.tools.console import layoutwidget class LogBufferWalker(urwid.SimpleListWalker): diff --git a/mitmproxy/tools/console/flowdetailview.py b/mitmproxy/tools/console/flowdetailview.py index 56161d50c..9b55ef2d7 100644 --- a/mitmproxy/tools/console/flowdetailview.py +++ b/mitmproxy/tools/console/flowdetailview.py @@ -4,8 +4,10 @@ import urwid import mitmproxy.flow from mitmproxy import http -from mitmproxy.tools.console import common, searchable -from mitmproxy.utils import human, strutils +from mitmproxy.tools.console import common +from mitmproxy.tools.console import searchable +from mitmproxy.utils import human +from mitmproxy.utils import strutils def maybe_timestamp(base, attr): diff --git a/mitmproxy/tools/console/flowview.py b/mitmproxy/tools/console/flowview.py index 20e969c39..8aae522a8 100644 --- a/mitmproxy/tools/console/flowview.py +++ b/mitmproxy/tools/console/flowview.py @@ -1,5 +1,4 @@ import logging - import math import sys from functools import lru_cache diff --git a/mitmproxy/tools/console/grideditor/__init__.py b/mitmproxy/tools/console/grideditor/__init__.py index c13ea70e3..6bcae5b94 100644 --- a/mitmproxy/tools/console/grideditor/__init__.py +++ b/mitmproxy/tools/console/grideditor/__init__.py @@ -1,17 +1,15 @@ from . import base -from .editors import ( - CookieAttributeEditor, - CookieEditor, - DataViewer, - OptionsEditor, - PathEditor, - QueryEditor, - RequestHeaderEditor, - RequestMultipartEditor, - RequestUrlEncodedEditor, - ResponseHeaderEditor, - SetCookieEditor, -) +from .editors import CookieAttributeEditor +from .editors import CookieEditor +from .editors import DataViewer +from .editors import OptionsEditor +from .editors import PathEditor +from .editors import QueryEditor +from .editors import RequestHeaderEditor +from .editors import RequestMultipartEditor +from .editors import RequestUrlEncodedEditor +from .editors import ResponseHeaderEditor +from .editors import SetCookieEditor __all__ = [ "base", diff --git a/mitmproxy/tools/console/grideditor/base.py b/mitmproxy/tools/console/grideditor/base.py index 5759c0cf3..a59717487 100644 --- a/mitmproxy/tools/console/grideditor/base.py +++ b/mitmproxy/tools/console/grideditor/base.py @@ -1,16 +1,23 @@ import abc import copy import os -from collections.abc import Callable, Container, Iterable, MutableSequence, Sequence -from typing import Any, AnyStr, ClassVar, Optional +from collections.abc import Callable +from collections.abc import Container +from collections.abc import Iterable +from collections.abc import MutableSequence +from collections.abc import Sequence +from typing import Any +from typing import AnyStr +from typing import ClassVar +from typing import Optional import urwid -from mitmproxy.utils import strutils -from mitmproxy import exceptions -from mitmproxy.tools.console import signals -from mitmproxy.tools.console import layoutwidget import mitmproxy.tools.console.master +from mitmproxy import exceptions +from mitmproxy.tools.console import layoutwidget +from mitmproxy.tools.console import signals +from mitmproxy.utils import strutils def read_file(filename: str, escaped: bool) -> AnyStr: diff --git a/mitmproxy/tools/console/grideditor/col_bytes.py b/mitmproxy/tools/console/grideditor/col_bytes.py index f29147417..9af1a3544 100644 --- a/mitmproxy/tools/console/grideditor/col_bytes.py +++ b/mitmproxy/tools/console/grideditor/col_bytes.py @@ -1,4 +1,5 @@ import urwid + from mitmproxy.tools.console import signals from mitmproxy.tools.console.grideditor import base from mitmproxy.utils import strutils diff --git a/mitmproxy/tools/console/grideditor/col_subgrid.py b/mitmproxy/tools/console/grideditor/col_subgrid.py index be4b4271b..17887b1af 100644 --- a/mitmproxy/tools/console/grideditor/col_subgrid.py +++ b/mitmproxy/tools/console/grideditor/col_subgrid.py @@ -1,7 +1,8 @@ import urwid -from mitmproxy.tools.console.grideditor import base -from mitmproxy.tools.console import signals + from mitmproxy.net.http import cookies +from mitmproxy.tools.console import signals +from mitmproxy.tools.console.grideditor import base class Column(base.Column): @@ -20,9 +21,7 @@ class Column(base.Column): def keypress(self, key: str, editor): if key in "rRe": - signals.status_message.send( - message="Press enter to edit this field." - ) + signals.status_message.send(message="Press enter to edit this field.") return elif key == "m_select": self.subeditor.grideditor = editor diff --git a/mitmproxy/tools/console/grideditor/col_text.py b/mitmproxy/tools/console/grideditor/col_text.py index d5ad1cba0..04dbb5ab0 100644 --- a/mitmproxy/tools/console/grideditor/col_text.py +++ b/mitmproxy/tools/console/grideditor/col_text.py @@ -4,7 +4,6 @@ Welcome to the encoding dance! In a nutshell, text columns are actually a proxy class for byte columns, which just encode/decodes contents. """ - from mitmproxy.tools.console import signals from mitmproxy.tools.console.grideditor import col_bytes diff --git a/mitmproxy/tools/console/grideditor/col_viewany.py b/mitmproxy/tools/console/grideditor/col_viewany.py index 2801587c0..b6ffe1f44 100644 --- a/mitmproxy/tools/console/grideditor/col_viewany.py +++ b/mitmproxy/tools/console/grideditor/col_viewany.py @@ -4,6 +4,7 @@ A display-only column that displays any data type. from typing import Any import urwid + from mitmproxy.tools.console.grideditor import base from mitmproxy.utils import strutils diff --git a/mitmproxy/tools/console/grideditor/editors.py b/mitmproxy/tools/console/grideditor/editors.py index 4e2677b65..bfc9b3862 100644 --- a/mitmproxy/tools/console/grideditor/editors.py +++ b/mitmproxy/tools/console/grideditor/editors.py @@ -1,4 +1,5 @@ -from typing import Any, Union +from typing import Any +from typing import Union import urwid diff --git a/mitmproxy/tools/console/keybindings.py b/mitmproxy/tools/console/keybindings.py index 903f71e43..5cb3819b5 100644 --- a/mitmproxy/tools/console/keybindings.py +++ b/mitmproxy/tools/console/keybindings.py @@ -1,6 +1,7 @@ -import urwid import textwrap +import urwid + from mitmproxy.tools.console import layoutwidget from mitmproxy.tools.console import signals from mitmproxy.utils import signals as utils_signals diff --git a/mitmproxy/tools/console/keymap.py b/mitmproxy/tools/console/keymap.py index d4fbcaf81..5cadd010d 100644 --- a/mitmproxy/tools/console/keymap.py +++ b/mitmproxy/tools/console/keymap.py @@ -4,15 +4,14 @@ from collections.abc import Sequence from functools import cache from typing import Optional -import ruamel.yaml import ruamel.yaml.error +import mitmproxy.types from mitmproxy import command -from mitmproxy.tools.console import commandexecutor -from mitmproxy.tools.console import signals from mitmproxy import ctx from mitmproxy import exceptions -import mitmproxy.types +from mitmproxy.tools.console import commandexecutor +from mitmproxy.tools.console import signals class KeyBindingError(Exception): @@ -62,7 +61,9 @@ class Binding: return self.key.replace("space", " ") def key_short(self) -> str: - return self.key.replace("enter", "⏎").replace("right", "→").replace("space", "␣") + return ( + self.key.replace("enter", "⏎").replace("right", "→").replace("space", "␣") + ) def sortkey(self): return self.key + ",".join(self.contexts) diff --git a/mitmproxy/tools/console/master.py b/mitmproxy/tools/console/master.py index 415e4a675..3bad84096 100644 --- a/mitmproxy/tools/console/master.py +++ b/mitmproxy/tools/console/master.py @@ -1,6 +1,6 @@ import asyncio +import contextlib import mimetypes -import os import os.path import shlex import shutil @@ -8,20 +8,19 @@ import stat import subprocess import sys import tempfile -import contextlib import threading from typing import TypeVar +import urwid from tornado.platform.asyncio import AddThreadSelectorEventLoop -import urwid - from mitmproxy import addons +from mitmproxy import log from mitmproxy import master from mitmproxy import options -from mitmproxy import log -from mitmproxy.addons import errorcheck, intercept +from mitmproxy.addons import errorcheck from mitmproxy.addons import eventstore +from mitmproxy.addons import intercept from mitmproxy.addons import readfile from mitmproxy.addons import view from mitmproxy.contrib.tornado import patch_tornado diff --git a/mitmproxy/tools/console/options.py b/mitmproxy/tools/console/options.py index 01b055f9e..8aca078ba 100644 --- a/mitmproxy/tools/console/options.py +++ b/mitmproxy/tools/console/options.py @@ -1,16 +1,17 @@ from __future__ import annotations + +import pprint +import textwrap from collections.abc import Sequence +from typing import Optional import urwid -import textwrap -import pprint -from typing import Optional from mitmproxy import exceptions from mitmproxy import optmanager from mitmproxy.tools.console import layoutwidget -from mitmproxy.tools.console import signals from mitmproxy.tools.console import overlay +from mitmproxy.tools.console import signals HELP_HEIGHT = 5 diff --git a/mitmproxy/tools/console/overlay.py b/mitmproxy/tools/console/overlay.py index cd3216191..17b55bc10 100644 --- a/mitmproxy/tools/console/overlay.py +++ b/mitmproxy/tools/console/overlay.py @@ -2,10 +2,10 @@ import math import urwid -from mitmproxy.tools.console import signals from mitmproxy.tools.console import grideditor -from mitmproxy.tools.console import layoutwidget from mitmproxy.tools.console import keymap +from mitmproxy.tools.console import layoutwidget +from mitmproxy.tools.console import signals class SimpleOverlay(urwid.Overlay, layoutwidget.LayoutWidget): diff --git a/mitmproxy/tools/console/palettes.py b/mitmproxy/tools/console/palettes.py index afbec3a01..415322be3 100644 --- a/mitmproxy/tools/console/palettes.py +++ b/mitmproxy/tools/console/palettes.py @@ -4,7 +4,9 @@ # http://urwid.org/manual/displayattributes.html # from __future__ import annotations -from collections.abc import Mapping, Sequence + +from collections.abc import Mapping +from collections.abc import Sequence from typing import Optional diff --git a/mitmproxy/tools/console/quickhelp.py b/mitmproxy/tools/console/quickhelp.py index 18e6c90d7..a24f81004 100644 --- a/mitmproxy/tools/console/quickhelp.py +++ b/mitmproxy/tools/console/quickhelp.py @@ -2,7 +2,8 @@ This module is reponsible for drawing the quick key help at the bottom of mitmproxy. """ from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional +from typing import Union import urwid @@ -21,6 +22,7 @@ from mitmproxy.tools.console.options import Options @dataclass class BasicKeyHelp: """Quick help for urwid-builtin keybindings (i.e. those keys that do not appear in the keymap)""" + key: str @@ -181,7 +183,7 @@ def _make_row(label: str, items: HelpItems, keymap: Keymap) -> urwid.Columns: " ", short, ], - wrap="clip" + wrap="clip", ) cols.append((14, txt)) diff --git a/mitmproxy/tools/console/signals.py b/mitmproxy/tools/console/signals.py index 024fad500..38c3f5cfa 100644 --- a/mitmproxy/tools/console/signals.py +++ b/mitmproxy/tools/console/signals.py @@ -18,7 +18,9 @@ status_message = signals.SyncSignal(_status_message) # Prompt for input -def _status_prompt(prompt: str, text: str | None, callback: Callable[[str], None]) -> None: +def _status_prompt( + prompt: str, text: str | None, callback: Callable[[str], None] +) -> None: ... @@ -26,7 +28,9 @@ status_prompt = signals.SyncSignal(_status_prompt) # Prompt for a single keystroke -def _status_prompt_onekey(prompt: str, keys: list[tuple[str, str]], callback: Callable[[str], None]) -> None: +def _status_prompt_onekey( + prompt: str, keys: list[tuple[str, str]], callback: Callable[[str], None] +) -> None: ... diff --git a/mitmproxy/tools/console/statusbar.py b/mitmproxy/tools/console/statusbar.py index 1ea3fe7fc..a09950000 100644 --- a/mitmproxy/tools/console/statusbar.py +++ b/mitmproxy/tools/console/statusbar.py @@ -1,4 +1,5 @@ from __future__ import annotations + from collections.abc import Callable from functools import lru_cache from typing import Optional @@ -6,15 +7,19 @@ from typing import Optional import urwid import mitmproxy.tools.console.master -from mitmproxy.tools.console import commandexecutor, flowlist, quickhelp +from mitmproxy.tools.console import commandexecutor from mitmproxy.tools.console import common +from mitmproxy.tools.console import flowlist +from mitmproxy.tools.console import quickhelp from mitmproxy.tools.console import signals from mitmproxy.tools.console.commander import commander from mitmproxy.utils import human @lru_cache -def shorten_message(msg: tuple[str, str] | str, max_width: int) -> list[tuple[str, str]]: +def shorten_message( + msg: tuple[str, str] | str, max_width: int +) -> list[tuple[str, str]]: """ Shorten message so that it fits into a single line in the statusbar. """ @@ -69,7 +74,9 @@ class ActionBar(urwid.WidgetWrap): if not self.prompting and flow is None or flow == self.master.view.focus.flow: self.show_quickhelp() - def sig_message(self, message: tuple[str, str] | str, expire: int | None = 1) -> None: + def sig_message( + self, message: tuple[str, str] | str, expire: int | None = 1 + ) -> None: if self.prompting: return cols, _ = self.master.ui.get_cols_rows() @@ -84,7 +91,9 @@ class ActionBar(urwid.WidgetWrap): signals.call_in.send(seconds=expire, callback=cb) - def sig_prompt(self, prompt: str, text: str | None, callback: Callable[[str], None]) -> None: + def sig_prompt( + self, prompt: str, text: str | None, callback: Callable[[str], None] + ) -> None: signals.focus.send(section="footer") self.top._w = urwid.Edit(f"{prompt.strip()}: ", text or "") self.bottom._w = urwid.Text("") @@ -109,7 +118,9 @@ class ActionBar(urwid.WidgetWrap): execute = commandexecutor.CommandExecutor(self.master) execute(txt) - def sig_prompt_onekey(self, prompt: str, keys: list[tuple[str, str]], callback: Callable[[str], None]) -> None: + def sig_prompt_onekey( + self, prompt: str, keys: list[tuple[str, str]], callback: Callable[[str], None] + ) -> None: """ Keys are a set of (word, key) tuples. The appropriate key in the word is highlighted. @@ -315,10 +326,12 @@ class StatusBar(urwid.WidgetWrap): ("heading", f"{arrow} {marked} [{offset}/{fc}]".ljust(11)), ] - listen_addrs: list[str] = list(dict.fromkeys( - human.format_address(a) - for a in self.master.addons.get("proxyserver").listen_addrs() - )) + listen_addrs: list[str] = list( + dict.fromkeys( + human.format_address(a) + for a in self.master.addons.get("proxyserver").listen_addrs() + ) + ) if listen_addrs: boundaddr = f"[{', '.join(listen_addrs)}]" else: diff --git a/mitmproxy/tools/console/window.py b/mitmproxy/tools/console/window.py index 5c2a2392e..7c9379b57 100644 --- a/mitmproxy/tools/console/window.py +++ b/mitmproxy/tools/console/window.py @@ -2,6 +2,7 @@ import os import re import urwid + from mitmproxy import flow from mitmproxy.tools.console import commands from mitmproxy.tools.console import common @@ -148,7 +149,9 @@ class Window(urwid.Frame): signals.flow_change.connect(self.flow_changed) signals.pop_view_state.connect(self.pop) - self.master.options.subscribe(self.configure, ["console_layout", "console_layout_headers"]) + self.master.options.subscribe( + self.configure, ["console_layout", "console_layout_headers"] + ) self.pane = 0 self.stacks = [WindowStack(master, "flowlist"), WindowStack(master, "eventlog")] diff --git a/mitmproxy/tools/dump.py b/mitmproxy/tools/dump.py index 527a93e8b..6bb269ee3 100644 --- a/mitmproxy/tools/dump.py +++ b/mitmproxy/tools/dump.py @@ -1,7 +1,11 @@ from mitmproxy import addons from mitmproxy import master from mitmproxy import options -from mitmproxy.addons import dumper, errorcheck, keepserving, readfile, termlog +from mitmproxy.addons import dumper +from mitmproxy.addons import errorcheck +from mitmproxy.addons import keepserving +from mitmproxy.addons import readfile +from mitmproxy.addons import termlog class DumpMaster(master.Master): diff --git a/mitmproxy/tools/main.py b/mitmproxy/tools/main.py index 414745109..dfcf944d9 100644 --- a/mitmproxy/tools/main.py +++ b/mitmproxy/tools/main.py @@ -1,18 +1,24 @@ from __future__ import annotations + import argparse import asyncio import logging import os import signal import sys -from collections.abc import Callable, Sequence -from typing import Any, Optional, TypeVar +from collections.abc import Callable +from collections.abc import Sequence +from typing import Any +from typing import Optional +from typing import TypeVar -from mitmproxy import exceptions, master +from mitmproxy import exceptions +from mitmproxy import master from mitmproxy import options from mitmproxy import optmanager from mitmproxy.tools import cmdline -from mitmproxy.utils import debug, arg_check +from mitmproxy.utils import arg_check +from mitmproxy.utils import debug def process_options(parser, opts, args): @@ -29,9 +35,7 @@ def process_options(parser, opts, args): args.flow_detail = 2 adict = { - key: val - for key, val in vars(args).items() - if key in opts and val is not None + key: val for key, val in vars(args).items() if key in opts and val is not None } opts.update(**adict) @@ -55,7 +59,9 @@ def run( logging.getLogger("tornado").setLevel(logging.WARNING) logging.getLogger("asyncio").setLevel(logging.WARNING) logging.getLogger("hpack").setLevel(logging.WARNING) - logging.getLogger("quic").setLevel(logging.WARNING) # aioquic uses a different prefix... + logging.getLogger("quic").setLevel( + logging.WARNING + ) # aioquic uses a different prefix... debug.register_info_dumpers() opts = options.Options() diff --git a/mitmproxy/tools/web/app.py b/mitmproxy/tools/web/app.py index 4b58dc271..153b1afe0 100644 --- a/mitmproxy/tools/web/app.py +++ b/mitmproxy/tools/web/app.py @@ -1,14 +1,18 @@ from __future__ import annotations + import asyncio import hashlib import json import logging import os.path import re -from collections.abc import Callable, Sequence +from collections.abc import Callable +from collections.abc import Sequence from io import BytesIO from itertools import islice -from typing import ClassVar, Optional, Union +from typing import ClassVar +from typing import Optional +from typing import Union import tornado.escape import tornado.web @@ -16,7 +20,9 @@ import tornado.websocket import mitmproxy.flow import mitmproxy.tools.web.master -from mitmproxy import certs, command, contentviews +from mitmproxy import certs +from mitmproxy import command +from mitmproxy import contentviews from mitmproxy import flowfilter from mitmproxy import http from mitmproxy import io @@ -25,8 +31,10 @@ from mitmproxy import optmanager from mitmproxy import version from mitmproxy.dns import DNSFlow from mitmproxy.http import HTTPFlow -from mitmproxy.tcp import TCPFlow, TCPMessage -from mitmproxy.udp import UDPFlow, UDPMessage +from mitmproxy.tcp import TCPFlow +from mitmproxy.tcp import TCPMessage +from mitmproxy.udp import UDPFlow +from mitmproxy.udp import UDPMessage from mitmproxy.utils.emoji import emoji from mitmproxy.utils.strutils import always_str from mitmproxy.websocket import WebSocketMessage @@ -191,7 +199,7 @@ class APIError(tornado.web.HTTPError): class RequestHandler(tornado.web.RequestHandler): - application: "Application" + application: Application def write(self, chunk: Union[str, bytes, dict, list]): # Writing arrays on the top level is ok nowadays. @@ -237,11 +245,11 @@ class RequestHandler(tornado.web.RequestHandler): return self.request.body @property - def view(self) -> "mitmproxy.addons.view.View": + def view(self) -> mitmproxy.addons.view.View: return self.application.master.view @property - def master(self) -> "mitmproxy.tools.web.master.WebMaster": + def master(self) -> mitmproxy.tools.web.master.WebMaster: return self.application.master @property @@ -322,7 +330,11 @@ class DumpFlows(RequestHandler): match = flowfilter.parse(self.request.arguments["filter"][0].decode()) except ValueError: # thrown py flowfilter.parse if filter is invalid raise APIError(400, f"Invalid filter argument / regex") - except (KeyError, IndexError): # Key+Index: ["filter"][0] can fail, if it's not set + except ( + KeyError, + IndexError, + ): # Key+Index: ["filter"][0] can fail, if it's not set + def match(_) -> bool: return True @@ -617,31 +629,35 @@ class DnsRebind(RequestHandler): raise tornado.web.HTTPError( 403, reason="To protect against DNS rebinding, mitmweb can only be accessed by IP at the moment. " - "(https://github.com/mitmproxy/mitmproxy/issues/3234)", + "(https://github.com/mitmproxy/mitmproxy/issues/3234)", ) class State(RequestHandler): def get(self): - self.write({ - "version": version.VERSION, - "contentViews": [v.name for v in contentviews.views if v.name != "Query"], - "servers": [s.to_json() for s in self.master.proxyserver.servers] - }) + self.write( + { + "version": version.VERSION, + "contentViews": [ + v.name for v in contentviews.views if v.name != "Query" + ], + "servers": [s.to_json() for s in self.master.proxyserver.servers], + } + ) class GZipContentAndFlowFiles(tornado.web.GZipContentEncoding): CONTENT_TYPES = { "application/octet-stream", - *tornado.web.GZipContentEncoding.CONTENT_TYPES + *tornado.web.GZipContentEncoding.CONTENT_TYPES, } class Application(tornado.web.Application): - master: "mitmproxy.tools.web.master.WebMaster" + master: mitmproxy.tools.web.master.WebMaster def __init__( - self, master: "mitmproxy.tools.web.master.WebMaster", debug: bool + self, master: mitmproxy.tools.web.master.WebMaster, debug: bool ) -> None: self.master = master super().__init__( diff --git a/mitmproxy/tools/web/master.py b/mitmproxy/tools/web/master.py index 62119e09c..3d0b59bc9 100644 --- a/mitmproxy/tools/web/master.py +++ b/mitmproxy/tools/web/master.py @@ -1,6 +1,6 @@ +import errno import logging -import errno import tornado.httpserver import tornado.ioloop @@ -8,16 +8,19 @@ from mitmproxy import addons from mitmproxy import flow from mitmproxy import log from mitmproxy import master -from mitmproxy import optmanager from mitmproxy import options -from mitmproxy.addons import errorcheck, eventstore +from mitmproxy import optmanager +from mitmproxy.addons import errorcheck +from mitmproxy.addons import eventstore from mitmproxy.addons import intercept from mitmproxy.addons import readfile from mitmproxy.addons import termlog from mitmproxy.addons import view from mitmproxy.addons.proxyserver import Proxyserver from mitmproxy.contrib.tornado import patch_tornado -from mitmproxy.tools.web import app, webaddons, static_viewer +from mitmproxy.tools.web import app +from mitmproxy.tools.web import static_viewer +from mitmproxy.tools.web import webaddons logger = logging.getLogger(__name__) @@ -87,7 +90,7 @@ class WebMaster(master.Master): app.ClientConnection.broadcast( resource="state", cmd="update", - data={"servers": [s.to_json() for s in self.proxyserver.servers]} + data={"servers": [s.to_json() for s in self.proxyserver.servers]}, ) async def running(self): diff --git a/mitmproxy/tools/web/static_viewer.py b/mitmproxy/tools/web/static_viewer.py index 7decf12de..3f9b5dbc6 100644 --- a/mitmproxy/tools/web/static_viewer.py +++ b/mitmproxy/tools/web/static_viewer.py @@ -7,10 +7,12 @@ import time from collections.abc import Iterable from typing import Optional -from mitmproxy import contentviews, http +from mitmproxy import contentviews from mitmproxy import ctx +from mitmproxy import flow from mitmproxy import flowfilter -from mitmproxy import io, flow +from mitmproxy import http +from mitmproxy import io from mitmproxy import version from mitmproxy.tools.web.app import flow_to_json diff --git a/mitmproxy/types.py b/mitmproxy/types.py index e3161aeea..8645811ac 100644 --- a/mitmproxy/types.py +++ b/mitmproxy/types.py @@ -1,13 +1,17 @@ import codecs -import os import glob +import os import re from collections.abc import Sequence -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from mitmproxy import exceptions from mitmproxy import flow -from mitmproxy.utils import emoji, strutils +from mitmproxy.utils import emoji +from mitmproxy.utils import strutils if TYPE_CHECKING: # pragma: no cover from mitmproxy.command import CommandManager diff --git a/mitmproxy/udp.py b/mitmproxy/udp.py index e719de0ac..2d716e910 100644 --- a/mitmproxy/udp.py +++ b/mitmproxy/udp.py @@ -1,6 +1,7 @@ import time -from mitmproxy import connection, flow +from mitmproxy import connection +from mitmproxy import flow from mitmproxy.coretypes import serializable diff --git a/mitmproxy/utils/arg_check.py b/mitmproxy/utils/arg_check.py index ad43eaf9b..24923d963 100644 --- a/mitmproxy/utils/arg_check.py +++ b/mitmproxy/utils/arg_check.py @@ -1,5 +1,5 @@ -import sys import re +import sys DEPRECATED = """ --confdir diff --git a/mitmproxy/utils/data.py b/mitmproxy/utils/data.py index 091640ec9..baa8c6abd 100644 --- a/mitmproxy/utils/data.py +++ b/mitmproxy/utils/data.py @@ -1,6 +1,6 @@ -import os.path import importlib import inspect +import os.path class Data: diff --git a/mitmproxy/utils/emoji.py b/mitmproxy/utils/emoji.py index 2bb31432e..2a45eddc9 100644 --- a/mitmproxy/utils/emoji.py +++ b/mitmproxy/utils/emoji.py @@ -2,7 +2,6 @@ """ All of the emoji and characters that can be used as flow markers. """ - # auto-generated. run this file to refresh. emoji = { diff --git a/mitmproxy/utils/human.py b/mitmproxy/utils/human.py index ddb336340..a417c84af 100644 --- a/mitmproxy/utils/human.py +++ b/mitmproxy/utils/human.py @@ -5,11 +5,11 @@ import time from typing import Optional SIZE_UNITS = { - "b": 1024 ** 0, - "k": 1024 ** 1, - "m": 1024 ** 2, - "g": 1024 ** 3, - "t": 1024 ** 4, + "b": 1024**0, + "k": 1024**1, + "m": 1024**2, + "g": 1024**3, + "t": 1024**4, } diff --git a/mitmproxy/utils/magisk.py b/mitmproxy/utils/magisk.py index 286aee926..815e513d9 100644 --- a/mitmproxy/utils/magisk.py +++ b/mitmproxy/utils/magisk.py @@ -1,13 +1,14 @@ -from zipfile import ZipFile import hashlib +import os +from zipfile import ZipFile + from cryptography import x509 from cryptography.hazmat.primitives import serialization -from mitmproxy import certs, ctx +from mitmproxy import certs +from mitmproxy import ctx from mitmproxy.options import CONF_BASENAME -import os - # The following 3 variables are for including in the magisk module as text file MODULE_PROP_TEXT = """id=mitmproxycert name=MITMProxy cert @@ -84,14 +85,14 @@ def get_ca_from_files() -> x509.Certificate: return certstore.default_ca._cert -def subject_hash_old(ca : x509.Certificate) -> str: +def subject_hash_old(ca: x509.Certificate) -> str: # Mimics the -subject_hash_old option of openssl used for android certificate names full_hash = hashlib.md5(ca.subject.public_bytes()).digest() - sho = (full_hash[0] | (full_hash[1] << 8) | (full_hash[2] << 16) | full_hash[3] << 24) + sho = full_hash[0] | (full_hash[1] << 8) | (full_hash[2] << 16) | full_hash[3] << 24 return hex(sho)[2:] -def write_magisk_module(path : str): +def write_magisk_module(path: str): # Makes a zip file that can be loaded by Magisk # Android certs are stored as DER files ca = get_ca_from_files() @@ -103,7 +104,9 @@ def write_magisk_module(path : str): zipp.writestr("config.sh", CONFIG_SH_TEXT) zipp.writestr("META-INF/com/google/android/updater-script", "#MAGISK") zipp.writestr("META-INF/com/google/android/update-binary", UPDATE_BINARY_TEXT) - zipp.writestr("common/file_contexts_image", "/magisk(/.*)? u:object_r:system_file:s0") + zipp.writestr( + "common/file_contexts_image", "/magisk(/.*)? u:object_r:system_file:s0" + ) zipp.writestr("common/post-fs-data.sh", "MODDIR=${0%/*}") zipp.writestr("common/service.sh", "MODDIR=${0%/*}") zipp.writestr("common/system.prop", "") diff --git a/mitmproxy/utils/signals.py b/mitmproxy/utils/signals.py index 37900f68b..cad5e5d1f 100644 --- a/mitmproxy/utils/signals.py +++ b/mitmproxy/utils/signals.py @@ -8,11 +8,16 @@ This is similar to the Blinker library (https://pypi.org/project/blinker/), with - supports async receivers. """ from __future__ import annotations + import asyncio import inspect import weakref -from collections.abc import Callable, Awaitable -from typing import Any, Generic, TypeVar, cast +from collections.abc import Awaitable +from collections.abc import Callable +from typing import Any +from typing import cast +from typing import Generic +from typing import TypeVar try: from typing import ParamSpec @@ -85,11 +90,13 @@ class _AsyncSignal(Generic[P], _SignalMixin): super().disconnect(receiver) async def send(self, *args: P.args, **kwargs: P.kwargs) -> None: - await asyncio.gather(*[ - aws - for aws in super().notify(*args, **kwargs) - if aws is not None and inspect.isawaitable(aws) - ]) + await asyncio.gather( + *[ + aws + for aws in super().notify(*args, **kwargs) + if aws is not None and inspect.isawaitable(aws) + ] + ) # noinspection PyPep8Naming diff --git a/mitmproxy/utils/sliding_window.py b/mitmproxy/utils/sliding_window.py index dca71cbd6..a2cbb300d 100644 --- a/mitmproxy/utils/sliding_window.py +++ b/mitmproxy/utils/sliding_window.py @@ -1,5 +1,8 @@ import itertools -from typing import Iterable, Iterator, Optional, TypeVar +from collections.abc import Iterable +from collections.abc import Iterator +from typing import Optional +from typing import TypeVar T = TypeVar("T") diff --git a/mitmproxy/utils/strutils.py b/mitmproxy/utils/strutils.py index 6f61ff54d..0e2c1b207 100644 --- a/mitmproxy/utils/strutils.py +++ b/mitmproxy/utils/strutils.py @@ -1,7 +1,9 @@ import codecs import io import re -from typing import Iterable, Union, overload +from collections.abc import Iterable +from typing import overload +from typing import Union # https://mypy.readthedocs.io/en/stable/more_types.html#function-overloading @@ -236,7 +238,7 @@ def escape_special_areas( """ buf = io.StringIO() parts = split_special_areas(data, area_delimiter) - rex = re.compile(fr"[{control_characters}]") + rex = re.compile(rf"[{control_characters}]") for i, x in enumerate(parts): if i % 2: x = rex.sub(_move_to_private_code_plane, x) diff --git a/mitmproxy/utils/typecheck.py b/mitmproxy/utils/typecheck.py index 6279cfaeb..a898a5b38 100644 --- a/mitmproxy/utils/typecheck.py +++ b/mitmproxy/utils/typecheck.py @@ -17,7 +17,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: Type) -> None: TypeError otherwise. This function supports only those types required for options. """ - e = TypeError("Expected {} for {}, but got {}.".format(typeinfo, name, type(value))) + e = TypeError(f"Expected {typeinfo} for {name}, but got {type(value)}.") origin = typing.get_origin(typeinfo) diff --git a/mitmproxy/utils/vt_codes.py b/mitmproxy/utils/vt_codes.py index e33a8f249..7e71d446c 100644 --- a/mitmproxy/utils/vt_codes.py +++ b/mitmproxy/utils/vt_codes.py @@ -49,7 +49,6 @@ if os.name == "nt": ) return ok - else: def ensure_supported(f: IO[str]) -> bool: diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 6f301922c..657f2f6cd 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -7,12 +7,14 @@ This module only defines the classes for individual `WebSocketMessage`s and the """ import time import warnings -from dataclasses import dataclass, field -from typing import Union +from dataclasses import dataclass +from dataclasses import field from typing import Optional +from typing import Union + +from wsproto.frame_protocol import Opcode from mitmproxy.coretypes import serializable -from wsproto.frame_protocol import Opcode WebSocketMessageState = tuple[int, bool, bytes, float, bool, bool] diff --git a/release/build-and-deploy-docker.py b/release/build-and-deploy-docker.py index 50a6ce496..f40783468 100644 --- a/release/build-and-deploy-docker.py +++ b/release/build-and-deploy-docker.py @@ -51,7 +51,8 @@ r = subprocess.run( f"{root / 'release'}:/release", "localtesting", "mitmdump", - "-s", "/release/selftest.py", + "-s", + "/release/selftest.py", ], capture_output=True, ) diff --git a/release/build.py b/release/build.py index 60177e31a..6f7f10c77 100644 --- a/release/build.py +++ b/release/build.py @@ -84,7 +84,9 @@ def archive(path: Path) -> tarfile.TarFile | ZipFile2: def version() -> str: - return os.environ.get("GITHUB_REF_NAME", "").replace("/", "-") or os.environ.get("BUILD_VERSION", "dev") + return os.environ.get("GITHUB_REF_NAME", "").replace("/", "-") or os.environ.get( + "BUILD_VERSION", "dev" + ) def operating_system() -> Literal["windows", "linux", "macos", "unknown"]: @@ -170,7 +172,13 @@ def msix_installer(): manifest = TEMP_DIR / "msix/AppxManifest.xml" app_version = version() if not re.match(r"\d+\.\d+\.\d+", app_version): - app_version = datetime.now().strftime("%y%m.%d.%H%M").replace(".0", ".").replace(".0", ".").replace(".0", ".") + app_version = ( + datetime.now() + .strftime("%y%m.%d.%H%M") + .replace(".0", ".") + .replace(".0", ".") + .replace(".0", ".") + ) manifest.write_text(manifest.read_text().replace("1.2.3", app_version)) makeappx_exe = ( @@ -237,7 +245,9 @@ def installbuilder_installer(): break ib_setup_hash.update(data) if ib_setup_hash.hexdigest() != IB_SETUP_SHA256: # pragma: no cover - raise RuntimeError(f"InstallBuilder hashes don't match: {ib_setup_hash.hexdigest()}") + raise RuntimeError( + f"InstallBuilder hashes don't match: {ib_setup_hash.hexdigest()}" + ) print("Install InstallBuilder...") subprocess.run( diff --git a/release/release.py b/release/release.py index c64602526..48c866ead 100755 --- a/release/release.py +++ b/release/release.py @@ -46,7 +46,10 @@ if __name__ == "__main__": branch = subprocess.run( ["git", "branch", "--show-current"], - cwd=root, check=True, capture_output=True, text=True + cwd=root, + check=True, + capture_output=True, + text=True, ).stdout.strip() print("➡️ Working dir clean?") @@ -56,7 +59,12 @@ if __name__ == "__main__": print(f"⚠️ Skipping status check for {branch}.") else: print(f"➡️ CI is passing for {branch}?") - assert get_json(f"https://api.github.com/repos/{repo}/commits/{branch}/status")["state"] == "success" + assert ( + get_json(f"https://api.github.com/repos/{repo}/commits/{branch}/status")[ + "state" + ] + == "success" + ) print("➡️ Updating CHANGELOG.md...") changelog = root / "CHANGELOG.md" @@ -70,7 +78,9 @@ if __name__ == "__main__": print("➡️ Updating web assets...") subprocess.run(["npm", "ci"], cwd=root / "web", check=True, capture_output=True) - subprocess.run(["npm", "start", "prod"], cwd=root / "web", check=True, capture_output=True) + subprocess.run( + ["npm", "start", "prod"], cwd=root / "web", check=True, capture_output=True + ) print("➡️ Updating version...") version_py = root / "mitmproxy" / "version.py" @@ -80,13 +90,22 @@ if __name__ == "__main__": version_py.write_text(ver, "utf8") print("➡️ Do release commit...") - subprocess.run(["git", "config", "user.email", "noreply@mitmproxy.org"], cwd=root, check=True) - subprocess.run(["git", "config", "user.name", "mitmproxy release bot"], cwd=root, check=True) - subprocess.run(["git", "commit", "-a", "-m", f"mitmproxy {version}"], cwd=root, check=True) + subprocess.run( + ["git", "config", "user.email", "noreply@mitmproxy.org"], cwd=root, check=True + ) + subprocess.run( + ["git", "config", "user.name", "mitmproxy release bot"], cwd=root, check=True + ) + subprocess.run( + ["git", "commit", "-a", "-m", f"mitmproxy {version}"], cwd=root, check=True + ) subprocess.run(["git", "tag", version], cwd=root, check=True) release_sha = subprocess.run( ["git", "rev-parse", "HEAD"], - cwd=root, check=True, capture_output=True, text=True + cwd=root, + check=True, + capture_output=True, + text=True, ).stdout.strip() if branch == "main": @@ -97,15 +116,32 @@ if __name__ == "__main__": version_py.write_text(ver, "utf8") print("➡️ Reopen main for development...") - subprocess.run(["git", "commit", "-a", "-m", f"reopen main for development"], cwd=root, check=True) + subprocess.run( + ["git", "commit", "-a", "-m", f"reopen main for development"], + cwd=root, + check=True, + ) print("➡️ Pushing...") - subprocess.run(["git", "push", "--atomic", "origin", branch, version], cwd=root, check=True) + subprocess.run( + ["git", "push", "--atomic", "origin", branch, version], cwd=root, check=True + ) print("➡️ Creating release on GitHub...") - subprocess.run(["gh", "release", "create", version, - "--title", f"mitmproxy {version}", - "--notes-file", "release/github-release-notes.txt"], cwd=root, check=True) + subprocess.run( + [ + "gh", + "release", + "create", + version, + "--title", + f"mitmproxy {version}", + "--notes-file", + "release/github-release-notes.txt", + ], + cwd=root, + check=True, + ) # We currently have to use a personal access token, which auto-triggers CI. # The default GITHUB_TOKEN cannot push to protected branches, @@ -118,7 +154,9 @@ if __name__ == "__main__": while True: print("⌛ Waiting for CI...") - workflows = get_json(f"https://api.github.com/repos/{repo}/actions/runs?head_sha={release_sha}")["workflow_runs"] + workflows = get_json( + f"https://api.github.com/repos/{repo}/actions/runs?head_sha={release_sha}" + )["workflow_runs"] all_done = True if not workflows: @@ -150,16 +188,23 @@ if __name__ == "__main__": assert resp.status == 200 print(f"➡️ Checking Docker ({version} tag)...") - resp = get(f"https://hub.docker.com/v2/repositories/mitmproxy/mitmproxy/tags/{version}") + resp = get( + f"https://hub.docker.com/v2/repositories/mitmproxy/mitmproxy/tags/{version}" + ) assert resp.status == 200 if branch == "main": print("➡️ Checking Docker (latest tag)...") - docker_latest_data = get_json("https://hub.docker.com/v2/repositories/mitmproxy/mitmproxy/tags/latest") + docker_latest_data = get_json( + "https://hub.docker.com/v2/repositories/mitmproxy/mitmproxy/tags/latest" + ) docker_last_updated = datetime.datetime.fromisoformat( - docker_latest_data["last_updated"].replace("Z", "+00:00")) + docker_latest_data["last_updated"].replace("Z", "+00:00") + ) print(f"Last update: {docker_last_updated.isoformat(timespec='minutes')}") - assert docker_last_updated > datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=2) + assert docker_last_updated > datetime.datetime.now( + datetime.timezone.utc + ) - datetime.timedelta(hours=2) print("") print("✅ All done. 🥳") diff --git a/release/selftest.py b/release/selftest.py index cf1130d6a..01eae5ca0 100644 --- a/release/selftest.py +++ b/release/selftest.py @@ -27,10 +27,7 @@ async def make_request(): cafile = Path(ctx.options.confdir).expanduser() / "mitmproxy-ca.pem" ssl_ctx = ssl.create_default_context(cafile=cafile) port = ctx.master.addons.get("proxyserver").listen_addrs()[0][1] - reader, writer = await asyncio.open_connection( - "127.0.0.1", port, - ssl=ssl_ctx - ) + reader, writer = await asyncio.open_connection("127.0.0.1", port, ssl=ssl_ctx) writer.write(b"GET / HTTP/1.1\r\nHost: mitm.it\r\nConnection: close\r\n\r\n") await writer.drain() resp = await reader.read() diff --git a/setup.py b/setup.py index c1c4617fc..187a8bcda 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,8 @@ import os import re from codecs import open -from setuptools import find_packages, setup +from setuptools import find_packages +from setuptools import setup # Based on https://github.com/pypa/sampleproject/blob/main/setup.py # and https://python-packaging-user-guide.readthedocs.org/ @@ -67,7 +68,7 @@ setup( ], "pyinstaller40": [ "hook-dirs = mitmproxy.utils.pyinstaller:hook_dirs", - ] + ], }, python_requires=">=3.9", # https://packaging.python.org/en/latest/discussions/install-requires-vs-requirements/#install-requires diff --git a/test/conftest.py b/test/conftest.py index 89d78c0e9..77be0686b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,4 +1,5 @@ from __future__ import annotations + import asyncio import os import socket diff --git a/test/examples/test_examples.py b/test/examples/test_examples.py index 50b08a196..1cf0bd304 100644 --- a/test/examples/test_examples.py +++ b/test/examples/test_examples.py @@ -1,8 +1,8 @@ from mitmproxy import contentviews +from mitmproxy.http import Headers +from mitmproxy.test import taddons from mitmproxy.test import tflow from mitmproxy.test import tutils -from mitmproxy.test import taddons -from mitmproxy.http import Headers class TestScripts: diff --git a/test/filename_matching.py b/test/filename_matching.py index 3e9878f02..9d64ede25 100755 --- a/test/filename_matching.py +++ b/test/filename_matching.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 - +import glob import os import re -import glob import sys diff --git a/test/full_coverage_plugin.py b/test/full_coverage_plugin.py index 3d1b8b678..7513e6600 100644 --- a/test/full_coverage_plugin.py +++ b/test/full_coverage_plugin.py @@ -1,8 +1,9 @@ -import os import configparser -import pytest +import os import sys +import pytest + here = os.path.abspath(os.path.dirname(__file__)) diff --git a/test/helper_tools/dumperview.py b/test/helper_tools/dumperview.py index 450b7f12f..9c90f784b 100755 --- a/test/helper_tools/dumperview.py +++ b/test/helper_tools/dumperview.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 import asyncio + import click from mitmproxy.addons import dumper -from mitmproxy.test import tflow from mitmproxy.test import taddons +from mitmproxy.test import tflow def run_async(coro): diff --git a/test/helper_tools/getcert b/test/helper_tools/getcert index 43ebf11dc..841fac644 100644 --- a/test/helper_tools/getcert +++ b/test/helper_tools/getcert @@ -2,9 +2,7 @@ import sys sys.path.insert(0, "../..") import socket -import tempfile import ssl -import subprocess addr = socket.gethostbyname(sys.argv[1]) print(ssl.get_server_certificate((addr, 443))) diff --git a/test/helper_tools/linkify-changelog.py b/test/helper_tools/linkify-changelog.py index f0db26175..77558d87c 100644 --- a/test/helper_tools/linkify-changelog.py +++ b/test/helper_tools/linkify-changelog.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from pathlib import Path import re +from pathlib import Path changelog = Path(__file__).parent / "../../CHANGELOG.md" diff --git a/test/helper_tools/loggrep.py b/test/helper_tools/loggrep.py index a986e47c4..c9528f491 100755 --- a/test/helper_tools/loggrep.py +++ b/test/helper_tools/loggrep.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import fileinput -import sys import re +import sys if __name__ == "__main__": if len(sys.argv) < 3: diff --git a/test/helper_tools/memoryleak.py b/test/helper_tools/memoryleak.py index d02482353..12e733ae2 100644 --- a/test/helper_tools/memoryleak.py +++ b/test/helper_tools/memoryleak.py @@ -1,7 +1,9 @@ import gc import threading -from pympler import muppy, refbrowser + from OpenSSL import SSL +from pympler import muppy +from pympler import refbrowser # import os # os.environ["TK_LIBRARY"] = r"C:\Python27\tcl\tcl8.5" diff --git a/test/individual_coverage.py b/test/individual_coverage.py index e3c5ef534..d6d3a40a5 100755 --- a/test/individual_coverage.py +++ b/test/individual_coverage.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 - -import io +import configparser import contextlib +import glob +import io +import itertools +import multiprocessing import os import sys -import glob -import multiprocessing -import configparser -import itertools + import pytest diff --git a/test/mitmproxy/addons/test_anticache.py b/test/mitmproxy/addons/test_anticache.py index b3eb00d33..a0746decc 100644 --- a/test/mitmproxy/addons/test_anticache.py +++ b/test/mitmproxy/addons/test_anticache.py @@ -1,7 +1,6 @@ -from mitmproxy.test import tflow - from mitmproxy.addons import anticache from mitmproxy.test import taddons +from mitmproxy.test import tflow class TestAntiCache: diff --git a/test/mitmproxy/addons/test_anticomp.py b/test/mitmproxy/addons/test_anticomp.py index 92650332c..70a97dbf5 100644 --- a/test/mitmproxy/addons/test_anticomp.py +++ b/test/mitmproxy/addons/test_anticomp.py @@ -1,7 +1,6 @@ -from mitmproxy.test import tflow - from mitmproxy.addons import anticomp from mitmproxy.test import taddons +from mitmproxy.test import tflow class TestAntiComp: diff --git a/test/mitmproxy/addons/test_browser.py b/test/mitmproxy/addons/test_browser.py index 31cbe292a..e5b9c1557 100644 --- a/test/mitmproxy/addons/test_browser.py +++ b/test/mitmproxy/addons/test_browser.py @@ -6,7 +6,9 @@ from mitmproxy.test import taddons def test_browser(caplog): caplog.set_level("INFO") - with mock.patch("subprocess.Popen") as po, mock.patch("shutil.which") as which, taddons.context(): + with mock.patch("subprocess.Popen") as po, mock.patch( + "shutil.which" + ) as which, taddons.context(): which.return_value = "chrome" b = browser.Browser() b.start() diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index 432111a2d..013d6f1b3 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -3,11 +3,14 @@ from contextlib import asynccontextmanager import pytest -from mitmproxy.addons.clientplayback import ClientPlayback, ReplayHandler +from mitmproxy.addons.clientplayback import ClientPlayback +from mitmproxy.addons.clientplayback import ReplayHandler from mitmproxy.addons.proxyserver import Proxyserver -from mitmproxy.exceptions import CommandError, OptionsError from mitmproxy.connection import Address -from mitmproxy.test import taddons, tflow +from mitmproxy.exceptions import CommandError +from mitmproxy.exceptions import OptionsError +from mitmproxy.test import taddons +from mitmproxy.test import tflow @asynccontextmanager diff --git a/test/mitmproxy/addons/test_command_history.py b/test/mitmproxy/addons/test_command_history.py index 915eddf3a..7871e4809 100644 --- a/test/mitmproxy/addons/test_command_history.py +++ b/test/mitmproxy/addons/test_command_history.py @@ -1,6 +1,6 @@ import os -from unittest.mock import patch from pathlib import Path +from unittest.mock import patch from mitmproxy.addons import command_history from mitmproxy.test import taddons diff --git a/test/mitmproxy/addons/test_comment.py b/test/mitmproxy/addons/test_comment.py index ba628cd49..b3c9833bb 100644 --- a/test/mitmproxy/addons/test_comment.py +++ b/test/mitmproxy/addons/test_comment.py @@ -1,5 +1,6 @@ -from mitmproxy.test import tflow, taddons from mitmproxy.addons.comment import Comment +from mitmproxy.test import taddons +from mitmproxy.test import tflow def test_comment(): diff --git a/test/mitmproxy/addons/test_core.py b/test/mitmproxy/addons/test_core.py index fc219769d..be2cb1b13 100644 --- a/test/mitmproxy/addons/test_core.py +++ b/test/mitmproxy/addons/test_core.py @@ -1,8 +1,9 @@ +import pytest + +from mitmproxy import exceptions from mitmproxy.addons import core from mitmproxy.test import taddons from mitmproxy.test import tflow -from mitmproxy import exceptions -import pytest def test_set(): diff --git a/test/mitmproxy/addons/test_cut.py b/test/mitmproxy/addons/test_cut.py index ba045bdfd..d2a30abb4 100644 --- a/test/mitmproxy/addons/test_cut.py +++ b/test/mitmproxy/addons/test_cut.py @@ -1,12 +1,14 @@ +from unittest import mock + +import pyperclip +import pytest + +from mitmproxy import certs +from mitmproxy import exceptions from mitmproxy.addons import cut from mitmproxy.addons import view -from mitmproxy import exceptions -from mitmproxy import certs from mitmproxy.test import taddons from mitmproxy.test import tflow -import pytest -import pyperclip -from unittest import mock def test_extract(tdata): diff --git a/test/mitmproxy/addons/test_disable_h2c.py b/test/mitmproxy/addons/test_disable_h2c.py index 98ec0e3dd..4d55ecfe4 100644 --- a/test/mitmproxy/addons/test_disable_h2c.py +++ b/test/mitmproxy/addons/test_disable_h2c.py @@ -1,7 +1,8 @@ from mitmproxy import flow from mitmproxy.addons import disable_h2c -from mitmproxy.test import taddons, tutils +from mitmproxy.test import taddons from mitmproxy.test import tflow +from mitmproxy.test import tutils class TestDisableH2CleartextUpgrade: diff --git a/test/mitmproxy/addons/test_dns_resolver.py b/test/mitmproxy/addons/test_dns_resolver.py index db91894b4..0937c97a5 100644 --- a/test/mitmproxy/addons/test_dns_resolver.py +++ b/test/mitmproxy/addons/test_dns_resolver.py @@ -6,10 +6,13 @@ from typing import Callable import pytest from mitmproxy import dns -from mitmproxy.addons import dns_resolver, proxyserver +from mitmproxy.addons import dns_resolver +from mitmproxy.addons import proxyserver from mitmproxy.connection import Address from mitmproxy.proxy.mode_specs import ProxyMode -from mitmproxy.test import taddons, tflow, tutils +from mitmproxy.test import taddons +from mitmproxy.test import tflow +from mitmproxy.test import tutils async def test_simple(monkeypatch): diff --git a/test/mitmproxy/addons/test_export.py b/test/mitmproxy/addons/test_export.py index 17187d85e..f3dcb8d76 100644 --- a/test/mitmproxy/addons/test_export.py +++ b/test/mitmproxy/addons/test_export.py @@ -1,15 +1,15 @@ import os import shlex +from unittest import mock -import pytest import pyperclip +import pytest from mitmproxy import exceptions from mitmproxy.addons import export # heh +from mitmproxy.test import taddons from mitmproxy.test import tflow from mitmproxy.test import tutils -from mitmproxy.test import taddons -from unittest import mock @pytest.fixture diff --git a/test/mitmproxy/addons/test_intercept.py b/test/mitmproxy/addons/test_intercept.py index 3cabfda28..0e45a2892 100644 --- a/test/mitmproxy/addons/test_intercept.py +++ b/test/mitmproxy/addons/test_intercept.py @@ -1,7 +1,7 @@ import pytest -from mitmproxy.addons import intercept from mitmproxy import exceptions +from mitmproxy.addons import intercept from mitmproxy.test import taddons from mitmproxy.test import tflow diff --git a/test/mitmproxy/addons/test_keepserving.py b/test/mitmproxy/addons/test_keepserving.py index 99459bf37..8ee85f8ce 100644 --- a/test/mitmproxy/addons/test_keepserving.py +++ b/test/mitmproxy/addons/test_keepserving.py @@ -1,8 +1,8 @@ import asyncio +from mitmproxy import command from mitmproxy.addons import keepserving from mitmproxy.test import taddons -from mitmproxy import command class Dummy: diff --git a/test/mitmproxy/addons/test_maplocal.py b/test/mitmproxy/addons/test_maplocal.py index 917664718..7edf4e8e9 100644 --- a/test/mitmproxy/addons/test_maplocal.py +++ b/test/mitmproxy/addons/test_maplocal.py @@ -3,10 +3,12 @@ from pathlib import Path import pytest -from mitmproxy.addons.maplocal import MapLocal, MapLocalSpec, file_candidates -from mitmproxy.utils.spec import parse_spec +from mitmproxy.addons.maplocal import file_candidates +from mitmproxy.addons.maplocal import MapLocal +from mitmproxy.addons.maplocal import MapLocalSpec from mitmproxy.test import taddons from mitmproxy.test import tflow +from mitmproxy.utils.spec import parse_spec @pytest.mark.parametrize( diff --git a/test/mitmproxy/addons/test_modifyheaders.py b/test/mitmproxy/addons/test_modifyheaders.py index 430c824b9..83a46256f 100644 --- a/test/mitmproxy/addons/test_modifyheaders.py +++ b/test/mitmproxy/addons/test_modifyheaders.py @@ -1,6 +1,7 @@ import pytest -from mitmproxy.addons.modifyheaders import parse_modify_spec, ModifyHeaders +from mitmproxy.addons.modifyheaders import ModifyHeaders +from mitmproxy.addons.modifyheaders import parse_modify_spec from mitmproxy.test import taddons from mitmproxy.test import tflow from mitmproxy.test.tutils import tresp diff --git a/test/mitmproxy/addons/test_next_layer.py b/test/mitmproxy/addons/test_next_layer.py index 4160b88d2..6004c8770 100644 --- a/test/mitmproxy/addons/test_next_layer.py +++ b/test/mitmproxy/addons/test_next_layer.py @@ -5,8 +5,10 @@ import pytest from mitmproxy import connection from mitmproxy.addons.next_layer import NextLayer +from mitmproxy.proxy import context +from mitmproxy.proxy import layer +from mitmproxy.proxy import layers from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy import context, layer, layers from mitmproxy.test import taddons from mitmproxy.test import tflow @@ -14,7 +16,11 @@ from mitmproxy.test import tflow @pytest.fixture def tctx(): context.Context( - connection.Client(peername=("client", 1234), sockname=("127.0.0.1", 8080), timestamp_start=1605699329), + connection.Client( + peername=("client", 1234), + sockname=("127.0.0.1", 8080), + timestamp_start=1605699329, + ), tctx.options, ) @@ -38,8 +44,8 @@ client_hello_with_extensions = bytes.fromhex( dtls_client_hello_with_extensions = bytes.fromhex( - "16fefd00000000000000000085" # record layer - "010000790000000000000079" # handshake layer + "16fefd00000000000000000085" # record layer + "010000790000000000000079" # handshake layer "fefd62bf0e0bf809df43e7669197be831919878b1a72c07a584d3c0a8ca6665878010000000cc02bc02fc00ac014c02cc0" "3001000043000d0010000e0403050306030401050106010807ff01000100000a00080006001d00170018000b00020100001" "7000000000010000e00000b6578616d706c652e636f6d" @@ -185,7 +191,7 @@ class TestNextLayer: [ ("dtls", layers.ClientTLSLayer, layers.ServerTLSLayer), ("quic", layers.ClientQuicLayer, layers.ServerQuicLayer), - ] + ], ) def test_next_layer_udp( self, @@ -200,10 +206,7 @@ class TestNextLayer: return isinstance(layer, layers.UDPLayer) and layer.flow is not None def is_http(layer: Optional[layer.Layer], mode: HTTPMode): - return ( - isinstance(layer, layers.HttpLayer) - and layer.mode is mode - ) + return isinstance(layer, layers.HttpLayer) and layer.mode is mode client_hello = { "dtls": dtls_client_hello_with_extensions, @@ -246,11 +249,15 @@ class TestNextLayer: ctx.layers = [layers.modes.TransparentProxy(ctx)] tctx.configure(nl, udp_hosts=["example.com"]) - assert isinstance(nl._next_layer(ctx, tflow.tdnsreq().packed, b""), layers.UDPLayer) + assert isinstance( + nl._next_layer(ctx, tflow.tdnsreq().packed, b""), layers.UDPLayer + ) ctx.layers = [layers.modes.TransparentProxy(ctx)] tctx.configure(nl, udp_hosts=[]) - assert isinstance(nl._next_layer(ctx, tflow.tdnsreq().packed, b""), layers.DNSLayer) + assert isinstance( + nl._next_layer(ctx, tflow.tdnsreq().packed, b""), layers.DNSLayer + ) def test_next_layer_reverse_raw(self): nl = NextLayer() @@ -273,7 +280,9 @@ class TestNextLayer: ctx.layers = [ layers.modes.ReverseProxy(ctx), layers.ServerQuicLayer(ctx), - layers.ClientQuicLayer(ctx,), + layers.ClientQuicLayer( + ctx, + ), ] assert isinstance(nl._next_layer(ctx, b"", b""), layers.RawQuicLayer) @@ -306,12 +315,16 @@ class TestNextLayer: layers.ServerQuicLayer(ctx), ] assert nl._next_layer(ctx, b"", b"") is None - assert isinstance(nl._next_layer(ctx, b"notahandshake", b""), layers.UDPLayer) + assert isinstance( + nl._next_layer(ctx, b"notahandshake", b""), layers.UDPLayer + ) ctx.layers = [ layers.modes.ReverseProxy(ctx), layers.ServerQuicLayer(ctx), ] - assert isinstance(nl._next_layer(ctx, quic_client_hello, b""), layers.ClientQuicLayer) + assert isinstance( + nl._next_layer(ctx, quic_client_hello, b""), layers.ClientQuicLayer + ) def test_next_layer_reverse_http3_mode(self): nl = NextLayer() @@ -325,7 +338,10 @@ class TestNextLayer: layers.modes.ReverseProxy(ctx), layers.ServerQuicLayer(ctx), ] - assert isinstance(nl._next_layer(ctx, b"notahandshakebutignore", b""), layers.ClientQuicLayer) + assert isinstance( + nl._next_layer(ctx, b"notahandshakebutignore", b""), + layers.ClientQuicLayer, + ) assert len(ctx.layers) == 3 decision = nl._next_layer(ctx, b"", b"") assert isinstance(decision, layers.HttpLayer) @@ -352,7 +368,10 @@ class TestNextLayer: ctx.layers = [layers.modes.ReverseProxy(ctx), layers.ServerTLSLayer(ctx)] assert isinstance(nl._next_layer(ctx, b"", b""), layers.UDPLayer) ctx.layers = [layers.modes.ReverseProxy(ctx), layers.ServerTLSLayer(ctx)] - assert isinstance(nl._next_layer(ctx, dtls_client_hello_with_extensions, b""), layers.ClientTLSLayer) + assert isinstance( + nl._next_layer(ctx, dtls_client_hello_with_extensions, b""), + layers.ClientTLSLayer, + ) assert len(ctx.layers) == 3 assert isinstance(nl._next_layer(ctx, b"", b""), layers.UDPLayer) @@ -366,7 +385,10 @@ class TestNextLayer: ctx.layers = [layers.modes.ReverseProxy(ctx)] assert isinstance(nl._next_layer(ctx, b"", b""), layers.UDPLayer) ctx.layers = [layers.modes.ReverseProxy(ctx)] - assert isinstance(nl._next_layer(ctx, dtls_client_hello_with_extensions, b""), layers.ClientTLSLayer) + assert isinstance( + nl._next_layer(ctx, dtls_client_hello_with_extensions, b""), + layers.ClientTLSLayer, + ) assert len(ctx.layers) == 2 assert isinstance(nl._next_layer(ctx, b"", b""), layers.UDPLayer) @@ -380,7 +402,10 @@ class TestNextLayer: ctx.layers = [layers.modes.ReverseProxy(ctx)] assert isinstance(nl._next_layer(ctx, b"", b""), layers.DNSLayer) ctx.layers = [layers.modes.ReverseProxy(ctx)] - assert isinstance(nl._next_layer(ctx, dtls_client_hello_with_extensions, b""), layers.ClientTLSLayer) + assert isinstance( + nl._next_layer(ctx, dtls_client_hello_with_extensions, b""), + layers.ClientTLSLayer, + ) assert len(ctx.layers) == 2 assert isinstance(nl._next_layer(ctx, b"", b""), layers.DNSLayer) diff --git a/test/mitmproxy/addons/test_proxyserver.py b/test/mitmproxy/addons/test_proxyserver.py index 5912d8eac..3027dc920 100644 --- a/test/mitmproxy/addons/test_proxyserver.py +++ b/test/mitmproxy/addons/test_proxyserver.py @@ -1,35 +1,46 @@ from __future__ import annotations import asyncio -from contextlib import asynccontextmanager -from dataclasses import dataclass import socket import ssl -from typing import Any, AsyncGenerator, Callable, ClassVar, Optional, TypeVar +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any +from typing import Callable +from typing import ClassVar +from typing import Optional +from typing import TypeVar from unittest.mock import Mock +import pytest from aioquic.asyncio.protocol import QuicConnectionProtocol from aioquic.asyncio.server import QuicServer from aioquic.h3 import events as h3_events -from aioquic.h3.connection import H3Connection, FrameUnexpected +from aioquic.h3.connection import FrameUnexpected +from aioquic.h3.connection import H3Connection from aioquic.quic import events as quic_events from aioquic.quic.configuration import QuicConfiguration -from aioquic.quic.connection import QuicConnection, QuicConnectionError -import pytest -from mitmproxy.addons.next_layer import NextLayer -from mitmproxy.addons.tlsconfig import TlsConfig +from aioquic.quic.connection import QuicConnection +from aioquic.quic.connection import QuicConnectionError import mitmproxy.platform -from mitmproxy import dns, exceptions +from mitmproxy import dns +from mitmproxy import exceptions from mitmproxy.addons import dns_resolver +from mitmproxy.addons.next_layer import NextLayer from mitmproxy.addons.proxyserver import Proxyserver +from mitmproxy.addons.tlsconfig import TlsConfig from mitmproxy.connection import Address from mitmproxy.net import udp -from mitmproxy.proxy import layers, server_hooks +from mitmproxy.proxy import layers +from mitmproxy.proxy import server_hooks from mitmproxy.proxy.layers import tls from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.test import taddons, tflow -from mitmproxy.test.tflow import tclient_conn, tserver_conn +from mitmproxy.test import taddons +from mitmproxy.test import tflow +from mitmproxy.test.tflow import tclient_conn +from mitmproxy.test.tflow import tserver_conn from mitmproxy.test.tutils import tdnsreq from mitmproxy.utils import data @@ -298,7 +309,9 @@ async def test_dns(caplog_async) -> None: resp = dns.Message.unpack(await r.read(udp.MAX_DATAGRAM_SIZE)) assert req.id == resp.id and "8.8.8.8" in str(resp) assert len(ps.connections) == 1 - dns_layer = ps.connections[("udp", w.get_extra_info("sockname"), dns_addr)].layer + dns_layer = ps.connections[ + ("udp", w.get_extra_info("sockname"), dns_addr) + ].layer assert isinstance(dns_layer, layers.DNSLayer) assert len(dns_layer.flows) == 2 @@ -339,10 +352,10 @@ async def test_dtls(monkeypatch, caplog_async) -> None: caplog_async.set_level("INFO") def server_handler( - transport: asyncio.DatagramTransport, - data: bytes, - remote_addr: Address, - _: Address, + transport: asyncio.DatagramTransport, + data: bytes, + remote_addr: Address, + _: Address, ): assert data == b"\x16" transport.sendto(b"\x01", remote_addr) @@ -360,7 +373,9 @@ async def test_dtls(monkeypatch, caplog_async) -> None: tctx.configure(ps, mode=[mode]) assert await ps.setup_servers() ps.running() - await caplog_async.await_log(f"reverse proxy to dtls://{server_addr[0]}:{server_addr[1]} listening") + await caplog_async.await_log( + f"reverse proxy to dtls://{server_addr[0]}:{server_addr[1]} listening" + ) assert ps.servers addr = ps.servers[mode].listen_addrs[0] r, w = await udp.open_connection(*addr) @@ -392,9 +407,7 @@ class H3EchoServer(QuicConnectionProtocol): response.append((b":status", b"200")) response.append((b"x-response", headers[b"x-request"])) self.http.send_headers( - stream_id=event.stream_id, - headers=response, - end_stream=event.stream_ended + stream_id=event.stream_id, headers=response, end_stream=event.stream_ended ) self.transmit() @@ -441,7 +454,9 @@ class QuicDatagramEchoServer(QuicConnectionProtocol): @asynccontextmanager -async def quic_server(create_protocol, alpn: list[str]) -> AsyncGenerator[Address, None]: +async def quic_server( + create_protocol, alpn: list[str] +) -> AsyncGenerator[Address, None]: configuration = QuicConfiguration( is_client=False, alpn_protocols=alpn, @@ -475,9 +490,11 @@ class QuicClient(QuicConnectionProtocol): def quic_event_received(self, event: quic_events.QuicEvent) -> None: if not self._waiter.done(): if isinstance(event, quic_events.ConnectionTerminated): - self._waiter.set_exception(QuicConnectionError( - event.error_code, event.frame_type, event.reason_phrase - )) + self._waiter.set_exception( + QuicConnectionError( + event.error_code, event.frame_type, event.reason_phrase + ) + ) elif isinstance(event, quic_events.HandshakeCompleted): self._waiter.set_result(None) @@ -501,9 +518,11 @@ class QuicDatagramClient(QuicClient): if isinstance(event, quic_events.DatagramFrameReceived): self._datagram.set_result(event.data) elif isinstance(event, quic_events.ConnectionTerminated): - self._datagram.set_exception(QuicConnectionError( - event.error_code, event.frame_type, event.reason_phrase - )) + self._datagram.set_exception( + QuicConnectionError( + event.error_code, event.frame_type, event.reason_phrase + ) + ) def send_datagram(self, data: bytes) -> None: self._quic.send_datagram_frame(data) @@ -532,7 +551,6 @@ class H3Response: class H3Client(QuicClient): - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._responses: dict[int, H3Response] = dict() @@ -765,7 +783,9 @@ async def test_reverse_http3_and_quic_stream( tctx.configure(ps, mode=[mode]) assert await ps.setup_servers() ps.running() - await caplog_async.await_log(f"reverse proxy to {scheme}://{server_addr[0]}:{server_addr[1]} listening") + await caplog_async.await_log( + f"reverse proxy to {scheme}://{server_addr[0]}:{server_addr[1]} listening" + ) assert ps.servers addr = ps.servers[mode].listen_addrs[0] async with quic_connect(H3Client, alpn=["h3"], address=addr) as client: @@ -797,10 +817,14 @@ async def test_reverse_quic_datagram(caplog_async, connection_strategy: str) -> tctx.configure(ps, mode=[mode]) assert await ps.setup_servers() ps.running() - await caplog_async.await_log(f"reverse proxy to quic://{server_addr[0]}:{server_addr[1]} listening") + await caplog_async.await_log( + f"reverse proxy to quic://{server_addr[0]}:{server_addr[1]} listening" + ) assert ps.servers addr = ps.servers[mode].listen_addrs[0] - async with quic_connect(QuicDatagramClient, alpn=["dgram"], address=addr) as client: + async with quic_connect( + QuicDatagramClient, alpn=["dgram"], address=addr + ) as client: client.send_datagram(b"echo") assert await client.recv_datagram() == b"echo" diff --git a/test/mitmproxy/addons/test_readfile.py b/test/mitmproxy/addons/test_readfile.py index 5c01e95e8..78513d586 100644 --- a/test/mitmproxy/addons/test_readfile.py +++ b/test/mitmproxy/addons/test_readfile.py @@ -1,8 +1,8 @@ import asyncio import io +from unittest import mock import pytest -from unittest import mock import mitmproxy.io from mitmproxy import exceptions diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py index 8361409ac..b14f5ccac 100644 --- a/test/mitmproxy/addons/test_save.py +++ b/test/mitmproxy/addons/test_save.py @@ -196,4 +196,4 @@ def test_disk_full(tmp_path, monkeypatch, capsys): with pytest.raises(SystemExit): sa.response(f) - assert "Error while writing" in capsys.readouterr().err \ No newline at end of file + assert "Error while writing" in capsys.readouterr().err diff --git a/test/mitmproxy/addons/test_script.py b/test/mitmproxy/addons/test_script.py index 0971fa0e9..678550a99 100644 --- a/test/mitmproxy/addons/test_script.py +++ b/test/mitmproxy/addons/test_script.py @@ -129,7 +129,10 @@ class TestScript: tdata.path("mitmproxy/data/addonscripts/import_error.py"), False, ) - assert "Note that mitmproxy's binaries include their own Python environment" in caplog.text + assert ( + "Note that mitmproxy's binaries include their own Python environment" + in caplog.text + ) async def test_optionexceptions(self, tdata, caplog_async): with taddons.context() as tctx: @@ -183,7 +186,9 @@ class TestScriptLoader: with taddons.context(sc): sc.script_run([tflow.tflow(resp=True)], rp) await caplog_async.await_log("recorder response") - debug = [i.msg for i in caplog_async.caplog.records if i.levelname == "DEBUG"] + debug = [ + i.msg for i in caplog_async.caplog.records if i.levelname == "DEBUG" + ] assert debug == [ "recorder configure", "recorder running", @@ -267,7 +272,9 @@ class TestScriptLoader: ], ) await caplog_async.await_log("configure") - debug = [i.msg for i in caplog_async.caplog.records if i.levelname == "DEBUG"] + debug = [ + i.msg for i in caplog_async.caplog.records if i.levelname == "DEBUG" + ] assert debug == [ "a load", "a configure", @@ -291,7 +298,9 @@ class TestScriptLoader: ) await caplog_async.await_log("b configure") - debug = [i.msg for i in caplog_async.caplog.records if i.levelname == "DEBUG"] + debug = [ + i.msg for i in caplog_async.caplog.records if i.levelname == "DEBUG" + ] assert debug == [ "c configure", "a configure", @@ -307,7 +316,9 @@ class TestScriptLoader: ], ) await caplog_async.await_log("e configure") - debug = [i.msg for i in caplog_async.caplog.records if i.levelname == "DEBUG"] + debug = [ + i.msg for i in caplog_async.caplog.records if i.levelname == "DEBUG" + ] assert debug == [ "c done", "b done", diff --git a/test/mitmproxy/addons/test_stickyauth.py b/test/mitmproxy/addons/test_stickyauth.py index 7b422fdd1..a684b8162 100644 --- a/test/mitmproxy/addons/test_stickyauth.py +++ b/test/mitmproxy/addons/test_stickyauth.py @@ -1,10 +1,9 @@ import pytest -from mitmproxy.test import tflow -from mitmproxy.test import taddons - -from mitmproxy.addons import stickyauth from mitmproxy import exceptions +from mitmproxy.addons import stickyauth +from mitmproxy.test import taddons +from mitmproxy.test import tflow def test_configure(): diff --git a/test/mitmproxy/addons/test_stickycookie.py b/test/mitmproxy/addons/test_stickycookie.py index d3edbbdb7..906087c0d 100644 --- a/test/mitmproxy/addons/test_stickycookie.py +++ b/test/mitmproxy/addons/test_stickycookie.py @@ -1,9 +1,8 @@ import pytest -from mitmproxy.test import tflow -from mitmproxy.test import taddons - from mitmproxy.addons import stickycookie +from mitmproxy.test import taddons +from mitmproxy.test import tflow from mitmproxy.test import tutils as ntutils diff --git a/test/mitmproxy/addons/test_termlog.py b/test/mitmproxy/addons/test_termlog.py index 62573e218..0ebf3601a 100644 --- a/test/mitmproxy/addons/test_termlog.py +++ b/test/mitmproxy/addons/test_termlog.py @@ -13,10 +13,7 @@ from mitmproxy.utils import vt_codes @pytest.fixture(autouse=True) def ensure_cleanup(): yield - assert not any( - isinstance(x, termlog.TermLogHandler) - for x in logging.root.handlers - ) + assert not any(isinstance(x, termlog.TermLogHandler) for x in logging.root.handlers) async def test_delayed_teardown(): diff --git a/test/mitmproxy/addons/test_tlsconfig.py b/test/mitmproxy/addons/test_tlsconfig.py index 9bd6142b9..1da747d04 100644 --- a/test/mitmproxy/addons/test_tlsconfig.py +++ b/test/mitmproxy/addons/test_tlsconfig.py @@ -4,15 +4,21 @@ from pathlib import Path from typing import Union import pytest - from cryptography import x509 from OpenSSL import SSL -from mitmproxy import certs, connection, tls, options + +from mitmproxy import certs +from mitmproxy import connection +from mitmproxy import options +from mitmproxy import tls from mitmproxy.addons import tlsconfig from mitmproxy.proxy import context -from mitmproxy.proxy.layers import modes, quic, tls as proxy_tls +from mitmproxy.proxy.layers import modes +from mitmproxy.proxy.layers import quic +from mitmproxy.proxy.layers import tls as proxy_tls from mitmproxy.test import taddons -from test.mitmproxy.proxy.layers import test_quic, test_tls +from test.mitmproxy.proxy.layers import test_quic +from test.mitmproxy.proxy.layers import test_tls def test_alpn_select_callback(): @@ -63,7 +69,11 @@ here = Path(__file__).parent def _ctx(opts: options.Options) -> context.Context: return context.Context( - connection.Client(peername=("client", 1234), sockname=("127.0.0.1", 8080), timestamp_start=1605699329), + connection.Client( + peername=("client", 1234), + sockname=("127.0.0.1", 8080), + timestamp_start=1605699329, + ), opts, ) @@ -172,10 +182,7 @@ class TestTlsConfig: tssl_server.write(tssl_client.read()) tssl_client.write(tssl_server.read()) tssl_server.write(tssl_client.read()) - return ( - tssl_client.handshake_completed() - and tssl_server.handshake_completed() - ) + return tssl_client.handshake_completed() and tssl_server.handshake_completed() def test_tls_start_client(self, tdata): ta = tlsconfig.TlsConfig() @@ -229,8 +236,12 @@ class TestTlsConfig: tssl_client = test_quic.SSLTest(alpn=["h3"]) assert self.quic_do_handshake(tssl_client, tssl_server) - san = tssl_client.quic.tls._peer_certificate.extensions.get_extension_for_class(x509.SubjectAlternativeName) - assert san.value.get_values_for_type(x509.DNSName) == ["example.mitmproxy.org"] + san = tssl_client.quic.tls._peer_certificate.extensions.get_extension_for_class( + x509.SubjectAlternativeName + ) + assert san.value.get_values_for_type(x509.DNSName) == [ + "example.mitmproxy.org" + ] def test_tls_start_server_cannot_verify(self): ta = tlsconfig.TlsConfig() @@ -240,7 +251,9 @@ class TestTlsConfig: ctx.server.sni = "" # explicitly opt out of using the address. tls_start = tls.TlsData(ctx.server, context=ctx) - with pytest.raises(ValueError, match="Cannot validate certificate hostname without SNI"): + with pytest.raises( + ValueError, match="Cannot validate certificate hostname without SNI" + ): ta.tls_start_server(tls_start) def test_tls_start_server_verify_failed(self): @@ -305,7 +318,9 @@ class TestTlsConfig: ta.quic_start_server(tls_start) assert settings_client is tls_start.settings - tssl_server = test_quic.SSLTest(server_side=True, sni=hostname.encode(), alpn=["h3"]) + tssl_server = test_quic.SSLTest( + server_side=True, sni=hostname.encode(), alpn=["h3"] + ) assert self.quic_do_handshake(tssl_client, tssl_server) def test_tls_start_server_insecure(self): diff --git a/test/mitmproxy/addons/test_upstream_auth.py b/test/mitmproxy/addons/test_upstream_auth.py index 1eb8eb013..883dabc2f 100644 --- a/test/mitmproxy/addons/test_upstream_auth.py +++ b/test/mitmproxy/addons/test_upstream_auth.py @@ -1,11 +1,12 @@ import base64 + import pytest from mitmproxy import exceptions +from mitmproxy.addons import upstream_auth from mitmproxy.proxy.mode_specs import ProxyMode from mitmproxy.test import taddons from mitmproxy.test import tflow -from mitmproxy.addons import upstream_auth def test_configure(): diff --git a/test/mitmproxy/addons/test_view.py b/test/mitmproxy/addons/test_view.py index 7c1b041cf..199fc00c4 100644 --- a/test/mitmproxy/addons/test_view.py +++ b/test/mitmproxy/addons/test_view.py @@ -1,14 +1,14 @@ import pytest -from mitmproxy.test import tflow - -from mitmproxy.addons import view -from mitmproxy import flowfilter from mitmproxy import exceptions +from mitmproxy import flowfilter from mitmproxy import io +from mitmproxy.addons import view from mitmproxy.test import taddons +from mitmproxy.test import tflow from mitmproxy.tools.console import consoleaddons -from mitmproxy.tools.console.common import render_marker, SYMBOL_MARK +from mitmproxy.tools.console.common import render_marker +from mitmproxy.tools.console.common import SYMBOL_MARK def tft(*, method="get", start=0): diff --git a/test/mitmproxy/contentviews/image/test_view.py b/test/mitmproxy/contentviews/image/test_view.py index 67c4b81b4..61c0c6379 100644 --- a/test/mitmproxy/contentviews/image/test_view.py +++ b/test/mitmproxy/contentviews/image/test_view.py @@ -1,5 +1,5 @@ -from mitmproxy.contentviews import image from .. import full_eval +from mitmproxy.contentviews import image def test_view_image(tdata): diff --git a/test/mitmproxy/contentviews/test_auto.py b/test/mitmproxy/contentviews/test_auto.py index 459d839f0..5dfbe2aaf 100644 --- a/test/mitmproxy/contentviews/test_auto.py +++ b/test/mitmproxy/contentviews/test_auto.py @@ -1,6 +1,6 @@ +from . import full_eval from mitmproxy.contentviews import auto from mitmproxy.test import tflow -from . import full_eval def test_view_auto(): diff --git a/test/mitmproxy/contentviews/test_base.py b/test/mitmproxy/contentviews/test_base.py index cd879bfda..efa971534 100644 --- a/test/mitmproxy/contentviews/test_base.py +++ b/test/mitmproxy/contentviews/test_base.py @@ -1,4 +1,5 @@ import pytest + from mitmproxy.contentviews import base diff --git a/test/mitmproxy/contentviews/test_css.py b/test/mitmproxy/contentviews/test_css.py index 7474a6b36..a2192d12f 100644 --- a/test/mitmproxy/contentviews/test_css.py +++ b/test/mitmproxy/contentviews/test_css.py @@ -1,7 +1,7 @@ import pytest -from mitmproxy.contentviews import css from . import full_eval +from mitmproxy.contentviews import css @pytest.mark.parametrize( diff --git a/test/mitmproxy/contentviews/test_graphql.py b/test/mitmproxy/contentviews/test_graphql.py index a38eedea0..89beda814 100644 --- a/test/mitmproxy/contentviews/test_graphql.py +++ b/test/mitmproxy/contentviews/test_graphql.py @@ -1,8 +1,8 @@ from hypothesis import given from hypothesis.strategies import binary -from mitmproxy.contentviews import graphql from . import full_eval +from mitmproxy.contentviews import graphql def test_render_priority(): diff --git a/test/mitmproxy/contentviews/test_grpc.py b/test/mitmproxy/contentviews/test_grpc.py index 8d296fe65..6a8803526 100644 --- a/test/mitmproxy/contentviews/test_grpc.py +++ b/test/mitmproxy/contentviews/test_grpc.py @@ -1,16 +1,16 @@ +import struct + import pytest -from mitmproxy.contentviews import grpc -from mitmproxy.contentviews.grpc import ( - ViewGrpcProtobuf, - ViewConfig, - ProtoParser, - parse_grpc_messages, -) -from mitmproxy.net.encoding import encode -from mitmproxy.test import tflow, tutils -import struct from . import full_eval +from mitmproxy.contentviews import grpc +from mitmproxy.contentviews.grpc import parse_grpc_messages +from mitmproxy.contentviews.grpc import ProtoParser +from mitmproxy.contentviews.grpc import ViewConfig +from mitmproxy.contentviews.grpc import ViewGrpcProtobuf +from mitmproxy.net.encoding import encode +from mitmproxy.test import tflow +from mitmproxy.test import tutils datadir = "mitmproxy/contentviews/test_grpc_data/" diff --git a/test/mitmproxy/contentviews/test_hex.py b/test/mitmproxy/contentviews/test_hex.py index 90db4bd7c..8eadf8436 100644 --- a/test/mitmproxy/contentviews/test_hex.py +++ b/test/mitmproxy/contentviews/test_hex.py @@ -1,5 +1,5 @@ -from mitmproxy.contentviews import hex from . import full_eval +from mitmproxy.contentviews import hex def test_view_hex(): diff --git a/test/mitmproxy/contentviews/test_http3.py b/test/mitmproxy/contentviews/test_http3.py index 157ee6914..d1b9fc641 100644 --- a/test/mitmproxy/contentviews/test_http3.py +++ b/test/mitmproxy/contentviews/test_http3.py @@ -1,58 +1,59 @@ import pytest +from . import full_eval +from mitmproxy.contentviews import http3 from mitmproxy.tcp import TCPMessage from mitmproxy.test import tflow -from mitmproxy.contentviews import http3 - -from . import full_eval if http3 is None: pytest.skip("HTTP/3 not available.", allow_module_level=True) -@pytest.mark.parametrize("data", [ - # HEADERS - b"\x01\x1d\x00\x00\xd1\xc1\xd7P\x8a\x08\x9d\\\x0b\x81p\xdcx\x0f\x03_P\x88%\xb6P\xc3\xab\xbc\xda\xe0\xdd", - # broken HEADERS - b"\x01\x1d\x00\x00\xd1\xc1\xd7P\x8a\x08\x9d\\\x0b\x81p\xdcx\x0f\x03_P\x88%\xb6P\xc3\xab\xff\xff\xff\xff", - # headers + data - ( - b'\x01@I\x00\x00\xdb_\'\x93I|\xa5\x89\xd3M\x1fj\x12q\xd8\x82\xa6\x0bP\xb0\xd0C\x1b_M\x90\xd0bXt\x1eT\xad\x8f~\xfdp' - b'\xeb\xc8\xc0\x97\x07V\x96\xd0z\xbe\x94\x08\x94\xdcZ\xd4\x10\x04%\x02\xe5\xc6\xde\xb8\x17\x14\xc5\xa3\x7fT\x03315' - b'\x00A;\r\n<' - b'TITLE>Not Found\r\n\r\n

Not Found

\r\n

HTTP Error 404. The requested resource is not found.

\r\n\r\n' - ), - b"", -]) +@pytest.mark.parametrize( + "data", + [ + # HEADERS + b"\x01\x1d\x00\x00\xd1\xc1\xd7P\x8a\x08\x9d\\\x0b\x81p\xdcx\x0f\x03_P\x88%\xb6P\xc3\xab\xbc\xda\xe0\xdd", + # broken HEADERS + b"\x01\x1d\x00\x00\xd1\xc1\xd7P\x8a\x08\x9d\\\x0b\x81p\xdcx\x0f\x03_P\x88%\xb6P\xc3\xab\xff\xff\xff\xff", + # headers + data + ( + b"\x01@I\x00\x00\xdb_'\x93I|\xa5\x89\xd3M\x1fj\x12q\xd8\x82\xa6\x0bP\xb0\xd0C\x1b_M\x90\xd0bXt\x1eT\xad\x8f~\xfdp" + b"\xeb\xc8\xc0\x97\x07V\x96\xd0z\xbe\x94\x08\x94\xdcZ\xd4\x10\x04%\x02\xe5\xc6\xde\xb8\x17\x14\xc5\xa3\x7fT\x03315" + b'\x00A;\r\n<' + b'TITLE>Not Found\r\n\r\n

Not Found

\r\n

HTTP Error 404. The requested resource is not found.

\r\n\r\n" + ), + b"", + ], +) def test_view_http3(data): v = full_eval(http3.ViewHttp3()) - t = tflow.ttcpflow(messages=[ - TCPMessage(from_client=len(data) > 16, content=data) - ]) + t = tflow.ttcpflow(messages=[TCPMessage(from_client=len(data) > 16, content=data)]) t.metadata["quic_is_unidirectional"] = False - assert (v(b"", flow=t, tcp_message=t.messages[0])) + assert v(b"", flow=t, tcp_message=t.messages[0]) -@pytest.mark.parametrize("data", [ - # SETTINGS - b"\x00\x04\r\x06\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x07\x00", - # unknown setting - b"\x00\x04\r\x3f\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x07\x00", - # out of bounds - b"\x00\x04\r\x06\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x42\x00", - # incomplete - b"\x00\x04\r\x06\xff\xff\xff", - # QPACK encoder stream - b"\x02", -]) +@pytest.mark.parametrize( + "data", + [ + # SETTINGS + b"\x00\x04\r\x06\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x07\x00", + # unknown setting + b"\x00\x04\r\x3f\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x07\x00", + # out of bounds + b"\x00\x04\r\x06\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x42\x00", + # incomplete + b"\x00\x04\r\x06\xff\xff\xff", + # QPACK encoder stream + b"\x02", + ], +) def test_view_http3_unidirectional(data): v = full_eval(http3.ViewHttp3()) - t = tflow.ttcpflow(messages=[ - TCPMessage(from_client=len(data) > 16, content=data) - ]) + t = tflow.ttcpflow(messages=[TCPMessage(from_client=len(data) > 16, content=data)]) t.metadata["quic_is_unidirectional"] = True - assert (v(b"", flow=t, tcp_message=t.messages[0])) + assert v(b"", flow=t, tcp_message=t.messages[0]) def test_render_priority(): diff --git a/test/mitmproxy/contentviews/test_javascript.py b/test/mitmproxy/contentviews/test_javascript.py index c050adee4..64647446d 100644 --- a/test/mitmproxy/contentviews/test_javascript.py +++ b/test/mitmproxy/contentviews/test_javascript.py @@ -1,7 +1,7 @@ import pytest -from mitmproxy.contentviews import javascript from . import full_eval +from mitmproxy.contentviews import javascript def test_view_javascript(): diff --git a/test/mitmproxy/contentviews/test_json.py b/test/mitmproxy/contentviews/test_json.py index 5b3883060..9711a6465 100644 --- a/test/mitmproxy/contentviews/test_json.py +++ b/test/mitmproxy/contentviews/test_json.py @@ -1,8 +1,8 @@ from hypothesis import given from hypothesis.strategies import binary -from mitmproxy.contentviews import json from . import full_eval +from mitmproxy.contentviews import json def test_parse_json(): @@ -18,31 +18,66 @@ def test_parse_json(): def test_format_json(): assert list(json.format_json({"data": ["str", 42, True, False, None, {}, []]})) assert list(json.format_json({"string": "test"})) == [ - [('text', '{'), ('text', '')], - [('text', ' '), ('Token_Name_Tag', '"string"'), ('text', ': '), ('Token_Literal_String', '"test"'), ('text', '')], - [('text', ''), ('text', '}')]] + [("text", "{"), ("text", "")], + [ + ("text", " "), + ("Token_Name_Tag", '"string"'), + ("text", ": "), + ("Token_Literal_String", '"test"'), + ("text", ""), + ], + [("text", ""), ("text", "}")], + ] assert list(json.format_json({"num": 4})) == [ - [('text', '{'), ('text', '')], - [('text', ' '), ('Token_Name_Tag', '"num"'), ('text', ': '), ('Token_Literal_Number', '4'), ('text', '')], - [('text', ''), ('text', '}')]] + [("text", "{"), ("text", "")], + [ + ("text", " "), + ("Token_Name_Tag", '"num"'), + ("text", ": "), + ("Token_Literal_Number", "4"), + ("text", ""), + ], + [("text", ""), ("text", "}")], + ] assert list(json.format_json({"bool": True})) == [ - [('text', '{'), ('text', '')], - [('text', ' '), ('Token_Name_Tag', '"bool"'), ('text', ': '), ('Token_Keyword_Constant', 'true'), ('text', '')], - [('text', ''), ('text', '}')]] + [("text", "{"), ("text", "")], + [ + ("text", " "), + ("Token_Name_Tag", '"bool"'), + ("text", ": "), + ("Token_Keyword_Constant", "true"), + ("text", ""), + ], + [("text", ""), ("text", "}")], + ] assert list(json.format_json({"object": {"int": 1}})) == [ - [('text', '{'), ('text', '')], - [('text', ' '), ('Token_Name_Tag', '"object"'), ('text', ': '), ('text', '{'), ('text', '')], - [('text', ' '), ('Token_Name_Tag', '"int"'), ('text', ': '), ('Token_Literal_Number', '1'), ('text', '')], - [('text', ' '), ('text', '}'), ('text', '')], - [('text', ''), ('text', '}')]] + [("text", "{"), ("text", "")], + [ + ("text", " "), + ("Token_Name_Tag", '"object"'), + ("text", ": "), + ("text", "{"), + ("text", ""), + ], + [ + ("text", " "), + ("Token_Name_Tag", '"int"'), + ("text", ": "), + ("Token_Literal_Number", "1"), + ("text", ""), + ], + [("text", " "), ("text", "}"), ("text", "")], + [("text", ""), ("text", "}")], + ] assert list(json.format_json({"list": ["string", 1, True]})) == [ - [('text', '{'), ('text', '')], - [('text', ' '), ('Token_Name_Tag', '"list"'), ('text', ': '), ('text', '[')], - [('Token_Literal_String', ' "string"'), ('text', ',')], - [('Token_Literal_Number', ' 1'), ('text', ',')], - [('Token_Keyword_Constant', ' true'), ('text', '')], - [('text', ' '), ('text', ']'), ('text', '')], - [('text', ''), ('text', '}')]] + [("text", "{"), ("text", "")], + [("text", " "), ("Token_Name_Tag", '"list"'), ("text", ": "), ("text", "[")], + [("Token_Literal_String", ' "string"'), ("text", ",")], + [("Token_Literal_Number", " 1"), ("text", ",")], + [("Token_Keyword_Constant", " true"), ("text", "")], + [("text", " "), ("text", "]"), ("text", "")], + [("text", ""), ("text", "}")], + ] def test_view_json(): diff --git a/test/mitmproxy/contentviews/test_mqtt.py b/test/mitmproxy/contentviews/test_mqtt.py index 7acc33541..87cc09d40 100644 --- a/test/mitmproxy/contentviews/test_mqtt.py +++ b/test/mitmproxy/contentviews/test_mqtt.py @@ -1,7 +1,7 @@ import pytest -from mitmproxy.contentviews import mqtt from . import full_eval +from mitmproxy.contentviews import mqtt @pytest.mark.parametrize( @@ -9,8 +9,14 @@ from . import full_eval [ pytest.param(b"\xC0\x00", "[PINGREQ]", id="PINGREQ"), pytest.param(b"\xD0\x00", "[PINGRESP]", id="PINGRESP"), - pytest.param(b"\x90\x00", "Packet type SUBACK is not supported yet!", id="SUBACK"), - pytest.param(b"\xA0\x00", "Packet type UNSUBSCRIBE is not supported yet!", id="UNSUBSCRIBE"), + pytest.param( + b"\x90\x00", "Packet type SUBACK is not supported yet!", id="SUBACK" + ), + pytest.param( + b"\xA0\x00", + "Packet type UNSUBSCRIBE is not supported yet!", + id="UNSUBSCRIBE", + ), pytest.param( b"\x82\x31\x00\x03\x00\x2cxxxx/yy/zzzzzz/56:6F:5E:6A:01:05/messages/in\x01", "[SUBSCRIBE] sent topic filters: 'xxxx/yy/zzzzzz/56:6F:5E:6A:01:05/messages/in'", @@ -52,10 +58,7 @@ def test_view_mqtt(data, expected_text): assert output == [[("text", expected_text)]] -@pytest.mark.parametrize( - "data", - [b"\xC0\xFF\xFF\xFF\xFF"] -) +@pytest.mark.parametrize("data", [b"\xC0\xFF\xFF\xFF\xFF"]) def test_mqtt_malformed(data): v = full_eval(mqtt.ViewMQTT()) with pytest.raises(Exception): diff --git a/test/mitmproxy/contentviews/test_msgpack.py b/test/mitmproxy/contentviews/test_msgpack.py index eeba8b2d1..65c8487f0 100644 --- a/test/mitmproxy/contentviews/test_msgpack.py +++ b/test/mitmproxy/contentviews/test_msgpack.py @@ -1,10 +1,9 @@ from hypothesis import given from hypothesis.strategies import binary - from msgpack import packb -from mitmproxy.contentviews import msgpack from . import full_eval +from mitmproxy.contentviews import msgpack def msgpack_encode(content): @@ -18,32 +17,86 @@ def test_parse_msgpack(): def test_format_msgpack(): - assert list(msgpack.format_msgpack({"string": "test", "int": 1, "float": 1.44, "bool": True})) == [ - [('text', '{')], - [('text', ''), ('text', ' '), ('Token_Name_Tag', '"string"'), ('text', ': '), ('Token_Literal_String', '"test"'), ('text', ',')], - [('text', ''), ('text', ' '), ('Token_Name_Tag', '"int"'), ('text', ': '), ('Token_Literal_Number', '1'), ('text', ',')], - [('text', ''), ('text', ' '), ('Token_Name_Tag', '"float"'), ('text', ': '), ('Token_Literal_Number', '1.44'), ('text', ',')], - [('text', ''), ('text', ' '), ('Token_Name_Tag', '"bool"'), ('text', ': '), ('Token_Keyword_Constant', 'True')], - [('text', ''), ('text', '}')] + assert list( + msgpack.format_msgpack( + {"string": "test", "int": 1, "float": 1.44, "bool": True} + ) + ) == [ + [("text", "{")], + [ + ("text", ""), + ("text", " "), + ("Token_Name_Tag", '"string"'), + ("text", ": "), + ("Token_Literal_String", '"test"'), + ("text", ","), + ], + [ + ("text", ""), + ("text", " "), + ("Token_Name_Tag", '"int"'), + ("text", ": "), + ("Token_Literal_Number", "1"), + ("text", ","), + ], + [ + ("text", ""), + ("text", " "), + ("Token_Name_Tag", '"float"'), + ("text", ": "), + ("Token_Literal_Number", "1.44"), + ("text", ","), + ], + [ + ("text", ""), + ("text", " "), + ("Token_Name_Tag", '"bool"'), + ("text", ": "), + ("Token_Keyword_Constant", "True"), + ], + [("text", ""), ("text", "}")], ] assert list(msgpack.format_msgpack({"object": {"key": "value"}, "list": [1]})) == [ - [('text', '{')], - [('text', ''), ('text', ' '), ('Token_Name_Tag', '"object"'), ('text', ': '), ('text', '{')], - [('text', ' '), ('text', ' '), ('Token_Name_Tag', '"key"'), ('text', ': '), ('Token_Literal_String', '"value"')], - [('text', ' '), ('text', '}'), ('text', ',')], - [('text', ''), ('text', ' '), ('Token_Name_Tag', '"list"'), ('text', ': '), ('text', '[')], - [('text', ' '), ('text', ' '), ('Token_Literal_Number', '1')], - [('text', ' '), ('text', ']')], - [('text', ''), ('text', '}')]] + [("text", "{")], + [ + ("text", ""), + ("text", " "), + ("Token_Name_Tag", '"object"'), + ("text", ": "), + ("text", "{"), + ], + [ + ("text", " "), + ("text", " "), + ("Token_Name_Tag", '"key"'), + ("text", ": "), + ("Token_Literal_String", '"value"'), + ], + [("text", " "), ("text", "}"), ("text", ",")], + [ + ("text", ""), + ("text", " "), + ("Token_Name_Tag", '"list"'), + ("text", ": "), + ("text", "["), + ], + [("text", " "), ("text", " "), ("Token_Literal_Number", "1")], + [("text", " "), ("text", "]")], + [("text", ""), ("text", "}")], + ] - assert list(msgpack.format_msgpack('string')) == [[('Token_Literal_String', '"string"')]] + assert list(msgpack.format_msgpack("string")) == [ + [("Token_Literal_String", '"string"')] + ] - assert list(msgpack.format_msgpack(1.2)) == [[('Token_Literal_Number', '1.2')]] + assert list(msgpack.format_msgpack(1.2)) == [[("Token_Literal_Number", "1.2")]] - assert list(msgpack.format_msgpack(True)) == [[('Token_Keyword_Constant', 'True')]] + assert list(msgpack.format_msgpack(True)) == [[("Token_Keyword_Constant", "True")]] - assert list(msgpack.format_msgpack(b'\x01\x02\x03')) == [[('text', "b'\\x01\\x02\\x03'")]] + assert list(msgpack.format_msgpack(b"\x01\x02\x03")) == [ + [("text", "b'\\x01\\x02\\x03'")] + ] def test_view_msgpack(): diff --git a/test/mitmproxy/contentviews/test_multipart.py b/test/mitmproxy/contentviews/test_multipart.py index da1f723e0..a748231d6 100644 --- a/test/mitmproxy/contentviews/test_multipart.py +++ b/test/mitmproxy/contentviews/test_multipart.py @@ -1,5 +1,5 @@ -from mitmproxy.contentviews import multipart from . import full_eval +from mitmproxy.contentviews import multipart def test_view_multipart(): diff --git a/test/mitmproxy/contentviews/test_protobuf.py b/test/mitmproxy/contentviews/test_protobuf.py index 5f8d84d2e..99d6768ed 100644 --- a/test/mitmproxy/contentviews/test_protobuf.py +++ b/test/mitmproxy/contentviews/test_protobuf.py @@ -1,7 +1,7 @@ import pytest -from mitmproxy.contentviews import protobuf from . import full_eval +from mitmproxy.contentviews import protobuf datadir = "mitmproxy/contentviews/test_protobuf_data/" diff --git a/test/mitmproxy/contentviews/test_query.py b/test/mitmproxy/contentviews/test_query.py index af47a02f8..b4b1408ef 100644 --- a/test/mitmproxy/contentviews/test_query.py +++ b/test/mitmproxy/contentviews/test_query.py @@ -1,6 +1,6 @@ +from . import full_eval from mitmproxy.contentviews import query from mitmproxy.test import tutils -from . import full_eval def test_view_query(): diff --git a/test/mitmproxy/contentviews/test_raw.py b/test/mitmproxy/contentviews/test_raw.py index d9fa44f89..0cffcf869 100644 --- a/test/mitmproxy/contentviews/test_raw.py +++ b/test/mitmproxy/contentviews/test_raw.py @@ -1,5 +1,5 @@ -from mitmproxy.contentviews import raw from . import full_eval +from mitmproxy.contentviews import raw def test_view_raw(): diff --git a/test/mitmproxy/contentviews/test_urlencoded.py b/test/mitmproxy/contentviews/test_urlencoded.py index 84c33dfce..e6005c0c8 100644 --- a/test/mitmproxy/contentviews/test_urlencoded.py +++ b/test/mitmproxy/contentviews/test_urlencoded.py @@ -1,6 +1,6 @@ +from . import full_eval from mitmproxy.contentviews import urlencoded from mitmproxy.net.http import url -from . import full_eval def test_view_urlencoded(): diff --git a/test/mitmproxy/contentviews/test_wbxml.py b/test/mitmproxy/contentviews/test_wbxml.py index e37f0da21..11f2886bf 100644 --- a/test/mitmproxy/contentviews/test_wbxml.py +++ b/test/mitmproxy/contentviews/test_wbxml.py @@ -1,5 +1,5 @@ -from mitmproxy.contentviews import wbxml from . import full_eval +from mitmproxy.contentviews import wbxml datadir = "mitmproxy/contentviews/test_wbxml_data/" diff --git a/test/mitmproxy/contentviews/test_xml_html.py b/test/mitmproxy/contentviews/test_xml_html.py index 4bb007972..de2b8d59f 100644 --- a/test/mitmproxy/contentviews/test_xml_html.py +++ b/test/mitmproxy/contentviews/test_xml_html.py @@ -1,7 +1,7 @@ import pytest -from mitmproxy.contentviews import xml_html from . import full_eval +from mitmproxy.contentviews import xml_html datadir = "mitmproxy/contentviews/test_xml_html_data/" diff --git a/test/mitmproxy/coretypes/test_bidi.py b/test/mitmproxy/coretypes/test_bidi.py index 3bdad3c2c..b4cff33cb 100644 --- a/test/mitmproxy/coretypes/test_bidi.py +++ b/test/mitmproxy/coretypes/test_bidi.py @@ -1,4 +1,5 @@ import pytest + from mitmproxy.coretypes import bidi diff --git a/test/mitmproxy/coretypes/test_serializable.py b/test/mitmproxy/coretypes/test_serializable.py index 70980a625..06e10fc54 100644 --- a/test/mitmproxy/coretypes/test_serializable.py +++ b/test/mitmproxy/coretypes/test_serializable.py @@ -5,7 +5,8 @@ import dataclasses import enum from collections.abc import Mapping from dataclasses import dataclass -from typing import Literal, Optional +from typing import Literal +from typing import Optional import pytest @@ -108,16 +109,31 @@ class FrozenWrapper(SerializableDataclass): class TestSerializableDataclass: - @pytest.mark.parametrize("cls, state", [ - (Simple, {"x": 42, "y": 'foo'}), - (Simple, {"x": 42, "y": None}), - (SerializableChild, {"foo": {"x": 42, "y": "foo"}, "maybe_foo": None}), - (SerializableChild, {"foo": {"x": 42, "y": "foo"}, "maybe_foo": {"x": 42, "y": "foo"}}), - (Inheritance, {"x": 42, "y": "foo", "z": True}), - (BuiltinChildren, {"a": [1, 2, 3], "b": {"foo": 42}, "c": (1, 2), "d": [{"x": 42, "y": "foo"}], "e": 1}), - (BuiltinChildren, {"a": None, "b": None, "c": None, "d": [], "e": None}), - (TLiteral, {"l": "foo"}), - ]) + @pytest.mark.parametrize( + "cls, state", + [ + (Simple, {"x": 42, "y": "foo"}), + (Simple, {"x": 42, "y": None}), + (SerializableChild, {"foo": {"x": 42, "y": "foo"}, "maybe_foo": None}), + ( + SerializableChild, + {"foo": {"x": 42, "y": "foo"}, "maybe_foo": {"x": 42, "y": "foo"}}, + ), + (Inheritance, {"x": 42, "y": "foo", "z": True}), + ( + BuiltinChildren, + { + "a": [1, 2, 3], + "b": {"foo": 42}, + "c": (1, 2), + "d": [{"x": 42, "y": "foo"}], + "e": 1, + }, + ), + (BuiltinChildren, {"a": None, "b": None, "c": None, "d": [], "e": None}), + (TLiteral, {"l": "foo"}), + ], + ) def test_roundtrip(self, cls, state): a = cls.from_state(copy.deepcopy(state)) assert a.get_state() == state @@ -142,7 +158,9 @@ class TestSerializableDataclass: with pytest.raises(ValueError): Simple.from_state({"x": 42, "y": 42}) with pytest.raises(ValueError): - BuiltinChildren.from_state({"a": None, "b": None, "c": ("foo",), "d": [], "e": None}) + BuiltinChildren.from_state( + {"a": None, "b": None, "c": ("foo",), "d": [], "e": None} + ) def test_invalid_key(self): with pytest.raises(ValueError): @@ -150,7 +168,15 @@ class TestSerializableDataclass: def test_invalid_type_in_list(self): with pytest.raises(ValueError, match="Invalid value for x"): - BuiltinChildren.from_state({"a": None, "b": None, "c": None, "d": [{"x": "foo", "y": "foo"}], "e": None}) + BuiltinChildren.from_state( + { + "a": None, + "b": None, + "c": None, + "d": [{"x": "foo", "y": "foo"}], + "e": None, + } + ) def test_unsupported_type(self): with pytest.raises(TypeError): @@ -162,8 +188,12 @@ class TestSerializableDataclass: TLiteral.from_state({"l": "unknown"}) def test_peername(self): - assert Addr.from_state({"peername": ("addr", 42)}).get_state() == {"peername": ("addr", 42)} - assert Addr.from_state({"peername": ("addr", 42, 0, 0)}).get_state() == {"peername": ("addr", 42, 0, 0)} + assert Addr.from_state({"peername": ("addr", 42)}).get_state() == { + "peername": ("addr", 42) + } + assert Addr.from_state({"peername": ("addr", 42, 0, 0)}).get_state() == { + "peername": ("addr", 42, 0, 0) + } def test_set_immutable(self): w = FrozenWrapper(Frozen(42)) diff --git a/test/mitmproxy/data/addonscripts/concurrent_decorator.py b/test/mitmproxy/data/addonscripts/concurrent_decorator.py index bf2628958..0af96c486 100644 --- a/test/mitmproxy/data/addonscripts/concurrent_decorator.py +++ b/test/mitmproxy/data/addonscripts/concurrent_decorator.py @@ -1,4 +1,5 @@ import time + from mitmproxy.script import concurrent diff --git a/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py b/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py index b4ef75292..e08ca0cb1 100644 --- a/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py +++ b/test/mitmproxy/data/addonscripts/concurrent_decorator_class.py @@ -1,4 +1,5 @@ import time + from mitmproxy.script import concurrent diff --git a/test/mitmproxy/io/test_compat.py b/test/mitmproxy/io/test_compat.py index 85ba5a0ee..35b11d619 100644 --- a/test/mitmproxy/io/test_compat.py +++ b/test/mitmproxy/io/test_compat.py @@ -1,7 +1,7 @@ import pytest -from mitmproxy import io from mitmproxy import exceptions +from mitmproxy import io @pytest.mark.parametrize( diff --git a/test/mitmproxy/io/test_io.py b/test/mitmproxy/io/test_io.py index 9d7ad8080..73d067d95 100644 --- a/test/mitmproxy/io/test_io.py +++ b/test/mitmproxy/io/test_io.py @@ -1,11 +1,14 @@ import io import pytest -from hypothesis import example, given +from hypothesis import example +from hypothesis import given from hypothesis.strategies import binary -from mitmproxy import exceptions, version -from mitmproxy.io import FlowReader, tnetstring +from mitmproxy import exceptions +from mitmproxy import version +from mitmproxy.io import FlowReader +from mitmproxy.io import tnetstring class TestFlowReader: diff --git a/test/mitmproxy/io/test_tnetstring.py b/test/mitmproxy/io/test_tnetstring.py index ccda4cfb8..6ae50bc82 100644 --- a/test/mitmproxy/io/test_tnetstring.py +++ b/test/mitmproxy/io/test_tnetstring.py @@ -1,8 +1,8 @@ -import unittest -import random -import math import io +import math +import random import struct +import unittest from mitmproxy.io import tnetstring diff --git a/test/mitmproxy/net/dns/test_domain_names.py b/test/mitmproxy/net/dns/test_domain_names.py index 72e6e5391..5f1847b55 100644 --- a/test/mitmproxy/net/dns/test_domain_names.py +++ b/test/mitmproxy/net/dns/test_domain_names.py @@ -1,5 +1,6 @@ import re import struct + import pytest from mitmproxy.net.dns import domain_names @@ -15,14 +16,11 @@ def test_unpack_from_with_compression(): domain_names.unpack_from_with_compression( b"\x03www\xc0\x00", 0, domain_names.cache() ) - assert ( - domain_names.unpack_from_with_compression( - b"\xFF\xFF\xFF\x07example\x03org\x00\xFF\xFF\xFF\x03www\xc0\x03", - 19, - domain_names.cache(), - ) - == ("www.example.org", 6) - ) + assert domain_names.unpack_from_with_compression( + b"\xFF\xFF\xFF\x07example\x03org\x00\xFF\xFF\xFF\x03www\xc0\x03", + 19, + domain_names.cache(), + ) == ("www.example.org", 6) def test_unpack(): diff --git a/test/mitmproxy/net/http/http1/test_assemble.py b/test/mitmproxy/net/http/http1/test_assemble.py index 5d17e1bfb..eb246cf19 100644 --- a/test/mitmproxy/net/http/http1/test_assemble.py +++ b/test/mitmproxy/net/http/http1/test_assemble.py @@ -1,17 +1,16 @@ import pytest from mitmproxy.http import Headers -from mitmproxy.net.http.http1.assemble import ( - assemble_request, - assemble_request_head, - assemble_response, - assemble_response_head, - _assemble_request_line, - _assemble_request_headers, - _assemble_response_headers, - assemble_body, -) -from mitmproxy.test.tutils import treq, tresp +from mitmproxy.net.http.http1.assemble import _assemble_request_headers +from mitmproxy.net.http.http1.assemble import _assemble_request_line +from mitmproxy.net.http.http1.assemble import _assemble_response_headers +from mitmproxy.net.http.http1.assemble import assemble_body +from mitmproxy.net.http.http1.assemble import assemble_request +from mitmproxy.net.http.http1.assemble import assemble_request_head +from mitmproxy.net.http.http1.assemble import assemble_response +from mitmproxy.net.http.http1.assemble import assemble_response_head +from mitmproxy.test.tutils import treq +from mitmproxy.test.tutils import tresp def test_assemble_request(): diff --git a/test/mitmproxy/net/http/http1/test_read.py b/test/mitmproxy/net/http/http1/test_read.py index 3f48a672e..a9148e7ab 100644 --- a/test/mitmproxy/net/http/http1/test_read.py +++ b/test/mitmproxy/net/http/http1/test_read.py @@ -1,18 +1,17 @@ import pytest from mitmproxy.http import Headers -from mitmproxy.net.http.http1.read import ( - read_request_head, - read_response_head, - connection_close, - expected_http_body_size, - _read_request_line, - _read_response_line, - _read_headers, - get_header_tokens, - validate_headers, -) -from mitmproxy.test.tutils import treq, tresp +from mitmproxy.net.http.http1.read import _read_headers +from mitmproxy.net.http.http1.read import _read_request_line +from mitmproxy.net.http.http1.read import _read_response_line +from mitmproxy.net.http.http1.read import connection_close +from mitmproxy.net.http.http1.read import expected_http_body_size +from mitmproxy.net.http.http1.read import get_header_tokens +from mitmproxy.net.http.http1.read import read_request_head +from mitmproxy.net.http.http1.read import read_response_head +from mitmproxy.net.http.http1.read import validate_headers +from mitmproxy.test.tutils import treq +from mitmproxy.test.tutils import tresp def test_get_header_tokens(): diff --git a/test/mitmproxy/net/http/test_cookies.py b/test/mitmproxy/net/http/test_cookies.py index 4b7f3dd65..be5a57c76 100644 --- a/test/mitmproxy/net/http/test_cookies.py +++ b/test/mitmproxy/net/http/test_cookies.py @@ -1,7 +1,8 @@ import time -import pytest from unittest import mock +import pytest + from mitmproxy.net.http import cookies diff --git a/test/mitmproxy/net/http/test_headers.py b/test/mitmproxy/net/http/test_headers.py index b7dff51d9..473b930f8 100644 --- a/test/mitmproxy/net/http/test_headers.py +++ b/test/mitmproxy/net/http/test_headers.py @@ -1,6 +1,7 @@ import collections -from mitmproxy.net.http.headers import parse_content_type, assemble_content_type +from mitmproxy.net.http.headers import assemble_content_type +from mitmproxy.net.http.headers import parse_content_type def test_parse_content_type(): diff --git a/test/mitmproxy/net/test_encoding.py b/test/mitmproxy/net/test_encoding.py index 9d155961b..640d318ae 100644 --- a/test/mitmproxy/net/test_encoding.py +++ b/test/mitmproxy/net/test_encoding.py @@ -1,4 +1,5 @@ from unittest import mock + import pytest from mitmproxy.net import encoding diff --git a/test/mitmproxy/net/test_tls.py b/test/mitmproxy/net/test_tls.py index c4fb16062..9e2600348 100644 --- a/test/mitmproxy/net/test_tls.py +++ b/test/mitmproxy/net/test_tls.py @@ -1,6 +1,8 @@ from pathlib import Path -from OpenSSL import SSL, crypto +from OpenSSL import crypto +from OpenSSL import SSL + from mitmproxy import certs from mitmproxy.net import tls @@ -58,7 +60,7 @@ def test_sslkeylogfile(tdata, monkeypatch): try: read.do_handshake() except SSL.WantReadError: - write.bio_write(read.bio_read(2 ** 16)) + write.bio_write(read.bio_read(2**16)) else: break read, write = write, read diff --git a/test/mitmproxy/net/test_udp.py b/test/mitmproxy/net/test_udp.py index 29a848cd1..d52636529 100644 --- a/test/mitmproxy/net/test_udp.py +++ b/test/mitmproxy/net/test_udp.py @@ -1,8 +1,14 @@ import asyncio from typing import Optional + import pytest + from mitmproxy.connection import Address -from mitmproxy.net.udp import MAX_DATAGRAM_SIZE, DatagramReader, DatagramWriter, open_connection, start_server +from mitmproxy.net.udp import DatagramReader +from mitmproxy.net.udp import DatagramWriter +from mitmproxy.net.udp import MAX_DATAGRAM_SIZE +from mitmproxy.net.udp import open_connection +from mitmproxy.net.udp import start_server async def test_client_server(): @@ -13,7 +19,7 @@ async def test_client_server(): transport: asyncio.DatagramTransport, data: bytes, remote_addr: Address, - local_addr: Address + local_addr: Address, ): nonlocal server_reader, server_writer if server_writer is None: @@ -23,9 +29,14 @@ async def test_client_server(): server = await start_server(handle_datagram, "127.0.0.1", 0) assert repr(server).startswith(" context.Context: opts = options.Options() Proxyserver().load(opts) return context.Context( - connection.Client(peername=("client", 1234), sockname=("127.0.0.1", 8080), - timestamp_start=1605699329, state=connection.ConnectionState.OPEN), - opts + connection.Client( + peername=("client", 1234), + sockname=("127.0.0.1", 8080), + timestamp_start=1605699329, + state=connection.ConnectionState.OPEN, + ), + opts, ) diff --git a/test/mitmproxy/proxy/layers/http/hyper_h2_test_helpers.py b/test/mitmproxy/proxy/layers/http/hyper_h2_test_helpers.py index d5f8e0182..9b1e2676d 100644 --- a/test/mitmproxy/proxy/layers/http/hyper_h2_test_helpers.py +++ b/test/mitmproxy/proxy/layers/http/hyper_h2_test_helpers.py @@ -1,6 +1,5 @@ # This file has been copied from https://github.com/python-hyper/hyper-h2/blob/master/test/helpers.py, # MIT License - # -*- coding: utf-8 -*- """ helpers @@ -9,19 +8,17 @@ helpers This module contains helpers for the h2 tests. """ from hpack.hpack import Encoder -from hyperframe.frame import ( - HeadersFrame, - DataFrame, - SettingsFrame, - WindowUpdateFrame, - PingFrame, - GoAwayFrame, - RstStreamFrame, - PushPromiseFrame, - PriorityFrame, - ContinuationFrame, - AltSvcFrame, -) +from hyperframe.frame import AltSvcFrame +from hyperframe.frame import ContinuationFrame +from hyperframe.frame import DataFrame +from hyperframe.frame import GoAwayFrame +from hyperframe.frame import HeadersFrame +from hyperframe.frame import PingFrame +from hyperframe.frame import PriorityFrame +from hyperframe.frame import PushPromiseFrame +from hyperframe.frame import RstStreamFrame +from hyperframe.frame import SettingsFrame +from hyperframe.frame import WindowUpdateFrame SAMPLE_SETTINGS = { SettingsFrame.HEADER_TABLE_SIZE: 4096, diff --git a/test/mitmproxy/proxy/layers/http/test_http.py b/test/mitmproxy/proxy/layers/http/test_http.py index 86ff6fc3a..0d8e638df 100644 --- a/test/mitmproxy/proxy/layers/http/test_http.py +++ b/test/mitmproxy/proxy/layers/http/test_http.py @@ -1,27 +1,34 @@ -from logging import WARNING - import gc +from logging import WARNING import pytest -from mitmproxy.connection import ConnectionState, Server -from mitmproxy.http import HTTPFlow, Response +from mitmproxy.connection import ConnectionState +from mitmproxy.connection import Server +from mitmproxy.http import HTTPFlow +from mitmproxy.http import Response from mitmproxy.proxy import layer -from mitmproxy.proxy.commands import CloseConnection, Log, OpenConnection, SendData -from mitmproxy.proxy.events import ConnectionClosed, DataReceived -from mitmproxy.proxy.layers import TCPLayer, http, tls +from mitmproxy.proxy.commands import CloseConnection +from mitmproxy.proxy.commands import Log +from mitmproxy.proxy.commands import OpenConnection +from mitmproxy.proxy.commands import SendData +from mitmproxy.proxy.events import ConnectionClosed +from mitmproxy.proxy.events import DataReceived +from mitmproxy.proxy.layers import http +from mitmproxy.proxy.layers import TCPLayer +from mitmproxy.proxy.layers import tls from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy.layers.tcp import TcpMessageInjected, TcpStartHook +from mitmproxy.proxy.layers.tcp import TcpMessageInjected +from mitmproxy.proxy.layers.tcp import TcpStartHook from mitmproxy.proxy.layers.websocket import WebsocketStartHook from mitmproxy.proxy.mode_specs import ProxyMode -from mitmproxy.tcp import TCPFlow, TCPMessage -from test.mitmproxy.proxy.tutils import ( - BytesMatching, - Placeholder, - Playbook, - reply, - reply_next_layer, -) +from mitmproxy.tcp import TCPFlow +from mitmproxy.tcp import TCPMessage +from test.mitmproxy.proxy.tutils import BytesMatching +from test.mitmproxy.proxy.tutils import Placeholder +from test.mitmproxy.proxy.tutils import Playbook +from test.mitmproxy.proxy.tutils import reply +from test.mitmproxy.proxy.tutils import reply_next_layer def test_http_proxy(tctx): @@ -744,7 +751,8 @@ def test_upstream_proxy(tctx, redirect, domain, scheme): << OpenConnection(server) >> reply(None) << SendData( - server, b"GET http://%s/ HTTP/1.1\r\nHost: %s\r\n\r\n" % (domain, domain), + server, + b"GET http://%s/ HTTP/1.1\r\nHost: %s\r\n\r\n" % (domain, domain), ) ) @@ -799,7 +807,8 @@ def test_upstream_proxy(tctx, redirect, domain, scheme): if redirect == "change-destination": playbook << SendData( server2, - b"GET http://%s.test/two HTTP/1.1\r\nHost: %s\r\n\r\n" % (domain, domain), + b"GET http://%s.test/two HTTP/1.1\r\nHost: %s\r\n\r\n" + % (domain, domain), ) else: playbook << SendData( @@ -808,7 +817,9 @@ def test_upstream_proxy(tctx, redirect, domain, scheme): ) else: if redirect == "change-destination": - playbook << SendData(server2, b"CONNECT %s.test:443 HTTP/1.1\r\n\r\n" % domain) + playbook << SendData( + server2, b"CONNECT %s.test:443 HTTP/1.1\r\n\r\n" % domain + ) playbook >> DataReceived( server2, b"HTTP/1.1 200 Connection established\r\n\r\n" ) @@ -830,9 +841,7 @@ def test_upstream_proxy(tctx, redirect, domain, scheme): assert flow().server_conn.address[0] == domain.decode("idna") if redirect == "change-proxy": - assert ( - server2().address == flow().server_conn.via[1] == ("other-proxy", 1234) - ) + assert server2().address == flow().server_conn.via[1] == ("other-proxy", 1234) else: assert server2().address == flow().server_conn.via[1] == ("proxy", 8080) @@ -958,8 +967,12 @@ def test_http_proxy_without_empty_chunk_in_head_request(tctx): << OpenConnection(server) >> reply(None) << SendData(server, b"HEAD / HTTP/1.1\r\n\r\n") - >> DataReceived(server, b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n") - << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n") + >> DataReceived( + server, b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + ) + << SendData( + tctx.client, b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + ) ) @@ -1658,7 +1671,7 @@ def test_drop_stream_with_paused_events(tctx): << http.HttpRequestHeadersHook(flow) >> reply() << OpenConnection(server) - >> reply('Connection killed: error') + >> reply("Connection killed: error") << http.HttpErrorHook(flow) >> reply() << SendData(tctx.client, BytesMatching(b"502 Bad Gateway.+Connection killed")) diff --git a/test/mitmproxy/proxy/layers/http/test_http1.py b/test/mitmproxy/proxy/layers/http/test_http1.py index 05e84e1fc..160493ad1 100644 --- a/test/mitmproxy/proxy/layers/http/test_http1.py +++ b/test/mitmproxy/proxy/layers/http/test_http1.py @@ -3,18 +3,17 @@ import pytest from mitmproxy import http from mitmproxy.proxy.commands import SendData from mitmproxy.proxy.events import DataReceived -from mitmproxy.proxy.layers.http import ( - Http1Server, - ReceiveHttp, - RequestHeaders, - RequestEndOfMessage, - ResponseHeaders, - ResponseEndOfMessage, - RequestData, - Http1Client, - ResponseData, -) -from test.mitmproxy.proxy.tutils import Placeholder, Playbook +from mitmproxy.proxy.layers.http import Http1Client +from mitmproxy.proxy.layers.http import Http1Server +from mitmproxy.proxy.layers.http import ReceiveHttp +from mitmproxy.proxy.layers.http import RequestData +from mitmproxy.proxy.layers.http import RequestEndOfMessage +from mitmproxy.proxy.layers.http import RequestHeaders +from mitmproxy.proxy.layers.http import ResponseData +from mitmproxy.proxy.layers.http import ResponseEndOfMessage +from mitmproxy.proxy.layers.http import ResponseHeaders +from test.mitmproxy.proxy.tutils import Placeholder +from test.mitmproxy.proxy.tutils import Playbook class TestServer: diff --git a/test/mitmproxy/proxy/layers/http/test_http2.py b/test/mitmproxy/proxy/layers/http/test_http2.py index e28bf72f3..a6a5412d8 100644 --- a/test/mitmproxy/proxy/layers/http/test_http2.py +++ b/test/mitmproxy/proxy/layers/http/test_http2.py @@ -1,28 +1,34 @@ +import time + import h2.settings import hpack import hyperframe.frame import pytest -import time from h2.errors import ErrorCodes -from mitmproxy.connection import ConnectionState, Server +from mitmproxy.connection import ConnectionState +from mitmproxy.connection import Server from mitmproxy.flow import Error -from mitmproxy.http import HTTPFlow, Headers, Request +from mitmproxy.http import Headers +from mitmproxy.http import HTTPFlow +from mitmproxy.http import Request from mitmproxy.net.http import status_codes -from mitmproxy.proxy.commands import ( - CloseConnection, - Log, - OpenConnection, - SendData, - RequestWakeup, -) +from mitmproxy.proxy.commands import CloseConnection +from mitmproxy.proxy.commands import Log +from mitmproxy.proxy.commands import OpenConnection +from mitmproxy.proxy.commands import RequestWakeup +from mitmproxy.proxy.commands import SendData from mitmproxy.proxy.context import Context -from mitmproxy.proxy.events import ConnectionClosed, DataReceived +from mitmproxy.proxy.events import ConnectionClosed +from mitmproxy.proxy.events import DataReceived from mitmproxy.proxy.layers import http from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy.layers.http._http2 import Http2Client, split_pseudo_headers +from mitmproxy.proxy.layers.http._http2 import Http2Client +from mitmproxy.proxy.layers.http._http2 import split_pseudo_headers from test.mitmproxy.proxy.layers.http.hyper_h2_test_helpers import FrameFactory -from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply +from test.mitmproxy.proxy.tutils import Placeholder +from test.mitmproxy.proxy.tutils import Playbook +from test.mitmproxy.proxy.tutils import reply example_request_headers = ( (b":method", b"GET"), @@ -325,8 +331,7 @@ def test_long_response(tctx: Context, trailers): << http.HttpResponseHeadersHook(flow) >> reply() >> DataReceived( - server, - sff.build_data_frame(b"a" * 10000, flags=[]).serialize() + server, sff.build_data_frame(b"a" * 10000, flags=[]).serialize() ) >> DataReceived( server, @@ -373,9 +378,7 @@ def test_long_response(tctx: Context, trailers): playbook >> DataReceived( server, - sff.build_data_frame( - b'', flags=["END_STREAM"] - ).serialize(), + sff.build_data_frame(b"", flags=["END_STREAM"]).serialize(), ) ) ( @@ -412,10 +415,7 @@ def test_long_response(tctx: Context, trailers): tctx.client, cff.build_data_frame(b"a" * 1).serialize(), ) - << SendData( - tctx.client, - cff.build_data_frame(b"a" * 4464).serialize() - ) + << SendData(tctx.client, cff.build_data_frame(b"a" * 4464).serialize()) << SendData( tctx.client, cff.build_headers_frame( @@ -430,15 +430,10 @@ def test_long_response(tctx: Context, trailers): tctx.client, cff.build_data_frame(b"a" * 1).serialize(), ) + << SendData(tctx.client, cff.build_data_frame(b"a" * 4464).serialize()) << SendData( tctx.client, - cff.build_data_frame(b"a" * 4464).serialize() - ) - << SendData( - tctx.client, - cff.build_data_frame( - b"", flags=["END_STREAM"] - ).serialize(), + cff.build_data_frame(b"", flags=["END_STREAM"]).serialize(), ) ) assert flow().request.url == "http://example.com/" diff --git a/test/mitmproxy/proxy/layers/http/test_http3.py b/test/mitmproxy/proxy/layers/http/test_http3.py index eccdd7bf4..59ffb670a 100644 --- a/test/mitmproxy/proxy/layers/http/test_http3.py +++ b/test/mitmproxy/proxy/layers/http/test_http3.py @@ -1,26 +1,33 @@ import collections.abc -from typing import Callable, Iterable, Optional -import pytest +from collections.abc import Iterable +from typing import Callable +from typing import Optional + import pylsqpack - +import pytest from aioquic._buffer import Buffer -from aioquic.h3.connection import ( - ErrorCode, - FrameType, - Headers as H3Headers, - Setting, - StreamType, - encode_frame, - encode_uint_var, - encode_settings, - parse_settings, -) +from aioquic.h3.connection import encode_frame +from aioquic.h3.connection import encode_settings +from aioquic.h3.connection import encode_uint_var +from aioquic.h3.connection import ErrorCode +from aioquic.h3.connection import FrameType +from aioquic.h3.connection import Headers as H3Headers +from aioquic.h3.connection import parse_settings +from aioquic.h3.connection import Setting +from aioquic.h3.connection import StreamType -from mitmproxy import connection, version +from mitmproxy import connection +from mitmproxy import version from mitmproxy.flow import Error -from mitmproxy.http import Headers, HTTPFlow, Request -from mitmproxy.proxy import commands, context, events, layers -from mitmproxy.proxy.layers import http, quic +from mitmproxy.http import Headers +from mitmproxy.http import HTTPFlow +from mitmproxy.http import Request +from mitmproxy.proxy import commands +from mitmproxy.proxy import context +from mitmproxy.proxy import events +from mitmproxy.proxy import layers +from mitmproxy.proxy.layers import http +from mitmproxy.proxy.layers import quic from mitmproxy.proxy.layers.http._http3 import Http3Client from test.mitmproxy.proxy import tutils @@ -47,6 +54,7 @@ def decode_frame(frame_type: int, frame_data: bytes) -> bytes: class CallbackPlaceholder(tutils._Placeholder[bytes]): """Data placeholder that invokes a callback once its bytes get set.""" + def __init__(self, cb: Callable[[bytes], None]): super().__init__(bytes) self._cb = cb @@ -59,6 +67,7 @@ class CallbackPlaceholder(tutils._Placeholder[bytes]): class DelayedPlaceholder(tutils._Placeholder[bytes]): """Data placeholder that resolves its bytes when needed.""" + def __init__(self, resolve: Callable[[], bytes]): super().__init__(bytes) self._resolve = resolve @@ -71,6 +80,7 @@ class DelayedPlaceholder(tutils._Placeholder[bytes]): class MultiPlaybook(tutils.Playbook): """Playbook that allows multiple events and commands to be registered at once.""" + def __lshift__(self, c): if isinstance(c, collections.abc.Iterable): for c_i in c: @@ -90,11 +100,8 @@ class MultiPlaybook(tutils.Playbook): class FrameFactory: """Helper class for generating QUIC stream events and commands.""" - def __init__( - self, - conn: connection.Connection, - is_client: bool - ) -> None: + + def __init__(self, conn: connection.Connection, is_client: bool) -> None: self.conn = conn self.is_client = is_client self.decoder = pylsqpack.Decoder( @@ -108,11 +115,7 @@ class FrameFactory: self.local_stream_id: dict[StreamType, int] = {} self.max_push_id: Optional[int] = None - def get_default_stream_id( - self, - stream_type: StreamType, - for_local: bool - ) -> int: + def get_default_stream_id(self, stream_type: StreamType, for_local: bool) -> int: if stream_type == StreamType.CONTROL: stream_id = 2 elif stream_type == StreamType.QPACK_ENCODER: @@ -132,9 +135,7 @@ class FrameFactory: ) -> quic.SendQuicStreamData: assert stream_type not in self.peer_stream_id if stream_id is None: - stream_id = self.get_default_stream_id( - stream_type, for_local=False - ) + stream_id = self.get_default_stream_id(stream_type, for_local=False) self.peer_stream_id[stream_type] = stream_id return quic.SendQuicStreamData( connection=self.conn, @@ -150,9 +151,7 @@ class FrameFactory: ) -> quic.QuicStreamDataReceived: assert stream_type not in self.local_stream_id if stream_id is None: - stream_id = self.get_default_stream_id( - stream_type, for_local=True - ) + stream_id = self.get_default_stream_id(stream_type, for_local=True) self.local_stream_id[stream_type] = stream_id return quic.QuicStreamDataReceived( connection=self.conn, @@ -185,10 +184,12 @@ class FrameFactory: buf = Buffer(data=data) assert buf.pull_uint_var() == FrameType.SETTINGS settings = parse_settings(buf.pull_bytes(buf.pull_uint_var())) - placeholder.setdefault(self.encoder.apply_settings( - max_table_capacity=settings[Setting.QPACK_MAX_TABLE_CAPACITY], - blocked_streams=settings[Setting.QPACK_BLOCKED_STREAMS], - )) + placeholder.setdefault( + self.encoder.apply_settings( + max_table_capacity=settings[Setting.QPACK_MAX_TABLE_CAPACITY], + blocked_streams=settings[Setting.QPACK_BLOCKED_STREAMS], + ) + ) return quic.SendQuicStreamData( connection=self.conn, @@ -368,10 +369,7 @@ class FrameFactory: @property def is_done(self) -> bool: - return ( - self.encoder_placeholder is None - and not self.decoder_placeholders - ) + return self.encoder_placeholder is None and not self.decoder_placeholders @pytest.fixture @@ -428,11 +426,14 @@ def test_invalid_header(tctx: context.Context): playbook, cff = start_h3_client(tctx) assert ( playbook - >> cff.receive_headers([ - (b":method", b"CONNECT"), - (b":path", b"/"), - (b":authority", b"example.com"), - ], end_stream=True) + >> cff.receive_headers( + [ + (b":method", b"CONNECT"), + (b":path", b"/"), + (b":authority", b"example.com"), + ], + end_stream=True, + ) << cff.send_decoder() # for receive_headers << quic.CloseQuicConnection( tctx.client, @@ -441,11 +442,14 @@ def test_invalid_header(tctx: context.Context): reason_phrase="Invalid HTTP/3 request headers: Required pseudo header is missing: b':scheme'", ) # ensure that once we close, we don't process messages anymore - >> cff.receive_headers([ - (b":method", b"CONNECT"), - (b":path", b"/"), - (b":authority", b"example.com"), - ], end_stream=True) + >> cff.receive_headers( + [ + (b":method", b"CONNECT"), + (b":path", b"/"), + (b":authority", b"example.com"), + ], + end_stream=True, + ) ) @@ -621,10 +625,7 @@ def test_request_trailers( >> tutils.reply(to=request) << sff.send_headers(example_request_trailers, end_stream=True) ) - assert ( - playbook - >> sff.receive_decoder() # for send_headers - ) + assert playbook >> sff.receive_decoder() # for send_headers assert cff.is_done and sff.is_done @@ -648,11 +649,13 @@ def test_upstream_error(tctx: context.Context): >> tutils.reply("oops server <> error") << http.HttpErrorHook(flow) >> tutils.reply() - << cff.send_headers([ - (b":status", b"502"), - (b'server', version.MITMPROXY.encode()), - (b'content-type', b'text/html'), - ]) + << cff.send_headers( + [ + (b":status", b"502"), + (b"server", version.MITMPROXY.encode()), + (b"content-type", b"text/html"), + ] + ) << quic.SendQuicStreamData( tctx.client, stream_id=0, @@ -670,9 +673,7 @@ def test_upstream_error(tctx: context.Context): @pytest.mark.parametrize("stream", ["stream", ""]) @pytest.mark.parametrize("when", ["request", "response"]) @pytest.mark.parametrize("how", ["RST", "disconnect", "RST+disconnect"]) -def test_http3_client_aborts( - tctx: context.Context, stream: str, when: str, how: str -): +def test_http3_client_aborts(tctx: context.Context, stream: str, when: str, how: str): """ Test handling of the case where a client aborts during request or response transmission. @@ -698,12 +699,12 @@ def test_http3_client_aborts( if stream and when == "request": assert ( playbook - >> tutils.reply( - side_effect=enable_request_streaming, to=request_headers - ) + >> tutils.reply(side_effect=enable_request_streaming, to=request_headers) << commands.OpenConnection(server) >> tutils.reply(None) - << commands.SendData(server, b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n\r\n") + << commands.SendData( + server, b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n\r\n" + ) ) else: assert playbook >> tutils.reply(to=request_headers) @@ -716,7 +717,7 @@ def test_http3_client_aborts( tctx.client, error_code=ErrorCode.H3_REQUEST_CANCELLED, frame_type=None, - reason_phrase="peer closed connection" + reason_phrase="peer closed connection", ) if stream: @@ -729,7 +730,7 @@ def test_http3_client_aborts( tctx.client, error_code=ErrorCode.H3_NO_ERROR, frame_type=None, - reason_phrase="peer closed connection" + reason_phrase="peer closed connection", ) assert playbook assert ( @@ -746,17 +747,21 @@ def test_http3_client_aborts( << commands.OpenConnection(server) >> tutils.reply(None) << commands.SendData(server, b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n\r\n") - >> events.DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\n123") + >> events.DataReceived( + server, b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\n123" + ) << http.HttpResponseHeadersHook(flow) ) if stream: assert ( playbook >> tutils.reply(side_effect=enable_response_streaming) - << cff.send_headers([ - (b":status", b"200"), - (b"content-length", b"6"), - ]) + << cff.send_headers( + [ + (b":status", b"200"), + (b"content-length", b"6"), + ] + ) << cff.send_data(b"123") ) else: @@ -769,7 +774,7 @@ def test_http3_client_aborts( tctx.client, error_code=ErrorCode.H3_REQUEST_CANCELLED, frame_type=None, - reason_phrase="peer closed connection" + reason_phrase="peer closed connection", ) playbook << commands.CloseConnection(server) @@ -782,7 +787,7 @@ def test_http3_client_aborts( tctx.client, error_code=ErrorCode.H3_REQUEST_CANCELLED, frame_type=None, - reason_phrase="peer closed connection" + reason_phrase="peer closed connection", ) assert playbook @@ -920,14 +925,10 @@ def test_stream_concurrency(tctx: context.Context): assert ( playbook # request client - >> cff.receive_headers( - headers1, stream_id=0, end_stream=True - ) + >> cff.receive_headers(headers1, stream_id=0, end_stream=True) << (request_header1 := http.HttpRequestHeadersHook(flow1)) << cff.send_decoder() # for receive_headers - >> cff.receive_headers( - headers2, stream_id=4, end_stream=True - ) + >> cff.receive_headers(headers2, stream_id=4, end_stream=True) << (request_header2 := http.HttpRequestHeadersHook(flow2)) << cff.send_decoder() # for receive_headers >> tutils.reply(to=request_header1) @@ -940,17 +941,13 @@ def test_stream_concurrency(tctx: context.Context): << commands.OpenConnection(server) >> tutils.reply(None, side_effect=make_h3) << sff.send_init() - << sff.send_headers( - headers2, stream_id=0, end_stream=True - ) + << sff.send_headers(headers2, stream_id=0, end_stream=True) >> sff.receive_init() << sff.send_encoder() >> sff.receive_encoder() >> sff.receive_decoder() # for send_headers >> tutils.reply(to=request1) - << sff.send_headers( - headers1, stream_id=4, end_stream=True - ) + << sff.send_headers(headers1, stream_id=4, end_stream=True) >> sff.receive_decoder() # for send_headers ) assert cff.is_done and sff.is_done @@ -964,23 +961,15 @@ def test_stream_concurrent_get_connection(tctx: context.Context): sff = FrameFactory(server, is_client=False) assert ( playbook - >> cff.receive_headers( - example_request_headers, stream_id=0, end_stream=True - ) + >> cff.receive_headers(example_request_headers, stream_id=0, end_stream=True) << cff.send_decoder() # for receive_headers << (o := commands.OpenConnection(server)) - >> cff.receive_headers( - example_request_headers, stream_id=4, end_stream=True - ) + >> cff.receive_headers(example_request_headers, stream_id=4, end_stream=True) << cff.send_decoder() # for receive_headers >> tutils.reply(None, to=o, side_effect=make_h3) << sff.send_init() - << sff.send_headers( - example_request_headers, stream_id=0, end_stream=True - ) - << sff.send_headers( - example_request_headers, stream_id=4, end_stream=True - ) + << sff.send_headers(example_request_headers, stream_id=0, end_stream=True) + << sff.send_headers(example_request_headers, stream_id=4, end_stream=True) >> sff.receive_init() << sff.send_encoder() >> sff.receive_encoder() @@ -1007,14 +996,10 @@ def test_kill_stream(tctx: context.Context): assert ( playbook # request client - >> cff.receive_headers( - headers1, stream_id=0, end_stream=True - ) + >> cff.receive_headers(headers1, stream_id=0, end_stream=True) << (request_header1 := http.HttpRequestHeadersHook(flow1)) << cff.send_decoder() # for receive_headers - >> cff.receive_headers( - headers2, stream_id=4, end_stream=True - ) + >> cff.receive_headers(headers2, stream_id=4, end_stream=True) << (request_header2 := http.HttpRequestHeadersHook(flow2)) << cff.send_decoder() # for receive_headers >> tutils.reply(to=request_header2, side_effect=kill) @@ -1028,9 +1013,7 @@ def test_kill_stream(tctx: context.Context): << commands.OpenConnection(server) >> tutils.reply(None, side_effect=make_h3) << sff.send_init() - << sff.send_headers( - headers1, stream_id=0, end_stream=True - ) + << sff.send_headers(headers1, stream_id=0, end_stream=True) >> sff.receive_init() << sff.send_encoder() >> sff.receive_encoder() @@ -1052,12 +1035,15 @@ class TestClient: << frame_factory.send_encoder() >> frame_factory.receive_encoder() >> http.RequestHeaders(1, req, end_stream=True) - << frame_factory.send_headers([ - (b":method", b"GET"), - (b':scheme', b'http'), - (b':path', b'/'), - (b'content-length', b'0'), - ], end_stream=True) + << frame_factory.send_headers( + [ + (b":method", b"GET"), + (b":scheme", b"http"), + (b":path", b"/"), + (b"content-length", b"0"), + ], + end_stream=True, + ) >> frame_factory.receive_decoder() # for send_headers >> http.RequestEndOfMessage(1) >> frame_factory.receive_headers(resp) @@ -1099,12 +1085,15 @@ class TestClient: "DATA frame is not allowed in this state" ) >> http.RequestHeaders(1, req, end_stream=False) - << frame_factory.send_headers([ - (b":method", b"GET"), - (b':scheme', b'http'), - (b':path', b'/'), - (b'content-length', b'0'), - ], end_stream=False) + << frame_factory.send_headers( + [ + (b":method", b"GET"), + (b":scheme", b"http"), + (b":path", b"/"), + (b"content-length", b"0"), + ], + end_stream=False, + ) >> frame_factory.receive_decoder() # for send_headers >> http.RequestHeaders(1, req, end_stream=False) << commands.Log( diff --git a/test/mitmproxy/proxy/layers/http/test_http_fuzz.py b/test/mitmproxy/proxy/layers/http/test_http_fuzz.py index 67478a02f..8e139e0c2 100644 --- a/test/mitmproxy/proxy/layers/http/test_http_fuzz.py +++ b/test/mitmproxy/proxy/layers/http/test_http_fuzz.py @@ -2,44 +2,44 @@ from typing import Any import pytest from h2.settings import SettingCodes -from hypothesis import example, given -from hypothesis.strategies import ( - binary, - booleans, - composite, - dictionaries, - integers, - lists, - sampled_from, - sets, - text, - data, -) +from hypothesis import example +from hypothesis import given +from hypothesis.strategies import binary +from hypothesis.strategies import booleans +from hypothesis.strategies import composite +from hypothesis.strategies import data +from hypothesis.strategies import dictionaries +from hypothesis.strategies import integers +from hypothesis.strategies import lists +from hypothesis.strategies import sampled_from +from hypothesis.strategies import sets +from hypothesis.strategies import text -from mitmproxy import options, connection +from mitmproxy import connection +from mitmproxy import options from mitmproxy.addons.proxyserver import Proxyserver from mitmproxy.connection import Server from mitmproxy.http import HTTPFlow -from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy import context, events -from mitmproxy.proxy.commands import OpenConnection, SendData -from mitmproxy.proxy.events import DataReceived, Start, ConnectionClosed +from mitmproxy.proxy import context +from mitmproxy.proxy import events +from mitmproxy.proxy.commands import OpenConnection +from mitmproxy.proxy.commands import SendData +from mitmproxy.proxy.events import ConnectionClosed +from mitmproxy.proxy.events import DataReceived +from mitmproxy.proxy.events import Start from mitmproxy.proxy.layers import http -from test.mitmproxy.proxy.layers.http.hyper_h2_test_helpers import FrameFactory -from test.mitmproxy.proxy.layers.http.test_http2 import ( - make_h2, - example_response_headers, - example_request_headers, - start_h2_client, -) -from test.mitmproxy.proxy.tutils import ( - Placeholder, - Playbook, - reply, - _TracebackInPlaybook, - _eq, -) from mitmproxy.proxy.layers.http import _http2 +from mitmproxy.proxy.layers.http import HTTPMode +from test.mitmproxy.proxy.layers.http.hyper_h2_test_helpers import FrameFactory +from test.mitmproxy.proxy.layers.http.test_http2 import example_request_headers +from test.mitmproxy.proxy.layers.http.test_http2 import example_response_headers +from test.mitmproxy.proxy.layers.http.test_http2 import make_h2 +from test.mitmproxy.proxy.layers.http.test_http2 import start_h2_client +from test.mitmproxy.proxy.tutils import _eq +from test.mitmproxy.proxy.tutils import _TracebackInPlaybook +from test.mitmproxy.proxy.tutils import Placeholder +from test.mitmproxy.proxy.tutils import Playbook +from test.mitmproxy.proxy.tutils import reply opts = options.Options() Proxyserver().load(opts) @@ -217,7 +217,7 @@ def h2_frames(draw): settings=draw( dictionaries( keys=sampled_from(SettingCodes), - values=integers(0, 2 ** 32 - 1), + values=integers(0, 2**32 - 1), max_size=5, ) ), @@ -244,7 +244,7 @@ def h2_frames(draw): draw(binary()), draw(h2_flags), stream_id=draw(h2_stream_ids_nonzero) ) window_update = ff.build_window_update_frame( - draw(h2_stream_ids), draw(integers(0, 2 ** 32 - 1)) + draw(h2_stream_ids), draw(integers(0, 2**32 - 1)) ) frames = draw( @@ -318,8 +318,12 @@ def test_fuzz_h2_request_mutations(chunks): def _tctx() -> context.Context: return context.Context( - connection.Client(peername=("client", 1234), sockname=("127.0.0.1", 8080), timestamp_start=1605699329), - opts + connection.Client( + peername=("client", 1234), + sockname=("127.0.0.1", 8080), + timestamp_start=1605699329, + ), + opts, ) diff --git a/test/mitmproxy/proxy/layers/http/test_http_version_interop.py b/test/mitmproxy/proxy/layers/http/test_http_version_interop.py index d0ae84248..599793e32 100644 --- a/test/mitmproxy/proxy/layers/http/test_http_version_interop.py +++ b/test/mitmproxy/proxy/layers/http/test_http_version_interop.py @@ -2,19 +2,21 @@ import h2.config import h2.connection import h2.events -from mitmproxy.http import HTTPFlow -from mitmproxy.proxy.context import Context -from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy.commands import CloseConnection, OpenConnection, SendData from mitmproxy.connection import Server +from mitmproxy.http import HTTPFlow +from mitmproxy.proxy.commands import CloseConnection +from mitmproxy.proxy.commands import OpenConnection +from mitmproxy.proxy.commands import SendData +from mitmproxy.proxy.context import Context from mitmproxy.proxy.events import DataReceived from mitmproxy.proxy.layers import http +from mitmproxy.proxy.layers.http import HTTPMode from test.mitmproxy.proxy.layers.http.hyper_h2_test_helpers import FrameFactory -from test.mitmproxy.proxy.layers.http.test_http2 import ( - example_response_headers, - make_h2, -) -from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply +from test.mitmproxy.proxy.layers.http.test_http2 import example_response_headers +from test.mitmproxy.proxy.layers.http.test_http2 import make_h2 +from test.mitmproxy.proxy.tutils import Placeholder +from test.mitmproxy.proxy.tutils import Playbook +from test.mitmproxy.proxy.tutils import reply example_request_headers = ( (b":method", b"GET"), @@ -77,7 +79,9 @@ def test_h2_to_h1(tctx): >> reply() << OpenConnection(server) >> reply(None) - << SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\ncookie: a=1; b=2\r\n\r\n") + << SendData( + server, b"GET / HTTP/1.1\r\nHost: example.com\r\ncookie: a=1; b=2\r\n\r\n" + ) >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\n") << http.HttpResponseHeadersHook(flow) >> reply() diff --git a/test/mitmproxy/proxy/layers/test_dns.py b/test/mitmproxy/proxy/layers/test_dns.py index ab520bb62..133ae20b3 100644 --- a/test/mitmproxy/proxy/layers/test_dns.py +++ b/test/mitmproxy/proxy/layers/test_dns.py @@ -1,11 +1,18 @@ import time -from mitmproxy.proxy.commands import CloseConnection, Log, OpenConnection, SendData -from mitmproxy.proxy.events import ConnectionClosed, DataReceived -from mitmproxy.proxy.layers import dns +from ..tutils import Placeholder +from ..tutils import Playbook +from ..tutils import reply from mitmproxy.dns import DNSFlow -from mitmproxy.test.tutils import tdnsreq, tdnsresp -from ..tutils import Placeholder, Playbook, reply +from mitmproxy.proxy.commands import CloseConnection +from mitmproxy.proxy.commands import Log +from mitmproxy.proxy.commands import OpenConnection +from mitmproxy.proxy.commands import SendData +from mitmproxy.proxy.events import ConnectionClosed +from mitmproxy.proxy.events import DataReceived +from mitmproxy.proxy.layers import dns +from mitmproxy.test.tutils import tdnsreq +from mitmproxy.test.tutils import tdnsresp def test_invalid_and_dummy_end(tctx): diff --git a/test/mitmproxy/proxy/layers/test_modes.py b/test/mitmproxy/proxy/layers/test_modes.py index 96ca863f5..6d8040e3b 100644 --- a/test/mitmproxy/proxy/layers/test_modes.py +++ b/test/mitmproxy/proxy/layers/test_modes.py @@ -1,38 +1,46 @@ import copy import pytest -from mitmproxy import dns +from mitmproxy import dns from mitmproxy.addons.proxyauth import ProxyAuth -from mitmproxy.connection import Client, ConnectionState, Server +from mitmproxy.connection import Client +from mitmproxy.connection import ConnectionState +from mitmproxy.connection import Server from mitmproxy.proxy import layers -from mitmproxy.proxy.commands import ( - CloseConnection, - Log, - OpenConnection, - RequestWakeup, - SendData, -) +from mitmproxy.proxy.commands import CloseConnection +from mitmproxy.proxy.commands import Log +from mitmproxy.proxy.commands import OpenConnection +from mitmproxy.proxy.commands import RequestWakeup +from mitmproxy.proxy.commands import SendData from mitmproxy.proxy.context import Context -from mitmproxy.proxy.events import ConnectionClosed, DataReceived -from mitmproxy.proxy.layer import NextLayer, NextLayerHook -from mitmproxy.proxy.layers import http, modes, quic, tcp, tls, udp +from mitmproxy.proxy.events import ConnectionClosed +from mitmproxy.proxy.events import DataReceived +from mitmproxy.proxy.layer import NextLayer +from mitmproxy.proxy.layer import NextLayerHook +from mitmproxy.proxy.layers import http +from mitmproxy.proxy.layers import modes +from mitmproxy.proxy.layers import quic +from mitmproxy.proxy.layers import tcp +from mitmproxy.proxy.layers import tls +from mitmproxy.proxy.layers import udp from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy.layers.tcp import TcpMessageHook, TcpStartHook -from mitmproxy.proxy.layers.tls import ( - ClientTLSLayer, - TlsStartClientHook, - TlsStartServerHook, -) +from mitmproxy.proxy.layers.tcp import TcpMessageHook +from mitmproxy.proxy.layers.tcp import TcpStartHook +from mitmproxy.proxy.layers.tls import ClientTLSLayer +from mitmproxy.proxy.layers.tls import TlsStartClientHook +from mitmproxy.proxy.layers.tls import TlsStartServerHook from mitmproxy.proxy.mode_specs import ProxyMode from mitmproxy.tcp import TCPFlow -from mitmproxy.test import taddons, tflow +from mitmproxy.test import taddons +from mitmproxy.test import tflow from mitmproxy.udp import UDPFlow -from test.mitmproxy.proxy.layers.test_tls import ( - reply_tls_start_client, - reply_tls_start_server, -) -from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply, reply_next_layer +from test.mitmproxy.proxy.layers.test_tls import reply_tls_start_client +from test.mitmproxy.proxy.layers.test_tls import reply_tls_start_server +from test.mitmproxy.proxy.tutils import Placeholder +from test.mitmproxy.proxy.tutils import Playbook +from test.mitmproxy.proxy.tutils import reply +from test.mitmproxy.proxy.tutils import reply_next_layer def test_upstream_https(tctx): @@ -45,12 +53,24 @@ def test_upstream_https(tctx): curl -x localhost:8080 -k http://example.com """ tctx1 = Context( - Client(peername=("client", 1234), sockname=("127.0.0.1", 8080), timestamp_start=1605699329, state=ConnectionState.OPEN), + Client( + peername=("client", 1234), + sockname=("127.0.0.1", 8080), + timestamp_start=1605699329, + state=ConnectionState.OPEN, + ), copy.deepcopy(tctx.options), ) - tctx1.client.proxy_mode = ProxyMode.parse("upstream:https://example.mitmproxy.org:8081") + tctx1.client.proxy_mode = ProxyMode.parse( + "upstream:https://example.mitmproxy.org:8081" + ) tctx2 = Context( - Client(peername=("client", 4321), sockname=("127.0.0.1", 8080), timestamp_start=1605699329, state=ConnectionState.OPEN), + Client( + peername=("client", 4321), + sockname=("127.0.0.1", 8080), + timestamp_start=1605699329, + state=ConnectionState.OPEN, + ), copy.deepcopy(tctx.options), ) assert tctx2.client.proxy_mode == ProxyMode.parse("regular") diff --git a/test/mitmproxy/proxy/layers/test_quic.py b/test/mitmproxy/proxy/layers/test_quic.py index 06da7725c..d635a3385 100644 --- a/test/mitmproxy/proxy/layers/test_quic.py +++ b/test/mitmproxy/proxy/layers/test_quic.py @@ -1,28 +1,42 @@ -from logging import DEBUG, ERROR, WARNING import ssl import time +from logging import DEBUG +from logging import ERROR +from logging import WARNING +from typing import Literal +from typing import Optional +from typing import TypeVar +from unittest.mock import MagicMock + +import pytest from aioquic.buffer import Buffer as QuicBuffer from aioquic.quic import events as quic_events from aioquic.quic.configuration import QuicConfiguration -from aioquic.quic.connection import QuicConnection, pull_quic_header -from typing import Literal, Optional, TypeVar -from unittest.mock import MagicMock -import pytest +from aioquic.quic.connection import pull_quic_header +from aioquic.quic.connection import QuicConnection + from mitmproxy import connection -from mitmproxy.proxy import commands, context, events, layer, tunnel +from mitmproxy.proxy import commands +from mitmproxy.proxy import context +from mitmproxy.proxy import events +from mitmproxy.proxy import layer from mitmproxy.proxy import layers -from mitmproxy.proxy.layers import quic, tcp, tls, udp -from mitmproxy.udp import UDPFlow, UDPMessage +from mitmproxy.proxy import tunnel +from mitmproxy.proxy.layers import quic +from mitmproxy.proxy.layers import tcp +from mitmproxy.proxy.layers import tls +from mitmproxy.proxy.layers import udp +from mitmproxy.tcp import TCPFlow +from mitmproxy.udp import UDPFlow +from mitmproxy.udp import UDPMessage from mitmproxy.utils import data from test.mitmproxy.proxy import tutils -from mitmproxy.tcp import TCPFlow - tlsdata = data.Data(__name__) -T = TypeVar('T', bound=layer.Layer) +T = TypeVar("T", bound=layer.Layer) class DummyLayer(layer.Layer): @@ -44,26 +58,44 @@ class TlsEchoLayer(tutils.EchoLayer): yield commands.SendData( event.connection, f"open-connection failed: {err}".encode() ) - elif isinstance(event, events.DataReceived) and event.data == b"close-connection": + elif ( + isinstance(event, events.DataReceived) and event.data == b"close-connection" + ): yield commands.CloseConnection(event.connection) - elif isinstance(event, events.DataReceived) and event.data == b"close-connection-error": + elif ( + isinstance(event, events.DataReceived) + and event.data == b"close-connection-error" + ): yield quic.CloseQuicConnection(event.connection, 123, None, "error") elif isinstance(event, events.DataReceived) and event.data == b"stop-stream": yield quic.StopQuicStream(event.connection, 24, 123) - elif isinstance(event, events.DataReceived) and event.data == b"invalid-command": + elif ( + isinstance(event, events.DataReceived) and event.data == b"invalid-command" + ): + class InvalidConnectionCommand(commands.ConnectionCommand): pass + yield InvalidConnectionCommand(event.connection) - elif isinstance(event, events.DataReceived) and event.data == b"invalid-stream-command": + elif ( + isinstance(event, events.DataReceived) + and event.data == b"invalid-stream-command" + ): + class InvalidStreamCommand(quic.QuicStreamCommand): pass + yield InvalidStreamCommand(event.connection, 42) elif isinstance(event, quic.QuicConnectionClosed): self.closed = event elif isinstance(event, quic.QuicStreamDataReceived): - yield quic.SendQuicStreamData(event.connection, event.stream_id, event.data, event.end_stream) + yield quic.SendQuicStreamData( + event.connection, event.stream_id, event.data, event.end_stream + ) elif isinstance(event, quic.QuicStreamReset): - yield quic.ResetQuicStream(event.connection, event.stream_id, event.error_code) + yield quic.ResetQuicStream( + event.connection, event.stream_id, event.error_code + ) else: yield from super()._handle_event(event) @@ -104,7 +136,7 @@ client_hello = bytes.fromhex( def test_error_code_to_str(): assert quic.error_code_to_str(0x6) == "FINAL_SIZE_ERROR" assert quic.error_code_to_str(0x104) == "H3_CLOSED_CRITICAL_STREAM" - assert quic.error_code_to_str(0xdead) == f"unknown error (0xdead)" + assert quic.error_code_to_str(0xDEAD) == f"unknown error (0xdead)" def test_is_success_error_code(): @@ -112,7 +144,7 @@ def test_is_success_error_code(): assert not quic.is_success_error_code(0x6) assert quic.is_success_error_code(0x100) assert not quic.is_success_error_code(0x104) - assert not quic.is_success_error_code(0xdead) + assert not quic.is_success_error_code(0xDEAD) @pytest.mark.parametrize("value", ["s1 s2\n", "s1 s2"]) @@ -133,7 +165,7 @@ class TestParseClientHello: ) with pytest.raises(ValueError, match="not initial"): quic.quic_parse_client_hello( - b'\\s\xd8\xd8\xa5dT\x8bc\xd3\xae\x1c\xb2\x8a7-\x1d\x19j\x85\xb0~\x8c\x80\xa5\x8cY\xac\x0ecK\x7fC2f\xbcm\x1b\xac~' + b"\\s\xd8\xd8\xa5dT\x8bc\xd3\xae\x1c\xb2\x8a7-\x1d\x19j\x85\xb0~\x8c\x80\xa5\x8cY\xac\x0ecK\x7fC2f\xbcm\x1b\xac~" ) def test_invalid(self, monkeypatch): @@ -156,7 +188,9 @@ class TestParseClientHello: def test_no_return(self): with pytest.raises(ValueError, match="No ClientHello"): - quic.quic_parse_client_hello(client_hello[0:1200] + b'\x00' + client_hello[1200:]) + quic.quic_parse_client_hello( + client_hello[0:1200] + b"\x00" + client_hello[1200:] + ) class TestQuicStreamLayer: @@ -247,13 +281,17 @@ class TestRawQuicLayer: << udp.UdpMessageHook(udpflow) >> tutils.reply() << commands.SendData(tctx.server, b"msg2") - >> udp.UdpMessageInjected(UDPFlow(("other", 80), tctx.server), UDPMessage(True, b"msg3")) + >> udp.UdpMessageInjected( + UDPFlow(("other", 80), tctx.server), UDPMessage(True, b"msg3") + ) << udp.UdpMessageHook(udpflow) >> tutils.reply() << commands.SendData(tctx.server, b"msg3") ) with pytest.raises(AssertionError, match="not associated"): - playbook >> udp.UdpMessageInjected(UDPFlow(("notfound", 0), ("noexist", 0)), UDPMessage(True, b"msg2")) + playbook >> udp.UdpMessageInjected( + UDPFlow(("notfound", 0), ("noexist", 0)), UDPMessage(True, b"msg2") + ) assert playbook def test_reset_with_end_hook(self, tctx: context.Context): @@ -316,8 +354,10 @@ class TestRawQuicLayer: >> tutils.reply(None) ) with pytest.raises(AssertionError, match="Unexpected stream event"): + class InvalidStreamEvent(quic.QuicStreamEvent): pass + playbook >> InvalidStreamEvent(tctx.client, 0) assert playbook @@ -329,8 +369,10 @@ class TestRawQuicLayer: >> tutils.reply(None) ) with pytest.raises(AssertionError, match="Unexpected event"): + class InvalidEvent(events.Event): pass + playbook >> InvalidEvent() assert playbook @@ -359,12 +401,16 @@ class TestRawQuicLayer: tutils.Playbook(quic.RawQuicLayer(tctx)) << commands.OpenConnection(tctx.server) >> tutils.reply(None) - >> quic.QuicStreamDataReceived(tctx.client, 0, b"open-connection", end_stream=False) + >> quic.QuicStreamDataReceived( + tctx.client, 0, b"open-connection", end_stream=False + ) << layer.NextLayerHook(tutils.Placeholder()) >> tutils.reply_next_layer(echo_new_server) << commands.OpenConnection(server) >> tutils.reply("uhoh") - << quic.SendQuicStreamData(tctx.client, 0, b"open-connection failed: uhoh", end_stream=False) + << quic.SendQuicStreamData( + tctx.client, 0, b"open-connection failed: uhoh", end_stream=False + ) ) def test_invalid_connection_command(self, tctx: context.Context): @@ -378,8 +424,12 @@ class TestRawQuicLayer: >> tutils.reply_next_layer(TlsEchoLayer) << quic.SendQuicStreamData(tctx.client, 0, b"msg1", end_stream=False) ) - with pytest.raises(AssertionError, match="Unexpected stream connection command"): - playbook >> quic.QuicStreamDataReceived(tctx.client, 0, b"invalid-command", end_stream=False) + with pytest.raises( + AssertionError, match="Unexpected stream connection command" + ): + playbook >> quic.QuicStreamDataReceived( + tctx.client, 0, b"invalid-command", end_stream=False + ) assert playbook @@ -412,8 +462,8 @@ def make_mock_quic( quic_layer.quic = mock quic_layer.tunnel_state = ( tls.tunnel.TunnelState.OPEN - if established else - tls.tunnel.TunnelState.ESTABLISHING + if established + else tls.tunnel.TunnelState.ESTABLISHING ) return tutils.Playbook(quic_layer), mock @@ -423,21 +473,19 @@ class TestQuicLayer: def test_invalid_event(self, tctx: context.Context, established: bool): class InvalidEvent(quic_events.QuicEvent): pass + playbook, conn = make_mock_quic( tctx, event=InvalidEvent(), established=established ) with pytest.raises(AssertionError, match="Unexpected event"): - assert ( - playbook - >> events.DataReceived(tctx.client, b"") - ) + assert playbook >> events.DataReceived(tctx.client, b"") def test_invalid_stream_command(self, tctx: context.Context): playbook, conn = make_mock_quic( tctx, quic_events.DatagramFrameReceived(b"invalid-stream-command") ) with pytest.raises(AssertionError, match="Unexpected stream command"): - assert (playbook >> events.DataReceived(tctx.client, b"")) + assert playbook >> events.DataReceived(tctx.client, b"") def test_close(self, tctx: context.Context): playbook, conn = make_mock_quic( @@ -470,7 +518,7 @@ class TestQuicLayer: tctx, quic_events.DatagramFrameReceived(b"packet") ) assert not conn._datagrams_pending - assert (playbook >> events.DataReceived(tctx.client, b"")) + assert playbook >> events.DataReceived(tctx.client, b"") assert len(conn._datagrams_pending) == 1 assert conn._datagrams_pending[0] == b"packet" @@ -479,15 +527,13 @@ class TestQuicLayer: tctx, quic_events.StreamDataReceived(b"packet", False, 42) ) assert 42 not in conn._streams - assert (playbook >> events.DataReceived(tctx.client, b"")) + assert playbook >> events.DataReceived(tctx.client, b"") assert b"packet" == conn._streams[42].sender._buffer def test_stream_reset(self, tctx: context.Context): - playbook, conn = make_mock_quic( - tctx, quic_events.StreamReset(123, 42) - ) + playbook, conn = make_mock_quic(tctx, quic_events.StreamReset(123, 42)) assert 42 not in conn._streams - assert (playbook >> events.DataReceived(tctx.client, b"")) + assert playbook >> events.DataReceived(tctx.client, b"") assert conn._streams[42].sender.reset_pending assert conn._streams[42].sender._reset_error_code == 123 @@ -497,7 +543,7 @@ class TestQuicLayer: ) assert 24 not in conn._streams conn._get_or_create_stream_for_send(24) - assert (playbook >> events.DataReceived(tctx.client, b"")) + assert playbook >> events.DataReceived(tctx.client, b"") assert conn._streams[24].receiver.stop_pending assert conn._streams[24].receiver._stop_error_code == 123 @@ -521,7 +567,9 @@ class SSLTest: self.ctx.verify_mode = ssl.CERT_OPTIONAL self.ctx.load_verify_locations( - cafile=tlsdata.path("../../net/data/verificationcerts/trusted-root.crt"), + cafile=tlsdata.path( + "../../net/data/verificationcerts/trusted-root.crt" + ), ) if alpn: @@ -613,7 +661,7 @@ def finish_handshake( playbook: tutils.Playbook, conn: connection.Connection, tssl: SSLTest, - child_layer: type[T] + child_layer: type[T], ) -> T: result: Optional[T] = None @@ -644,9 +692,7 @@ def finish_handshake( return result -def reply_tls_start_client( - alpn: Optional[str] = None, *args, **kwargs -) -> tutils.reply: +def reply_tls_start_client(alpn: Optional[str] = None, *args, **kwargs) -> tutils.reply: """ Helper function to simplify the syntax for quic_start_client hooks. """ @@ -658,9 +704,9 @@ def reply_tls_start_client( tlsdata.path("../../net/data/verificationcerts/trusted-leaf.key"), ) tls_start.settings = quic.QuicTlsSettings( - certificate = config.certificate, - certificate_chain = config.certificate_chain, - certificate_private_key = config.private_key, + certificate=config.certificate, + certificate_chain=config.certificate_chain, + certificate_private_key=config.private_key, ) if alpn is not None: tls_start.settings.alpn_protocols = [alpn] @@ -668,9 +714,7 @@ def reply_tls_start_client( return tutils.reply(*args, side_effect=make_client_conn, **kwargs) -def reply_tls_start_server( - alpn: Optional[str] = None, *args, **kwargs -) -> tutils.reply: +def reply_tls_start_server(alpn: Optional[str] = None, *args, **kwargs) -> tutils.reply: """ Helper function to simplify the syntax for quic_start_server hooks. """ @@ -799,7 +843,7 @@ class TestServerTLS: >> events.Wakeup(playbook.actual[9]) << commands.Log( "Server QUIC handshake failed. hostname 'wrong.host.mitmproxy.org' doesn't match 'example.mitmproxy.org'", - WARNING + WARNING, ) << tls.TlsFailedServerHook(tls_hook_data) >> tutils.reply() @@ -810,7 +854,8 @@ class TestServerTLS: ) ) assert ( - tls_hook_data().conn.error == "hostname 'wrong.host.mitmproxy.org' doesn't match 'example.mitmproxy.org'" + tls_hook_data().conn.error + == "hostname 'wrong.host.mitmproxy.org' doesn't match 'example.mitmproxy.org'" ) assert not tctx.server.tls_established @@ -822,7 +867,11 @@ def make_client_tls_layer( # This is a bit contrived as the client layer expects a server layer as parent. # We also set child layers manually to avoid NextLayer noise. - server_layer = DummyLayer(tctx) if no_server else quic.ServerQuicLayer(tctx, time=lambda: tssl_client.now) + server_layer = ( + DummyLayer(tctx) + if no_server + else quic.ServerQuicLayer(tctx, time=lambda: tssl_client.now) + ) client_layer = quic.ClientQuicLayer(tctx, time=lambda: tssl_client.now) server_layer.child_layer = client_layer playbook = tutils.Playbook(server_layer) @@ -881,13 +930,21 @@ class TestClientTLS: assert ( playbook >> events.Wakeup(playbook.actual[16]) - << commands.Log(" >> Wakeup(command=RequestWakeup({'delay': 0.20000000000000004}))", DEBUG) - << commands.Log(" [quic] close_notify Client(client:1234, state=open, tls) (reason=Idle timeout)", DEBUG) + << commands.Log( + " >> Wakeup(command=RequestWakeup({'delay': 0.20000000000000004}))", + DEBUG, + ) + << commands.Log( + " [quic] close_notify Client(client:1234, state=open, tls) (reason=Idle timeout)", + DEBUG, + ) << commands.CloseConnection(tctx.client) ) @pytest.mark.parametrize("server_state", ["open", "closed"]) - def test_server_required(self, tctx: context.Context, server_state: Literal["open", "closed"]): + def test_server_required( + self, tctx: context.Context, server_state: Literal["open", "closed"] + ): """ Test the scenario where a server connection is required (for example, because of an unknown ALPN) to establish TLS with the client. @@ -959,7 +1016,9 @@ class TestClientTLS: _test_echo(playbook, tssl_server, tctx.server) @pytest.mark.parametrize("server_state", ["open", "closed"]) - def test_passthrough_from_clienthello(self, tctx: context.Context, server_state: Literal["open", "closed"]): + def test_passthrough_from_clienthello( + self, tctx: context.Context, server_state: Literal["open", "closed"] + ): """ Test the scenario where the connection is moved to passthrough mode in the tls_clienthello hook. """ @@ -1017,7 +1076,9 @@ class TestClientTLS: client_layer.debug = "" assert ( playbook - >> events.DataReceived(connection.Server(address=None), b"data on other stream") + >> events.DataReceived( + connection.Server(address=None), b"data on other stream" + ) << commands.Log(">> DataReceived(server, b'data on other stream')", DEBUG) << commands.Log( "[quic] Swallowing DataReceived(server, b'data on other stream') as handshake failed.", @@ -1093,12 +1154,16 @@ class TestClientTLS: >> tutils.reply() << commands.Log(f"No QUIC context was provided, failing connection.", ERROR) << commands.CloseConnection(tctx.client) - << commands.Log("Client QUIC handshake failed. connection closed early", WARNING) + << commands.Log( + "Client QUIC handshake failed. connection closed early", WARNING + ) << tls.TlsFailedClientHook(tutils.Placeholder()) ) def test_no_server_tls(self, tctx: context.Context): - playbook, client_layer, tssl_client = make_client_tls_layer(tctx, no_server=True) + playbook, client_layer, tssl_client = make_client_tls_layer( + tctx, no_server=True + ) def require_server_conn(client_hello: tls.ClientHelloData) -> None: client_hello.establish_server_tls_first = True @@ -1129,29 +1194,35 @@ class TestClientTLS: def test_non_init_clienthello(self, tctx: context.Context): playbook, client_layer, tssl_client = make_client_tls_layer(tctx) data = ( - b'\xc2\x00\x00\x00\x01\x08q\xda\x98\x03X-\x13o\x08y\xa5RQv\xbe\xe3\xeb\x00@a\x98\x19\xf95t\xad-\x1c\\a\xdd\x8c\xd0\x15F' - b'\xdf\xdc\x87cb\x1eu\xb0\x95*\xac\xa8\xf7a \xb8\nQ\xbd=\xf5x\xca\r\xe6\x8b\x05 w\x9f\xcd\x8d\xcb\xa0\x06\x1e \x8d.\x8f' - b'T\xda\x12et\xe4\x83\x93X\x8aa\xd1\xb2\x18\xb6\xa7\xf50y\x9b\xc5T\xe1\x87\xdd\x9fqv\xb0\x90\xa7s' - b'\xee\x00\x00\x00\x01\x08q\xda\x98\x03X-\x13o\x08y\xa5RQv\xbe\xe3\xeb@a*.\xa8j\x90\x1b\x1a\x7fZ\x04\x0b\\\xc7\x00\x03' - b'\xd7sC\xf8G\x84\x1e\xba\xcf\x08Z\xdd\x98+\xaa\x98J\xca\xe3\xb7u1\x89\x00\xdf\x8e\x16`\xd9^\xc0@i\x1a\x10\x99\r\xd8' - b'\x1dv3\xc6\xb8"\xb9\xa8F\x95K\x9a/\xbc\'\xd8\xd8\x94\x8f\xe7B/\x05\x9d\xfb\x80\xa9\xda@\xe6\xb0J\xfe\xe0\x0f\x02L}' - b'\xd9\xed\xd2L\xa7\xcf' + b"\xc2\x00\x00\x00\x01\x08q\xda\x98\x03X-\x13o\x08y\xa5RQv\xbe\xe3\xeb\x00@a\x98\x19\xf95t\xad-\x1c\\a\xdd\x8c\xd0\x15F" + b"\xdf\xdc\x87cb\x1eu\xb0\x95*\xac\xa8\xf7a \xb8\nQ\xbd=\xf5x\xca\r\xe6\x8b\x05 w\x9f\xcd\x8d\xcb\xa0\x06\x1e \x8d.\x8f" + b"T\xda\x12et\xe4\x83\x93X\x8aa\xd1\xb2\x18\xb6\xa7\xf50y\x9b\xc5T\xe1\x87\xdd\x9fqv\xb0\x90\xa7s" + b"\xee\x00\x00\x00\x01\x08q\xda\x98\x03X-\x13o\x08y\xa5RQv\xbe\xe3\xeb@a*.\xa8j\x90\x1b\x1a\x7fZ\x04\x0b\\\xc7\x00\x03" + b"\xd7sC\xf8G\x84\x1e\xba\xcf\x08Z\xdd\x98+\xaa\x98J\xca\xe3\xb7u1\x89\x00\xdf\x8e\x16`\xd9^\xc0@i\x1a\x10\x99\r\xd8" + b"\x1dv3\xc6\xb8\"\xb9\xa8F\x95K\x9a/\xbc'\xd8\xd8\x94\x8f\xe7B/\x05\x9d\xfb\x80\xa9\xda@\xe6\xb0J\xfe\xe0\x0f\x02L}" + b"\xd9\xed\xd2L\xa7\xcf" ) assert ( playbook >> events.DataReceived(tctx.client, data) - << commands.Log(f"Client QUIC handshake failed. Invalid handshake received, roaming not supported. ({data.hex()})", WARNING) + << commands.Log( + f"Client QUIC handshake failed. Invalid handshake received, roaming not supported. ({data.hex()})", + WARNING, + ) << tls.TlsFailedClientHook(tutils.Placeholder()) ) assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING def test_invalid_clienthello(self, tctx: context.Context): playbook, client_layer, tssl_client = make_client_tls_layer(tctx) - data = client_hello[0:1200] + b'\x00' + client_hello[1200:] + data = client_hello[0:1200] + b"\x00" + client_hello[1200:] assert ( playbook >> events.DataReceived(tctx.client, data) - << commands.Log(f"Client QUIC handshake failed. Cannot parse ClientHello: No ClientHello returned. ({data.hex()})", WARNING) + << commands.Log( + f"Client QUIC handshake failed. Cannot parse ClientHello: No ClientHello returned. ({data.hex()})", + WARNING, + ) << tls.TlsFailedClientHook(tutils.Placeholder()) ) assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING diff --git a/test/mitmproxy/proxy/layers/test_socks5_fuzz.py b/test/mitmproxy/proxy/layers/test_socks5_fuzz.py index bbefa4b01..e8762fd35 100644 --- a/test/mitmproxy/proxy/layers/test_socks5_fuzz.py +++ b/test/mitmproxy/proxy/layers/test_socks5_fuzz.py @@ -8,11 +8,14 @@ from mitmproxy.proxy.events import DataReceived from mitmproxy.proxy.layers.modes import Socks5Proxy opts = options.Options() -tctx = Context(Client( - peername=("client", 1234), - sockname=("127.0.0.1", 8080), - timestamp_start=1605699329 -), opts) +tctx = Context( + Client( + peername=("client", 1234), + sockname=("127.0.0.1", 8080), + timestamp_start=1605699329, + ), + opts, +) @given(binary()) diff --git a/test/mitmproxy/proxy/layers/test_tcp.py b/test/mitmproxy/proxy/layers/test_tcp.py index df01fa9f9..786462259 100644 --- a/test/mitmproxy/proxy/layers/test_tcp.py +++ b/test/mitmproxy/proxy/layers/test_tcp.py @@ -1,11 +1,18 @@ import pytest -from mitmproxy.proxy.commands import CloseConnection, CloseTcpConnection, OpenConnection, SendData -from mitmproxy.proxy.events import ConnectionClosed, DataReceived +from ..tutils import Placeholder +from ..tutils import Playbook +from ..tutils import reply +from mitmproxy.proxy.commands import CloseConnection +from mitmproxy.proxy.commands import CloseTcpConnection +from mitmproxy.proxy.commands import OpenConnection +from mitmproxy.proxy.commands import SendData +from mitmproxy.proxy.events import ConnectionClosed +from mitmproxy.proxy.events import DataReceived from mitmproxy.proxy.layers import tcp from mitmproxy.proxy.layers.tcp import TcpMessageInjected -from mitmproxy.tcp import TCPFlow, TCPMessage -from ..tutils import Placeholder, Playbook, reply +from mitmproxy.tcp import TCPFlow +from mitmproxy.tcp import TCPMessage def test_open_connection(tctx): diff --git a/test/mitmproxy/proxy/layers/test_tls.py b/test/mitmproxy/proxy/layers/test_tls.py index 7422a5322..1fde306fa 100644 --- a/test/mitmproxy/proxy/layers/test_tls.py +++ b/test/mitmproxy/proxy/layers/test_tls.py @@ -1,20 +1,26 @@ import ssl -from logging import DEBUG, WARNING - import time +from logging import DEBUG +from logging import WARNING from typing import Optional import pytest - from OpenSSL import SSL + from mitmproxy import connection -from mitmproxy.connection import ConnectionState, Server -from mitmproxy.proxy import commands, context, events, layer +from mitmproxy.connection import ConnectionState +from mitmproxy.connection import Server +from mitmproxy.proxy import commands +from mitmproxy.proxy import context +from mitmproxy.proxy import events +from mitmproxy.proxy import layer from mitmproxy.proxy.layers import tls -from mitmproxy.tls import ClientHelloData, TlsData +from mitmproxy.tls import ClientHelloData +from mitmproxy.tls import TlsData from mitmproxy.utils import data from test.mitmproxy.proxy import tutils -from test.mitmproxy.proxy.tutils import BytesMatching, StrMatching +from test.mitmproxy.proxy.tutils import BytesMatching +from test.mitmproxy.proxy.tutils import StrMatching tlsdata = data.Data(__name__) @@ -136,7 +142,7 @@ class SSLTest: def bio_write(self, buf: bytes) -> int: return self.inc.write(buf) - def bio_read(self, bufsize: int = 2 ** 16) -> bytes: + def bio_read(self, bufsize: int = 2**16) -> bytes: return self.out.read(bufsize) def do_handshake(self) -> None: @@ -365,7 +371,9 @@ class TestServerTLS: >> events.DataReceived(tctx.server, tssl.bio_read()) << commands.Log( # different casing in OpenSSL < 3.0 - StrMatching("Server TLS handshake failed. Certificate verify failed: [Hh]ostname mismatch"), + StrMatching( + "Server TLS handshake failed. Certificate verify failed: [Hh]ostname mismatch" + ), WARNING, ) << tls.TlsFailedServerHook(tls_hook_data) @@ -374,11 +382,14 @@ class TestServerTLS: << commands.SendData( tctx.client, # different casing in OpenSSL < 3.0 - BytesMatching(b"open-connection failed: Certificate verify failed: [Hh]ostname mismatch"), + BytesMatching( + b"open-connection failed: Certificate verify failed: [Hh]ostname mismatch" + ), ) ) assert ( - tls_hook_data().conn.error.lower() == "Certificate verify failed: Hostname mismatch".lower() + tls_hook_data().conn.error.lower() + == "Certificate verify failed: Hostname mismatch".lower() ) assert not tctx.server.tls_established @@ -780,7 +791,9 @@ def test_is_dtls_handshake_record(): def test_dtls_record_contents(): - data = bytes.fromhex("16fefd00000000000000000002beef" "16fefd00000000000000000001ff") + data = bytes.fromhex( + "16fefd00000000000000000002beef" "16fefd00000000000000000001ff" + ) assert list(tls.dtls_handshake_record_contents(data)) == [b"\xbe\xef", b"\xff"] for i in range(12): assert list(tls.dtls_handshake_record_contents(data[:i])) == [] @@ -800,8 +813,8 @@ dtls_client_hello_no_extensions = bytes.fromhex( "cc02bc02fc00ac014c02cc03001000000" ) dtls_client_hello_with_extensions = bytes.fromhex( - "16fefd00000000000000000085" # record layer - "010000790000000000000079" # hanshake layer + "16fefd00000000000000000085" # record layer + "010000790000000000000079" # hanshake layer "fefd62bf0e0bf809df43e7669197be831919878b1a72c07a584d3c0a8ca6665878010000000cc02bc02fc00ac014c02cc0" "3001000043000d0010000e0403050306030401050106010807ff01000100000a00080006001d00170018000b00020100001" "7000000000010000e00000b6578616d706c652e636f6d" @@ -809,26 +822,35 @@ dtls_client_hello_with_extensions = bytes.fromhex( def test_dtls_get_client_hello(): - single_record = bytes.fromhex("16fefd00000000000000000042") + dtls_client_hello_no_extensions + single_record = ( + bytes.fromhex("16fefd00000000000000000042") + dtls_client_hello_no_extensions + ) assert tls.get_dtls_client_hello(single_record) == dtls_client_hello_no_extensions split_over_two_records = ( - bytes.fromhex("16fefd00000000000000000020") - + dtls_client_hello_no_extensions[:32] - + bytes.fromhex("16fefd00000000000000000022") - + dtls_client_hello_no_extensions[32:] + bytes.fromhex("16fefd00000000000000000020") + + dtls_client_hello_no_extensions[:32] + + bytes.fromhex("16fefd00000000000000000022") + + dtls_client_hello_no_extensions[32:] + ) + assert ( + tls.get_dtls_client_hello(split_over_two_records) + == dtls_client_hello_no_extensions ) - assert tls.get_dtls_client_hello(split_over_two_records) == dtls_client_hello_no_extensions incomplete = split_over_two_records[:42] assert tls.get_dtls_client_hello(incomplete) is None def test_dtls_parse_client_hello(): - assert tls.dtls_parse_client_hello(dtls_client_hello_with_extensions).sni == "example.com" + assert ( + tls.dtls_parse_client_hello(dtls_client_hello_with_extensions).sni + == "example.com" + ) assert tls.dtls_parse_client_hello(dtls_client_hello_with_extensions[:50]) is None with pytest.raises(ValueError): tls.dtls_parse_client_hello( # Server Name Length longer than actual Server Name - dtls_client_hello_with_extensions[:-16] + b"\x00\x0e\x00\x00\x20\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + dtls_client_hello_with_extensions[:-16] + + b"\x00\x0e\x00\x00\x20\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ) diff --git a/test/mitmproxy/proxy/layers/test_tls_fuzz.py b/test/mitmproxy/proxy/layers/test_tls_fuzz.py index 402c08a5f..e15740989 100644 --- a/test/mitmproxy/proxy/layers/test_tls_fuzz.py +++ b/test/mitmproxy/proxy/layers/test_tls_fuzz.py @@ -1,8 +1,10 @@ -from hypothesis import given, example -from hypothesis.strategies import binary, integers +from hypothesis import example +from hypothesis import given +from hypothesis.strategies import binary +from hypothesis.strategies import integers -from mitmproxy.tls import ClientHello from mitmproxy.proxy.layers.tls import parse_client_hello +from mitmproxy.tls import ClientHello client_hello_with_extensions = bytes.fromhex( "16030300bb" # record layer diff --git a/test/mitmproxy/proxy/layers/test_udp.py b/test/mitmproxy/proxy/layers/test_udp.py index 14b344a57..9b8d3b419 100644 --- a/test/mitmproxy/proxy/layers/test_udp.py +++ b/test/mitmproxy/proxy/layers/test_udp.py @@ -1,11 +1,17 @@ import pytest -from mitmproxy.proxy.commands import CloseConnection, OpenConnection, SendData -from mitmproxy.proxy.events import ConnectionClosed, DataReceived +from ..tutils import Placeholder +from ..tutils import Playbook +from ..tutils import reply +from mitmproxy.proxy.commands import CloseConnection +from mitmproxy.proxy.commands import OpenConnection +from mitmproxy.proxy.commands import SendData +from mitmproxy.proxy.events import ConnectionClosed +from mitmproxy.proxy.events import DataReceived from mitmproxy.proxy.layers import udp from mitmproxy.proxy.layers.udp import UdpMessageInjected -from mitmproxy.udp import UDPFlow, UDPMessage -from ..tutils import Placeholder, Playbook, reply +from mitmproxy.udp import UDPFlow +from mitmproxy.udp import UDPMessage def test_open_connection(tctx): diff --git a/test/mitmproxy/proxy/layers/test_websocket.py b/test/mitmproxy/proxy/layers/test_websocket.py index a1d96133f..eebca259d 100644 --- a/test/mitmproxy/proxy/layers/test_websocket.py +++ b/test/mitmproxy/proxy/layers/test_websocket.py @@ -2,20 +2,28 @@ import secrets from dataclasses import dataclass import pytest - -import wsproto import wsproto.events -from mitmproxy.http import HTTPFlow, Request, Response -from mitmproxy.proxy.layers.http import HTTPMode -from mitmproxy.proxy.commands import SendData, CloseConnection, Log -from mitmproxy.connection import ConnectionState -from mitmproxy.proxy.events import DataReceived, ConnectionClosed -from mitmproxy.proxy.layers import http, websocket -from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected -from mitmproxy.websocket import WebSocketData, WebSocketMessage -from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply from wsproto.frame_protocol import Opcode +from mitmproxy.connection import ConnectionState +from mitmproxy.http import HTTPFlow +from mitmproxy.http import Request +from mitmproxy.http import Response +from mitmproxy.proxy.commands import CloseConnection +from mitmproxy.proxy.commands import Log +from mitmproxy.proxy.commands import SendData +from mitmproxy.proxy.events import ConnectionClosed +from mitmproxy.proxy.events import DataReceived +from mitmproxy.proxy.layers import http +from mitmproxy.proxy.layers import websocket +from mitmproxy.proxy.layers.http import HTTPMode +from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected +from mitmproxy.websocket import WebSocketData +from mitmproxy.websocket import WebSocketMessage +from test.mitmproxy.proxy.tutils import Placeholder +from test.mitmproxy.proxy.tutils import Playbook +from test.mitmproxy.proxy.tutils import reply + @dataclass class _Masked: diff --git a/test/mitmproxy/proxy/test_context.py b/test/mitmproxy/proxy/test_context.py index 62e9c9b7a..1ac2bd833 100644 --- a/test/mitmproxy/proxy/test_context.py +++ b/test/mitmproxy/proxy/test_context.py @@ -1,5 +1,6 @@ from mitmproxy.proxy import context -from mitmproxy.test import tflow, taddons +from mitmproxy.test import taddons +from mitmproxy.test import tflow def test_context(): diff --git a/test/mitmproxy/proxy/test_events.py b/test/mitmproxy/proxy/test_events.py index c415fadad..5e867369e 100644 --- a/test/mitmproxy/proxy/test_events.py +++ b/test/mitmproxy/proxy/test_events.py @@ -3,7 +3,8 @@ from unittest.mock import Mock import pytest from mitmproxy import connection -from mitmproxy.proxy import events, commands +from mitmproxy.proxy import commands +from mitmproxy.proxy import events @pytest.fixture diff --git a/test/mitmproxy/proxy/test_layer.py b/test/mitmproxy/proxy/test_layer.py index 1d4baef7e..b646c2e02 100644 --- a/test/mitmproxy/proxy/test_layer.py +++ b/test/mitmproxy/proxy/test_layer.py @@ -2,7 +2,9 @@ from logging import DEBUG import pytest -from mitmproxy.proxy import commands, events, layer +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer from mitmproxy.proxy.context import Context from test.mitmproxy.proxy import tutils diff --git a/test/mitmproxy/proxy/test_mode_servers.py b/test/mitmproxy/proxy/test_mode_servers.py index eea437ebb..917645fc8 100644 --- a/test/mitmproxy/proxy/test_mode_servers.py +++ b/test/mitmproxy/proxy/test_mode_servers.py @@ -1,14 +1,18 @@ import asyncio import platform from typing import cast -from unittest.mock import AsyncMock, MagicMock, Mock +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import Mock import pytest import mitmproxy.platform from mitmproxy.addons.proxyserver import Proxyserver from mitmproxy.net import udp -from mitmproxy.proxy.mode_servers import DnsInstance, ServerInstance, WireGuardServerInstance +from mitmproxy.proxy.mode_servers import DnsInstance +from mitmproxy.proxy.mode_servers import ServerInstance +from mitmproxy.proxy.mode_servers import WireGuardServerInstance from mitmproxy.proxy.server import ConnectionHandler from mitmproxy.test import taddons @@ -18,14 +22,23 @@ def test_make(): context = MagicMock() assert ServerInstance.make("regular", manager) - for mode in ["regular", "http3", "upstream:example.com", "transparent", "reverse:example.com", "socks5"]: + for mode in [ + "regular", + "http3", + "upstream:example.com", + "transparent", + "reverse:example.com", + "socks5", + ]: inst = ServerInstance.make(mode, manager) assert inst assert inst.make_top_layer(context) assert inst.mode.description assert inst.to_json() - with pytest.raises(ValueError, match="is not a spec for a WireGuardServerInstance server."): + with pytest.raises( + ValueError, match="is not a spec for a WireGuardServerInstance server." + ): WireGuardServerInstance.make("regular", manager) @@ -86,7 +99,9 @@ async def test_transparent(failure, monkeypatch, caplog_async): if failure: monkeypatch.setattr(mitmproxy.platform, "original_addr", None) else: - monkeypatch.setattr(mitmproxy.platform, "original_addr", lambda s: ("address", 42)) + monkeypatch.setattr( + mitmproxy.platform, "original_addr", lambda s: ("address", 42) + ) with taddons.context(Proxyserver()) as tctx: tctx.options.connection_strategy = "lazy" @@ -199,12 +214,16 @@ async def test_wireguard_invalid_conf(tmp_path): async def test_tcp_start_error(): manager = MagicMock() - server = await asyncio.start_server(MagicMock(), host="127.0.0.1", port=0, reuse_address=False) + server = await asyncio.start_server( + MagicMock(), host="127.0.0.1", port=0, reuse_address=False + ) port = server.sockets[0].getsockname()[1] with taddons.context() as tctx: inst = ServerInstance.make(f"regular@127.0.0.1:{port}", manager) - with pytest.raises(OSError, match=f"proxy failed to listen on 127\\.0\\.0\\.1:{port}"): + with pytest.raises( + OSError, match=f"proxy failed to listen on 127\\.0\\.0\\.1:{port}" + ): await inst.start() tctx.options.listen_host = "127.0.0.1" tctx.options.listen_port = port @@ -253,7 +272,9 @@ async def test_udp_start_error(): await inst.start() port = inst.listen_addrs[0][1] inst2 = ServerInstance.make(f"dns@127.0.0.1:{port}", manager) - with pytest.raises(OSError, match=f"server failed to listen on 127\\.0\\.0\\.1:{port}"): + with pytest.raises( + OSError, match=f"server failed to listen on 127\\.0\\.0\\.1:{port}" + ): await inst2.start() await inst.stop() @@ -267,8 +288,12 @@ async def test_udp_connection_reuse(monkeypatch): with taddons.context(): inst = cast(DnsInstance, ServerInstance.make("dns", manager)) - inst.handle_udp_datagram(MagicMock(), b"\x00\x00\x01", ("remoteaddr", 0), ("localaddr", 0)) - inst.handle_udp_datagram(MagicMock(), b"\x00\x00\x02", ("remoteaddr", 0), ("localaddr", 0)) + inst.handle_udp_datagram( + MagicMock(), b"\x00\x00\x01", ("remoteaddr", 0), ("localaddr", 0) + ) + inst.handle_udp_datagram( + MagicMock(), b"\x00\x00\x02", ("remoteaddr", 0), ("localaddr", 0) + ) await asyncio.sleep(0) assert len(inst.manager.connections) == 1 diff --git a/test/mitmproxy/proxy/test_mode_specs.py b/test/mitmproxy/proxy/test_mode_specs.py index be83e5238..b52c72196 100644 --- a/test/mitmproxy/proxy/test_mode_specs.py +++ b/test/mitmproxy/proxy/test_mode_specs.py @@ -2,7 +2,8 @@ import dataclasses import pytest -from mitmproxy.proxy.mode_specs import ProxyMode, Socks5Mode +from mitmproxy.proxy.mode_specs import ProxyMode +from mitmproxy.proxy.mode_specs import Socks5Mode def test_parse(): @@ -45,7 +46,10 @@ def test_listen_addr(): assert ProxyMode.parse("regular").listen_host() == "" assert ProxyMode.parse("regular@127.0.0.2:8080").listen_host() == "127.0.0.2" assert ProxyMode.parse("regular").listen_host(default="127.0.0.3") == "127.0.0.3" - assert ProxyMode.parse("regular@127.0.0.2:8080").listen_host(default="127.0.0.3") == "127.0.0.2" + assert ( + ProxyMode.parse("regular@127.0.0.2:8080").listen_host(default="127.0.0.3") + == "127.0.0.2" + ) assert ProxyMode.parse("reverse:https://1.2.3.4").listen_port() == 8080 assert ProxyMode.parse("reverse:dns://8.8.8.8").listen_port() == 53 diff --git a/test/mitmproxy/proxy/test_tunnel.py b/test/mitmproxy/proxy/test_tunnel.py index ad9af4112..0256390cf 100644 --- a/test/mitmproxy/proxy/test_tunnel.py +++ b/test/mitmproxy/proxy/test_tunnel.py @@ -2,12 +2,22 @@ from typing import Optional import pytest -from mitmproxy.proxy import tunnel, layer -from mitmproxy.proxy.commands import CloseTcpConnection, SendData, Log, CloseConnection, OpenConnection -from mitmproxy.connection import Server, ConnectionState +from mitmproxy.connection import ConnectionState +from mitmproxy.connection import Server +from mitmproxy.proxy import layer +from mitmproxy.proxy import tunnel +from mitmproxy.proxy.commands import CloseConnection +from mitmproxy.proxy.commands import CloseTcpConnection +from mitmproxy.proxy.commands import Log +from mitmproxy.proxy.commands import OpenConnection +from mitmproxy.proxy.commands import SendData from mitmproxy.proxy.context import Context -from mitmproxy.proxy.events import Event, DataReceived, Start, ConnectionClosed -from test.mitmproxy.proxy.tutils import Playbook, reply +from mitmproxy.proxy.events import ConnectionClosed +from mitmproxy.proxy.events import DataReceived +from mitmproxy.proxy.events import Event +from mitmproxy.proxy.events import Start +from test.mitmproxy.proxy.tutils import Playbook +from test.mitmproxy.proxy.tutils import reply class TChildLayer(layer.Layer): diff --git a/test/mitmproxy/proxy/test_tutils.py b/test/mitmproxy/proxy/test_tutils.py index 04880990e..ec676405b 100644 --- a/test/mitmproxy/proxy/test_tutils.py +++ b/test/mitmproxy/proxy/test_tutils.py @@ -4,8 +4,10 @@ from typing import Any import pytest -from mitmproxy.proxy import commands, events, layer from . import tutils +from mitmproxy.proxy import commands +from mitmproxy.proxy import events +from mitmproxy.proxy import layer class TEvent(events.Event): diff --git a/test/mitmproxy/proxy/tutils.py b/test/mitmproxy/proxy/tutils.py index 087db7131..618f6da0b 100644 --- a/test/mitmproxy/proxy/tutils.py +++ b/test/mitmproxy/proxy/tutils.py @@ -1,17 +1,24 @@ import collections.abc import difflib -import logging - import itertools +import logging import re import textwrap import traceback -from collections.abc import Callable, Iterable -from typing import Any, AnyStr, Generic, Optional, TypeVar, Union +from collections.abc import Callable +from collections.abc import Iterable +from typing import Any +from typing import AnyStr +from typing import Generic +from typing import Optional +from typing import TypeVar +from typing import Union -from mitmproxy.proxy import commands, context, layer -from mitmproxy.proxy import events from mitmproxy.connection import ConnectionState +from mitmproxy.proxy import commands +from mitmproxy.proxy import context +from mitmproxy.proxy import events +from mitmproxy.proxy import layer from mitmproxy.proxy.events import command_reply_subclasses from mitmproxy.proxy.layer import Layer diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py index 4bc5f1582..446fbb8d9 100644 --- a/test/mitmproxy/script/test_concurrent.py +++ b/test/mitmproxy/script/test_concurrent.py @@ -4,8 +4,8 @@ import time import pytest -from mitmproxy.test import tflow from mitmproxy.test import taddons +from mitmproxy.test import tflow class TestConcurrent: diff --git a/test/mitmproxy/test_addonmanager.py b/test/mitmproxy/test_addonmanager.py index 77057d11b..4122f377c 100644 --- a/test/mitmproxy/test_addonmanager.py +++ b/test/mitmproxy/test_addonmanager.py @@ -7,7 +7,8 @@ from mitmproxy import exceptions from mitmproxy import hooks from mitmproxy import master from mitmproxy import options -from mitmproxy.proxy.layers.http import HttpRequestHook, HttpResponseHook +from mitmproxy.proxy.layers.http import HttpRequestHook +from mitmproxy.proxy.layers.http import HttpResponseHook from mitmproxy.test import taddons from mitmproxy.test import tflow diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index 448efdfea..815a84c61 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -1,13 +1,14 @@ import os -from datetime import datetime, timezone +from datetime import datetime +from datetime import timezone from pathlib import Path + +import pytest from cryptography import x509 from cryptography.x509 import NameOID -import pytest - -from mitmproxy import certs from ..conftest import skip_windows +from mitmproxy import certs # class TestDNTree: diff --git a/test/mitmproxy/test_command_lexer.py b/test/mitmproxy/test_command_lexer.py index dfe9b2719..dd7be31dd 100644 --- a/test/mitmproxy/test_command_lexer.py +++ b/test/mitmproxy/test_command_lexer.py @@ -1,6 +1,7 @@ import pyparsing import pytest -from hypothesis import given, example +from hypothesis import example +from hypothesis import given from hypothesis.strategies import text from mitmproxy import command_lexer diff --git a/test/mitmproxy/test_connection.py b/test/mitmproxy/test_connection.py index 300015c76..eee660785 100644 --- a/test/mitmproxy/test_connection.py +++ b/test/mitmproxy/test_connection.py @@ -1,14 +1,20 @@ import pytest -from mitmproxy.connection import Server, Client, ConnectionState -from mitmproxy.test.tflow import tclient_conn, tserver_conn +from mitmproxy.connection import Client +from mitmproxy.connection import ConnectionState +from mitmproxy.connection import Server +from mitmproxy.test.tflow import tclient_conn +from mitmproxy.test.tflow import tserver_conn class TestConnection: def test_basic(self): - c = Client(peername=("127.0.0.1", 52314), sockname=("127.0.0.1", 8080), - timestamp_start=1607780791, - state=ConnectionState.OPEN) + c = Client( + peername=("127.0.0.1", 52314), + sockname=("127.0.0.1", 8080), + timestamp_start=1607780791, + state=ConnectionState.OPEN, + ) assert not c.tls_established c.timestamp_tls_setup = 1607780792 assert c.tls_established @@ -34,7 +40,7 @@ class TestClient: peername=("127.0.0.1", 52314), sockname=("127.0.0.1", 8080), timestamp_start=1607780791, - cipher_list=["foo", "bar"] + cipher_list=["foo", "bar"], ) assert repr(c) assert str(c) diff --git a/test/mitmproxy/test_dns.py b/test/mitmproxy/test_dns.py index 6a5075c91..cf9af2ac3 100644 --- a/test/mitmproxy/test_dns.py +++ b/test/mitmproxy/test_dns.py @@ -1,5 +1,6 @@ import ipaddress import struct + import pytest from mitmproxy import dns @@ -111,7 +112,7 @@ class TestMessage: with pytest.raises(ValueError): req.packed - test("id", 0, 2 ** 16 - 1) + test("id", 0, 2**16 - 1) test("reserved", 0, 7) test("op_code", 0, 0b1111) test("response_code", 0, 0b1111) diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index a44408c25..29f2eeb93 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -8,8 +8,10 @@ from mitmproxy import flowfilter from mitmproxy import options from mitmproxy.exceptions import FlowReadException from mitmproxy.io import tnetstring -from mitmproxy.proxy import server_hooks, layers -from mitmproxy.test import taddons, tflow +from mitmproxy.proxy import layers +from mitmproxy.proxy import server_hooks +from mitmproxy.test import taddons +from mitmproxy.test import tflow class State: diff --git a/test/mitmproxy/test_flowfilter.py b/test/mitmproxy/test_flowfilter.py index cf6388c1e..b494c9333 100644 --- a/test/mitmproxy/test_flowfilter.py +++ b/test/mitmproxy/test_flowfilter.py @@ -1,8 +1,11 @@ import io -import pytest from unittest.mock import patch + +import pytest + +from mitmproxy import flowfilter +from mitmproxy import http from mitmproxy.test import tflow -from mitmproxy import flowfilter, http class TestParsing: diff --git a/test/mitmproxy/test_http.py b/test/mitmproxy/test_http.py index 169aafd71..acdff67d9 100644 --- a/test/mitmproxy/test_http.py +++ b/test/mitmproxy/test_http.py @@ -1,17 +1,21 @@ import asyncio import email -import time import json +import time from unittest import mock import pytest from mitmproxy import flow from mitmproxy import flowfilter -from mitmproxy.http import Headers, Request, Response, HTTPFlow +from mitmproxy.http import Headers +from mitmproxy.http import HTTPFlow +from mitmproxy.http import Request +from mitmproxy.http import Response from mitmproxy.net.http.cookies import CookieAttrs from mitmproxy.test.tflow import tflow -from mitmproxy.test.tutils import treq, tresp +from mitmproxy.test.tutils import treq +from mitmproxy.test.tutils import tresp class TestRequest: diff --git a/test/mitmproxy/test_optmanager.py b/test/mitmproxy/test_optmanager.py index 391a7d0b7..62df7d0b5 100644 --- a/test/mitmproxy/test_optmanager.py +++ b/test/mitmproxy/test_optmanager.py @@ -1,14 +1,14 @@ +import argparse import copy import io from collections.abc import Sequence from typing import Optional import pytest -import argparse +from mitmproxy import exceptions from mitmproxy import options from mitmproxy import optmanager -from mitmproxy import exceptions class TO(optmanager.OptManager): diff --git a/test/mitmproxy/test_tcp.py b/test/mitmproxy/test_tcp.py index 13001c862..3fdde8b23 100644 --- a/test/mitmproxy/test_tcp.py +++ b/test/mitmproxy/test_tcp.py @@ -1,7 +1,7 @@ import pytest -from mitmproxy import tcp from mitmproxy import flowfilter +from mitmproxy import tcp from mitmproxy.test import tflow diff --git a/test/mitmproxy/test_tls.py b/test/mitmproxy/test_tls.py index af23bc3ab..b04491b81 100644 --- a/test/mitmproxy/test_tls.py +++ b/test/mitmproxy/test_tls.py @@ -103,10 +103,11 @@ class TestDTLSClientHello: assert c.cipher_suites == [2, 3, 10, 5, 4, 9] assert c.alpn_protocols == [b"h2", b"http/1.1"] assert c.extensions == [ - (13, b'\x00\x0e\x04\x03\x05\x03\x06\x03\x04\x01\x05\x01\x06\x01\x08\x07'), - (65281, b'\x00'), - (10, b'\x00\x06\x00\x1d\x00\x17\x00\x18'), - (11, b'\x01\x00'), (23, b''), - (0, b'\x00\x0e\x00\x00\x0bexample.com'), - (16, b'\x00\x0c\x02h2\x08http/1.1') + (13, b"\x00\x0e\x04\x03\x05\x03\x06\x03\x04\x01\x05\x01\x06\x01\x08\x07"), + (65281, b"\x00"), + (10, b"\x00\x06\x00\x1d\x00\x17\x00\x18"), + (11, b"\x01\x00"), + (23, b""), + (0, b"\x00\x0e\x00\x00\x0bexample.com"), + (16, b"\x00\x0c\x02h2\x08http/1.1"), ] diff --git a/test/mitmproxy/test_types.py b/test/mitmproxy/test_types.py index 29d2b1f03..31a33b429 100644 --- a/test/mitmproxy/test_types.py +++ b/test/mitmproxy/test_types.py @@ -1,17 +1,16 @@ +import contextlib +import os from collections.abc import Sequence import pytest -import os -import contextlib import mitmproxy.exceptions import mitmproxy.types -from mitmproxy.test import taddons -from mitmproxy.test import tflow +from . import test_command from mitmproxy import command from mitmproxy import flow - -from . import test_command +from mitmproxy.test import taddons +from mitmproxy.test import tflow @contextlib.contextmanager diff --git a/test/mitmproxy/test_udp.py b/test/mitmproxy/test_udp.py index 2a6a8dd12..ba652f74f 100644 --- a/test/mitmproxy/test_udp.py +++ b/test/mitmproxy/test_udp.py @@ -1,7 +1,7 @@ import pytest -from mitmproxy import udp from mitmproxy import flowfilter +from mitmproxy import udp from mitmproxy.test import tflow diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py index f227d0dc5..08117b0fb 100644 --- a/test/mitmproxy/test_websocket.py +++ b/test/mitmproxy/test_websocket.py @@ -1,9 +1,9 @@ import pytest +from wsproto.frame_protocol import Opcode from mitmproxy import http from mitmproxy import websocket from mitmproxy.test import tflow -from wsproto.frame_protocol import Opcode class TestWebSocketData: diff --git a/test/mitmproxy/tools/console/test_contentview.py b/test/mitmproxy/tools/console/test_contentview.py index 9819ea9a2..ee0b72757 100644 --- a/test/mitmproxy/tools/console/test_contentview.py +++ b/test/mitmproxy/tools/console/test_contentview.py @@ -1,6 +1,6 @@ -from mitmproxy.test import tflow from mitmproxy import contentviews from mitmproxy.contentviews.base import format_text +from mitmproxy.test import tflow class TContentView(contentviews.View): diff --git a/test/mitmproxy/tools/console/test_keymap.py b/test/mitmproxy/tools/console/test_keymap.py index fb73471e3..624412ab2 100644 --- a/test/mitmproxy/tools/console/test_keymap.py +++ b/test/mitmproxy/tools/console/test_keymap.py @@ -1,8 +1,10 @@ -from mitmproxy.tools.console import keymap -from mitmproxy.test import taddons from unittest import mock + import pytest +from mitmproxy.test import taddons +from mitmproxy.tools.console import keymap + def test_binding(): b = keymap.Binding("space", "cmd", ["options"], "") diff --git a/test/mitmproxy/tools/console/test_quickhelp.py b/test/mitmproxy/tools/console/test_quickhelp.py index 958bf6ae4..722af3dab 100644 --- a/test/mitmproxy/tools/console/test_quickhelp.py +++ b/test/mitmproxy/tools/console/test_quickhelp.py @@ -1,7 +1,8 @@ import pytest from mitmproxy.test.tflow import tflow -from mitmproxy.tools.console import defaultkeys, quickhelp +from mitmproxy.tools.console import defaultkeys +from mitmproxy.tools.console import quickhelp from mitmproxy.tools.console.eventlog import EventLog from mitmproxy.tools.console.flowlist import FlowListBox from mitmproxy.tools.console.flowview import FlowView @@ -38,7 +39,7 @@ tflow2.marked = "x" (EventLog, None, True), (PathEditor, None, False), (SimpleOverlay, None, False), - ] + ], ) def test_quickhelp(widget, flow, keymap, is_root_widget): qh = quickhelp.make(widget, flow, is_root_widget) diff --git a/test/mitmproxy/tools/web/test_app.py b/test/mitmproxy/tools/web/test_app.py index 9a8ece62a..d4b29906b 100644 --- a/test/mitmproxy/tools/web/test_app.py +++ b/test/mitmproxy/tools/web/test_app.py @@ -1,5 +1,5 @@ -import io import gzip +import io import json import logging import textwrap @@ -14,7 +14,10 @@ import tornado.testing from tornado import httpclient from tornado import websocket -from mitmproxy import certs, log, options, optmanager +from mitmproxy import certs +from mitmproxy import log +from mitmproxy import options +from mitmproxy import optmanager from mitmproxy.http import Headers from mitmproxy.proxy.mode_servers import ServerInstance from mitmproxy.test import tflow @@ -159,7 +162,9 @@ class TestApp(tornado.testing.AsyncHTTPTestCase): o = options.Options(http2=False) return webmaster.WebMaster(o, with_termlog=False) - m: webmaster.WebMaster = self.io_loop.asyncio_loop.run_until_complete(make_master()) + m: webmaster.WebMaster = self.io_loop.asyncio_loop.run_until_complete( + make_master() + ) f = tflow.tflow(resp=True) f.id = "42" f.request.content = b"foo\nbar" @@ -177,11 +182,13 @@ class TestApp(tornado.testing.AsyncHTTPTestCase): si2 = ServerInstance.make("reverse:example.com", m.proxyserver) si2.last_exception = RuntimeError("I failed somehow.") si3 = ServerInstance.make("socks5", m.proxyserver) - m.proxyserver.servers._instances.update({ - si1.mode: si1, - si2.mode: si2, - si3.mode: si3, - }) + m.proxyserver.servers._instances.update( + { + si1.mode: si1, + si2.mode: si2, + si3.mode: si3, + } + ) self.master = m self.view = m.view self.events = m.events @@ -497,11 +504,14 @@ class TestApp(tornado.testing.AsyncHTTPTestCase): "export function TBackendState(): Required {\n" " return %s\n" "}\n" - % textwrap.indent(json.dumps(data, indent=4, sort_keys=True), " ").lstrip() + % textwrap.indent( + json.dumps(data, indent=4, sort_keys=True), " " + ).lstrip() ) ( - Path(__file__).parent / "../../../../web/src/js/__tests__/ducks/_tbackendstate.ts" + Path(__file__).parent + / "../../../../web/src/js/__tests__/ducks/_tbackendstate.ts" ).write_bytes(content.encode()) def test_err(self): diff --git a/test/mitmproxy/tools/web/test_master.py b/test/mitmproxy/tools/web/test_master.py index c193ea87b..14b104694 100644 --- a/test/mitmproxy/tools/web/test_master.py +++ b/test/mitmproxy/tools/web/test_master.py @@ -2,12 +2,15 @@ import asyncio from unittest.mock import MagicMock import pytest + from mitmproxy.options import Options from mitmproxy.tools.web.master import WebMaster async def test_reuse(): - server = await asyncio.start_server(MagicMock(), host="127.0.0.1", port=0, reuse_address=False) + server = await asyncio.start_server( + MagicMock(), host="127.0.0.1", port=0, reuse_address=False + ) port = server.sockets[0].getsockname()[1] master = WebMaster(Options(), with_termlog=False) master.options.web_host = "127.0.0.1" diff --git a/test/mitmproxy/tools/web/test_static_viewer.py b/test/mitmproxy/tools/web/test_static_viewer.py index 4364e2557..74473a18f 100644 --- a/test/mitmproxy/tools/web/test_static_viewer.py +++ b/test/mitmproxy/tools/web/test_static_viewer.py @@ -1,14 +1,13 @@ import json from unittest import mock +from mitmproxy import flowfilter +from mitmproxy.addons import readfile +from mitmproxy.addons import save from mitmproxy.test import taddons from mitmproxy.test import tflow - -from mitmproxy import flowfilter -from mitmproxy.tools.web.app import flow_to_json - from mitmproxy.tools.web import static_viewer -from mitmproxy.addons import save, readfile +from mitmproxy.tools.web.app import flow_to_json def test_save_static(tmpdir): diff --git a/test/mitmproxy/utils/test_arg_check.py b/test/mitmproxy/utils/test_arg_check.py index 97102f49b..d498d1c43 100644 --- a/test/mitmproxy/utils/test_arg_check.py +++ b/test/mitmproxy/utils/test_arg_check.py @@ -1,5 +1,5 @@ -import io import contextlib +import io from unittest import mock import pytest diff --git a/test/mitmproxy/utils/test_data.py b/test/mitmproxy/utils/test_data.py index f40fc8665..4e7c7af2a 100644 --- a/test/mitmproxy/utils/test_data.py +++ b/test/mitmproxy/utils/test_data.py @@ -1,4 +1,5 @@ import pytest + from mitmproxy.utils import data diff --git a/test/mitmproxy/utils/test_debug.py b/test/mitmproxy/utils/test_debug.py index a61bff868..6384a5981 100644 --- a/test/mitmproxy/utils/test_debug.py +++ b/test/mitmproxy/utils/test_debug.py @@ -1,6 +1,7 @@ import io import sys from unittest import mock + import pytest from mitmproxy.utils import debug diff --git a/test/mitmproxy/utils/test_emoji.py b/test/mitmproxy/utils/test_emoji.py index a147ba885..2b099926b 100644 --- a/test/mitmproxy/utils/test_emoji.py +++ b/test/mitmproxy/utils/test_emoji.py @@ -1,5 +1,5 @@ -from mitmproxy.utils import emoji from mitmproxy.tools.console.common import SYMBOL_MARK +from mitmproxy.utils import emoji def test_emoji(): diff --git a/test/mitmproxy/utils/test_human.py b/test/mitmproxy/utils/test_human.py index 944740611..d4791de3c 100644 --- a/test/mitmproxy/utils/test_human.py +++ b/test/mitmproxy/utils/test_human.py @@ -1,5 +1,7 @@ import time + import pytest + from mitmproxy.utils import human @@ -16,8 +18,8 @@ def test_parse_size(): assert human.parse_size("0b") == 0 assert human.parse_size("1") == 1 assert human.parse_size("1k") == 1024 - assert human.parse_size("1m") == 1024 ** 2 - assert human.parse_size("1g") == 1024 ** 3 + assert human.parse_size("1m") == 1024**2 + assert human.parse_size("1g") == 1024**3 with pytest.raises(ValueError): human.parse_size("1f") with pytest.raises(ValueError): diff --git a/test/mitmproxy/utils/test_magisk.py b/test/mitmproxy/utils/test_magisk.py index 83116d7f3..2382e3921 100644 --- a/test/mitmproxy/utils/test_magisk.py +++ b/test/mitmproxy/utils/test_magisk.py @@ -1,8 +1,10 @@ -from mitmproxy.utils import magisk -from cryptography import x509 -from mitmproxy.test import taddons import os +from cryptography import x509 + +from mitmproxy.test import taddons +from mitmproxy.utils import magisk + def test_get_ca(tdata): with taddons.context() as tctx: diff --git a/test/mitmproxy/utils/test_signals.py b/test/mitmproxy/utils/test_signals.py index 1fcc4f26d..dd856eb58 100644 --- a/test/mitmproxy/utils/test_signals.py +++ b/test/mitmproxy/utils/test_signals.py @@ -1,7 +1,9 @@ from unittest import mock import pytest -from mitmproxy.utils.signals import AsyncSignal, SyncSignal + +from mitmproxy.utils.signals import AsyncSignal +from mitmproxy.utils.signals import SyncSignal def test_sync_signal() -> None: diff --git a/test/mitmproxy/utils/test_spec.py b/test/mitmproxy/utils/test_spec.py index 6cefcacc7..630dd1799 100644 --- a/test/mitmproxy/utils/test_spec.py +++ b/test/mitmproxy/utils/test_spec.py @@ -1,4 +1,5 @@ import pytest + from mitmproxy.utils.spec import parse_spec diff --git a/test/mitmproxy/utils/test_strutils.py b/test/mitmproxy/utils/test_strutils.py index 3459a673f..f5b2894ac 100644 --- a/test/mitmproxy/utils/test_strutils.py +++ b/test/mitmproxy/utils/test_strutils.py @@ -44,7 +44,7 @@ def test_escape_control_characters(): def test_bytes_to_escaped_str(): assert strutils.bytes_to_escaped_str(b"foo") == "foo" assert strutils.bytes_to_escaped_str(b"\b") == r"\x08" - assert strutils.bytes_to_escaped_str(br"&!?=\)") == r"&!?=\\)" + assert strutils.bytes_to_escaped_str(rb"&!?=\)") == r"&!?=\\)" assert strutils.bytes_to_escaped_str(b"\xc3\xbc") == r"\xc3\xbc" assert strutils.bytes_to_escaped_str(b"'") == r"'" assert strutils.bytes_to_escaped_str(b'"') == r'"' @@ -69,9 +69,9 @@ def test_bytes_to_escaped_str(): def test_escaped_str_to_bytes(): assert strutils.escaped_str_to_bytes("foo") == b"foo" assert strutils.escaped_str_to_bytes("\x08") == b"\b" - assert strutils.escaped_str_to_bytes("&!?=\\\\)") == br"&!?=\)" + assert strutils.escaped_str_to_bytes("&!?=\\\\)") == rb"&!?=\)" assert strutils.escaped_str_to_bytes("\\x08") == b"\b" - assert strutils.escaped_str_to_bytes("&!?=\\\\)") == br"&!?=\)" + assert strutils.escaped_str_to_bytes("&!?=\\\\)") == rb"&!?=\)" assert strutils.escaped_str_to_bytes("\u00fc") == b"\xc3\xbc" with pytest.raises(ValueError): diff --git a/test/mitmproxy/utils/test_typecheck.py b/test/mitmproxy/utils/test_typecheck.py index 347e39f30..0f480157c 100644 --- a/test/mitmproxy/utils/test_typecheck.py +++ b/test/mitmproxy/utils/test_typecheck.py @@ -1,7 +1,10 @@ import io import typing from collections.abc import Sequence -from typing import Any, Optional, TextIO, Union +from typing import Any +from typing import Optional +from typing import TextIO +from typing import Union import pytest @@ -84,4 +87,4 @@ def test_typesec_to_str(): def test_typing_aliases(): assert (typecheck.typespec_to_str(typing.Sequence[str])) == "sequence of str" typecheck.check_option_type("foo", [10], typing.Sequence[int]) - typecheck.check_option_type("foo", (42, "42"), typing.Tuple[int, str]) + typecheck.check_option_type("foo", (42, "42"), tuple[int, str])