mirror of
https://github.com/mitmproxy/mitmproxy.git
synced 2024-12-11 23:23:58 +00:00
handle script hooks in replay, fix tests, fix #402
This commit is contained in:
parent
9b5a8af12d
commit
0c52b4e3b9
@ -169,6 +169,7 @@ class ClientPlaybackState:
|
||||
def __init__(self, flows, exit):
|
||||
self.flows, self.exit = flows, exit
|
||||
self.current = None
|
||||
self.testing = False # Disables actual replay for testing.
|
||||
|
||||
def count(self):
|
||||
return len(self.flows)
|
||||
@ -186,19 +187,16 @@ class ClientPlaybackState:
|
||||
if flow is self.current:
|
||||
self.current = None
|
||||
|
||||
def tick(self, master, testing=False):
|
||||
"""
|
||||
testing: Disables actual replay for testing.
|
||||
"""
|
||||
def tick(self, master):
|
||||
if self.flows and not self.current:
|
||||
n = self.flows.pop(0).copy()
|
||||
n.response = None
|
||||
n.reply = controller.DummyReply()
|
||||
self.current = master.handle_request(n)
|
||||
if not testing and not self.current.response:
|
||||
master.replay_request(self.current) # pragma: no cover
|
||||
elif self.current.response:
|
||||
master.handle_response(self.current)
|
||||
self.current = self.flows.pop(0).copy()
|
||||
if not self.testing:
|
||||
master.replay_request(self.current)
|
||||
else:
|
||||
self.current.reply = controller.DummyReply()
|
||||
master.handle_request(self.current)
|
||||
if self.current.response:
|
||||
master.handle_response(self.current)
|
||||
|
||||
|
||||
class ServerPlaybackState:
|
||||
@ -371,6 +369,8 @@ class State(object):
|
||||
"""
|
||||
Add a request to the state. Returns the matching flow.
|
||||
"""
|
||||
if flow in self._flow_list: # catch flow replay
|
||||
return flow
|
||||
self._flow_list.append(flow)
|
||||
if flow.match(self._limit):
|
||||
self.view.append(flow)
|
||||
|
@ -1040,7 +1040,7 @@ class HTTPHandler(ProtocolHandler):
|
||||
# inline script to set flow.stream = True
|
||||
flow = self.c.channel.ask("responseheaders", flow)
|
||||
if flow == KILL:
|
||||
raise KillSignal
|
||||
raise KillSignal()
|
||||
else:
|
||||
# now get the rest of the request body, if body still needs to be
|
||||
# read but not streaming this response
|
||||
@ -1085,7 +1085,7 @@ class HTTPHandler(ProtocolHandler):
|
||||
self.process_server_address(flow) # The inline script may have changed request.host
|
||||
|
||||
if request_reply is None or request_reply == KILL:
|
||||
return False
|
||||
raise KillSignal()
|
||||
|
||||
if isinstance(request_reply, HTTPResponse):
|
||||
flow.response = request_reply
|
||||
@ -1099,7 +1099,7 @@ class HTTPHandler(ProtocolHandler):
|
||||
self.c.log("response", "debug", [flow.response._assemble_first_line()])
|
||||
response_reply = self.c.channel.ask("response", flow)
|
||||
if response_reply is None or response_reply == KILL:
|
||||
return False
|
||||
raise KillSignal()
|
||||
|
||||
self.send_response_to_client(flow)
|
||||
|
||||
@ -1140,7 +1140,6 @@ class HTTPHandler(ProtocolHandler):
|
||||
self.handle_error(e, flow)
|
||||
except KillSignal:
|
||||
self.c.log("Connection killed", "info")
|
||||
flow.live = None
|
||||
finally:
|
||||
flow.live = None # Connection is not live anymore.
|
||||
return False
|
||||
@ -1437,32 +1436,43 @@ class RequestReplayThread(threading.Thread):
|
||||
r = self.flow.request
|
||||
form_out_backup = r.form_out
|
||||
try:
|
||||
# In all modes, we directly connect to the server displayed
|
||||
if self.config.mode == "upstream":
|
||||
server_address = self.config.mode.get_upstream_server(self.flow.client_conn)[2:]
|
||||
server = ServerConnection(server_address)
|
||||
server.connect()
|
||||
if r.scheme == "https":
|
||||
send_connect_request(server, r.host, r.port)
|
||||
server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni)
|
||||
r.form_out = "relative"
|
||||
else:
|
||||
r.form_out = "absolute"
|
||||
self.flow.response = None
|
||||
request_reply = self.channel.ask("request", self.flow)
|
||||
if request_reply is None or request_reply == KILL:
|
||||
raise KillSignal()
|
||||
elif isinstance(request_reply, HTTPResponse):
|
||||
self.flow.response = request_reply
|
||||
else:
|
||||
server_address = (r.host, r.port)
|
||||
server = ServerConnection(server_address)
|
||||
server.connect()
|
||||
if r.scheme == "https":
|
||||
server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni)
|
||||
r.form_out = "relative"
|
||||
# In all modes, we directly connect to the server displayed
|
||||
if self.config.mode == "upstream":
|
||||
server_address = self.config.mode.get_upstream_server(self.flow.client_conn)[2:]
|
||||
server = ServerConnection(server_address)
|
||||
server.connect()
|
||||
if r.scheme == "https":
|
||||
send_connect_request(server, r.host, r.port)
|
||||
server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni)
|
||||
r.form_out = "relative"
|
||||
else:
|
||||
r.form_out = "absolute"
|
||||
else:
|
||||
server_address = (r.host, r.port)
|
||||
server = ServerConnection(server_address)
|
||||
server.connect()
|
||||
if r.scheme == "https":
|
||||
server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni)
|
||||
r.form_out = "relative"
|
||||
|
||||
server.send(r.assemble())
|
||||
self.flow.server_conn = server
|
||||
self.flow.response = HTTPResponse.from_stream(server.rfile, r.method,
|
||||
body_size_limit=self.config.body_size_limit)
|
||||
self.channel.ask("response", self.flow)
|
||||
except (proxy.ProxyError, http.HttpError, tcp.NetLibError), v:
|
||||
server.send(r.assemble())
|
||||
self.flow.server_conn = server
|
||||
self.flow.response = HTTPResponse.from_stream(server.rfile, r.method,
|
||||
body_size_limit=self.config.body_size_limit)
|
||||
response_reply = self.channel.ask("response", self.flow)
|
||||
if response_reply is None or response_reply == KILL:
|
||||
raise KillSignal()
|
||||
except (proxy.ProxyError, http.HttpError, tcp.NetLibError) as v:
|
||||
self.flow.error = Error(repr(v))
|
||||
self.channel.ask("error", self.flow)
|
||||
except KillSignal:
|
||||
self.channel.tell("log", proxy.Log("Connection killed", "info"))
|
||||
finally:
|
||||
r.form_out = form_out_backup
|
||||
|
@ -86,19 +86,20 @@ class TestClientPlaybackState:
|
||||
fm = flow.FlowMaster(None, s)
|
||||
fm.start_client_playback([first, tutils.tflow()], True)
|
||||
c = fm.client_playback
|
||||
c.testing = True
|
||||
|
||||
assert not c.done()
|
||||
assert not s.flow_count()
|
||||
assert c.count() == 2
|
||||
c.tick(fm, testing=True)
|
||||
c.tick(fm)
|
||||
assert s.flow_count()
|
||||
assert c.count() == 1
|
||||
|
||||
c.tick(fm, testing=True)
|
||||
c.tick(fm)
|
||||
assert c.count() == 1
|
||||
|
||||
c.clear(c.current)
|
||||
c.tick(fm, testing=True)
|
||||
c.tick(fm)
|
||||
assert c.count() == 0
|
||||
c.clear(c.current)
|
||||
assert c.done()
|
||||
@ -696,6 +697,7 @@ class TestFlowMaster:
|
||||
fm = flow.FlowMaster(DummyServer(ProxyConfig()), s)
|
||||
assert not fm.start_server_playback(pb, False, [], False, False, None, False)
|
||||
assert not fm.start_client_playback(pb, False)
|
||||
fm.client_playback.testing = True
|
||||
|
||||
q = Queue.Queue()
|
||||
assert not fm.state.flow_count()
|
||||
|
Loading…
Reference in New Issue
Block a user