aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/python/grpcio_tests/tests/qps/benchmark_client.py8
-rw-r--r--src/python/grpcio_tests/tests/stress/client.py9
-rw-r--r--test/http2_test/http2_base_server.py1
-rw-r--r--test/http2_test/http2_test_server.py34
4 files changed, 32 insertions, 20 deletions
diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py
index 83b46c914e..650e4756e7 100644
--- a/src/python/grpcio_tests/tests/qps/benchmark_client.py
+++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py
@@ -68,12 +68,8 @@ class BenchmarkClient:
else:
channel = grpc.insecure_channel(server)
- connected_event = threading.Event()
- def wait_for_ready(connectivity):
- if connectivity == grpc.ChannelConnectivity.READY:
- connected_event.set()
- channel.subscribe(wait_for_ready, try_to_connect=True)
- connected_event.wait()
+ # waits for the channel to be ready before we start sending messages
+ grpc.channel_ready_future(channel).result()
if config.payload_config.WhichOneof('payload') == 'simple_params':
self._generic = False
diff --git a/src/python/grpcio_tests/tests/stress/client.py b/src/python/grpcio_tests/tests/stress/client.py
index 390ea13021..b8116729b5 100644
--- a/src/python/grpcio_tests/tests/stress/client.py
+++ b/src/python/grpcio_tests/tests/stress/client.py
@@ -110,10 +110,13 @@ def _get_channel(target, args):
channel_credentials = grpc.ssl_channel_credentials(
root_certificates=root_certificates)
options = (('grpc.ssl_target_name_override', args.server_host_override,),)
- return grpc.secure_channel(
- target, channel_credentials, options=options)
+ channel = grpc.secure_channel(target, channel_credentials, options=options)
else:
- return grpc.insecure_channel(target)
+ channel = grpc.insecure_channel(target)
+
+ # waits for the channel to be ready before we start sending messages
+ grpc.channel_ready_future(channel).result()
+ return channel
def run_test(args):
test_cases = _parse_weighted_test_cases(args.test_cases)
diff --git a/test/http2_test/http2_base_server.py b/test/http2_test/http2_base_server.py
index ee7719b1a8..8de028ceb1 100644
--- a/test/http2_test/http2_base_server.py
+++ b/test/http2_test/http2_base_server.py
@@ -73,7 +73,6 @@ class H2ProtocolBaseServer(twisted.internet.protocol.Protocol):
def on_connection_lost(self, reason):
logging.info('Disconnected %s' % reason)
- twisted.internet.reactor.callFromThread(twisted.internet.reactor.stop)
def dataReceived(self, data):
try:
diff --git a/test/http2_test/http2_test_server.py b/test/http2_test/http2_test_server.py
index 44e36d34b6..abde3433ad 100644
--- a/test/http2_test/http2_test_server.py
+++ b/test/http2_test/http2_test_server.py
@@ -73,18 +73,32 @@ class H2Factory(twisted.internet.protocol.Factory):
else:
return t().get_base_server()
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--base_port', type=int, default=8080,
+ help='base port to run the servers (default: 8080). One test server is '
+ 'started on each incrementing port, beginning with base_port, in the '
+ 'following order: goaway,max_streams,ping,rst_after_data,rst_after_header,'
+ 'rst_during_data'
+ )
+ return parser.parse_args()
+
+def start_test_servers(base_port):
+ """ Start one server per test case on incrementing port numbers
+ beginning with base_port """
+ index = 0
+ for test_case in sorted(_TEST_CASE_MAPPING.keys()):
+ portnum = base_port + index
+ logging.warning('serving on port %d : %s'%(portnum, test_case))
+ endpoint = twisted.internet.endpoints.TCP4ServerEndpoint(
+ twisted.internet.reactor, portnum, backlog=128)
+ endpoint.listen(H2Factory(test_case))
+ index += 1
+
if __name__ == '__main__':
logging.basicConfig(
format='%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s',
level=logging.INFO)
- parser = argparse.ArgumentParser()
- parser.add_argument('--test_case', choices=sorted(_TEST_CASE_MAPPING.keys()),
- help='test case to run', required=True)
- parser.add_argument('--port', type=int, default=8080,
- help='port to run the server (default: 8080)')
- args = parser.parse_args()
- logging.info('Running test case %s on port %d' % (args.test_case, args.port))
- endpoint = twisted.internet.endpoints.TCP4ServerEndpoint(
- twisted.internet.reactor, args.port, backlog=128)
- endpoint.listen(H2Factory(args.test_case))
+ args = parse_arguments()
+ start_test_servers(args.base_port)
twisted.internet.reactor.run()