Merge pull request #2 from mhils/pr-5435

Pr 5435
This commit is contained in:
Manuel Meitinger 2022-10-24 02:02:29 +02:00 committed by GitHub
commit 98ce1c2b2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 129 additions and 83 deletions

View File

@ -94,6 +94,8 @@ class CloseConnection(ConnectionCommand):
all other connections will ultimately be closed during cleanup.
"""
class CloseTcpConnection(CloseConnection):
half_close: bool
"""
If True, only close our half of the connection by sending a FIN packet.

View File

@ -45,6 +45,7 @@ from ._hooks import ( # noqa
from ._http1 import Http1Client, Http1Connection, Http1Server
from ._http2 import Http2Client, Http2Server
from ._http3 import Http3Client, Http3Server
from ..quic import QuicStreamEvent
from ...context import Context
from ...mode_specs import ReverseMode, UpstreamMode
@ -790,7 +791,8 @@ class HttpStream(layer.Layer):
# The easiest approach for this is to just always full close for now.
# Alternatively, we could signal that we want a half close only through ResponseProtocolError,
# but that is more complex to implement.
command.half_close = False
if isinstance(command, commands.CloseTcpConnection):
command = commands.CloseConnection(command.connection)
yield command
else:
yield command
@ -886,7 +888,7 @@ class HttpLayer(layer.Layer):
if isinstance(event, events.ConnectionClosed):
# The peer has closed it - let's close it too!
yield commands.CloseConnection(event.connection)
else:
elif isinstance(event, (events.DataReceived, QuicStreamEvent)):
# The peer has sent data or another connection activity occurred.
# This can happen with HTTP/2 servers that already send a settings frame.
child_layer: HttpConnection
@ -899,6 +901,8 @@ class HttpLayer(layer.Layer):
self.connections[self.context.server] = child_layer
yield from self.event_to_child(child_layer, events.Start())
yield from self.event_to_child(child_layer, event)
else:
raise AssertionError(f"Unexpected event: {event}")
else:
handler = self.connections[event.connection]
yield from self.event_to_child(handler, event)

View File

@ -369,7 +369,7 @@ class Http1Client(Http1Connection):
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1.expected_http_body_size(self.request, self.response) == -1:
yield commands.CloseConnection(self.conn, half_close=True)
yield commands.CloseTcpConnection(self.conn, half_close=True)
yield from self.mark_done(request=True)
else:
raise AssertionError(f"Unexpected event: {event}")

View File

@ -12,9 +12,9 @@ from mitmproxy import connection, http, version
from mitmproxy.net.http import status_codes
from mitmproxy.proxy import commands, context, events, layer
from mitmproxy.proxy.layers.quic import (
QuicConnectionClosed,
QuicStreamEvent,
error_code_to_str,
get_connection_error,
)
from mitmproxy.proxy.utils import expect
@ -164,12 +164,10 @@ class Http3Connection(HttpConnection):
# report a protocol error for all remaining open streams when a connection is closed
elif isinstance(event, events.ConnectionClosed):
self._handle_event = self.done # type: ignore
close_event = get_connection_error(self.conn)
msg = (
"peer closed connection"
if close_event is None else
close_event.reason_phrase or error_code_to_str(close_event.error_code)
)
if isinstance(event, QuicConnectionClosed):
msg = event.reason_phrase or error_code_to_str(event.error_code)
else:
msg = "peer closed connection"
for stream_id in self.h3_conn.get_reserved_stream_ids():
yield ReceiveHttp(self.ReceiveProtocolError(stream_id, msg))

View File

@ -17,12 +17,12 @@ from aioquic.quic.packet import QuicErrorCode
from mitmproxy import connection
from mitmproxy.proxy import commands, layer
from mitmproxy.proxy.layers.quic import (
CloseQuicConnection,
QuicStreamDataReceived,
QuicStreamEvent,
QuicStreamReset,
ResetQuicStream,
SendQuicStreamData,
set_connection_error,
)
@ -87,13 +87,10 @@ class MockQuic:
# we'll get closed if a protocol error occurs in `H3Connection.handle_event`
# we note the error on the connection and yield a CloseConnection
# this will then call `QuicConnection.close` with the proper values
# once the `Http3Connection` receives `ConnectionClosed`, it will send out `*ProtocolError`
set_connection_error(self.conn, ConnectionTerminated(
error_code=error_code,
frame_type=frame_type,
reason_phrase=reason_phrase,
))
self.pending_commands.append(commands.CloseConnection(self.conn))
# once the `Http3Connection` receives `ConnectionClosed`, it will send out `ProtocolError`
self.pending_commands.append(
CloseQuicConnection(self.conn, error_code, frame_type, reason_phrase)
)
def get_next_available_stream_id(self, is_unidirectional: bool = False) -> int:
# since we always reserve the ID, we have to "find" the next ID like `QuicConnection` does

View File

@ -60,13 +60,14 @@ class ReverseProxy(DestinationKnown):
assert isinstance(spec, ReverseMode)
self.context.server.address = spec.address
if spec.scheme in ("https", "http3", "quic", "tls", "dtls"):
if spec.scheme in ("http3", "quic"):
if not self.context.options.keep_host_header:
self.context.server.sni = spec.address[0]
if spec.scheme == "http3" or spec.scheme == "quic":
self.child_layer = quic.ServerQuicLayer(self.context)
else:
self.child_layer = tls.ServerTLSLayer(self.context)
self.child_layer = quic.ServerQuicLayer(self.context)
elif spec.scheme in ("https", "tls", "dtls"):
if not self.context.options.keep_host_header:
self.context.server.sni = spec.address[0]
self.child_layer = tls.ServerTLSLayer(self.context)
elif spec.scheme == "udp":
self.child_layer = udp.UDPLayer(self.context)
elif spec.scheme == "http" or spec.scheme == "tcp":

View File

@ -177,6 +177,56 @@ class StopQuicStream(QuicStreamCommand):
self.error_code = error_code
class CloseQuicConnection(commands.CloseConnection):
"""Close a QUIC connection."""
error_code: int
"The error code which was specified when closing the connection."
frame_type: int | None
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
# XXX: A bit much boilerplate right now. Should switch to dataclasses.
def __init__(
self,
conn: connection.Connection,
error_code: int,
frame_type: int | None,
reason_phrase: str,
):
super().__init__(conn)
self.error_code = error_code
self.frame_type = frame_type
self.reason_phrase = reason_phrase
class QuicConnectionClosed(events.ConnectionClosed):
"""QUIC connection has been closed."""
error_code: int
"The error code which was specified when closing the connection."
frame_type: int | None
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
def __init__(
self,
conn: connection.Connection,
error_code: int,
frame_type: int | None,
reason_phrase: str,
):
super().__init__(conn)
self.error_code = error_code
self.frame_type = frame_type
self.reason_phrase = reason_phrase
class QuicSecretsLogger:
logger: tls.MasterSecretLogger
@ -208,28 +258,12 @@ def error_code_to_str(error_code: int) -> str:
return f"unknown error (0x{error_code:x})"
def get_connection_error(conn: connection.Connection) -> quic_events.ConnectionTerminated | None:
"""Returns the QUIC close event that is associated with the given connection."""
close_event = getattr(conn, "quic_error", None)
if close_event is None:
return None
assert isinstance(close_event, quic_events.ConnectionTerminated)
return close_event
def is_success_error_code(error_code: int) -> bool:
"""Returns whether the given error code actually indicates no error."""
return error_code in (QuicErrorCode.NO_ERROR, H3ErrorCode.H3_NO_ERROR)
def set_connection_error(conn: connection.Connection, close_event: quic_events.ConnectionTerminated) -> None:
"""Stores the given close event for the given connection."""
setattr(conn, "quic_error", close_event)
@dataclass
class QuicClientHello(Exception):
"""Helper error only used in `quic_parse_client_hello`."""
@ -299,7 +333,7 @@ class QuicStreamLayer(layer.Layer):
"""Virtual client connection for this stream. Use this in QuicRawLayer instead of `context.client`."""
server: connection.Server
"""Virtual server connection for this stream. Use this in QuicRawLayer instead of `context.server`."""
child_layer: layer.Layer
child_layer: TCPLayer
"""The stream's child layer."""
def __init__(self, context: context.Context, ignore: bool, stream_id: int) -> None:
@ -335,11 +369,21 @@ class QuicStreamLayer(layer.Layer):
if ignore else
layer.NextLayer(context)
)
if ignore:
self.child_layer = TCPLayer(context, ignore=True)
else:
tcp_layer = TCPLayer(context)
# This can potentially move to a smarter place later on,
# but it's useful debugging info in mitmproxy for now.
tcp_layer.flow.metadata["quic_is_unidirectional"] = stream_is_unidirectional(stream_id)
tcp_layer.flow.metadata["quic_initiator"] = "client" if stream_is_client_initiated(stream_id) else "server"
tcp_layer.flow.metadata["quic_stream_id_client"] = stream_id
self.child_layer = tcp_layer
self.handle_event = self.child_layer.handle_event # type: ignore
self._handle_event = self.child_layer._handle_event # type: ignore
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
pass
raise AssertionError
def open_server_stream(self, server_stream_id) -> None:
assert self._server_stream_id is None
@ -354,6 +398,8 @@ class QuicStreamLayer(layer.Layer):
if stream_is_unidirectional(server_stream_id) else
connection.ConnectionState.OPEN
)
if self.child_layer.flow:
self.child_layer.flow.metadata["quic_stream_id_server"] = server_stream_id
def stream_id(self, client: bool) -> int | None:
return self._client_stream_id if client else self._server_stream_id
@ -481,17 +527,13 @@ class RawQuicLayer(layer.Layer):
# handle close events that target this context
elif (
isinstance(event, events.ConnectionClosed)
isinstance(event, QuicConnectionClosed)
and (
event.connection is self.context.client
or event.connection is self.context.server
)
):
# copy the connection error
from_client = event.connection is self.context.client
close_event = get_connection_error(event.connection)
if close_event is not None:
set_connection_error(self.context.server if from_client else self.context.client, close_event)
# always forward to the datagram layer
yield from self.event_to_child(self.datagram_layer, event)
@ -552,7 +594,9 @@ class RawQuicLayer(layer.Layer):
if command.connection.state & connection.ConnectionState.CAN_WRITE:
command.connection.state &= ~connection.ConnectionState.CAN_WRITE
yield SendQuicStreamData(quic_conn, stream_id, b"", end_stream=True)
if not command.half_close:
# XXX: Use `command.connection.state & connection.ConnectionState.CAN_READ` instead?
only_close_our_half = isinstance(command, commands.CloseTcpConnection) and command.half_close
if not only_close_our_half:
if (
stream_is_client_initiated(stream_id) == to_client
or not stream_is_unidirectional(stream_id)
@ -605,21 +649,24 @@ class QuicLayer(tunnel.TunnelLayer):
conn.tls = True
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
# turn Wakeup events into empty DataReceived events
if (
isinstance(event, events.Wakeup)
and event.command in self._wakeup_commands
):
# TunnelLayer has no understanding of wakeups, so we turn this into an empty DataReceived event
# which TunnelLayer recognizes as belonging to our connection.
assert self.quic
timer = self._wakeup_commands.pop(event.command)
if self.quic._state is not QuicConnectionState.TERMINATED:
self.quic.handle_timer(now=max(timer, self._loop.time()))
event = events.DataReceived(self.tunnel_connection, b"")
yield from super()._handle_event(event)
yield from super()._handle_event(
events.DataReceived(self.tunnel_connection, b"")
)
else:
yield from super()._handle_event(event)
def _handle_command(self, command: commands.Command) -> layer.CommandGenerator[None]:
"""Turns stream commands into aioquic connection invocations."""
if (
isinstance(command, QuicStreamCommand)
and command.connection is self.conn
@ -783,13 +830,12 @@ class QuicLayer(tunnel.TunnelLayer):
# handle post-handshake events
while event := self.quic.next_event():
if isinstance(event, quic_events.ConnectionTerminated):
set_connection_error(self.conn, event)
if self.debug:
reason = event.reason_phrase or error_code_to_str(event.error_code)
yield commands.Log(
f"{self.debug}[quic] close_notify {self.conn} (reason={reason})", DEBUG
)
yield commands.CloseConnection(self.conn)
yield CloseQuicConnection(self.conn, event.error_code, event.frame_type, event.reason_phrase)
return # we don't handle any further events, nor do/can we transmit data, so exit
elif isinstance(event, quic_events.DatagramFrameReceived):
yield from self.event_to_child(events.DataReceived(self.conn, event.data))
@ -801,6 +847,7 @@ class QuicLayer(tunnel.TunnelLayer):
quic_events.ConnectionIdIssued,
quic_events.ConnectionIdRetired,
quic_events.PingAcknowledged,
quic_events.ProtocolNegotiated,
)):
pass
else:
@ -820,16 +867,15 @@ class QuicLayer(tunnel.TunnelLayer):
self.quic.send_datagram_frame(data)
yield from self.tls_interact()
def send_close(self, half_close: bool) -> layer.CommandGenerator[None]:
def send_close(self, command: commands.CloseConnection) -> layer.CommandGenerator[None]:
# properly close the QUIC connection
if self.quic is not None:
close_event = get_connection_error(self.conn)
if close_event is None:
self.quic.close()
if isinstance(command, CloseQuicConnection):
self.quic.close(command.error_code, command.frame_type, command.reason_phrase)
else:
self.quic.close(close_event.error_code, close_event.frame_type, close_event.reason_phrase)
self.quic.close()
yield from self.tls_interact()
yield from super().send_close(half_close)
yield from super().send_close(command)
class ServerQuicLayer(QuicLayer):

