Split mitmproxy.proxy.layers.quic into subpackages (#7187)

* individual coverage: skip logic-free __init__ files

* split quic layer into subpackages

this commit should not introduce any functional changes
This commit is contained in:
Maximilian Hils 2024-09-18 19:22:51 +02:00 committed by GitHub
parent e7d1ad69b9
commit b53d2bd19a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1905 additions and 1706 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
from ._client_hello_parser import quic_parse_client_hello_from_datagrams
from ._commands import CloseQuicConnection
from ._commands import ResetQuicStream
from ._commands import SendQuicStreamData
from ._commands import StopSendingQuicStream
from ._events import QuicConnectionClosed
from ._events import QuicStreamDataReceived
from ._events import QuicStreamEvent
from ._events import QuicStreamReset
from ._events import QuicStreamStopSending
from ._hooks import QuicStartClientHook
from ._hooks import QuicStartServerHook
from ._hooks import QuicTlsData
from ._hooks import QuicTlsSettings
from ._raw_layers import QuicStreamLayer
from ._raw_layers import RawQuicLayer
from ._stream_layers import ClientQuicLayer
from ._stream_layers import error_code_to_str
from ._stream_layers import ServerQuicLayer
__all__ = [
"quic_parse_client_hello_from_datagrams",
"CloseQuicConnection",
"ResetQuicStream",
"SendQuicStreamData",
"StopSendingQuicStream",
"QuicConnectionClosed",
"QuicStreamDataReceived",
"QuicStreamEvent",
"QuicStreamReset",
"QuicStreamStopSending",
"QuicStartClientHook",
"QuicStartServerHook",
"QuicTlsData",
"QuicTlsSettings",
"QuicStreamLayer",
"RawQuicLayer",
"ClientQuicLayer",
"error_code_to_str",
"ServerQuicLayer",
]

View File

@ -0,0 +1,111 @@
"""
This module contains a very terrible QUIC client hello parser.
Nothing is more permanent than a temporary solution!
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Optional
from aioquic.buffer import Buffer as QuicBuffer
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.connection import QuicConnectionError
from aioquic.quic.logger import QuicLogger
from aioquic.quic.packet import PACKET_TYPE_INITIAL
from aioquic.quic.packet import pull_quic_header
from aioquic.tls import HandshakeType
from mitmproxy.tls import ClientHello
@dataclass
class QuicClientHello(Exception):
"""Helper error only used in `quic_parse_client_hello_from_datagrams`."""
data: bytes
def quic_parse_client_hello_from_datagrams(
datagrams: list[bytes],
) -> Optional[ClientHello]:
"""
Check if the supplied bytes contain a full ClientHello message,
and if so, parse it.
Args:
- msgs: list of ClientHello fragments received from client
Returns:
- A ClientHello object on success
- None, if the QUIC record is incomplete
Raises:
- A ValueError, if the passed ClientHello is invalid
"""
# ensure the first packet is indeed the initial one
buffer = QuicBuffer(data=datagrams[0])
header = pull_quic_header(buffer, 8)
if header.packet_type != PACKET_TYPE_INITIAL:
raise ValueError("Packet is not initial one.")
# patch aioquic to intercept the client hello
quic = QuicConnection(
configuration=QuicConfiguration(
is_client=False,
certificate="",
private_key="",
quic_logger=QuicLogger(),
),
original_destination_connection_id=header.destination_cid,
)
_initialize = quic._initialize
def server_handle_hello_replacement(
input_buf: QuicBuffer,
initial_buf: QuicBuffer,
handshake_buf: QuicBuffer,
onertt_buf: QuicBuffer,
) -> None:
assert input_buf.pull_uint8() == HandshakeType.CLIENT_HELLO
length = 0
for b in input_buf.pull_bytes(3):
length = (length << 8) | b
offset = input_buf.tell()
raise QuicClientHello(input_buf.data_slice(offset, offset + length))
def initialize_replacement(peer_cid: bytes) -> None:
try:
return _initialize(peer_cid)
finally:
quic.tls._server_handle_hello = server_handle_hello_replacement # type: ignore
quic._initialize = initialize_replacement # type: ignore
try:
for dgm in datagrams:
quic.receive_datagram(dgm, ("0.0.0.0", 0), now=time.time())
except QuicClientHello as hello:
try:
return ClientHello(hello.data)
except EOFError as e:
raise ValueError("Invalid ClientHello data.") from e
except QuicConnectionError as e:
raise ValueError(e.reason_phrase) from e
quic_logger = quic._configuration.quic_logger
assert isinstance(quic_logger, QuicLogger)
traces = quic_logger.to_dict().get("traces")
assert isinstance(traces, list)
for trace in traces:
quic_events = trace.get("events")
for event in quic_events:
if event["name"] == "transport:packet_dropped":
raise ValueError(
f"Invalid ClientHello packet: {event['data']['trigger']}"
)
return None # pragma: no cover # FIXME: this should have test coverage

View File

@ -0,0 +1,92 @@
from __future__ import annotations
from mitmproxy import connection
from mitmproxy.proxy import commands
class QuicStreamCommand(commands.ConnectionCommand):
"""Base class for all QUIC stream commands."""
stream_id: int
"""The ID of the stream the command was issued for."""
def __init__(self, connection: connection.Connection, stream_id: int) -> None:
super().__init__(connection)
self.stream_id = stream_id
class SendQuicStreamData(QuicStreamCommand):
"""Command that sends data on a stream."""
data: bytes
"""The data which should be sent."""
end_stream: bool
"""Whether the FIN bit should be set in the STREAM frame."""
def __init__(
self,
connection: connection.Connection,
stream_id: int,
data: bytes,
end_stream: bool = False,
) -> None:
super().__init__(connection, stream_id)
self.data = data
self.end_stream = end_stream
def __repr__(self):
target = repr(self.connection).partition("(")[0].lower()
end_stream = "[end_stream] " if self.end_stream else ""
return f"SendQuicStreamData({target} on {self.stream_id}, {end_stream}{self.data!r})"
class ResetQuicStream(QuicStreamCommand):
"""Abruptly terminate the sending part of a stream."""
error_code: int
"""An error code indicating why the stream is being reset."""
def __init__(
self, connection: connection.Connection, stream_id: int, error_code: int
) -> None:
super().__init__(connection, stream_id)
self.error_code = error_code
class StopSendingQuicStream(QuicStreamCommand):
"""Request termination of the receiving part of a stream."""
error_code: int
"""An error code indicating why the stream is being stopped."""
def __init__(
self, connection: connection.Connection, stream_id: int, error_code: int
) -> None:
super().__init__(connection, stream_id)
self.error_code = error_code
class CloseQuicConnection(commands.CloseConnection):
"""Close a QUIC connection."""
error_code: int
"The error code which was specified when closing the connection."
frame_type: int | None
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
# XXX: A bit much boilerplate right now. Should switch to dataclasses.
def __init__(
self,
conn: connection.Connection,
error_code: int,
frame_type: int | None,
reason_phrase: str,
) -> None:
super().__init__(conn)
self.error_code = error_code
self.frame_type = frame_type
self.reason_phrase = reason_phrase

View File

@ -0,0 +1,70 @@
from __future__ import annotations
from dataclasses import dataclass
from mitmproxy import connection
from mitmproxy.proxy import events
@dataclass
class QuicStreamEvent(events.ConnectionEvent):
"""Base class for all QUIC stream events."""
stream_id: int
"""The ID of the stream the event was fired for."""
@dataclass
class QuicStreamDataReceived(QuicStreamEvent):
"""Event that is fired whenever data is received on a stream."""
data: bytes
"""The data which was received."""
end_stream: bool
"""Whether the STREAM frame had the FIN bit set."""
def __repr__(self):
target = repr(self.connection).partition("(")[0].lower()
end_stream = "[end_stream] " if self.end_stream else ""
return f"QuicStreamDataReceived({target} on {self.stream_id}, {end_stream}{self.data!r})"
@dataclass
class QuicStreamReset(QuicStreamEvent):
"""Event that is fired when the remote peer resets a stream."""
error_code: int
"""The error code that triggered the reset."""
@dataclass
class QuicStreamStopSending(QuicStreamEvent):
"""Event that is fired when the remote peer sends a STOP_SENDING frame."""
error_code: int
"""The application protocol error code."""
class QuicConnectionClosed(events.ConnectionClosed):
"""QUIC connection has been closed."""
error_code: int
"The error code which was specified when closing the connection."
frame_type: int | None
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
def __init__(
self,
conn: connection.Connection,
error_code: int,
frame_type: int | None,
reason_phrase: str,
) -> None:
super().__init__(conn)
self.error_code = error_code
self.frame_type = frame_type
self.reason_phrase = reason_phrase

View File

@ -0,0 +1,77 @@
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import field
from ssl import VerifyMode
from aioquic.tls import CipherSuite
from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import dsa
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric import rsa
from mitmproxy.proxy import commands
from mitmproxy.tls import TlsData
@dataclass
class QuicTlsSettings:
"""
Settings necessary to establish QUIC's TLS context.
"""
alpn_protocols: list[str] | None = None
"""A list of supported ALPN protocols."""
certificate: x509.Certificate | None = None
"""The certificate to use for the connection."""
certificate_chain: list[x509.Certificate] = field(default_factory=list)
"""A list of additional certificates to send to the peer."""
certificate_private_key: (
dsa.DSAPrivateKey | ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey | None
) = None
"""The certificate's private key."""
cipher_suites: list[CipherSuite] | None = None
"""An optional list of allowed/advertised cipher suites."""
ca_path: str | None = None
"""An optional path to a directory that contains the necessary information to verify the peer certificate."""
ca_file: str | None = None
"""An optional path to a PEM file that will be used to verify the peer certificate."""
verify_mode: VerifyMode | None = None
"""An optional flag that specifies how/if the peer's certificate should be validated."""
@dataclass
class QuicTlsData(TlsData):
"""
Event data for `quic_start_client` and `quic_start_server` event hooks.
"""
settings: QuicTlsSettings | None = None
"""
The associated `QuicTlsSettings` object.
This will be set by an addon in the `quic_start_*` event hooks.
"""
@dataclass
class QuicStartClientHook(commands.StartHook):
"""
TLS negotiation between mitmproxy and a client over QUIC is about to start.
An addon is expected to initialize data.settings.
(by default, this is done by `mitmproxy.addons.tlsconfig`)
"""
data: QuicTlsData
@dataclass
class QuicStartServerHook(commands.StartHook):
"""
TLS negotiation between mitmproxy and a server over QUIC is about to start.
An addon is expected to initialize data.settings.
(by default, this is done by `mitmproxy.addons.tlsconfig`)
"""
data: QuicTlsData

View File

@ -0,0 +1,429 @@
"""
This module contains the proxy layers for raw QUIC proxying.
This is used if we want to speak QUIC, but we do not want to do HTTP.
"""
from __future__ import annotations
import time
from aioquic.quic.connection import QuicErrorCode
from aioquic.quic.connection import stream_is_client_initiated
from aioquic.quic.connection import stream_is_unidirectional
from ._commands import CloseQuicConnection
from ._commands import ResetQuicStream
from ._commands import SendQuicStreamData
from ._commands import StopSendingQuicStream
from ._events import QuicConnectionClosed
from ._events import QuicStreamDataReceived
from ._events import QuicStreamEvent
from ._events import QuicStreamReset
from mitmproxy import connection
from mitmproxy.proxy import commands
from mitmproxy.proxy import context
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy import tunnel
from mitmproxy.proxy.layers.tcp import TCPLayer
from mitmproxy.proxy.layers.udp import UDPLayer
class QuicStreamNextLayer(layer.NextLayer):
"""`NextLayer` variant that callbacks `QuicStreamLayer` after layer decision."""
def __init__(
self,
context: context.Context,
stream: QuicStreamLayer,
ask_on_start: bool = False,
) -> None:
super().__init__(context, ask_on_start)
self._stream = stream
self._layer: layer.Layer | None = None
@property # type: ignore
def layer(self) -> layer.Layer | None: # type: ignore
return self._layer
@layer.setter
def layer(self, value: layer.Layer | None) -> None:
self._layer = value
if self._layer:
self._stream.refresh_metadata()
class QuicStreamLayer(layer.Layer):
"""
Layer for QUIC streams.
Serves as a marker for NextLayer and keeps track of the connection states.
"""
client: connection.Client
"""Virtual client connection for this stream. Use this in QuicRawLayer instead of `context.client`."""
server: connection.Server
"""Virtual server connection for this stream. Use this in QuicRawLayer instead of `context.server`."""
child_layer: layer.Layer
"""The stream's child layer."""
def __init__(self, context: context.Context, ignore: bool, stream_id: int) -> None:
# we mustn't reuse the client from the QUIC connection, as the state and protocol differs
self.client = context.client = context.client.copy()
self.client.transport_protocol = "tcp"
self.client.state = connection.ConnectionState.OPEN
# unidirectional client streams are not fully open, set the appropriate state
if stream_is_unidirectional(stream_id):
self.client.state = (
connection.ConnectionState.CAN_READ
if stream_is_client_initiated(stream_id)
else connection.ConnectionState.CAN_WRITE
)
self._client_stream_id = stream_id
# start with a closed server
self.server = context.server = connection.Server(
address=context.server.address,
transport_protocol="tcp",
)
self._server_stream_id: int | None = None
# ignored connections will be assigned a TCPLayer immediately
super().__init__(context)
self.child_layer = (
TCPLayer(context, ignore=True)
if ignore
else QuicStreamNextLayer(context, self)
)
self.refresh_metadata()
# we don't handle any events, pass everything to the child layer
self.handle_event = self.child_layer.handle_event # type: ignore
self._handle_event = self.child_layer._handle_event # type: ignore
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
raise AssertionError # pragma: no cover
def open_server_stream(self, server_stream_id) -> None:
assert self._server_stream_id is None
self._server_stream_id = server_stream_id
self.server.timestamp_start = time.time()
self.server.state = (
(
connection.ConnectionState.CAN_WRITE
if stream_is_client_initiated(server_stream_id)
else connection.ConnectionState.CAN_READ
)
if stream_is_unidirectional(server_stream_id)
else connection.ConnectionState.OPEN
)
self.refresh_metadata()
def refresh_metadata(self) -> None:
# find the first transport layer
child_layer: layer.Layer | None = self.child_layer
while True:
if isinstance(child_layer, layer.NextLayer):
child_layer = child_layer.layer
elif isinstance(child_layer, tunnel.TunnelLayer):
child_layer = child_layer.child_layer
else:
break # pragma: no cover
if isinstance(child_layer, (UDPLayer, TCPLayer)) and child_layer.flow:
child_layer.flow.metadata["quic_is_unidirectional"] = (
stream_is_unidirectional(self._client_stream_id)
)
child_layer.flow.metadata["quic_initiator"] = (
"client"
if stream_is_client_initiated(self._client_stream_id)
else "server"
)
child_layer.flow.metadata["quic_stream_id_client"] = self._client_stream_id
child_layer.flow.metadata["quic_stream_id_server"] = self._server_stream_id
def stream_id(self, client: bool) -> int | None:
return self._client_stream_id if client else self._server_stream_id
class RawQuicLayer(layer.Layer):
"""
This layer is responsible for de-multiplexing QUIC streams into an individual layer stack per stream.
"""
ignore: bool
"""Indicates whether traffic should be routed as-is."""
datagram_layer: layer.Layer
"""
The layer that is handling datagrams over QUIC. It's like a child_layer, but with a forked context.
Instead of having a datagram-equivalent for all `QuicStream*` classes, we use `SendData` and `DataReceived` instead.
There is also no need for another `NextLayer` marker, as a missing `QuicStreamLayer` implies UDP,
and the connection state is the same as the one of the underlying QUIC connection.
"""
client_stream_ids: dict[int, QuicStreamLayer]
"""Maps stream IDs from the client connection to stream layers."""
server_stream_ids: dict[int, QuicStreamLayer]
"""Maps stream IDs from the server connection to stream layers."""
connections: dict[connection.Connection, layer.Layer]
"""Maps connections to layers."""
command_sources: dict[commands.Command, layer.Layer]
"""Keeps track of blocking commands and wakeup requests."""
next_stream_id: list[int]
"""List containing the next stream ID for all four is_unidirectional/is_client combinations."""
def __init__(self, context: context.Context, ignore: bool = False) -> None:
super().__init__(context)
self.ignore = ignore
self.datagram_layer = (
UDPLayer(self.context.fork(), ignore=True)
if ignore
else layer.NextLayer(self.context.fork())
)
self.client_stream_ids = {}
self.server_stream_ids = {}
self.connections = {
context.client: self.datagram_layer,
context.server: self.datagram_layer,
}
self.command_sources = {}
self.next_stream_id = [0, 1, 2, 3]
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
# we treat the datagram layer as child layer, so forward Start
if isinstance(event, events.Start):
if self.context.server.timestamp_start is None:
err = yield commands.OpenConnection(self.context.server)
if err:
yield commands.CloseConnection(self.context.client)
self._handle_event = self.done # type: ignore
return
yield from self.event_to_child(self.datagram_layer, event)
# properly forward completion events based on their command
elif isinstance(event, events.CommandCompleted):
yield from self.event_to_child(
self.command_sources.pop(event.command), event
)
# route injected messages based on their connections (prefer client, fallback to server)
elif isinstance(event, events.MessageInjected):
if event.flow.client_conn in self.connections:
yield from self.event_to_child(
self.connections[event.flow.client_conn], event
)
elif event.flow.server_conn in self.connections:
yield from self.event_to_child(
self.connections[event.flow.server_conn], event
)
else:
raise AssertionError(f"Flow not associated: {event.flow!r}")
# handle stream events targeting this context
elif isinstance(event, QuicStreamEvent) and (
event.connection is self.context.client
or event.connection is self.context.server
):
from_client = event.connection is self.context.client
# fetch or create the layer
stream_ids = (
self.client_stream_ids if from_client else self.server_stream_ids
)
if event.stream_id in stream_ids:
stream_layer = stream_ids[event.stream_id]
else:
# ensure we haven't just forgotten to register the ID
assert stream_is_client_initiated(event.stream_id) == from_client
# for server-initiated streams we need to open the client as well
if from_client:
client_stream_id = event.stream_id
server_stream_id = None
else:
client_stream_id = self.get_next_available_stream_id(
is_client=False,
is_unidirectional=stream_is_unidirectional(event.stream_id),
)
server_stream_id = event.stream_id
# create, register and start the layer
stream_layer = QuicStreamLayer(
self.context.fork(), self.ignore, client_stream_id
)
self.client_stream_ids[client_stream_id] = stream_layer
if server_stream_id is not None:
stream_layer.open_server_stream(server_stream_id)
self.server_stream_ids[server_stream_id] = stream_layer
self.connections[stream_layer.client] = stream_layer
self.connections[stream_layer.server] = stream_layer
yield from self.event_to_child(stream_layer, events.Start())
# forward data and close events
conn = stream_layer.client if from_client else stream_layer.server
if isinstance(event, QuicStreamDataReceived):
if event.data:
yield from self.event_to_child(
stream_layer, events.DataReceived(conn, event.data)
)
if event.end_stream:
yield from self.close_stream_layer(stream_layer, from_client)
elif isinstance(event, QuicStreamReset):
# preserve stream resets
for command in self.close_stream_layer(stream_layer, from_client):
if (
isinstance(command, SendQuicStreamData)
and command.stream_id == stream_layer.stream_id(not from_client)
and command.end_stream
and not command.data
):
yield ResetQuicStream(
command.connection, command.stream_id, event.error_code
)
else:
yield command
else:
raise AssertionError(f"Unexpected stream event: {event!r}")
# handle close events that target this context
elif isinstance(event, QuicConnectionClosed) and (
event.connection is self.context.client
or event.connection is self.context.server
):
from_client = event.connection is self.context.client
other_conn = self.context.server if from_client else self.context.client
# be done if both connections are closed
if other_conn.connected:
yield CloseQuicConnection(
other_conn, event.error_code, event.frame_type, event.reason_phrase
)
else:
self._handle_event = self.done # type: ignore
# always forward to the datagram layer and swallow `CloseConnection` commands
for command in self.event_to_child(self.datagram_layer, event):
if (
not isinstance(command, commands.CloseConnection)
or command.connection is not other_conn
):
yield command
# forward to either the client or server connection of stream layers and swallow empty stream end
for conn, child_layer in self.connections.items():
if isinstance(child_layer, QuicStreamLayer) and (
(conn is child_layer.client)
if from_client
else (conn is child_layer.server)
):
conn.state &= ~connection.ConnectionState.CAN_WRITE
for command in self.close_stream_layer(child_layer, from_client):
if not isinstance(command, SendQuicStreamData) or command.data:
yield command
# all other connection events are routed to their corresponding layer
elif isinstance(event, events.ConnectionEvent):
yield from self.event_to_child(self.connections[event.connection], event)
else:
raise AssertionError(f"Unexpected event: {event!r}")
def close_stream_layer(
self, stream_layer: QuicStreamLayer, client: bool
) -> layer.CommandGenerator[None]:
"""Closes the incoming part of a connection."""
conn = stream_layer.client if client else stream_layer.server
conn.state &= ~connection.ConnectionState.CAN_READ
assert conn.timestamp_start is not None
if conn.timestamp_end is None:
conn.timestamp_end = time.time()
yield from self.event_to_child(stream_layer, events.ConnectionClosed(conn))
def event_to_child(
self, child_layer: layer.Layer, event: events.Event
) -> layer.CommandGenerator[None]:
"""Forwards events to child layers and translates commands."""
for command in child_layer.handle_event(event):
# intercept commands for streams connections
if (
isinstance(child_layer, QuicStreamLayer)
and isinstance(command, commands.ConnectionCommand)
and (
command.connection is child_layer.client
or command.connection is child_layer.server
)
):
# get the target connection and stream ID
to_client = command.connection is child_layer.client
quic_conn = self.context.client if to_client else self.context.server
stream_id = child_layer.stream_id(to_client)
# write data and check CloseConnection wasn't called before
if isinstance(command, commands.SendData):
assert stream_id is not None
if command.connection.state & connection.ConnectionState.CAN_WRITE:
yield SendQuicStreamData(quic_conn, stream_id, command.data)
# send a FIN and optionally also a STOP frame
elif isinstance(command, commands.CloseConnection):
assert stream_id is not None
if command.connection.state & connection.ConnectionState.CAN_WRITE:
command.connection.state &= (
~connection.ConnectionState.CAN_WRITE
)
yield SendQuicStreamData(
quic_conn, stream_id, b"", end_stream=True
)
# XXX: Use `command.connection.state & connection.ConnectionState.CAN_READ` instead?
only_close_our_half = (
isinstance(command, commands.CloseTcpConnection)
and command.half_close
)
if not only_close_our_half:
if stream_is_client_initiated(
stream_id
) == to_client or not stream_is_unidirectional(stream_id):
yield StopSendingQuicStream(
quic_conn, stream_id, QuicErrorCode.NO_ERROR
)
yield from self.close_stream_layer(child_layer, to_client)
# open server connections by reserving the next stream ID
elif isinstance(command, commands.OpenConnection):
assert not to_client
assert stream_id is None
client_stream_id = child_layer.stream_id(client=True)
assert client_stream_id is not None
stream_id = self.get_next_available_stream_id(
is_client=True,
is_unidirectional=stream_is_unidirectional(client_stream_id),
)
child_layer.open_server_stream(stream_id)
self.server_stream_ids[stream_id] = child_layer
yield from self.event_to_child(
child_layer, events.OpenConnectionCompleted(command, None)
)
else:
raise AssertionError(
f"Unexpected stream connection command: {command!r}"
)
# remember blocking and wakeup commands
else:
if command.blocking or isinstance(command, commands.RequestWakeup):
self.command_sources[command] = child_layer
if isinstance(command, commands.OpenConnection):
self.connections[command.connection] = child_layer
yield command
def get_next_available_stream_id(
self, is_client: bool, is_unidirectional: bool = False
) -> int:
index = (int(is_unidirectional) << 1) | int(not is_client)
stream_id = self.next_stream_id[index]
self.next_stream_id[index] = stream_id + 4
return stream_id
def done(self, _) -> layer.CommandGenerator[None]: # pragma: no cover
yield from ()

View File

@ -0,0 +1,638 @@
"""
This module contains the client and server proxy layers for QUIC streams
which decrypt and encrypt traffic. Decrypted stream data is then forwarded
to either the raw layers, or the HTTP/3 client in ../http/_http3.py.
"""
from __future__ import annotations
import time
from collections.abc import Callable
from logging import DEBUG
from logging import ERROR
from logging import WARNING
from aioquic.buffer import Buffer as QuicBuffer
from aioquic.h3.connection import ErrorCode as H3ErrorCode
from aioquic.quic import events as quic_events
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.connection import QuicConnectionState
from aioquic.quic.connection import QuicErrorCode
from aioquic.quic.packet import encode_quic_version_negotiation
from aioquic.quic.packet import PACKET_TYPE_INITIAL
from aioquic.quic.packet import pull_quic_header
from cryptography import x509
from ._client_hello_parser import quic_parse_client_hello_from_datagrams
from ._commands import CloseQuicConnection
from ._commands import QuicStreamCommand
from ._commands import ResetQuicStream
from ._commands import SendQuicStreamData
from ._commands import StopSendingQuicStream
from ._events import QuicConnectionClosed
from ._events import QuicStreamDataReceived
from ._events import QuicStreamReset
from ._events import QuicStreamStopSending
from ._hooks import QuicStartClientHook
from ._hooks import QuicStartServerHook
from ._hooks import QuicTlsData
from ._hooks import QuicTlsSettings
from mitmproxy import certs
from mitmproxy import connection
from mitmproxy import ctx
from mitmproxy.net import tls
from mitmproxy.proxy import commands
from mitmproxy.proxy import context
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy import tunnel
from mitmproxy.proxy.layers.tls import TlsClienthelloHook
from mitmproxy.proxy.layers.tls import TlsEstablishedClientHook
from mitmproxy.proxy.layers.tls import TlsEstablishedServerHook
from mitmproxy.proxy.layers.tls import TlsFailedClientHook
from mitmproxy.proxy.layers.tls import TlsFailedServerHook
from mitmproxy.proxy.layers.udp import UDPLayer
from mitmproxy.tls import ClientHelloData
SUPPORTED_QUIC_VERSIONS_SERVER = QuicConfiguration(is_client=False).supported_versions
class QuicLayer(tunnel.TunnelLayer):
quic: QuicConnection | None = None
tls: QuicTlsSettings | None = None
def __init__(
self,
context: context.Context,
conn: connection.Connection,
time: Callable[[], float] | None,
) -> None:
super().__init__(context, tunnel_connection=conn, conn=conn)
self.child_layer = layer.NextLayer(self.context, ask_on_start=True)
self._time = time or ctx.master.event_loop.time
self._wakeup_commands: dict[commands.RequestWakeup, float] = dict()
conn.tls = True
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.Wakeup) and event.command in self._wakeup_commands:
# TunnelLayer has no understanding of wakeups, so we turn this into an empty DataReceived event
# which TunnelLayer recognizes as belonging to our connection.
assert self.quic
scheduled_time = self._wakeup_commands.pop(event.command)
if self.quic._state is not QuicConnectionState.TERMINATED:
# weird quirk: asyncio sometimes returns a bit ahead of time.
now = max(scheduled_time, self._time())
self.quic.handle_timer(now)
yield from super()._handle_event(
events.DataReceived(self.tunnel_connection, b"")
)
else:
yield from super()._handle_event(event)
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
# the parent will call _handle_command multiple times, we transmit cumulative afterwards
# this will reduce the number of sends, especially if data=b"" and end_stream=True
yield from super().event_to_child(event)
if self.quic:
yield from self.tls_interact()
def _handle_command(
self, command: commands.Command
) -> layer.CommandGenerator[None]:
"""Turns stream commands into aioquic connection invocations."""
if isinstance(command, QuicStreamCommand) and command.connection is self.conn:
assert self.quic
if isinstance(command, SendQuicStreamData):
self.quic.send_stream_data(
command.stream_id, command.data, command.end_stream
)
elif isinstance(command, ResetQuicStream):
stream = self.quic._get_or_create_stream_for_send(command.stream_id)
existing_reset_error_code = stream.sender._reset_error_code
if existing_reset_error_code is None:
self.quic.reset_stream(command.stream_id, command.error_code)
elif self.debug: # pragma: no cover
yield commands.Log(
f"{self.debug}[quic] stream {stream.stream_id} already reset ({existing_reset_error_code=}, {command.error_code=})",
DEBUG,
)
elif isinstance(command, StopSendingQuicStream):
# the stream might have already been closed, check before stopping
if command.stream_id in self.quic._streams:
self.quic.stop_stream(command.stream_id, command.error_code)
else:
raise AssertionError(f"Unexpected stream command: {command!r}")
else:
yield from super()._handle_command(command)
def start_tls(
self, original_destination_connection_id: bytes | None
) -> layer.CommandGenerator[None]:
"""Initiates the aioquic connection."""
# must only be called if QUIC is uninitialized
assert not self.quic
assert not self.tls
# query addons to provide the necessary TLS settings
tls_data = QuicTlsData(self.conn, self.context)
if self.conn is self.context.client:
yield QuicStartClientHook(tls_data)
else:
yield QuicStartServerHook(tls_data)
if not tls_data.settings:
yield commands.Log(
f"No QUIC context was provided, failing connection.", ERROR
)
yield commands.CloseConnection(self.conn)
return
# build the aioquic connection
configuration = tls_settings_to_configuration(
settings=tls_data.settings,
is_client=self.conn is self.context.server,
server_name=self.conn.sni,
)
self.quic = QuicConnection(
configuration=configuration,
original_destination_connection_id=original_destination_connection_id,
)
self.tls = tls_data.settings
# if we act as client, connect to upstream
if original_destination_connection_id is None:
self.quic.connect(self.conn.peername, now=self._time())
yield from self.tls_interact()
def tls_interact(self) -> layer.CommandGenerator[None]:
"""Retrieves all pending outgoing packets from aioquic and sends the data."""
# send all queued datagrams
assert self.quic
now = self._time()
for data, addr in self.quic.datagrams_to_send(now=now):
assert addr == self.conn.peername
yield commands.SendData(self.tunnel_connection, data)
timer = self.quic.get_timer()
if timer is not None:
# smooth wakeups a bit.
smoothed = timer + 0.002
# request a new wakeup if all pending requests trigger at a later time
if not any(
existing <= smoothed for existing in self._wakeup_commands.values()
):
command = commands.RequestWakeup(timer - now)
self._wakeup_commands[command] = timer
yield command
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, str | None]]:
assert self.quic
# forward incoming data to aioquic
if data:
self.quic.receive_datagram(data, self.conn.peername, now=self._time())
# handle pre-handshake events
while event := self.quic.next_event():
if isinstance(event, quic_events.ConnectionTerminated):
err = event.reason_phrase or error_code_to_str(event.error_code)
return False, err
elif isinstance(event, quic_events.HandshakeCompleted):
# concatenate all peer certificates
all_certs: list[x509.Certificate] = []
if self.quic.tls._peer_certificate:
all_certs.append(self.quic.tls._peer_certificate)
all_certs.extend(self.quic.tls._peer_certificate_chain)
# set the connection's TLS properties
self.conn.timestamp_tls_setup = time.time()
if event.alpn_protocol:
self.conn.alpn = event.alpn_protocol.encode("ascii")
self.conn.certificate_list = [certs.Cert(cert) for cert in all_certs]
assert self.quic.tls.key_schedule
self.conn.cipher = self.quic.tls.key_schedule.cipher_suite.name
self.conn.tls_version = "QUIC"
# log the result and report the success to addons
if self.debug:
yield commands.Log(
f"{self.debug}[quic] tls established: {self.conn}", DEBUG
)
if self.conn is self.context.client:
yield TlsEstablishedClientHook(
QuicTlsData(self.conn, self.context, settings=self.tls)
)
else:
yield TlsEstablishedServerHook(
QuicTlsData(self.conn, self.context, settings=self.tls)
)
yield from self.tls_interact()
return True, None
elif isinstance(
event,
(
quic_events.ConnectionIdIssued,
quic_events.ConnectionIdRetired,
quic_events.PingAcknowledged,
quic_events.ProtocolNegotiated,
),
):
pass
else:
raise AssertionError(f"Unexpected event: {event!r}")
# transmit buffered data and re-arm timer
yield from self.tls_interact()
return False, None
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
self.conn.error = err
if self.conn is self.context.client:
yield TlsFailedClientHook(
QuicTlsData(self.conn, self.context, settings=self.tls)
)
else:
yield TlsFailedServerHook(
QuicTlsData(self.conn, self.context, settings=self.tls)
)
yield from super().on_handshake_error(err)
def receive_data(self, data: bytes) -> layer.CommandGenerator[None]:
assert self.quic
# forward incoming data to aioquic
if data:
self.quic.receive_datagram(data, self.conn.peername, now=self._time())
# handle post-handshake events
while event := self.quic.next_event():
if isinstance(event, quic_events.ConnectionTerminated):
if self.debug:
reason = event.reason_phrase or error_code_to_str(event.error_code)
yield commands.Log(
f"{self.debug}[quic] close_notify {self.conn} (reason={reason})",
DEBUG,
)
# We don't rely on `ConnectionTerminated` to dispatch `QuicConnectionClosed`, because
# after aioquic receives a termination frame, it still waits for the next `handle_timer`
# before returning `ConnectionTerminated` in `next_event`. In the meantime, the underlying
# connection could be closed. Therefore, we instead dispatch on `ConnectionClosed` and simply
# close the connection here.
yield commands.CloseConnection(self.tunnel_connection)
return # we don't handle any further events, nor do/can we transmit data, so exit
elif isinstance(event, quic_events.DatagramFrameReceived):
yield from self.event_to_child(
events.DataReceived(self.conn, event.data)
)
elif isinstance(event, quic_events.StreamDataReceived):
yield from self.event_to_child(
QuicStreamDataReceived(
self.conn, event.stream_id, event.data, event.end_stream
)
)
elif isinstance(event, quic_events.StreamReset):
yield from self.event_to_child(
QuicStreamReset(self.conn, event.stream_id, event.error_code)
)
elif isinstance(event, quic_events.StopSendingReceived):
yield from self.event_to_child(
QuicStreamStopSending(self.conn, event.stream_id, event.error_code)
)
elif isinstance(
event,
(
quic_events.ConnectionIdIssued,
quic_events.ConnectionIdRetired,
quic_events.PingAcknowledged,
quic_events.ProtocolNegotiated,
),
):
pass
else:
raise AssertionError(f"Unexpected event: {event!r}")
# transmit buffered data and re-arm timer
yield from self.tls_interact()
def receive_close(self) -> layer.CommandGenerator[None]:
assert self.quic
# if `_close_event` is not set, the underlying connection has been closed
# we turn this into a QUIC close event as well
close_event = self.quic._close_event or quic_events.ConnectionTerminated(
QuicErrorCode.NO_ERROR, None, "Connection closed."
)
yield from self.event_to_child(
QuicConnectionClosed(
self.conn,
close_event.error_code,
close_event.frame_type,
close_event.reason_phrase,
)
)
def send_data(self, data: bytes) -> layer.CommandGenerator[None]:
# non-stream data uses datagram frames
assert self.quic
if data:
self.quic.send_datagram_frame(data)
yield from self.tls_interact()
def send_close(
self, command: commands.CloseConnection
) -> layer.CommandGenerator[None]:
# properly close the QUIC connection
if self.quic:
if isinstance(command, CloseQuicConnection):
self.quic.close(
command.error_code, command.frame_type, command.reason_phrase
)
else:
self.quic.close()
yield from self.tls_interact()
yield from super().send_close(command)
class ServerQuicLayer(QuicLayer):
"""
This layer establishes QUIC for a single server connection.
"""
wait_for_clienthello: bool = False
def __init__(
self,
context: context.Context,
conn: connection.Server | None = None,
time: Callable[[], float] | None = None,
):
super().__init__(context, conn or context.server, time)
def start_handshake(self) -> layer.CommandGenerator[None]:
wait_for_clienthello = not self.command_to_reply_to and isinstance(
self.child_layer, ClientQuicLayer
)
if wait_for_clienthello:
self.wait_for_clienthello = True
self.tunnel_state = tunnel.TunnelState.CLOSED
else:
yield from self.start_tls(None)
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
if self.wait_for_clienthello:
for command in super().event_to_child(event):
if (
isinstance(command, commands.OpenConnection)
and command.connection == self.conn
):
self.wait_for_clienthello = False
else:
yield command
else:
yield from super().event_to_child(event)
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
yield commands.Log(f"Server QUIC handshake failed. {err}", level=WARNING)
yield from super().on_handshake_error(err)
class ClientQuicLayer(QuicLayer):
"""
This layer establishes QUIC on a single client connection.
"""
server_tls_available: bool
"""Indicates whether the parent layer is a ServerQuicLayer."""
handshake_datagram_buf: list[bytes]
def __init__(
self, context: context.Context, time: Callable[[], float] | None = None
) -> None:
# same as ClientTLSLayer, we might be nested in some other transport
if context.client.tls:
context.client.alpn = None
context.client.cipher = None
context.client.sni = None
context.client.timestamp_tls_setup = None
context.client.tls_version = None
context.client.certificate_list = []
context.client.mitmcert = None
context.client.alpn_offers = []
context.client.cipher_list = []
super().__init__(context, context.client, time)
self.server_tls_available = len(self.context.layers) >= 2 and isinstance(
self.context.layers[-2], ServerQuicLayer
)
self.handshake_datagram_buf = []
def start_handshake(self) -> layer.CommandGenerator[None]:
yield from ()
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, str | None]]:
if not self.context.options.http3:
yield commands.Log(
f"Swallowing QUIC handshake because HTTP/3 is disabled.", DEBUG
)
return False, None
# if we already had a valid client hello, don't process further packets
if self.tls:
return (yield from super().receive_handshake_data(data))
# fail if the received data is not a QUIC packet
buffer = QuicBuffer(data=data)
try:
header = pull_quic_header(buffer)
except TypeError:
return False, f"Cannot parse QUIC header: Malformed head ({data.hex()})"
except ValueError as e:
return False, f"Cannot parse QUIC header: {e} ({data.hex()})"
# negotiate version, support all versions known to aioquic
if (
header.version is not None
and header.version not in SUPPORTED_QUIC_VERSIONS_SERVER
):
yield commands.SendData(
self.tunnel_connection,
encode_quic_version_negotiation(
source_cid=header.destination_cid,
destination_cid=header.source_cid,
supported_versions=SUPPORTED_QUIC_VERSIONS_SERVER,
),
)
return False, None
# ensure it's (likely) a client handshake packet
if len(data) < 1200 or header.packet_type != PACKET_TYPE_INITIAL:
return (
False,
f"Invalid handshake received, roaming not supported. ({data.hex()})",
)
self.handshake_datagram_buf.append(data)
# extract the client hello
try:
client_hello = quic_parse_client_hello_from_datagrams(
self.handshake_datagram_buf
)
except ValueError as e:
msgs = b"\n".join(self.handshake_datagram_buf)
dbg = f"Cannot parse ClientHello: {str(e)} ({msgs.hex()})"
self.handshake_datagram_buf.clear()
return False, dbg
if not client_hello:
return False, None
# copy the client hello information
self.conn.sni = client_hello.sni
self.conn.alpn_offers = client_hello.alpn_protocols
# check with addons what we shall do
tls_clienthello = ClientHelloData(self.context, client_hello)
yield TlsClienthelloHook(tls_clienthello)
# replace the QUIC layer with an UDP layer if requested
if tls_clienthello.ignore_connection:
self.conn = self.tunnel_connection = connection.Client(
peername=("ignore-conn", 0),
sockname=("ignore-conn", 0),
transport_protocol="udp",
state=connection.ConnectionState.OPEN,
)
# we need to replace the server layer as well, if there is one
parent_layer = self.context.layers[self.context.layers.index(self) - 1]
if isinstance(parent_layer, ServerQuicLayer):
parent_layer.conn = parent_layer.tunnel_connection = connection.Server(
address=None
)
replacement_layer = UDPLayer(self.context, ignore=True)
parent_layer.handle_event = replacement_layer.handle_event # type: ignore
parent_layer._handle_event = replacement_layer._handle_event # type: ignore
yield from parent_layer.handle_event(events.Start())
for dgm in self.handshake_datagram_buf:
yield from parent_layer.handle_event(
events.DataReceived(self.context.client, dgm)
)
self.handshake_datagram_buf.clear()
return True, None
# start the server QUIC connection if demanded and available
if (
tls_clienthello.establish_server_tls_first
and not self.context.server.tls_established
):
err = yield from self.start_server_tls()
if err:
yield commands.Log(
f"Unable to establish QUIC connection with server ({err}). "
f"Trying to establish QUIC with client anyway. "
f"If you plan to redirect requests away from this server, "
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
)
# start the client QUIC connection
yield from self.start_tls(header.destination_cid)
# XXX copied from TLS, we assume that `CloseConnection` in `start_tls` takes effect immediately
if not self.conn.connected:
return False, "connection closed early"
# send the client hello to aioquic
assert self.quic
for dgm in self.handshake_datagram_buf:
self.quic.receive_datagram(dgm, self.conn.peername, now=self._time())
self.handshake_datagram_buf.clear()
# handle events emanating from `self.quic`
return (yield from super().receive_handshake_data(b""))
def start_server_tls(self) -> layer.CommandGenerator[str | None]:
if not self.server_tls_available:
return f"No server QUIC available."
err = yield commands.OpenConnection(self.context.server)
return err
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
yield commands.Log(f"Client QUIC handshake failed. {err}", level=WARNING)
yield from super().on_handshake_error(err)
self.event_to_child = self.errored # type: ignore
def errored(self, event: events.Event) -> layer.CommandGenerator[None]:
if self.debug is not None:
yield commands.Log(
f"{self.debug}[quic] Swallowing {event} as handshake failed.", DEBUG
)
class QuicSecretsLogger:
logger: tls.MasterSecretLogger
def __init__(self, logger: tls.MasterSecretLogger) -> None:
super().__init__()
self.logger = logger
def write(self, s: str) -> int:
if s[-1:] == "\n":
s = s[:-1]
data = s.encode("ascii")
self.logger(None, data) # type: ignore
return len(data) + 1
def flush(self) -> None:
# done by the logger during write
pass
def error_code_to_str(error_code: int) -> str:
"""Returns the corresponding name of the given error code or a string containing its numeric value."""
try:
return H3ErrorCode(error_code).name
except ValueError:
try:
return QuicErrorCode(error_code).name
except ValueError:
return f"unknown error (0x{error_code:x})"
def is_success_error_code(error_code: int) -> bool:
"""Returns whether the given error code actually indicates no error."""
return error_code in (QuicErrorCode.NO_ERROR, H3ErrorCode.H3_NO_ERROR)
def tls_settings_to_configuration(
settings: QuicTlsSettings,
is_client: bool,
server_name: str | None = None,
) -> QuicConfiguration:
"""Converts `QuicTlsSettings` to `QuicConfiguration`."""
return QuicConfiguration(
alpn_protocols=settings.alpn_protocols,
is_client=is_client,
secrets_log_file=(
QuicSecretsLogger(tls.log_master_secret) # type: ignore
if tls.log_master_secret is not None
else None
),
server_name=server_name,
cafile=settings.ca_file,
capath=settings.ca_path,
certificate=settings.certificate,
certificate_chain=settings.certificate_chain,
cipher_suites=settings.cipher_suites,
private_key=settings.certificate_private_key,
verify_mode=settings.verify_mode,
max_datagram_frame_size=65536,
)

