mitmproxy-wireguard -> mitmproxy_rs (#5909)

mitmproxy-rs includes all the fantastic WireGuard work,
but will add more non-WireGuard stuff. :)
This commit is contained in:
Maximilian Hils 2023-02-04 22:28:15 +01:00 committed by GitHub
parent 977385ceab
commit a7e50c793e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 25 additions and 67 deletions

View File

@ -9,8 +9,9 @@ from typing import cast
from typing import Optional from typing import Optional
from typing import Union from typing import Union
import mitmproxy_rs
from mitmproxy.connection import Address from mitmproxy.connection import Address
from mitmproxy.net import udp_wireguard
from mitmproxy.utils import human from mitmproxy.utils import human
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -162,14 +163,14 @@ class DatagramReader:
class DatagramWriter: class DatagramWriter:
_transport: asyncio.DatagramTransport _transport: asyncio.DatagramTransport | mitmproxy_rs.DatagramTransport
_remote_addr: Address _remote_addr: Address
_reader: DatagramReader | None _reader: DatagramReader | None
_closed: asyncio.Event | None _closed: asyncio.Event | None
def __init__( def __init__(
self, self,
transport: asyncio.DatagramTransport, transport: asyncio.DatagramTransport | mitmproxy_rs.DatagramTransport,
remote_addr: Address, remote_addr: Address,
reader: DatagramReader | None = None, reader: DatagramReader | None = None,
) -> None: ) -> None:
@ -189,7 +190,7 @@ class DatagramWriter:
@property @property
def _protocol( def _protocol(
self, self,
) -> DrainableDatagramProtocol | udp_wireguard.WireGuardDatagramTransport: ) -> DrainableDatagramProtocol | mitmproxy_rs.DatagramTransport:
return self._transport.get_protocol() # type: ignore return self._transport.get_protocol() # type: ignore
def write(self, data: bytes) -> None: def write(self, data: bytes) -> None:

View File

@ -1,35 +0,0 @@
"""
This module contains a mock DatagramTransport for use with mitmproxy-wireguard.
"""
import asyncio
from typing import Any
import mitmproxy_wireguard as wg
from mitmproxy.connection import Address
class WireGuardDatagramTransport(asyncio.DatagramTransport):
def __init__(self, server: wg.Server, local_addr: Address, remote_addr: Address):
self._server: wg.Server = server
self._local_addr: Address = local_addr
self._remote_addr: Address = remote_addr
super().__init__()
def sendto(self, data, addr=None):
self._server.send_datagram(data, self._local_addr, addr or self._remote_addr)
def get_extra_info(self, name: str, default: Any = None) -> Any:
if name == "sockname":
return self._server.getsockname()
else:
raise NotImplementedError
def get_protocol(self):
return self
async def drain(self) -> None:
pass
async def wait_closed(self) -> None:
pass

View File

@ -87,7 +87,7 @@ class ReverseProxy(DestinationKnown):
class TransparentProxy(DestinationKnown): class TransparentProxy(DestinationKnown):
@expect(events.Start) @expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.context.server.address assert self.context.server.address, "No server address set."
self.child_layer = layer.NextLayer(self.context) self.child_layer = layer.NextLayer(self.context)
err = yield from self.finish_start() err = yield from self.finish_start()
if err: if err:

View File

