From 4350e748e470be92abab9bb1f72ad8eb1ede5486 Mon Sep 17 00:00:00 2001 From: Makarand Dharmapurikar Date: Thu, 1 Dec 2016 14:24:22 -0800 Subject: ability to deal with multiple streams in flight. --- test/http2_test/http2_base_server.py | 76 ++++++++++++++++++++-------------- test/http2_test/test_goaway.py | 19 ++++----- test/http2_test/test_max_streams.py | 9 ++-- test/http2_test/test_ping.py | 13 +++--- test/http2_test/test_rst_after_data.py | 14 +++---- 5 files changed, 72 insertions(+), 59 deletions(-) (limited to 'test/http2_test') diff --git a/test/http2_test/http2_base_server.py b/test/http2_test/http2_base_server.py index 91caa74fcc..44fb575c0f 100644 --- a/test/http2_test/http2_base_server.py +++ b/test/http2_test/http2_base_server.py @@ -6,6 +6,7 @@ from twisted.internet.protocol import Protocol from twisted.internet import reactor from h2.connection import H2Connection from h2.events import RequestReceived, DataReceived, WindowUpdated, RemoteSettingsChanged, PingAcknowledged +from h2.exceptions import ProtocolError READ_CHUNK_SIZE = 16384 GRPC_HEADER_SIZE = 5 @@ -13,7 +14,7 @@ GRPC_HEADER_SIZE = 5 class H2ProtocolBaseServer(Protocol): def __init__(self): self._conn = H2Connection(client_side=False) - self._recv_buffer = '' + self._recv_buffer = {} self._handlers = {} self._handlers['ConnectionMade'] = self.on_connection_made_default self._handlers['DataReceived'] = self.on_data_received_default @@ -23,6 +24,7 @@ class H2ProtocolBaseServer(Protocol): self._handlers['ConnectionLost'] = self.on_connection_lost self._handlers['PingAcknowledged'] = self.on_ping_acknowledged_default self._stream_status = {} + self._send_remaining = {} self._outstanding_pings = 0 def set_handlers(self, handlers): @@ -45,18 +47,23 @@ class H2ProtocolBaseServer(Protocol): reactor.callFromThread(reactor.stop) def dataReceived(self, data): - events = self._conn.receive_data(data) + try: + events = self._conn.receive_data(data) + except ProtocolError: + # this try/except block catches exceptions due to race between sending + # GOAWAY and processing a response in flight. + return if self._conn.data_to_send: self.transport.write(self._conn.data_to_send()) for event in events: if isinstance(event, RequestReceived) and self._handlers.has_key('RequestReceived'): - logging.info('RequestReceived Event') + logging.info('RequestReceived Event for stream: %d'%event.stream_id) self._handlers['RequestReceived'](event) elif isinstance(event, DataReceived) and self._handlers.has_key('DataReceived'): - logging.info('DataReceived Event') + logging.info('DataReceived Event for stream: %d'%event.stream_id) self._handlers['DataReceived'](event) elif isinstance(event, WindowUpdated) and self._handlers.has_key('WindowUpdated'): - logging.info('WindowUpdated Event') + logging.info('WindowUpdated Event for stream: %d'%event.stream_id) self._handlers['WindowUpdated'](event) elif isinstance(event, PingAcknowledged) and self._handlers.has_key('PingAcknowledged'): logging.info('PingAcknowledged Event') @@ -68,10 +75,10 @@ class H2ProtocolBaseServer(Protocol): def on_data_received_default(self, event): self._conn.acknowledge_received_data(len(event.data), event.stream_id) - self._recv_buffer += event.data + self._recv_buffer[event.stream_id] += event.data def on_request_received_default(self, event): - self._recv_buffer = '' + self._recv_buffer[event.stream_id] = '' self._stream_id = event.stream_id self._stream_status[event.stream_id] = True self._conn.send_headers( @@ -86,48 +93,57 @@ class H2ProtocolBaseServer(Protocol): self.transport.write(self._conn.data_to_send()) def on_window_update_default(self, event): - pass + # send pending data, if any + self.default_send(event.stream_id) def send_reset_stream(self): self._conn.reset_stream(self._stream_id) self.transport.write(self._conn.data_to_send()) - def setup_send(self, data_to_send): - self._send_remaining = len(data_to_send) + def setup_send(self, data_to_send, stream_id): + logging.info('Setting up data to send for stream_id: %d'%stream_id) + self._send_remaining[stream_id] = len(data_to_send) self._send_offset = 0 self._data_to_send = data_to_send - self.default_send() + self.default_send(stream_id) - def default_send(self): - while self._send_remaining > 0: - lfcw = self._conn.local_flow_control_window(self._stream_id) + def default_send(self, stream_id): + if not self._send_remaining.has_key(stream_id): + # not setup to send data yet + return + + while self._send_remaining[stream_id] > 0: + if self._stream_status[stream_id] is False: + logging.info('Stream %d is closed.'%stream_id) + break + lfcw = self._conn.local_flow_control_window(stream_id) if lfcw == 0: break chunk_size = min(lfcw, READ_CHUNK_SIZE) - bytes_to_send = min(chunk_size, self._send_remaining) + bytes_to_send = min(chunk_size, self._send_remaining[stream_id]) logging.info('flow_control_window = %d. sending [%d:%d] stream_id %d'% (lfcw, self._send_offset, self._send_offset + bytes_to_send, - self._stream_id)) + stream_id)) data = self._data_to_send[self._send_offset : self._send_offset + bytes_to_send] - self._conn.send_data(self._stream_id, data, False) - self._send_remaining -= bytes_to_send + self._conn.send_data(stream_id, data, False) + self._send_remaining[stream_id] -= bytes_to_send self._send_offset += bytes_to_send - if self._send_remaining == 0: - self._handlers['SendDone']() + if self._send_remaining[stream_id] == 0: + self._handlers['SendDone'](stream_id) def default_ping(self): self._outstanding_pings += 1 self._conn.ping(b'\x00'*8) self.transport.write(self._conn.data_to_send()) - def on_send_done_default(self): - if self._stream_status[self._stream_id]: - self._stream_status[self._stream_id] = False - self.default_send_trailer() + def on_send_done_default(self, stream_id): + if self._stream_status[stream_id]: + self._stream_status[stream_id] = False + self.default_send_trailer(stream_id) - def default_send_trailer(self): - logging.info('Sending trailer for stream id %d'%self._stream_id) - self._conn.send_headers(self._stream_id, + def default_send_trailer(self, stream_id): + logging.info('Sending trailer for stream id %d'%stream_id) + self._conn.send_headers(stream_id, headers=[ ('grpc-status', '0') ], end_stream=True ) @@ -141,8 +157,8 @@ class H2ProtocolBaseServer(Protocol): response_data = b'\x00' + struct.pack('i', len(serialized_resp_proto))[::-1] + serialized_resp_proto return response_data - @staticmethod - def parse_received_data(recv_buffer): + def parse_received_data(self, stream_id): + recv_buffer = self._recv_buffer[stream_id] """ returns a grpc framed string of bytes containing response proto of the size asked in request """ grpc_msg_size = struct.unpack('i',recv_buffer[1:5][::-1])[0] @@ -152,5 +168,5 @@ class H2ProtocolBaseServer(Protocol): req_proto_str = recv_buffer[5:5+grpc_msg_size] sr = messages_pb2.SimpleRequest() sr.ParseFromString(req_proto_str) - logging.info('Parsed request: response_size=%s'%sr.response_size) + logging.info('Parsed request for stream %d: response_size=%s'%(stream_id, sr.response_size)) return sr diff --git a/test/http2_test/test_goaway.py b/test/http2_test/test_goaway.py index 419bd7b3f8..7dd7cb7948 100644 --- a/test/http2_test/test_goaway.py +++ b/test/http2_test/test_goaway.py @@ -12,7 +12,6 @@ class TestcaseGoaway(object): self._base_server = http2_base_server.H2ProtocolBaseServer() self._base_server._handlers['RequestReceived'] = self.on_request_received self._base_server._handlers['DataReceived'] = self.on_data_received - self._base_server._handlers['WindowUpdated'] = self.on_window_update_default self._base_server._handlers['SendDone'] = self.on_send_done self._base_server._handlers['ConnectionLost'] = self.on_connection_lost self._ready_to_send = False @@ -27,11 +26,11 @@ class TestcaseGoaway(object): if self._iteration == 2: self._base_server.on_connection_lost(reason) - def on_send_done(self): - self._base_server.on_send_done_default() - if self._base_server._stream_id == 1: - logging.info('Sending GOAWAY for stream 1') - self._base_server._conn.close_connection(error_code=0, additional_data=None, last_stream_id=1) + def on_send_done(self, stream_id): + self._base_server.on_send_done_default(stream_id) + logging.info('Sending GOAWAY for stream %d:'%stream_id) + self._base_server._conn.close_connection(error_code=0, additional_data=None, last_stream_id=stream_id) + self._base_server._stream_status[stream_id] = False def on_request_received(self, event): self._ready_to_send = False @@ -39,13 +38,9 @@ class TestcaseGoaway(object): def on_data_received(self, event): self._base_server.on_data_received_default(event) - sr = self._base_server.parse_received_data(self._base_server._recv_buffer) + sr = self._base_server.parse_received_data(event.stream_id) if sr: logging.info('Creating response size = %s'%sr.response_size) response_data = self._base_server.default_response_data(sr.response_size) self._ready_to_send = True - self._base_server.setup_send(response_data) - - def on_window_update_default(self, event): - if self._ready_to_send: - self._base_server.default_send() + self._base_server.setup_send(response_data, event.stream_id) diff --git a/test/http2_test/test_max_streams.py b/test/http2_test/test_max_streams.py index a85dde48b5..deb26770c3 100644 --- a/test/http2_test/test_max_streams.py +++ b/test/http2_test/test_max_streams.py @@ -24,7 +24,8 @@ class TestcaseSettingsMaxStreams(object): def on_data_received(self, event): self._base_server.on_data_received_default(event) - sr = self._base_server.parse_received_data(self._base_server._recv_buffer) - logging.info('Creating response size = %s'%sr.response_size) - response_data = self._base_server.default_response_data(sr.response_size) - self._base_server.setup_send(response_data) + sr = self._base_server.parse_received_data(event.stream_id) + if sr: + logging.info('Creating response of size = %s'%sr.response_size) + response_data = self._base_server.default_response_data(sr.response_size) + self._base_server.setup_send(response_data, event.stream_id) diff --git a/test/http2_test/test_ping.py b/test/http2_test/test_ping.py index bade9df9b1..2e6dadbc07 100644 --- a/test/http2_test/test_ping.py +++ b/test/http2_test/test_ping.py @@ -23,12 +23,13 @@ class TestcasePing(object): def on_data_received(self, event): self._base_server.on_data_received_default(event) - sr = self._base_server.parse_received_data(self._base_server._recv_buffer) - logging.info('Creating response size = %s'%sr.response_size) - response_data = self._base_server.default_response_data(sr.response_size) - self._base_server.default_ping() - self._base_server.setup_send(response_data) - self._base_server.default_ping() + sr = self._base_server.parse_received_data(event.stream_id) + if sr: + logging.info('Creating response size = %s'%sr.response_size) + response_data = self._base_server.default_response_data(sr.response_size) + self._base_server.default_ping() + self._base_server.setup_send(response_data, event.stream_id) + self._base_server.default_ping() def on_connection_lost(self, reason): logging.info('Disconnect received. Ping Count %d'%self._base_server._outstanding_pings) diff --git a/test/http2_test/test_rst_after_data.py b/test/http2_test/test_rst_after_data.py index ef8d4084d9..c4ff56c889 100644 --- a/test/http2_test/test_rst_after_data.py +++ b/test/http2_test/test_rst_after_data.py @@ -14,10 +14,10 @@ class TestcaseRstStreamAfterData(object): def on_data_received(self, event): self._base_server.on_data_received_default(event) - sr = self._base_server.parse_received_data(self._base_server._recv_buffer) - assert(sr is not None) - response_data = self._base_server.default_response_data(sr.response_size) - self._ready_to_send = True - self._base_server.setup_send(response_data) - # send reset stream - self._base_server.send_reset_stream() + sr = self._base_server.parse_received_data(event.stream_id) + if sr: + response_data = self._base_server.default_response_data(sr.response_size) + self._ready_to_send = True + self._base_server.setup_send(response_data, event.stream_id) + # send reset stream + self._base_server.send_reset_stream() -- cgit v1.2.3