View File

@ -132,7 +132,7 @@ class TCPLayer(layer.Layer):
yield TcpEndHook(self.flow)
self.flow.live = False
else:
yield commands.CloseConnection(send_to, half_close=True)
yield commands.CloseTcpConnection(send_to, half_close=True)
else:
raise AssertionError(f"Unexpected event: {event}")

View File

@ -440,9 +440,9 @@ class TLSLayer(tunnel.TunnelLayer):
pass
yield from self.tls_interact()
def send_close(self, half_close: bool) -> layer.CommandGenerator[None]:
def send_close(self, command: commands.CloseConnection) -> layer.CommandGenerator[None]:
# We should probably shutdown the TLS connection properly here.
yield from super().send_close(half_close)
yield from super().send_close(command)
class ServerTLSLayer(TLSLayer):

View File

@ -369,8 +369,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
assert writer
if not writer.is_closing():
writer.write(command.data)
elif isinstance(command, commands.CloseConnection):
elif isinstance(command, commands.CloseTcpConnection):
self.close_connection(command.connection, command.half_close)
elif isinstance(command, commands.CloseConnection):
self.close_connection(command.connection, False)
elif isinstance(command, commands.StartHook):
asyncio_utils.create_task(
self.hook_task(command),

View File

@ -117,11 +117,9 @@ class TunnelLayer(layer.Layer):
yield from self.send_data(command.data)
elif isinstance(command, commands.CloseConnection):
if self.conn != self.tunnel_connection:
if command.half_close:
self.conn.state &= ~connection.ConnectionState.CAN_WRITE
else:
self.conn.state = connection.ConnectionState.CLOSED
yield from self.send_close(command.half_close)
self.conn.state &= ~connection.ConnectionState.CAN_WRITE
command.connection = self.tunnel_connection
yield from self.send_close(command)
elif isinstance(command, commands.OpenConnection):
# create our own OpenConnection command object that blocks here.
self.command_to_reply_to = command
@ -172,8 +170,8 @@ class TunnelLayer(layer.Layer):
def send_data(self, data: bytes) -> layer.CommandGenerator[None]:
yield commands.SendData(self.tunnel_connection, data)
def send_close(self, half_close: bool) -> layer.CommandGenerator[None]:
yield commands.CloseConnection(self.tunnel_connection, half_close=half_close)
def send_close(self, command: commands.CloseConnection) -> layer.CommandGenerator[None]:
yield command
class LayerStack:

View File

@ -54,6 +54,7 @@ def run(
logging.getLogger("tornado").setLevel(logging.WARNING)
logging.getLogger("asyncio").setLevel(logging.WARNING)
logging.getLogger("hpack").setLevel(logging.WARNING)
logging.getLogger("quic").setLevel(logging.WARNING) # aioquic uses a different prefix...
debug.register_info_dumpers()
opts = options.Options()

View File

@ -1,6 +1,6 @@
import pytest
from mitmproxy.proxy.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy.commands import CloseConnection, CloseTcpConnection, OpenConnection, SendData
from mitmproxy.proxy.events import ConnectionClosed, DataReceived
from mitmproxy.proxy.layers import tcp
from mitmproxy.proxy.layers.tcp import TcpMessageInjected
@ -52,7 +52,7 @@ def test_simple(tctx):
>> reply()
<< SendData(tctx.client, b"hi")
>> ConnectionClosed(tctx.server)
<< CloseConnection(tctx.client, half_close=True)
<< CloseTcpConnection(tctx.client, half_close=True)
>> ConnectionClosed(tctx.client)
<< CloseConnection(tctx.server)
<< tcp.TcpEndHook(f)
@ -88,7 +88,7 @@ def test_receive_data_after_half_close(tctx):
>> DataReceived(tctx.client, b"eof-delimited-request")
<< SendData(tctx.server, b"eof-delimited-request")
>> ConnectionClosed(tctx.client)
<< CloseConnection(tctx.server, half_close=True)
<< CloseTcpConnection(tctx.server, half_close=True)
>> DataReceived(tctx.server, b"i'm late")
<< SendData(tctx.client, b"i'm late")
>> ConnectionClosed(tctx.server)

View File

@ -3,7 +3,7 @@ from typing import Optional
import pytest
from mitmproxy.proxy import tunnel, layer
from mitmproxy.proxy.commands import SendData, Log, CloseConnection, OpenConnection
from mitmproxy.proxy.commands import CloseTcpConnection, SendData, Log, CloseConnection, OpenConnection
from mitmproxy.connection import Server, ConnectionState
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.events import Event, DataReceived, Start, ConnectionClosed
@ -24,7 +24,7 @@ class TChildLayer(layer.Layer):
err = yield OpenConnection(self.context.server)
yield Log(f"Opened: {err=}. Server state: {self.context.server.state.name}")
elif isinstance(event, DataReceived) and event.data == b"half-close":
err = yield CloseConnection(event.connection, half_close=True)
err = yield CloseTcpConnection(event.connection, half_close=True)
elif isinstance(event, ConnectionClosed):
yield Log(f"Got {event.connection.__class__.__name__.lower()} close.")
yield CloseConnection(event.connection)
@ -164,7 +164,7 @@ def test_tunnel_default_impls(tctx: Context):
>> reply(None)
<< Log("Opened: err=None. Server state: OPEN")
>> DataReceived(server, b"half-close")
<< CloseConnection(server, half_close=True)
<< CloseTcpConnection(server, half_close=True)
)

View File

@ -224,11 +224,10 @@ class Playbook:
for cmd in cmds:
pos += 1
assert self.actual[pos] == cmd
if isinstance(cmd, commands.CloseConnection):
if cmd.half_close:
cmd.connection.state &= ~ConnectionState.CAN_WRITE
else:
cmd.connection.state = ConnectionState.CLOSED
if isinstance(cmd, commands.CloseTcpConnection) and cmd.half_close:
cmd.connection.state &= ~ConnectionState.CAN_WRITE
elif isinstance(cmd, commands.CloseConnection):
cmd.connection.state = ConnectionState.CLOSED
elif isinstance(cmd, commands.Log):
need_to_emulate_log = (
not self.logs

View File

@ -43,7 +43,6 @@ export interface OptionsState {
proxy_debug: boolean
proxyauth: string | undefined
rawtcp: boolean
rawudp: boolean
readfile_filter: string | undefined
rfile: string | undefined
save_stream_file: string | undefined
@ -135,7 +134,6 @@ export const defaultState: OptionsState = {
proxy_debug: false,
proxyauth: undefined,
rawtcp: true,
rawudp: true,
readfile_filter: undefined,
rfile: undefined,
save_stream_file: undefined,