diff --git a/CHANGELOG b/CHANGELOG index 58d6694b2..05dd44695 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -9,6 +9,7 @@ Unreleased: mitmproxy next * Prevent transparent mode from connecting to itself in the basic cases (@prinzhorn) * Display HTTP trailers in mitmweb (@sanlengjingvv) * Revamp onboarding app (@mhils) + * Add ASGI support for embedded apps (@mhils) * --- TODO: add new PRs above this line --- diff --git a/examples/addons/wsgi-flask-app.py b/examples/addons/wsgi-flask-app.py index 2a9f0e2b7..43fd8cdca 100644 --- a/examples/addons/wsgi-flask-app.py +++ b/examples/addons/wsgi-flask-app.py @@ -6,7 +6,7 @@ instance, we're using the Flask framework (http://flask.pocoo.org/) to expose a single simplest-possible page. """ from flask import Flask -from mitmproxy.addons import wsgiapp +from mitmproxy.addons import asgiapp app = Flask("proxapp") @@ -19,7 +19,7 @@ def hello_world() -> str: addons = [ # Host app at the magic domain "example.com" on port 80. Requests to this # domain and port combination will now be routed to the WSGI app instance. - wsgiapp.WSGIApp(app, "example.com", 80) + asgiapp.WSGIApp(app, "example.com", 80) # SSL works too, but the magic domain needs to be resolvable from the mitmproxy machine due to mitmproxy's design. # mitmproxy will connect to said domain and use serve its certificate (unless --no-upstream-cert is set) # but won't send any data. diff --git a/mitmproxy/addons/asgiapp.py b/mitmproxy/addons/asgiapp.py new file mode 100644 index 000000000..71aed121e --- /dev/null +++ b/mitmproxy/addons/asgiapp.py @@ -0,0 +1,129 @@ +import asyncio +import urllib.parse + +import asgiref.compatibility +import asgiref.wsgi + +from mitmproxy import ctx, http + + +class ASGIApp: + """ + An addon that hosts an ASGI/WSGI HTTP app within mitmproxy, at a specified hostname and port. + + Some important caveats: + - This implementation will block and wait until the entire HTTP response is completed before sending out data. + - It currently only implements the HTTP protocol (Lifespan and WebSocket are unimplemented). + """ + + def __init__(self, asgi_app, host: str, port: int): + asgi_app = asgiref.compatibility.guarantee_single_callable(asgi_app) + self.asgi_app, self.host, self.port = asgi_app, host, port + + @property + def name(self) -> str: + return f"asgiapp:{self.host}:{self.port}" + + def request(self, flow: http.HTTPFlow) -> None: + assert flow.reply + if (flow.request.pretty_host, flow.request.port) == (self.host, self.port) and not flow.reply.has_message: + flow.reply.take() # pause hook completion + asyncio.create_task(serve(self.asgi_app, flow)) + + +class WSGIApp(ASGIApp): + def __init__(self, wsgi_app, host: str, port: int): + asgi_app = asgiref.wsgi.WsgiToAsgi(wsgi_app) + super().__init__(asgi_app, host, port) + + +HTTP_VERSION_MAP = { + "HTTP/1.0": "1.0", + "HTTP/1.1": "1.1", + "HTTP/2.0": "2", +} + + +def make_scope(flow: http.HTTPFlow) -> dict: + # %3F is a quoted question mark + quoted_path = urllib.parse.quote_from_bytes(flow.request.data.path).split("%3F", maxsplit=1) + + # (Unicode string) – HTTP request target excluding any query string, with percent-encoded + # sequences and UTF-8 byte sequences decoded into characters. + path = quoted_path[0] + + # (byte string) – URL portion after the ?, percent-encoded. + query_string: bytes + if len(quoted_path) > 1: + query_string = quoted_path[1].encode() + else: + query_string = b"" + + return { + "type": "http", + "asgi": { + "version": "3.0", + "spec_version": "2.1", + }, + "http_version": HTTP_VERSION_MAP.get(flow.request.http_version, "1.1"), + "method": flow.request.method, + "scheme": flow.request.scheme, + "path": path, + "raw_path": flow.request.path, + "query_string": query_string, + "headers": list(list(x) for x in flow.request.headers.fields), + "client": flow.client_conn.address, + "extensions": { + "mitmproxy.master": ctx.master, + } + } + + +async def serve(app, flow: http.HTTPFlow): + """ + Serves app on flow. + """ + assert flow.reply + + scope = make_scope(flow) + done = asyncio.Event() + received_body = False + + async def receive(): + nonlocal received_body + if not received_body: + received_body = True + return { + "type": "http.request", + "body": flow.request.raw_content, + } + else: # pragma: no cover + # We really don't expect this to be called a second time, but what to do? + # We just wait until the request is done before we continue here with sending a disconnect. + await done.wait() + return { + "type": "http.disconnect" + } + + async def send(event): + if event["type"] == "http.response.start": + flow.response = http.HTTPResponse.make(event["status"], b"", event.get("headers", [])) + flow.response.decode() + elif event["type"] == "http.response.body": + flow.response.content += event.get("body", b"") + if not event.get("more_body", False): + flow.reply.ack() + else: + raise AssertionError(f"Unexpected event: {event['type']}") + + try: + await app(scope, receive, send) + if not flow.reply.has_message: + raise RuntimeError(f"no response sent.") + except Exception as e: + ctx.log.error(f"Error in asgi app: {e}") + flow.response = http.HTTPResponse.make(500, b"ASGI Error.") + flow.reply.ack(force=True) + finally: + flow.reply.commit() + done.set() diff --git a/mitmproxy/addons/onboarding.py b/mitmproxy/addons/onboarding.py index 94ca7c490..b104140e6 100644 --- a/mitmproxy/addons/onboarding.py +++ b/mitmproxy/addons/onboarding.py @@ -1,4 +1,4 @@ -from mitmproxy.addons import wsgiapp +from mitmproxy.addons import asgiapp from mitmproxy.addons.onboardingapp import app from mitmproxy import ctx @@ -6,11 +6,11 @@ APP_HOST = "mitm.it" APP_PORT = 80 -class Onboarding(wsgiapp.WSGIApp): +class Onboarding(asgiapp.WSGIApp): name = "onboarding" def __init__(self): - super().__init__(app, None, None) + super().__init__(app, APP_HOST, APP_PORT) def load(self, loader): loader.add_option( diff --git a/mitmproxy/addons/wsgiapp.py b/mitmproxy/addons/wsgiapp.py deleted file mode 100644 index 549d8c87e..000000000 --- a/mitmproxy/addons/wsgiapp.py +++ /dev/null @@ -1,42 +0,0 @@ -from mitmproxy import ctx -from mitmproxy import exceptions - -from mitmproxy.net import wsgi -from mitmproxy import version - - -class WSGIApp: - """ - An addon that hosts a WSGI app within mitproxy, at a specified - hostname and port. - """ - def __init__(self, app, host, port): - self.app, self.host, self.port = app, host, port - - @property - def name(self): - return "wsgiapp:%s:%s" % (self.host, self.port) - - def serve(self, app, flow): - """ - Serves app on flow, and prevents further handling of the flow. - """ - app = wsgi.WSGIAdaptor( - app, - flow.request.pretty_host, - flow.request.port, - version.MITMPROXY - ) - err = app.serve( - flow, - flow.client_conn.wfile, - **{"mitmproxy.master": ctx.master} - ) - if err: - ctx.log.error("Error in wsgi app. %s" % err) - raise exceptions.AddonHalt() - flow.reply.kill() - - def request(self, f): - if (f.request.pretty_host, f.request.port) == (self.host, self.port): - self.serve(self.app, f) diff --git a/mitmproxy/net/wsgi.py b/mitmproxy/net/wsgi.py deleted file mode 100644 index a40dcecaa..000000000 --- a/mitmproxy/net/wsgi.py +++ /dev/null @@ -1,165 +0,0 @@ -import time -import traceback -import urllib -import io - -from mitmproxy.net import http -from mitmproxy.utils import strutils - - -class ClientConn: - - def __init__(self, address): - self.address = address - - -class Flow: - - def __init__(self, address, request): - self.client_conn = ClientConn(address) - self.request = request - - -class Request: - - def __init__(self, scheme, method, path, http_version, headers, content): - self.scheme, self.method, self.path = scheme, method, path - self.headers, self.content = headers, content - self.http_version = http_version - - -def date_time_string(): - """Return the current date and time formatted for a message header.""" - WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - MONTHS = [ - None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' - ] - now = time.time() - year, month, day, hh, mm, ss, wd, y_, z_ = time.gmtime(now) - s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - WEEKS[wd], - day, MONTHS[month], year, - hh, mm, ss - ) - return s - - -class WSGIAdaptor: - - def __init__(self, app, domain, port, sversion): - self.app, self.domain, self.port, self.sversion = app, domain, port, sversion - - def make_environ(self, flow, errsoc, **extra): - """ - Raises: - ValueError, if the content-encoding is invalid. - """ - path = strutils.always_str(flow.request.path, "latin-1") - if '?' in path: - path_info, query = strutils.always_str(path, "latin-1").split('?', 1) - else: - path_info = path - query = '' - environ = { - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': strutils.always_str(flow.request.scheme, "latin-1"), - 'wsgi.input': io.BytesIO(flow.request.content or b""), - 'wsgi.errors': errsoc, - 'wsgi.multithread': True, - 'wsgi.multiprocess': False, - 'wsgi.run_once': False, - 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': strutils.always_str(flow.request.method, "latin-1"), - 'SCRIPT_NAME': '', - 'PATH_INFO': urllib.parse.unquote(path_info), - 'QUERY_STRING': query, - 'CONTENT_TYPE': strutils.always_str(flow.request.headers.get('Content-Type', ''), "latin-1"), - 'CONTENT_LENGTH': strutils.always_str(flow.request.headers.get('Content-Length', ''), "latin-1"), - 'SERVER_NAME': self.domain, - 'SERVER_PORT': str(self.port), - 'SERVER_PROTOCOL': strutils.always_str(flow.request.http_version, "latin-1"), - } - environ.update(extra) - if flow.client_conn.address: - environ["REMOTE_ADDR"] = strutils.always_str(flow.client_conn.address[0], "latin-1") - environ["REMOTE_PORT"] = flow.client_conn.address[1] - - for key, value in flow.request.headers.items(): - key = 'HTTP_' + strutils.always_str(key, "latin-1").upper().replace('-', '_') - if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): - environ[key] = value - return environ - - def error_page(self, soc, headers_sent, s): - """ - Make a best-effort attempt to write an error page. If headers are - already sent, we just bung the error into the page. - """ - c = """ - -