View File

@ -18,8 +18,8 @@ from mitmproxy.proxy.layers import modes
from mitmproxy.proxy.layers import quic
from mitmproxy.proxy.layers import tls as proxy_tls
from mitmproxy.test import taddons
from test.mitmproxy.proxy.layers import test_quic
from test.mitmproxy.proxy.layers import test_tls
from test.mitmproxy.proxy.layers.quic import test__stream_layers as test_quic
def test_alpn_select_callback():

View File

@ -13,6 +13,7 @@ from aioquic.h3.connection import Headers as H3Headers
from aioquic.h3.connection import parse_settings
from aioquic.h3.connection import Setting
from aioquic.h3.connection import StreamType
from aioquic.quic.packet import QuicErrorCode
from mitmproxy import connection
from mitmproxy import version
@ -806,13 +807,13 @@ def test_rst_then_close(tctx):
>> cff.receive_data(b"unexpected data frame")
<< quic.CloseQuicConnection(
tctx.client,
error_code=quic.QuicErrorCode.PROTOCOL_VIOLATION.value,
error_code=QuicErrorCode.PROTOCOL_VIOLATION.value,
frame_type=None,
reason_phrase=err,
)
>> quic.QuicConnectionClosed(
tctx.client,
error_code=quic.QuicErrorCode.PROTOCOL_VIOLATION.value,
error_code=QuicErrorCode.PROTOCOL_VIOLATION.value,
frame_type=None,
reason_phrase=err,
)

