mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 22:00:34 +00:00
Merge pull request #4650 from mhils/prinzhorn
[WIP] Fix WebSocket/TCP injection
This commit is contained in:
commit
5b4ac96f4c
@ -55,6 +55,10 @@ To document all event hooks, we do a bit of hackery:
|
||||
{% if doc.qualname.startswith("ServerConnectionHookData") and doc.name != "__init__" %}
|
||||
{{ default_is_public(doc) }}
|
||||
{% endif %}
|
||||
{% elif doc.modulename == "mitmproxy.websocket" %}
|
||||
{% if doc.qualname != "WebSocketMessage.type" %}
|
||||
{{ default_is_public(doc) }}
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{{ default_is_public(doc) }}
|
||||
{% endif %}
|
||||
|
@ -8,7 +8,7 @@ from mitmproxy.proxy import commands, events, server_hooks
|
||||
from mitmproxy.proxy import server
|
||||
from mitmproxy.proxy.layers.tcp import TcpMessageInjected
|
||||
from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected
|
||||
from mitmproxy.utils import asyncio_utils, human, strutils
|
||||
from mitmproxy.utils import asyncio_utils, human
|
||||
from wsproto.frame_protocol import Opcode
|
||||
|
||||
|
||||
@ -190,15 +190,14 @@ class Proxyserver:
|
||||
self._connections[event.flow.client_conn.peername].server_event(event)
|
||||
|
||||
@command.command("inject.websocket")
|
||||
def inject_websocket(self, flow: Flow, to_client: bool, message: str, is_text: bool = True):
|
||||
def inject_websocket(self, flow: Flow, to_client: bool, message: bytes, is_text: bool = True):
|
||||
if not isinstance(flow, http.HTTPFlow) or not flow.websocket:
|
||||
ctx.log.warn("Cannot inject WebSocket messages into non-WebSocket flows.")
|
||||
|
||||
message_bytes = strutils.escaped_str_to_bytes(message)
|
||||
msg = websocket.WebSocketMessage(
|
||||
Opcode.TEXT if is_text else Opcode.BINARY,
|
||||
not to_client,
|
||||
message_bytes
|
||||
message
|
||||
)
|
||||
event = WebSocketMessageInjected(flow, msg)
|
||||
try:
|
||||
@ -207,12 +206,11 @@ class Proxyserver:
|
||||
ctx.log.warn(str(e))
|
||||
|
||||
@command.command("inject.tcp")
|
||||
def inject_tcp(self, flow: Flow, to_client: bool, message: str):
|
||||
def inject_tcp(self, flow: Flow, to_client: bool, message: bytes):
|
||||
if not isinstance(flow, tcp.TCPFlow):
|
||||
ctx.log.warn("Cannot inject TCP messages into non-TCP flows.")
|
||||
|
||||
message_bytes = strutils.escaped_str_to_bytes(message)
|
||||
event = TcpMessageInjected(flow, tcp.TCPMessage(not to_client, message_bytes))
|
||||
event = TcpMessageInjected(flow, tcp.TCPMessage(not to_client, message))
|
||||
try:
|
||||
self.inject_event(event)
|
||||
except ValueError as e:
|
||||
|
@ -73,7 +73,7 @@ class Command:
|
||||
for name, parameter in self.signature.parameters.items():
|
||||
t = parameter.annotation
|
||||
if not mitmproxy.types.CommandTypes.get(parameter.annotation, None):
|
||||
raise exceptions.CommandError(f"Argument {name} has an unknown type ({_empty_as_none(t)}) in {func}.")
|
||||
raise exceptions.CommandError(f"Argument {name} has an unknown type {t} in {func}.")
|
||||
if self.return_type and not mitmproxy.types.CommandTypes.get(self.return_type, None):
|
||||
raise exceptions.CommandError(f"Return type has an unknown type ({self.return_type}) in {func}.")
|
||||
|
||||
@ -106,8 +106,15 @@ class Command:
|
||||
raise exceptions.CommandError(f"Command argument mismatch: \n {expected}\n {received}")
|
||||
|
||||
for name, value in bound_arguments.arguments.items():
|
||||
convert_to = self.signature.parameters[name].annotation
|
||||
bound_arguments.arguments[name] = parsearg(self.manager, value, convert_to)
|
||||
param = self.signature.parameters[name]
|
||||
convert_to = param.annotation
|
||||
if param.kind == param.VAR_POSITIONAL:
|
||||
bound_arguments.arguments[name] = tuple(
|
||||
parsearg(self.manager, x, convert_to)
|
||||
for x in value
|
||||
)
|
||||
else:
|
||||
bound_arguments.arguments[name] = parsearg(self.manager, value, convert_to)
|
||||
|
||||
bound_arguments.apply_defaults()
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import ast
|
||||
import re
|
||||
|
||||
import pyparsing
|
||||
@ -10,13 +9,9 @@ import pyparsing
|
||||
PartialQuotedString = pyparsing.Regex(
|
||||
re.compile(
|
||||
r'''
|
||||
(["']) # start quote
|
||||
(?:
|
||||
(?:\\.) # escape sequence
|
||||
|
|
||||
(?!\1). # unescaped character that is not our quote nor the begin of an escape sequence. We can't use \1 in []
|
||||
)*
|
||||
(?:\1|$) # end quote
|
||||
"[^"]*(?:"|$) # double-quoted string that ends with double quote or EOF
|
||||
|
|
||||
'[^']*(?:'|$) # single-quoted string that ends with double quote or EOF
|
||||
''',
|
||||
re.VERBOSE
|
||||
)
|
||||
@ -32,18 +27,15 @@ expr = pyparsing.ZeroOrMore(
|
||||
def quote(val: str) -> str:
|
||||
if val and all(char not in val for char in "'\" \r\n\t"):
|
||||
return val
|
||||
return repr(val) # TODO: More of a hack.
|
||||
if '"' not in val:
|
||||
return f'"{val}"'
|
||||
if "'" not in val:
|
||||
return f"'{val}'"
|
||||
return '"' + val.replace('"', r"\x22") + '"'
|
||||
|
||||
|
||||
def unquote(x: str) -> str:
|
||||
quoted = (
|
||||
(x.startswith('"') and x.endswith('"'))
|
||||
or
|
||||
(x.startswith("'") and x.endswith("'"))
|
||||
)
|
||||
if quoted:
|
||||
try:
|
||||
x = ast.literal_eval(x)
|
||||
except Exception:
|
||||
x = x[1:-1]
|
||||
return x
|
||||
if len(x) > 1 and x[0] in "'\"" and x[0] == x[-1]:
|
||||
return x[1:-1]
|
||||
else:
|
||||
return x
|
||||
|
@ -34,8 +34,8 @@ class Display(base.Cell):
|
||||
|
||||
class Edit(base.Cell):
|
||||
def __init__(self, data: bytes) -> None:
|
||||
data = strutils.bytes_to_escaped_str(data)
|
||||
w = urwid.Edit(edit_text=data, wrap="any", multiline=True)
|
||||
d = strutils.bytes_to_escaped_str(data)
|
||||
w = urwid.Edit(edit_text=d, wrap="any", multiline=True)
|
||||
w = urwid.AttrWrap(w, "editfield")
|
||||
super().__init__(w)
|
||||
|
||||
|
@ -1,10 +1,12 @@
|
||||
import codecs
|
||||
import os
|
||||
import glob
|
||||
import re
|
||||
import typing
|
||||
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy import flow
|
||||
from mitmproxy.utils import emoji
|
||||
from mitmproxy.utils import emoji, strutils
|
||||
|
||||
if typing.TYPE_CHECKING: # pragma: no cover
|
||||
from mitmproxy.command import CommandManager
|
||||
@ -104,16 +106,52 @@ class _StrType(_BaseType):
|
||||
typ = str
|
||||
display = "str"
|
||||
|
||||
# https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals
|
||||
escape_sequences = re.compile(r"""
|
||||
\\ (
|
||||
[\\'"abfnrtv] # Standard C escape sequence
|
||||
| [0-7]{1,3} # Character with octal value
|
||||
| x.. # Character with hex value
|
||||
| N{[^}]+} # Character name in the Unicode database
|
||||
| u.... # Character with 16-bit hex value
|
||||
| U........ # Character with 32-bit hex value
|
||||
)
|
||||
""", re.VERBOSE)
|
||||
|
||||
@staticmethod
|
||||
def _unescape(match: re.Match) -> str:
|
||||
return codecs.decode(match.group(0), "unicode-escape") # type: ignore
|
||||
|
||||
def completion(self, manager: "CommandManager", t: type, s: str) -> typing.Sequence[str]:
|
||||
return []
|
||||
|
||||
def parse(self, manager: "CommandManager", t: type, s: str) -> str:
|
||||
return s
|
||||
try:
|
||||
return self.escape_sequences.sub(self._unescape, s)
|
||||
except ValueError as e:
|
||||
raise exceptions.TypeError(f"Invalid str: {e}") from e
|
||||
|
||||
def is_valid(self, manager: "CommandManager", typ: typing.Any, val: typing.Any) -> bool:
|
||||
return isinstance(val, str)
|
||||
|
||||
|
||||
class _BytesType(_BaseType):
|
||||
typ = bytes
|
||||
display = "bytes"
|
||||
|
||||
def completion(self, manager: "CommandManager", t: type, s: str) -> typing.Sequence[str]:
|
||||
return []
|
||||
|
||||
def parse(self, manager: "CommandManager", t: type, s: str) -> bytes:
|
||||
try:
|
||||
return strutils.escaped_str_to_bytes(s)
|
||||
except ValueError as e:
|
||||
raise exceptions.TypeError(str(e))
|
||||
|
||||
def is_valid(self, manager: "CommandManager", typ: typing.Any, val: typing.Any) -> bool:
|
||||
return isinstance(val, bytes)
|
||||
|
||||
|
||||
class _UnknownType(_BaseType):
|
||||
typ = Unknown
|
||||
display = "unknown"
|
||||
@ -460,4 +498,5 @@ CommandTypes = TypeManager(
|
||||
_PathType,
|
||||
_StrType,
|
||||
_StrSeqType,
|
||||
_BytesType,
|
||||
)
|
||||
|
@ -79,7 +79,7 @@ def escape_control_characters(text: str, keep_spacing=True) -> str:
|
||||
return text.translate(trans)
|
||||
|
||||
|
||||
def bytes_to_escaped_str(data, keep_spacing=False, escape_single_quotes=False):
|
||||
def bytes_to_escaped_str(data: bytes, keep_spacing: bool = False, escape_single_quotes: bool = False) -> str:
|
||||
"""
|
||||
Take bytes and return a safe string that can be displayed to the user.
|
||||
|
||||
@ -107,7 +107,7 @@ def bytes_to_escaped_str(data, keep_spacing=False, escape_single_quotes=False):
|
||||
return ret
|
||||
|
||||
|
||||
def escaped_str_to_bytes(data):
|
||||
def escaped_str_to_bytes(data: str) -> bytes:
|
||||
"""
|
||||
Take an escaped string and return the unescaped bytes equivalent.
|
||||
|
||||
@ -119,7 +119,7 @@ def escaped_str_to_bytes(data):
|
||||
|
||||
# This one is difficult - we use an undocumented Python API here
|
||||
# as per http://stackoverflow.com/a/23151714/934719
|
||||
return codecs.escape_decode(data)[0]
|
||||
return codecs.escape_decode(data)[0] # type: ignore
|
||||
|
||||
|
||||
def is_mostly_bin(s: bytes) -> bool:
|
||||
|
@ -20,18 +20,16 @@ class WebSocketMessage(serializable.Serializable):
|
||||
"""
|
||||
A single WebSocket message sent from one peer to the other.
|
||||
|
||||
Fragmented WebSocket messages are reassembled by mitmproxy and the
|
||||
Fragmented WebSocket messages are reassembled by mitmproxy and then
|
||||
represented as a single instance of this class.
|
||||
|
||||
The [WebSocket RFC](https://tools.ietf.org/html/rfc6455) specifies both
|
||||
text and binary messages. To avoid a whole class of nasty type confusion bugs,
|
||||
mitmproxy stores all message contents as binary. If you need text, you can decode the `content` property:
|
||||
mitmproxy stores all message contents as `bytes`. If you need a `str`, you can access the `text` property
|
||||
on text messages:
|
||||
|
||||
>>> from wsproto.frame_protocol import Opcode
|
||||
>>> if message.type == Opcode.TEXT:
|
||||
>>> text = message.content.decode()
|
||||
|
||||
Per the WebSocket spec, text messages always use UTF-8 encoding.
|
||||
>>> if message.is_text:
|
||||
>>> text = message.text
|
||||
"""
|
||||
|
||||
from_client: bool
|
||||
@ -40,8 +38,7 @@ class WebSocketMessage(serializable.Serializable):
|
||||
"""
|
||||
The message type, as per RFC 6455's [opcode](https://tools.ietf.org/html/rfc6455#section-5.2).
|
||||
|
||||
Note that mitmproxy will always store the message contents as *bytes*.
|
||||
A dedicated `.text` property for text messages is planned, see https://github.com/mitmproxy/mitmproxy/pull/4486.
|
||||
Mitmproxy currently only exposes messages assembled from `TEXT` and `BINARY` frames.
|
||||
"""
|
||||
content: bytes
|
||||
"""A byte-string representing the content of this message."""
|
||||
@ -81,10 +78,39 @@ class WebSocketMessage(serializable.Serializable):
|
||||
else:
|
||||
return repr(self.content)
|
||||
|
||||
@property
|
||||
def is_text(self) -> bool:
|
||||
"""
|
||||
`True` if this message is assembled from WebSocket `TEXT` frames,
|
||||
`False` if it is assembled from `BINARY` frames.
|
||||
"""
|
||||
return self.type == Opcode.TEXT
|
||||
|
||||
def kill(self):
|
||||
# Likely to be replaced with .drop() in the future, see https://github.com/mitmproxy/mitmproxy/pull/4486
|
||||
self.killed = True
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""
|
||||
The message content as text.
|
||||
|
||||
This attribute is only available if `WebSocketMessage.is_text` is `True`.
|
||||
|
||||
*See also:* `WebSocketMessage.content`
|
||||
"""
|
||||
if self.type != Opcode.TEXT:
|
||||
raise AttributeError(f"{self.type.name.title()} WebSocket frames do not have a 'text' attribute.")
|
||||
|
||||
return self.content.decode()
|
||||
|
||||
@text.setter
|
||||
def text(self, value: str) -> None:
|
||||
if self.type != Opcode.TEXT:
|
||||
raise AttributeError(f"{self.type.name.title()} WebSocket frames do not have a 'text' attribute.")
|
||||
|
||||
self.content = value.encode()
|
||||
|
||||
|
||||
class WebSocketData(stateobject.StateObject):
|
||||
"""
|
||||
@ -97,9 +123,9 @@ class WebSocketData(stateobject.StateObject):
|
||||
|
||||
closed_by_client: Optional[bool] = None
|
||||
"""
|
||||
True if the client closed the connection,
|
||||
False if the server closed the connection,
|
||||
None if the connection is active.
|
||||
`True` if the client closed the connection,
|
||||
`False` if the server closed the connection,
|
||||
`None` if the connection is active.
|
||||
"""
|
||||
close_code: Optional[int] = None
|
||||
"""[Close Code](https://tools.ietf.org/html/rfc6455#section-7.1.5)"""
|
||||
|
@ -90,7 +90,7 @@ async def test_start_stop():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject():
|
||||
async def test_inject() -> None:
|
||||
async def server_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
while s := await reader.read(1):
|
||||
writer.write(s.upper())
|
||||
@ -112,39 +112,39 @@ async def test_inject():
|
||||
|
||||
writer.write(b"a")
|
||||
assert await reader.read(1) == b"A"
|
||||
ps.inject_tcp(state.flows[0], False, "b")
|
||||
ps.inject_tcp(state.flows[0], False, b"b")
|
||||
assert await reader.read(1) == b"B"
|
||||
ps.inject_tcp(state.flows[0], True, "c")
|
||||
ps.inject_tcp(state.flows[0], True, b"c")
|
||||
assert await reader.read(1) == b"c"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_fail():
|
||||
async def test_inject_fail() -> None:
|
||||
ps = Proxyserver()
|
||||
with taddons.context(ps) as tctx:
|
||||
ps.inject_websocket(
|
||||
tflow.tflow(),
|
||||
True,
|
||||
"test"
|
||||
b"test"
|
||||
)
|
||||
await tctx.master.await_log("Cannot inject WebSocket messages into non-WebSocket flows.", level="warn")
|
||||
ps.inject_tcp(
|
||||
tflow.tflow(),
|
||||
True,
|
||||
"test"
|
||||
b"test"
|
||||
)
|
||||
await tctx.master.await_log("Cannot inject TCP messages into non-TCP flows.", level="warn")
|
||||
|
||||
ps.inject_websocket(
|
||||
tflow.twebsocketflow(),
|
||||
True,
|
||||
"test"
|
||||
b"test"
|
||||
)
|
||||
await tctx.master.await_log("Flow is not from a live connection.", level="warn")
|
||||
ps.inject_websocket(
|
||||
tflow.ttcpflow(),
|
||||
True,
|
||||
"test"
|
||||
b"test"
|
||||
)
|
||||
await tctx.master.await_log("Flow is not from a live connection.", level="warn")
|
||||
|
||||
|
@ -367,24 +367,6 @@ class TestCommand:
|
||||
],
|
||||
[],
|
||||
],
|
||||
[
|
||||
r'cmd13 "a \"b\" c"',
|
||||
[
|
||||
command.ParseResult(value="cmd13", type=mitmproxy.types.Cmd, valid=False),
|
||||
command.ParseResult(value=" ", type=mitmproxy.types.Space, valid=True),
|
||||
command.ParseResult(value=r'"a \"b\" c"', type=mitmproxy.types.Unknown, valid=False),
|
||||
],
|
||||
[],
|
||||
],
|
||||
[
|
||||
r"cmd14 'a \'b\' c'",
|
||||
[
|
||||
command.ParseResult(value="cmd14", type=mitmproxy.types.Cmd, valid=False),
|
||||
command.ParseResult(value=" ", type=mitmproxy.types.Space, valid=True),
|
||||
command.ParseResult(value=r"'a \'b\' c'", type=mitmproxy.types.Unknown, valid=False),
|
||||
],
|
||||
[],
|
||||
],
|
||||
[
|
||||
" spaces_at_the_begining_are_not_stripped",
|
||||
[
|
||||
@ -436,12 +418,6 @@ def test_simple():
|
||||
c.call("nonexistent")
|
||||
with pytest.raises(exceptions.CommandError, match="Unknown"):
|
||||
c.execute("\\")
|
||||
with pytest.raises(exceptions.CommandError, match="Unknown"):
|
||||
c.execute(r"\'")
|
||||
with pytest.raises(exceptions.CommandError, match="Unknown"):
|
||||
c.execute(r"\"")
|
||||
with pytest.raises(exceptions.CommandError, match="Unknown"):
|
||||
c.execute(r"\"")
|
||||
|
||||
c.add("empty", a.empty)
|
||||
c.execute("empty")
|
||||
|
@ -11,7 +11,6 @@ from mitmproxy import command_lexer
|
||||
("'foo'", True),
|
||||
('"foo"', True),
|
||||
("'foo' bar'", False),
|
||||
("'foo\\' bar'", True),
|
||||
("'foo' 'bar'", False),
|
||||
("'foo'x", False),
|
||||
('''"foo ''', True),
|
||||
@ -43,8 +42,19 @@ def test_expr(test_input, expected):
|
||||
|
||||
|
||||
@given(text())
|
||||
@example(r"foo")
|
||||
@example(r"'foo\''")
|
||||
@example(r"'foo\"'")
|
||||
@example(r'"foo\""')
|
||||
@example(r'"foo\'"')
|
||||
@example("'foo\\'")
|
||||
@example("'foo\\\\'")
|
||||
@example("\"foo\\'\"")
|
||||
@example("\"foo\\\\'\"")
|
||||
@example('\'foo\\"\'')
|
||||
@example(r"\\\foo")
|
||||
def test_quote_unquote_cycle(s):
|
||||
assert command_lexer.unquote(command_lexer.quote(s)) == s
|
||||
assert command_lexer.unquote(command_lexer.quote(s)).replace(r"\x22", '"') == s
|
||||
|
||||
|
||||
@given(text())
|
||||
|
@ -40,6 +40,21 @@ def test_str():
|
||||
assert b.is_valid(tctx.master.commands, str, 1) is False
|
||||
assert b.completion(tctx.master.commands, str, "") == []
|
||||
assert b.parse(tctx.master.commands, str, "foo") == "foo"
|
||||
assert b.parse(tctx.master.commands, str, r"foo\nbar") == "foo\nbar"
|
||||
assert b.parse(tctx.master.commands, str, r"\N{BELL}") == "🔔"
|
||||
with pytest.raises(mitmproxy.exceptions.TypeError):
|
||||
b.parse(tctx.master.commands, bool, r"\N{UNKNOWN UNICODE SYMBOL!}")
|
||||
|
||||
|
||||
def test_bytes():
|
||||
with taddons.context() as tctx:
|
||||
b = mitmproxy.types._BytesType()
|
||||
assert b.is_valid(tctx.master.commands, bytes, b"foo") is True
|
||||
assert b.is_valid(tctx.master.commands, bytes, 1) is False
|
||||
assert b.completion(tctx.master.commands, bytes, "") == []
|
||||
assert b.parse(tctx.master.commands, bytes, "foo") == b"foo"
|
||||
with pytest.raises(mitmproxy.exceptions.TypeError):
|
||||
b.parse(tctx.master.commands, bytes, "incomplete escape sequence\\")
|
||||
|
||||
|
||||
def test_unknown():
|
||||
|
@ -1,3 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from mitmproxy import http
|
||||
from mitmproxy import websocket
|
||||
from mitmproxy.test import tflow
|
||||
@ -26,3 +28,18 @@ class TestWebSocketMessage:
|
||||
assert not m.killed
|
||||
m.kill()
|
||||
assert m.killed
|
||||
|
||||
def test_text(self):
|
||||
txt = websocket.WebSocketMessage(Opcode.TEXT, True, b"foo")
|
||||
bin = websocket.WebSocketMessage(Opcode.BINARY, True, b"foo")
|
||||
|
||||
assert txt.is_text
|
||||
assert txt.text == "foo"
|
||||
txt.text = "bar"
|
||||
assert txt.content == b"bar"
|
||||
|
||||
assert not bin.is_text
|
||||
with pytest.raises(AttributeError, match="do not have a 'text' attribute."):
|
||||
_ = bin.text
|
||||
with pytest.raises(AttributeError, match="do not have a 'text' attribute."):
|
||||
bin.text = "bar"
|
||||
|
@ -20,18 +20,18 @@ async def test_commands_exist():
|
||||
await m.load_flow(tflow())
|
||||
|
||||
for binding in km.bindings:
|
||||
parsed, _ = command_manager.parse_partial(binding.command.strip())
|
||||
|
||||
cmd = parsed[0].value
|
||||
args = [
|
||||
a.value for a in parsed[1:]
|
||||
if a.type != mitmproxy.types.Space
|
||||
]
|
||||
|
||||
assert cmd in m.commands.commands
|
||||
|
||||
cmd_obj = m.commands.commands[cmd]
|
||||
try:
|
||||
parsed, _ = command_manager.parse_partial(binding.command.strip())
|
||||
|
||||
cmd = parsed[0].value
|
||||
args = [
|
||||
a.value for a in parsed[1:]
|
||||
if a.type != mitmproxy.types.Space
|
||||
]
|
||||
|
||||
assert cmd in m.commands.commands
|
||||
|
||||
cmd_obj = m.commands.commands[cmd]
|
||||
cmd_obj.prepare_args(args)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid command: {binding.command}") from e
|
||||
raise ValueError(f"Invalid binding: {binding.command}") from e
|
||||
|
Loading…
Reference in New Issue
Block a user