Decouple message type from message class name.

This commit is contained in:
Aldo Cortesi 2014-01-04 14:42:32 +13:00
parent 1e07d9e6e7
commit 45eab17e0c
3 changed files with 20 additions and 20 deletions

View File

@ -39,13 +39,13 @@ class Channel:
def __init__(self, q):
self.q = q
def ask(self, m):
def ask(self, mtype, m):
"""
Decorate a message with a reply attribute, and send it to the
master. then wait for a response.
"""
m.reply = Reply(m)
self.q.put(m)
self.q.put((mtype, m))
while not should_exit:
try:
# The timeout is here so we can handle a should_exit event.
@ -54,13 +54,13 @@ class Channel:
continue
return g
def tell(self, m):
def tell(self, mtype, m):
"""
Decorate a message with a dummy reply attribute, send it to the
master, then return immediately.
"""
m.reply = DummyReply()
self.q.put(m)
self.q.put((mtype, m))
class Slave(threading.Thread):
@ -98,7 +98,7 @@ class Master:
while True:
# Small timeout to prevent pegging the CPU
msg = q.get(timeout=0.01)
self.handle(msg)
self.handle(*msg)
changed = True
except Queue.Empty:
pass
@ -112,13 +112,13 @@ class Master:
self.tick(self.masterq)
self.shutdown()
def handle(self, msg):
c = "handle_" + msg.__class__.__name__.lower()
def handle(self, mtype, obj):
c = "handle_" + mtype
m = getattr(self, c, None)
if m:
m(msg)
m(obj)
else:
msg.reply()
obj.reply()
def shutdown(self):
global should_exit

View File

@ -97,10 +97,10 @@ class RequestReplayThread(threading.Thread):
self.flow.request, httpversion, code, msg, headers, content, server.cert,
server.rfile.first_byte_timestamp
)
self.channel.ask(response)
self.channel.ask("response", response)
except (ProxyError, http.HttpError, tcp.NetLibError), v:
err = flow.Error(self.flow.request, str(v))
self.channel.ask(err)
self.channel.ask("error", err)
class HandleSNI:
@ -173,7 +173,7 @@ class ProxyHandler(tcp.BaseHandler):
self.server_conn.require_request = False
self.server_conn.conn_info = conn_info
self.channel.ask(self.server_conn)
self.channel.ask("serverconnect", self.server_conn)
self.server_conn.connect()
except tcp.NetLibError, v:
raise ProxyError(502, v)
@ -187,7 +187,7 @@ class ProxyHandler(tcp.BaseHandler):
def handle(self):
cc = flow.ClientConnect(self.client_address)
self.log(cc, "connect")
self.channel.ask(cc)
self.channel.ask("clientconnect", cc)
while self.handle_request(cc) and not cc.close:
pass
cc.close = True
@ -199,7 +199,7 @@ class ProxyHandler(tcp.BaseHandler):
[
"handled %s requests"%cc.requestcount]
)
self.channel.tell(cd)
self.channel.tell("clientdisconnect", cd)
def handle_request(self, cc):
try:
@ -209,13 +209,13 @@ class ProxyHandler(tcp.BaseHandler):
return
cc.requestcount += 1
request_reply = self.channel.ask(request)
request_reply = self.channel.ask("request", request)
if request_reply is None or request_reply == KILL:
return
elif isinstance(request_reply, flow.Response):
request = False
response = request_reply
response_reply = self.channel.ask(response)
response_reply = self.channel.ask("response", response)
else:
request = request_reply
if self.config.reverse_proxy:
@ -261,7 +261,7 @@ class ProxyHandler(tcp.BaseHandler):
request, httpversion, code, msg, headers, content, sc.cert,
sc.rfile.first_byte_timestamp
)
response_reply = self.channel.ask(response)
response_reply = self.channel.ask("response", response)
# Not replying to the server invalidates the server
# connection, so we terminate.
if response_reply == KILL:
@ -288,7 +288,7 @@ class ProxyHandler(tcp.BaseHandler):
if request:
err = flow.Error(request, cc.error)
self.channel.ask(err)
self.channel.ask("error", err)
self.log(
cc, cc.error,
["url: %s"%request.get_url()]
@ -308,7 +308,7 @@ class ProxyHandler(tcp.BaseHandler):
msg.append(" -> "+i)
msg = "\n".join(msg)
l = Log(msg)
self.channel.tell(l)
self.channel.tell("log", l)
def find_cert(self, cc, host, port, sni):
if self.config.certfile:

View File

@ -6,7 +6,7 @@ class TestMaster:
def test_default_handler(self):
m = controller.Master(None)
msg = mock.MagicMock()
m.handle(msg)
m.handle("type", msg)
assert msg.reply.call_count == 1