mirror of
https://github.com/mitmproxy/mitmproxy.git
synced 2024-11-24 05:40:05 +00:00
commit
8b74cbed72
@ -2,6 +2,8 @@
|
||||
|
||||
## Unreleased: mitmproxy next
|
||||
|
||||
* Add experimental QUIC support.
|
||||
([#5435](https://github.com/mitmproxy/mitmproxy/issues/5435), @meitinger)
|
||||
* ASGI/WSGI apps can now listen on all ports for a specific hostname.
|
||||
This makes it simpler to accept both HTTP and HTTPS.
|
||||
|
||||
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||
|
||||
from mitmproxy import hooks, log, addonmanager
|
||||
from mitmproxy.proxy import server_hooks, layer
|
||||
from mitmproxy.proxy.layers import dns, http, modes, tcp, tls, udp, websocket
|
||||
from mitmproxy.proxy.layers import dns, http, modes, quic, tcp, tls, udp, websocket
|
||||
|
||||
known = set()
|
||||
|
||||
@ -139,6 +139,15 @@ with outfile.open("w") as f, contextlib.redirect_stdout(f):
|
||||
],
|
||||
)
|
||||
|
||||
category(
|
||||
"QUIC",
|
||||
"",
|
||||
[
|
||||
quic.QuicStartClientHook,
|
||||
quic.QuicStartServerHook,
|
||||
],
|
||||
)
|
||||
|
||||
category(
|
||||
"TLS",
|
||||
"",
|
||||
|
@ -16,6 +16,7 @@ from mitmproxy import flow
|
||||
from mitmproxy import flowfilter
|
||||
from mitmproxy import http
|
||||
from mitmproxy.contrib import click as miniclick
|
||||
from mitmproxy.net.dns import response_codes
|
||||
from mitmproxy.tcp import TCPFlow, TCPMessage
|
||||
from mitmproxy.udp import UDPFlow, UDPMessage
|
||||
from mitmproxy.utils import human
|
||||
@ -204,7 +205,7 @@ class Dumper:
|
||||
blink=(code_int == 418),
|
||||
)
|
||||
|
||||
if not flow.response.is_http2:
|
||||
if not (flow.response.is_http2 or flow.response.is_http3):
|
||||
reason = flow.response.reason
|
||||
else:
|
||||
reason = http.status_codes.RESPONSES.get(flow.response.status_code, "")
|
||||
@ -335,16 +336,20 @@ class Dumper:
|
||||
def udp_error(self, f):
|
||||
self._proto_error(f)
|
||||
|
||||
def _proto_message(self, f):
|
||||
def _proto_message(self, f: Union[TCPFlow, UDPFlow]) -> None:
|
||||
if self.match(f):
|
||||
message = f.messages[-1]
|
||||
direction = "->" if message.from_client else "<-"
|
||||
if f.client_conn.tls_version == "QUIC":
|
||||
type_ = f"quic/{f.type}"
|
||||
else:
|
||||
type_ = f.type
|
||||
self.echo(
|
||||
"{client} {direction} {type} {direction} {server}".format(
|
||||
client=human.format_address(f.client_conn.peername),
|
||||
server=human.format_address(f.server_conn.address),
|
||||
direction=direction,
|
||||
type=f.type,
|
||||
type=type_,
|
||||
)
|
||||
)
|
||||
if ctx.options.flow_detail >= 3:
|
||||
@ -377,9 +382,14 @@ class Dumper:
|
||||
self._echo_dns_query(f)
|
||||
|
||||
arrows = self.style(" <<", bold=True)
|
||||
answers = ", ".join(
|
||||
self.style(str(x), fg="bright_blue") for x in f.response.answers
|
||||
)
|
||||
if f.response.answers:
|
||||
answers = ", ".join(
|
||||
self.style(str(x), fg="bright_blue") for x in f.response.answers
|
||||
)
|
||||
else:
|
||||
answers = self.style(response_codes.to_str(
|
||||
f.response.response_code,
|
||||
), fg="red")
|
||||
self.echo(f"{arrows} {answers}")
|
||||
|
||||
def dns_error(self, f: dns.DNSFlow):
|
||||
|
@ -17,17 +17,20 @@ that sets nextlayer.layer works just as well.
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
import struct
|
||||
from typing import Any, Callable, Iterable, Optional, Union
|
||||
from typing import Any, Callable, Iterable, Optional, Union, cast
|
||||
|
||||
from mitmproxy import ctx, dns, exceptions, connection
|
||||
from mitmproxy.net.tls import is_tls_record_magic
|
||||
from mitmproxy.proxy.layers.http import HTTPMode
|
||||
from mitmproxy.proxy import context, layer, layers
|
||||
from mitmproxy.proxy import context, layer, layers, mode_specs
|
||||
from mitmproxy.proxy.layers import modes
|
||||
from mitmproxy.proxy.layers.quic import quic_parse_client_hello
|
||||
from mitmproxy.proxy.layers.tls import HTTP_ALPNS, dtls_parse_client_hello, parse_client_hello
|
||||
from mitmproxy.tls import ClientHello
|
||||
|
||||
LayerCls = type[layer.Layer]
|
||||
ClientSecurityLayerCls = Union[type[layers.ClientTLSLayer], type[layers.ClientQuicLayer]]
|
||||
ServerSecurityLayerCls = Union[type[layers.ServerTLSLayer], type[layers.ServerQuicLayer]]
|
||||
|
||||
|
||||
def stack_match(
|
||||
@ -118,7 +121,12 @@ class NextLayer:
|
||||
else: # pragma: no cover
|
||||
raise AssertionError()
|
||||
|
||||
def setup_tls_layer(self, context: context.Context) -> layer.Layer:
|
||||
def setup_tls_layer(
|
||||
self,
|
||||
context: context.Context,
|
||||
client_layer_cls: ClientSecurityLayerCls = layers.ClientTLSLayer,
|
||||
server_layer_cls: ServerSecurityLayerCls = layers.ServerTLSLayer,
|
||||
) -> layer.Layer:
|
||||
def s(*layers):
|
||||
return stack_match(context, layers)
|
||||
|
||||
@ -130,14 +138,14 @@ class NextLayer:
|
||||
s(modes.HttpProxy)
|
||||
or s(modes.HttpUpstreamProxy)
|
||||
or s(modes.ReverseProxy)
|
||||
or s(modes.ReverseProxy, layers.ServerTLSLayer)
|
||||
or s(modes.ReverseProxy, server_layer_cls)
|
||||
):
|
||||
return layers.ClientTLSLayer(context)
|
||||
return client_layer_cls(context)
|
||||
else:
|
||||
# We already assign the next layer here so that ServerTLSLayer
|
||||
# We already assign the next layer here so that the server layer
|
||||
# knows that it can safely wait for a ClientHello.
|
||||
ret = layers.ServerTLSLayer(context)
|
||||
ret.child_layer = layers.ClientTLSLayer(context)
|
||||
ret = server_layer_cls(context)
|
||||
ret.child_layer = client_layer_cls(context)
|
||||
return ret
|
||||
|
||||
def is_destination_in_hosts(self, context: context.Context, hosts: Iterable[re.Pattern]) -> bool:
|
||||
@ -147,6 +155,79 @@ class NextLayer:
|
||||
for rex in hosts
|
||||
)
|
||||
|
||||
def get_http_layer(self, context: context.Context) -> Optional[layers.HttpLayer]:
|
||||
def s(*layers):
|
||||
return stack_match(context, layers)
|
||||
|
||||
# Setup the HTTP layer for a regular HTTP proxy ...
|
||||
if (
|
||||
s(modes.HttpProxy)
|
||||
or
|
||||
# or a "Secure Web Proxy", see https://www.chromium.org/developers/design-documents/secure-web-proxy
|
||||
s(modes.HttpProxy, (layers.ClientTLSLayer, layers.ClientQuicLayer))
|
||||
):
|
||||
return layers.HttpLayer(context, HTTPMode.regular)
|
||||
# ... or an upstream proxy.
|
||||
if (
|
||||
s(modes.HttpUpstreamProxy)
|
||||
or
|
||||
s(modes.HttpUpstreamProxy, (layers.ClientTLSLayer, layers.ClientQuicLayer))
|
||||
):
|
||||
return layers.HttpLayer(context, HTTPMode.upstream)
|
||||
return None
|
||||
|
||||
def detect_udp_tls(self, data_client: bytes) -> Optional[tuple[ClientHello, ClientSecurityLayerCls, ServerSecurityLayerCls]]:
|
||||
if len(data_client) == 0:
|
||||
return None
|
||||
|
||||
# first try DTLS (the parser may return None)
|
||||
try:
|
||||
client_hello = dtls_parse_client_hello(data_client)
|
||||
if client_hello is not None:
|
||||
return (client_hello, layers.ClientTLSLayer, layers.ServerTLSLayer)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# next try QUIC
|
||||
try:
|
||||
client_hello = quic_parse_client_hello(data_client)
|
||||
return (client_hello, layers.ClientQuicLayer, layers.ServerQuicLayer)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# that's all we currently have to offer
|
||||
return None
|
||||
|
||||
def raw_udp_layer(self, context: context.Context, ignore: bool = False) -> layer.Layer:
|
||||
def s(*layers):
|
||||
return stack_match(context, layers)
|
||||
|
||||
# for regular and upstream HTTP3, if we already created a client QUIC layer
|
||||
# we need a server and raw QUIC layer as well
|
||||
if (
|
||||
s(modes.HttpProxy, layers.ClientQuicLayer)
|
||||
or
|
||||
s(modes.HttpUpstreamProxy, layers.ClientQuicLayer)
|
||||
):
|
||||
server_layer = layers.ServerQuicLayer(context)
|
||||
server_layer.child_layer = layers.RawQuicLayer(context, ignore=ignore)
|
||||
return server_layer
|
||||
|
||||
# for reverse HTTP3 and QUIC, we need a client and raw QUIC layer
|
||||
elif (s(modes.ReverseProxy, layers.ServerQuicLayer)):
|
||||
client_layer = layers.ClientQuicLayer(context)
|
||||
client_layer.child_layer = layers.RawQuicLayer(context, ignore=ignore)
|
||||
return client_layer
|
||||
|
||||
# in other cases we assume `setup_tls_layer` happened, so if the
|
||||
# top layer is `ClientQuicLayer` we return a raw QUIC layer...
|
||||
elif isinstance(context.layers[-1], layers.ClientQuicLayer):
|
||||
return layers.RawQuicLayer(context, ignore=ignore)
|
||||
|
||||
# ... otherwise an UDP layer
|
||||
else:
|
||||
return layers.UDPLayer(context, ignore=ignore)
|
||||
|
||||
def next_layer(self, nextlayer: layer.NextLayer):
|
||||
if nextlayer.layer is None:
|
||||
nextlayer.layer = self._next_layer(
|
||||
@ -160,40 +241,29 @@ class NextLayer:
|
||||
) -> Optional[layer.Layer]:
|
||||
assert context.layers
|
||||
|
||||
# helper function to quickly check if the existing layer stack matches a particular configuration.
|
||||
def s(*layers):
|
||||
return stack_match(context, layers)
|
||||
|
||||
if context.client.transport_protocol == "tcp":
|
||||
if len(data_client) < 3 and not data_server:
|
||||
is_quic_stream = isinstance(context.layers[-1], layers.QuicStreamLayer)
|
||||
if (
|
||||
len(data_client) < 3
|
||||
and not data_server
|
||||
and not is_quic_stream
|
||||
):
|
||||
return None # not enough data yet to make a decision
|
||||
|
||||
# 1. check for --ignore/--allow
|
||||
ignore = self.ignore_connection(context.server.address, data_client)
|
||||
if ignore is True:
|
||||
return layers.TCPLayer(context, ignore=True)
|
||||
if ignore is None:
|
||||
if ignore is None and not is_quic_stream:
|
||||
return None
|
||||
|
||||
# 2. Check for TLS
|
||||
if is_tls_record_magic(data_client):
|
||||
return self.setup_tls_layer(context)
|
||||
|
||||
# 3. Setup the HTTP layer for a regular HTTP proxy
|
||||
if (
|
||||
s(modes.HttpProxy)
|
||||
or
|
||||
# or a "Secure Web Proxy", see https://www.chromium.org/developers/design-documents/secure-web-proxy
|
||||
s(modes.HttpProxy, layers.ClientTLSLayer)
|
||||
):
|
||||
return layers.HttpLayer(context, HTTPMode.regular)
|
||||
# 3b. ... or an upstream proxy.
|
||||
if (
|
||||
s(modes.HttpUpstreamProxy)
|
||||
or
|
||||
s(modes.HttpUpstreamProxy, layers.ClientTLSLayer)
|
||||
):
|
||||
return layers.HttpLayer(context, HTTPMode.upstream)
|
||||
# 3. Check for HTTP
|
||||
if http_layer := self.get_http_layer(context):
|
||||
return http_layer
|
||||
|
||||
# 4. Check for --tcp
|
||||
if self.is_destination_in_hosts(context, self.tcp_hosts):
|
||||
@ -215,31 +285,57 @@ class NextLayer:
|
||||
|
||||
elif context.client.transport_protocol == "udp":
|
||||
# unlike TCP, we make a decision immediately
|
||||
try:
|
||||
dtls_client_hello = dtls_parse_client_hello(data_client)
|
||||
except ValueError:
|
||||
dtls_client_hello = None
|
||||
tls = self.detect_udp_tls(data_client)
|
||||
|
||||
# 1. check for --ignore/--allow
|
||||
if self.ignore_connection(
|
||||
context.server.address,
|
||||
data_client,
|
||||
is_tls=lambda _: dtls_client_hello is not None,
|
||||
client_hello=lambda _: dtls_client_hello
|
||||
is_tls=lambda _: tls is not None,
|
||||
client_hello=lambda _: None if tls is None else tls[0]
|
||||
):
|
||||
return layers.UDPLayer(context, ignore=True)
|
||||
return self.raw_udp_layer(context, ignore=True)
|
||||
|
||||
# 2. Check for DTLS
|
||||
if dtls_client_hello is not None:
|
||||
return self.setup_tls_layer(context)
|
||||
# 2. Check for DTLS/QUIC
|
||||
if tls is not None:
|
||||
_, client_layer_cls, server_layer_cls = tls
|
||||
return self.setup_tls_layer(context, client_layer_cls, server_layer_cls)
|
||||
|
||||
# 3. (skipped for now, until we support HTTP/3)
|
||||
# 3. Check for HTTP
|
||||
if http_layer := self.get_http_layer(context):
|
||||
return http_layer
|
||||
|
||||
# 4. Check for --udp
|
||||
if self.is_destination_in_hosts(context, self.udp_hosts):
|
||||
return layers.UDPLayer(context)
|
||||
return self.raw_udp_layer(context)
|
||||
|
||||
# 5. Check for DNS
|
||||
# 5. Check for reverse modes
|
||||
if (isinstance(context.layers[0], modes.ReverseProxy)):
|
||||
scheme = cast(mode_specs.ReverseMode, context.client.proxy_mode).scheme
|
||||
if scheme in ("udp", "dtls"):
|
||||
return layers.UDPLayer(context)
|
||||
elif scheme == "http3":
|
||||
if isinstance(context.layers[-1], layers.ClientQuicLayer):
|
||||
return layers.HttpLayer(context, HTTPMode.transparent)
|
||||
else:
|
||||
return layers.ClientQuicLayer(context)
|
||||
elif scheme == "quic":
|
||||
if isinstance(context.layers[-1], layers.ClientQuicLayer):
|
||||
# the client supports QUIC, use raw layer
|
||||
return layers.RawQuicLayer(context)
|
||||
elif data_client:
|
||||
# we have received data, which was not a handshake, use UDP
|
||||
# on the client, and send datagrams over QUIC to the server
|
||||
return layers.UDPLayer(context)
|
||||
else:
|
||||
# wait for client data to make a decision
|
||||
return None
|
||||
elif scheme == "dns":
|
||||
return layers.DNSLayer(context)
|
||||
else:
|
||||
raise AssertionError(scheme)
|
||||
|
||||
# 6. Check for DNS
|
||||
try:
|
||||
dns.Message.unpack(data_client)
|
||||
except struct.error:
|
||||
@ -247,12 +343,8 @@ class NextLayer:
|
||||
else:
|
||||
return layers.DNSLayer(context)
|
||||
|
||||
# 6. Check for raw udp mode.
|
||||
if ctx.options.rawudp:
|
||||
return layers.UDPLayer(context)
|
||||
|
||||
# 7. Ignore the connection by default. (In the future, we'll assume HTTP/3)
|
||||
return layers.UDPLayer(context, ignore=True)
|
||||
# 7. Use raw mode.
|
||||
return self.raw_udp_layer(context)
|
||||
|
||||
else:
|
||||
raise AssertionError(context.client.transport_protocol)
|
||||
|
@ -19,12 +19,14 @@ from mitmproxy import (
|
||||
http,
|
||||
platform,
|
||||
tcp,
|
||||
udp,
|
||||
websocket,
|
||||
)
|
||||
from mitmproxy.connection import Address
|
||||
from mitmproxy.flow import Flow
|
||||
from mitmproxy.proxy import events, mode_specs, server_hooks
|
||||
from mitmproxy.proxy.layers.tcp import TcpMessageInjected
|
||||
from mitmproxy.proxy.layers.udp import UdpMessageInjected
|
||||
from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected
|
||||
from mitmproxy.proxy.mode_servers import ProxyConnectionHandler, ServerInstance, ServerManager
|
||||
from mitmproxy.utils import human, signals
|
||||
@ -305,6 +307,17 @@ class Proxyserver(ServerManager):
|
||||
except ValueError as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
@command.command("inject.udp")
|
||||
def inject_udp(self, flow: Flow, to_client: bool, message: bytes):
|
||||
if not isinstance(flow, udp.UDPFlow):
|
||||
logger.warning("Cannot inject UDP messages into non-UDP flows.")
|
||||
|
||||
event = UdpMessageInjected(flow, udp.UDPMessage(not to_client, message))
|
||||
try:
|
||||
self.inject_event(event)
|
||||
except ValueError as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def server_connect(self, data: server_hooks.ServerConnectionHookData):
|
||||
if data.server.sockname is None:
|
||||
data.server.sockname = self._connect_addr
|
||||
|
@ -2,15 +2,18 @@ import ipaddress
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import ssl
|
||||
from typing import Any, Optional, TypedDict
|
||||
|
||||
from aioquic.h3.connection import H3_ALPN
|
||||
from aioquic.tls import CipherSuite
|
||||
from OpenSSL import SSL, crypto
|
||||
from mitmproxy import certs, ctx, exceptions, connection, tls
|
||||
from mitmproxy.net import tls as net_tls
|
||||
from mitmproxy.options import CONF_BASENAME
|
||||
from mitmproxy.proxy import context
|
||||
from mitmproxy.proxy.layers import modes
|
||||
from mitmproxy.proxy.layers import tls as proxy_tls
|
||||
from mitmproxy.proxy.layers import tls as proxy_tls, quic
|
||||
|
||||
# We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default.
|
||||
# https://ssl-config.mozilla.org/#config=old
|
||||
@ -295,6 +298,94 @@ class TlsConfig:
|
||||
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
|
||||
def quic_start_client(self, tls_start: quic.QuicTlsData) -> None:
|
||||
"""Establish QUIC between client and proxy."""
|
||||
if tls_start.settings is not None:
|
||||
return # a user addon has already provided the settings.
|
||||
tls_start.settings = quic.QuicTlsSettings()
|
||||
|
||||
# keep the following part in sync with `tls_start_client`
|
||||
assert isinstance(tls_start.conn, connection.Client)
|
||||
|
||||
client: connection.Client = tls_start.conn
|
||||
server: connection.Server = tls_start.context.server
|
||||
|
||||
entry = self.get_cert(tls_start.context)
|
||||
|
||||
if not client.cipher_list and ctx.options.ciphers_client:
|
||||
client.cipher_list = ctx.options.ciphers_client.split(":")
|
||||
|
||||
if ctx.options.add_upstream_certs_to_client_chain: # pragma: no cover
|
||||
extra_chain_certs = server.certificate_list
|
||||
else:
|
||||
extra_chain_certs = []
|
||||
|
||||
# set context parameters
|
||||
if client.cipher_list:
|
||||
tls_start.settings.cipher_suites = [
|
||||
CipherSuite[cipher] for cipher in client.cipher_list
|
||||
]
|
||||
# if we don't have upstream ALPN, we allow all offered by the client
|
||||
tls_start.settings.alpn_protocols = [
|
||||
alpn.decode("ascii")
|
||||
for alpn in [
|
||||
alpn for alpn in (client.alpn, server.alpn) if alpn
|
||||
] or client.alpn_offers
|
||||
]
|
||||
|
||||
# set the certificates
|
||||
tls_start.settings.certificate = entry.cert._cert
|
||||
tls_start.settings.certificate_private_key = entry.privatekey
|
||||
tls_start.settings.certificate_chain = [
|
||||
cert._cert for cert in (*entry.chain_certs, *extra_chain_certs)
|
||||
]
|
||||
|
||||
def quic_start_server(self, tls_start: quic.QuicTlsData) -> None:
|
||||
"""Establish QUIC between proxy and server."""
|
||||
if tls_start.settings is not None:
|
||||
return # a user addon has already provided the settings.
|
||||
tls_start.settings = quic.QuicTlsSettings()
|
||||
|
||||
# keep the following part in sync with `tls_start_server`
|
||||
assert isinstance(tls_start.conn, connection.Server)
|
||||
|
||||
client: connection.Client = tls_start.context.client
|
||||
server: connection.Server = tls_start.conn
|
||||
assert server.address
|
||||
|
||||
if ctx.options.ssl_insecure:
|
||||
tls_start.settings.verify_mode = ssl.CERT_NONE
|
||||
else:
|
||||
tls_start.settings.verify_mode = ssl.CERT_REQUIRED
|
||||
|
||||
if server.sni is None:
|
||||
server.sni = client.sni or server.address[0]
|
||||
|
||||
if not server.alpn_offers:
|
||||
if client.alpn_offers:
|
||||
server.alpn_offers = tuple(client.alpn_offers)
|
||||
else:
|
||||
# aioquic fails if no ALPN is offered, so use H3
|
||||
server.alpn_offers = tuple(alpn.encode("ascii") for alpn in H3_ALPN)
|
||||
|
||||
if not server.cipher_list and ctx.options.ciphers_server:
|
||||
server.cipher_list = ctx.options.ciphers_server.split(":")
|
||||
|
||||
# set context parameters
|
||||
if server.cipher_list:
|
||||
tls_start.settings.cipher_suites = [
|
||||
CipherSuite[cipher] for cipher in server.cipher_list
|
||||
]
|
||||
if server.alpn_offers:
|
||||
tls_start.settings.alpn_protocols = [
|
||||
alpn.decode("ascii") for alpn in server.alpn_offers
|
||||
]
|
||||
|
||||
# set the certificates
|
||||
# NOTE client certificates are not supported
|
||||
tls_start.settings.ca_path = ctx.options.ssl_verify_upstream_trusted_confdir
|
||||
tls_start.settings.ca_file = ctx.options.ssl_verify_upstream_trusted_ca
|
||||
|
||||
def running(self):
|
||||
# FIXME: We have a weird bug where the contract for configure is not followed and it is never called with
|
||||
# confdir or command_history as updated.
|
||||
|
@ -283,6 +283,7 @@ class CertStoreEntry:
|
||||
cert: Cert
|
||||
privatekey: rsa.RSAPrivateKey
|
||||
chain_file: Optional[Path]
|
||||
chain_certs: list[Cert]
|
||||
|
||||
|
||||
TCustomCertId = str # manually provided certs (e.g. mitmproxy's --certs)
|
||||
@ -311,6 +312,15 @@ class CertStore:
|
||||
self.default_privatekey = default_privatekey
|
||||
self.default_ca = default_ca
|
||||
self.default_chain_file = default_chain_file
|
||||
self.default_chain_certs = (
|
||||
[
|
||||
Cert.from_pem(chunk)
|
||||
for chunk in re.split(rb"(?=-----BEGIN( [A-Z]+)+-----)", self.default_chain_file.read_bytes())
|
||||
if chunk.startswith(b"-----BEGIN CERTIFICATE-----")
|
||||
]
|
||||
if self.default_chain_file
|
||||
else [default_ca]
|
||||
)
|
||||
self.dhparams = dhparams
|
||||
self.certs = {}
|
||||
self.expire_queue = []
|
||||
@ -453,7 +463,7 @@ class CertStore:
|
||||
except ValueError:
|
||||
key = self.default_privatekey
|
||||
|
||||
self.add_cert(CertStoreEntry(cert, key, path), spec)
|
||||
self.add_cert(CertStoreEntry(cert, key, path, [cert]), spec)
|
||||
|
||||
def add_cert(self, entry: CertStoreEntry, *names: str) -> None:
|
||||
"""
|
||||
@ -516,6 +526,7 @@ class CertStore:
|
||||
),
|
||||
privatekey=self.default_privatekey,
|
||||
chain_file=self.default_chain_file,
|
||||
chain_certs=self.default_chain_certs,
|
||||
)
|
||||
self.certs[(commonname, tuple(sans))] = entry
|
||||
self.expire(entry)
|
||||
|
@ -53,7 +53,7 @@ class Connection(serializable.SerializableDataclass, metaclass=ABCMeta):
|
||||
sockname: Optional[Address]
|
||||
"""Our local `(ip, port)` tuple for this connection."""
|
||||
|
||||
state: ConnectionState
|
||||
state: ConnectionState = field(default=ConnectionState.CLOSED, metadata={"serialize": False})
|
||||
"""The current connection state."""
|
||||
|
||||
# all connections have a unique id. While
|
||||
@ -172,8 +172,6 @@ class Client(Connection):
|
||||
sockname: Address
|
||||
"""The local address we received this connection on."""
|
||||
|
||||
state: ConnectionState = field(default=ConnectionState.OPEN)
|
||||
|
||||
mitmcert: Optional[certs.Cert] = None
|
||||
"""
|
||||
The certificate used by mitmproxy to establish TLS with the client.
|
||||
@ -265,8 +263,6 @@ class Server(Connection):
|
||||
"""The server's resolved `(ip, port)` tuple. Will be set during connection establishment."""
|
||||
sockname: Optional[Address] = None
|
||||
|
||||
state: ConnectionState = field(default=ConnectionState.CLOSED)
|
||||
|
||||
timestamp_start: Optional[float] = None
|
||||
"""*Timestamp:* TCP SYN sent."""
|
||||
timestamp_tcp_setup: Optional[float] = None
|
||||
|
@ -36,13 +36,9 @@ from . import (
|
||||
graphql,
|
||||
grpc,
|
||||
mqtt,
|
||||
http3,
|
||||
)
|
||||
|
||||
try:
|
||||
from . import http3
|
||||
except ImportError:
|
||||
# FIXME: Remove once QUIC is merged.
|
||||
http3 = None # type: ignore
|
||||
from .base import View, KEY_MAX, format_text, format_dict, TViewResult
|
||||
from ..tcp import TCPMessage
|
||||
from ..udp import UDPMessage
|
||||
@ -238,8 +234,7 @@ add(protobuf.ViewProtobuf())
|
||||
add(msgpack.ViewMsgPack())
|
||||
add(grpc.ViewGrpcProtobuf())
|
||||
add(mqtt.ViewMQTT())
|
||||
if http3 is not None:
|
||||
add(http3.ViewHttp3())
|
||||
add(http3.ViewHttp3())
|
||||
|
||||
__all__ = [
|
||||
"View",
|
||||
|
@ -1,13 +1,14 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
from aioquic.h3.connection import Setting, parse_settings
|
||||
|
||||
from mitmproxy import flow, tcp
|
||||
from . import base
|
||||
from .hex import ViewHex
|
||||
from ..proxy.layers.http import is_h3_alpn # type: ignore
|
||||
from ..proxy.layers.http import is_h3_alpn
|
||||
|
||||
from aioquic.buffer import Buffer, BufferReadError
|
||||
import pylsqpack
|
||||
@ -75,7 +76,7 @@ class StreamType:
|
||||
@dataclass
|
||||
class ConnectionState:
|
||||
message_count: int = 0
|
||||
frames: dict[int, list[Frame | StreamType]] = field(default_factory=dict)
|
||||
frames: dict[int, list[Union[Frame, StreamType]]] = field(default_factory=dict)
|
||||
client_buf: bytearray = field(default_factory=bytearray)
|
||||
server_buf: bytearray = field(default_factory=bytearray)
|
||||
|
||||
@ -89,8 +90,8 @@ class ViewHttp3(base.View):
|
||||
def __call__(
|
||||
self,
|
||||
data,
|
||||
flow: flow.Flow | None = None,
|
||||
tcp_message: tcp.TCPMessage | None = None,
|
||||
flow: Optional[flow.Flow] = None,
|
||||
tcp_message: Optional[tcp.TCPMessage] = None,
|
||||
**metadata
|
||||
):
|
||||
assert isinstance(flow, tcp.TCPFlow)
|
||||
@ -109,7 +110,6 @@ class ViewHttp3(base.View):
|
||||
h3_buf = Buffer(data=bytes(buf[:8]))
|
||||
stream_type = h3_buf.pull_uint_var()
|
||||
consumed = h3_buf.tell()
|
||||
assert consumed == 1
|
||||
del buf[:consumed]
|
||||
state.frames[0] = [
|
||||
StreamType(stream_type)
|
||||
@ -147,13 +147,13 @@ class ViewHttp3(base.View):
|
||||
def render_priority(
|
||||
self,
|
||||
data: bytes,
|
||||
flow: flow.Flow | None = None,
|
||||
flow: Optional[flow.Flow] = None,
|
||||
**metadata
|
||||
) -> float:
|
||||
return 2 * float(bool(flow and is_h3_alpn(flow.client_conn.alpn))) * float(isinstance(flow, tcp.TCPFlow))
|
||||
|
||||
|
||||
def fmt_frames(frames: list[Frame | StreamType]) -> Iterator[base.TViewLine]:
|
||||
def fmt_frames(frames: list[Union[Frame, StreamType]]) -> Iterator[base.TViewLine]:
|
||||
for i, frame in enumerate(frames):
|
||||
if i > 0:
|
||||
yield [("text", "")]
|
||||
|
@ -69,6 +69,8 @@ class SerializableDataclass(Serializable):
|
||||
fields = []
|
||||
# noinspection PyDataclass
|
||||
for field in dataclasses.fields(cls):
|
||||
if field.metadata.get("serialize", True) is False:
|
||||
continue
|
||||
if isinstance(field.type, str):
|
||||
field.type = hints[field.name]
|
||||
fields.append(field)
|
||||
|
@ -285,6 +285,10 @@ class Message(serializable.Serializable):
|
||||
def is_http2(self) -> bool:
|
||||
return self.data.http_version == b"HTTP/2.0"
|
||||
|
||||
@property
|
||||
def is_http3(self) -> bool:
|
||||
return self.data.http_version == b"HTTP/3"
|
||||
|
||||
@property
|
||||
def headers(self) -> Headers:
|
||||
"""
|
||||
@ -763,7 +767,7 @@ class Request(Message):
|
||||
|
||||
*See also:* `Request.authority`,`Request.host`, `Request.pretty_host`
|
||||
"""
|
||||
if self.is_http2:
|
||||
if self.is_http2 or self.is_http3:
|
||||
return self.authority or self.data.headers.get("Host", None)
|
||||
else:
|
||||
return self.data.headers.get("Host", None)
|
||||
@ -771,13 +775,13 @@ class Request(Message):
|
||||
@host_header.setter
|
||||
def host_header(self, val: Union[None, str, bytes]) -> None:
|
||||
if val is None:
|
||||
if self.is_http2:
|
||||
if self.is_http2 or self.is_http3:
|
||||
self.data.authority = b""
|
||||
self.headers.pop("Host", None)
|
||||
else:
|
||||
if self.is_http2:
|
||||
if self.is_http2 or self.is_http3:
|
||||
self.authority = val # type: ignore
|
||||
if not self.is_http2 or "Host" in self.headers:
|
||||
if not (self.is_http2 or self.is_http3) or "Host" in self.headers:
|
||||
# For h2, we only overwrite, but not create, as :authority is the h2 host header.
|
||||
self.headers["Host"] = val
|
||||
|
||||
|
@ -414,6 +414,13 @@ def convert_18_19(data):
|
||||
return data
|
||||
|
||||
|
||||
def convert_19_20(data):
|
||||
data["version"] = 20
|
||||
data["client_conn"].pop("state", None)
|
||||
data["server_conn"].pop("state", None)
|
||||
return data
|
||||
|
||||
|
||||
def _convert_dict_keys(o: Any) -> Any:
|
||||
if isinstance(o, dict):
|
||||
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
|
||||
@ -477,6 +484,7 @@ converters = {
|
||||
16: convert_16_17,
|
||||
17: convert_17_18,
|
||||
18: convert_18_19,
|
||||
19: convert_19_20,
|
||||
}
|
||||
|
||||
|
||||
|
@ -8,7 +8,7 @@ from typing import Literal
|
||||
from mitmproxy.net import check
|
||||
|
||||
ServerSpec = tuple[
|
||||
Literal["http", "https", "tls", "dtls", "tcp", "udp", "dns"],
|
||||
Literal["http", "https", "http3", "tls", "dtls", "tcp", "udp", "dns", "quic"],
|
||||
tuple[str, int]
|
||||
]
|
||||
|
||||
@ -45,7 +45,7 @@ def parse(server_spec: str, default_scheme: str) -> ServerSpec:
|
||||
scheme = m.group("scheme")
|
||||
else:
|
||||
scheme = default_scheme
|
||||
if scheme not in ("http", "https", "tls", "dtls", "tcp", "udp", "dns"):
|
||||
if scheme not in ("http", "https", "http3", "tls", "dtls", "tcp", "udp", "dns", "quic"):
|
||||
raise ValueError(f"Invalid server scheme: {scheme}")
|
||||
|
||||
host = m.group("host")
|
||||
@ -62,6 +62,8 @@ def parse(server_spec: str, default_scheme: str) -> ServerSpec:
|
||||
port = {
|
||||
"http": 80,
|
||||
"https": 443,
|
||||
"quic": 443,
|
||||
"http3": 443,
|
||||
"dns": 53,
|
||||
}[scheme]
|
||||
except KeyError:
|
||||
|
@ -114,7 +114,7 @@ class UdpServer(DrainableDatagramProtocol):
|
||||
|
||||
|
||||
class DatagramReader:
|
||||
_packets: asyncio.Queue
|
||||
_packets: asyncio.Queue[bytes]
|
||||
_eof: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -149,13 +149,6 @@ class Options(optmanager.OptManager):
|
||||
"Enable/disable raw TCP connections. "
|
||||
"TCP connections are enabled by default. ",
|
||||
)
|
||||
self.add_option(
|
||||
"rawudp",
|
||||
bool,
|
||||
True,
|
||||
"Enable/disable raw UDP connections. "
|
||||
"UDP connections are enabled by default. ",
|
||||
)
|
||||
self.add_option(
|
||||
"ssl_insecure",
|
||||
bool,
|
||||
|
@ -94,6 +94,8 @@ class CloseConnection(ConnectionCommand):
|
||||
all other connections will ultimately be closed during cleanup.
|
||||
"""
|
||||
|
||||
|
||||
class CloseTcpConnection(CloseConnection):
|
||||
half_close: bool
|
||||
"""
|
||||
If True, only close our half of the connection by sending a FIN packet.
|
||||
|
@ -3,6 +3,7 @@ When IO actions occur at the proxy server, they are passed down to layers as eve
|
||||
Events represent the only way for layers to receive new data from sockets.
|
||||
The counterpart to events are commands.
|
||||
"""
|
||||
import typing
|
||||
import warnings
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
@ -72,7 +73,7 @@ class CommandCompleted(Event):
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
command_cls = cls.__annotations__.get("command", None)
|
||||
command_cls = typing.get_type_hints(cls).get("command", None)
|
||||
valid_command_subclass = (
|
||||
isinstance(command_cls, type)
|
||||
and issubclass(command_cls, commands.Command)
|
||||
@ -80,7 +81,7 @@ class CommandCompleted(Event):
|
||||
)
|
||||
if not valid_command_subclass:
|
||||
warnings.warn(
|
||||
f"{command_cls} needs a properly annotated command attribute.",
|
||||
f"{cls} needs a properly annotated command attribute.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
if command_cls in command_reply_subclasses:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from . import modes
|
||||
from .dns import DNSLayer
|
||||
from .http import HttpLayer
|
||||
from .quic import QuicStreamLayer, RawQuicLayer, ClientQuicLayer, ServerQuicLayer
|
||||
from .tcp import TCPLayer
|
||||
from .udp import UDPLayer
|
||||
from .tls import ClientTLSLayer, ServerTLSLayer
|
||||
@ -10,9 +11,13 @@ __all__ = [
|
||||
"modes",
|
||||
"DNSLayer",
|
||||
"HttpLayer",
|
||||
"QuicStreamLayer",
|
||||
"RawQuicLayer",
|
||||
"TCPLayer",
|
||||
"UDPLayer",
|
||||
"ClientQuicLayer",
|
||||
"ClientTLSLayer",
|
||||
"ServerQuicLayer",
|
||||
"ServerTLSLayer",
|
||||
"WebsocketLayer",
|
||||
]
|
||||
|
@ -9,12 +9,12 @@ from typing import Optional, Union
|
||||
|
||||
import wsproto.handshake
|
||||
from mitmproxy import flow, http
|
||||
from mitmproxy.connection import Connection, Server
|
||||
from mitmproxy.connection import Connection, Server, TransportProtocol
|
||||
from mitmproxy.net import server_spec
|
||||
from mitmproxy.net.http import status_codes, url
|
||||
from mitmproxy.net.http.http1 import expected_http_body_size
|
||||
from mitmproxy.proxy import commands, events, layer, tunnel
|
||||
from mitmproxy.proxy.layers import tcp, tls, websocket
|
||||
from mitmproxy.proxy.layers import quic, tcp, tls, websocket
|
||||
from mitmproxy.proxy.layers.http import _upstream_proxy
|
||||
from mitmproxy.proxy.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
@ -44,6 +44,8 @@ from ._hooks import ( # noqa
|
||||
)
|
||||
from ._http1 import Http1Client, Http1Connection, Http1Server
|
||||
from ._http2 import Http2Client, Http2Server
|
||||
from ._http3 import Http3Client, Http3Server
|
||||
from ..quic import QuicStreamEvent
|
||||
from ...context import Context
|
||||
from ...mode_specs import ReverseMode, UpstreamMode
|
||||
|
||||
@ -65,6 +67,10 @@ def validate_request(mode: HTTPMode, request: http.Request) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def is_h3_alpn(alpn: Optional[bytes]) -> bool:
|
||||
return alpn == b"h3" or (alpn is not None and alpn.startswith(b"h3-"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetHttpConnection(HttpCommand):
|
||||
"""
|
||||
@ -75,6 +81,7 @@ class GetHttpConnection(HttpCommand):
|
||||
address: tuple[str, int]
|
||||
tls: bool
|
||||
via: Optional[server_spec.ServerSpec]
|
||||
transport_protocol: TransportProtocol = "tcp"
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
@ -85,6 +92,7 @@ class GetHttpConnection(HttpCommand):
|
||||
and self.address == connection.address
|
||||
and self.tls == connection.tls
|
||||
and self.via == connection.via
|
||||
and self.transport_protocol == connection.transport_protocol
|
||||
)
|
||||
|
||||
|
||||
@ -220,7 +228,7 @@ class HttpStream(layer.Layer):
|
||||
"https" if self.context.client.tls else "http"
|
||||
)
|
||||
|
||||
if self.mode is HTTPMode.regular and not self.flow.request.is_http2:
|
||||
if self.mode is HTTPMode.regular and not (self.flow.request.is_http2 or self.flow.request.is_http3):
|
||||
# Set the request target to origin-form for HTTP/1, some servers don't support absolute-form requests.
|
||||
# see https://github.com/mitmproxy/mitmproxy/issues/1759
|
||||
self.flow.request.authority = ""
|
||||
@ -665,6 +673,7 @@ class HttpStream(layer.Layer):
|
||||
(self.flow.request.host, self.flow.request.port),
|
||||
self.flow.request.scheme == "https",
|
||||
self.flow.server_conn.via,
|
||||
self.flow.server_conn.transport_protocol,
|
||||
)
|
||||
if err:
|
||||
yield from self.handle_protocol_error(
|
||||
@ -784,7 +793,8 @@ class HttpStream(layer.Layer):
|
||||
# The easiest approach for this is to just always full close for now.
|
||||
# Alternatively, we could signal that we want a half close only through ResponseProtocolError,
|
||||
# but that is more complex to implement.
|
||||
command.half_close = False
|
||||
if isinstance(command, commands.CloseTcpConnection):
|
||||
command = commands.CloseConnection(command.connection)
|
||||
yield command
|
||||
else:
|
||||
yield command
|
||||
@ -829,7 +839,9 @@ class HttpLayer(layer.Layer):
|
||||
self.command_sources = {}
|
||||
|
||||
http_conn: HttpConnection
|
||||
if self.context.client.alpn == b"h2":
|
||||
if is_h3_alpn(self.context.client.alpn):
|
||||
http_conn = Http3Server(context.fork())
|
||||
elif self.context.client.alpn == b"h2":
|
||||
http_conn = Http2Server(context.fork())
|
||||
else:
|
||||
http_conn = Http1Server(context.fork())
|
||||
@ -846,9 +858,6 @@ class HttpLayer(layer.Layer):
|
||||
proxy_mode = self.context.client.proxy_mode
|
||||
assert isinstance(proxy_mode, UpstreamMode)
|
||||
self.context.server.via = (proxy_mode.scheme, proxy_mode.address)
|
||||
elif isinstance(event, events.Wakeup):
|
||||
stream = self.command_sources.pop(event.command)
|
||||
yield from self.event_to_child(stream, event)
|
||||
elif isinstance(event, events.CommandCompleted):
|
||||
stream = self.command_sources.pop(event.command)
|
||||
yield from self.event_to_child(stream, event)
|
||||
@ -881,10 +890,13 @@ class HttpLayer(layer.Layer):
|
||||
if isinstance(event, events.ConnectionClosed):
|
||||
# The peer has closed it - let's close it too!
|
||||
yield commands.CloseConnection(event.connection)
|
||||
elif isinstance(event, events.DataReceived):
|
||||
# The peer has sent data. This can happen with HTTP/2 servers that already send a settings frame.
|
||||
elif isinstance(event, (events.DataReceived, QuicStreamEvent)):
|
||||
# The peer has sent data or another connection activity occurred.
|
||||
# This can happen with HTTP/2 servers that already send a settings frame.
|
||||
child_layer: HttpConnection
|
||||
if self.context.server.alpn == b"h2":
|
||||
if is_h3_alpn(self.context.server.alpn):
|
||||
child_layer = Http3Client(self.context.fork())
|
||||
elif self.context.server.alpn == b"h2":
|
||||
child_layer = Http2Client(self.context.fork())
|
||||
else:
|
||||
child_layer = Http1Client(self.context.fork())
|
||||
@ -1000,7 +1012,7 @@ class HttpLayer(layer.Layer):
|
||||
|
||||
if not can_use_context_connection:
|
||||
|
||||
context.server = Server(address=event.address)
|
||||
context.server = Server(address=event.address, transport_protocol=event.transport_protocol)
|
||||
|
||||
if event.via:
|
||||
context.server.via = event.via
|
||||
@ -1017,7 +1029,12 @@ class HttpLayer(layer.Layer):
|
||||
context.server.sni = self.context.client.sni or event.address[0]
|
||||
else:
|
||||
context.server.sni = event.address[0]
|
||||
stack /= tls.ServerTLSLayer(context)
|
||||
if context.server.transport_protocol == "tcp":
|
||||
stack /= tls.ServerTLSLayer(context)
|
||||
elif context.server.transport_protocol == "udp":
|
||||
stack /= quic.ServerQuicLayer(context)
|
||||
else:
|
||||
raise AssertionError(context.server.transport_protocol) # pragma: no cover
|
||||
|
||||
stack /= HttpClient(context)
|
||||
|
||||
@ -1067,7 +1084,9 @@ class HttpClient(layer.Layer):
|
||||
else:
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if not err:
|
||||
if self.context.server.alpn == b"h2":
|
||||
if is_h3_alpn(self.context.server.alpn):
|
||||
self.child_layer = Http3Client(self.context)
|
||||
elif self.context.server.alpn == b"h2":
|
||||
self.child_layer = Http2Client(self.context)
|
||||
else:
|
||||
self.child_layer = Http1Client(self.context)
|
||||
|
@ -77,7 +77,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
state = start
|
||||
|
||||
def read_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.stream_id
|
||||
assert self.stream_id is not None
|
||||
while True:
|
||||
try:
|
||||
if isinstance(event, events.DataReceived):
|
||||
@ -189,7 +189,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
|
||||
# This simplifies our connection management quite a bit as we can rely on
|
||||
# the proxyserver's max-connection-per-server throttling.
|
||||
or (self.request.is_http2 and isinstance(self, Http1Client))
|
||||
or ((self.request.is_http2 or self.request.is_http3) and isinstance(self, Http1Client))
|
||||
)
|
||||
if connection_done:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
@ -223,7 +223,7 @@ class Http1Server(Http1Connection):
|
||||
if isinstance(event, ResponseHeaders):
|
||||
self.response = response = event.response
|
||||
|
||||
if response.is_http2:
|
||||
if response.is_http2 or response.is_http3:
|
||||
response = response.copy()
|
||||
# Convert to an HTTP/1 response.
|
||||
response.http_version = "HTTP/1.1"
|
||||
@ -331,7 +331,7 @@ class Http1Client(Http1Connection):
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
|
||||
if not self.stream_id:
|
||||
if self.stream_id is None:
|
||||
assert isinstance(event, RequestHeaders)
|
||||
self.stream_id = event.stream_id
|
||||
self.request = event.request
|
||||
@ -339,7 +339,7 @@ class Http1Client(Http1Connection):
|
||||
|
||||
if isinstance(event, RequestHeaders):
|
||||
request = event.request
|
||||
if request.is_http2:
|
||||
if request.is_http2 or request.is_http3:
|
||||
# Convert to an HTTP/1 request.
|
||||
request = (
|
||||
request.copy()
|
||||
@ -369,7 +369,7 @@ class Http1Client(Http1Connection):
|
||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
||||
yield commands.CloseConnection(self.conn, half_close=True)
|
||||
yield commands.CloseTcpConnection(self.conn, half_close=True)
|
||||
yield from self.mark_done(request=True)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
@ -383,7 +383,7 @@ class Http1Client(Http1Connection):
|
||||
yield commands.Log(f"Unexpected data from server: {bytes(self.buf)!r}")
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
assert self.stream_id
|
||||
assert self.stream_id is not None
|
||||
|
||||
response_head = self.buf.maybe_extract_lines()
|
||||
if response_head:
|
||||
|
@ -313,6 +313,48 @@ def normalize_h2_headers(headers: list[tuple[bytes, bytes]]) -> CommandGenerator
|
||||
headers[i] = (headers[i][0].lower(), headers[i][1])
|
||||
|
||||
|
||||
def format_h2_request_headers(
|
||||
context: Context,
|
||||
event: RequestHeaders,
|
||||
) -> CommandGenerator[list[tuple[bytes, bytes]]]:
|
||||
pseudo_headers = [
|
||||
(b":method", event.request.data.method),
|
||||
(b":scheme", event.request.data.scheme),
|
||||
(b":path", event.request.data.path),
|
||||
]
|
||||
if event.request.authority:
|
||||
pseudo_headers.append((b":authority", event.request.data.authority))
|
||||
|
||||
if event.request.is_http2 or event.request.is_http3:
|
||||
hdrs = list(event.request.headers.fields)
|
||||
if context.options.normalize_outbound_headers:
|
||||
yield from normalize_h2_headers(hdrs)
|
||||
else:
|
||||
headers = event.request.headers
|
||||
if not event.request.authority and "host" in headers:
|
||||
headers = headers.copy()
|
||||
pseudo_headers.append((b":authority", headers.pop(b"host")))
|
||||
hdrs = normalize_h1_headers(list(headers.fields), True)
|
||||
|
||||
return pseudo_headers + hdrs
|
||||
|
||||
|
||||
def format_h2_response_headers(
|
||||
context: Context,
|
||||
event: ResponseHeaders,
|
||||
) -> CommandGenerator[list[tuple[bytes, bytes]]]:
|
||||
headers = [
|
||||
(b":status", b"%d" % event.response.status_code),
|
||||
*event.response.headers.fields,
|
||||
]
|
||||
if event.response.is_http2 or event.response.is_http3:
|
||||
if context.options.normalize_outbound_headers:
|
||||
yield from normalize_h2_headers(headers)
|
||||
else:
|
||||
headers = normalize_h1_headers(headers, False)
|
||||
return headers
|
||||
|
||||
|
||||
class Http2Server(Http2Connection):
|
||||
h2_conf = h2.config.H2Configuration(
|
||||
**Http2Connection.h2_conf_defaults,
|
||||
@ -330,19 +372,9 @@ class Http2Server(Http2Connection):
|
||||
def _handle_event(self, event: Event) -> CommandGenerator[None]:
|
||||
if isinstance(event, ResponseHeaders):
|
||||
if self.is_open_for_us(event.stream_id):
|
||||
headers = [
|
||||
(b":status", b"%d" % event.response.status_code),
|
||||
*event.response.headers.fields,
|
||||
]
|
||||
if event.response.is_http2:
|
||||
if self.context.options.normalize_outbound_headers:
|
||||
yield from normalize_h2_headers(headers)
|
||||
else:
|
||||
headers = normalize_h1_headers(headers, False)
|
||||
|
||||
self.h2_conn.send_headers(
|
||||
event.stream_id,
|
||||
headers,
|
||||
headers=(yield from format_h2_response_headers(self.context, event)),
|
||||
end_stream=event.end_stream,
|
||||
)
|
||||
yield SendData(self.conn, self.h2_conn.data_to_send())
|
||||
@ -485,28 +517,9 @@ class Http2Client(Http2Connection):
|
||||
yield RequestWakeup(self.context.options.http2_ping_keepalive)
|
||||
yield from super()._handle_event(event)
|
||||
elif isinstance(event, RequestHeaders):
|
||||
pseudo_headers = [
|
||||
(b":method", event.request.data.method),
|
||||
(b":scheme", event.request.data.scheme),
|
||||
(b":path", event.request.data.path),
|
||||
]
|
||||
if event.request.authority:
|
||||
pseudo_headers.append((b":authority", event.request.data.authority))
|
||||
|
||||
if event.request.is_http2:
|
||||
hdrs = list(event.request.headers.fields)
|
||||
if self.context.options.normalize_outbound_headers:
|
||||
yield from normalize_h2_headers(hdrs)
|
||||
else:
|
||||
headers = event.request.headers
|
||||
if not event.request.authority and "host" in headers:
|
||||
headers = headers.copy()
|
||||
pseudo_headers.append((b":authority", headers.pop(b"host")))
|
||||
hdrs = normalize_h1_headers(list(headers.fields), True)
|
||||
|
||||
self.h2_conn.send_headers(
|
||||
event.stream_id,
|
||||
pseudo_headers + hdrs,
|
||||
headers=(yield from format_h2_request_headers(self.context, event)),
|
||||
end_stream=event.end_stream,
|
||||
)
|
||||
self.streams[event.stream_id] = StreamState.EXPECTING_HEADERS
|
||||
@ -642,6 +655,10 @@ def parse_h2_response_headers(
|
||||
|
||||
|
||||
__all__ = [
|
||||
"format_h2_request_headers",
|
||||
"format_h2_response_headers",
|
||||
"parse_h2_request_headers",
|
||||
"parse_h2_response_headers",
|
||||
"Http2Client",
|
||||
"Http2Server",
|
||||
]
|
||||
|
285
mitmproxy/proxy/layers/http/_http3.py
Normal file
285
mitmproxy/proxy/layers/http/_http3.py
Normal file
@ -0,0 +1,285 @@
|
||||
from abc import abstractmethod
|
||||
import time
|
||||
from typing import Dict, Union
|
||||
|
||||
from aioquic.h3.connection import (
|
||||
ErrorCode as H3ErrorCode,
|
||||
FrameUnexpected as H3FrameUnexpected,
|
||||
)
|
||||
from aioquic.h3.events import DataReceived, HeadersReceived, PushPromiseReceived
|
||||
|
||||
from mitmproxy import connection, http, version
|
||||
from mitmproxy.net.http import status_codes
|
||||
from mitmproxy.proxy import commands, context, events, layer
|
||||
from mitmproxy.proxy.layers.quic import (
|
||||
QuicConnectionClosed,
|
||||
QuicStreamEvent,
|
||||
StopQuicStream,
|
||||
error_code_to_str,
|
||||
)
|
||||
from mitmproxy.proxy.utils import expect
|
||||
|
||||
from . import (
|
||||
RequestData,
|
||||
RequestEndOfMessage,
|
||||
RequestHeaders,
|
||||
RequestProtocolError,
|
||||
RequestTrailers,
|
||||
ResponseData,
|
||||
ResponseEndOfMessage,
|
||||
ResponseHeaders,
|
||||
ResponseProtocolError,
|
||||
ResponseTrailers,
|
||||
)
|
||||
from ._base import (
|
||||
HttpConnection,
|
||||
HttpEvent,
|
||||
ReceiveHttp,
|
||||
format_error,
|
||||
)
|
||||
from ._http2 import (
|
||||
format_h2_request_headers,
|
||||
format_h2_response_headers,
|
||||
parse_h2_request_headers,
|
||||
parse_h2_response_headers,
|
||||
)
|
||||
from ._http_h3 import LayeredH3Connection, StreamReset, TrailersReceived
|
||||
|
||||
|
||||
class Http3Connection(HttpConnection):
|
||||
h3_conn: LayeredH3Connection
|
||||
|
||||
ReceiveData: type[Union[RequestData, ResponseData]]
|
||||
ReceiveEndOfMessage: type[Union[RequestEndOfMessage, ResponseEndOfMessage]]
|
||||
ReceiveProtocolError: type[Union[RequestProtocolError, ResponseProtocolError]]
|
||||
ReceiveTrailers: type[Union[RequestTrailers, ResponseTrailers]]
|
||||
|
||||
def __init__(self, context: context.Context, conn: connection.Connection):
|
||||
super().__init__(context, conn)
|
||||
self.h3_conn = LayeredH3Connection(self.conn, is_client=self.conn is self.context.server)
|
||||
self._stream_protocol_errors: dict[int, int] = {}
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.Start):
|
||||
yield from self.h3_conn.transmit()
|
||||
|
||||
# send mitmproxy HTTP events over the H3 connection
|
||||
elif isinstance(event, HttpEvent):
|
||||
try:
|
||||
if isinstance(event, (RequestData, ResponseData)):
|
||||
self.h3_conn.send_data(event.stream_id, event.data)
|
||||
elif isinstance(event, (RequestHeaders, ResponseHeaders)):
|
||||
headers = yield from (
|
||||
format_h2_request_headers(self.context, event)
|
||||
if isinstance(event, RequestHeaders)
|
||||
else format_h2_response_headers(self.context, event)
|
||||
)
|
||||
self.h3_conn.send_headers(event.stream_id, headers, end_stream=event.end_stream)
|
||||
elif isinstance(event, (RequestTrailers, ResponseTrailers)):
|
||||
self.h3_conn.send_trailers(event.stream_id, [*event.trailers.fields])
|
||||
elif isinstance(event, (RequestEndOfMessage, ResponseEndOfMessage)):
|
||||
self.h3_conn.end_stream(event.stream_id)
|
||||
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
|
||||
code = {
|
||||
status_codes.CLIENT_CLOSED_REQUEST: H3ErrorCode.H3_REQUEST_CANCELLED.value,
|
||||
}.get(event.code, H3ErrorCode.H3_INTERNAL_ERROR.value)
|
||||
self._stream_protocol_errors[event.stream_id] = code
|
||||
send_error_message = (
|
||||
isinstance(event, ResponseProtocolError)
|
||||
and not self.h3_conn.has_sent_headers(event.stream_id)
|
||||
and event.code != status_codes.NO_RESPONSE
|
||||
)
|
||||
if send_error_message:
|
||||
self.h3_conn.send_headers(
|
||||
event.stream_id,
|
||||
[
|
||||
(b":status", b"%d" % event.code),
|
||||
(b"server", version.MITMPROXY.encode()),
|
||||
(b"content-type", b"text/html"),
|
||||
],
|
||||
)
|
||||
self.h3_conn.send_data(
|
||||
event.stream_id,
|
||||
format_error(event.code, event.message),
|
||||
end_stream=True,
|
||||
)
|
||||
else:
|
||||
self.h3_conn.reset_stream(event.stream_id, code)
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
|
||||
except H3FrameUnexpected as e:
|
||||
# Http2Connection also ignores HttpEvents that violate the current stream state
|
||||
yield commands.Log(f"Received {event!r} unexpectedly: {e}")
|
||||
|
||||
else:
|
||||
# transmit buffered data
|
||||
yield from self.h3_conn.transmit()
|
||||
|
||||
# forward stream messages from the QUIC layer to the H3 connection
|
||||
elif isinstance(event, QuicStreamEvent):
|
||||
h3_events = self.h3_conn.handle_stream_event(event)
|
||||
if event.stream_id in self._stream_protocol_errors:
|
||||
# we already reset or ended the stream, tell the peer to stop
|
||||
# (this is a noop if the peer already did the same)
|
||||
yield StopQuicStream(
|
||||
self.conn,
|
||||
event.stream_id,
|
||||
self._stream_protocol_errors[event.stream_id],
|
||||
)
|
||||
else:
|
||||
for h3_event in h3_events:
|
||||
if isinstance(h3_event, StreamReset):
|
||||
if h3_event.push_id is None:
|
||||
err_str = error_code_to_str(h3_event.error_code)
|
||||
err_code = {
|
||||
H3ErrorCode.H3_REQUEST_CANCELLED.value: status_codes.CLIENT_CLOSED_REQUEST,
|
||||
}.get(h3_event.error_code, self.ReceiveProtocolError.code)
|
||||
yield ReceiveHttp(
|
||||
self.ReceiveProtocolError(
|
||||
h3_event.stream_id,
|
||||
f"stream reset by client ({err_str})",
|
||||
code=err_code,
|
||||
)
|
||||
)
|
||||
elif isinstance(h3_event, DataReceived):
|
||||
if h3_event.push_id is None:
|
||||
if h3_event.data:
|
||||
yield ReceiveHttp(self.ReceiveData(h3_event.stream_id, h3_event.data))
|
||||
if h3_event.stream_ended:
|
||||
yield ReceiveHttp(self.ReceiveEndOfMessage(h3_event.stream_id))
|
||||
elif isinstance(h3_event, HeadersReceived):
|
||||
if h3_event.push_id is None:
|
||||
try:
|
||||
receive_event = self.parse_headers(h3_event)
|
||||
except ValueError as e:
|
||||
self.h3_conn.close_connection(
|
||||
error_code=H3ErrorCode.H3_GENERAL_PROTOCOL_ERROR,
|
||||
reason_phrase=f"Invalid HTTP/3 request headers: {e}",
|
||||
)
|
||||
else:
|
||||
yield ReceiveHttp(receive_event)
|
||||
if h3_event.stream_ended:
|
||||
yield ReceiveHttp(self.ReceiveEndOfMessage(h3_event.stream_id))
|
||||
elif isinstance(h3_event, TrailersReceived):
|
||||
if h3_event.push_id is None:
|
||||
yield ReceiveHttp(self.ReceiveTrailers(h3_event.stream_id, http.Headers(h3_event.trailers)))
|
||||
if h3_event.stream_ended:
|
||||
yield ReceiveHttp(self.ReceiveEndOfMessage(h3_event.stream_id))
|
||||
elif isinstance(h3_event, PushPromiseReceived): # pragma: no cover
|
||||
# we don't support push
|
||||
pass
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
yield from self.h3_conn.transmit()
|
||||
|
||||
# report a protocol error for all remaining open streams when a connection is closed
|
||||
elif isinstance(event, QuicConnectionClosed):
|
||||
self._handle_event = self.done # type: ignore
|
||||
self.h3_conn.handle_connection_closed(event)
|
||||
msg = event.reason_phrase or error_code_to_str(event.error_code)
|
||||
for stream_id in self.h3_conn.get_open_stream_ids(push_id=None):
|
||||
yield ReceiveHttp(self.ReceiveProtocolError(stream_id, msg))
|
||||
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
|
||||
@expect(HttpEvent, QuicStreamEvent, QuicConnectionClosed)
|
||||
def done(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
@abstractmethod
|
||||
def parse_headers(
|
||||
self, event: HeadersReceived
|
||||
) -> Union[RequestHeaders, ResponseHeaders]:
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
class Http3Server(Http3Connection):
|
||||
ReceiveData = RequestData
|
||||
ReceiveEndOfMessage = RequestEndOfMessage
|
||||
ReceiveProtocolError = RequestProtocolError
|
||||
ReceiveTrailers = RequestTrailers
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context, context.client)
|
||||
|
||||
def parse_headers(self, event: HeadersReceived) -> Union[RequestHeaders, ResponseHeaders]:
|
||||
# same as HTTP/2
|
||||
(
|
||||
host,
|
||||
port,
|
||||
method,
|
||||
scheme,
|
||||
authority,
|
||||
path,
|
||||
headers,
|
||||
) = parse_h2_request_headers(event.headers)
|
||||
request = http.Request(
|
||||
host=host,
|
||||
port=port,
|
||||
method=method,
|
||||
scheme=scheme,
|
||||
authority=authority,
|
||||
path=path,
|
||||
http_version=b"HTTP/3",
|
||||
headers=headers,
|
||||
content=None,
|
||||
trailers=None,
|
||||
timestamp_start=time.time(),
|
||||
timestamp_end=None,
|
||||
)
|
||||
return RequestHeaders(event.stream_id, request, end_stream=event.stream_ended)
|
||||
|
||||
|
||||
class Http3Client(Http3Connection):
|
||||
ReceiveData = ResponseData
|
||||
ReceiveEndOfMessage = ResponseEndOfMessage
|
||||
ReceiveProtocolError = ResponseProtocolError
|
||||
ReceiveTrailers = ResponseTrailers
|
||||
|
||||
our_stream_id: Dict[int, int]
|
||||
their_stream_id: Dict[int, int]
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context, context.server)
|
||||
self.our_stream_id = {}
|
||||
self.their_stream_id = {}
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
# QUIC and HTTP/3 would actually allow for direct stream ID mapping, but since we want
|
||||
# to support H2<->H3, we need to translate IDs.
|
||||
# NOTE: We always create bidirectional streams, as we can't safely infer unidirectionality.
|
||||
if isinstance(event, HttpEvent):
|
||||
ours = self.our_stream_id.get(event.stream_id, None)
|
||||
if ours is None:
|
||||
ours = self.h3_conn.get_next_available_stream_id()
|
||||
self.our_stream_id[event.stream_id] = ours
|
||||
self.their_stream_id[ours] = event.stream_id
|
||||
event.stream_id = ours
|
||||
|
||||
for cmd in super()._handle_event(event):
|
||||
if isinstance(cmd, ReceiveHttp):
|
||||
cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id]
|
||||
yield cmd
|
||||
|
||||
def parse_headers(self, event: HeadersReceived) -> Union[RequestHeaders, ResponseHeaders]:
|
||||
# same as HTTP/2
|
||||
status_code, headers = parse_h2_response_headers(event.headers)
|
||||
response = http.Response(
|
||||
http_version=b"HTTP/3",
|
||||
status_code=status_code,
|
||||
reason=b"",
|
||||
headers=headers,
|
||||
content=None,
|
||||
trailers=None,
|
||||
timestamp_start=time.time(),
|
||||
timestamp_end=None,
|
||||
)
|
||||
return ResponseHeaders(event.stream_id, response, event.stream_ended)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Http3Client",
|
||||
"Http3Server",
|
||||
]
|
275
mitmproxy/proxy/layers/http/_http_h3.py
Normal file
275
mitmproxy/proxy/layers/http/_http_h3.py
Normal file
@ -0,0 +1,275 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from aioquic.h3.connection import (
|
||||
FrameUnexpected,
|
||||
H3Connection,
|
||||
H3Event,
|
||||
H3Stream,
|
||||
Headers,
|
||||
HeadersState,
|
||||
StreamType,
|
||||
)
|
||||
from aioquic.h3.events import HeadersReceived
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.events import StreamDataReceived
|
||||
from aioquic.quic.packet import QuicErrorCode
|
||||
|
||||
from mitmproxy import connection
|
||||
from mitmproxy.proxy import commands, layer
|
||||
from mitmproxy.proxy.layers.quic import (
|
||||
CloseQuicConnection,
|
||||
QuicConnectionClosed,
|
||||
QuicStreamDataReceived,
|
||||
QuicStreamEvent,
|
||||
QuicStreamReset,
|
||||
ResetQuicStream,
|
||||
SendQuicStreamData,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrailersReceived(H3Event):
|
||||
"""
|
||||
The TrailersReceived event is fired whenever trailers are received.
|
||||
"""
|
||||
|
||||
trailers: Headers
|
||||
"The trailers."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the trailers were received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
push_id: Optional[int] = None
|
||||
"The Push ID or `None` if this is not a push."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamReset(H3Event):
|
||||
"""
|
||||
The StreamReset event is fired whenever a stream is reset by the peer.
|
||||
"""
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream that was reset."
|
||||
|
||||
error_code: int
|
||||
"""The error code indicating why the stream was reset."""
|
||||
|
||||
push_id: Optional[int] = None
|
||||
"The Push ID or `None` if this is not a push."
|
||||
|
||||
|
||||
class MockQuic:
|
||||
"""
|
||||
aioquic intermingles QUIC and HTTP/3. This is something we don't want to do because that makes testing much harder.
|
||||
Instead, we mock our QUIC connection object here and then take out the wire data to be sent.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: connection.Connection, is_client: bool) -> None:
|
||||
self.conn = conn
|
||||
self.pending_commands: list[commands.Command] = []
|
||||
self._next_stream_id: list[int] = [0, 1, 2, 3]
|
||||
self._is_client = is_client
|
||||
|
||||
# the following fields are accessed by H3Connection
|
||||
self.configuration = QuicConfiguration(is_client=is_client)
|
||||
self._quic_logger = None
|
||||
self._remote_max_datagram_frame_size = 0
|
||||
|
||||
def close(
|
||||
self,
|
||||
error_code: int = QuicErrorCode.NO_ERROR,
|
||||
frame_type: Optional[int] = None,
|
||||
reason_phrase: str = "",
|
||||
) -> None:
|
||||
# we'll get closed if a protocol error occurs in `H3Connection.handle_event`
|
||||
# we note the error on the connection and yield a CloseConnection
|
||||
# this will then call `QuicConnection.close` with the proper values
|
||||
# once the `Http3Connection` receives `ConnectionClosed`, it will send out `ProtocolError`
|
||||
self.pending_commands.append(
|
||||
CloseQuicConnection(self.conn, error_code, frame_type, reason_phrase)
|
||||
)
|
||||
|
||||
def get_next_available_stream_id(self, is_unidirectional: bool = False) -> int:
|
||||
# since we always reserve the ID, we have to "find" the next ID like `QuicConnection` does
|
||||
index = (int(is_unidirectional) << 1) | int(not self._is_client)
|
||||
stream_id = self._next_stream_id[index]
|
||||
self._next_stream_id[index] = stream_id + 4
|
||||
return stream_id
|
||||
|
||||
def reset_stream(self, stream_id: int, error_code: int) -> None:
|
||||
self.pending_commands.append(ResetQuicStream(self.conn, stream_id, error_code))
|
||||
|
||||
def send_stream_data(self, stream_id: int, data: bytes, end_stream: bool = False) -> None:
|
||||
self.pending_commands.append(SendQuicStreamData(self.conn, stream_id, data, end_stream))
|
||||
|
||||
|
||||
class LayeredH3Connection(H3Connection):
|
||||
"""
|
||||
Creates a H3 connection using a fake QUIC connection, which allows layer separation.
|
||||
Also ensures that headers, data and trailers are sent in that order.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: connection.Connection, is_client: bool, enable_webtransport: bool = False) -> None:
|
||||
self._mock = MockQuic(conn, is_client)
|
||||
super().__init__(self._mock, enable_webtransport) # type: ignore
|
||||
|
||||
def _after_send(self, stream_id: int, end_stream: bool) -> None:
|
||||
# if the stream ended, `QuicConnection` has an assert that no further data is being sent
|
||||
# to catch this more early on, we set the header state on the `H3Stream`
|
||||
if end_stream:
|
||||
self._stream[stream_id].headers_send_state = HeadersState.AFTER_TRAILERS
|
||||
|
||||
def _handle_request_or_push_frame(
|
||||
self,
|
||||
frame_type: int,
|
||||
frame_data: Optional[bytes],
|
||||
stream: H3Stream,
|
||||
stream_ended: bool,
|
||||
) -> list[H3Event]:
|
||||
# turn HeadersReceived into TrailersReceived for trailers
|
||||
events = super()._handle_request_or_push_frame(frame_type, frame_data, stream, stream_ended)
|
||||
for index, event in enumerate(events):
|
||||
if (
|
||||
isinstance(event, HeadersReceived)
|
||||
and self._stream[event.stream_id].headers_recv_state == HeadersState.AFTER_TRAILERS
|
||||
):
|
||||
events[index] = TrailersReceived(event.headers, event.stream_id, event.stream_ended, event.push_id)
|
||||
return events
|
||||
|
||||
def close_connection(
|
||||
self,
|
||||
error_code: int = QuicErrorCode.NO_ERROR,
|
||||
frame_type: Optional[int] = None,
|
||||
reason_phrase: str = "",
|
||||
) -> None:
|
||||
"""Closes the underlying QUIC connection and ignores any incoming events."""
|
||||
|
||||
self._is_done = True
|
||||
self._quic.close(error_code, frame_type, reason_phrase)
|
||||
|
||||
def end_stream(self, stream_id: int) -> None:
|
||||
"""Ends the given stream if not already done so."""
|
||||
|
||||
stream = self._get_or_create_stream(stream_id)
|
||||
if stream.headers_send_state != HeadersState.AFTER_TRAILERS:
|
||||
super().send_data(stream_id, b"", end_stream=True)
|
||||
stream.headers_send_state = HeadersState.AFTER_TRAILERS
|
||||
|
||||
def get_next_available_stream_id(self, is_unidirectional: bool = False):
|
||||
"""Reserves and returns the next available stream ID."""
|
||||
|
||||
return self._quic.get_next_available_stream_id(is_unidirectional)
|
||||
|
||||
def get_open_stream_ids(self, push_id: Optional[int]) -> Iterable[int]:
|
||||
"""Iterates over all non-special open streams, optionally for a given push id."""
|
||||
|
||||
return (
|
||||
stream.stream_id
|
||||
for stream in self._stream.values()
|
||||
if (
|
||||
stream.push_id == push_id
|
||||
and stream.stream_type == (
|
||||
None
|
||||
if push_id is None else
|
||||
StreamType.PUSH
|
||||
)
|
||||
and not (
|
||||
stream.headers_recv_state == HeadersState.AFTER_TRAILERS
|
||||
and stream.headers_send_state == HeadersState.AFTER_TRAILERS
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def handle_connection_closed(self, event: QuicConnectionClosed) -> None:
|
||||
self._is_done = True
|
||||
|
||||
def handle_stream_event(self, event: QuicStreamEvent) -> list[H3Event]:
|
||||
# don't do anything if we're done
|
||||
if self._is_done:
|
||||
return []
|
||||
|
||||
# treat reset events similar to data events with end_stream=True
|
||||
# We can receive multiple reset events as long as the final size does not change.
|
||||
elif isinstance(event, QuicStreamReset):
|
||||
stream = self._get_or_create_stream(event.stream_id)
|
||||
stream.ended = True
|
||||
stream.headers_recv_state = HeadersState.AFTER_TRAILERS
|
||||
return [StreamReset(event.stream_id, event.error_code, stream.push_id)]
|
||||
|
||||
# convert data events from the QUIC layer back to aioquic events
|
||||
elif isinstance(event, QuicStreamDataReceived):
|
||||
if self._get_or_create_stream(event.stream_id).ended:
|
||||
# aioquic will not send us any data events once a stream has ended.
|
||||
# Instead, it will close the connection. We simulate this here for H3 tests.
|
||||
self.close_connection(error_code=QuicErrorCode.PROTOCOL_VIOLATION, reason_phrase="stream already ended")
|
||||
return []
|
||||
else:
|
||||
return self.handle_event(StreamDataReceived(event.data, event.end_stream, event.stream_id))
|
||||
|
||||
# should never happen
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
|
||||
def has_sent_headers(self, stream_id: int) -> bool:
|
||||
"""Indicates whether headers have been sent over the given stream."""
|
||||
|
||||
try:
|
||||
return self._stream[stream_id].headers_send_state != HeadersState.INITIAL
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def reset_stream(self, stream_id: int, error_code: int) -> None:
|
||||
"""Resets a stream that hasn't been ended locally yet."""
|
||||
|
||||
# set the header state and queue a reset event
|
||||
stream = self._get_or_create_stream(stream_id)
|
||||
stream.headers_send_state = HeadersState.AFTER_TRAILERS
|
||||
self._quic.reset_stream(stream_id, error_code)
|
||||
|
||||
def send_data(self, stream_id: int, data: bytes, end_stream: bool = False) -> None:
|
||||
"""Sends data over the given stream."""
|
||||
|
||||
super().send_data(stream_id, data, end_stream)
|
||||
self._after_send(stream_id, end_stream)
|
||||
|
||||
def send_datagram(self, flow_id: int, data: bytes) -> None:
|
||||
# supporting datagrams would require additional information from the underlying QUIC connection
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def send_headers(self, stream_id: int, headers: Headers, end_stream: bool = False) -> None:
|
||||
"""Sends headers over the given stream."""
|
||||
|
||||
# ensure we haven't sent something before
|
||||
stream = self._get_or_create_stream(stream_id)
|
||||
if stream.headers_send_state != HeadersState.INITIAL:
|
||||
raise FrameUnexpected("initial HEADERS frame is not allowed in this state")
|
||||
super().send_headers(stream_id, headers, end_stream)
|
||||
self._after_send(stream_id, end_stream)
|
||||
|
||||
def send_trailers(self, stream_id: int, trailers: Headers) -> None:
|
||||
"""Sends trailers over the given stream and ends it."""
|
||||
|
||||
# ensure we got some headers first
|
||||
stream = self._get_or_create_stream(stream_id)
|
||||
if stream.headers_send_state != HeadersState.AFTER_HEADERS:
|
||||
raise FrameUnexpected("trailing HEADERS frame is not allowed in this state")
|
||||
super().send_headers(stream_id, trailers, end_stream=True)
|
||||
self._after_send(stream_id, end_stream=True)
|
||||
|
||||
def transmit(self) -> layer.CommandGenerator[None]:
|
||||
"""Yields all pending commands for the upper QUIC layer."""
|
||||
|
||||
while self._mock.pending_commands:
|
||||
yield self._mock.pending_commands.pop(0)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LayeredH3Connection",
|
||||
"StreamReset",
|
||||
"TrailersReceived",
|
||||
]
|
@ -8,7 +8,7 @@ from typing import Callable, Optional
|
||||
from mitmproxy import connection
|
||||
from mitmproxy.proxy import commands, events, layer
|
||||
from mitmproxy.proxy.commands import StartHook
|
||||
from mitmproxy.proxy.layers import dns, tls
|
||||
from mitmproxy.proxy.layers import quic, tls
|
||||
from mitmproxy.proxy.mode_specs import ReverseMode
|
||||
from mitmproxy.proxy.utils import expect
|
||||
|
||||
@ -61,16 +61,18 @@ class ReverseProxy(DestinationKnown):
|
||||
assert isinstance(spec, ReverseMode)
|
||||
self.context.server.address = spec.address
|
||||
|
||||
if spec.scheme == "https" or spec.scheme == "tls" or spec.scheme == "dtls":
|
||||
if spec.scheme in ("http3", "quic"):
|
||||
if not self.context.options.keep_host_header:
|
||||
self.context.server.sni = spec.address[0]
|
||||
self.child_layer = quic.ServerQuicLayer(self.context)
|
||||
elif spec.scheme in ("https", "tls", "dtls"):
|
||||
if not self.context.options.keep_host_header:
|
||||
self.context.server.sni = spec.address[0]
|
||||
self.child_layer = tls.ServerTLSLayer(self.context)
|
||||
elif spec.scheme == "http" or spec.scheme == "tcp" or spec.scheme == "udp":
|
||||
elif spec.scheme in ("tcp", "http", "udp", "dns"):
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
elif spec.scheme == "dns":
|
||||
self.child_layer = dns.DNSLayer(self.context)
|
||||
else:
|
||||
raise AssertionError(self.context.client.transport_protocol) # pragma: no cover
|
||||
raise AssertionError(spec.scheme) # pragma: no cover
|
||||
|
||||
err = yield from self.finish_start()
|
||||
if err:
|
||||
|
1125
mitmproxy/proxy/layers/quic.py
Normal file
1125
mitmproxy/proxy/layers/quic.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -123,16 +123,16 @@ class TCPLayer(layer.Layer):
|
||||
or (self.context.server.state & ConnectionState.CAN_READ)
|
||||
)
|
||||
if all_done:
|
||||
self._handle_event = self.done
|
||||
if self.context.server.state is not ConnectionState.CLOSED:
|
||||
yield commands.CloseConnection(self.context.server)
|
||||
if self.context.client.state is not ConnectionState.CLOSED:
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
self._handle_event = self.done
|
||||
if self.flow:
|
||||
yield TcpEndHook(self.flow)
|
||||
self.flow.live = False
|
||||
else:
|
||||
yield commands.CloseConnection(send_to, half_close=True)
|
||||
yield commands.CloseTcpConnection(send_to, half_close=True)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
|
@ -440,9 +440,9 @@ class TLSLayer(tunnel.TunnelLayer):
|
||||
pass
|
||||
yield from self.tls_interact()
|
||||
|
||||
def send_close(self, half_close: bool) -> layer.CommandGenerator[None]:
|
||||
def send_close(self, command: commands.CloseConnection) -> layer.CommandGenerator[None]:
|
||||
# We should probably shutdown the TLS connection properly here.
|
||||
yield from super().send_close(half_close)
|
||||
yield from super().send_close(command)
|
||||
|
||||
|
||||
class ServerTLSLayer(TLSLayer):
|
||||
@ -659,7 +659,7 @@ class ClientTLSLayer(TLSLayer):
|
||||
|
||||
def errored(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if self.debug is not None:
|
||||
yield commands.Log(f"Swallowing {event} as handshake failed.", DEBUG)
|
||||
yield commands.Log(f"{self.debug}[tls] Swallowing {event} as handshake failed.", DEBUG)
|
||||
|
||||
|
||||
class MockTLSLayer(TLSLayer):
|
||||
|
@ -429,3 +429,8 @@ class Socks5Instance(AsyncioServerInstance[mode_specs.Socks5Mode]):
|
||||
class DnsInstance(AsyncioServerInstance[mode_specs.DnsMode]):
|
||||
def make_top_layer(self, context: Context) -> Layer:
|
||||
return layers.DNSLayer(context)
|
||||
|
||||
|
||||
class Http3Instance(AsyncioServerInstance[mode_specs.Http3Mode]):
|
||||
def make_top_layer(self, context: Context) -> Layer:
|
||||
return layers.modes.HttpProxy(context)
|
||||
|
@ -211,13 +211,13 @@ class ReverseMode(ProxyMode):
|
||||
"""A reverse proxy. This acts like a normal server, but redirects all requests to a fixed target."""
|
||||
description = "reverse proxy"
|
||||
transport_protocol = TCP
|
||||
scheme: Literal["http", "https", "tls", "dtls", "tcp", "udp", "dns"]
|
||||
scheme: Literal["http", "https", "http3", "tls", "dtls", "tcp", "udp", "dns", "quic"]
|
||||
address: tuple[str, int]
|
||||
|
||||
# noinspection PyDataclass
|
||||
def __post_init__(self) -> None:
|
||||
self.scheme, self.address = server_spec.parse(self.data, default_scheme="https")
|
||||
if self.scheme in ("dns", "dtls", "udp"):
|
||||
if self.scheme in ("http3", "dtls", "udp", "dns", "quic"):
|
||||
self.transport_protocol = UDP
|
||||
self.description = f"{self.description} to {self.data}"
|
||||
|
||||
@ -248,6 +248,18 @@ class DnsMode(ProxyMode):
|
||||
_check_empty(self.data)
|
||||
|
||||
|
||||
class Http3Mode(ProxyMode):
|
||||
"""
|
||||
A regular HTTP3 proxy that is interfaced with absolute-form HTTP requests.
|
||||
(This class will be merged into `RegularMode` once the UDP implementation is deemed stable enough.)
|
||||
"""
|
||||
description = "HTTP3 proxy"
|
||||
transport_protocol = UDP
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
_check_empty(self.data)
|
||||
|
||||
|
||||
class WireGuardMode(ProxyMode):
|
||||
"""Proxy Server based on WireGuard"""
|
||||
description = "WireGuard server"
|
||||
|
@ -369,8 +369,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
assert writer
|
||||
if not writer.is_closing():
|
||||
writer.write(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
elif isinstance(command, commands.CloseTcpConnection):
|
||||
self.close_connection(command.connection, command.half_close)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
self.close_connection(command.connection, False)
|
||||
elif isinstance(command, commands.StartHook):
|
||||
asyncio_utils.create_task(
|
||||
self.hook_task(command),
|
||||
@ -423,6 +425,7 @@ class LiveConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):
|
||||
sockname=writer.get_extra_info("sockname"),
|
||||
timestamp_start=time.time(),
|
||||
proxy_mode=mode,
|
||||
state=ConnectionState.OPEN,
|
||||
)
|
||||
context = Context(client, options)
|
||||
super().__init__(context)
|
||||
|
@ -108,6 +108,35 @@ class TunnelLayer(layer.Layer):
|
||||
yield from self.event_to_child(evt)
|
||||
self._event_queue.clear()
|
||||
|
||||
def _handle_command(self, command: commands.Command) -> layer.CommandGenerator[None]:
|
||||
if (
|
||||
isinstance(command, commands.ConnectionCommand)
|
||||
and command.connection == self.conn
|
||||
):
|
||||
if isinstance(command, commands.SendData):
|
||||
yield from self.send_data(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
if self.conn != self.tunnel_connection:
|
||||
self.conn.state &= ~connection.ConnectionState.CAN_WRITE
|
||||
command.connection = self.tunnel_connection
|
||||
yield from self.send_close(command)
|
||||
elif isinstance(command, commands.OpenConnection):
|
||||
# create our own OpenConnection command object that blocks here.
|
||||
self.command_to_reply_to = command
|
||||
self.tunnel_state = TunnelState.ESTABLISHING
|
||||
err = yield commands.OpenConnection(self.tunnel_connection)
|
||||
if err:
|
||||
yield from self.event_to_child(
|
||||
events.OpenConnectionCompleted(command, err)
|
||||
)
|
||||
self.tunnel_state = TunnelState.CLOSED
|
||||
else:
|
||||
yield from self.start_handshake()
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected command: {command}")
|
||||
else:
|
||||
yield command
|
||||
|
||||
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if (
|
||||
self.tunnel_state is TunnelState.ESTABLISHING
|
||||
@ -116,35 +145,7 @@ class TunnelLayer(layer.Layer):
|
||||
self._event_queue.append(event)
|
||||
return
|
||||
for command in self.child_layer.handle_event(event):
|
||||
if (
|
||||
isinstance(command, commands.ConnectionCommand)
|
||||
and command.connection == self.conn
|
||||
):
|
||||
if isinstance(command, commands.SendData):
|
||||
yield from self.send_data(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
if self.conn != self.tunnel_connection:
|
||||
if command.half_close:
|
||||
self.conn.state &= ~connection.ConnectionState.CAN_WRITE
|
||||
else:
|
||||
self.conn.state = connection.ConnectionState.CLOSED
|
||||
yield from self.send_close(command.half_close)
|
||||
elif isinstance(command, commands.OpenConnection):
|
||||
# create our own OpenConnection command object that blocks here.
|
||||
self.command_to_reply_to = command
|
||||
self.tunnel_state = TunnelState.ESTABLISHING
|
||||
err = yield commands.OpenConnection(self.tunnel_connection)
|
||||
if err:
|
||||
yield from self.event_to_child(
|
||||
events.OpenConnectionCompleted(command, err)
|
||||
)
|
||||
self.tunnel_state = TunnelState.CLOSED
|
||||
else:
|
||||
yield from self.start_handshake()
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected command: {command}")
|
||||
else:
|
||||
yield command
|
||||
yield from self._handle_command(command)
|
||||
|
||||
def start_handshake(self) -> layer.CommandGenerator[None]:
|
||||
yield from self._handle_event(events.DataReceived(self.tunnel_connection, b""))
|
||||
@ -169,8 +170,8 @@ class TunnelLayer(layer.Layer):
|
||||
def send_data(self, data: bytes) -> layer.CommandGenerator[None]:
|
||||
yield commands.SendData(self.tunnel_connection, data)
|
||||
|
||||
def send_close(self, half_close: bool) -> layer.CommandGenerator[None]:
|
||||
yield commands.CloseConnection(self.tunnel_connection, half_close=half_close)
|
||||
def send_close(self, command: commands.CloseConnection) -> layer.CommandGenerator[None]:
|
||||
yield command
|
||||
|
||||
|
||||
class LayerStack:
|
||||
|
@ -119,6 +119,7 @@ SCHEME_STYLES = {
|
||||
"tcp": "scheme_tcp",
|
||||
"udp": "scheme_udp",
|
||||
"dns": "scheme_dns",
|
||||
"quic": "scheme_quic",
|
||||
}
|
||||
HTTP_REQUEST_METHOD_STYLES = {
|
||||
"GET": "method_get",
|
||||
@ -763,12 +764,16 @@ def format_flow(
|
||||
duration = f.messages[-1].timestamp - f.client_conn.timestamp_start
|
||||
else:
|
||||
duration = None
|
||||
if f.client_conn.tls_version == "QUIC":
|
||||
protocol = "quic"
|
||||
else:
|
||||
protocol = f.type
|
||||
return format_message_flow(
|
||||
render_mode=render_mode,
|
||||
focused=focused,
|
||||
timestamp_start=f.client_conn.timestamp_start,
|
||||
marked=f.marked,
|
||||
protocol=f.type,
|
||||
protocol=protocol,
|
||||
client_address=f.client_conn.peername,
|
||||
server_address=f.server_conn.address,
|
||||
total_size=total_size,
|
||||
|
@ -40,6 +40,7 @@ class Palette:
|
||||
"scheme_tcp",
|
||||
"scheme_udp",
|
||||
"scheme_dns",
|
||||
"scheme_quic",
|
||||
"scheme_other",
|
||||
"url_punctuation",
|
||||
"url_domain",
|
||||
@ -180,6 +181,7 @@ class LowDark(Palette):
|
||||
scheme_tcp=("dark magenta", "default"),
|
||||
scheme_udp=("dark magenta", "default"),
|
||||
scheme_dns=("dark blue", "default"),
|
||||
scheme_quic=("brown", "default"),
|
||||
scheme_other=("dark magenta", "default"),
|
||||
url_punctuation=("light gray", "default"),
|
||||
url_domain=("white", "default"),
|
||||
@ -280,6 +282,7 @@ class LowLight(Palette):
|
||||
scheme_tcp=("light magenta", "default"),
|
||||
scheme_udp=("light magenta", "default"),
|
||||
scheme_dns=("light blue", "default"),
|
||||
scheme_quic=("brown", "default"),
|
||||
scheme_other=("light magenta", "default"),
|
||||
url_punctuation=("dark gray", "default"),
|
||||
url_domain=("dark gray", "default"),
|
||||
@ -401,6 +404,7 @@ class SolarizedLight(LowLight):
|
||||
scheme_tcp=("light magenta", "default"),
|
||||
scheme_udp=("light magenta", "default"),
|
||||
scheme_dns=("light blue", "default"),
|
||||
scheme_quic=(sol_orange, "default"),
|
||||
scheme_other=("light magenta", "default"),
|
||||
url_punctuation=("dark gray", "default"),
|
||||
url_domain=("dark gray", "default"),
|
||||
|
@ -55,6 +55,7 @@ def run(
|
||||
logging.getLogger("tornado").setLevel(logging.WARNING)
|
||||
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||
logging.getLogger("hpack").setLevel(logging.WARNING)
|
||||
logging.getLogger("quic").setLevel(logging.WARNING) # aioquic uses a different prefix...
|
||||
debug.register_info_dumpers()
|
||||
|
||||
opts = options.Options()
|
||||
|
2
mitmproxy/tools/web/static/app.css
vendored
2
mitmproxy/tools/web/static/app.css
vendored
File diff suppressed because one or more lines are too long
50
mitmproxy/tools/web/static/app.js
vendored
50
mitmproxy/tools/web/static/app.js
vendored
File diff suppressed because one or more lines are too long
BIN
mitmproxy/tools/web/static/images/resourceQuicIcon.png
vendored
Normal file
BIN
mitmproxy/tools/web/static/images/resourceQuicIcon.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.3 KiB |
@ -7,7 +7,7 @@ MITMPROXY = "mitmproxy " + VERSION
|
||||
|
||||
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
|
||||
# for each change in the file format.
|
||||
FLOW_FORMAT_VERSION = 19
|
||||
FLOW_FORMAT_VERSION = 20
|
||||
|
||||
|
||||
def get_dev_version() -> str:
|
||||
|
@ -57,7 +57,6 @@ exclude =
|
||||
mitmproxy/connections.py
|
||||
mitmproxy/contentviews/base.py
|
||||
mitmproxy/contentviews/grpc.py
|
||||
mitmproxy/contentviews/http3.py
|
||||
mitmproxy/ctx.py
|
||||
mitmproxy/exceptions.py
|
||||
mitmproxy/flow.py
|
||||
|
1
setup.py
1
setup.py
@ -73,6 +73,7 @@ setup(
|
||||
# https://packaging.python.org/en/latest/discussions/install-requires-vs-requirements/#install-requires
|
||||
# It is not considered best practice to use install_requires to pin dependencies to specific versions.
|
||||
install_requires=[
|
||||
"aioquic_mitmproxy>=0.9.20,<0.10",
|
||||
"asgiref>=3.2.10,<3.6",
|
||||
"Brotli>=1.0,<1.1",
|
||||
"certifi>=2019.9.11", # no semver here - this should always be on the last release!
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy.addons import dumper
|
||||
from mitmproxy.http import Headers
|
||||
from mitmproxy.net.dns import response_codes
|
||||
from mitmproxy.test import taddons
|
||||
from mitmproxy.test import tflow
|
||||
from mitmproxy.test import tutils
|
||||
@ -226,6 +227,12 @@ def test_dns():
|
||||
assert "8.8.8.8" in sio.getvalue()
|
||||
sio.truncate(0)
|
||||
|
||||
f = tflow.tdnsflow()
|
||||
f.response = f.request.fail(response_codes.NOTIMP)
|
||||
d.dns_response(f)
|
||||
assert "NOTIMP" in sio.getvalue()
|
||||
sio.truncate(0)
|
||||
|
||||
f = tflow.tdnsflow(err=True)
|
||||
d.dns_error(f)
|
||||
assert "error" in sio.getvalue()
|
||||
@ -279,6 +286,16 @@ def test_http2():
|
||||
assert "HTTP/2.0 200 OK" in sio.getvalue()
|
||||
|
||||
|
||||
def test_quic():
|
||||
sio = io.StringIO()
|
||||
d = dumper.Dumper(sio)
|
||||
with taddons.context(d):
|
||||
f = tflow.ttcpflow()
|
||||
f.client_conn.tls_version = "QUIC"
|
||||
d.tcp_message(f)
|
||||
assert "quic/tcp" in sio.getvalue()
|
||||
|
||||
|
||||
def test_styling():
|
||||
sio = io.StringIO()
|
||||
|
||||
|
@ -46,6 +46,39 @@ dtls_client_hello_with_extensions = bytes.fromhex(
|
||||
)
|
||||
|
||||
|
||||
quic_client_hello = bytes.fromhex(
|
||||
"ca0000000108c0618c84b54541320823fcce946c38d8210044e6a93bbb283593f75ffb6f2696b16cfdcb5b1255"
|
||||
"577b2af5fc5894188c9568bc65eef253faf7f0520e41341cfa81d6aae573586665ce4e1e41676364820402feec"
|
||||
"a81f3d22dbb476893422069066104a43e121c951a08c53b83f960becf99cf5304d5bc5346f52f472bd1a04d192"
|
||||
"0bae025064990d27e5e4c325ac46121d3acadebe7babdb96192fb699693d65e2b2e21c53beeb4f40b50673a2f6"
|
||||
"c22091cb7c76a845384fedee58df862464d1da505a280bfef91ca83a10bebbcb07855219dbc14aecf8a48da049"
|
||||
"d03c77459b39d5355c95306cd03d6bdb471694fa998ca3b1f875ce87915b88ead15c5d6313a443f39aad808922"
|
||||
"57ddfa6b4a898d773bb6fb520ede47ebd59d022431b1054a69e0bbbdf9f0fb32fc8bcc4b6879dd8cd5389474b1"
|
||||
"99e18333e14d0347740a11916429a818bb8d93295d36e99840a373bb0e14c8b3adcf5e2165e70803f15316fd5e"
|
||||
"5eeec04ae68d98f1adb22c54611c80fcd8ece619dbdf97b1510032ec374b7a71f94d9492b8b8cb56f56556dd97"
|
||||
"edf1e50fa90e868ff93636a365678bdf3ee3f8e632588cd506b6f44fbfd4d99988238fbd5884c98f6a124108c1"
|
||||
"878970780e42b111e3be6215776ef5be5a0205915e6d720d22c6a81a475c9e41ba94e4983b964cb5c8e1f40607"
|
||||
"76d1d8d1adcef7587ea084231016bd6ee2643d11a3a35eb7fe4cca2b3f1a4b21e040b0d426412cca6c4271ea63"
|
||||
"fb54ed7f57b41cd1af1be5507f87ea4f4a0c997367e883291de2f1b8a49bdaa52bae30064351b1139703400730"
|
||||
"18a4104344ec6b4454b50a42e804bc70e78b9b3c82497273859c82ed241b643642d76df6ceab8f916392113a62"
|
||||
"b231f228c7300624d74a846bec2f479ab8a8c3461f91c7bf806236e3bd2f54ba1ef8e2a1e0bfdde0c5ad227f7d"
|
||||
"364c52510b1ade862ce0c8d7bd24b6d7d21c99b34de6d177eb3d575787b2af55060d76d6c2060befbb7953a816"
|
||||
"6f66ad88ecf929dbb0ad3a16cf7dfd39d925e0b4b649c6d0c07ad46ed0229c17fb6a1395f16e1b138aab3af760"
|
||||
"2b0ac762c4f611f7f3468997224ffbe500a7c53f92f65e41a3765a9f1d7e3f78208f5b4e147962d8c97d6c1a80"
|
||||
"91ffc36090b2043d71853616f34c2185dc883c54ab6d66e10a6c18e0b9a4742597361f8554a42da3373241d0c8"
|
||||
"54119bfadccffaf2335b2d97ffee627cb891bda8140a39399f853da4859f7e19682e152243efbaffb662edd19b"
|
||||
"3819a74107c7dbe05ecb32e79dcdb1260f153b1ef133e978ccca3d9e400a7ed6c458d77e2956d2cb897b7a298b"
|
||||
"fe144b5defdc23dfd2adf69f1fb0917840703402d524987ae3b1dcb85229843c9a419ef46e1ba0ba7783f2a2ec"
|
||||
"d057a57518836aef2a7839ebd3688da98b54c942941f642e434727108d59ea25875b3050ca53d4637c76cbcbb9"
|
||||
"e972c2b0b781131ee0a1403138b55486fe86bbd644920ee6aa578e3bab32d7d784b5c140295286d90c99b14823"
|
||||
"1487f7ea64157001b745aa358c9ea6bec5a8d8b67a7534ec1f7648ff3b435911dfc3dff798d32fbf2efe2c1fcc"
|
||||
"278865157590572387b76b78e727d3e7682cb501cdcdf9a0f17676f99d9aa67f10edccc9a92080294e88bf28c2"
|
||||
"a9f32ae535fdb27fff7706540472abb9eab90af12b2bea005da189874b0ca69e6ae1690a6f2adf75be3853c94e"
|
||||
"fd8098ed579c20cb37be6885d8d713af4ba52958cee383089b98ed9cb26e11127cf88d1b7d254f15f7903dd7ed"
|
||||
"297c0013924e88248684fe8f2098326ce51aa6e5"
|
||||
)
|
||||
|
||||
|
||||
class TestNextLayer:
|
||||
def test_configure(self):
|
||||
nl = NextLayer()
|
||||
@ -147,49 +180,210 @@ class TestNextLayer:
|
||||
assert isinstance(nl._next_layer(ctx, b"GET /foo", b""), layers.HttpLayer)
|
||||
assert isinstance(nl._next_layer(ctx, b"", b"hello"), layers.TCPLayer)
|
||||
|
||||
def test_next_layer_udp(self):
|
||||
@pytest.mark.parametrize(
|
||||
("protocol", "client_layer", "server_layer"),
|
||||
[
|
||||
("dtls", layers.ClientTLSLayer, layers.ServerTLSLayer),
|
||||
("quic", layers.ClientQuicLayer, layers.ServerQuicLayer),
|
||||
]
|
||||
)
|
||||
def test_next_layer_udp(
|
||||
self,
|
||||
protocol: str,
|
||||
client_layer: layer.Layer,
|
||||
server_layer: layer.Layer,
|
||||
):
|
||||
def is_ignored_udp(layer: Optional[layer.Layer]):
|
||||
return isinstance(layer, layers.UDPLayer) and layer.flow is None
|
||||
|
||||
def is_intercepted_udp(layer: Optional[layer.Layer]):
|
||||
return isinstance(layer, layers.UDPLayer) and layer.flow is not None
|
||||
|
||||
def is_http(layer: Optional[layer.Layer], mode: HTTPMode):
|
||||
return (
|
||||
isinstance(layer, layers.HttpLayer)
|
||||
and layer.mode is mode
|
||||
)
|
||||
|
||||
client_hello = {
|
||||
"dtls": dtls_client_hello_with_extensions,
|
||||
"quic": quic_client_hello,
|
||||
}[protocol]
|
||||
nl = NextLayer()
|
||||
ctx = MagicMock()
|
||||
ctx.client.alpn = None
|
||||
ctx.server.address = ("example.com", 443)
|
||||
ctx.client.transport_protocol = "udp"
|
||||
with taddons.context(nl) as tctx:
|
||||
ctx.layers = [layers.modes.HttpProxy(ctx)]
|
||||
tctx.configure(nl, rawudp=False)
|
||||
assert is_ignored_udp(nl._next_layer(ctx, b"", b""))
|
||||
ctx.layers = [layers.modes.HttpProxy(ctx), client_layer(ctx)]
|
||||
assert is_http(nl._next_layer(ctx, b"", b""), HTTPMode.regular)
|
||||
|
||||
ctx.layers = [layers.modes.HttpProxy(ctx)]
|
||||
tctx.configure(nl, rawudp=True)
|
||||
assert is_intercepted_udp(nl._next_layer(ctx, b"", b""))
|
||||
ctx.layers = [layers.modes.HttpUpstreamProxy(ctx), client_layer(ctx)]
|
||||
assert is_http(nl._next_layer(ctx, b"", b""), HTTPMode.upstream)
|
||||
|
||||
ctx.layers = [layers.modes.HttpProxy(ctx)]
|
||||
ctx.layers = [layers.modes.TransparentProxy(ctx)]
|
||||
is_intercepted_udp(nl._next_layer(ctx, b"", b""))
|
||||
|
||||
ctx.layers = [layers.modes.TransparentProxy(ctx)]
|
||||
ctx.server.address = ("nomatch.com", 443)
|
||||
tctx.configure(nl, ignore_hosts=["example.com"])
|
||||
assert is_intercepted_udp(nl._next_layer(ctx, dtls_client_hello_with_extensions[:50], b""))
|
||||
assert is_ignored_udp(nl._next_layer(ctx, dtls_client_hello_with_extensions, b""))
|
||||
assert is_intercepted_udp(nl._next_layer(ctx, client_hello[:50], b""))
|
||||
assert is_ignored_udp(nl._next_layer(ctx, client_hello, b""))
|
||||
|
||||
ctx.layers = [layers.modes.HttpProxy(ctx)]
|
||||
ctx.layers = [layers.modes.TransparentProxy(ctx)]
|
||||
ctx.server.address = ("example.com", 443)
|
||||
assert is_ignored_udp(nl._next_layer(ctx, dtls_client_hello_with_extensions[:50], b""))
|
||||
assert is_ignored_udp(nl._next_layer(ctx, client_hello[:50], b""))
|
||||
|
||||
ctx.layers = [layers.modes.HttpProxy(ctx)]
|
||||
ctx.layers = [layers.modes.TransparentProxy(ctx)]
|
||||
tctx.configure(nl, ignore_hosts=[])
|
||||
assert isinstance(nl._next_layer(ctx, dtls_client_hello_with_extensions, b""), layers.ClientTLSLayer)
|
||||
decision = nl._next_layer(ctx, client_hello, b"")
|
||||
assert isinstance(decision, server_layer)
|
||||
assert isinstance(decision.child_layer, client_layer)
|
||||
|
||||
ctx.layers = [layers.modes.HttpProxy(ctx)]
|
||||
ctx.layers = [layers.modes.ReverseProxy(ctx), server_layer(ctx)]
|
||||
tctx.configure(nl, ignore_hosts=[])
|
||||
assert isinstance(nl._next_layer(ctx, client_hello, b""), client_layer)
|
||||
|
||||
ctx.layers = [layers.modes.TransparentProxy(ctx)]
|
||||
tctx.configure(nl, udp_hosts=["example.com"])
|
||||
assert isinstance(nl._next_layer(ctx, tflow.tdnsreq().packed, b""), layers.UDPLayer)
|
||||
|
||||
ctx.layers = [layers.modes.HttpProxy(ctx)]
|
||||
ctx.layers = [layers.modes.TransparentProxy(ctx)]
|
||||
tctx.configure(nl, udp_hosts=[])
|
||||
assert isinstance(nl._next_layer(ctx, tflow.tdnsreq().packed, b""), layers.DNSLayer)
|
||||
|
||||
def test_next_layer_reverse_raw(self):
|
||||
nl = NextLayer()
|
||||
with taddons.context(nl):
|
||||
ctx = MagicMock()
|
||||
ctx.client.alpn = None
|
||||
ctx.server.address = ("example.com", 443)
|
||||
ctx.client.transport_protocol = "udp"
|
||||
with taddons.context(nl) as tctx:
|
||||
tctx.configure(nl, ignore_hosts=["example.com"])
|
||||
|
||||
ctx.layers = [
|
||||
layers.modes.HttpProxy(ctx),
|
||||
layers.ClientQuicLayer(ctx),
|
||||
]
|
||||
decision = nl._next_layer(ctx, b"", b"")
|
||||
assert isinstance(decision, layers.ServerQuicLayer)
|
||||
assert isinstance(decision.child_layer, layers.RawQuicLayer)
|
||||
|
||||
ctx.layers = [
|
||||
layers.modes.ReverseProxy(ctx),
|
||||
layers.ServerQuicLayer(ctx),
|
||||
layers.ClientQuicLayer(ctx,),
|
||||
]
|
||||
assert isinstance(nl._next_layer(ctx, b"", b""), layers.RawQuicLayer)
|
||||
|
||||
ctx.layers = [
|
||||
layers.modes.ReverseProxy(ctx),
|
||||
layers.ServerQuicLayer(ctx),
|
||||
]
|
||||
decision = nl._next_layer(ctx, b"", b"")
|
||||
assert isinstance(decision, layers.ClientQuicLayer)
|
||||
assert isinstance(decision.child_layer, layers.RawQuicLayer)
|
||||
|
||||
tctx.configure(nl, ignore_hosts=[])
|
||||
|
||||
def test_next_layer_reverse_quic_mode(self):
|
||||
nl = NextLayer()
|
||||
with taddons.context(nl):
|
||||
ctx = MagicMock()
|
||||
ctx.client.alpn = None
|
||||
ctx.server.address = ("example.com", 443)
|
||||
ctx.client.transport_protocol = "udp"
|
||||
ctx.client.proxy_mode.scheme = "quic"
|
||||
ctx.layers = [
|
||||
layers.modes.ReverseProxy(ctx),
|
||||
layers.ServerQuicLayer(ctx),
|
||||
layers.ClientQuicLayer(ctx),
|
||||
]
|
||||
assert isinstance(nl._next_layer(ctx, b"", b""), layers.RawQuicLayer)
|
||||
ctx.layers = [
|
||||
layers.modes.ReverseProxy(ctx),
|
||||
layers.ServerQuicLayer(ctx),
|
||||
]
|
||||
assert nl._next_layer(ctx, b"", b"") is None
|
||||
assert isinstance(nl._next_layer(ctx, b"notahandshake", b""), layers.UDPLayer)
|
||||
ctx.layers = [
|
||||
layers.modes.ReverseProxy(ctx),
|
||||
layers.ServerQuicLayer(ctx),
|
||||
]
|
||||
assert isinstance(nl._next_layer(ctx, quic_client_hello, b""), layers.ClientQuicLayer)
|
||||
|
||||
def test_next_layer_reverse_http3_mode(self):
|
||||
nl = NextLayer()
|
||||
with taddons.context(nl):
|
||||
ctx = MagicMock()
|
||||
ctx.client.alpn = None
|
||||
ctx.server.address = ("example.com", 443)
|
||||
ctx.client.transport_protocol = "udp"
|
||||
ctx.client.proxy_mode.scheme = "http3"
|
||||
ctx.layers = [
|
||||
layers.modes.ReverseProxy(ctx),
|
||||
layers.ServerQuicLayer(ctx),
|
||||
]
|
||||
assert isinstance(nl._next_layer(ctx, b"notahandshakebutignore", b""), layers.ClientQuicLayer)
|
||||
assert len(ctx.layers) == 3
|
||||
decision = nl._next_layer(ctx, b"", b"")
|
||||
assert isinstance(decision, layers.HttpLayer)
|
||||
assert decision.mode is HTTPMode.transparent
|
||||
|
||||
def test_next_layer_reverse_invalid_mode(self):
|
||||
nl = NextLayer()
|
||||
ctx = MagicMock()
|
||||
ctx.client.alpn = None
|
||||
ctx.server.address = ("example.com", 443)
|
||||
ctx.client.transport_protocol = "udp"
|
||||
ctx.client.proxy_mode.scheme = "invalidscheme"
|
||||
ctx.layers = [layers.modes.ReverseProxy(ctx)]
|
||||
with pytest.raises(AssertionError, match="invalidscheme"):
|
||||
nl._next_layer(ctx, b"", b"")
|
||||
|
||||
def test_next_layer_reverse_dtls_mode(self):
|
||||
nl = NextLayer()
|
||||
ctx = MagicMock()
|
||||
ctx.client.alpn = None
|
||||
ctx.server.address = ("example.com", 443)
|
||||
ctx.client.transport_protocol = "udp"
|
||||
ctx.client.proxy_mode.scheme = "dtls"
|
||||
ctx.layers = [layers.modes.ReverseProxy(ctx), layers.ServerTLSLayer(ctx)]
|
||||
assert isinstance(nl._next_layer(ctx, b"", b""), layers.UDPLayer)
|
||||
ctx.layers = [layers.modes.ReverseProxy(ctx), layers.ServerTLSLayer(ctx)]
|
||||
assert isinstance(nl._next_layer(ctx, dtls_client_hello_with_extensions, b""), layers.ClientTLSLayer)
|
||||
assert len(ctx.layers) == 3
|
||||
assert isinstance(nl._next_layer(ctx, b"", b""), layers.UDPLayer)
|
||||
|
||||
def test_next_layer_reverse_udp_mode(self):
|
||||
nl = NextLayer()
|
||||
ctx = MagicMock()
|
||||
ctx.client.alpn = None
|
||||
ctx.server.address = ("example.com", 443)
|
||||
ctx.client.transport_protocol = "udp"
|
||||
ctx.client.proxy_mode.scheme = "udp"
|
||||
ctx.layers = [layers.modes.ReverseProxy(ctx)]
|
||||
assert isinstance(nl._next_layer(ctx, b"", b""), layers.UDPLayer)
|
||||
ctx.layers = [layers.modes.ReverseProxy(ctx)]
|
||||
assert isinstance(nl._next_layer(ctx, dtls_client_hello_with_extensions, b""), layers.ClientTLSLayer)
|
||||
assert len(ctx.layers) == 2
|
||||
assert isinstance(nl._next_layer(ctx, b"", b""), layers.UDPLayer)
|
||||
|
||||
def test_next_layer_reverse_dns_mode(self):
|
||||
nl = NextLayer()
|
||||
ctx = MagicMock()
|
||||
ctx.client.alpn = None
|
||||
ctx.server.address = ("example.com", 443)
|
||||
ctx.client.transport_protocol = "udp"
|
||||
ctx.client.proxy_mode.scheme = "dns"
|
||||
ctx.layers = [layers.modes.ReverseProxy(ctx)]
|
||||
assert isinstance(nl._next_layer(ctx, b"", b""), layers.DNSLayer)
|
||||
ctx.layers = [layers.modes.ReverseProxy(ctx)]
|
||||
assert isinstance(nl._next_layer(ctx, dtls_client_hello_with_extensions, b""), layers.ClientTLSLayer)
|
||||
assert len(ctx.layers) == 2
|
||||
assert isinstance(nl._next_layer(ctx, b"", b""), layers.DNSLayer)
|
||||
|
||||
def test_next_layer_invalid_proto(self):
|
||||
nl = NextLayer()
|
||||
ctx = MagicMock()
|
||||
|
@ -1,9 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
import socket
|
||||
import ssl
|
||||
from typing import Any, AsyncGenerator, Callable, ClassVar, Optional, TypeVar
|
||||
from unittest.mock import Mock
|
||||
|
||||
from aioquic.asyncio.protocol import QuicConnectionProtocol
|
||||
from aioquic.asyncio.server import QuicServer
|
||||
from aioquic.h3 import events as h3_events
|
||||
from aioquic.h3.connection import H3Connection, FrameUnexpected
|
||||
from aioquic.quic import events as quic_events
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection, QuicConnectionError
|
||||
import pytest
|
||||
from mitmproxy.addons.next_layer import NextLayer
|
||||
from mitmproxy.addons.tlsconfig import TlsConfig
|
||||
|
||||
import mitmproxy.platform
|
||||
from mitmproxy import dns, exceptions
|
||||
@ -17,6 +31,10 @@ from mitmproxy.proxy.layers.http import HTTPMode
|
||||
from mitmproxy.test import taddons, tflow
|
||||
from mitmproxy.test.tflow import tclient_conn, tserver_conn
|
||||
from mitmproxy.test.tutils import tdnsreq
|
||||
from mitmproxy.utils import data
|
||||
|
||||
|
||||
tlsdata = data.Data(__name__)
|
||||
|
||||
|
||||
class HelperAddon:
|
||||
@ -146,6 +164,11 @@ async def test_inject_fail(caplog) -> None:
|
||||
ps.inject_tcp(tflow.tflow(), True, b"test")
|
||||
assert "Cannot inject TCP messages into non-TCP flows." in caplog.text
|
||||
|
||||
ps.inject_udp(tflow.tflow(), True, b"test")
|
||||
assert "Cannot inject UDP messages into non-UDP flows." in caplog.text
|
||||
ps.inject_udp(tflow.tudpflow(), True, b"test")
|
||||
assert "Flow is not from a live connection." in caplog.text
|
||||
|
||||
ps.inject_websocket(tflow.twebsocketflow(), True, b"test")
|
||||
assert "Flow is not from a live connection." in caplog.text
|
||||
ps.inject_websocket(tflow.ttcpflow(), True, b"test")
|
||||
@ -347,3 +370,479 @@ async def test_dtls(monkeypatch, caplog_async) -> None:
|
||||
assert len(ps.connections) == 1
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log("Stopped reverse proxy to dtls")
|
||||
|
||||
|
||||
class H3EchoServer(QuicConnectionProtocol):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._seen_headers: set[int] = set()
|
||||
self.http: Optional[H3Connection] = None
|
||||
|
||||
def http_headers_received(self, event: h3_events.HeadersReceived) -> None:
|
||||
assert event.push_id is None
|
||||
headers: dict[bytes, bytes] = {}
|
||||
for name, value in event.headers:
|
||||
headers[name] = value
|
||||
response = []
|
||||
if event.stream_id not in self._seen_headers:
|
||||
self._seen_headers.add(event.stream_id)
|
||||
assert headers[b":authority"] == b"example.mitmproxy.org"
|
||||
assert headers[b":method"] == b"GET"
|
||||
assert headers[b":path"] == b"/test"
|
||||
response.append((b":status", b"200"))
|
||||
response.append((b"x-response", headers[b"x-request"]))
|
||||
self.http.send_headers(
|
||||
stream_id=event.stream_id,
|
||||
headers=response,
|
||||
end_stream=event.stream_ended
|
||||
)
|
||||
self.transmit()
|
||||
|
||||
def http_data_received(self, event: h3_events.DataReceived) -> None:
|
||||
assert event.push_id is None
|
||||
assert event.stream_id in self._seen_headers
|
||||
try:
|
||||
self.http.send_data(
|
||||
stream_id=event.stream_id,
|
||||
data=event.data,
|
||||
end_stream=event.stream_ended,
|
||||
)
|
||||
except FrameUnexpected:
|
||||
if event.data or not event.stream_ended:
|
||||
raise
|
||||
self._quic.send_stream_data(
|
||||
stream_id=event.stream_id,
|
||||
data=b"",
|
||||
end_stream=True,
|
||||
)
|
||||
self.transmit()
|
||||
|
||||
def http_event_received(self, event: h3_events.H3Event) -> None:
|
||||
if isinstance(event, h3_events.HeadersReceived):
|
||||
self.http_headers_received(event)
|
||||
elif isinstance(event, h3_events.DataReceived):
|
||||
self.http_data_received(event)
|
||||
else:
|
||||
raise AssertionError(event)
|
||||
|
||||
def quic_event_received(self, event: quic_events.QuicEvent) -> None:
|
||||
if isinstance(event, quic_events.ProtocolNegotiated):
|
||||
self.http = H3Connection(self._quic)
|
||||
if self.http is not None:
|
||||
for http_event in self.http.handle_event(event):
|
||||
self.http_event_received(http_event)
|
||||
|
||||
|
||||
class QuicDatagramEchoServer(QuicConnectionProtocol):
|
||||
def quic_event_received(self, event: quic_events.QuicEvent) -> None:
|
||||
if isinstance(event, quic_events.DatagramFrameReceived):
|
||||
self._quic.send_datagram_frame(event.data)
|
||||
self.transmit()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def quic_server(create_protocol, alpn: list[str]) -> AsyncGenerator[Address, None]:
|
||||
configuration = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=alpn,
|
||||
max_datagram_frame_size=65536,
|
||||
)
|
||||
configuration.load_cert_chain(
|
||||
certfile=tlsdata.path("../net/data/verificationcerts/trusted-leaf.crt"),
|
||||
keyfile=tlsdata.path("../net/data/verificationcerts/trusted-leaf.key"),
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
transport, server = await loop.create_datagram_endpoint(
|
||||
lambda: QuicServer(
|
||||
configuration=configuration,
|
||||
create_protocol=create_protocol,
|
||||
),
|
||||
local_addr=("127.0.0.1", 0),
|
||||
)
|
||||
try:
|
||||
yield transport.get_extra_info("sockname")
|
||||
finally:
|
||||
server.close()
|
||||
|
||||
|
||||
class QuicClient(QuicConnectionProtocol):
|
||||
TIMEOUT: ClassVar[int] = 5
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._waiter = self._loop.create_future()
|
||||
|
||||
def quic_event_received(self, event: quic_events.QuicEvent) -> None:
|
||||
if not self._waiter.done():
|
||||
if isinstance(event, quic_events.ConnectionTerminated):
|
||||
self._waiter.set_exception(QuicConnectionError(
|
||||
event.error_code, event.frame_type, event.reason_phrase
|
||||
))
|
||||
elif isinstance(event, quic_events.HandshakeCompleted):
|
||||
self._waiter.set_result(None)
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
if not self._waiter.done():
|
||||
self._waiter.set_exception(exc)
|
||||
return super().connection_lost(exc)
|
||||
|
||||
async def wait_handshake(self) -> None:
|
||||
return await asyncio.wait_for(self._waiter, timeout=QuicClient.TIMEOUT)
|
||||
|
||||
|
||||
class QuicDatagramClient(QuicClient):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._datagram: asyncio.Future[bytes] = self._loop.create_future()
|
||||
|
||||
def quic_event_received(self, event: quic_events.QuicEvent) -> None:
|
||||
super().quic_event_received(event)
|
||||
if not self._datagram.done():
|
||||
if isinstance(event, quic_events.DatagramFrameReceived):
|
||||
self._datagram.set_result(event.data)
|
||||
elif isinstance(event, quic_events.ConnectionTerminated):
|
||||
self._datagram.set_exception(QuicConnectionError(
|
||||
event.error_code, event.frame_type, event.reason_phrase
|
||||
))
|
||||
|
||||
def send_datagram(self, data: bytes) -> None:
|
||||
self._quic.send_datagram_frame(data)
|
||||
self.transmit()
|
||||
|
||||
async def recv_datagram(self) -> bytes:
|
||||
return await asyncio.wait_for(self._datagram, timeout=QuicClient.TIMEOUT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class H3Response:
|
||||
waiter: asyncio.Future[H3Response]
|
||||
stream_id: int
|
||||
headers: Optional[h3_events.H3Event] = None
|
||||
data: Optional[bytes] = None
|
||||
trailers: Optional[h3_events.H3Event] = None
|
||||
callback: Optional[Callable[[str], None]] = None
|
||||
|
||||
async def wait_result(self) -> H3Response:
|
||||
return await asyncio.wait_for(self.waiter, timeout=QuicClient.TIMEOUT)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
super().__setattr__(name, value)
|
||||
if self.callback:
|
||||
self.callback(name)
|
||||
|
||||
|
||||
class H3Client(QuicClient):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._responses: dict[int, H3Response] = dict()
|
||||
self.http = H3Connection(self._quic)
|
||||
|
||||
def http_headers_received(self, event: h3_events.HeadersReceived) -> None:
|
||||
assert event.push_id is None
|
||||
response = self._responses[event.stream_id]
|
||||
if response.waiter.done():
|
||||
return
|
||||
if response.headers is None:
|
||||
response.headers = event.headers
|
||||
if event.stream_ended:
|
||||
response.waiter.set_result(response)
|
||||
elif response.trailers is None:
|
||||
response.trailers = event.headers
|
||||
if event.stream_ended:
|
||||
response.waiter.set_result(response)
|
||||
else:
|
||||
response.waiter.set_exception(Exception("Headers after trailers received."))
|
||||
|
||||
def http_data_received(self, event: h3_events.DataReceived) -> None:
|
||||
assert event.push_id is None
|
||||
response = self._responses[event.stream_id]
|
||||
if response.waiter.done():
|
||||
return
|
||||
if response.headers is None:
|
||||
response.waiter.set_exception(Exception("Data without headers received."))
|
||||
elif response.trailers is None:
|
||||
if response.data is None:
|
||||
response.data = event.data
|
||||
else:
|
||||
response.data = response.data + event.data
|
||||
if event.stream_ended:
|
||||
response.waiter.set_result(response)
|
||||
elif event.data or not event.stream_ended:
|
||||
response.waiter.set_exception(Exception("Data after trailers received."))
|
||||
else:
|
||||
response.waiter.set_result(response)
|
||||
|
||||
def http_event_received(self, event: h3_events.H3Event) -> None:
|
||||
if isinstance(event, h3_events.HeadersReceived):
|
||||
self.http_headers_received(event)
|
||||
elif isinstance(event, h3_events.DataReceived):
|
||||
self.http_data_received(event)
|
||||
else:
|
||||
raise AssertionError(event)
|
||||
|
||||
def quic_event_received(self, event: quic_events.QuicEvent) -> None:
|
||||
super().quic_event_received(event)
|
||||
for http_event in self.http.handle_event(event):
|
||||
self.http_event_received(http_event)
|
||||
|
||||
def request(
|
||||
self,
|
||||
headers: h3_events.H3Event,
|
||||
data: Optional[bytes] = None,
|
||||
trailers: Optional[h3_events.H3Event] = None,
|
||||
end_stream: bool = True,
|
||||
) -> H3Response:
|
||||
stream_id = self._quic.get_next_available_stream_id()
|
||||
self.http.send_headers(
|
||||
stream_id=stream_id,
|
||||
headers=headers,
|
||||
end_stream=data is None and trailers is None and end_stream,
|
||||
)
|
||||
if data is not None:
|
||||
self.http.send_data(
|
||||
stream_id=stream_id,
|
||||
data=data,
|
||||
end_stream=trailers is None and end_stream,
|
||||
)
|
||||
if trailers is not None:
|
||||
self.http.send_headers(
|
||||
stream_id=stream_id,
|
||||
headers=trailers,
|
||||
end_stream=end_stream,
|
||||
)
|
||||
waiter = self._loop.create_future()
|
||||
response = H3Response(waiter=waiter, stream_id=stream_id)
|
||||
self._responses[stream_id] = response
|
||||
self.transmit()
|
||||
return response
|
||||
|
||||
|
||||
T = TypeVar("T", bound=QuicClient)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def quic_connect(
|
||||
cls: type[T],
|
||||
alpn: list[str],
|
||||
address: Address,
|
||||
) -> AsyncGenerator[T, None]:
|
||||
configuration = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=alpn,
|
||||
server_name="example.mitmproxy.org",
|
||||
verify_mode=ssl.CERT_NONE,
|
||||
max_datagram_frame_size=65536,
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: cls(QuicConnection(configuration=configuration)),
|
||||
local_addr=("127.0.0.1", 0),
|
||||
)
|
||||
assert isinstance(protocol, cls)
|
||||
try:
|
||||
protocol.connect(address)
|
||||
await protocol.wait_handshake()
|
||||
yield protocol
|
||||
finally:
|
||||
protocol.close()
|
||||
await protocol.wait_closed()
|
||||
transport.close()
|
||||
|
||||
|
||||
async def _test_echo(client: H3Client, strict: bool) -> None:
|
||||
def assert_no_data(response: H3Response):
|
||||
if strict:
|
||||
assert response.data is None
|
||||
else:
|
||||
assert not response.data
|
||||
|
||||
headers = [
|
||||
(b":scheme", b"https"),
|
||||
(b":authority", b"example.mitmproxy.org"),
|
||||
(b":method", b"GET"),
|
||||
(b":path", b"/test"),
|
||||
]
|
||||
r1 = await client.request(
|
||||
headers=headers + [(b"x-request", b"justheaders")],
|
||||
data=None,
|
||||
trailers=None,
|
||||
).wait_result()
|
||||
assert r1.headers == [
|
||||
(b":status", b"200"),
|
||||
(b"x-response", b"justheaders"),
|
||||
]
|
||||
assert_no_data(r1)
|
||||
assert r1.trailers is None
|
||||
|
||||
r2 = await client.request(
|
||||
headers=headers + [(b"x-request", b"hasdata")],
|
||||
data=b"echo",
|
||||
trailers=None,
|
||||
).wait_result()
|
||||
assert r2.headers == [
|
||||
(b":status", b"200"),
|
||||
(b"x-response", b"hasdata"),
|
||||
]
|
||||
assert r2.data == b"echo"
|
||||
assert r2.trailers is None
|
||||
|
||||
r3 = await client.request(
|
||||
headers=headers + [(b"x-request", b"nodata")],
|
||||
data=None,
|
||||
trailers=[(b"x-request", b"buttrailers")],
|
||||
).wait_result()
|
||||
assert r3.headers == [
|
||||
(b":status", b"200"),
|
||||
(b"x-response", b"nodata"),
|
||||
]
|
||||
assert_no_data(r3)
|
||||
assert r3.trailers == [(b"x-response", b"buttrailers")]
|
||||
|
||||
r4 = await client.request(
|
||||
headers=headers + [(b"x-request", b"this")],
|
||||
data=b"has",
|
||||
trailers=[(b"x-request", b"everything")],
|
||||
).wait_result()
|
||||
assert r4.headers == [
|
||||
(b":status", b"200"),
|
||||
(b"x-response", b"this"),
|
||||
]
|
||||
assert r4.data == b"has"
|
||||
assert r4.trailers == [(b"x-response", b"everything")]
|
||||
|
||||
# the following test makes sure that we behave properly if end_stream is sent separately
|
||||
r5 = client.request(
|
||||
headers=headers + [(b"x-request", b"this")],
|
||||
data=b"has",
|
||||
trailers=[(b"x-request", b"everything but end_stream")],
|
||||
end_stream=False,
|
||||
)
|
||||
if not strict:
|
||||
trailer_waiter = asyncio.get_running_loop().create_future()
|
||||
r5.callback = lambda name: name != "trailers" or trailer_waiter.set_result(None)
|
||||
await asyncio.wait_for(trailer_waiter, timeout=QuicClient.TIMEOUT)
|
||||
assert r5.trailers is not None
|
||||
assert not r5.waiter.done()
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
client._quic.send_stream_data(
|
||||
stream_id=r5.stream_id,
|
||||
data=b"",
|
||||
end_stream=True,
|
||||
)
|
||||
client.transmit()
|
||||
await r5.wait_result()
|
||||
assert r5.headers == [
|
||||
(b":status", b"200"),
|
||||
(b"x-response", b"this"),
|
||||
]
|
||||
assert r5.data == b"has"
|
||||
assert r5.trailers == [(b"x-response", b"everything but end_stream")]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("connection_strategy", ["lazy", "eager"])
|
||||
@pytest.mark.parametrize("scheme", ["http3", "quic"])
|
||||
async def test_reverse_http3_and_quic_stream(
|
||||
caplog_async, scheme: str, connection_strategy: str
|
||||
) -> None:
|
||||
caplog_async.set_level("INFO")
|
||||
ps = Proxyserver()
|
||||
nl = NextLayer()
|
||||
ta = TlsConfig()
|
||||
with taddons.context(ps, nl, ta) as tctx:
|
||||
tctx.options.keep_host_header = True
|
||||
tctx.options.connection_strategy = connection_strategy
|
||||
ta.configure(["confdir"])
|
||||
async with quic_server(H3EchoServer, alpn=["h3"]) as server_addr:
|
||||
mode = f"reverse:{scheme}://{server_addr[0]}:{server_addr[1]}@127.0.0.1:0"
|
||||
tctx.configure(
|
||||
ta,
|
||||
ssl_verify_upstream_trusted_ca=tlsdata.path(
|
||||
"../net/data/verificationcerts/trusted-root.crt"
|
||||
),
|
||||
)
|
||||
tctx.configure(ps, mode=[mode])
|
||||
assert await ps.setup_servers()
|
||||
ps.running()
|
||||
await caplog_async.await_log(f"reverse proxy to {scheme}://{server_addr[0]}:{server_addr[1]} listening")
|
||||
assert ps.servers
|
||||
addr = ps.servers[mode].listen_addrs[0]
|
||||
async with quic_connect(H3Client, alpn=["h3"], address=addr) as client:
|
||||
await _test_echo(client, strict=scheme == "http3")
|
||||
assert len(ps.connections) == 1
|
||||
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log(f"Stopped reverse proxy to {scheme}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("connection_strategy", ["lazy", "eager"])
|
||||
async def test_reverse_quic_datagram(caplog_async, connection_strategy: str) -> None:
|
||||
caplog_async.set_level("INFO")
|
||||
ps = Proxyserver()
|
||||
nl = NextLayer()
|
||||
ta = TlsConfig()
|
||||
with taddons.context(ps, nl, ta) as tctx:
|
||||
tctx.options.keep_host_header = True
|
||||
tctx.options.connection_strategy = connection_strategy
|
||||
ta.configure(["confdir"])
|
||||
async with quic_server(QuicDatagramEchoServer, alpn=["dgram"]) as server_addr:
|
||||
mode = f"reverse:quic://{server_addr[0]}:{server_addr[1]}@127.0.0.1:0"
|
||||
tctx.configure(
|
||||
ta,
|
||||
ssl_verify_upstream_trusted_ca=tlsdata.path(
|
||||
"../net/data/verificationcerts/trusted-root.crt"
|
||||
),
|
||||
)
|
||||
tctx.configure(ps, mode=[mode])
|
||||
assert await ps.setup_servers()
|
||||
ps.running()
|
||||
await caplog_async.await_log(f"reverse proxy to quic://{server_addr[0]}:{server_addr[1]} listening")
|
||||
assert ps.servers
|
||||
addr = ps.servers[mode].listen_addrs[0]
|
||||
async with quic_connect(QuicDatagramClient, alpn=["dgram"], address=addr) as client:
|
||||
client.send_datagram(b"echo")
|
||||
assert await client.recv_datagram() == b"echo"
|
||||
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log("Stopped reverse proxy to quic")
|
||||
|
||||
|
||||
async def test_regular_http3(caplog_async, monkeypatch) -> None:
|
||||
caplog_async.set_level("INFO")
|
||||
ps = Proxyserver()
|
||||
nl = NextLayer()
|
||||
ta = TlsConfig()
|
||||
with taddons.context(ps, nl, ta) as tctx:
|
||||
ta.configure(["confdir"])
|
||||
async with quic_server(H3EchoServer, alpn=["h3"]) as server_addr:
|
||||
orig_open_connection = udp.open_connection
|
||||
|
||||
def open_connection_path(
|
||||
host: str, port: int, *args, **kwargs
|
||||
) -> udp.UdpClient:
|
||||
if host == "example.mitmproxy.org" and port == 443:
|
||||
host = server_addr[0]
|
||||
port = server_addr[1]
|
||||
return orig_open_connection(host, port, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(udp, "open_connection", open_connection_path)
|
||||
mode = f"http3@127.0.0.1:0"
|
||||
tctx.configure(
|
||||
ta,
|
||||
ssl_verify_upstream_trusted_ca=tlsdata.path(
|
||||
"../net/data/verificationcerts/trusted-root.crt"
|
||||
),
|
||||
)
|
||||
tctx.configure(ps, mode=[mode])
|
||||
assert await ps.setup_servers()
|
||||
ps.running()
|
||||
await caplog_async.await_log(f"HTTP3 proxy listening")
|
||||
assert ps.servers
|
||||
addr = ps.servers[mode].listen_addrs[0]
|
||||
async with quic_connect(H3Client, alpn=["h3"], address=addr) as client:
|
||||
await _test_echo(client=client, strict=True)
|
||||
assert len(ps.connections) == 1
|
||||
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log("Stopped HTTP3 proxy")
|
||||
|
@ -5,13 +5,14 @@ from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
from cryptography import x509
|
||||
from OpenSSL import SSL
|
||||
from mitmproxy import certs, connection, tls, options
|
||||
from mitmproxy.addons import tlsconfig
|
||||
from mitmproxy.proxy import context
|
||||
from mitmproxy.proxy.layers import modes, tls as proxy_tls
|
||||
from mitmproxy.proxy.layers import modes, quic, tls as proxy_tls
|
||||
from mitmproxy.test import taddons
|
||||
from test.mitmproxy.proxy.layers import test_tls
|
||||
from test.mitmproxy.proxy.layers import test_quic, test_tls
|
||||
|
||||
|
||||
def test_alpn_select_callback():
|
||||
@ -163,6 +164,19 @@ class TestTlsConfig:
|
||||
|
||||
return True
|
||||
|
||||
def quic_do_handshake(
|
||||
self,
|
||||
tssl_client: test_quic.SSLTest,
|
||||
tssl_server: test_quic.SSLTest,
|
||||
) -> bool:
|
||||
tssl_server.write(tssl_client.read())
|
||||
tssl_client.write(tssl_server.read())
|
||||
tssl_server.write(tssl_client.read())
|
||||
return (
|
||||
tssl_client.handshake_completed()
|
||||
and tssl_server.handshake_completed()
|
||||
)
|
||||
|
||||
def test_tls_start_client(self, tdata):
|
||||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
@ -190,6 +204,34 @@ class TestTlsConfig:
|
||||
("DNS", "example.mitmproxy.org"),
|
||||
)
|
||||
|
||||
def test_quic_start_client(self, tdata):
|
||||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
ta.configure(["confdir"])
|
||||
tctx.configure(
|
||||
ta,
|
||||
certs=[
|
||||
tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.pem")
|
||||
],
|
||||
ciphers_client="CHACHA20_POLY1305_SHA256",
|
||||
)
|
||||
ctx = _ctx(tctx.options)
|
||||
|
||||
tls_start = quic.QuicTlsData(ctx.client, context=ctx)
|
||||
ta.quic_start_client(tls_start)
|
||||
settings_server = tls_start.settings
|
||||
settings_server.alpn_protocols = ["h3"]
|
||||
tssl_server = test_quic.SSLTest(server_side=True, settings=settings_server)
|
||||
|
||||
# assert that a preexisting settings is not overwritten
|
||||
ta.quic_start_client(tls_start)
|
||||
assert settings_server is tls_start.settings
|
||||
|
||||
tssl_client = test_quic.SSLTest(alpn=["h3"])
|
||||
assert self.quic_do_handshake(tssl_client, tssl_server)
|
||||
san = tssl_client.quic.tls._peer_certificate.extensions.get_extension_for_class(x509.SubjectAlternativeName)
|
||||
assert san.value.get_values_for_type(x509.DNSName) == ["example.mitmproxy.org"]
|
||||
|
||||
def test_tls_start_server_cannot_verify(self):
|
||||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
@ -240,6 +282,32 @@ class TestTlsConfig:
|
||||
tssl_server = test_tls.SSLTest(server_side=True, sni=hostname.encode())
|
||||
assert self.do_handshake(tssl_client, tssl_server)
|
||||
|
||||
@pytest.mark.parametrize("hostname", ["example.mitmproxy.org", "192.0.2.42"])
|
||||
def test_quic_start_server_verify_ok(self, hostname, tdata):
|
||||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
ctx = _ctx(tctx.options)
|
||||
ctx.server.address = (hostname, 443)
|
||||
tctx.configure(
|
||||
ta,
|
||||
ssl_verify_upstream_trusted_ca=tdata.path(
|
||||
"mitmproxy/net/data/verificationcerts/trusted-root.crt"
|
||||
),
|
||||
)
|
||||
|
||||
tls_start = quic.QuicTlsData(ctx.server, context=ctx)
|
||||
ta.quic_start_server(tls_start)
|
||||
settings_client = tls_start.settings
|
||||
settings_client.alpn_protocols = ["h3"]
|
||||
tssl_client = test_quic.SSLTest(settings=settings_client)
|
||||
|
||||
# assert that a preexisting ssl_conn is not overwritten
|
||||
ta.quic_start_server(tls_start)
|
||||
assert settings_client is tls_start.settings
|
||||
|
||||
tssl_server = test_quic.SSLTest(server_side=True, sni=hostname.encode(), alpn=["h3"])
|
||||
assert self.quic_do_handshake(tssl_client, tssl_server)
|
||||
|
||||
def test_tls_start_server_insecure(self):
|
||||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
@ -259,6 +327,25 @@ class TestTlsConfig:
|
||||
tssl_server = test_tls.SSLTest(server_side=True)
|
||||
assert self.do_handshake(tssl_client, tssl_server)
|
||||
|
||||
def test_quic_start_server_insecure(self):
|
||||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
ctx = _ctx(tctx.options)
|
||||
ctx.server.address = ("example.mitmproxy.org", 443)
|
||||
ctx.client.alpn_offers = [b"h3"]
|
||||
|
||||
tctx.configure(
|
||||
ta,
|
||||
ssl_verify_upstream_trusted_ca=None,
|
||||
ssl_insecure=True,
|
||||
ciphers_server="CHACHA20_POLY1305_SHA256",
|
||||
)
|
||||
tls_start = quic.QuicTlsData(ctx.server, context=ctx)
|
||||
ta.quic_start_server(tls_start)
|
||||
tssl_client = test_quic.SSLTest(settings=tls_start.settings)
|
||||
tssl_server = test_quic.SSLTest(server_side=True, alpn=["h3"])
|
||||
assert self.quic_do_handshake(tssl_client, tssl_server)
|
||||
|
||||
def test_alpn_selection(self):
|
||||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
|
@ -13,7 +13,9 @@ def tctx() -> context.Context:
|
||||
opts = options.Options()
|
||||
Proxyserver().load(opts)
|
||||
return context.Context(
|
||||
connection.Client(peername=("client", 1234), sockname=("127.0.0.1", 8080), timestamp_start=1605699329), opts
|
||||
connection.Client(peername=("client", 1234), sockname=("127.0.0.1", 8080),
|
||||
timestamp_start=1605699329, state=connection.ConnectionState.OPEN),
|
||||
opts
|
||||
)
|
||||
|
||||
|
||||
|
1145
test/mitmproxy/proxy/layers/http/test_http3.py
Normal file
1145
test/mitmproxy/proxy/layers/http/test_http3.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -4,18 +4,19 @@ import pytest
|
||||
from mitmproxy import dns
|
||||
|
||||
from mitmproxy.addons.proxyauth import ProxyAuth
|
||||
from mitmproxy.connection import Client, Server
|
||||
from mitmproxy.connection import Client, ConnectionState, Server
|
||||
from mitmproxy.proxy import layers
|
||||
from mitmproxy.proxy.commands import (
|
||||
CloseConnection,
|
||||
Log,
|
||||
OpenConnection,
|
||||
RequestWakeup,
|
||||
SendData,
|
||||
)
|
||||
from mitmproxy.proxy.context import Context
|
||||
from mitmproxy.proxy.events import ConnectionClosed, DataReceived
|
||||
from mitmproxy.proxy.layer import NextLayer, NextLayerHook
|
||||
from mitmproxy.proxy.layers import http, modes, tcp, tls
|
||||
from mitmproxy.proxy.layers import http, modes, quic, tcp, tls, udp
|
||||
from mitmproxy.proxy.layers.http import HTTPMode
|
||||
from mitmproxy.proxy.layers.tcp import TcpMessageHook, TcpStartHook
|
||||
from mitmproxy.proxy.layers.tls import (
|
||||
@ -25,7 +26,8 @@ from mitmproxy.proxy.layers.tls import (
|
||||
)
|
||||
from mitmproxy.proxy.mode_specs import ProxyMode
|
||||
from mitmproxy.tcp import TCPFlow
|
||||
from mitmproxy.test import tflow
|
||||
from mitmproxy.test import taddons, tflow
|
||||
from mitmproxy.udp import UDPFlow
|
||||
from test.mitmproxy.proxy.layers.test_tls import (
|
||||
reply_tls_start_client,
|
||||
reply_tls_start_server,
|
||||
@ -43,12 +45,12 @@ def test_upstream_https(tctx):
|
||||
curl -x localhost:8080 -k http://example.com
|
||||
"""
|
||||
tctx1 = Context(
|
||||
Client(peername=("client", 1234), sockname=("127.0.0.1", 8080), timestamp_start=1605699329),
|
||||
Client(peername=("client", 1234), sockname=("127.0.0.1", 8080), timestamp_start=1605699329, state=ConnectionState.OPEN),
|
||||
copy.deepcopy(tctx.options),
|
||||
)
|
||||
tctx1.client.proxy_mode = ProxyMode.parse("upstream:https://example.mitmproxy.org:8081")
|
||||
tctx2 = Context(
|
||||
Client(peername=("client", 4321), sockname=("127.0.0.1", 8080), timestamp_start=1605699329),
|
||||
Client(peername=("client", 4321), sockname=("127.0.0.1", 8080), timestamp_start=1605699329, state=ConnectionState.OPEN),
|
||||
copy.deepcopy(tctx.options),
|
||||
)
|
||||
assert tctx2.client.proxy_mode == ProxyMode.parse("regular")
|
||||
@ -159,6 +161,8 @@ def test_reverse_dns(tctx):
|
||||
assert (
|
||||
Playbook(modes.ReverseProxy(tctx), hooks=False)
|
||||
>> DataReceived(tctx.client, tflow.tdnsreq().packed)
|
||||
<< NextLayerHook(Placeholder(NextLayer))
|
||||
>> reply_next_layer(layers.DNSLayer)
|
||||
<< layers.dns.DnsRequestHook(f)
|
||||
>> reply(None)
|
||||
<< OpenConnection(server)
|
||||
@ -168,6 +172,53 @@ def test_reverse_dns(tctx):
|
||||
assert server().address == ("8.8.8.8", 53)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("keep_host_header", [True, False])
|
||||
def test_quic(tctx: Context, keep_host_header: bool):
|
||||
with taddons.context():
|
||||
tctx.options.keep_host_header = keep_host_header
|
||||
tctx.server.sni = "other"
|
||||
tctx.client.proxy_mode = ProxyMode.parse("reverse:quic://1.2.3.4:5")
|
||||
client_hello = Placeholder(bytes)
|
||||
|
||||
def set_settings(data: quic.QuicTlsData):
|
||||
data.settings = quic.QuicTlsSettings()
|
||||
|
||||
assert (
|
||||
Playbook(modes.ReverseProxy(tctx))
|
||||
<< OpenConnection(tctx.server)
|
||||
>> reply(None)
|
||||
<< quic.QuicStartServerHook(Placeholder(quic.QuicTlsData))
|
||||
>> reply(side_effect=set_settings)
|
||||
<< SendData(tctx.server, client_hello)
|
||||
<< RequestWakeup(Placeholder(float))
|
||||
)
|
||||
assert tctx.server.address == ("1.2.3.4", 5)
|
||||
assert quic.quic_parse_client_hello(client_hello()).sni == (
|
||||
"other" if keep_host_header else "1.2.3.4"
|
||||
)
|
||||
|
||||
|
||||
def test_udp(tctx: Context):
|
||||
tctx.client.proxy_mode = ProxyMode.parse("reverse:udp://1.2.3.4:5")
|
||||
flow = Placeholder(UDPFlow)
|
||||
assert (
|
||||
Playbook(modes.ReverseProxy(tctx))
|
||||
<< OpenConnection(tctx.server)
|
||||
>> reply(None)
|
||||
>> DataReceived(tctx.client, b"test-input")
|
||||
<< NextLayerHook(Placeholder(NextLayer))
|
||||
>> reply_next_layer(layers.UDPLayer)
|
||||
<< udp.UdpStartHook(flow)
|
||||
>> reply()
|
||||
<< udp.UdpMessageHook(flow)
|
||||
>> reply()
|
||||
<< SendData(tctx.server, b"test-input")
|
||||
)
|
||||
assert tctx.server.address == ("1.2.3.4", 5)
|
||||
assert len(flow().messages) == 1
|
||||
assert flow().messages[0].content == b"test-input"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("patch", [True, False])
|
||||
@pytest.mark.parametrize("connection_strategy", ["eager", "lazy"])
|
||||
def test_reverse_proxy_tcp_over_tls(
|
||||
|
1164
test/mitmproxy/proxy/layers/test_quic.py
Normal file
1164
test/mitmproxy/proxy/layers/test_quic.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from mitmproxy.proxy.commands import CloseConnection, OpenConnection, SendData
|
||||
from mitmproxy.proxy.commands import CloseConnection, CloseTcpConnection, OpenConnection, SendData
|
||||
from mitmproxy.proxy.events import ConnectionClosed, DataReceived
|
||||
from mitmproxy.proxy.layers import tcp
|
||||
from mitmproxy.proxy.layers.tcp import TcpMessageInjected
|
||||
@ -52,7 +52,7 @@ def test_simple(tctx):
|
||||
>> reply()
|
||||
<< SendData(tctx.client, b"hi")
|
||||
>> ConnectionClosed(tctx.server)
|
||||
<< CloseConnection(tctx.client, half_close=True)
|
||||
<< CloseTcpConnection(tctx.client, half_close=True)
|
||||
>> ConnectionClosed(tctx.client)
|
||||
<< CloseConnection(tctx.server)
|
||||
<< tcp.TcpEndHook(f)
|
||||
@ -88,7 +88,7 @@ def test_receive_data_after_half_close(tctx):
|
||||
>> DataReceived(tctx.client, b"eof-delimited-request")
|
||||
<< SendData(tctx.server, b"eof-delimited-request")
|
||||
>> ConnectionClosed(tctx.client)
|
||||
<< CloseConnection(tctx.server, half_close=True)
|
||||
<< CloseTcpConnection(tctx.server, half_close=True)
|
||||
>> DataReceived(tctx.server, b"i'm late")
|
||||
<< SendData(tctx.client, b"i'm late")
|
||||
>> ConnectionClosed(tctx.server)
|
||||
|
@ -639,7 +639,7 @@ class TestClientTLS:
|
||||
>> events.DataReceived(Server(address=None), b"data on other stream")
|
||||
<< commands.Log(">> DataReceived(server, b'data on other stream')", DEBUG)
|
||||
<< commands.Log(
|
||||
"Swallowing DataReceived(server, b'data on other stream') as handshake failed.",
|
||||
"[tls] Swallowing DataReceived(server, b'data on other stream') as handshake failed.",
|
||||
DEBUG,
|
||||
)
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ def test_dataclasses(tconn):
|
||||
assert repr(commands.SendData(tconn, b"foo"))
|
||||
assert repr(commands.OpenConnection(tconn))
|
||||
assert repr(commands.CloseConnection(tconn))
|
||||
assert repr(commands.CloseTcpConnection(tconn, half_close=True))
|
||||
assert repr(commands.Log("hello"))
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ def test_make():
|
||||
context = MagicMock()
|
||||
assert ServerInstance.make("regular", manager)
|
||||
|
||||
for mode in ["regular", "upstream:example.com", "transparent", "reverse:example.com", "socks5"]:
|
||||
for mode in ["regular", "http3", "upstream:example.com", "transparent", "reverse:example.com", "socks5"]:
|
||||
inst = ServerInstance.make(mode, manager)
|
||||
assert inst
|
||||
assert inst.make_top_layer(context)
|
||||
|
@ -53,9 +53,11 @@ def test_listen_addr():
|
||||
|
||||
def test_parse_specific_modes():
|
||||
assert ProxyMode.parse("regular")
|
||||
assert ProxyMode.parse("http3")
|
||||
assert ProxyMode.parse("transparent")
|
||||
assert ProxyMode.parse("upstream:https://proxy")
|
||||
assert ProxyMode.parse("reverse:https://host@443")
|
||||
assert ProxyMode.parse("reverse:http3://host@443")
|
||||
assert ProxyMode.parse("socks5")
|
||||
assert ProxyMode.parse("dns")
|
||||
assert ProxyMode.parse("reverse:dns://8.8.8.8")
|
||||
@ -70,6 +72,9 @@ def test_parse_specific_modes():
|
||||
with pytest.raises(ValueError, match="takes no arguments"):
|
||||
ProxyMode.parse("regular:configuration")
|
||||
|
||||
with pytest.raises(ValueError, match="takes no arguments"):
|
||||
ProxyMode.parse("http3:configuration")
|
||||
|
||||
with pytest.raises(ValueError, match="invalid upstream proxy scheme"):
|
||||
ProxyMode.parse("upstream:dns://example.com")
|
||||
|
||||
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||
import pytest
|
||||
|
||||
from mitmproxy.proxy import tunnel, layer
|
||||
from mitmproxy.proxy.commands import SendData, Log, CloseConnection, OpenConnection
|
||||
from mitmproxy.proxy.commands import CloseTcpConnection, SendData, Log, CloseConnection, OpenConnection
|
||||
from mitmproxy.connection import Server, ConnectionState
|
||||
from mitmproxy.proxy.context import Context
|
||||
from mitmproxy.proxy.events import Event, DataReceived, Start, ConnectionClosed
|
||||
@ -24,7 +24,7 @@ class TChildLayer(layer.Layer):
|
||||
err = yield OpenConnection(self.context.server)
|
||||
yield Log(f"Opened: {err=}. Server state: {self.context.server.state.name}")
|
||||
elif isinstance(event, DataReceived) and event.data == b"half-close":
|
||||
err = yield CloseConnection(event.connection, half_close=True)
|
||||
err = yield CloseTcpConnection(event.connection, half_close=True)
|
||||
elif isinstance(event, ConnectionClosed):
|
||||
yield Log(f"Got {event.connection.__class__.__name__.lower()} close.")
|
||||
yield CloseConnection(event.connection)
|
||||
@ -164,7 +164,7 @@ def test_tunnel_default_impls(tctx: Context):
|
||||
>> reply(None)
|
||||
<< Log("Opened: err=None. Server state: OPEN")
|
||||
>> DataReceived(server, b"half-close")
|
||||
<< CloseConnection(server, half_close=True)
|
||||
<< CloseTcpConnection(server, half_close=True)
|
||||
)
|
||||
|
||||
|
||||
|
@ -224,11 +224,10 @@ class Playbook:
|
||||
for cmd in cmds:
|
||||
pos += 1
|
||||
assert self.actual[pos] == cmd
|
||||
if isinstance(cmd, commands.CloseConnection):
|
||||
if cmd.half_close:
|
||||
cmd.connection.state &= ~ConnectionState.CAN_WRITE
|
||||
else:
|
||||
cmd.connection.state = ConnectionState.CLOSED
|
||||
if isinstance(cmd, commands.CloseTcpConnection) and cmd.half_close:
|
||||
cmd.connection.state &= ~ConnectionState.CAN_WRITE
|
||||
elif isinstance(cmd, commands.CloseConnection):
|
||||
cmd.connection.state = ConnectionState.CLOSED
|
||||
elif isinstance(cmd, commands.Log):
|
||||
need_to_emulate_log = (
|
||||
not self.logs
|
||||
|
@ -65,10 +65,12 @@ class TestCertStore:
|
||||
(tmp_path / "mitmproxy-ca.pem").write_bytes(cert)
|
||||
ca = certs.CertStore.from_store(tmp_path, "mitmproxy", 2048)
|
||||
assert ca.default_chain_file is None
|
||||
assert len(ca.default_chain_certs) == 1
|
||||
|
||||
(tmp_path / "mitmproxy-ca.pem").write_bytes(2 * cert)
|
||||
ca = certs.CertStore.from_store(tmp_path, "mitmproxy", 2048)
|
||||
assert ca.default_chain_file == (tmp_path / "mitmproxy-ca.pem")
|
||||
assert len(ca.default_chain_certs) == 2
|
||||
|
||||
def test_sans(self, tstore):
|
||||
c1 = tstore.get_cert("foo.com", ["*.bar.com"])
|
||||
|
@ -6,7 +6,9 @@ from mitmproxy.test.tflow import tclient_conn, tserver_conn
|
||||
|
||||
class TestConnection:
|
||||
def test_basic(self):
|
||||
c = Client(peername=("127.0.0.1", 52314), sockname=("127.0.0.1", 8080), timestamp_start=1607780791)
|
||||
c = Client(peername=("127.0.0.1", 52314), sockname=("127.0.0.1", 8080),
|
||||
timestamp_start=1607780791,
|
||||
state=ConnectionState.OPEN)
|
||||
assert not c.tls_established
|
||||
c.timestamp_tls_setup = 1607780792
|
||||
assert c.tls_established
|
||||
@ -39,7 +41,7 @@ class TestClient:
|
||||
c.timestamp_tls_setup = 1607780791
|
||||
assert str(c)
|
||||
c.alpn = b"foo"
|
||||
assert str(c) == "Client(127.0.0.1:52314, state=open, alpn=foo)"
|
||||
assert str(c) == "Client(127.0.0.1:52314, state=closed, alpn=foo)"
|
||||
|
||||
def test_state(self):
|
||||
c = tclient_conn()
|
||||
|
@ -60,3 +60,7 @@
|
||||
.resource-icon-dns {
|
||||
background-image: url(images/resourceDnsIcon.png);
|
||||
}
|
||||
|
||||
.resource-icon-quic {
|
||||
background-image: url(images/resourceQuicIcon.png);
|
||||
}
|
||||
|
BIN
web/src/images/resourceQuicIcon.png
Normal file
BIN
web/src/images/resourceQuicIcon.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.3 KiB |
@ -30,7 +30,7 @@ test("FlowView", async () => {
|
||||
|
||||
store.dispatch(flowActions.select(store.getState().flows.list[2].id));
|
||||
|
||||
fireEvent.click(screen.getByText("TCP Messages"));
|
||||
fireEvent.click(screen.getByText("Stream Data"));
|
||||
expect(asFragment()).toMatchSnapshot();
|
||||
|
||||
fireEvent.click(screen.getByText("Error"));
|
||||
@ -49,7 +49,7 @@ test("FlowView", async () => {
|
||||
|
||||
store.dispatch(flowActions.select(store.getState().flows.list[4].id));
|
||||
|
||||
fireEvent.click(screen.getByText("UDP Messages"));
|
||||
fireEvent.click(screen.getByText("Datagrams"));
|
||||
expect(asFragment()).toMatchSnapshot();
|
||||
|
||||
fireEvent.click(screen.getByText("Error"));
|
||||
|
@ -1006,7 +1006,7 @@ exports[`FlowView 7`] = `
|
||||
class="active"
|
||||
href="#"
|
||||
>
|
||||
TCP Messages
|
||||
Stream Data
|
||||
</a>
|
||||
<a
|
||||
class=""
|
||||
@ -1030,9 +1030,6 @@ exports[`FlowView 7`] = `
|
||||
<section
|
||||
class="tcp"
|
||||
>
|
||||
<h4>
|
||||
TCP Data
|
||||
</h4>
|
||||
<div
|
||||
class="contentview"
|
||||
>
|
||||
@ -1079,7 +1076,7 @@ exports[`FlowView 8`] = `
|
||||
class=""
|
||||
href="#"
|
||||
>
|
||||
TCP Messages
|
||||
Stream Data
|
||||
</a>
|
||||
<a
|
||||
class="active"
|
||||
@ -1443,7 +1440,7 @@ exports[`FlowView 12`] = `
|
||||
class="active"
|
||||
href="#"
|
||||
>
|
||||
UDP Messages
|
||||
Datagrams
|
||||
</a>
|
||||
<a
|
||||
class=""
|
||||
@ -1467,9 +1464,6 @@ exports[`FlowView 12`] = `
|
||||
<section
|
||||
class="udp"
|
||||
>
|
||||
<h4>
|
||||
UDP Data
|
||||
</h4>
|
||||
<div
|
||||
class="contentview"
|
||||
>
|
||||
@ -1516,7 +1510,7 @@ exports[`FlowView 13`] = `
|
||||
class=""
|
||||
href="#"
|
||||
>
|
||||
UDP Messages
|
||||
Datagrams
|
||||
</a>
|
||||
<a
|
||||
class="active"
|
||||
|
@ -38,6 +38,9 @@ icon.sortKey = flow => getIcon(flow)
|
||||
|
||||
const getIcon = (flow: Flow): string => {
|
||||
if (flow.type !== "http") {
|
||||
if (flow.client_conn.tls_version === "QUIC") {
|
||||
return `resource-icon-quic`;
|
||||
}
|
||||
return `resource-icon-${flow.type}`
|
||||
}
|
||||
if (flow.websocket) {
|
||||
|
@ -29,7 +29,9 @@ export default function Messages({flow, messages_meta}: MessagesPropTypes) {
|
||||
try {
|
||||
return JSON.parse(content)
|
||||
} catch (e) {
|
||||
const err: ContentViewData = {"description": "Network Error", lines: [[["error", `${content}`]]]};
|
||||
const err: ContentViewData[] = [
|
||||
{"description": "Network Error", lines: [[["error", `${content}`]]]}
|
||||
];
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
@ -6,9 +6,8 @@ import Messages from "./Messages";
|
||||
export default function TcpMessages({flow}: { flow: TCPFlow }) {
|
||||
return (
|
||||
<section className="tcp">
|
||||
<h4>TCP Data</h4>
|
||||
<Messages flow={flow} messages_meta={flow.messages_meta}/>
|
||||
</section>
|
||||
)
|
||||
}
|
||||
TcpMessages.displayName = "TCP Messages"
|
||||
TcpMessages.displayName = "Stream Data"
|
||||
|
@ -6,9 +6,8 @@ import Messages from "./Messages";
|
||||
export default function UdpMessages({flow}: { flow: UDPFlow }) {
|
||||
return (
|
||||
<section className="udp">
|
||||
<h4>UDP Data</h4>
|
||||
<Messages flow={flow} messages_meta={flow.messages_meta}/>
|
||||
</section>
|
||||
)
|
||||
}
|
||||
UdpMessages.displayName = "UDP Messages"
|
||||
UdpMessages.displayName = "Datagrams"
|
||||
|
@ -42,7 +42,6 @@ export interface OptionsState {
|
||||
proxy_debug: boolean
|
||||
proxyauth: string | undefined
|
||||
rawtcp: boolean
|
||||
rawudp: boolean
|
||||
readfile_filter: string | undefined
|
||||
rfile: string | undefined
|
||||
save_stream_file: string | undefined
|
||||
@ -133,7 +132,6 @@ export const defaultState: OptionsState = {
|
||||
proxy_debug: false,
|
||||
proxyauth: undefined,
|
||||
rawtcp: true,
|
||||
rawudp: true,
|
||||
readfile_filter: undefined,
|
||||
rfile: undefined,
|
||||
save_stream_file: undefined,
|
||||
|
Loading…
Reference in New Issue
Block a user