diff --git a/examples/contrib/tls_passthrough.py b/examples/contrib/tls_passthrough.py index 8652651f2..182f26289 100644 --- a/examples/contrib/tls_passthrough.py +++ b/examples/contrib/tls_passthrough.py @@ -19,8 +19,7 @@ import random from abc import ABC, abstractmethod from enum import Enum -from mitmproxy import connection, ctx -from mitmproxy.proxy.layers import tls +from mitmproxy import connection, ctx, tls from mitmproxy.utils import human @@ -95,7 +94,7 @@ class MaybeTls: data.ignore_connection = True self.strategy.record_skipped(server_address) - def tls_handshake(self, data: tls.TlsHookData): + def tls_handshake(self, data: tls.TlsData): if isinstance(data.conn, connection.Server): return # we are only interested in failing client connections here. server_address = data.context.server.peername diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index 9e2a57bd9..548a95a87 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -4,11 +4,12 @@ from pathlib import Path from typing import List, Optional, TypedDict, Any from OpenSSL import SSL -from mitmproxy import certs, ctx, exceptions, connection +from mitmproxy import certs, ctx, exceptions, connection, tls from mitmproxy.net import tls as net_tls from mitmproxy.options import CONF_BASENAME from mitmproxy.proxy import context -from mitmproxy.proxy.layers import tls, modes +from mitmproxy.proxy.layers import modes +from mitmproxy.proxy.layers import tls as proxy_tls # We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default. # https://ssl-config.mozilla.org/#config=old @@ -46,7 +47,7 @@ def alpn_select_callback(conn: SSL.Connection, options: List[bytes]) -> Any: # We do have a server connection, but the remote server refused to negotiate a protocol: # We need to mirror this on the client connection. return SSL.NO_OVERLAPPING_PROTOCOLS - http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS + http_alpns = proxy_tls.HTTP_ALPNS if http2 else proxy_tls.HTTP1_ALPNS for alpn in options: # client sends in order of preference, so we are nice and respect that. if alpn in http_alpns: return alpn @@ -112,7 +113,7 @@ class TlsConfig: ctx.options.upstream_cert ) - def tls_start_client(self, tls_start: tls.TlsHookData) -> None: + def tls_start_client(self, tls_start: tls.TlsData) -> None: """Establish TLS between client and proxy.""" client: connection.Client = tls_start.context.client server: connection.Server = tls_start.context.server @@ -159,7 +160,7 @@ class TlsConfig: )) tls_start.ssl_conn.set_accept_state() - def tls_start_server(self, tls_start: tls.TlsHookData) -> None: + def tls_start_server(self, tls_start: tls.TlsData) -> None: """Establish TLS between proxy and server.""" client: connection.Client = tls_start.context.client server: connection.Server = tls_start.context.server diff --git a/mitmproxy/proxy/layers/tls.py b/mitmproxy/proxy/layers/tls.py index 8c9c4e0bc..f9896f0df 100644 --- a/mitmproxy/proxy/layers/tls.py +++ b/mitmproxy/proxy/layers/tls.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Iterator, Literal, Optional, Tuple from OpenSSL import SSL -from mitmproxy.tls import ClientHello +from mitmproxy.tls import ClientHello, ClientHelloData, TlsData from mitmproxy import certs, connection from mitmproxy.proxy import commands, events, layer, tunnel from mitmproxy.proxy import context @@ -98,22 +98,6 @@ HTTP_ALPNS = (b"h2",) + HTTP1_ALPNS # We need these classes as hooks can only have one argument at the moment. -@dataclass -class ClientHelloData: - context: context.Context - """The context object for this connection.""" - client_hello: ClientHello - """The entire parsed TLS ClientHello.""" - ignore_connection: bool = False - """ - If set to `True`, do not intercept this connection and forward encrypted contents unmodified. - """ - establish_server_tls_first: bool = False - """ - If set to `True`, pause this handshake and establish TLS with an upstream server first. - This makes it possible to process the server certificate when generating an interception certificate. - """ - @dataclass class TlsClienthelloHook(StartHook): @@ -126,13 +110,6 @@ class TlsClienthelloHook(StartHook): data: ClientHelloData -@dataclass -class TlsHookData: - conn: connection.Connection - context: context.Context - ssl_conn: Optional[SSL.Connection] = None - - @dataclass class TlsStartClientHook(StartHook): """ @@ -141,7 +118,7 @@ class TlsStartClientHook(StartHook): An addon is expected to initialize data.ssl_conn. (by default, this is done by `mitmproxy.addons.tlsconfig`) """ - data: TlsHookData + data: TlsData @dataclass @@ -152,7 +129,7 @@ class TlsStartServerHook(StartHook): An addon is expected to initialize data.ssl_conn. (by default, this is done by `mitmproxy.addons.tlsconfig`) """ - data: TlsHookData + data: TlsData @dataclass @@ -162,7 +139,7 @@ class TlsHandshakeHook(StartHook): If `data.conn.error` is `None`, negotiation was successful. """ - data: TlsHookData + data: TlsData class _TLSLayer(tunnel.TunnelLayer): @@ -184,7 +161,7 @@ class _TLSLayer(tunnel.TunnelLayer): def start_tls(self) -> layer.CommandGenerator[None]: assert not self.tls - tls_start = TlsHookData(self.conn, self.context) + tls_start = TlsData(self.conn, self.context) if tls_start.conn == tls_start.context.client: yield TlsStartClientHook(tls_start) else: @@ -256,13 +233,13 @@ class _TLSLayer(tunnel.TunnelLayer): self.conn.tls_version = self.tls.get_protocol_version_name() if self.debug: yield commands.Log(f"{self.debug}[tls] tls established: {self.conn}", "debug") - yield TlsHandshakeHook(TlsHookData(self.conn, self.context, self.tls)) + yield TlsHandshakeHook(TlsData(self.conn, self.context, self.tls)) yield from self.receive_data(b"") return True, None def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]: self.conn.error = err - yield TlsHandshakeHook(TlsHookData(self.conn, self.context, self.tls)) + yield TlsHandshakeHook(TlsData(self.conn, self.context, self.tls)) yield from super().on_handshake_error(err) def receive_data(self, data: bytes) -> layer.CommandGenerator[None]: diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 50c1c4721..4b26b23c6 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -16,12 +16,11 @@ from contextlib import contextmanager from dataclasses import dataclass from OpenSSL import SSL -from mitmproxy import http, options as moptions +from mitmproxy import http, options as moptions, tls from mitmproxy.proxy.context import Context from mitmproxy.proxy.layers.http import HTTPMode from mitmproxy.proxy import commands, events, layer, layers, server_hooks from mitmproxy.connection import Address, Client, Connection, ConnectionState -from mitmproxy.proxy.layers import tls from mitmproxy.utils import asyncio_utils from mitmproxy.utils import human from mitmproxy.utils.data import pkg_data @@ -414,7 +413,7 @@ if __name__ == "__main__": # pragma: no cover if "redirect" in flow.request.path: flow.request.host = "httpbin.org" - def tls_start_client(tls_start: tls.TlsHookData): + def tls_start_client(tls_start: tls.TlsData): # INSECURE ssl_context = SSL.Context(SSL.SSLv23_METHOD) ssl_context.use_privatekey_file( @@ -426,7 +425,7 @@ if __name__ == "__main__": # pragma: no cover tls_start.ssl_conn = SSL.Connection(ssl_context) tls_start.ssl_conn.set_accept_state() - def tls_start_server(tls_start: tls.TlsHookData): + def tls_start_server(tls_start: tls.TlsData): # INSECURE ssl_context = SSL.Context(SSL.SSLv23_METHOD) tls_start.ssl_conn = SSL.Connection(ssl_context) diff --git a/mitmproxy/tls.py b/mitmproxy/tls.py index 4e95f231d..55b14668b 100644 --- a/mitmproxy/tls.py +++ b/mitmproxy/tls.py @@ -1,13 +1,20 @@ import io +from dataclasses import dataclass from typing import List, Optional, Tuple from kaitaistruct import KaitaiStream +from OpenSSL import SSL +from mitmproxy import connection from mitmproxy.contrib.kaitaistruct import tls_client_hello from mitmproxy.net import check +from mitmproxy.proxy import context class ClientHello: + """ + A TLS ClientHello is the first message sent by the client when initiating TLS. + """ def __init__(self, raw_client_hello): self._client_hello = tls_client_hello.TlsClientHello( @@ -23,10 +30,10 @@ class ClientHello: if self._client_hello.extensions: for extension in self._client_hello.extensions.extensions: is_valid_sni_extension = ( - extension.type == 0x00 and - len(extension.body.server_names) == 1 and - extension.body.server_names[0].name_type == 0 and - check.is_valid_host(extension.body.server_names[0].host_name) + extension.type == 0x00 and + len(extension.body.server_names) == 1 and + extension.body.server_names[0].name_type == 0 and + check.is_valid_host(extension.body.server_names[0].host_name) ) if is_valid_sni_extension: return extension.body.server_names[0].host_name.decode("ascii") @@ -51,3 +58,39 @@ class ClientHello: def __repr__(self): return f"ClientHello(sni: {self.sni}, alpn_protocols: {self.alpn_protocols})" + + +@dataclass +class ClientHelloData: + """ + Event data for `tls_clienthello` event hooks. + """ + context: context.Context + """The context object for this connection.""" + client_hello: ClientHello + """The entire parsed TLS ClientHello.""" + ignore_connection: bool = False + """ + If set to `True`, do not intercept this connection and forward encrypted contents unmodified. + """ + establish_server_tls_first: bool = False + """ + If set to `True`, pause this handshake and establish TLS with an upstream server first. + This makes it possible to process the server certificate when generating an interception certificate. + """ + + +@dataclass +class TlsData: + """ + Event data for `tls_start_client`, `tls_start_server`, and `tls_handshake` event hooks. + """ + conn: connection.Connection + """The affected connection.""" + context: context.Context + """The context object for this connection.""" + ssl_conn: Optional[SSL.Connection] = None + """ + The associated pyOpenSSL `SSL.Connection` object. + This will be set by an addon in the `tls_start_*` event hooks. + """ diff --git a/test/mitmproxy/addons/test_tlsconfig.py b/test/mitmproxy/addons/test_tlsconfig.py index 268fd355f..115d92c0c 100644 --- a/test/mitmproxy/addons/test_tlsconfig.py +++ b/test/mitmproxy/addons/test_tlsconfig.py @@ -6,10 +6,10 @@ from typing import Union import pytest from OpenSSL import SSL -from mitmproxy import certs, connection +from mitmproxy import certs, connection, tls from mitmproxy.addons import tlsconfig from mitmproxy.proxy import context -from mitmproxy.proxy.layers import modes, tls +from mitmproxy.proxy.layers import modes, tls as proxy_tls from mitmproxy.test import taddons from test.mitmproxy.proxy.layers import test_tls @@ -130,7 +130,7 @@ class TestTlsConfig: ) ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options) - tls_start = tls.TlsHookData(ctx.client, context=ctx) + tls_start = tls.TlsData(ctx.client, context=ctx) ta.tls_start_client(tls_start) tssl_server = tls_start.ssl_conn tssl_client = test_tls.SSLTest() @@ -145,7 +145,7 @@ class TestTlsConfig: ctx.client.cipher_list = ["TLS_AES_256_GCM_SHA384", "ECDHE-RSA-AES128-SHA"] ctx.server.address = ("example.mitmproxy.org", 443) - tls_start = tls.TlsHookData(ctx.server, context=ctx) + tls_start = tls.TlsData(ctx.server, context=ctx) ta.tls_start_server(tls_start) tssl_client = tls_start.ssl_conn tssl_server = test_tls.SSLTest(server_side=True) @@ -160,7 +160,7 @@ class TestTlsConfig: tctx.configure(ta, ssl_verify_upstream_trusted_ca=tdata.path( "mitmproxy/net/data/verificationcerts/trusted-root.crt")) - tls_start = tls.TlsHookData(ctx.server, context=ctx) + tls_start = tls.TlsData(ctx.server, context=ctx) ta.tls_start_server(tls_start) tssl_client = tls_start.ssl_conn tssl_server = test_tls.SSLTest(server_side=True) @@ -179,7 +179,7 @@ class TestTlsConfig: http2=False, ciphers_server="ALL" ) - tls_start = tls.TlsHookData(ctx.server, context=ctx) + tls_start = tls.TlsData(ctx.server, context=ctx) ta.tls_start_server(tls_start) tssl_client = tls_start.ssl_conn tssl_server = test_tls.SSLTest(server_side=True) @@ -190,7 +190,7 @@ class TestTlsConfig: with taddons.context(ta) as tctx: ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options) ctx.server.address = ("example.mitmproxy.org", 443) - tls_start = tls.TlsHookData(ctx.server, context=ctx) + tls_start = tls.TlsData(ctx.server, context=ctx) def assert_alpn(http2, client_offers, expected): tctx.configure(ta, http2=http2) @@ -199,8 +199,8 @@ class TestTlsConfig: ta.tls_start_server(tls_start) assert ctx.server.alpn_offers == expected - assert_alpn(True, tls.HTTP_ALPNS + (b"foo",), tls.HTTP_ALPNS + (b"foo",)) - assert_alpn(False, tls.HTTP_ALPNS + (b"foo",), tls.HTTP1_ALPNS + (b"foo",)) + assert_alpn(True, proxy_tls.HTTP_ALPNS + (b"foo",), proxy_tls.HTTP_ALPNS + (b"foo",)) + assert_alpn(False, proxy_tls.HTTP_ALPNS + (b"foo",), proxy_tls.HTTP1_ALPNS + (b"foo",)) assert_alpn(True, [], []) assert_alpn(False, [], []) ctx.client.timestamp_tls_setup = time.time() @@ -222,7 +222,7 @@ class TestTlsConfig: modes.HttpProxy(ctx), 123 ] - tls_start = tls.TlsHookData(ctx.client, context=ctx) + tls_start = tls.TlsData(ctx.client, context=ctx) ta.tls_start_client(tls_start) assert tls_start.ssl_conn.get_app_data()["client_alpn"] == b"http/1.1" @@ -244,7 +244,7 @@ class TestTlsConfig: ssl_verify_upstream_trusted_ca=tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.crt"), ) - tls_start = tls.TlsHookData(ctx.server, context=ctx) + tls_start = tls.TlsData(ctx.server, context=ctx) ta.tls_start_server(tls_start) tssl_client = tls_start.ssl_conn tssl_server = test_tls.SSLTest(server_side=True) diff --git a/test/mitmproxy/proxy/layers/test_tls.py b/test/mitmproxy/proxy/layers/test_tls.py index 4c4aad7d1..e20537c34 100644 --- a/test/mitmproxy/proxy/layers/test_tls.py +++ b/test/mitmproxy/proxy/layers/test_tls.py @@ -9,6 +9,7 @@ from mitmproxy import connection from mitmproxy.connection import ConnectionState, Server from mitmproxy.proxy import commands, context, events, layer from mitmproxy.proxy.layers import tls +from mitmproxy.tls import ClientHelloData, TlsData from mitmproxy.utils import data from test.mitmproxy.proxy import tutils @@ -68,8 +69,8 @@ def test_get_client_hello(): assert tls.get_client_hello(single_record) == client_hello_no_extensions split_over_two_records = ( - bytes.fromhex("1603010020") + client_hello_no_extensions[:32] + - bytes.fromhex("1603010045") + client_hello_no_extensions[32:] + bytes.fromhex("1603010020") + client_hello_no_extensions[:32] + + bytes.fromhex("1603010045") + client_hello_no_extensions[32:] ) assert tls.get_client_hello(split_over_two_records) == client_hello_no_extensions @@ -134,9 +135,9 @@ def _test_echo(playbook: tutils.Playbook, tssl: SSLTest, conn: connection.Connec tssl.obj.write(b"Hello World") data = tutils.Placeholder(bytes) assert ( - playbook - >> events.DataReceived(conn, tssl.bio_read()) - << commands.SendData(conn, data) + playbook + >> events.DataReceived(conn, tssl.bio_read()) + << commands.SendData(conn, data) ) tssl.bio_write(data()) assert tssl.obj.read() == b"hello world" @@ -156,13 +157,13 @@ class TlsEchoLayer(tutils.EchoLayer): def finish_handshake(playbook: tutils.Playbook, conn: connection.Connection, tssl: SSLTest): data = tutils.Placeholder(bytes) - tls_hook_data = tutils.Placeholder(tls.TlsHookData) + tls_hook_data = tutils.Placeholder(TlsData) assert ( - playbook - >> events.DataReceived(conn, tssl.bio_read()) - << tls.TlsHandshakeHook(tls_hook_data) - >> tutils.reply() - << commands.SendData(conn, data) + playbook + >> events.DataReceived(conn, tssl.bio_read()) + << tls.TlsHandshakeHook(tls_hook_data) + >> tutils.reply() + << commands.SendData(conn, data) ) assert tls_hook_data().conn.error is None tssl.bio_write(data()) @@ -173,7 +174,7 @@ def reply_tls_start_client(alpn: typing.Optional[bytes] = None, *args, **kwargs) Helper function to simplify the syntax for tls_start_client hooks. """ - def make_client_conn(tls_start: tls.TlsHookData) -> None: + def make_client_conn(tls_start: TlsData) -> None: # ssl_context = SSL.Context(Method.TLS_METHOD) # ssl_context.set_min_proto_version(SSL.TLS1_3_VERSION) ssl_context = SSL.Context(SSL.SSLv23_METHOD) @@ -198,7 +199,7 @@ def reply_tls_start_server(alpn: typing.Optional[bytes] = None, *args, **kwargs) Helper function to simplify the syntax for tls_start_server hooks. """ - def make_server_conn(tls_start: tls.TlsHookData) -> None: + def make_server_conn(tls_start: TlsData) -> None: # ssl_context = SSL.Context(Method.TLS_METHOD) # ssl_context.set_min_proto_version(SSL.TLS1_3_VERSION) ssl_context = SSL.Context(SSL.SSLv23_METHOD) @@ -243,9 +244,9 @@ class TestServerTLS: layer.child_layer = TlsEchoLayer(tctx) assert ( - tutils.Playbook(layer) - >> events.DataReceived(tctx.client, b"Hello World") - << commands.SendData(tctx.client, b"hello world") + tutils.Playbook(layer) + >> events.DataReceived(tctx.client, b"Hello World") + << commands.SendData(tctx.client, b"hello world") ) def test_simple(self, tctx): @@ -259,10 +260,10 @@ class TestServerTLS: # send ClientHello, receive ClientHello data = tutils.Placeholder(bytes) assert ( - playbook - << tls.TlsStartServerHook(tutils.Placeholder()) - >> reply_tls_start_server() - << commands.SendData(tctx.server, data) + playbook + << tls.TlsStartServerHook(tutils.Placeholder()) + >> reply_tls_start_server() + << commands.SendData(tctx.server, data) ) tssl.bio_write(data()) with pytest.raises(ssl.SSLWantReadError): @@ -274,31 +275,31 @@ class TestServerTLS: # finish handshake (locally) tssl.do_handshake() assert ( - playbook - >> events.DataReceived(tctx.server, tssl.bio_read()) - << None + playbook + >> events.DataReceived(tctx.server, tssl.bio_read()) + << None ) assert tctx.server.tls_established # Echo assert ( - playbook - >> events.DataReceived(tctx.client, b"foo") - << layer.NextLayerHook(tutils.Placeholder()) - >> tutils.reply_next_layer(TlsEchoLayer) - << commands.SendData(tctx.client, b"foo") + playbook + >> events.DataReceived(tctx.client, b"foo") + << layer.NextLayerHook(tutils.Placeholder()) + >> tutils.reply_next_layer(TlsEchoLayer) + << commands.SendData(tctx.client, b"foo") ) _test_echo(playbook, tssl, tctx.server) with pytest.raises(ssl.SSLWantReadError): tssl.obj.unwrap() assert ( - playbook - >> events.DataReceived(tctx.server, tssl.bio_read()) - << commands.CloseConnection(tctx.server) - >> events.ConnectionClosed(tctx.server) - << None + playbook + >> events.DataReceived(tctx.server, tssl.bio_read()) + << commands.CloseConnection(tctx.server) + >> events.ConnectionClosed(tctx.server) + << None ) def test_untrusted_cert(self, tctx): @@ -312,15 +313,15 @@ class TestServerTLS: # send ClientHello data = tutils.Placeholder(bytes) assert ( - playbook - >> events.DataReceived(tctx.client, b"open-connection") - << layer.NextLayerHook(tutils.Placeholder()) - >> tutils.reply_next_layer(TlsEchoLayer) - << commands.OpenConnection(tctx.server) - >> tutils.reply(None) - << tls.TlsStartServerHook(tutils.Placeholder()) - >> reply_tls_start_server() - << commands.SendData(tctx.server, data) + playbook + >> events.DataReceived(tctx.client, b"open-connection") + << layer.NextLayerHook(tutils.Placeholder()) + >> tutils.reply_next_layer(TlsEchoLayer) + << commands.OpenConnection(tctx.server) + >> tutils.reply(None) + << tls.TlsStartServerHook(tutils.Placeholder()) + >> reply_tls_start_server() + << commands.SendData(tctx.server, data) ) # receive ServerHello, finish client handshake @@ -328,16 +329,16 @@ class TestServerTLS: with pytest.raises(ssl.SSLWantReadError): tssl.do_handshake() - tls_hook_data = tutils.Placeholder(tls.TlsHookData) + tls_hook_data = tutils.Placeholder(TlsData) assert ( - playbook - >> events.DataReceived(tctx.server, tssl.bio_read()) - << commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn") - << tls.TlsHandshakeHook(tls_hook_data) - >> tutils.reply() - << commands.CloseConnection(tctx.server) - << commands.SendData(tctx.client, - b"open-connection failed: Certificate verify failed: Hostname mismatch") + playbook + >> events.DataReceived(tctx.server, tssl.bio_read()) + << commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn") + << tls.TlsHandshakeHook(tls_hook_data) + >> tutils.reply() + << commands.CloseConnection(tctx.server) + << commands.SendData(tctx.client, + b"open-connection failed: Certificate verify failed: Hostname mismatch") ) assert tls_hook_data().conn.error == "Certificate verify failed: Hostname mismatch" assert not tctx.server.tls_established @@ -349,17 +350,17 @@ class TestServerTLS: # send ClientHello, receive random garbage back data = tutils.Placeholder(bytes) - tls_hook_data = tutils.Placeholder(tls.TlsHookData) + tls_hook_data = tutils.Placeholder(TlsData) assert ( - playbook - << tls.TlsStartServerHook(tutils.Placeholder()) - >> reply_tls_start_server() - << commands.SendData(tctx.server, data) - >> events.DataReceived(tctx.server, b"HTTP/1.1 404 Not Found\r\n") - << commands.Log("Server TLS handshake failed. The remote server does not speak TLS.", "warn") - << tls.TlsHandshakeHook(tls_hook_data) - >> tutils.reply() - << commands.CloseConnection(tctx.server) + playbook + << tls.TlsStartServerHook(tutils.Placeholder()) + >> reply_tls_start_server() + << commands.SendData(tctx.server, data) + >> events.DataReceived(tctx.server, b"HTTP/1.1 404 Not Found\r\n") + << commands.Log("Server TLS handshake failed. The remote server does not speak TLS.", "warn") + << tls.TlsHandshakeHook(tls_hook_data) + >> tutils.reply() + << commands.CloseConnection(tctx.server) ) assert tls_hook_data().conn.error == "The remote server does not speak TLS." @@ -388,7 +389,7 @@ class TestServerTLS: tssl.do_handshake() # send back error - tls_hook_data = tutils.Placeholder(tls.TlsHookData) + tls_hook_data = tutils.Placeholder(TlsData) assert ( playbook >> events.DataReceived(tctx.server, tssl.bio_read()) @@ -402,8 +403,8 @@ class TestServerTLS: def make_client_tls_layer( - tctx: context.Context, - **kwargs + tctx: context.Context, + **kwargs ) -> typing.Tuple[tutils.Playbook, tls.ClientTLSLayer, SSLTest]: # This is a bit contrived as the client layer expects a server layer as parent. # We also set child layers manually to avoid NextLayer noise. @@ -435,13 +436,13 @@ class TestClientTLS: # Send ClientHello, receive ServerHello data = tutils.Placeholder(bytes) assert ( - playbook - >> events.DataReceived(tctx.client, tssl_client.bio_read()) - << tls.TlsClienthelloHook(tutils.Placeholder()) - >> tutils.reply() - << tls.TlsStartClientHook(tutils.Placeholder()) - >> reply_tls_start_client() - << commands.SendData(tctx.client, data) + playbook + >> events.DataReceived(tctx.client, tssl_client.bio_read()) + << tls.TlsClienthelloHook(tutils.Placeholder()) + >> tutils.reply() + << tls.TlsStartClientHook(tutils.Placeholder()) + >> reply_tls_start_client() + << commands.SendData(tctx.client, data) ) tssl_client.bio_write(data()) tssl_client.do_handshake() @@ -455,9 +456,9 @@ class TestClientTLS: _test_echo(playbook, tssl_client, tctx.client) other_server = Server(None) assert ( - playbook - >> events.DataReceived(other_server, b"Plaintext") - << commands.SendData(other_server, b"plaintext") + playbook + >> events.DataReceived(other_server, b"Plaintext") + << commands.SendData(other_server, b"plaintext") ) @pytest.mark.parametrize("server_state", ["open", "closed"]) @@ -474,14 +475,14 @@ class TestClientTLS: # We should now get instructed to open a server connection. data = tutils.Placeholder(bytes) - def require_server_conn(client_hello: tls.ClientHelloData) -> None: + def require_server_conn(client_hello: ClientHelloData) -> None: client_hello.establish_server_tls_first = True ( - playbook - >> events.DataReceived(tctx.client, tssl_client.bio_read()) - << tls.TlsClienthelloHook(tutils.Placeholder()) - >> tutils.reply(side_effect=require_server_conn) + playbook + >> events.DataReceived(tctx.client, tssl_client.bio_read()) + << tls.TlsClienthelloHook(tutils.Placeholder()) + >> tutils.reply(side_effect=require_server_conn) ) if server_state == "closed": ( @@ -503,12 +504,12 @@ class TestClientTLS: data = tutils.Placeholder(bytes) assert ( - playbook - >> events.DataReceived(tctx.server, tssl_server.bio_read()) - << tls.TlsHandshakeHook(tutils.Placeholder()) - >> tutils.reply() - << commands.SendData(tctx.server, data) - << tls.TlsStartClientHook(tutils.Placeholder()) + playbook + >> events.DataReceived(tctx.server, tssl_server.bio_read()) + << tls.TlsHandshakeHook(tutils.Placeholder()) + >> tutils.reply() + << commands.SendData(tctx.server, data) + << tls.TlsStartClientHook(tutils.Placeholder()) ) tssl_server.bio_write(data()) assert tctx.server.tls_established @@ -516,9 +517,9 @@ class TestClientTLS: data = tutils.Placeholder(bytes) assert ( - playbook - >> reply_tls_start_client(alpn=b"quux") - << commands.SendData(tctx.client, data) + playbook + >> reply_tls_start_client(alpn=b"quux") + << commands.SendData(tctx.client, data) ) tssl_client.bio_write(data()) tssl_client.do_handshake() @@ -544,7 +545,7 @@ class TestClientTLS: playbook, client_layer, tssl_client = make_client_tls_layer(tctx, alpn=["quux"]) - def make_passthrough(client_hello: tls.ClientHelloData) -> None: + def make_passthrough(client_hello: ClientHelloData) -> None: client_hello.ignore_connection = True client_hello = tssl_client.bio_read() @@ -570,17 +571,17 @@ class TestClientTLS: def test_cannot_parse_clienthello(self, tctx: context.Context): """Test the scenario where we cannot parse the ClientHello""" playbook, client_layer, tssl_client = make_client_tls_layer(tctx) - tls_hook_data = tutils.Placeholder(tls.TlsHookData) + tls_hook_data = tutils.Placeholder(TlsData) invalid = b"\x16\x03\x01\x00\x00" assert ( - playbook - >> events.DataReceived(tctx.client, invalid) - << commands.Log(f"Client TLS handshake failed. Cannot parse ClientHello: {invalid.hex()}", level="warn") - << tls.TlsHandshakeHook(tls_hook_data) - >> tutils.reply() - << commands.CloseConnection(tctx.client) + playbook + >> events.DataReceived(tctx.client, invalid) + << commands.Log(f"Client TLS handshake failed. Cannot parse ClientHello: {invalid.hex()}", level="warn") + << tls.TlsHandshakeHook(tls_hook_data) + >> tutils.reply() + << commands.CloseConnection(tctx.client) ) assert tls_hook_data().conn.error assert not tctx.client.tls_established @@ -601,28 +602,28 @@ class TestClientTLS: data = tutils.Placeholder(bytes) assert ( - playbook - >> events.DataReceived(tctx.client, tssl_client.bio_read()) - << tls.TlsClienthelloHook(tutils.Placeholder()) - >> tutils.reply() - << tls.TlsStartClientHook(tutils.Placeholder()) - >> reply_tls_start_client() - << commands.SendData(tctx.client, data) + playbook + >> events.DataReceived(tctx.client, tssl_client.bio_read()) + << tls.TlsClienthelloHook(tutils.Placeholder()) + >> tutils.reply() + << tls.TlsStartClientHook(tutils.Placeholder()) + >> reply_tls_start_client() + << commands.SendData(tctx.client, data) ) tssl_client.bio_write(data()) with pytest.raises(ssl.SSLCertVerificationError): tssl_client.do_handshake() # Finish Handshake - tls_hook_data = tutils.Placeholder(tls.TlsHookData) + tls_hook_data = tutils.Placeholder(TlsData) assert ( - playbook - >> events.DataReceived(tctx.client, tssl_client.bio_read()) - << commands.Log("Client TLS handshake failed. The client does not trust the proxy's certificate " - "for wrong.host.mitmproxy.org (sslv3 alert bad certificate)", "warn") - << tls.TlsHandshakeHook(tls_hook_data) - >> tutils.reply() - << commands.CloseConnection(tctx.client) - >> events.ConnectionClosed(tctx.client) + playbook + >> events.DataReceived(tctx.client, tssl_client.bio_read()) + << commands.Log("Client TLS handshake failed. The client does not trust the proxy's certificate " + "for wrong.host.mitmproxy.org (sslv3 alert bad certificate)", "warn") + << tls.TlsHandshakeHook(tls_hook_data) + >> tutils.reply() + << commands.CloseConnection(tctx.client) + >> events.ConnectionClosed(tctx.client) ) assert not tctx.client.tls_established assert tls_hook_data().conn.error @@ -634,7 +635,7 @@ class TestClientTLS: the proxy certificate.""" playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org") playbook.logs = True - tls_hook_data = tutils.Placeholder(tls.TlsHookData) + tls_hook_data = tutils.Placeholder(TlsData) playbook >> events.DataReceived(tctx.client, tssl_client.bio_read()) playbook << tls.TlsClienthelloHook(tutils.Placeholder()) @@ -687,7 +688,7 @@ class TestClientTLS: playbook, client_layer, tssl_client = make_client_tls_layer(tctx, max_ver=ssl.TLSVersion.TLSv1_2) playbook.logs = True - tls_hook_data = tutils.Placeholder(tls.TlsHookData) + tls_hook_data = tutils.Placeholder(TlsData) assert ( playbook >> events.DataReceived(tctx.client, tssl_client.bio_read())