From a16ea7f9b129d34edbdb867e720d4a2ea71e9a7f Mon Sep 17 00:00:00 2001 From: Makarand Dharmapurikar Date: Fri, 2 Dec 2016 10:17:03 -0800 Subject: added new test (rst_during_data) --- test/http2_test/http2_test_server.py | 41 +++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 17 deletions(-) (limited to 'test/http2_test/http2_test_server.py') diff --git a/test/http2_test/http2_test_server.py b/test/http2_test/http2_test_server.py index c74fc4b1fb..5270ee4255 100644 --- a/test/http2_test/http2_test_server.py +++ b/test/http2_test/http2_test_server.py @@ -9,10 +9,20 @@ from twisted.internet import endpoints, reactor import http2_base_server import test_rst_after_header import test_rst_after_data +import test_rst_during_data import test_goaway import test_ping import test_max_streams +test_case_mappings = { + 'rst_after_header': test_rst_after_header.TestcaseRstStreamAfterHeader, + 'rst_after_data': test_rst_after_data.TestcaseRstStreamAfterData, + 'rst_during_data': test_rst_during_data.TestcaseRstStreamDuringData, + 'goaway': test_goaway.TestcaseGoaway, + 'ping': test_ping.TestcasePing, + 'max_streams': test_max_streams.TestcaseSettingsMaxStreams, +} + class H2Factory(Factory): def __init__(self, testcase): logging.info('In H2Factory') @@ -22,20 +32,16 @@ class H2Factory(Factory): def buildProtocol(self, addr): self._num_streams += 1 logging.info('New Connection: %d'%self._num_streams) - if self._testcase == 'rst_after_header': - t = test_rst_after_header.TestcaseRstStreamAfterHeader() - elif self._testcase == 'rst_after_data': - t = test_rst_after_data.TestcaseRstStreamAfterData() - elif self._testcase == 'goaway': - t = test_goaway.TestcaseGoaway(self._num_streams) - elif self._testcase == 'ping': - t = test_ping.TestcasePing() - elif self._testcase == 'max_streams': - t = test_max_streams.TestcaseSettingsMaxStreams() - else: + if not test_case_mappings.has_key(self._testcase): logging.error('Unknown test case: %s'%self._testcase) assert(0) - return t.get_base_server() + else: + t = test_case_mappings[self._testcase] + + if self._testcase == 'goaway': + return t(self._num_streams).get_base_server() + else: + return t().get_base_server() if __name__ == "__main__": logging.basicConfig(format = "%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s", level=logging.INFO) @@ -43,8 +49,9 @@ if __name__ == "__main__": parser.add_argument("test") parser.add_argument("port") args = parser.parse_args() - if args.test not in ['rst_after_header', 'rst_after_data', 'goaway', 'ping', 'max_streams']: - print 'unknown test: ', args.test - endpoint = endpoints.TCP4ServerEndpoint(reactor, int(args.port), backlog=128) - endpoint.listen(H2Factory(args.test)) - reactor.run() + if args.test not in test_case_mappings.keys(): + logging.error('unknown test: %s'%args.test) + else: + endpoint = endpoints.TCP4ServerEndpoint(reactor, int(args.port), backlog=128) + endpoint.listen(H2Factory(args.test)) + reactor.run() -- cgit v1.2.3