View File

@ -0,0 +1,53 @@
import pytest
from aioquic.quic.connection import QuicConnection
from aioquic.quic.connection import QuicConnectionError
from mitmproxy.proxy.layers.quic import _client_hello_parser
from mitmproxy.proxy.layers.quic._client_hello_parser import (
quic_parse_client_hello_from_datagrams,
)
from test.mitmproxy.proxy.layers.quic.test__stream_layers import client_hello
class TestParseClientHello:
def test_input(self):
assert (
quic_parse_client_hello_from_datagrams([client_hello]).sni == "example.com"
)
with pytest.raises(ValueError):
quic_parse_client_hello_from_datagrams(
[client_hello[:183] + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00"]
)
with pytest.raises(ValueError, match="not initial"):
quic_parse_client_hello_from_datagrams(
[
b"\\s\xd8\xd8\xa5dT\x8bc\xd3\xae\x1c\xb2\x8a7-\x1d\x19j\x85\xb0~\x8c\x80\xa5\x8cY\xac\x0ecK\x7fC2f\xbcm\x1b\xac~"
]
)
def test_invalid(self, monkeypatch):
# XXX: This test is terrible, it should use actual invalid data.
class InvalidClientHello(Exception):
@property
def data(self):
raise EOFError()
monkeypatch.setattr(_client_hello_parser, "QuicClientHello", InvalidClientHello)
with pytest.raises(ValueError, match="Invalid ClientHello"):
quic_parse_client_hello_from_datagrams([client_hello])
def test_connection_error(self, monkeypatch):
def raise_conn_err(self, data, addr, now):
raise QuicConnectionError(0, 0, "Conn err")
monkeypatch.setattr(QuicConnection, "receive_datagram", raise_conn_err)
with pytest.raises(ValueError, match="Conn err"):
quic_parse_client_hello_from_datagrams([client_hello])
def test_no_return(self):
with pytest.raises(
ValueError, match="Invalid ClientHello packet: payload_decrypt_error"
):
quic_parse_client_hello_from_datagrams(
[client_hello[0:1200] + b"\x00" + client_hello[1200:]]
)

View File

@ -0,0 +1,15 @@
from mitmproxy.proxy.layers.quic._commands import CloseQuicConnection
from mitmproxy.proxy.layers.quic._commands import QuicStreamCommand
from mitmproxy.proxy.layers.quic._commands import ResetQuicStream
from mitmproxy.proxy.layers.quic._commands import SendQuicStreamData
from mitmproxy.proxy.layers.quic._commands import StopSendingQuicStream
from mitmproxy.test.tflow import tclient_conn
def test_reprs():
client = tclient_conn()
assert repr(QuicStreamCommand(client, 42))
assert repr(SendQuicStreamData(client, 42, b"data"))
assert repr(ResetQuicStream(client, 42, 0xFF))
assert repr(StopSendingQuicStream(client, 42, 0xFF))
assert repr(CloseQuicConnection(client, 0xFF, None, "reason"))

View File

@ -0,0 +1,9 @@
from mitmproxy.proxy.layers.quic._events import QuicConnectionClosed
from mitmproxy.proxy.layers.quic._events import QuicStreamDataReceived
from mitmproxy.test.tflow import tclient_conn
def test_reprs():
client = tclient_conn()
assert repr(QuicStreamDataReceived(client, 42, b"data", end_stream=False))
assert repr(QuicConnectionClosed(client, 0xFF, None, "reason"))

View File

@ -0,0 +1,19 @@
from mitmproxy.options import Options
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.layers.quic._hooks import QuicStartServerHook
from mitmproxy.proxy.layers.quic._hooks import QuicTlsData
from mitmproxy.proxy.layers.quic._hooks import QuicTlsSettings
from mitmproxy.test.tflow import tclient_conn
def test_reprs():
client = tclient_conn()
assert repr(
QuicStartServerHook(
data=QuicTlsData(
conn=client,
context=Context(client, Options()),
settings=QuicTlsSettings(),
)
)
)

View File

@ -0,0 +1,266 @@
import pytest
from mitmproxy import connection
from mitmproxy.proxy import commands
from mitmproxy.proxy import context
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy import layers
from mitmproxy.proxy import tunnel
from mitmproxy.proxy.layers import tcp
from mitmproxy.proxy.layers import udp
from mitmproxy.proxy.layers.quic._commands import CloseQuicConnection
from mitmproxy.proxy.layers.quic._commands import ResetQuicStream
from mitmproxy.proxy.layers.quic._commands import SendQuicStreamData
from mitmproxy.proxy.layers.quic._commands import StopSendingQuicStream
from mitmproxy.proxy.layers.quic._events import QuicConnectionClosed
from mitmproxy.proxy.layers.quic._events import QuicStreamDataReceived
from mitmproxy.proxy.layers.quic._events import QuicStreamEvent
from mitmproxy.proxy.layers.quic._events import QuicStreamReset
from mitmproxy.proxy.layers.quic._raw_layers import QuicStreamLayer
from mitmproxy.proxy.layers.quic._raw_layers import RawQuicLayer
from mitmproxy.tcp import TCPFlow
from mitmproxy.udp import UDPFlow
from mitmproxy.udp import UDPMessage
from test.mitmproxy.proxy import tutils
from test.mitmproxy.proxy.layers.quic.test__stream_layers import TlsEchoLayer
class TestQuicStreamLayer:
def test_ignored(self, tctx: context.Context):
quic_layer = QuicStreamLayer(tctx, True, 1)
assert isinstance(quic_layer.child_layer, layers.TCPLayer)
assert not quic_layer.child_layer.flow
quic_layer.child_layer.flow = TCPFlow(tctx.client, tctx.server)
quic_layer.refresh_metadata()
assert quic_layer.child_layer.flow.metadata["quic_is_unidirectional"] is False
assert quic_layer.child_layer.flow.metadata["quic_initiator"] == "server"
assert quic_layer.child_layer.flow.metadata["quic_stream_id_client"] == 1
assert quic_layer.child_layer.flow.metadata["quic_stream_id_server"] is None
assert quic_layer.stream_id(True) == 1
assert quic_layer.stream_id(False) is None
def test_simple(self, tctx: context.Context):
quic_layer = QuicStreamLayer(tctx, False, 2)
assert isinstance(quic_layer.child_layer, layer.NextLayer)
tunnel_layer = tunnel.TunnelLayer(tctx, tctx.client, tctx.server)
quic_layer.child_layer.layer = tunnel_layer
tcp_layer = layers.TCPLayer(tctx)
tunnel_layer.child_layer = tcp_layer
quic_layer.open_server_stream(3)
assert tcp_layer.flow.metadata["quic_is_unidirectional"] is True
assert tcp_layer.flow.metadata["quic_initiator"] == "client"
assert tcp_layer.flow.metadata["quic_stream_id_client"] == 2
assert tcp_layer.flow.metadata["quic_stream_id_server"] == 3
assert quic_layer.stream_id(True) == 2
assert quic_layer.stream_id(False) == 3
class TestRawQuicLayer:
@pytest.mark.parametrize("ignore", [True, False])
def test_error(self, tctx: context.Context, ignore: bool):
quic_layer = RawQuicLayer(tctx, ignore=ignore)
assert (
tutils.Playbook(quic_layer)
<< commands.OpenConnection(tctx.server)
>> tutils.reply("failed to open")
<< commands.CloseConnection(tctx.client)
)
assert quic_layer._handle_event == quic_layer.done
def test_ignored(self, tctx: context.Context):
quic_layer = RawQuicLayer(tctx, ignore=True)
assert (
tutils.Playbook(quic_layer)
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> events.DataReceived(tctx.client, b"msg1")
<< commands.SendData(tctx.server, b"msg1")
>> events.DataReceived(tctx.server, b"msg2")
<< commands.SendData(tctx.client, b"msg2")
>> QuicStreamDataReceived(tctx.client, 0, b"msg3", end_stream=False)
<< SendQuicStreamData(tctx.server, 0, b"msg3", end_stream=False)
>> QuicStreamDataReceived(tctx.client, 6, b"msg4", end_stream=False)
<< SendQuicStreamData(tctx.server, 2, b"msg4", end_stream=False)
>> QuicStreamDataReceived(tctx.server, 9, b"msg5", end_stream=False)
<< SendQuicStreamData(tctx.client, 1, b"msg5", end_stream=False)
>> QuicStreamDataReceived(tctx.client, 0, b"", end_stream=True)
<< SendQuicStreamData(tctx.server, 0, b"", end_stream=True)
>> QuicStreamReset(tctx.client, 6, 142)
<< ResetQuicStream(tctx.server, 2, 142)
>> QuicConnectionClosed(tctx.client, 42, None, "closed")
<< CloseQuicConnection(tctx.server, 42, None, "closed")
>> QuicConnectionClosed(tctx.server, 42, None, "closed")
<< None
)
assert quic_layer._handle_event == quic_layer.done
def test_msg_inject(self, tctx: context.Context):
udpflow = tutils.Placeholder(UDPFlow)
playbook = tutils.Playbook(RawQuicLayer(tctx))
assert (
playbook
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> events.DataReceived(tctx.client, b"msg1")
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(udp.UDPLayer)
<< udp.UdpStartHook(udpflow)
>> tutils.reply()
<< udp.UdpMessageHook(udpflow)
>> tutils.reply()
<< commands.SendData(tctx.server, b"msg1")
>> udp.UdpMessageInjected(udpflow, UDPMessage(True, b"msg2"))
<< udp.UdpMessageHook(udpflow)
>> tutils.reply()
<< commands.SendData(tctx.server, b"msg2")
>> udp.UdpMessageInjected(
UDPFlow(("other", 80), tctx.server), UDPMessage(True, b"msg3")
)
<< udp.UdpMessageHook(udpflow)
>> tutils.reply()
<< commands.SendData(tctx.server, b"msg3")
)
with pytest.raises(AssertionError, match="not associated"):
playbook >> udp.UdpMessageInjected(
UDPFlow(("notfound", 0), ("noexist", 0)), UDPMessage(True, b"msg2")
)
assert playbook
def test_reset_with_end_hook(self, tctx: context.Context):
tcpflow = tutils.Placeholder(TCPFlow)
assert (
tutils.Playbook(RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> QuicStreamDataReceived(tctx.client, 2, b"msg1", end_stream=False)
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(tcp.TCPLayer)
<< tcp.TcpStartHook(tcpflow)
>> tutils.reply()
<< tcp.TcpMessageHook(tcpflow)
>> tutils.reply()
<< SendQuicStreamData(tctx.server, 2, b"msg1", end_stream=False)
>> QuicStreamReset(tctx.client, 2, 42)
<< ResetQuicStream(tctx.server, 2, 42)
<< tcp.TcpEndHook(tcpflow)
>> tutils.reply()
)
def test_close_with_end_hooks(self, tctx: context.Context):
udpflow = tutils.Placeholder(UDPFlow)
tcpflow = tutils.Placeholder(TCPFlow)
assert (
tutils.Playbook(RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> events.DataReceived(tctx.client, b"msg1")
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(udp.UDPLayer)
<< udp.UdpStartHook(udpflow)
>> tutils.reply()
<< udp.UdpMessageHook(udpflow)
>> tutils.reply()
<< commands.SendData(tctx.server, b"msg1")
>> QuicStreamDataReceived(tctx.client, 2, b"msg2", end_stream=False)
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(tcp.TCPLayer)
<< tcp.TcpStartHook(tcpflow)
>> tutils.reply()
<< tcp.TcpMessageHook(tcpflow)
>> tutils.reply()
<< SendQuicStreamData(tctx.server, 2, b"msg2", end_stream=False)
>> QuicConnectionClosed(tctx.client, 42, None, "bye")
<< CloseQuicConnection(tctx.server, 42, None, "bye")
<< udp.UdpEndHook(udpflow)
<< tcp.TcpEndHook(tcpflow)
>> tutils.reply(to=-2)
>> tutils.reply(to=-2)
>> QuicConnectionClosed(tctx.server, 42, None, "bye")
)
def test_invalid_stream_event(self, tctx: context.Context):
playbook = tutils.Playbook(RawQuicLayer(tctx))
assert (
tutils.Playbook(RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
)
with pytest.raises(AssertionError, match="Unexpected stream event"):
class InvalidStreamEvent(QuicStreamEvent):
pass
playbook >> InvalidStreamEvent(tctx.client, 0)
assert playbook
def test_invalid_event(self, tctx: context.Context):
playbook = tutils.Playbook(RawQuicLayer(tctx))
assert (
tutils.Playbook(RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
)
with pytest.raises(AssertionError, match="Unexpected event"):
class InvalidEvent(events.Event):
pass
playbook >> InvalidEvent()
assert playbook
def test_full_close(self, tctx: context.Context):
assert (
tutils.Playbook(RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> QuicStreamDataReceived(tctx.client, 0, b"msg1", end_stream=True)
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(lambda ctx: udp.UDPLayer(ctx, ignore=True))
<< SendQuicStreamData(tctx.server, 0, b"msg1", end_stream=False)
<< SendQuicStreamData(tctx.server, 0, b"", end_stream=True)
<< StopSendingQuicStream(tctx.server, 0, 0)
)
def test_open_connection(self, tctx: context.Context):
server = connection.Server(address=("other", 80))
def echo_new_server(ctx: context.Context):
echo_layer = TlsEchoLayer(ctx)
echo_layer.context.server = server
return echo_layer
assert (
tutils.Playbook(RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> QuicStreamDataReceived(
tctx.client, 0, b"open-connection", end_stream=False
)
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(echo_new_server)
<< commands.OpenConnection(server)
>> tutils.reply("uhoh")
<< SendQuicStreamData(
tctx.client, 0, b"open-connection failed: uhoh", end_stream=False
)
)
def test_invalid_connection_command(self, tctx: context.Context):
playbook = tutils.Playbook(RawQuicLayer(tctx))
assert (
playbook
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> QuicStreamDataReceived(tctx.client, 0, b"msg1", end_stream=False)
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(TlsEchoLayer)
<< SendQuicStreamData(tctx.client, 0, b"msg1", end_stream=False)
)
with pytest.raises(
AssertionError, match="Unexpected stream connection command"
):
playbook >> QuicStreamDataReceived(
tctx.client, 0, b"invalid-command", end_stream=False
)
assert playbook

View File

@ -19,19 +19,31 @@ from mitmproxy.proxy import commands
from mitmproxy.proxy import context
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy import layers
from mitmproxy.proxy import tunnel
from mitmproxy.proxy.layers import quic
from mitmproxy.proxy.layers import tcp
from mitmproxy.proxy.layers import tls
from mitmproxy.proxy.layers import udp
from mitmproxy.tcp import TCPFlow
from mitmproxy.udp import UDPFlow
from mitmproxy.udp import UDPMessage
from mitmproxy.proxy.layers.quic._commands import CloseQuicConnection
from mitmproxy.proxy.layers.quic._commands import QuicStreamCommand
from mitmproxy.proxy.layers.quic._commands import ResetQuicStream
from mitmproxy.proxy.layers.quic._commands import SendQuicStreamData
from mitmproxy.proxy.layers.quic._commands import StopSendingQuicStream
from mitmproxy.proxy.layers.quic._events import QuicConnectionClosed
from mitmproxy.proxy.layers.quic._events import QuicStreamDataReceived
from mitmproxy.proxy.layers.quic._events import QuicStreamReset
from mitmproxy.proxy.layers.quic._events import QuicStreamStopSending
from mitmproxy.proxy.layers.quic._hooks import QuicStartClientHook
from mitmproxy.proxy.layers.quic._hooks import QuicStartServerHook
from mitmproxy.proxy.layers.quic._hooks import QuicTlsData
from mitmproxy.proxy.layers.quic._hooks import QuicTlsSettings
from mitmproxy.proxy.layers.quic._stream_layers import ClientQuicLayer
from mitmproxy.proxy.layers.quic._stream_layers import error_code_to_str
from mitmproxy.proxy.layers.quic._stream_layers import is_success_error_code
from mitmproxy.proxy.layers.quic._stream_layers import QuicLayer
from mitmproxy.proxy.layers.quic._stream_layers import QuicSecretsLogger
from mitmproxy.proxy.layers.quic._stream_layers import ServerQuicLayer
from mitmproxy.proxy.layers.quic._stream_layers import tls_settings_to_configuration
from mitmproxy.utils import data
from test.mitmproxy.proxy import tutils
tlsdata = data.Data(__name__)
tdata = data.Data("test")
T = TypeVar("T", bound=layer.Layer)
@ -47,7 +59,7 @@ class DummyLayer(layer.Layer):
class TlsEchoLayer(tutils.EchoLayer):
err: str | None = None
closed: quic.QuicConnectionClosed | None = None
closed: QuicConnectionClosed | None = None
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.data == b"open-connection":
@ -64,7 +76,7 @@ class TlsEchoLayer(tutils.EchoLayer):
isinstance(event, events.DataReceived)
and event.data == b"close-connection-error"
):
yield quic.CloseQuicConnection(event.connection, 123, None, "error")
yield CloseQuicConnection(event.connection, 123, None, "error")
elif (
isinstance(event, events.DataReceived) and event.data == b"invalid-command"
):
@ -78,22 +90,20 @@ class TlsEchoLayer(tutils.EchoLayer):
and event.data == b"invalid-stream-command"
):
class InvalidStreamCommand(quic.QuicStreamCommand):
class InvalidStreamCommand(QuicStreamCommand):
pass
yield InvalidStreamCommand(event.connection, 42)
elif isinstance(event, quic.QuicConnectionClosed):
elif isinstance(event, QuicConnectionClosed):
self.closed = event
elif isinstance(event, quic.QuicStreamDataReceived):
yield quic.SendQuicStreamData(
elif isinstance(event, QuicStreamDataReceived):
yield SendQuicStreamData(
event.connection, event.stream_id, event.data, event.end_stream
)
elif isinstance(event, quic.QuicStreamReset):
yield quic.ResetQuicStream(
event.connection, event.stream_id, event.error_code
)
elif isinstance(event, quic.QuicStreamStopSending):
yield quic.StopSendingQuicStream(
elif isinstance(event, QuicStreamReset):
yield ResetQuicStream(event.connection, event.stream_id, event.error_code)
elif isinstance(event, QuicStreamStopSending):
yield StopSendingQuicStream(
event.connection, event.stream_id, event.error_code
)
else:
@ -198,312 +208,28 @@ fragmented_client_hello2 = bytes.fromhex(
def test_error_code_to_str():
assert quic.error_code_to_str(0x6) == "FINAL_SIZE_ERROR"
assert quic.error_code_to_str(0x104) == "H3_CLOSED_CRITICAL_STREAM"
assert quic.error_code_to_str(0xDEAD) == f"unknown error (0xdead)"
assert error_code_to_str(0x6) == "FINAL_SIZE_ERROR"
assert error_code_to_str(0x104) == "H3_CLOSED_CRITICAL_STREAM"
assert error_code_to_str(0xDEAD) == f"unknown error (0xdead)"
def test_is_success_error_code():
assert quic.is_success_error_code(0x0)
assert not quic.is_success_error_code(0x6)
assert quic.is_success_error_code(0x100)
assert not quic.is_success_error_code(0x104)
assert not quic.is_success_error_code(0xDEAD)
assert is_success_error_code(0x0)
assert not is_success_error_code(0x6)
assert is_success_error_code(0x100)
assert not is_success_error_code(0x104)
assert not is_success_error_code(0xDEAD)
@pytest.mark.parametrize("value", ["s1 s2\n", "s1 s2"])
def test_secrets_logger(value: str):
logger = MagicMock()
quic_logger = quic.QuicSecretsLogger(logger)
quic_logger = QuicSecretsLogger(logger)
assert quic_logger.write(value) == 6
quic_logger.flush()
logger.assert_called_once_with(None, b"s1 s2")
class TestParseClientHello:
def test_input(self):
assert (
quic.quic_parse_client_hello_from_datagrams([client_hello]).sni
== "example.com"
)
with pytest.raises(ValueError):
quic.quic_parse_client_hello_from_datagrams(
[client_hello[:183] + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00"]
)
with pytest.raises(ValueError, match="not initial"):
quic.quic_parse_client_hello_from_datagrams(
[
b"\\s\xd8\xd8\xa5dT\x8bc\xd3\xae\x1c\xb2\x8a7-\x1d\x19j\x85\xb0~\x8c\x80\xa5\x8cY\xac\x0ecK\x7fC2f\xbcm\x1b\xac~"
]
)
def test_invalid(self, monkeypatch):
class InvalidClientHello(Exception):
@property
def data(self):
raise EOFError()
monkeypatch.setattr(quic, "QuicClientHello", InvalidClientHello)
with pytest.raises(ValueError, match="Invalid ClientHello"):
quic.quic_parse_client_hello_from_datagrams([client_hello])
def test_connection_error(self, monkeypatch):
def raise_conn_err(self, data, addr, now):
raise quic.QuicConnectionError(0, 0, "Conn err")
monkeypatch.setattr(QuicConnection, "receive_datagram", raise_conn_err)
with pytest.raises(ValueError, match="Conn err"):
quic.quic_parse_client_hello_from_datagrams([client_hello])
def test_no_return(self):
with pytest.raises(
ValueError, match="Invalid ClientHello packet: payload_decrypt_error"
):
quic.quic_parse_client_hello_from_datagrams(
[client_hello[0:1200] + b"\x00" + client_hello[1200:]]
)
class TestQuicStreamLayer:
def test_ignored(self, tctx: context.Context):
quic_layer = quic.QuicStreamLayer(tctx, True, 1)
assert isinstance(quic_layer.child_layer, layers.TCPLayer)
assert not quic_layer.child_layer.flow
quic_layer.child_layer.flow = TCPFlow(tctx.client, tctx.server)
quic_layer.refresh_metadata()
assert quic_layer.child_layer.flow.metadata["quic_is_unidirectional"] is False
assert quic_layer.child_layer.flow.metadata["quic_initiator"] == "server"
assert quic_layer.child_layer.flow.metadata["quic_stream_id_client"] == 1
assert quic_layer.child_layer.flow.metadata["quic_stream_id_server"] is None
assert quic_layer.stream_id(True) == 1
assert quic_layer.stream_id(False) is None
def test_simple(self, tctx: context.Context):
quic_layer = quic.QuicStreamLayer(tctx, False, 2)
assert isinstance(quic_layer.child_layer, layer.NextLayer)
tunnel_layer = tunnel.TunnelLayer(tctx, tctx.client, tctx.server)
quic_layer.child_layer.layer = tunnel_layer
tcp_layer = layers.TCPLayer(tctx)
tunnel_layer.child_layer = tcp_layer
quic_layer.open_server_stream(3)
assert tcp_layer.flow.metadata["quic_is_unidirectional"] is True
assert tcp_layer.flow.metadata["quic_initiator"] == "client"
assert tcp_layer.flow.metadata["quic_stream_id_client"] == 2
assert tcp_layer.flow.metadata["quic_stream_id_server"] == 3
assert quic_layer.stream_id(True) == 2
assert quic_layer.stream_id(False) == 3
class TestRawQuicLayer:
@pytest.mark.parametrize("ignore", [True, False])
def test_error(self, tctx: context.Context, ignore: bool):
quic_layer = quic.RawQuicLayer(tctx, ignore=ignore)
assert (
tutils.Playbook(quic_layer)
<< commands.OpenConnection(tctx.server)
>> tutils.reply("failed to open")
<< commands.CloseConnection(tctx.client)
)
assert quic_layer._handle_event == quic_layer.done
def test_ignored(self, tctx: context.Context):
quic_layer = quic.RawQuicLayer(tctx, ignore=True)
assert (
tutils.Playbook(quic_layer)
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> events.DataReceived(tctx.client, b"msg1")
<< commands.SendData(tctx.server, b"msg1")
>> events.DataReceived(tctx.server, b"msg2")
<< commands.SendData(tctx.client, b"msg2")
>> quic.QuicStreamDataReceived(tctx.client, 0, b"msg3", end_stream=False)
<< quic.SendQuicStreamData(tctx.server, 0, b"msg3", end_stream=False)
>> quic.QuicStreamDataReceived(tctx.client, 6, b"msg4", end_stream=False)
<< quic.SendQuicStreamData(tctx.server, 2, b"msg4", end_stream=False)
>> quic.QuicStreamDataReceived(tctx.server, 9, b"msg5", end_stream=False)
<< quic.SendQuicStreamData(tctx.client, 1, b"msg5", end_stream=False)
>> quic.QuicStreamDataReceived(tctx.client, 0, b"", end_stream=True)
<< quic.SendQuicStreamData(tctx.server, 0, b"", end_stream=True)
>> quic.QuicStreamReset(tctx.client, 6, 142)
<< quic.ResetQuicStream(tctx.server, 2, 142)
>> quic.QuicConnectionClosed(tctx.client, 42, None, "closed")
<< quic.CloseQuicConnection(tctx.server, 42, None, "closed")
>> quic.QuicConnectionClosed(tctx.server, 42, None, "closed")
<< None
)
assert quic_layer._handle_event == quic_layer.done
def test_msg_inject(self, tctx: context.Context):
udpflow = tutils.Placeholder(UDPFlow)
playbook = tutils.Playbook(quic.RawQuicLayer(tctx))
assert (
playbook
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> events.DataReceived(tctx.client, b"msg1")
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(udp.UDPLayer)
<< udp.UdpStartHook(udpflow)
>> tutils.reply()
<< udp.UdpMessageHook(udpflow)
>> tutils.reply()
<< commands.SendData(tctx.server, b"msg1")
>> udp.UdpMessageInjected(udpflow, UDPMessage(True, b"msg2"))
<< udp.UdpMessageHook(udpflow)
>> tutils.reply()
<< commands.SendData(tctx.server, b"msg2")
>> udp.UdpMessageInjected(
UDPFlow(("other", 80), tctx.server), UDPMessage(True, b"msg3")
)
<< udp.UdpMessageHook(udpflow)
>> tutils.reply()
<< commands.SendData(tctx.server, b"msg3")
)
with pytest.raises(AssertionError, match="not associated"):
playbook >> udp.UdpMessageInjected(
UDPFlow(("notfound", 0), ("noexist", 0)), UDPMessage(True, b"msg2")
)
assert playbook
def test_reset_with_end_hook(self, tctx: context.Context):
tcpflow = tutils.Placeholder(TCPFlow)
assert (
tutils.Playbook(quic.RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> quic.QuicStreamDataReceived(tctx.client, 2, b"msg1", end_stream=False)
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(tcp.TCPLayer)
<< tcp.TcpStartHook(tcpflow)
>> tutils.reply()
<< tcp.TcpMessageHook(tcpflow)
>> tutils.reply()
<< quic.SendQuicStreamData(tctx.server, 2, b"msg1", end_stream=False)
>> quic.QuicStreamReset(tctx.client, 2, 42)
<< quic.ResetQuicStream(tctx.server, 2, 42)
<< tcp.TcpEndHook(tcpflow)
>> tutils.reply()
)
def test_close_with_end_hooks(self, tctx: context.Context):
udpflow = tutils.Placeholder(UDPFlow)
tcpflow = tutils.Placeholder(TCPFlow)
assert (
tutils.Playbook(quic.RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> events.DataReceived(tctx.client, b"msg1")
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(udp.UDPLayer)
<< udp.UdpStartHook(udpflow)
>> tutils.reply()
<< udp.UdpMessageHook(udpflow)
>> tutils.reply()
<< commands.SendData(tctx.server, b"msg1")
>> quic.QuicStreamDataReceived(tctx.client, 2, b"msg2", end_stream=False)
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(tcp.TCPLayer)
<< tcp.TcpStartHook(tcpflow)
>> tutils.reply()
<< tcp.TcpMessageHook(tcpflow)
>> tutils.reply()
<< quic.SendQuicStreamData(tctx.server, 2, b"msg2", end_stream=False)
>> quic.QuicConnectionClosed(tctx.client, 42, None, "bye")
<< quic.CloseQuicConnection(tctx.server, 42, None, "bye")
<< udp.UdpEndHook(udpflow)
<< tcp.TcpEndHook(tcpflow)
>> tutils.reply(to=-2)
>> tutils.reply(to=-2)
>> quic.QuicConnectionClosed(tctx.server, 42, None, "bye")
)
def test_invalid_stream_event(self, tctx: context.Context):
playbook = tutils.Playbook(quic.RawQuicLayer(tctx))
assert (
tutils.Playbook(quic.RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
)
with pytest.raises(AssertionError, match="Unexpected stream event"):
class InvalidStreamEvent(quic.QuicStreamEvent):
pass
playbook >> InvalidStreamEvent(tctx.client, 0)
assert playbook
def test_invalid_event(self, tctx: context.Context):
playbook = tutils.Playbook(quic.RawQuicLayer(tctx))
assert (
tutils.Playbook(quic.RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
)
with pytest.raises(AssertionError, match="Unexpected event"):
class InvalidEvent(events.Event):
pass
playbook >> InvalidEvent()
assert playbook
def test_full_close(self, tctx: context.Context):
assert (
tutils.Playbook(quic.RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> quic.QuicStreamDataReceived(tctx.client, 0, b"msg1", end_stream=True)
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(lambda ctx: udp.UDPLayer(ctx, ignore=True))
<< quic.SendQuicStreamData(tctx.server, 0, b"msg1", end_stream=False)
<< quic.SendQuicStreamData(tctx.server, 0, b"", end_stream=True)
<< quic.StopSendingQuicStream(tctx.server, 0, 0)
)
def test_open_connection(self, tctx: context.Context):
server = connection.Server(address=("other", 80))
def echo_new_server(ctx: context.Context):
echo_layer = TlsEchoLayer(ctx)
echo_layer.context.server = server
return echo_layer
assert (
tutils.Playbook(quic.RawQuicLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> quic.QuicStreamDataReceived(
tctx.client, 0, b"open-connection", end_stream=False
)
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(echo_new_server)
<< commands.OpenConnection(server)
>> tutils.reply("uhoh")
<< quic.SendQuicStreamData(
tctx.client, 0, b"open-connection failed: uhoh", end_stream=False
)
)
def test_invalid_connection_command(self, tctx: context.Context):
playbook = tutils.Playbook(quic.RawQuicLayer(tctx))
assert (
playbook
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
>> quic.QuicStreamDataReceived(tctx.client, 0, b"msg1", end_stream=False)
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(TlsEchoLayer)
<< quic.SendQuicStreamData(tctx.client, 0, b"msg1", end_stream=False)
)
with pytest.raises(
AssertionError, match="Unexpected stream connection command"
):
playbook >> quic.QuicStreamDataReceived(
tctx.client, 0, b"invalid-command", end_stream=False
)
assert playbook
class MockQuic(QuicConnection):
def __init__(self, event) -> None:
super().__init__(configuration=QuicConfiguration(is_client=True))
@ -527,7 +253,7 @@ def make_mock_quic(
established: bool = True,
) -> tuple[tutils.Playbook, MockQuic]:
tctx.client.state = connection.ConnectionState.CLOSED
quic_layer = quic.QuicLayer(tctx, tctx.client, time=lambda: 0)
quic_layer = QuicLayer(tctx, tctx.client, time=lambda: 0)
quic_layer.child_layer = TlsEchoLayer(tctx)
mock = MockQuic(event)
quic_layer.quic = mock
@ -579,7 +305,7 @@ class TestQuicLayer:
assert (
playbook
>> events.DataReceived(tctx.client, b"")
<< quic.CloseQuicConnection(tctx.client, 123, None, "error")
<< CloseQuicConnection(tctx.client, 123, None, "error")
)
assert conn._close_event
assert conn._close_event.error_code == 123
@ -626,7 +352,7 @@ class SSLTest:
alpn: list[str] | None = None,
sni: str | None = "example.mitmproxy.org",
version: int | None = None,
settings: quic.QuicTlsSettings | None = None,
settings: QuicTlsSettings | None = None,
):
if settings is None:
self.ctx = QuicConfiguration(
@ -636,8 +362,8 @@ class SSLTest:
self.ctx.verify_mode = ssl.CERT_OPTIONAL
self.ctx.load_verify_locations(
cafile=tlsdata.path(
"../../net/data/verificationcerts/trusted-root.crt"
cafile=tdata.path(
"mitmproxy/net/data/verificationcerts/trusted-root.crt"
),
)
@ -649,11 +375,11 @@ class SSLTest:
else:
filename = "trusted-leaf"
self.ctx.load_cert_chain(
certfile=tlsdata.path(
f"../../net/data/verificationcerts/{filename}.crt"
certfile=tdata.path(
f"mitmproxy/net/data/verificationcerts/{filename}.crt"
),
keyfile=tlsdata.path(
f"../../net/data/verificationcerts/{filename}.key"
keyfile=tdata.path(
f"mitmproxy/net/data/verificationcerts/{filename}.key"
),
)
@ -664,7 +390,7 @@ class SSLTest:
else:
assert alpn is None
assert version is None
self.ctx = quic.tls_settings_to_configuration(
self.ctx = tls_settings_to_configuration(
settings=settings,
is_client=not server_side,
server_name=sni,
@ -676,7 +402,7 @@ class SSLTest:
if not server_side:
self.quic.connect(self.address, now=self.now)
def write(self, buf: bytes) -> int:
def write(self, buf: bytes):
self.now = self.now + 0.1
if self.quic is None:
quic_buf = QuicBuffer(data=buf)
@ -766,13 +492,13 @@ def reply_tls_start_client(alpn: str | None = None, *args, **kwargs) -> tutils.r
Helper function to simplify the syntax for quic_start_client hooks.
"""
def make_client_conn(tls_start: quic.QuicTlsData) -> None:
def make_client_conn(tls_start: QuicTlsData) -> None:
config = QuicConfiguration()
config.load_cert_chain(
tlsdata.path("../../net/data/verificationcerts/trusted-leaf.crt"),
tlsdata.path("../../net/data/verificationcerts/trusted-leaf.key"),
tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.crt"),
tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.key"),
)
tls_start.settings = quic.QuicTlsSettings(
tls_start.settings = QuicTlsSettings(
certificate=config.certificate,
certificate_chain=config.certificate_chain,
certificate_private_key=config.private_key,
@ -788,9 +514,9 @@ def reply_tls_start_server(alpn: str | None = None, *args, **kwargs) -> tutils.r
Helper function to simplify the syntax for quic_start_server hooks.
"""
def make_server_conn(tls_start: quic.QuicTlsData) -> None:
tls_start.settings = quic.QuicTlsSettings(
ca_file=tlsdata.path("../../net/data/verificationcerts/trusted-root.crt"),
def make_server_conn(tls_start: QuicTlsData) -> None:
tls_start.settings = QuicTlsSettings(
ca_file=tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.crt"),
verify_mode=ssl.CERT_REQUIRED,
)
if alpn is not None:
@ -801,11 +527,11 @@ def reply_tls_start_server(alpn: str | None = None, *args, **kwargs) -> tutils.r
class TestServerQuic:
def test_repr(self, tctx: context.Context):
assert repr(quic.ServerQuicLayer(tctx, time=lambda: 0))
assert repr(ServerQuicLayer(tctx, time=lambda: 0))
def test_not_connected(self, tctx: context.Context):
"""Test that we don't do anything if no server connection exists."""
layer = quic.ServerQuicLayer(tctx, time=lambda: 0)
layer = ServerQuicLayer(tctx, time=lambda: 0)
layer.child_layer = TlsEchoLayer(tctx)
assert (
@ -817,7 +543,7 @@ class TestServerQuic:
def test_simple(self, tctx: context.Context):
tssl = SSLTest(server_side=True)
playbook = tutils.Playbook(quic.ServerQuicLayer(tctx, time=lambda: tssl.now))
playbook = tutils.Playbook(ServerQuicLayer(tctx, time=lambda: tssl.now))
tctx.server.address = ("example.mitmproxy.org", 443)
tctx.server.state = connection.ConnectionState.OPEN
tctx.server.sni = "example.mitmproxy.org"
@ -826,7 +552,7 @@ class TestServerQuic:
data = tutils.Placeholder(bytes)
assert (
playbook
<< quic.QuicStartServerHook(tutils.Placeholder())
<< QuicStartServerHook(tutils.Placeholder())
>> reply_tls_start_server()
<< commands.SendData(tctx.server, data)
<< commands.RequestWakeup(0.2)
@ -873,7 +599,7 @@ class TestServerQuic:
"""If the certificate is not trusted, we should fail."""
tssl = SSLTest(server_side=True)
playbook = tutils.Playbook(quic.ServerQuicLayer(tctx, time=lambda: tssl.now))
playbook = tutils.Playbook(ServerQuicLayer(tctx, time=lambda: tssl.now))
tctx.server.address = ("wrong.host.mitmproxy.org", 443)
tctx.server.sni = "wrong.host.mitmproxy.org"
@ -886,7 +612,7 @@ class TestServerQuic:
>> events.DataReceived(tctx.client, b"open-connection")
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
<< quic.QuicStartServerHook(tutils.Placeholder())
<< QuicStartServerHook(tutils.Placeholder())
>> reply_tls_start_server()
<< commands.SendData(tctx.server, data)
<< commands.RequestWakeup(0.2)
@ -906,7 +632,7 @@ class TestServerQuic:
tssl.write(data())
tssl.now = tssl.now + 60
tls_hook_data = tutils.Placeholder(quic.QuicTlsData)
tls_hook_data = tutils.Placeholder(QuicTlsData)
assert (
playbook
>> tutils.reply(to=commands.RequestWakeup)
@ -934,7 +660,7 @@ class TestServerQuic:
def make_client_tls_layer(
tctx: context.Context, no_server: bool = False, **kwargs
) -> tuple[tutils.Playbook, quic.ClientQuicLayer, SSLTest]:
) -> tuple[tutils.Playbook, ClientQuicLayer, SSLTest]:
tssl_client = SSLTest(**kwargs)
# This is a bit contrived as the client layer expects a server layer as parent.
@ -942,9 +668,9 @@ def make_client_tls_layer(
server_layer = (
DummyLayer(tctx)
if no_server
else quic.ServerQuicLayer(tctx, time=lambda: tssl_client.now)
else ServerQuicLayer(tctx, time=lambda: tssl_client.now)
)
client_layer = quic.ClientQuicLayer(tctx, time=lambda: tssl_client.now)
client_layer = ClientQuicLayer(tctx, time=lambda: tssl_client.now)
server_layer.child_layer = client_layer
playbook = tutils.Playbook(server_layer)
@ -966,7 +692,7 @@ class TestClientQuic:
"""Test that we swallow QUIC packets if QUIC and HTTP/3 are disabled."""
tctx.options.http3 = False
assert (
tutils.Playbook(quic.ClientQuicLayer(tctx, time=time.time), logs=True)
tutils.Playbook(ClientQuicLayer(tctx, time=time.time), logs=True)
>> events.DataReceived(tctx.client, client_hello)
<< commands.Log(
"Swallowing QUIC handshake because HTTP/3 is disabled.", DEBUG
@ -987,7 +713,7 @@ class TestClientQuic:
>> events.DataReceived(tctx.client, tssl_client.read())
<< tls.TlsClienthelloHook(tutils.Placeholder())
>> tutils.reply()
<< quic.QuicStartClientHook(tutils.Placeholder())
<< QuicStartClientHook(tutils.Placeholder())
>> reply_tls_start_client()
<< commands.SendData(tctx.client, data)
<< commands.RequestWakeup(tutils.Placeholder())
@ -1057,7 +783,7 @@ class TestClientQuic:
playbook >> tutils.reply(None)
assert (
playbook
<< quic.QuicStartServerHook(tutils.Placeholder())
<< QuicStartServerHook(tutils.Placeholder())
>> reply_tls_start_server(alpn="quux")
<< commands.SendData(tctx.server, data)
<< commands.RequestWakeup(tutils.Placeholder())
@ -1075,7 +801,7 @@ class TestClientQuic:
>> tutils.reply()
<< commands.SendData(tctx.server, data)
<< commands.RequestWakeup(tutils.Placeholder())
<< quic.QuicStartClientHook(tutils.Placeholder())
<< QuicStartClientHook(tutils.Placeholder())
)
tssl_server.write(data())
assert tctx.server.tls_established
@ -1147,7 +873,7 @@ class TestClientQuic:
def test_fragmented_client_hello(
self, tctx: context.Context, fragments: list[bytes]
):
client_layer = quic.ClientQuicLayer(tctx, time=time.time)
client_layer = ClientQuicLayer(tctx, time=time.time)
playbook = tutils.Playbook(client_layer)
assert not tctx.client.sni
@ -1159,7 +885,7 @@ class TestClientQuic:
>> events.DataReceived(tctx.client, fragments[1])
<< tls.TlsClienthelloHook(tutils.Placeholder())
>> tutils.reply()
<< quic.QuicStartClientHook(tutils.Placeholder())
<< QuicStartClientHook(tutils.Placeholder())
)
assert tctx.client.sni == "localhost"
@ -1176,7 +902,7 @@ class TestClientQuic:
):
"""Test the scenario where we cannot parse the ClientHello"""
playbook, client_layer, tssl_client = make_client_tls_layer(tctx)
tls_hook_data = tutils.Placeholder(quic.QuicTlsData)
tls_hook_data = tutils.Placeholder(QuicTlsData)
assert (
playbook
@ -1219,7 +945,7 @@ class TestClientQuic:
>> events.DataReceived(tctx.client, tssl_client.read())
<< tls.TlsClienthelloHook(tutils.Placeholder())
>> tutils.reply()
<< quic.QuicStartClientHook(tutils.Placeholder())
<< QuicStartClientHook(tutils.Placeholder())
>> reply_tls_start_client()
<< commands.SendData(tctx.client, data)
<< commands.RequestWakeup(tutils.Placeholder())
@ -1228,7 +954,7 @@ class TestClientQuic:
assert not tssl_client.handshake_completed()
# Finish Handshake
tls_hook_data = tutils.Placeholder(quic.QuicTlsData)
tls_hook_data = tutils.Placeholder(QuicTlsData)
playbook >> events.DataReceived(tctx.client, tssl_client.read())
assert playbook
tssl_client.now = tssl_client.now + 60
@ -1268,7 +994,7 @@ class TestClientQuic:
f"If you plan to redirect requests away from this server, "
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
)
<< quic.QuicStartClientHook(tutils.Placeholder())
<< QuicStartClientHook(tutils.Placeholder())
)
tctx.client.state = connection.ConnectionState.CLOSED
assert (
@ -1301,7 +1027,7 @@ class TestClientQuic:
f"If you plan to redirect requests away from this server, "
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
)
<< quic.QuicStartClientHook(tutils.Placeholder())
<< QuicStartClientHook(tutils.Placeholder())
)
def test_version_negotiation(self, tctx: context.Context):
@ -1354,7 +1080,7 @@ class TestClientQuic:
assert client_layer.tunnel_state == tls.tunnel.TunnelState.ESTABLISHING
def test_invalid_fragmented_clienthello(self, tctx: context.Context):
client_layer = quic.ClientQuicLayer(tctx, time=time.time)
client_layer = ClientQuicLayer(tctx, time=time.time)
playbook = tutils.Playbook(client_layer)
assert not tctx.client.sni
@ -1381,5 +1107,5 @@ class TestClientQuic:
tctx.client.tls = True
tctx.client.sni = "some"
DummyLayer(tctx)
quic.ClientQuicLayer(tctx, time=lambda: 0)
ClientQuicLayer(tctx, time=lambda: 0)
assert tctx.client.sni is None