diff options
-rw-r--r-- | src/python/grpcio_tests/tests/qps/benchmark_client.py | 8 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/stress/client.py | 9 | ||||
-rw-r--r-- | test/http2_test/http2_base_server.py | 1 | ||||
-rw-r--r-- | test/http2_test/http2_test_server.py | 34 |
4 files changed, 32 insertions, 20 deletions
diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py index 83b46c914e..650e4756e7 100644 --- a/src/python/grpcio_tests/tests/qps/benchmark_client.py +++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py @@ -68,12 +68,8 @@ class BenchmarkClient: else: channel = grpc.insecure_channel(server) - connected_event = threading.Event() - def wait_for_ready(connectivity): - if connectivity == grpc.ChannelConnectivity.READY: - connected_event.set() - channel.subscribe(wait_for_ready, try_to_connect=True) - connected_event.wait() + # waits for the channel to be ready before we start sending messages + grpc.channel_ready_future(channel).result() if config.payload_config.WhichOneof('payload') == 'simple_params': self._generic = False diff --git a/src/python/grpcio_tests/tests/stress/client.py b/src/python/grpcio_tests/tests/stress/client.py index 390ea13021..b8116729b5 100644 --- a/src/python/grpcio_tests/tests/stress/client.py +++ b/src/python/grpcio_tests/tests/stress/client.py @@ -110,10 +110,13 @@ def _get_channel(target, args): channel_credentials = grpc.ssl_channel_credentials( root_certificates=root_certificates) options = (('grpc.ssl_target_name_override', args.server_host_override,),) - return grpc.secure_channel( - target, channel_credentials, options=options) + channel = grpc.secure_channel(target, channel_credentials, options=options) else: - return grpc.insecure_channel(target) + channel = grpc.insecure_channel(target) + + # waits for the channel to be ready before we start sending messages + grpc.channel_ready_future(channel).result() + return channel def run_test(args): test_cases = _parse_weighted_test_cases(args.test_cases) diff --git a/test/http2_test/http2_base_server.py b/test/http2_test/http2_base_server.py index ee7719b1a8..8de028ceb1 100644 --- a/test/http2_test/http2_base_server.py +++ b/test/http2_test/http2_base_server.py @@ -73,7 +73,6 @@ class H2ProtocolBaseServer(twisted.internet.protocol.Protocol): def on_connection_lost(self, reason): logging.info('Disconnected %s' % reason) - twisted.internet.reactor.callFromThread(twisted.internet.reactor.stop) def dataReceived(self, data): try: diff --git a/test/http2_test/http2_test_server.py b/test/http2_test/http2_test_server.py index 44e36d34b6..abde3433ad 100644 --- a/test/http2_test/http2_test_server.py +++ b/test/http2_test/http2_test_server.py @@ -73,18 +73,32 @@ class H2Factory(twisted.internet.protocol.Factory): else: return t().get_base_server() +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--base_port', type=int, default=8080, + help='base port to run the servers (default: 8080). One test server is ' + 'started on each incrementing port, beginning with base_port, in the ' + 'following order: goaway,max_streams,ping,rst_after_data,rst_after_header,' + 'rst_during_data' + ) + return parser.parse_args() + +def start_test_servers(base_port): + """ Start one server per test case on incrementing port numbers + beginning with base_port """ + index = 0 + for test_case in sorted(_TEST_CASE_MAPPING.keys()): + portnum = base_port + index + logging.warning('serving on port %d : %s'%(portnum, test_case)) + endpoint = twisted.internet.endpoints.TCP4ServerEndpoint( + twisted.internet.reactor, portnum, backlog=128) + endpoint.listen(H2Factory(test_case)) + index += 1 + if __name__ == '__main__': logging.basicConfig( format='%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s', level=logging.INFO) - parser = argparse.ArgumentParser() - parser.add_argument('--test_case', choices=sorted(_TEST_CASE_MAPPING.keys()), - help='test case to run', required=True) - parser.add_argument('--port', type=int, default=8080, - help='port to run the server (default: 8080)') - args = parser.parse_args() - logging.info('Running test case %s on port %d' % (args.test_case, args.port)) - endpoint = twisted.internet.endpoints.TCP4ServerEndpoint( - twisted.internet.reactor, args.port, backlog=128) - endpoint.listen(H2Factory(args.test_case)) + args = parse_arguments() + start_test_servers(args.base_port) twisted.internet.reactor.run() |