mirror of
https://github.com/mitmproxy/mitmproxy.git
synced 2024-11-23 13:19:48 +00:00
optimize tnetstring parsing (#7121)
* Use memoryview to represent tnetstring * Allow :data: in pop to be bytes | memory view to accomodate test * Update CHANGELOG.md * [autofix.ci] apply automated fixes * Use str() instead of decode() to avoid one copy Co-authored-by: Maximilian Hils <github@maximilianhils.com> * Keep diff minimal Co-authored-by: Maximilian Hils <github@maximilianhils.com> * Make pop only accept argument of type memory view * cache `ord()` --------- Co-authored-by: Michele Russo <michele.russo@huawei.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Hils <github@maximilianhils.com>
This commit is contained in:
parent
499e8e8742
commit
332f222994
@ -7,6 +7,8 @@
|
||||
|
||||
## Unreleased: mitmproxy next
|
||||
|
||||
- Fix endless tnetstring parsing in case of very large tnetstring
|
||||
([#7121](https://github.com/mitmproxy/mitmproxy/pull/7121), @mik1904)
|
||||
- Tighten HTTP detection heuristic to better support custom TCP-based protocols.
|
||||
([#7087](https://github.com/mitmproxy/mitmproxy/pull/7087))
|
||||
- Improve the error message when users specify the `certs` option without a matching private key.
|
||||
|
@ -154,7 +154,7 @@ def loads(string: bytes) -> TSerializable:
|
||||
"""
|
||||
This function parses a tnetstring into a python object.
|
||||
"""
|
||||
return pop(string)[0]
|
||||
return pop(memoryview(string))[0]
|
||||
|
||||
|
||||
def load(file_handle: BinaryIO) -> TSerializable:
|
||||
@ -178,17 +178,17 @@ def load(file_handle: BinaryIO) -> TSerializable:
|
||||
if c != b":":
|
||||
raise ValueError("not a tnetstring: missing or invalid length prefix")
|
||||
|
||||
data = file_handle.read(int(data_length))
|
||||
data = memoryview(file_handle.read(int(data_length)))
|
||||
data_type = file_handle.read(1)[0]
|
||||
|
||||
return parse(data_type, data)
|
||||
|
||||
|
||||
def parse(data_type: int, data: bytes) -> TSerializable:
|
||||
def parse(data_type: int, data: memoryview) -> TSerializable:
|
||||
if data_type == ord(b","):
|
||||
return data
|
||||
return data.tobytes()
|
||||
if data_type == ord(b";"):
|
||||
return data.decode("utf8")
|
||||
return str(data, "utf8")
|
||||
if data_type == ord(b"#"):
|
||||
try:
|
||||
return int(data)
|
||||
@ -226,20 +226,28 @@ def parse(data_type: int, data: bytes) -> TSerializable:
|
||||
raise ValueError(f"unknown type tag: {data_type}")
|
||||
|
||||
|
||||
def pop(data: bytes) -> tuple[TSerializable, bytes]:
|
||||
def split(data: memoryview, sep: bytes) -> tuple[int, memoryview]:
|
||||
i = 0
|
||||
try:
|
||||
ord_sep = ord(sep)
|
||||
while data[i] != ord_sep:
|
||||
i += 1
|
||||
# here i is the position of b":" in the memoryview
|
||||
return int(data[:i]), data[i + 1 :]
|
||||
except (IndexError, ValueError):
|
||||
raise ValueError(
|
||||
f"not a tnetstring: missing or invalid length prefix: {data.tobytes()!r}"
|
||||
)
|
||||
|
||||
|
||||
def pop(data: memoryview) -> tuple[TSerializable, memoryview]:
|
||||
"""
|
||||
This function parses a tnetstring into a python object.
|
||||
It returns a tuple giving the parsed object and a string
|
||||
containing any unparsed data from the end of the string.
|
||||
"""
|
||||
# Parse out data length, type and remaining string.
|
||||
try:
|
||||
blength, data = data.split(b":", 1)
|
||||
length = int(blength)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"not a tnetstring: missing or invalid length prefix: {data!r}"
|
||||
)
|
||||
# Parse out data length, type and remaining string.
|
||||
length, data = split(data, b":")
|
||||
try:
|
||||
data, data_type, remain = data[:length], data[length], data[length + 1 :]
|
||||
except IndexError:
|
||||
|
@ -72,19 +72,19 @@ class Test_Format(unittest.TestCase):
|
||||
for data, expect in FORMAT_EXAMPLES.items():
|
||||
self.assertEqual(expect, tnetstring.loads(data))
|
||||
self.assertEqual(expect, tnetstring.loads(tnetstring.dumps(expect)))
|
||||
self.assertEqual((expect, b""), tnetstring.pop(data))
|
||||
self.assertEqual((expect, b""), tnetstring.pop(memoryview(data)))
|
||||
|
||||
def test_roundtrip_format_random(self):
|
||||
for _ in range(10):
|
||||
v = get_random_object()
|
||||
self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v)))
|
||||
self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v)))
|
||||
self.assertEqual((v, b""), tnetstring.pop(memoryview(tnetstring.dumps(v))))
|
||||
|
||||
def test_roundtrip_format_unicode(self):
|
||||
for _ in range(10):
|
||||
v = get_random_object()
|
||||
self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v)))
|
||||
self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v)))
|
||||
self.assertEqual((v, b""), tnetstring.pop(memoryview(tnetstring.dumps(v))))
|
||||
|
||||
def test_roundtrip_big_integer(self):
|
||||
# Recent Python versions do not like ints above 4300 digits, https://github.com/python/cpython/issues/95778
|
||||
|
Loading…
Reference in New Issue
Block a user