mirror of
https://github.com/mitmproxy/mitmproxy.git
synced 2024-11-25 14:20:03 +00:00
Decouple message type from message class name.
This commit is contained in:
parent
1e07d9e6e7
commit
45eab17e0c
@ -39,13 +39,13 @@ class Channel:
|
||||
def __init__(self, q):
|
||||
self.q = q
|
||||
|
||||
def ask(self, m):
|
||||
def ask(self, mtype, m):
|
||||
"""
|
||||
Decorate a message with a reply attribute, and send it to the
|
||||
master. then wait for a response.
|
||||
"""
|
||||
m.reply = Reply(m)
|
||||
self.q.put(m)
|
||||
self.q.put((mtype, m))
|
||||
while not should_exit:
|
||||
try:
|
||||
# The timeout is here so we can handle a should_exit event.
|
||||
@ -54,13 +54,13 @@ class Channel:
|
||||
continue
|
||||
return g
|
||||
|
||||
def tell(self, m):
|
||||
def tell(self, mtype, m):
|
||||
"""
|
||||
Decorate a message with a dummy reply attribute, send it to the
|
||||
master, then return immediately.
|
||||
"""
|
||||
m.reply = DummyReply()
|
||||
self.q.put(m)
|
||||
self.q.put((mtype, m))
|
||||
|
||||
|
||||
class Slave(threading.Thread):
|
||||
@ -98,7 +98,7 @@ class Master:
|
||||
while True:
|
||||
# Small timeout to prevent pegging the CPU
|
||||
msg = q.get(timeout=0.01)
|
||||
self.handle(msg)
|
||||
self.handle(*msg)
|
||||
changed = True
|
||||
except Queue.Empty:
|
||||
pass
|
||||
@ -112,13 +112,13 @@ class Master:
|
||||
self.tick(self.masterq)
|
||||
self.shutdown()
|
||||
|
||||
def handle(self, msg):
|
||||
c = "handle_" + msg.__class__.__name__.lower()
|
||||
def handle(self, mtype, obj):
|
||||
c = "handle_" + mtype
|
||||
m = getattr(self, c, None)
|
||||
if m:
|
||||
m(msg)
|
||||
m(obj)
|
||||
else:
|
||||
msg.reply()
|
||||
obj.reply()
|
||||
|
||||
def shutdown(self):
|
||||
global should_exit
|
||||
|
@ -97,10 +97,10 @@ class RequestReplayThread(threading.Thread):
|
||||
self.flow.request, httpversion, code, msg, headers, content, server.cert,
|
||||
server.rfile.first_byte_timestamp
|
||||
)
|
||||
self.channel.ask(response)
|
||||
self.channel.ask("response", response)
|
||||
except (ProxyError, http.HttpError, tcp.NetLibError), v:
|
||||
err = flow.Error(self.flow.request, str(v))
|
||||
self.channel.ask(err)
|
||||
self.channel.ask("error", err)
|
||||
|
||||
|
||||
class HandleSNI:
|
||||
@ -173,7 +173,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
self.server_conn.require_request = False
|
||||
|
||||
self.server_conn.conn_info = conn_info
|
||||
self.channel.ask(self.server_conn)
|
||||
self.channel.ask("serverconnect", self.server_conn)
|
||||
self.server_conn.connect()
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(502, v)
|
||||
@ -187,7 +187,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
cc = flow.ClientConnect(self.client_address)
|
||||
self.log(cc, "connect")
|
||||
self.channel.ask(cc)
|
||||
self.channel.ask("clientconnect", cc)
|
||||
while self.handle_request(cc) and not cc.close:
|
||||
pass
|
||||
cc.close = True
|
||||
@ -199,7 +199,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
[
|
||||
"handled %s requests"%cc.requestcount]
|
||||
)
|
||||
self.channel.tell(cd)
|
||||
self.channel.tell("clientdisconnect", cd)
|
||||
|
||||
def handle_request(self, cc):
|
||||
try:
|
||||
@ -209,13 +209,13 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
return
|
||||
cc.requestcount += 1
|
||||
|
||||
request_reply = self.channel.ask(request)
|
||||
request_reply = self.channel.ask("request", request)
|
||||
if request_reply is None or request_reply == KILL:
|
||||
return
|
||||
elif isinstance(request_reply, flow.Response):
|
||||
request = False
|
||||
response = request_reply
|
||||
response_reply = self.channel.ask(response)
|
||||
response_reply = self.channel.ask("response", response)
|
||||
else:
|
||||
request = request_reply
|
||||
if self.config.reverse_proxy:
|
||||
@ -261,7 +261,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
request, httpversion, code, msg, headers, content, sc.cert,
|
||||
sc.rfile.first_byte_timestamp
|
||||
)
|
||||
response_reply = self.channel.ask(response)
|
||||
response_reply = self.channel.ask("response", response)
|
||||
# Not replying to the server invalidates the server
|
||||
# connection, so we terminate.
|
||||
if response_reply == KILL:
|
||||
@ -288,7 +288,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
|
||||
if request:
|
||||
err = flow.Error(request, cc.error)
|
||||
self.channel.ask(err)
|
||||
self.channel.ask("error", err)
|
||||
self.log(
|
||||
cc, cc.error,
|
||||
["url: %s"%request.get_url()]
|
||||
@ -308,7 +308,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
msg.append(" -> "+i)
|
||||
msg = "\n".join(msg)
|
||||
l = Log(msg)
|
||||
self.channel.tell(l)
|
||||
self.channel.tell("log", l)
|
||||
|
||||
def find_cert(self, cc, host, port, sni):
|
||||
if self.config.certfile:
|
||||
|
@ -6,7 +6,7 @@ class TestMaster:
|
||||
def test_default_handler(self):
|
||||
m = controller.Master(None)
|
||||
msg = mock.MagicMock()
|
||||
m.handle(msg)
|
||||
m.handle("type", msg)
|
||||
assert msg.reply.call_count == 1
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user