mirror of
https://github.com/mitmproxy/mitmproxy.git
synced 2024-11-27 15:20:51 +00:00
Add experimental Windows OS proxy mode (#5912)
* add experimental Windows OS proxy mode this is merely a proof-of-concept now, but works under the most ideal circumstances.
This commit is contained in:
parent
a7e50c793e
commit
54185c2c8d
@ -15,6 +15,7 @@ import asyncio
|
||||
import errno
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import textwrap
|
||||
import typing
|
||||
@ -85,10 +86,11 @@ Self = TypeVar("Self", bound="ServerInstance")
|
||||
class ServerInstance(Generic[M], metaclass=ABCMeta):
|
||||
__modes: ClassVar[dict[str, type[ServerInstance]]] = {}
|
||||
|
||||
last_exception: Exception | None = None
|
||||
|
||||
def __init__(self, mode: M, manager: ServerManager):
|
||||
self.mode: M = mode
|
||||
self.manager: ServerManager = manager
|
||||
self.last_exception: Exception | None = None
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""Register all subclasses so that make() finds them."""
|
||||
@ -119,12 +121,41 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
|
||||
def is_running(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
try:
|
||||
await self._start()
|
||||
except Exception as e:
|
||||
self.last_exception = e
|
||||
raise
|
||||
else:
|
||||
self.last_exception = None
|
||||
if self.listen_addrs:
|
||||
addrs = " and ".join({human.format_address(a) for a in self.listen_addrs})
|
||||
logger.info(f"{self.mode.description} listening at {addrs}.")
|
||||
else:
|
||||
logger.info(f"{self.mode.description} started.")
|
||||
|
||||
async def stop(self) -> None:
|
||||
listen_addrs = self.listen_addrs
|
||||
try:
|
||||
await self._stop()
|
||||
except Exception as e:
|
||||
self.last_exception = e
|
||||
raise
|
||||
else:
|
||||
self.last_exception = None
|
||||
if listen_addrs:
|
||||
addrs = " and ".join({human.format_address(a) for a in listen_addrs})
|
||||
logger.info(f"{self.mode.description} at {addrs} stopped.")
|
||||
else:
|
||||
logger.info(f"{self.mode.description} stopped.")
|
||||
|
||||
@abstractmethod
|
||||
async def _start(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
async def _stop(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
@ -166,10 +197,8 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
|
||||
else:
|
||||
handler.layer.context.client.sockname = original_dst
|
||||
handler.layer.context.server.address = original_dst
|
||||
elif isinstance(self.mode, mode_specs.WireGuardMode):
|
||||
original_dst = writer.get_extra_info("original_dst")
|
||||
handler.layer.context.client.sockname = original_dst
|
||||
handler.layer.context.server.address = original_dst
|
||||
elif isinstance(self.mode, (mode_specs.WireGuardMode, mode_specs.OsProxyMode)):
|
||||
handler.layer.context.server.address = handler.layer.context.client.sockname
|
||||
|
||||
connection_id = (
|
||||
handler.layer.context.client.transport_protocol,
|
||||
@ -197,7 +226,9 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
|
||||
handler.layer = self.make_top_layer(handler.layer.context)
|
||||
handler.layer.context.client.transport_protocol = "udp"
|
||||
handler.layer.context.server.transport_protocol = "udp"
|
||||
if isinstance(self.mode, mode_specs.WireGuardMode):
|
||||
if isinstance(
|
||||
self.mode, (mode_specs.WireGuardMode, mode_specs.OsProxyMode)
|
||||
):
|
||||
handler.layer.context.server.address = local_addr
|
||||
|
||||
# pre-register here - we may get datagrams before the task is executed.
|
||||
@ -217,13 +248,19 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
|
||||
|
||||
class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
|
||||
_server: asyncio.Server | udp.UdpServer | None = None
|
||||
_listen_addrs: tuple[Address, ...] = tuple()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._server is not None
|
||||
|
||||
async def start(self) -> None:
|
||||
@property
|
||||
def listen_addrs(self) -> tuple[Address, ...]:
|
||||
if self._server is not None:
|
||||
return tuple(s.getsockname() for s in self._server.sockets)
|
||||
else:
|
||||
return tuple()
|
||||
|
||||
async def _start(self) -> None:
|
||||
assert self._server is None
|
||||
host = self.mode.listen_host(ctx.options.listen_host)
|
||||
port = self.mode.listen_port(ctx.options.listen_port)
|
||||
@ -231,7 +268,6 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
|
||||
self._server = await self.listen(host, port)
|
||||
self._listen_addrs = tuple(s.getsockname() for s in self._server.sockets)
|
||||
except OSError as e:
|
||||
self.last_exception = e
|
||||
message = f"{self.mode.description} failed to listen on {host or '*'}:{port} with {e}"
|
||||
if e.errno == errno.EADDRINUSE and self.mode.custom_listen_port is None:
|
||||
assert (
|
||||
@ -239,31 +275,15 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
|
||||
) # since [@ [listen_addr:]listen_port]
|
||||
message += f"\nTry specifying a different port by using `--mode {self.mode.full_spec}@{port + 1}`."
|
||||
raise OSError(e.errno, message, e.filename) from e
|
||||
except Exception as e:
|
||||
self.last_exception = e
|
||||
raise
|
||||
else:
|
||||
self.last_exception = None
|
||||
addrs = " and ".join({human.format_address(a) for a in self._listen_addrs})
|
||||
logger.info(f"{self.mode.description} listening at {addrs}.")
|
||||
|
||||
async def stop(self) -> None:
|
||||
async def _stop(self) -> None:
|
||||
assert self._server is not None
|
||||
# we always reset _server and _listen_addrs and ignore failures
|
||||
server = self._server
|
||||
listen_addrs = self._listen_addrs
|
||||
self._server = None
|
||||
self._listen_addrs = tuple()
|
||||
try:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
except Exception as e:
|
||||
self.last_exception = e
|
||||
raise
|
||||
else:
|
||||
self.last_exception = None
|
||||
addrs = " and ".join({human.format_address(a) for a in listen_addrs})
|
||||
logger.info(f"Stopped {self.mode.description} at {addrs}.")
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
finally:
|
||||
# we always reset _server and ignore failures
|
||||
self._server = None
|
||||
|
||||
async def listen(self, host: str, port: int) -> asyncio.Server | udp.UdpServer:
|
||||
if self.mode.transport_protocol == "tcp":
|
||||
@ -297,14 +317,9 @@ class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta):
|
||||
else:
|
||||
raise AssertionError(self.mode.transport_protocol)
|
||||
|
||||
@property
|
||||
def listen_addrs(self) -> tuple[Address, ...]:
|
||||
return self._listen_addrs
|
||||
|
||||
|
||||
class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
|
||||
_server: mitmproxy_rs.WireGuardServer | None = None
|
||||
_listen_addrs: tuple[Address, ...] = tuple()
|
||||
|
||||
server_key: str
|
||||
client_key: str
|
||||
@ -316,7 +331,14 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
|
||||
def is_running(self) -> bool:
|
||||
return self._server is not None
|
||||
|
||||
async def start(self) -> None:
|
||||
@property
|
||||
def listen_addrs(self) -> tuple[Address, ...]:
|
||||
if self._server:
|
||||
return (self._server.getsockname(),)
|
||||
else:
|
||||
return tuple()
|
||||
|
||||
async def _start(self) -> None:
|
||||
assert self._server is None
|
||||
host = self.mode.listen_host(ctx.options.listen_host)
|
||||
port = self.mode.listen_port(ctx.options.listen_port)
|
||||
@ -326,56 +348,40 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
|
||||
else:
|
||||
conf_path = Path(ctx.options.confdir).expanduser() / "wireguard.conf"
|
||||
|
||||
try:
|
||||
if not conf_path.exists():
|
||||
conf_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conf_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"server_key": mitmproxy_rs.genkey(),
|
||||
"client_key": mitmproxy_rs.genkey(),
|
||||
},
|
||||
indent=4,
|
||||
)
|
||||
if not conf_path.exists():
|
||||
conf_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conf_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"server_key": mitmproxy_rs.genkey(),
|
||||
"client_key": mitmproxy_rs.genkey(),
|
||||
},
|
||||
indent=4,
|
||||
)
|
||||
|
||||
try:
|
||||
c = json.loads(conf_path.read_text())
|
||||
self.server_key = c["server_key"]
|
||||
self.client_key = c["client_key"]
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Invalid configuration file ({conf_path}): {e}"
|
||||
) from e
|
||||
# error early on invalid keys
|
||||
p = mitmproxy_rs.pubkey(self.client_key)
|
||||
_ = mitmproxy_rs.pubkey(self.server_key)
|
||||
|
||||
self._server = await mitmproxy_rs.start_wireguard_server(
|
||||
host,
|
||||
port,
|
||||
self.server_key,
|
||||
[p],
|
||||
self.wg_handle_tcp_connection,
|
||||
self.handle_udp_datagram,
|
||||
)
|
||||
self._listen_addrs = (self._server.getsockname(),)
|
||||
except Exception as e:
|
||||
self.last_exception = e
|
||||
message = f"{self.mode.description} failed to listen on {host or '*'}:{port} with {e}"
|
||||
raise OSError(message) from e
|
||||
else:
|
||||
self.last_exception = None
|
||||
|
||||
addrs = " and ".join({human.format_address(a) for a in self.listen_addrs})
|
||||
try:
|
||||
c = json.loads(conf_path.read_text())
|
||||
self.server_key = c["server_key"]
|
||||
self.client_key = c["client_key"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid configuration file ({conf_path}): {e}") from e
|
||||
# error early on invalid keys
|
||||
p = mitmproxy_rs.pubkey(self.client_key)
|
||||
_ = mitmproxy_rs.pubkey(self.server_key)
|
||||
|
||||
self._server = await mitmproxy_rs.start_wireguard_server(
|
||||
host,
|
||||
port,
|
||||
self.server_key,
|
||||
[p],
|
||||
self.wg_handle_tcp_connection,
|
||||
self.handle_udp_datagram,
|
||||
)
|
||||
|
||||
conf = self.client_conf()
|
||||
assert conf
|
||||
logger.info(
|
||||
f"{self.mode.description} listening at {addrs}.\n"
|
||||
+ "------------------------------------------------------------\n"
|
||||
+ conf
|
||||
+ "\n------------------------------------------------------------"
|
||||
)
|
||||
logger.info("-" * 60 + "\n" + conf + "\n" + "-" * 60)
|
||||
|
||||
def client_conf(self) -> str | None:
|
||||
if not self._server:
|
||||
@ -399,24 +405,86 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
|
||||
def to_json(self) -> dict:
|
||||
return {"wireguard_conf": self.client_conf(), **super().to_json()}
|
||||
|
||||
async def stop(self) -> None:
|
||||
async def _stop(self) -> None:
|
||||
assert self._server is not None
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
self.last_exception = None
|
||||
|
||||
addrs = " and ".join({human.format_address(a) for a in self.listen_addrs})
|
||||
logger.info(f"Stopped {self.mode.description} at {addrs}.")
|
||||
|
||||
@property
|
||||
def listen_addrs(self) -> tuple[Address, ...]:
|
||||
return self._listen_addrs
|
||||
try:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
finally:
|
||||
self._server = None
|
||||
|
||||
async def wg_handle_tcp_connection(self, stream: mitmproxy_rs.TcpStream) -> None:
|
||||
await self.handle_tcp_connection(stream, stream)
|
||||
|
||||
|
||||
class OsProxyInstance(ServerInstance[mode_specs.OsProxyMode]):
|
||||
_server: ClassVar[mitmproxy_rs.OsProxy | None] = None
|
||||
"""The OsProxy server. Will be started once and then reused for all future instances."""
|
||||
_instance: ClassVar[OsProxyInstance | None] = None
|
||||
"""The current OsProxy Instance. Will be unset again if an instance is stopped."""
|
||||
listen_addrs = ()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._instance is not None
|
||||
|
||||
def make_top_layer(self, context: Context) -> Layer:
|
||||
return layers.modes.TransparentProxy(context)
|
||||
|
||||
@classmethod
|
||||
async def os_handle_tcp_connection(cls, stream: mitmproxy_rs.TcpStream) -> None:
|
||||
if cls._instance is not None:
|
||||
await cls._instance.handle_tcp_connection(stream, stream)
|
||||
|
||||
@classmethod
|
||||
def os_handle_datagram(
|
||||
cls,
|
||||
transport: mitmproxy_rs.DatagramTransport,
|
||||
data: bytes,
|
||||
remote_addr: Address,
|
||||
local_addr: Address,
|
||||
) -> None:
|
||||
if cls._instance is not None:
|
||||
cls._instance.handle_udp_datagram(
|
||||
transport=transport,
|
||||
data=data,
|
||||
remote_addr=remote_addr,
|
||||
local_addr=local_addr,
|
||||
)
|
||||
|
||||
async def _start(self) -> None:
|
||||
if self._instance:
|
||||
raise RuntimeError("Cannot spawn more than one OS proxy instance.")
|
||||
|
||||
if self.mode.data.startswith("!"):
|
||||
spec = f"{self.mode.data},{os.getpid()}"
|
||||
elif self.mode.data:
|
||||
spec = self.mode.data
|
||||
else:
|
||||
spec = f"!{os.getpid()}"
|
||||
|
||||
cls = self.__class__
|
||||
cls._instance = self # assign before awaiting to avoid races
|
||||
if cls._server is None:
|
||||
try:
|
||||
cls._server = await mitmproxy_rs.start_os_proxy(
|
||||
cls.os_handle_tcp_connection,
|
||||
cls.os_handle_datagram,
|
||||
)
|
||||
except Exception:
|
||||
cls._instance = None
|
||||
raise
|
||||
|
||||
cls._server.set_intercept(spec)
|
||||
|
||||
async def _stop(self) -> None:
|
||||
assert self._instance
|
||||
assert self._server
|
||||
self.__class__._instance = None
|
||||
# We're not shutting down the server because we want to avoid additional UAC prompts.
|
||||
self._server.set_intercept("")
|
||||
|
||||
|
||||
class RegularInstance(AsyncioServerInstance[mode_specs.RegularMode]):
|
||||
def make_top_layer(self, context: Context) -> Layer:
|
||||
return layers.modes.HttpProxy(context)
|
||||
|
@ -30,6 +30,8 @@ from typing import ClassVar
|
||||
from typing import Literal
|
||||
from typing import TypeVar
|
||||
|
||||
import mitmproxy_rs
|
||||
|
||||
from mitmproxy.coretypes.serializable import Serializable
|
||||
from mitmproxy.net import server_spec
|
||||
|
||||
@ -85,7 +87,7 @@ class ProxyMode(Serializable, metaclass=ABCMeta):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def transport_protocol(self) -> Literal["tcp", "udp"]:
|
||||
def transport_protocol(self) -> Literal["tcp", "udp"] | None:
|
||||
"""The transport protocol used by this mode's server."""
|
||||
|
||||
@classmethod
|
||||
@ -189,7 +191,7 @@ class RegularMode(ProxyMode):
|
||||
class TransparentMode(ProxyMode):
|
||||
"""A transparent proxy, see https://docs.mitmproxy.org/dev/howto-transparent/"""
|
||||
|
||||
description = "transparent proxy"
|
||||
description = "Transparent Proxy"
|
||||
transport_protocol = TCP
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@ -280,3 +282,14 @@ class WireGuardMode(ProxyMode):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class OsProxyMode(ProxyMode):
|
||||
"""OS-level transparent proxy."""
|
||||
|
||||
description = "OS proxy"
|
||||
transport_protocol = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# should not raise
|
||||
mitmproxy_rs.OsProxy.describe_spec(self.data)
|
||||
|
@ -111,7 +111,7 @@ async def test_start_stop(caplog_async):
|
||||
|
||||
await ps.setup_servers() # assert this can always be called without side effects
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log("Stopped HTTP(S) proxy at")
|
||||
await caplog_async.await_log("stopped")
|
||||
if ps.servers.is_updating:
|
||||
async with ps.servers._lock:
|
||||
pass # wait until start/stop is finished.
|
||||
@ -318,7 +318,7 @@ async def test_dns(caplog_async) -> None:
|
||||
w.write(b"\x00")
|
||||
await caplog_async.await_log("sent an invalid message")
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log("Stopped DNS server at")
|
||||
await caplog_async.await_log("stopped")
|
||||
|
||||
|
||||
def test_validation_no_transparent(monkeypatch):
|
||||
@ -384,7 +384,7 @@ async def test_dtls(monkeypatch, caplog_async) -> None:
|
||||
assert repr(ps) == "Proxyserver(1 active conns)"
|
||||
assert len(ps.connections) == 1
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log("Stopped reverse proxy to dtls")
|
||||
await caplog_async.await_log("stopped")
|
||||
|
||||
|
||||
class H3EchoServer(QuicConnectionProtocol):
|
||||
@ -793,7 +793,7 @@ async def test_reverse_http3_and_quic_stream(
|
||||
assert len(ps.connections) == 1
|
||||
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log(f"Stopped reverse proxy to {scheme}")
|
||||
await caplog_async.await_log(f"stopped")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("connection_strategy", ["lazy", "eager"])
|
||||
@ -829,7 +829,7 @@ async def test_reverse_quic_datagram(caplog_async, connection_strategy: str) ->
|
||||
assert await client.recv_datagram() == b"echo"
|
||||
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log("Stopped reverse proxy to quic")
|
||||
await caplog_async.await_log("stopped")
|
||||
|
||||
|
||||
async def test_regular_http3(caplog_async, monkeypatch) -> None:
|
||||
@ -869,4 +869,4 @@ async def test_regular_http3(caplog_async, monkeypatch) -> None:
|
||||
assert len(ps.connections) == 1
|
||||
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log("Stopped HTTP3 proxy")
|
||||
await caplog_async.await_log("stopped")
|
||||
|
@ -86,6 +86,9 @@ def start_h2_client(tctx: Context, keepalive: int = 0) -> tuple[Playbook, FrameF
|
||||
|
||||
|
||||
def make_h2(open_connection: OpenConnection) -> None:
|
||||
assert isinstance(
|
||||
open_connection, OpenConnection
|
||||
), f"Expected OpenConnection event, not {open_connection}"
|
||||
open_connection.connection.alpn = b"h2"
|
||||
|
||||
|
||||
|
@ -273,7 +273,6 @@ def h2_frames(draw):
|
||||
|
||||
def h2_layer(opts):
|
||||
tctx = _tctx()
|
||||
tctx.options.http2_ping_keepalive = 0
|
||||
tctx.client.alpn = b"h2"
|
||||
|
||||
layer = http.HttpLayer(tctx, HTTPMode.regular)
|
||||
@ -317,7 +316,7 @@ def test_fuzz_h2_request_mutations(chunks):
|
||||
|
||||
|
||||
def _tctx() -> context.Context:
|
||||
return context.Context(
|
||||
tctx = context.Context(
|
||||
connection.Client(
|
||||
peername=("client", 1234),
|
||||
sockname=("127.0.0.1", 8080),
|
||||
@ -325,6 +324,8 @@ def _tctx() -> context.Context:
|
||||
),
|
||||
opts,
|
||||
)
|
||||
tctx.options.http2_ping_keepalive = 0
|
||||
return tctx
|
||||
|
||||
|
||||
def _h2_response(chunks):
|
||||
|
@ -5,12 +5,14 @@ from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import mitmproxy_rs
|
||||
import pytest
|
||||
|
||||
import mitmproxy.platform
|
||||
from mitmproxy.addons.proxyserver import Proxyserver
|
||||
from mitmproxy.net import udp
|
||||
from mitmproxy.proxy.mode_servers import DnsInstance
|
||||
from mitmproxy.proxy.mode_servers import OsProxyInstance
|
||||
from mitmproxy.proxy.mode_servers import ServerInstance
|
||||
from mitmproxy.proxy.mode_servers import WireGuardServerInstance
|
||||
from mitmproxy.proxy.server import ConnectionHandler
|
||||
@ -88,7 +90,7 @@ async def test_tcp_start_stop(caplog_async):
|
||||
assert await caplog_async.await_log("client disconnect")
|
||||
|
||||
await inst.stop()
|
||||
assert await caplog_async.await_log("Stopped HTTP(S) proxy")
|
||||
assert await caplog_async.await_log("stopped")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("failure", [True, False])
|
||||
@ -107,7 +109,7 @@ async def test_transparent(failure, monkeypatch, caplog_async):
|
||||
tctx.options.connection_strategy = "lazy"
|
||||
inst = ServerInstance.make("transparent@127.0.0.1:0", manager)
|
||||
await inst.start()
|
||||
await caplog_async.await_log("proxy listening")
|
||||
await caplog_async.await_log("listening")
|
||||
|
||||
host, port, *_ = inst.listen_addrs[0]
|
||||
reader, writer = await asyncio.open_connection(host, port)
|
||||
@ -123,7 +125,7 @@ async def test_transparent(failure, monkeypatch, caplog_async):
|
||||
assert await caplog_async.await_log("client disconnect")
|
||||
|
||||
await inst.stop()
|
||||
assert await caplog_async.await_log("Stopped transparent proxy")
|
||||
assert await caplog_async.await_log("stopped")
|
||||
|
||||
|
||||
async def test_wireguard(tdata, monkeypatch, caplog):
|
||||
@ -176,7 +178,7 @@ async def test_wireguard(tdata, monkeypatch, caplog):
|
||||
raise
|
||||
|
||||
await inst.stop()
|
||||
assert "Stopped WireGuard server" in caplog.text
|
||||
assert "stopped" in caplog.text
|
||||
|
||||
|
||||
async def test_wireguard_generate_conf(tmp_path):
|
||||
@ -205,7 +207,7 @@ async def test_wireguard_invalid_conf(tmp_path):
|
||||
# directory instead of filename
|
||||
inst = WireGuardServerInstance.make(f"wireguard:{tmp_path}", MagicMock())
|
||||
|
||||
with pytest.raises(OSError):
|
||||
with pytest.raises(ValueError, match="Invalid configuration file"):
|
||||
await inst.start()
|
||||
|
||||
assert "Invalid configuration file" in repr(inst.last_exception)
|
||||
@ -261,7 +263,7 @@ async def test_udp_start_stop(caplog_async):
|
||||
writer.close()
|
||||
|
||||
await inst.stop()
|
||||
assert await caplog_async.await_log("Stopped")
|
||||
assert await caplog_async.await_log("stopped")
|
||||
|
||||
|
||||
async def test_udp_start_error():
|
||||
@ -297,3 +299,79 @@ async def test_udp_connection_reuse(monkeypatch):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(inst.manager.connections) == 1
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def patched_os_proxy(monkeypatch):
|
||||
start_os_proxy = AsyncMock()
|
||||
monkeypatch.setattr(mitmproxy_rs, "start_os_proxy", start_os_proxy)
|
||||
# make sure _server and _instance are restored after this test
|
||||
monkeypatch.setattr(OsProxyInstance, "_server", None)
|
||||
monkeypatch.setattr(OsProxyInstance, "_instance", None)
|
||||
return start_os_proxy
|
||||
|
||||
|
||||
async def test_os_proxy(patched_os_proxy, caplog_async):
|
||||
caplog_async.set_level("INFO")
|
||||
|
||||
with taddons.context():
|
||||
inst = ServerInstance.make(f"osproxy", MagicMock())
|
||||
assert not inst.is_running
|
||||
|
||||
await inst.start()
|
||||
assert patched_os_proxy.called
|
||||
assert await caplog_async.await_log("OS proxy started.")
|
||||
assert inst.is_running
|
||||
|
||||
await inst.stop()
|
||||
assert await caplog_async.await_log("OS proxy stopped")
|
||||
assert not inst.is_running
|
||||
|
||||
# just called for coverage
|
||||
inst.make_top_layer(MagicMock())
|
||||
|
||||
|
||||
async def test_os_proxy_startup_err(patched_os_proxy):
|
||||
patched_os_proxy.side_effect = RuntimeError("OS proxy startup error")
|
||||
|
||||
with taddons.context():
|
||||
inst = ServerInstance.make(f"osproxy:!curl", MagicMock())
|
||||
with pytest.raises(RuntimeError):
|
||||
await inst.start()
|
||||
assert not inst.is_running
|
||||
|
||||
|
||||
async def test_multiple_os_proxies(patched_os_proxy):
|
||||
manager = MagicMock()
|
||||
|
||||
with taddons.context():
|
||||
inst1 = ServerInstance.make(f"osproxy:curl", manager)
|
||||
await inst1.start()
|
||||
|
||||
inst2 = ServerInstance.make(f"osproxy:wget", manager)
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Cannot spawn more than one OS proxy instance"
|
||||
):
|
||||
await inst2.start()
|
||||
|
||||
|
||||
async def test_always_uses_current_instance(patched_os_proxy, monkeypatch):
|
||||
manager = MagicMock()
|
||||
|
||||
with taddons.context():
|
||||
inst1 = ServerInstance.make(f"osproxy:curl", manager)
|
||||
await inst1.start()
|
||||
await inst1.stop()
|
||||
|
||||
handle_tcp, handle_udp = patched_os_proxy.await_args[0]
|
||||
|
||||
inst2 = ServerInstance.make(f"osproxy:wget", manager)
|
||||
await inst2.start()
|
||||
|
||||
monkeypatch.setattr(inst2, "handle_tcp_connection", h_tcp := AsyncMock())
|
||||
await handle_tcp(Mock())
|
||||
assert h_tcp.await_count
|
||||
|
||||
monkeypatch.setattr(inst2, "handle_udp_datagram", h_udp := Mock())
|
||||
handle_udp(Mock(), b"", ("", 0), ("", 0))
|
||||
assert h_udp.called
|
||||
|
@ -70,6 +70,8 @@ def test_parse_specific_modes():
|
||||
assert ProxyMode.parse("wireguard:foo.conf").data == "foo.conf"
|
||||
assert ProxyMode.parse("wireguard@51821").listen_port() == 51821
|
||||
|
||||
assert ProxyMode.parse("osproxy")
|
||||
|
||||
with pytest.raises(ValueError, match="invalid port"):
|
||||
ProxyMode.parse("regular@invalid-port")
|
||||
|
||||
@ -87,3 +89,6 @@ def test_parse_specific_modes():
|
||||
|
||||
with pytest.raises(ValueError, match="Port specification missing."):
|
||||
ProxyMode.parse("reverse:dtls://127.0.0.1")
|
||||
|
||||
with pytest.raises(ValueError, match="invalid intercept spec"):
|
||||
ProxyMode.parse("osproxy:,,,")
|
||||
|
@ -8,6 +8,7 @@ from contextlib import redirect_stdout
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import tornado.testing
|
||||
@ -176,8 +177,13 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
|
||||
m.events._add_log(log.LogEntry("test log", "info"))
|
||||
m.events.done()
|
||||
si1 = ServerInstance.make("regular", m.proxyserver)
|
||||
si1._listen_addrs = [("127.0.0.1", 8080), ("::1", 8080)]
|
||||
si1._server = True # spoof is_running
|
||||
sock1 = Mock()
|
||||
sock1.getsockname.return_value = ("127.0.0.1", 8080)
|
||||
sock2 = Mock()
|
||||
sock2.getsockname.return_value = ("::1", 8080)
|
||||
server = Mock()
|
||||
server.sockets = [sock1, sock2]
|
||||
si1._server = server
|
||||
si2 = ServerInstance.make("reverse:example.com", m.proxyserver)
|
||||
si2.last_exception = RuntimeError("I failed somehow.")
|
||||
si3 = ServerInstance.make("socks5", m.proxyserver)
|
||||
|
Loading…
Reference in New Issue
Block a user