diff options
Diffstat (limited to 'test/http2_test/http2_test_server.py')
-rw-r--r-- | test/http2_test/http2_test_server.py | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/test/http2_test/http2_test_server.py b/test/http2_test/http2_test_server.py new file mode 100644 index 0000000000..be5f1593eb --- /dev/null +++ b/test/http2_test/http2_test_server.py @@ -0,0 +1,164 @@ +""" + HTTP2 Test Server. Highly experimental work in progress. +""" +import struct +import messages_pb2 +import argparse +import logging +import time + +from twisted.internet.defer import Deferred, inlineCallbacks +from twisted.internet.protocol import Protocol, Factory +from twisted.internet import endpoints, reactor, error, defer +from h2.connection import H2Connection +from h2.events import RequestReceived, DataReceived, WindowUpdated, RemoteSettingsChanged +from threading import Lock +import http2_base_server + +READ_CHUNK_SIZE = 16384 +GRPC_HEADER_SIZE = 5 + +class TestcaseRstStreamAfterHeader(object): + def __init__(self): + self._base_server = http2_base_server.H2ProtocolBaseServer() + self._base_server._handlers['RequestReceived'] = self.on_request_received + + def get_base_server(self): + return self._base_server + + def on_request_received(self, event): + # send initial headers + self._base_server.on_request_received_default(event) + # send reset stream + self._base_server.send_reset_stream() + +class TestcaseRstStreamAfterData(object): + def __init__(self): + self._base_server = http2_base_server.H2ProtocolBaseServer() + self._base_server._handlers['DataReceived'] = self.on_data_received + + def get_base_server(self): + return self._base_server + + def on_data_received(self, event): + self._base_server.on_data_received_default(event) + sr = self._base_server.parse_received_data(self._base_server._recv_buffer) + assert(sr is not None) + assert(sr.response_size <= 2048) # so it can fit into one flow control window + response_data = self._base_server.default_response_data(sr.response_size) + self._ready_to_send = True + self._base_server.setup_send(response_data) + # send reset stream + self._base_server.send_reset_stream() + +class TestcaseGoaway(object): + """ + Process incoming request normally. After sending trailer response, + send GOAWAY with stream id = 1. + assert that the next request is made on a different connection. + """ + def __init__(self, iteration): + self._base_server = http2_base_server.H2ProtocolBaseServer() + self._base_server._handlers['RequestReceived'] = self.on_request_received + self._base_server._handlers['DataReceived'] = self.on_data_received + self._base_server._handlers['WindowUpdated'] = self.on_window_update_default + self._base_server._handlers['SendDone'] = self.on_send_done + self._base_server._handlers['ConnectionLost'] = self.on_connection_lost + self._ready_to_send = False + self._iteration = iteration + + def get_base_server(self): + return self._base_server + + def on_connection_lost(self, reason): + logging.info('Disconnect received. Count %d'%self._iteration) + # _iteration == 2 => Two different connections have been used. + if self._iteration == 2: + self._base_server.on_connection_lost(reason) + + def on_send_done(self): + self._base_server.on_send_done_default() + if self._base_server._stream_id == 1: + logging.info('Sending GOAWAY for stream 1') + self._base_server._conn.close_connection(error_code=0, additional_data=None, last_stream_id=1) + + def on_request_received(self, event): + self._ready_to_send = False + self._base_server.on_request_received_default(event) + + def on_data_received(self, event): + self._base_server.on_data_received_default(event) + sr = self._base_server.parse_received_data(self._base_server._recv_buffer) + if sr: + time.sleep(1) + logging.info('Creating response size = %s'%sr.response_size) + response_data = self._base_server.default_response_data(sr.response_size) + self._ready_to_send = True + self._base_server.setup_send(response_data) + + def on_window_update_default(self, event): + if self._ready_to_send: + self._base_server.default_send() + +class TestcasePing(object): + """ + """ + def __init__(self, iteration): + self._base_server = http2_base_server.H2ProtocolBaseServer() + self._base_server._handlers['RequestReceived'] = self.on_request_received + self._base_server._handlers['DataReceived'] = self.on_data_received + self._base_server._handlers['ConnectionLost'] = self.on_connection_lost + + def get_base_server(self): + return self._base_server + + def on_request_received(self, event): + self._base_server.default_ping() + self._base_server.on_request_received_default(event) + self._base_server.default_ping() + + def on_data_received(self, event): + self._base_server.on_data_received_default(event) + sr = self._base_server.parse_received_data(self._base_server._recv_buffer) + logging.info('Creating response size = %s'%sr.response_size) + response_data = self._base_server.default_response_data(sr.response_size) + self._base_server.default_ping() + self._base_server.setup_send(response_data) + self._base_server.default_ping() + + def on_connection_lost(self, reason): + logging.info('Disconnect received. Ping Count %d'%self._base_server._outstanding_pings) + assert(self._base_server._outstanding_pings == 0) + self._base_server.on_connection_lost(reason) + +class H2Factory(Factory): + def __init__(self, testcase): + logging.info('In H2Factory') + self._num_streams = 0 + self._testcase = testcase + + def buildProtocol(self, addr): + self._num_streams += 1 + if self._testcase == 'rst_stream_after_header': + t = TestcaseRstStreamAfterHeader(self._num_streams) + elif self._testcase == 'rst_stream_after_data': + t = TestcaseRstStreamAfterData(self._num_streams) + elif self._testcase == 'goaway': + t = TestcaseGoaway(self._num_streams) + elif self._testcase == 'ping': + t = TestcasePing(self._num_streams) + else: + assert(0) + return t.get_base_server() + +if __name__ == "__main__": + logging.basicConfig(format = "%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s", level=logging.INFO) + parser = argparse.ArgumentParser() + parser.add_argument("test") + parser.add_argument("port") + args = parser.parse_args() + if args.test not in ['rst_stream_after_header', 'rst_stream_after_data', 'goaway', 'ping']: + print 'unknown test: ', args.test + endpoint = endpoints.TCP4ServerEndpoint(reactor, int(args.port), backlog=128) + endpoint.listen(H2Factory(args.test)) + reactor.run() |