optimize tnetstring parsing (#7121)

* Use memoryview to represent tnetstring

* Allow :data: in pop to be bytes | memory view to accomodate test

* Update CHANGELOG.md

* [autofix.ci] apply automated fixes

* Use str() instead of decode() to avoid one copy

Co-authored-by: Maximilian Hils <github@maximilianhils.com>

* Keep diff minimal

Co-authored-by: Maximilian Hils <github@maximilianhils.com>

* Make pop only accept argument of type memory view

* cache `ord()`

---------

Co-authored-by: Michele Russo <michele.russo@huawei.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Maximilian Hils <github@maximilianhils.com>
This commit is contained in:
Michele Russo 2024-08-20 21:55:47 +02:00 committed by GitHub
parent 499e8e8742
commit 332f222994
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 17 deletions

View File

@ -7,6 +7,8 @@
## Unreleased: mitmproxy next
- Fix endless tnetstring parsing in case of very large tnetstring
([#7121](https://github.com/mitmproxy/mitmproxy/pull/7121), @mik1904)
- Tighten HTTP detection heuristic to better support custom TCP-based protocols.
([#7087](https://github.com/mitmproxy/mitmproxy/pull/7087))
- Improve the error message when users specify the `certs` option without a matching private key.

View File

@ -154,7 +154,7 @@ def loads(string: bytes) -> TSerializable:
"""
This function parses a tnetstring into a python object.
"""
return pop(string)[0]
return pop(memoryview(string))[0]
def load(file_handle: BinaryIO) -> TSerializable:
@ -178,17 +178,17 @@ def load(file_handle: BinaryIO) -> TSerializable:
if c != b":":
raise ValueError("not a tnetstring: missing or invalid length prefix")
data = file_handle.read(int(data_length))
data = memoryview(file_handle.read(int(data_length)))
data_type = file_handle.read(1)[0]
return parse(data_type, data)
def parse(data_type: int, data: bytes) -> TSerializable:
def parse(data_type: int, data: memoryview) -> TSerializable:
if data_type == ord(b","):
return data
return data.tobytes()
if data_type == ord(b";"):
return data.decode("utf8")
return str(data, "utf8")
if data_type == ord(b"#"):
try:
return int(data)
@ -226,20 +226,28 @@ def parse(data_type: int, data: bytes) -> TSerializable:
raise ValueError(f"unknown type tag: {data_type}")
def pop(data: bytes) -> tuple[TSerializable, bytes]:
def split(data: memoryview, sep: bytes) -> tuple[int, memoryview]:
i = 0
try:
ord_sep = ord(sep)
while data[i] != ord_sep:
i += 1
# here i is the position of b":" in the memoryview
return int(data[:i]), data[i + 1 :]
except (IndexError, ValueError):
raise ValueError(
f"not a tnetstring: missing or invalid length prefix: {data.tobytes()!r}"
)
def pop(data: memoryview) -> tuple[TSerializable, memoryview]:
"""
This function parses a tnetstring into a python object.
It returns a tuple giving the parsed object and a string
containing any unparsed data from the end of the string.
"""
# Parse out data length, type and remaining string.
try:
blength, data = data.split(b":", 1)
length = int(blength)
except ValueError:
raise ValueError(
f"not a tnetstring: missing or invalid length prefix: {data!r}"
)
# Parse out data length, type and remaining string.
length, data = split(data, b":")
try:
data, data_type, remain = data[:length], data[length], data[length + 1 :]
except IndexError:

View File

@ -72,19 +72,19 @@ class Test_Format(unittest.TestCase):
for data, expect in FORMAT_EXAMPLES.items():
self.assertEqual(expect, tnetstring.loads(data))
self.assertEqual(expect, tnetstring.loads(tnetstring.dumps(expect)))
self.assertEqual((expect, b""), tnetstring.pop(data))
self.assertEqual((expect, b""), tnetstring.pop(memoryview(data)))
def test_roundtrip_format_random(self):
for _ in range(10):
v = get_random_object()
self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v)))
self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v)))
self.assertEqual((v, b""), tnetstring.pop(memoryview(tnetstring.dumps(v))))
def test_roundtrip_format_unicode(self):
for _ in range(10):
v = get_random_object()
self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v)))
self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v)))
self.assertEqual((v, b""), tnetstring.pop(memoryview(tnetstring.dumps(v))))
def test_roundtrip_big_integer(self):
# Recent Python versions do not like ints above 4300 digits, https://github.com/python/cpython/issues/95778