Add experimental Windows OS proxy mode (#5912)

* add experimental Windows OS proxy mode

this is merely a proof-of-concept now, but works under the most ideal circumstances.
This commit is contained in:
Maximilian Hils 2023-02-06 17:34:48 +01:00 committed by GitHub
parent a7e50c793e
commit 54185c2c8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 288 additions and 114 deletions

View File

@ -15,6 +15,7 @@ import asyncio
import errno
import json
import logging
import os
import socket
import textwrap
import typing
@ -85,10 +86,11 @@ Self = TypeVar("Self", bound="ServerInstance")
class ServerInstance(Generic[M], metaclass=ABCMeta):
__modes: ClassVar[dict[str, type[ServerInstance]]] = {}
last_exception: Exception | None = None
def __init__(self, mode: M, manager: ServerManager):
self.mode: M = mode
self.manager: ServerManager = manager
self.last_exception: Exception | None = None
def __init_subclass__(cls, **kwargs):
"""Register all subclasses so that make() finds them."""
@ -119,12 +121,41 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
def is_running(self) -> bool:
pass
@abstractmethod
async def start(self) -> None:
try:
await self._start()
except Exception as e:
self.last_exception = e
raise
else:
self.last_exception = None
if self.listen_addrs:
addrs = " and ".join({human.format_address(a) for a in self.listen_addrs})
logger.info(f"{self.mode.description} listening at {addrs}.")
else:
logger.info(f"{self.mode.description} started.")
async def stop(self) -> None:
listen_addrs = self.listen_addrs
try:
await self._stop()
except Exception as e:
self.last_exception = e
raise
else:
self.last_exception = None
if listen_addrs:
addrs = " and ".join({human.format_address(a) for a in listen_addrs})
logger.info(f"{self.mode.description} at {addrs} stopped.")
else:
logger.info(f"{self.mode.description} stopped.")
@abstractmethod
async def _start(self) -> None:
pass
@abstractmethod
async def stop(self) -> None:
async def _stop(self) -> None:
pass
@property
@ -166,10 +197,8 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
else:
handler.layer.context.client.sockname = original_dst
handler.layer.context.server.address = original_dst
elif isinstance(self.mode, mode_specs.WireGuardMode):
original_dst = writer.get_extra_info("original_dst")
handler.layer.context.client.sockname = original_dst
handler.layer.context.server.address = original_dst
elif isinstance(self.mode, (mode_specs.WireGuardMode, mode_specs.OsProxyMode)):
handler.layer.context.server.address = handler.layer.context.client.sockname
connection_id = (
handler.layer.context.client.transport_protocol,
@ -197,7 +226,9 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
handler.layer = self.make_top_layer(handler.layer.context)
handler.layer.context.client.transport_protocol = "udp"
handler.layer.context.server.transport_protocol = "udp"
if isinstance(self.mode, mode_specs.WireGuardMode):
if isinstance(
self.mode, (mode_specs.WireGuardMode, mode_specs.OsProxyMode)
):
handler.layer.context.server.address = local_addr
# pre-register here - we may get datagrams before the task is executed.
@ -217,13 +248,19 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
_server: asyncio.Server | udp.UdpServer | None = None
_listen_addrs: tuple[Address, ...] = tuple()
@property
def is_running(self) -> bool:
return self._server is not None
async def start(self) -> None:
@property
def listen_addrs(self) -> tuple[Address, ...]:
if self._server is not None:
return tuple(s.getsockname() for s in self._server.sockets)
else:
return tuple()
async def _start(self) -> None:
assert self._server is None
host = self.mode.listen_host(ctx.options.listen_host)
port = self.mode.listen_port(ctx.options.listen_port)
@ -231,7 +268,6 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
self._server = await self.listen(host, port)
self._listen_addrs = tuple(s.getsockname() for s in self._server.sockets)
except OSError as e:
self.last_exception = e
message = f"{self.mode.description} failed to listen on {host or '*'}:{port} with {e}"
if e.errno == errno.EADDRINUSE and self.mode.custom_listen_port is None:
assert (
@ -239,31 +275,15 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
) # since [@ [listen_addr:]listen_port]
message += f"\nTry specifying a different port by using `--mode {self.mode.full_spec}@{port + 1}`."
raise OSError(e.errno, message, e.filename) from e
except Exception as e:
self.last_exception = e
raise
else:
self.last_exception = None
addrs = " and ".join({human.format_address(a) for a in self._listen_addrs})
logger.info(f"{self.mode.description} listening at {addrs}.")
async def stop(self) -> None:
async def _stop(self) -> None:
assert self._server is not None
# we always reset _server and _listen_addrs and ignore failures
server = self._server
listen_addrs = self._listen_addrs
self._server = None
self._listen_addrs = tuple()
try:
server.close()
await server.wait_closed()
except Exception as e:
self.last_exception = e
raise
else:
self.last_exception = None
addrs = " and ".join({human.format_address(a) for a in listen_addrs})
logger.info(f"Stopped {self.mode.description} at {addrs}.")
self._server.close()
await self._server.wait_closed()
finally:
# we always reset _server and ignore failures
self._server = None
async def listen(self, host: str, port: int) -> asyncio.Server | udp.UdpServer:
if self.mode.transport_protocol == "tcp":
@ -297,14 +317,9 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
else:
raise AssertionError(self.mode.transport_protocol)
@property
def listen_addrs(self) -> tuple[Address, ...]:
return self._listen_addrs
class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
_server: mitmproxy_rs.WireGuardServer | None = None
_listen_addrs: tuple[Address, ...] = tuple()
server_key: str
client_key: str
@ -316,7 +331,14 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
def is_running(self) -> bool:
return self._server is not None
async def start(self) -> None:
@property
def listen_addrs(self) -> tuple[Address, ...]:
if self._server:
return (self._server.getsockname(),)
else:
return tuple()
async def _start(self) -> None:
assert self._server is None
host = self.mode.listen_host(ctx.options.listen_host)
port = self.mode.listen_port(ctx.options.listen_port)
@ -326,56 +348,40 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
else:
conf_path = Path(ctx.options.confdir).expanduser() / "wireguard.conf"
try:
if not conf_path.exists():
conf_path.parent.mkdir(parents=True, exist_ok=True)
conf_path.write_text(
json.dumps(
{
"server_key": mitmproxy_rs.genkey(),
"client_key": mitmproxy_rs.genkey(),
},
indent=4,
)
if not conf_path.exists():
conf_path.parent.mkdir(parents=True, exist_ok=True)
conf_path.write_text(
json.dumps(
{
"server_key": mitmproxy_rs.genkey(),
"client_key": mitmproxy_rs.genkey(),
},
indent=4,
)
try:
c = json.loads(conf_path.read_text())
self.server_key = c["server_key"]
self.client_key = c["client_key"]
except Exception as e:
raise ValueError(
f"Invalid configuration file ({conf_path}): {e}"
) from e
# error early on invalid keys
p = mitmproxy_rs.pubkey(self.client_key)
_ = mitmproxy_rs.pubkey(self.server_key)
self._server = await mitmproxy_rs.start_wireguard_server(
host,
port,
self.server_key,
[p],
self.wg_handle_tcp_connection,
self.handle_udp_datagram,
)
self._listen_addrs = (self._server.getsockname(),)
except Exception as e:
self.last_exception = e
message = f"{self.mode.description} failed to listen on {host or '*'}:{port} with {e}"
raise OSError(message) from e
else:
self.last_exception = None
addrs = " and ".join({human.format_address(a) for a in self.listen_addrs})
try:
c = json.loads(conf_path.read_text())
self.server_key = c["server_key"]
self.client_key = c["client_key"]
except Exception as e:
raise ValueError(f"Invalid configuration file ({conf_path}): {e}") from e
# error early on invalid keys
p = mitmproxy_rs.pubkey(self.client_key)
_ = mitmproxy_rs.pubkey(self.server_key)
self._server = await mitmproxy_rs.start_wireguard_server(
host,
port,
self.server_key,
[p],
self.wg_handle_tcp_connection,
self.handle_udp_datagram,
)
conf = self.client_conf()
assert conf
logger.info(
f"{self.mode.description} listening at {addrs}.\n"
+ "------------------------------------------------------------\n"
+ conf
+ "\n------------------------------------------------------------"
)
logger.info("-" * 60 + "\n" + conf + "\n" + "-" * 60)
def client_conf(self) -> str | None:
if not self._server:
@ -399,24 +405,86 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
def to_json(self) -> dict:
return {"wireguard_conf": self.client_conf(), **super().to_json()}
async def stop(self) -> None:
async def _stop(self) -> None:
assert self._server is not None
self._server.close()
await self._server.wait_closed()
self._server = None
self.last_exception = None
addrs = " and ".join({human.format_address(a) for a in self.listen_addrs})
logger.info(f"Stopped {self.mode.description} at {addrs}.")
@property
def listen_addrs(self) -> tuple[Address, ...]:
return self._listen_addrs
try:
self._server.close()
await self._server.wait_closed()
finally:
self._server = None
async def wg_handle_tcp_connection(self, stream: mitmproxy_rs.TcpStream) -> None:
await self.handle_tcp_connection(stream, stream)
class OsProxyInstance(ServerInstance[mode_specs.OsProxyMode]):
_server: ClassVar[mitmproxy_rs.OsProxy | None] = None
"""The OsProxy server. Will be started once and then reused for all future instances."""
_instance: ClassVar[OsProxyInstance | None] = None
"""The current OsProxy Instance. Will be unset again if an instance is stopped."""
listen_addrs = ()
@property
def is_running(self) -> bool:
return self._instance is not None
def make_top_layer(self, context: Context) -> Layer:
return layers.modes.TransparentProxy(context)
@classmethod
async def os_handle_tcp_connection(cls, stream: mitmproxy_rs.TcpStream) -> None:
if cls._instance is not None:
await cls._instance.handle_tcp_connection(stream, stream)
@classmethod
def os_handle_datagram(
cls,
transport: mitmproxy_rs.DatagramTransport,
data: bytes,
remote_addr: Address,
local_addr: Address,
) -> None:
if cls._instance is not None:
cls._instance.handle_udp_datagram(
transport=transport,
data=data,
remote_addr=remote_addr,
local_addr=local_addr,
)
async def _start(self) -> None:
if self._instance:
raise RuntimeError("Cannot spawn more than one OS proxy instance.")
if self.mode.data.startswith("!"):
spec = f"{self.mode.data},{os.getpid()}"
elif self.mode.data:
spec = self.mode.data
else:
spec = f"!{os.getpid()}"
cls = self.__class__
cls._instance = self # assign before awaiting to avoid races
if cls._server is None:
try:
cls._server = await mitmproxy_rs.start_os_proxy(
cls.os_handle_tcp_connection,
cls.os_handle_datagram,
)
except Exception:
cls._instance = None
raise
cls._server.set_intercept(spec)
async def _stop(self) -> None:
assert self._instance
assert self._server
self.__class__._instance = None
# We're not shutting down the server because we want to avoid additional UAC prompts.
self._server.set_intercept("")
class RegularInstance(AsyncioServerInstance[mode_specs.RegularMode]):
def make_top_layer(self, context: Context) -> Layer:
return layers.modes.HttpProxy(context)

View File

@ -30,6 +30,8 @@ from typing import ClassVar
from typing import Literal
from typing import TypeVar
import mitmproxy_rs
from mitmproxy.coretypes.serializable import Serializable
from mitmproxy.net import server_spec
@ -85,7 +87,7 @@ class ProxyMode(Serializable, metaclass=ABCMeta):
@property
@abstractmethod
def transport_protocol(self) -> Literal["tcp", "udp"]:
def transport_protocol(self) -> Literal["tcp", "udp"] | None:
"""The transport protocol used by this mode's server."""
@classmethod
@ -189,7 +191,7 @@ class RegularMode(ProxyMode):
class TransparentMode(ProxyMode):
"""A transparent proxy, see https://docs.mitmproxy.org/dev/howto-transparent/"""
description = "transparent proxy"
description = "Transparent Proxy"
transport_protocol = TCP
def __post_init__(self) -> None:
@ -280,3 +282,14 @@ class WireGuardMode(ProxyMode):
def __post_init__(self) -> None:
pass
class OsProxyMode(ProxyMode):
"""OS-level transparent proxy."""
description = "OS proxy"
transport_protocol = None
def __post_init__(self) -> None:
# should not raise
mitmproxy_rs.OsProxy.describe_spec(self.data)

View File

@ -111,7 +111,7 @@ async def test_start_stop(caplog_async):
await ps.setup_servers() # assert this can always be called without side effects
tctx.configure(ps, server=False)
await caplog_async.await_log("Stopped HTTP(S) proxy at")
await caplog_async.await_log("stopped")
if ps.servers.is_updating:
async with ps.servers._lock:
pass # wait until start/stop is finished.
@ -318,7 +318,7 @@ async def test_dns(caplog_async) -> None:
w.write(b"\x00")
await caplog_async.await_log("sent an invalid message")
tctx.configure(ps, server=False)
await caplog_async.await_log("Stopped DNS server at")
await caplog_async.await_log("stopped")
def test_validation_no_transparent(monkeypatch):
@ -384,7 +384,7 @@ async def test_dtls(monkeypatch, caplog_async) -> None:
assert repr(ps) == "Proxyserver(1 active conns)"
assert len(ps.connections) == 1
tctx.configure(ps, server=False)
await caplog_async.await_log("Stopped reverse proxy to dtls")
await caplog_async.await_log("stopped")
class H3EchoServer(QuicConnectionProtocol):
@ -793,7 +793,7 @@ async def test_reverse_http3_and_quic_stream(
assert len(ps.connections) == 1
tctx.configure(ps, server=False)
await caplog_async.await_log(f"Stopped reverse proxy to {scheme}")
await caplog_async.await_log(f"stopped")
@pytest.mark.parametrize("connection_strategy", ["lazy", "eager"])
@ -829,7 +829,7 @@ async def test_reverse_quic_datagram(caplog_async, connection_strategy: str) ->
assert await client.recv_datagram() == b"echo"
tctx.configure(ps, server=False)
await caplog_async.await_log("Stopped reverse proxy to quic")
await caplog_async.await_log("stopped")
async def test_regular_http3(caplog_async, monkeypatch) -> None:
@ -869,4 +869,4 @@ async def test_regular_http3(caplog_async, monkeypatch) -> None:
assert len(ps.connections) == 1
tctx.configure(ps, server=False)
await caplog_async.await_log("Stopped HTTP3 proxy")
await caplog_async.await_log("stopped")

View File

@ -86,6 +86,9 @@ def start_h2_client(tctx: Context, keepalive: int = 0) -> tuple[Playbook, FrameF
def make_h2(open_connection: OpenConnection) -> None:
assert isinstance(
open_connection, OpenConnection
), f"Expected OpenConnection event, not {open_connection}"
open_connection.connection.alpn = b"h2"

View File

@ -273,7 +273,6 @@ def h2_frames(draw):
def h2_layer(opts):
tctx = _tctx()
tctx.options.http2_ping_keepalive = 0
tctx.client.alpn = b"h2"
layer = http.HttpLayer(tctx, HTTPMode.regular)
@ -317,7 +316,7 @@ def test_fuzz_h2_request_mutations(chunks):
def _tctx() -> context.Context:
return context.Context(
tctx = context.Context(
connection.Client(
peername=("client", 1234),
sockname=("127.0.0.1", 8080),
@ -325,6 +324,8 @@ def _tctx() -> context.Context:
),
opts,
)
tctx.options.http2_ping_keepalive = 0
return tctx
def _h2_response(chunks):

View File

@ -5,12 +5,14 @@ from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import Mock
import mitmproxy_rs
import pytest
import mitmproxy.platform
from mitmproxy.addons.proxyserver import Proxyserver
from mitmproxy.net import udp
from mitmproxy.proxy.mode_servers import DnsInstance
from mitmproxy.proxy.mode_servers import OsProxyInstance
from mitmproxy.proxy.mode_servers import ServerInstance
from mitmproxy.proxy.mode_servers import WireGuardServerInstance
from mitmproxy.proxy.server import ConnectionHandler
@ -88,7 +90,7 @@ async def test_tcp_start_stop(caplog_async):
assert await caplog_async.await_log("client disconnect")
await inst.stop()
assert await caplog_async.await_log("Stopped HTTP(S) proxy")
assert await caplog_async.await_log("stopped")
@pytest.mark.parametrize("failure", [True, False])
@ -107,7 +109,7 @@ async def test_transparent(failure, monkeypatch, caplog_async):
tctx.options.connection_strategy = "lazy"
inst = ServerInstance.make("transparent@127.0.0.1:0", manager)
await inst.start()
await caplog_async.await_log("proxy listening")
await caplog_async.await_log("listening")
host, port, *_ = inst.listen_addrs[0]
reader, writer = await asyncio.open_connection(host, port)
@ -123,7 +125,7 @@ async def test_transparent(failure, monkeypatch, caplog_async):
assert await caplog_async.await_log("client disconnect")
await inst.stop()
assert await caplog_async.await_log("Stopped transparent proxy")
assert await caplog_async.await_log("stopped")
async def test_wireguard(tdata, monkeypatch, caplog):
@ -176,7 +178,7 @@ async def test_wireguard(tdata, monkeypatch, caplog):
raise
await inst.stop()
assert "Stopped WireGuard server" in caplog.text
assert "stopped" in caplog.text
async def test_wireguard_generate_conf(tmp_path):
@ -205,7 +207,7 @@ async def test_wireguard_invalid_conf(tmp_path):
# directory instead of filename
inst = WireGuardServerInstance.make(f"wireguard:{tmp_path}", MagicMock())
with pytest.raises(OSError):
with pytest.raises(ValueError, match="Invalid configuration file"):
await inst.start()
assert "Invalid configuration file" in repr(inst.last_exception)
@ -261,7 +263,7 @@ async def test_udp_start_stop(caplog_async):
writer.close()
await inst.stop()
assert await caplog_async.await_log("Stopped")
assert await caplog_async.await_log("stopped")
async def test_udp_start_error():
@ -297,3 +299,79 @@ async def test_udp_connection_reuse(monkeypatch):
await asyncio.sleep(0)
assert len(inst.manager.connections) == 1
@pytest.fixture()
def patched_os_proxy(monkeypatch):
start_os_proxy = AsyncMock()
monkeypatch.setattr(mitmproxy_rs, "start_os_proxy", start_os_proxy)
# make sure _server and _instance are restored after this test
monkeypatch.setattr(OsProxyInstance, "_server", None)
monkeypatch.setattr(OsProxyInstance, "_instance", None)
return start_os_proxy
async def test_os_proxy(patched_os_proxy, caplog_async):
caplog_async.set_level("INFO")
with taddons.context():
inst = ServerInstance.make(f"osproxy", MagicMock())
assert not inst.is_running
await inst.start()
assert patched_os_proxy.called
assert await caplog_async.await_log("OS proxy started.")
assert inst.is_running
await inst.stop()
assert await caplog_async.await_log("OS proxy stopped")
assert not inst.is_running
# just called for coverage
inst.make_top_layer(MagicMock())
async def test_os_proxy_startup_err(patched_os_proxy):
patched_os_proxy.side_effect = RuntimeError("OS proxy startup error")
with taddons.context():
inst = ServerInstance.make(f"osproxy:!curl", MagicMock())
with pytest.raises(RuntimeError):
await inst.start()
assert not inst.is_running
async def test_multiple_os_proxies(patched_os_proxy):
manager = MagicMock()
with taddons.context():
inst1 = ServerInstance.make(f"osproxy:curl", manager)
await inst1.start()
inst2 = ServerInstance.make(f"osproxy:wget", manager)
with pytest.raises(
RuntimeError, match="Cannot spawn more than one OS proxy instance"
):
await inst2.start()
async def test_always_uses_current_instance(patched_os_proxy, monkeypatch):
manager = MagicMock()
with taddons.context():
inst1 = ServerInstance.make(f"osproxy:curl", manager)
await inst1.start()
await inst1.stop()
handle_tcp, handle_udp = patched_os_proxy.await_args[0]
inst2 = ServerInstance.make(f"osproxy:wget", manager)
await inst2.start()
monkeypatch.setattr(inst2, "handle_tcp_connection", h_tcp := AsyncMock())
await handle_tcp(Mock())
assert h_tcp.await_count
monkeypatch.setattr(inst2, "handle_udp_datagram", h_udp := Mock())
handle_udp(Mock(), b"", ("", 0), ("", 0))
assert h_udp.called

View File

@ -70,6 +70,8 @@ def test_parse_specific_modes():
assert ProxyMode.parse("wireguard:foo.conf").data == "foo.conf"
assert ProxyMode.parse("wireguard@51821").listen_port() == 51821
assert ProxyMode.parse("osproxy")
with pytest.raises(ValueError, match="invalid port"):
ProxyMode.parse("regular@invalid-port")
@ -87,3 +89,6 @@ def test_parse_specific_modes():
with pytest.raises(ValueError, match="Port specification missing."):
ProxyMode.parse("reverse:dtls://127.0.0.1")
with pytest.raises(ValueError, match="invalid intercept spec"):
ProxyMode.parse("osproxy:,,,")

View File

@ -8,6 +8,7 @@ from contextlib import redirect_stdout
from pathlib import Path
from typing import Optional
from unittest import mock
from unittest.mock import Mock
import pytest
import tornado.testing
@ -176,8 +177,13 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
m.events._add_log(log.LogEntry("test log", "info"))
m.events.done()
si1 = ServerInstance.make("regular", m.proxyserver)
si1._listen_addrs = [("127.0.0.1", 8080), ("::1", 8080)]
si1._server = True # spoof is_running
sock1 = Mock()
sock1.getsockname.return_value = ("127.0.0.1", 8080)
sock2 = Mock()
sock2.getsockname.return_value = ("::1", 8080)
server = Mock()
server.sockets = [sock1, sock2]
si1._server = server
si2 = ServerInstance.make("reverse:example.com", m.proxyserver)
si2.last_exception = RuntimeError("I failed somehow.")
si3 = ServerInstance.make("socks5", m.proxyserver)