Internal Server Error

-
{err}"
- - """.format(err=s).strip().encode() - - if not headers_sent: - soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") - soc.write(b"Content-Type: text/html\r\n") - soc.write("Content-Length: {length}\r\n".format(length=len(c)).encode()) - soc.write(b"\r\n") - soc.write(c) - - def serve(self, request, soc, **env): - state = dict( - response_started=False, - headers_sent=False, - status=None, - headers=None - ) - - def write(data): - if not state["headers_sent"]: - soc.write("HTTP/1.1 {status}\r\n".format(status=state["status"]).encode()) - headers = state["headers"] - if 'server' not in headers: - headers["Server"] = self.sversion - if 'date' not in headers: - headers["Date"] = date_time_string() - soc.write(bytes(headers)) - soc.write(b"\r\n") - state["headers_sent"] = True - if data: - soc.write(data) - soc.flush() - - def start_response(status, headers, exc_info=None): - if exc_info: - if state["headers_sent"]: - raise exc_info[1] - elif state["status"]: - raise AssertionError('Response already started') - state["status"] = status - state["headers"] = http.Headers([[strutils.always_bytes(k), strutils.always_bytes(v)] for k, v in headers]) - if exc_info: - self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2])) - state["headers_sent"] = True - - errs = io.BytesIO() - try: - dataiter = self.app( - self.make_environ(request, errs, **env), start_response - ) - for i in dataiter: - write(i) - if not state["headers_sent"]: - write(b"") - except Exception: - try: - s = traceback.format_exc() - errs.write(s.encode("utf-8", "replace")) - self.error_page(soc, state["headers_sent"], s) - except Exception: # pragma: no cover - pass - return errs.getvalue() diff --git a/setup.py b/setup.py index 8b65a5029..770cb8dad 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ setup( # https://packaging.python.org/en/latest/requirements/#install-requires # It is not considered best practice to use install_requires to pin dependencies to specific versions. install_requires=[ + "asgiref>=3.2.10, <3.3", "blinker>=1.4, <1.5", "Brotli>=1.0,<1.1", "certifi>=2019.9.11", # no semver here - this should always be on the last release! diff --git a/test/mitmproxy/addons/test_asgiapp.py b/test/mitmproxy/addons/test_asgiapp.py new file mode 100644 index 000000000..5031e2c64 --- /dev/null +++ b/test/mitmproxy/addons/test_asgiapp.py @@ -0,0 +1,53 @@ +import flask + +from .. import tservers +from mitmproxy.addons import asgiapp + +tapp = flask.Flask(__name__) + + +@tapp.route("/") +def hello(): + return "testapp" + + +@tapp.route("/error") +def error(): + raise ValueError("An exception...") + + +async def errapp(scope, receive, send): + raise ValueError("errapp") + + +async def noresponseapp(scope, receive, send): + return + + +class TestApp(tservers.HTTPProxyTest): + def addons(self): + return [ + asgiapp.WSGIApp(tapp, "testapp", 80), + asgiapp.ASGIApp(errapp, "errapp", 80), + asgiapp.ASGIApp(noresponseapp, "noresponseapp", 80), + ] + + def test_simple(self): + p = self.pathoc() + with p.connect(): + ret = p.request("get:'http://testapp/'") + assert b"testapp" in ret.content + + def test_app_err(self): + p = self.pathoc() + with p.connect(): + ret = p.request("get:'http://errapp/?foo=bar'") + assert ret.status_code == 500 + assert b"ASGI Error" in ret.content + + def test_app_no_response(self): + p = self.pathoc() + with p.connect(): + ret = p.request("get:'http://noresponseapp/'") + assert ret.status_code == 500 + assert b"ASGI Error" in ret.content \ No newline at end of file diff --git a/test/mitmproxy/addons/test_wsgiapp.py b/test/mitmproxy/addons/test_wsgiapp.py deleted file mode 100644 index 760ee460b..000000000 --- a/test/mitmproxy/addons/test_wsgiapp.py +++ /dev/null @@ -1,41 +0,0 @@ -import flask - -from .. import tservers -from mitmproxy.addons import wsgiapp - -tapp = flask.Flask(__name__) - - -@tapp.route("/") -def hello(): - return "testapp" - - -@tapp.route("/error") -def error(): - raise ValueError("An exception...") - - -def errapp(environ, start_response): - raise ValueError("errapp") - - -class TestApp(tservers.HTTPProxyTest): - def addons(self): - return [ - wsgiapp.WSGIApp(tapp, "testapp", 80), - wsgiapp.WSGIApp(errapp, "errapp", 80) - ] - - def test_simple(self): - p = self.pathoc() - with p.connect(): - ret = p.request("get:'http://testapp/'") - assert ret.status_code == 200 - - def test_app_err(self): - p = self.pathoc() - with p.connect(): - ret = p.request("get:'http://errapp/'") - assert ret.status_code == 500 - assert b"ValueError" in ret.content diff --git a/test/mitmproxy/net/test_wsgi.py b/test/mitmproxy/net/test_wsgi.py deleted file mode 100644 index b4d6b53f2..000000000 --- a/test/mitmproxy/net/test_wsgi.py +++ /dev/null @@ -1,106 +0,0 @@ -from io import BytesIO -import sys -from mitmproxy.net import wsgi -from mitmproxy.net.http import Headers - - -def tflow(): - headers = Headers(test=b"value") - req = wsgi.Request("http", "GET", "/", "HTTP/1.1", headers, "") - return wsgi.Flow(("127.0.0.1", 8888), req) - - -class ExampleApp: - - def __init__(self): - self.called = False - - def __call__(self, environ, start_response): - self.called = True - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] - start_response(status, response_headers) - return [b'Hello', b' world!\n'] - - -class TestWSGI: - - def test_make_environ(self): - w = wsgi.WSGIAdaptor(None, "foo", 80, "version") - tf = tflow() - assert w.make_environ(tf, None) - - tf.request.path = "/foo?bar=voing" - r = w.make_environ(tf, None) - assert r["QUERY_STRING"] == "bar=voing" - - def test_serve(self): - ta = ExampleApp() - w = wsgi.WSGIAdaptor(ta, "foo", 80, "version") - f = tflow() - f.request.host = "foo" - f.request.port = 80 - - wfile = BytesIO() - err = w.serve(f, wfile) - assert ta.called - assert not err - - val = wfile.getvalue() - assert b"Hello world" in val - assert b"Server:" in val - - def _serve(self, app): - w = wsgi.WSGIAdaptor(app, "foo", 80, "version") - f = tflow() - f.request.host = "foo" - f.request.port = 80 - wfile = BytesIO() - w.serve(f, wfile) - return wfile.getvalue() - - def test_serve_empty_body(self): - def app(environ, start_response): - status = '200 OK' - response_headers = [('Foo', 'bar')] - start_response(status, response_headers) - return [] - assert self._serve(app) - - def test_serve_double_start(self): - def app(environ, start_response): - try: - raise ValueError("foo") - except: - sys.exc_info() - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] - start_response(status, response_headers) - start_response(status, response_headers) - assert b"Internal Server Error" in self._serve(app) - - def test_serve_single_err(self): - def app(environ, start_response): - try: - raise ValueError("foo") - except: - ei = sys.exc_info() - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] - start_response(status, response_headers, ei) - yield b"" - assert b"Internal Server Error" in self._serve(app) - - def test_serve_double_err(self): - def app(environ, start_response): - try: - raise ValueError("foo") - except: - ei = sys.exc_info() - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] - start_response(status, response_headers) - yield b"aaa" - start_response(status, response_headers, ei) - yield b"bbb" - assert b"Internal Server Error" in self._serve(app)