add server_connect_error hook (#6806)

* Add server connection error hook

* Add new hook to api-events doc

* Rename and add test

* Forgot to commit

* Small fix

* [autofix.ci] apply automated fixes

* fixed test_server.py

* fixed 'Import block is un-sorted or un-formatted'

* [autofix.ci] apply automated fixes

* test++, doc++

* add CHANGELOG entry

* [autofix.ci] apply automated fixes

* fix authors

* fix test

* [autofix.ci] apply automated fixes

---------

Co-authored-by: haanhvu <haanh6594@gmail.com>
Co-authored-by: spacewasp <spacewasp1982@gmail.com>
This commit is contained in:
Maximilian Hils 2024-04-17 12:47:06 +02:00 committed by GitHub
parent b858a826bf
commit f997cd3a21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 87 additions and 2 deletions

View File

@ -9,6 +9,8 @@
* Add support for editing non text files in a hex editor * Add support for editing non text files in a hex editor
([#6768](https://github.com/mitmproxy/mitmproxy/pull/6768), @wnyyyy) ([#6768](https://github.com/mitmproxy/mitmproxy/pull/6768), @wnyyyy)
* Add `server_connect_error` hook that is triggered when connection establishment fails.
([#6806](https://github.com/mitmproxy/mitmproxy/pull/6806), @haanhvu, @spacewasp, @mhils)
* Add section in mitmweb for rendering, adding and removing a comment * Add section in mitmweb for rendering, adding and removing a comment
([#6709](https://github.com/mitmproxy/mitmproxy/pull/6709), @lups2000) ([#6709](https://github.com/mitmproxy/mitmproxy/pull/6709), @lups2000)
* Fix multipart form content view being unusable. * Fix multipart form content view being unusable.

View File

@ -97,6 +97,7 @@ with outfile.open("w") as f, contextlib.redirect_stdout(f):
server_hooks.ServerConnectHook, server_hooks.ServerConnectHook,
server_hooks.ServerConnectedHook, server_hooks.ServerConnectedHook,
server_hooks.ServerDisconnectedHook, server_hooks.ServerDisconnectedHook,
server_hooks.ServerConnectErrorHook,
], ],
) )

View File

@ -196,6 +196,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.log( self.log(
f"server connection to {human.format_address(command.connection.address)} killed before connect: {err}" f"server connection to {human.format_address(command.connection.address)} killed before connect: {err}"
) )
await self.handle_hook(server_hooks.ServerConnectErrorHook(hook_data))
self.server_event( self.server_event(
events.OpenConnectionCompleted(command, f"Connection killed: {err}") events.OpenConnectionCompleted(command, f"Connection killed: {err}")
) )
@ -224,6 +225,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
err = "connection cancelled" err = "connection cancelled"
self.log(f"error establishing server connection: {err}") self.log(f"error establishing server connection: {err}")
command.connection.error = err command.connection.error = err
await self.handle_hook(server_hooks.ServerConnectErrorHook(hook_data))
self.server_event(events.OpenConnectionCompleted(command, err)) self.server_event(events.OpenConnectionCompleted(command, err))
if isinstance(e, asyncio.CancelledError): if isinstance(e, asyncio.CancelledError):
# From https://docs.python.org/3/library/asyncio-exceptions.html#asyncio.CancelledError: # From https://docs.python.org/3/library/asyncio-exceptions.html#asyncio.CancelledError:
@ -237,8 +239,11 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
command.connection.state = ConnectionState.OPEN command.connection.state = ConnectionState.OPEN
command.connection.peername = writer.get_extra_info("peername") command.connection.peername = writer.get_extra_info("peername")
command.connection.sockname = writer.get_extra_info("sockname") command.connection.sockname = writer.get_extra_info("sockname")
self.transports[command.connection].reader = reader self.transports[command.connection] = ConnectionIO(
self.transports[command.connection].writer = writer handler=asyncio.current_task(),
reader=reader,
writer=writer,
)
assert command.connection.peername assert command.connection.peername
if command.connection.address[0] != command.connection.peername[0]: if command.connection.address[0] != command.connection.peername[0]:

View File

@ -63,3 +63,14 @@ class ServerDisconnectedHook(commands.StartHook):
""" """
data: ServerConnectionHookData data: ServerConnectionHookData
@dataclass
class ServerConnectErrorHook(commands.StartHook):
"""
Mitmproxy failed to connect to a server.
Every server connection will receive either a server_connected or a server_connect_error event, but not both.
"""
data: ServerConnectionHookData

View File

@ -0,0 +1,66 @@
import asyncio
import collections
from unittest import mock
import pytest
from mitmproxy import options
from mitmproxy.connection import Server
from mitmproxy.proxy import commands
from mitmproxy.proxy import server
from mitmproxy.proxy import server_hooks
from mitmproxy.proxy.mode_specs import ProxyMode
class MockConnectionHandler(server.SimpleConnectionHandler):
hook_handlers: dict[str, mock.Mock]
def __init__(self):
super().__init__(
reader=mock.Mock(),
writer=mock.Mock(),
options=options.Options(),
mode=ProxyMode.parse("regular"),
hooks=collections.defaultdict(lambda: mock.Mock()),
)
@pytest.mark.parametrize("result", ("success", "killed", "failed"))
async def test_open_connection(result, monkeypatch):
handler = MockConnectionHandler()
server_connect = handler.hook_handlers["server_connect"]
server_connected = handler.hook_handlers["server_connected"]
server_connect_error = handler.hook_handlers["server_connect_error"]
server_disconnected = handler.hook_handlers["server_disconnected"]
match result:
case "success":
monkeypatch.setattr(
asyncio,
"open_connection",
mock.AsyncMock(return_value=(mock.MagicMock(), mock.MagicMock())),
)
monkeypatch.setattr(
MockConnectionHandler, "handle_connection", mock.AsyncMock()
)
case "failed":
monkeypatch.setattr(
asyncio, "open_connection", mock.AsyncMock(side_effect=OSError)
)
case "killed":
def _kill(d: server_hooks.ServerConnectionHookData) -> None:
d.server.error = "do not connect"
server_connect.side_effect = _kill
await handler.open_connection(
commands.OpenConnection(connection=Server(address=("server", 1234)))
)
assert server_connect.call_args[0][0].server.address == ("server", 1234)
assert server_connected.called == (result == "success")
assert server_connect_error.called == (result != "success")
assert server_disconnected.called == (result == "success")