@ -28,7 +28,7 @@ from typing import Generic
from typing import get_args from typing import get_args
from typing import TypeVar from typing import TypeVar
import mitmproxy_wireguard as wg import mitmproxy_rs
from mitmproxy import ctx from mitmproxy import ctx
from mitmproxy import flow from mitmproxy import flow
@ -37,7 +37,6 @@ from mitmproxy.connection import Address
from mitmproxy.master import Master from mitmproxy.master import Master
from mitmproxy.net import local_ip from mitmproxy.net import local_ip
from mitmproxy.net import udp from mitmproxy.net import udp
from mitmproxy.net.udp_wireguard import WireGuardDatagramTransport
from mitmproxy.proxy import commands from mitmproxy.proxy import commands
from mitmproxy.proxy import layers from mitmproxy.proxy import layers
from mitmproxy.proxy import mode_specs from mitmproxy.proxy import mode_specs
@ -149,8 +148,8 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
async def handle_tcp_connection( async def handle_tcp_connection(
self, self,
reader: asyncio.StreamReader | wg.TcpStream, reader: asyncio.StreamReader | mitmproxy_rs.TcpStream,
writer: asyncio.StreamWriter | wg.TcpStream, writer: asyncio.StreamWriter | mitmproxy_rs.TcpStream,
) -> None: ) -> None:
handler = ProxyConnectionHandler( handler = ProxyConnectionHandler(
ctx.master, reader, writer, ctx.options, self.mode ctx.master, reader, writer, ctx.options, self.mode
@ -182,7 +181,7 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
def handle_udp_datagram( def handle_udp_datagram(
self, self,
transport: asyncio.DatagramTransport, transport: asyncio.DatagramTransport | mitmproxy_rs.DatagramTransport,
data: bytes, data: bytes,
remote_addr: Address, remote_addr: Address,
local_addr: Address, local_addr: Address,
@ -304,7 +303,7 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]): class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
_server: wg.Server | None = None _server: mitmproxy_rs.WireGuardServer | None = None
_listen_addrs: tuple[Address, ...] = tuple() _listen_addrs: tuple[Address, ...] = tuple()
server_key: str server_key: str
@ -333,8 +332,8 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
conf_path.write_text( conf_path.write_text(
json.dumps( json.dumps(
{ {
"server_key": wg.genkey(), "server_key": mitmproxy_rs.genkey(),
"client_key": wg.genkey(), "client_key": mitmproxy_rs.genkey(),
}, },
indent=4, indent=4,
) )
@ -349,16 +348,16 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
f"Invalid configuration file ({conf_path}): {e}" f"Invalid configuration file ({conf_path}): {e}"
) from e ) from e
# error early on invalid keys # error early on invalid keys
p = wg.pubkey(self.client_key) p = mitmproxy_rs.pubkey(self.client_key)
_ = wg.pubkey(self.server_key) _ = mitmproxy_rs.pubkey(self.server_key)
self._server = await wg.start_server( self._server = await mitmproxy_rs.start_wireguard_server(
host, host,
port, port,
self.server_key, self.server_key,
[p], [p],
self.wg_handle_tcp_connection, self.wg_handle_tcp_connection,
self.wg_handle_udp_datagram, self.handle_udp_datagram,
) )
self._listen_addrs = (self._server.getsockname(),) self._listen_addrs = (self._server.getsockname(),)
except Exception as e: except Exception as e:
@ -391,7 +390,7 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
DNS = 10.0.0.53 DNS = 10.0.0.53
[Peer] [Peer]
PublicKey = {wg.pubkey(self.server_key)} PublicKey = {mitmproxy_rs.pubkey(self.server_key)}
AllowedIPs = 0.0.0.0/0 AllowedIPs = 0.0.0.0/0
Endpoint = {host}:{port} Endpoint = {host}:{port}
""" """
@ -414,16 +413,9 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
def listen_addrs(self) -> tuple[Address, ...]: def listen_addrs(self) -> tuple[Address, ...]:
return self._listen_addrs return self._listen_addrs
async def wg_handle_tcp_connection(self, stream: wg.TcpStream) -> None: async def wg_handle_tcp_connection(self, stream: mitmproxy_rs.TcpStream) -> None:
await self.handle_tcp_connection(stream, stream) await self.handle_tcp_connection(stream, stream)
def wg_handle_udp_datagram(
self, data: bytes, remote_addr: Address, local_addr: Address
) -> None:
assert self._server is not None
transport = WireGuardDatagramTransport(self._server, local_addr, remote_addr)
self.handle_udp_datagram(transport, data, remote_addr, local_addr)
class RegularInstance(AsyncioServerInstance[mode_specs.RegularMode]): class RegularInstance(AsyncioServerInstance[mode_specs.RegularMode]):
def make_top_layer(self, context: Context) -> Layer: def make_top_layer(self, context: Context) -> Layer:

View File

@ -20,7 +20,7 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
from typing import Union from typing import Union
import mitmproxy_wireguard as wg import mitmproxy_rs
from OpenSSL import SSL from OpenSSL import SSL
from mitmproxy import http from mitmproxy import http
@ -93,10 +93,10 @@ class TimeoutWatchdog:
class ConnectionIO: class ConnectionIO:
handler: Optional[asyncio.Task] = None handler: Optional[asyncio.Task] = None
reader: Optional[ reader: Optional[
Union[asyncio.StreamReader, udp.DatagramReader, wg.TcpStream] Union[asyncio.StreamReader, udp.DatagramReader, mitmproxy_rs.TcpStream]
] = None ] = None
writer: Optional[ writer: Optional[
Union[asyncio.StreamWriter, udp.DatagramWriter, wg.TcpStream] Union[asyncio.StreamWriter, udp.DatagramWriter, mitmproxy_rs.TcpStream]
] = None ] = None
@ -429,8 +429,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
class LiveConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta): class LiveConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):
def __init__( def __init__(
self, self,
reader: Union[asyncio.StreamReader, wg.TcpStream], reader: Union[asyncio.StreamReader, mitmproxy_rs.TcpStream],
writer: Union[asyncio.StreamWriter, wg.TcpStream], writer: Union[asyncio.StreamWriter, mitmproxy_rs.TcpStream],
options: moptions.Options, options: moptions.Options,
mode: mode_specs.ProxyMode, mode: mode_specs.ProxyMode,
) -> None: ) -> None:

View File

@ -85,7 +85,7 @@ setup(
"hyperframe>=6.0,<7", "hyperframe>=6.0,<7",
"kaitaistruct>=0.10,<0.11", "kaitaistruct>=0.10,<0.11",
"ldap3>=2.8,<2.10", "ldap3>=2.8,<2.10",
"mitmproxy_wireguard>=0.1.6,<0.2", "mitmproxy_rs>=0.2.0b1,<0.3",
"msgpack>=1.0.0, <1.1.0", "msgpack>=1.0.0, <1.1.0",
"passlib>=1.6.5, <1.8", "passlib>=1.6.5, <1.8",
"protobuf>=3.14,<5", "protobuf>=3.14,<5",