diff options
Diffstat (limited to 'src/python/grpcio/tests/unit')
15 files changed, 1189 insertions, 273 deletions
diff --git a/src/python/grpcio/tests/unit/_compression_test.py b/src/python/grpcio/tests/unit/_compression_test.py new file mode 100644 index 0000000000..9e8b8578c1 --- /dev/null +++ b/src/python/grpcio/tests/unit/_compression_test.py @@ -0,0 +1,133 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Tests server and client side compression.""" + +import unittest + +import grpc +from grpc import _grpcio_metadata +from grpc.framework.foundation import logging_pool + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_UNARY_UNARY = '/test/UnaryUnary' +_STREAM_STREAM = '/test/StreamStream' + + +def handle_unary(request, servicer_context): + servicer_context.send_initial_metadata([ + ('grpc-internal-encoding-request', 'gzip')]) + return request + + +def handle_stream(request_iterator, servicer_context): + # TODO(issue:#6891) We should be able to remove this loop, + # and replace with return; yield + servicer_context.send_initial_metadata([ + ('grpc-internal-encoding-request', 'gzip')]) + for request in request_iterator: + yield request + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + self.stream_stream = lambda x, y: handle_stream(x, y) + elif not self.request_streaming and not self.response_streaming: + self.unary_unary = lambda x, y: handle_unary(x, y) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(False, False) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(True, True) + else: + return None + + +class CompressionTest(unittest.TestCase): + + def setUp(self): + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + self._server = grpc.server((_GenericHandler(),), self._server_pool) + self._port = self._server.add_insecure_port('[::]:0') + self._server.start() + + def testUnary(self): + request = b'\x00' * 100 + + # Client -> server compressed through default client channel compression + # settings. Server -> client compressed via server-side metadata setting. + # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer + # literal with proper use of the public API. + compressed_channel = grpc.insecure_channel('localhost:%d' % self._port, + options=[('grpc.default_compression_algorithm', 1)]) + multi_callable = compressed_channel.unary_unary(_UNARY_UNARY) + response = multi_callable(request) + self.assertEqual(request, response) + + # Client -> server compressed through client metadata setting. Server -> + # client compressed via server-side metadata setting. + # TODO(https://github.com/grpc/grpc/issues/4078): replace the "0" integer + # literal with proper use of the public API. + uncompressed_channel = grpc.insecure_channel('localhost:%d' % self._port, + options=[('grpc.default_compression_algorithm', 0)]) + multi_callable = compressed_channel.unary_unary(_UNARY_UNARY) + response = multi_callable(request, metadata=[ + ('grpc-internal-encoding-request', 'gzip')]) + self.assertEqual(request, response) + + def testStreaming(self): + request = b'\x00' * 100 + + # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer + # literal with proper use of the public API. + compressed_channel = grpc.insecure_channel('localhost:%d' % self._port, + options=[('grpc.default_compression_algorithm', 1)]) + multi_callable = compressed_channel.stream_stream(_STREAM_STREAM) + call = multi_callable([request] * test_constants.STREAM_LENGTH) + for response in call: + self.assertEqual(request, response) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py index c1de779014..cac0c8b3b9 100644 --- a/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py @@ -159,9 +159,9 @@ class CancelManyCallsTest(unittest.TestCase): server_completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server() server.register_completion_queue(server_completion_queue) - port = server.add_http2_port('[::]:0') + port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port)) + channel = cygrpc.Channel('localhost:{}'.format(port).encode()) state = _State() diff --git a/src/python/grpcio/tests/unit/_cython/_channel_test.py b/src/python/grpcio/tests/unit/_cython/_channel_test.py index 3dc7a246ae..f9c8a3ac62 100644 --- a/src/python/grpcio/tests/unit/_cython/_channel_test.py +++ b/src/python/grpcio/tests/unit/_cython/_channel_test.py @@ -37,7 +37,7 @@ from tests.unit.framework.common import test_constants def _channel_and_completion_queue(): - channel = cygrpc.Channel('localhost:54321', cygrpc.ChannelArgs(())) + channel = cygrpc.Channel(b'localhost:54321', cygrpc.ChannelArgs(())) completion_queue = cygrpc.CompletionQueue() return channel, completion_queue diff --git a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py index 6ae7a90fbe..27fcee0d6f 100644 --- a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py +++ b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -126,9 +126,9 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server_completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server() server.register_completion_queue(server_completion_queue) - port = server.add_http2_port('[::]:0') + port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port)) + channel = cygrpc.Channel('localhost:{}'.format(port).encode()) server_shutdown_tag = 'server_shutdown_tag' server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag) diff --git a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py index a006a20ce3..b740695e35 100644 --- a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py @@ -46,38 +46,38 @@ def _metadata_plugin_callback(context, callback): callback(cygrpc.Metadata( [cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY, _CALL_CREDENTIALS_METADATA_VALUE)]), - cygrpc.StatusCode.ok, '') + cygrpc.StatusCode.ok, b'') class TypeSmokeTest(unittest.TestCase): def testStringsInUtilitiesUpDown(self): self.assertEqual(0, cygrpc.StatusCode.ok) - metadatum = cygrpc.Metadatum('a', 'b') - self.assertEqual('a'.encode(), metadatum.key) - self.assertEqual('b'.encode(), metadatum.value) + metadatum = cygrpc.Metadatum(b'a', b'b') + self.assertEqual(b'a', metadatum.key) + self.assertEqual(b'b', metadatum.value) metadata = cygrpc.Metadata([metadatum]) self.assertEqual(1, len(metadata)) self.assertEqual(metadatum.key, metadata[0].key) def testMetadataIteration(self): metadata = cygrpc.Metadata([ - cygrpc.Metadatum('a', 'b'), cygrpc.Metadatum('c', 'd')]) + cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')]) iterator = iter(metadata) metadatum = next(iterator) self.assertIsInstance(metadatum, cygrpc.Metadatum) - self.assertEqual(metadatum.key, 'a'.encode()) - self.assertEqual(metadatum.value, 'b'.encode()) + self.assertEqual(metadatum.key, b'a') + self.assertEqual(metadatum.value, b'b') metadatum = next(iterator) self.assertIsInstance(metadatum, cygrpc.Metadatum) - self.assertEqual(metadatum.key, 'c'.encode()) - self.assertEqual(metadatum.value, 'd'.encode()) + self.assertEqual(metadatum.key, b'c') + self.assertEqual(metadatum.value, b'd') with self.assertRaises(StopIteration): next(iterator) def testOperationsIteration(self): operations = cygrpc.Operations([ - cygrpc.operation_send_message('asdf', _EMPTY_FLAGS)]) + cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)]) iterator = iter(operations) operation = next(iterator) self.assertIsInstance(operation, cygrpc.Operation) @@ -87,7 +87,7 @@ class TypeSmokeTest(unittest.TestCase): next(iterator) def testOperationFlags(self): - operation = cygrpc.operation_send_message('asdf', + operation = cygrpc.operation_send_message(b'asdf', cygrpc.WriteFlag.no_compress) self.assertEqual(cygrpc.WriteFlag.no_compress, operation.flags) @@ -105,16 +105,16 @@ class TypeSmokeTest(unittest.TestCase): del server def testChannelUpDown(self): - channel = cygrpc.Channel('[::]:0', cygrpc.ChannelArgs([])) + channel = cygrpc.Channel(b'[::]:0', cygrpc.ChannelArgs([])) del channel def testCredentialsMetadataPluginUpDown(self): plugin = cygrpc.CredentialsMetadataPlugin( - lambda ignored_a, ignored_b: None, '') + lambda ignored_a, ignored_b: None, b'') del plugin def testCallCredentialsFromPluginUpDown(self): - plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '') + plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, b'') call_credentials = cygrpc.call_credentials_metadata_plugin(plugin) del plugin del call_credentials @@ -123,7 +123,7 @@ class TypeSmokeTest(unittest.TestCase): server = cygrpc.Server() completion_queue = cygrpc.CompletionQueue() server.register_completion_queue(completion_queue) - port = server.add_http2_port('[::]:0') + port = server.add_http2_port(b'[::]:0') self.assertIsInstance(port, int) server.start() del server @@ -131,7 +131,7 @@ class TypeSmokeTest(unittest.TestCase): def testServerStartShutdown(self): completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server() - server.add_http2_port('[::]:0') + server.add_http2_port(b'[::]:0') server.register_completion_queue(completion_queue) server.start() shutdown_tag = object() @@ -150,9 +150,9 @@ class ServerClientMixin(object): self.server = cygrpc.Server() self.server.register_completion_queue(self.server_completion_queue) if server_credentials: - self.port = self.server.add_http2_port('[::]:0', server_credentials) + self.port = self.server.add_http2_port(b'[::]:0', server_credentials) else: - self.port = self.server.add_http2_port('[::]:0') + self.port = self.server.add_http2_port(b'[::]:0') self.server.start() self.client_completion_queue = cygrpc.CompletionQueue() if client_credentials: @@ -160,10 +160,10 @@ class ServerClientMixin(object): cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override, host_override)]) self.client_channel = cygrpc.Channel( - 'localhost:{}'.format(self.port), client_channel_arguments, + 'localhost:{}'.format(self.port).encode(), client_channel_arguments, client_credentials) else: - self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port)) + self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port).encode()) if host_override: self.host_argument = None # default host self.expected_host = host_override diff --git a/src/python/grpcio/tests/unit/_empty_message_test.py b/src/python/grpcio/tests/unit/_empty_message_test.py index f324f6216b..8c7d697728 100644 --- a/src/python/grpcio/tests/unit/_empty_message_test.py +++ b/src/python/grpcio/tests/unit/_empty_message_test.py @@ -37,10 +37,10 @@ from tests.unit.framework.common import test_constants _REQUEST = b'' _RESPONSE = b'' -_UNARY_UNARY = b'/test/UnaryUnary' -_UNARY_STREAM = b'/test/UnaryStream' -_STREAM_UNARY = b'/test/StreamUnary' -_STREAM_STREAM = b'/test/StreamStream' +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' def handle_unary_unary(request, servicer_context): diff --git a/src/python/grpcio/tests/unit/_exit_scenarios.py b/src/python/grpcio/tests/unit/_exit_scenarios.py new file mode 100644 index 0000000000..24a2faef85 --- /dev/null +++ b/src/python/grpcio/tests/unit/_exit_scenarios.py @@ -0,0 +1,249 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Defines a number of module-scope gRPC scenarios to test clean exit.""" + +import argparse +import threading +import time + +import grpc + +from tests.unit.framework.common import test_constants + +WAIT_TIME = 1000 + +REQUEST = b'request' + +UNSTARTED_SERVER = 'unstarted_server' +RUNNING_SERVER = 'running_server' +POLL_CONNECTIVITY_NO_SERVER = 'poll_connectivity_no_server' +POLL_CONNECTIVITY = 'poll_connectivity' +IN_FLIGHT_UNARY_UNARY_CALL = 'in_flight_unary_unary_call' +IN_FLIGHT_UNARY_STREAM_CALL = 'in_flight_unary_stream_call' +IN_FLIGHT_STREAM_UNARY_CALL = 'in_flight_stream_unary_call' +IN_FLIGHT_STREAM_STREAM_CALL = 'in_flight_stream_stream_call' +IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL = 'in_flight_partial_unary_stream_call' +IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL = 'in_flight_partial_stream_unary_call' +IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL = 'in_flight_partial_stream_stream_call' + +UNARY_UNARY = b'/test/UnaryUnary' +UNARY_STREAM = b'/test/UnaryStream' +STREAM_UNARY = b'/test/StreamUnary' +STREAM_STREAM = b'/test/StreamStream' +PARTIAL_UNARY_STREAM = b'/test/PartialUnaryStream' +PARTIAL_STREAM_UNARY = b'/test/PartialStreamUnary' +PARTIAL_STREAM_STREAM = b'/test/PartialStreamStream' + +TEST_TO_METHOD = { + IN_FLIGHT_UNARY_UNARY_CALL: UNARY_UNARY, + IN_FLIGHT_UNARY_STREAM_CALL: UNARY_STREAM, + IN_FLIGHT_STREAM_UNARY_CALL: STREAM_UNARY, + IN_FLIGHT_STREAM_STREAM_CALL: STREAM_STREAM, + IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL: PARTIAL_UNARY_STREAM, + IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL: PARTIAL_STREAM_UNARY, + IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL: PARTIAL_STREAM_STREAM, +} + + +def hang_unary_unary(request, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_unary_stream(request, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_partial_unary_stream(request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH // 2): + yield request + time.sleep(WAIT_TIME) + + +def hang_stream_unary(request_iterator, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_partial_stream_unary(request_iterator, servicer_context): + for _ in range(test_constants.STREAM_LENGTH // 2): + next(request_iterator) + time.sleep(WAIT_TIME) + + +def hang_stream_stream(request_iterator, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_partial_stream_stream(request_iterator, servicer_context): + for _ in range(test_constants.STREAM_LENGTH // 2): + yield next(request_iterator) + time.sleep(WAIT_TIME) + + +class MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming, partial_hang): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + if partial_hang: + self.stream_stream = hang_partial_stream_stream + else: + self.stream_stream = hang_stream_stream + elif self.request_streaming: + if partial_hang: + self.stream_unary = hang_partial_stream_unary + else: + self.stream_unary = hang_stream_unary + elif self.response_streaming: + if partial_hang: + self.unary_stream = hang_partial_unary_stream + else: + self.unary_stream = hang_unary_stream + else: + self.unary_unary = hang_unary_unary + + +class GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == UNARY_UNARY: + return MethodHandler(False, False, False) + elif handler_call_details.method == UNARY_STREAM: + return MethodHandler(False, True, False) + elif handler_call_details.method == STREAM_UNARY: + return MethodHandler(True, False, False) + elif handler_call_details.method == STREAM_STREAM: + return MethodHandler(True, True, False) + elif handler_call_details.method == PARTIAL_UNARY_STREAM: + return MethodHandler(False, True, True) + elif handler_call_details.method == PARTIAL_STREAM_UNARY: + return MethodHandler(True, False, True) + elif handler_call_details.method == PARTIAL_STREAM_STREAM: + return MethodHandler(True, True, True) + else: + return None + + +# Traditional executors will not exit until all their +# current jobs complete. Because we submit jobs that will +# never finish, we don't want to block exit on these jobs. +class DaemonPool(object): + + def submit(self, fn, *args, **kwargs): + thread = threading.Thread(target=fn, args=args, kwargs=kwargs) + thread.daemon = True + thread.start() + + def shutdown(self, wait=True): + pass + + +def infinite_request_iterator(): + while True: + yield REQUEST + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('scenario', type=str) + parser.add_argument( + '--wait_for_interrupt', dest='wait_for_interrupt', action='store_true') + args = parser.parse_args() + + if args.scenario == UNSTARTED_SERVER: + server = grpc.server((), DaemonPool()) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == RUNNING_SERVER: + server = grpc.server((), DaemonPool()) + port = server.add_insecure_port('[::]:0') + server.start() + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == POLL_CONNECTIVITY_NO_SERVER: + channel = grpc.insecure_channel('localhost:12345') + + def connectivity_callback(connectivity): + pass + + channel.subscribe(connectivity_callback, try_to_connect=True) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == POLL_CONNECTIVITY: + server = grpc.server((), DaemonPool()) + port = server.add_insecure_port('[::]:0') + server.start() + channel = grpc.insecure_channel('localhost:%d' % port) + + def connectivity_callback(connectivity): + pass + + channel.subscribe(connectivity_callback, try_to_connect=True) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + + else: + handler = GenericHandler() + server = grpc.server((), DaemonPool()) + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers((handler,)) + server.start() + channel = grpc.insecure_channel('localhost:%d' % port) + + method = TEST_TO_METHOD[args.scenario] + + if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL: + multi_callable = channel.unary_unary(method) + future = multi_callable.future(REQUEST) + result, call = multi_callable.with_call(REQUEST) + elif (args.scenario == IN_FLIGHT_UNARY_STREAM_CALL or + args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL): + multi_callable = channel.unary_stream(method) + response_iterator = multi_callable(REQUEST) + for response in response_iterator: + pass + elif (args.scenario == IN_FLIGHT_STREAM_UNARY_CALL or + args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL): + multi_callable = channel.stream_unary(method) + future = multi_callable.future(infinite_request_iterator()) + result, call = multi_callable.with_call( + [REQUEST] * test_constants.STREAM_LENGTH) + elif (args.scenario == IN_FLIGHT_STREAM_STREAM_CALL or + args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL): + multi_callable = channel.stream_stream(method) + response_iterator = multi_callable(infinite_request_iterator()) + for response in response_iterator: + pass diff --git a/src/python/grpcio/tests/unit/_exit_test.py b/src/python/grpcio/tests/unit/_exit_test.py new file mode 100644 index 0000000000..b0d6af73e5 --- /dev/null +++ b/src/python/grpcio/tests/unit/_exit_test.py @@ -0,0 +1,185 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests clean exit of server/client on Python Interpreter exit/sigint. + +The tests in this module spawn a subprocess for each test case, the +test is considered successful if it doesn't hang/timeout. +""" + +import atexit +import os +import signal +import six +import subprocess +import sys +import threading +import time +import unittest + +from tests.unit import _exit_scenarios + +SCENARIO_FILE = os.path.abspath(os.path.join( + os.path.dirname(os.path.realpath(__file__)), '_exit_scenarios.py')) +INTERPRETER = sys.executable +BASE_COMMAND = [INTERPRETER, SCENARIO_FILE] +BASE_SIGTERM_COMMAND = BASE_COMMAND + ['--wait_for_interrupt'] + +INIT_TIME = 1.0 + + +processes = [] +process_lock = threading.Lock() + + +# Make sure we attempt to clean up any +# processes we may have left running +def cleanup_processes(): + with process_lock: + for process in processes: + try: + process.kill() + except Exception: + pass +atexit.register(cleanup_processes) + + +def interrupt_and_wait(process): + with process_lock: + processes.append(process) + time.sleep(INIT_TIME) + os.kill(process.pid, signal.SIGINT) + process.wait() + + +def wait(process): + with process_lock: + processes.append(process) + process.wait() + + +class ExitTest(unittest.TestCase): + + def test_unstarted_server(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.UNSTARTED_SERVER], + stdout=sys.stdout, stderr=sys.stderr) + wait(process) + + def test_unstarted_server_terminate(self): + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.UNSTARTED_SERVER], + stdout=sys.stdout) + interrupt_and_wait(process) + + def test_running_server(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.RUNNING_SERVER], + stdout=sys.stdout, stderr=sys.stderr) + wait(process) + + def test_running_server_terminate(self): + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.RUNNING_SERVER], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + def test_poll_connectivity_no_server(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER], + stdout=sys.stdout, stderr=sys.stderr) + wait(process) + + def test_poll_connectivity_no_server_terminate(self): + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + def test_poll_connectivity(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY], + stdout=sys.stdout, stderr=sys.stderr) + wait(process) + + def test_poll_connectivity_terminate(self): + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + def test_in_flight_unary_unary_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + def test_in_flight_unary_stream_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + def test_in_flight_stream_unary_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + def test_in_flight_stream_stream_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + def test_in_flight_partial_unary_stream_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + def test_in_flight_partial_stream_unary_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + def test_in_flight_partial_stream_stream_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_metadata_code_details_test.py b/src/python/grpcio/tests/unit/_metadata_code_details_test.py new file mode 100644 index 0000000000..0fd02d2a22 --- /dev/null +++ b/src/python/grpcio/tests/unit/_metadata_code_details_test.py @@ -0,0 +1,523 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests application-provided metadata, status code, and details.""" + +import threading +import unittest + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit import test_common +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZED_REQUEST = b'\x46\x47\x48' +_SERIALIZED_RESPONSE = b'\x49\x50\x51' + +_REQUEST_SERIALIZER = lambda unused_request: _SERIALIZED_REQUEST +_REQUEST_DESERIALIZER = lambda unused_serialized_request: object() +_RESPONSE_SERIALIZER = lambda unused_response: _SERIALIZED_RESPONSE +_RESPONSE_DESERIALIZER = lambda unused_serialized_resopnse: object() + +_SERVICE = 'test.TestService' +_UNARY_UNARY = 'UnaryUnary' +_UNARY_STREAM = 'UnaryStream' +_STREAM_UNARY = 'StreamUnary' +_STREAM_STREAM = 'StreamStream' + +_CLIENT_METADATA = ( + ('client-md-key', 'client-md-key'), + ('client-md-key-bin', b'\x00\x01') +) + +_SERVER_INITIAL_METADATA = ( + ('server-initial-md-key', 'server-initial-md-value'), + ('server-initial-md-key-bin', b'\x00\x02') +) + +_SERVER_TRAILING_METADATA = ( + ('server-trailing-md-key', 'server-trailing-md-value'), + ('server-trailing-md-key-bin', b'\x00\x03') +) + +_NON_OK_CODE = grpc.StatusCode.NOT_FOUND +_DETAILS = 'Test details!' + + +class _Servicer(object): + + def __init__(self): + self._lock = threading.Lock() + self._code = None + self._details = None + self._exception = False + self._return_none = False + self._received_client_metadata = None + + def unary_unary(self, request, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + if self._exception: + raise test_control.Defect() + else: + return None if self._return_none else object() + + def unary_stream(self, request, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + for _ in range(test_constants.STREAM_LENGTH // 2): + yield _SERIALIZED_RESPONSE + if self._exception: + raise test_control.Defect() + + def stream_unary(self, request_iterator, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the + # request iterator. + for ignored_request in request_iterator: + pass + if self._exception: + raise test_control.Defect() + else: + return None if self._return_none else _SERIALIZED_RESPONSE + + def stream_stream(self, request_iterator, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the + # request iterator. + for ignored_request in request_iterator: + pass + for _ in range(test_constants.STREAM_LENGTH // 3): + yield object() + if self._exception: + raise test_control.Defect() + + def set_code(self, code): + with self._lock: + self._code = code + + def set_details(self, details): + with self._lock: + self._details = details + + def set_exception(self): + with self._lock: + self._exception = True + + def set_return_none(self): + with self._lock: + self._return_none = True + + def received_client_metadata(self): + with self._lock: + return self._received_client_metadata + + +def _generic_handler(servicer): + method_handlers = { + _UNARY_UNARY: grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, request_deserializer=_REQUEST_DESERIALIZER, + response_serializer=_RESPONSE_SERIALIZER), + _UNARY_STREAM: grpc.unary_stream_rpc_method_handler( + servicer.unary_stream), + _STREAM_UNARY: grpc.stream_unary_rpc_method_handler( + servicer.stream_unary), + _STREAM_STREAM: grpc.stream_stream_rpc_method_handler( + servicer.stream_stream, request_deserializer=_REQUEST_DESERIALIZER, + response_serializer=_RESPONSE_SERIALIZER), + } + return grpc.method_handlers_generic_handler(_SERVICE, method_handlers) + + +class MetadataCodeDetailsTest(unittest.TestCase): + + def setUp(self): + self._servicer = _Servicer() + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + self._server = grpc.server( + (_generic_handler(self._servicer),), self._server_pool) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + + channel = grpc.insecure_channel('localhost:{}'.format(port)) + self._unary_unary = channel.unary_unary( + '/'.join(('', _SERVICE, _UNARY_UNARY,)), + request_serializer=_REQUEST_SERIALIZER, + response_deserializer=_RESPONSE_DESERIALIZER,) + self._unary_stream = channel.unary_stream( + '/'.join(('', _SERVICE, _UNARY_STREAM,)),) + self._stream_unary = channel.stream_unary( + '/'.join(('', _SERVICE, _STREAM_UNARY,)),) + self._stream_stream = channel.stream_stream( + '/'.join(('', _SERVICE, _STREAM_STREAM,)), + request_serializer=_REQUEST_SERIALIZER, + response_deserializer=_RESPONSE_DESERIALIZER,) + + + def testSuccessfulUnaryUnary(self): + self._servicer.set_details(_DETAILS) + + unused_response, call = self._unary_unary.with_call( + object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, call.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testSuccessfulUnaryStream(self): + self._servicer.set_details(_DETAILS) + + call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testSuccessfulStreamUnary(self): + self._servicer.set_details(_DETAILS) + + unused_response, call = self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, call.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testSuccessfulStreamStream(self): + self._servicer.set_details(_DETAILS) + + call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testCustomCodeUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeUnaryStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + with self.assertRaises(grpc.RpcError): + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testCustomCodeStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeStreamStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + with self.assertRaises(grpc.RpcError) as exception_context: + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeExceptionUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeExceptionUnaryStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + with self.assertRaises(grpc.RpcError): + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testCustomCodeExceptionStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeExceptionStreamStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + with self.assertRaises(grpc.RpcError): + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testCustomCodeReturnNoneUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_return_none() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeReturnNoneStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_return_none() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_metadata_test.py b/src/python/grpcio/tests/unit/_metadata_test.py index 2cb13f236b..c637a28039 100644 --- a/src/python/grpcio/tests/unit/_metadata_test.py +++ b/src/python/grpcio/tests/unit/_metadata_test.py @@ -44,33 +44,33 @@ _CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'), _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x00\x00\x00' -_UNARY_UNARY = b'/test/UnaryUnary' -_UNARY_STREAM = b'/test/UnaryStream' -_STREAM_UNARY = b'/test/StreamUnary' -_STREAM_STREAM = b'/test/StreamStream' +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' _USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__) _CLIENT_METADATA = ( - (b'client-md-key', b'client-md-key'), - (b'client-md-key-bin', b'\x00\x01') + ('client-md-key', 'client-md-key'), + ('client-md-key-bin', b'\x00\x01') ) _SERVER_INITIAL_METADATA = ( - (b'server-initial-md-key', b'server-initial-md-value'), - (b'server-initial-md-key-bin', b'\x00\x02') + ('server-initial-md-key', 'server-initial-md-value'), + ('server-initial-md-key-bin', b'\x00\x02') ) _SERVER_TRAILING_METADATA = ( - (b'server-trailing-md-key', b'server-trailing-md-value'), - (b'server-trailing-md-key-bin', b'\x00\x03') + ('server-trailing-md-key', 'server-trailing-md-value'), + ('server-trailing-md-key-bin', b'\x00\x03') ) def user_agent(metadata): for key, val in metadata: - if key == b'user-agent': - return val.decode('ascii') + if key == 'user-agent': + return val raise KeyError('No user agent!') diff --git a/src/python/grpcio/tests/unit/_rpc_test.py b/src/python/grpcio/tests/unit/_rpc_test.py index 9814504edf..c70d65a6df 100644 --- a/src/python/grpcio/tests/unit/_rpc_test.py +++ b/src/python/grpcio/tests/unit/_rpc_test.py @@ -45,10 +45,10 @@ _DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] _SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 _DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] -_UNARY_UNARY = b'/test/UnaryUnary' -_UNARY_STREAM = b'/test/UnaryStream' -_STREAM_UNARY = b'/test/StreamUnary' -_STREAM_STREAM = b'/test/StreamStream' +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' class _Callback(object): @@ -79,7 +79,7 @@ class _Handler(object): def handle_unary_unary(self, request, servicer_context): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) return request def handle_unary_stream(self, request, servicer_context): @@ -88,7 +88,7 @@ class _Handler(object): yield request self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) def handle_stream_unary(self, request_iterator, servicer_context): if servicer_context is not None: @@ -100,13 +100,13 @@ class _Handler(object): response_elements.append(request) self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) return b''.join(response_elements) def handle_stream_stream(self, request_iterator, servicer_context): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) for request in request_iterator: self._control.control() yield request @@ -185,7 +185,7 @@ class RPCTest(unittest.TestCase): self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) self._server = grpc.server((), self._server_pool) - port = self._server.add_insecure_port(b'[::]:0') + port = self._server.add_insecure_port('[::]:0') self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) self._server.start() @@ -195,7 +195,7 @@ class RPCTest(unittest.TestCase): request = b'abc' with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(b'NoSuchMethod')(request) + self._channel.unary_unary('NoSuchMethod')(request) self.assertEqual( grpc.StatusCode.UNIMPLEMENTED, exception_context.exception.code()) @@ -207,7 +207,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_unary_multi_callable(self._channel) response = multi_callable( request, metadata=( - (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponse'),)) + ('test', 'SuccessfulUnaryRequestBlockingUnaryResponse'),)) self.assertEqual(expected_response, response) @@ -218,7 +218,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_unary_multi_callable(self._channel) response, call = multi_callable.with_call( request, metadata=( - (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),)) + ('test', 'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),)) self.assertEqual(expected_response, response) self.assertIs(grpc.StatusCode.OK, call.code()) @@ -230,7 +230,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_unary_multi_callable(self._channel) response_future = multi_callable.future( request, metadata=( - (b'test', b'SuccessfulUnaryRequestFutureUnaryResponse'),)) + ('test', 'SuccessfulUnaryRequestFutureUnaryResponse'),)) response = response_future.result() self.assertEqual(expected_response, response) @@ -242,7 +242,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_stream_multi_callable(self._channel) response_iterator = multi_callable( request, - metadata=((b'test', b'SuccessfulUnaryRequestStreamResponse'),)) + metadata=(('test', 'SuccessfulUnaryRequestStreamResponse'),)) responses = tuple(response_iterator) self.assertSequenceEqual(expected_responses, responses) @@ -255,7 +255,7 @@ class RPCTest(unittest.TestCase): multi_callable = _stream_unary_multi_callable(self._channel) response = multi_callable( request_iterator, - metadata=((b'test', b'SuccessfulStreamRequestBlockingUnaryResponse'),)) + metadata=(('test', 'SuccessfulStreamRequestBlockingUnaryResponse'),)) self.assertEqual(expected_response, response) @@ -268,7 +268,7 @@ class RPCTest(unittest.TestCase): response, call = multi_callable.with_call( request_iterator, metadata=( - (b'test', b'SuccessfulStreamRequestBlockingUnaryResponseWithCall'), + ('test', 'SuccessfulStreamRequestBlockingUnaryResponseWithCall'), )) self.assertEqual(expected_response, response) @@ -283,7 +283,7 @@ class RPCTest(unittest.TestCase): response_future = multi_callable.future( request_iterator, metadata=( - (b'test', b'SuccessfulStreamRequestFutureUnaryResponse'),)) + ('test', 'SuccessfulStreamRequestFutureUnaryResponse'),)) response = response_future.result() self.assertEqual(expected_response, response) @@ -297,7 +297,7 @@ class RPCTest(unittest.TestCase): multi_callable = _stream_stream_multi_callable(self._channel) response_iterator = multi_callable( request_iterator, - metadata=((b'test', b'SuccessfulStreamRequestStreamResponse'),)) + metadata=(('test', 'SuccessfulStreamRequestStreamResponse'),)) responses = tuple(response_iterator) self.assertSequenceEqual(expected_responses, responses) @@ -312,9 +312,9 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_unary_multi_callable(self._channel) first_response = multi_callable( - first_request, metadata=((b'test', b'SequentialInvocations'),)) + first_request, metadata=(('test', 'SequentialInvocations'),)) second_response = multi_callable( - second_request, metadata=((b'test', b'SequentialInvocations'),)) + second_request, metadata=(('test', 'SequentialInvocations'),)) self.assertEqual(expected_first_response, first_response) self.assertEqual(expected_second_response, second_response) @@ -331,7 +331,7 @@ class RPCTest(unittest.TestCase): request_iterator = iter(requests) response_future = pool.submit( multi_callable, request_iterator, - metadata=((b'test', b'ConcurrentBlockingInvocations'),)) + metadata=(('test', 'ConcurrentBlockingInvocations'),)) response_futures[index] = response_future responses = tuple( response_future.result() for response_future in response_futures) @@ -350,7 +350,7 @@ class RPCTest(unittest.TestCase): request_iterator = iter(requests) response_future = multi_callable.future( request_iterator, - metadata=((b'test', b'ConcurrentFutureInvocations'),)) + metadata=(('test', 'ConcurrentFutureInvocations'),)) response_futures[index] = response_future responses = tuple( response_future.result() for response_future in response_futures) @@ -380,8 +380,8 @@ class RPCTest(unittest.TestCase): inner_response_future = multi_callable.future( request, metadata=( - (b'test', - b'WaitingForSomeButNotAllConcurrentFutureInvocations'),)) + ('test', + 'WaitingForSomeButNotAllConcurrentFutureInvocations'),)) outer_response_future = pool.submit(wrap_future(inner_response_future)) response_futures[index] = outer_response_future @@ -400,7 +400,7 @@ class RPCTest(unittest.TestCase): response_iterator = multi_callable( request, metadata=( - (b'test', b'ConsumingOneStreamResponseUnaryRequest'),)) + ('test', 'ConsumingOneStreamResponseUnaryRequest'),)) next(response_iterator) def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self): @@ -410,7 +410,7 @@ class RPCTest(unittest.TestCase): response_iterator = multi_callable( request, metadata=( - (b'test', b'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) + ('test', 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) for _ in range(test_constants.STREAM_LENGTH // 2): next(response_iterator) @@ -422,7 +422,7 @@ class RPCTest(unittest.TestCase): response_iterator = multi_callable( request_iterator, metadata=( - (b'test', b'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) + ('test', 'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) for _ in range(test_constants.STREAM_LENGTH // 2): next(response_iterator) @@ -434,7 +434,7 @@ class RPCTest(unittest.TestCase): response_iterator = multi_callable( request_iterator, metadata=( - (b'test', b'ConsumingTooManyStreamResponsesStreamRequest'),)) + ('test', 'ConsumingTooManyStreamResponsesStreamRequest'),)) for _ in range(test_constants.STREAM_LENGTH): next(response_iterator) for _ in range(test_constants.STREAM_LENGTH): @@ -453,7 +453,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_future = multi_callable.future( request, - metadata=((b'test', b'CancelledUnaryRequestUnaryResponse'),)) + metadata=(('test', 'CancelledUnaryRequestUnaryResponse'),)) response_future.cancel() self.assertTrue(response_future.cancelled()) @@ -468,7 +468,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_iterator = multi_callable( request, - metadata=((b'test', b'CancelledUnaryRequestStreamResponse'),)) + metadata=(('test', 'CancelledUnaryRequestStreamResponse'),)) self._control.block_until_paused() response_iterator.cancel() @@ -488,7 +488,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_future = multi_callable.future( request_iterator, - metadata=((b'test', b'CancelledStreamRequestUnaryResponse'),)) + metadata=(('test', 'CancelledStreamRequestUnaryResponse'),)) self._control.block_until_paused() response_future.cancel() @@ -508,7 +508,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_iterator = multi_callable( request_iterator, - metadata=((b'test', b'CancelledStreamRequestStreamResponse'),)) + metadata=(('test', 'CancelledStreamRequestStreamResponse'),)) response_iterator.cancel() with self.assertRaises(grpc.RpcError): @@ -526,7 +526,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable.with_call( request, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredUnaryRequestBlockingUnaryResponse'),)) + metadata=(('test', 'ExpiredUnaryRequestBlockingUnaryResponse'),)) self.assertIsNotNone(exception_context.exception.initial_metadata()) self.assertIs( @@ -542,7 +542,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_future = multi_callable.future( request, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredUnaryRequestFutureUnaryResponse'),)) + metadata=(('test', 'ExpiredUnaryRequestFutureUnaryResponse'),)) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() @@ -567,7 +567,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: response_iterator = multi_callable( request, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredUnaryRequestStreamResponse'),)) + metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),)) next(response_iterator) self.assertIs( @@ -583,7 +583,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable( request_iterator, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredStreamRequestBlockingUnaryResponse'),)) + metadata=(('test', 'ExpiredStreamRequestBlockingUnaryResponse'),)) self.assertIsNotNone(exception_context.exception.initial_metadata()) self.assertIs( @@ -600,7 +600,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_future = multi_callable.future( request_iterator, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredStreamRequestFutureUnaryResponse'),)) + metadata=(('test', 'ExpiredStreamRequestFutureUnaryResponse'),)) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() @@ -625,7 +625,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: response_iterator = multi_callable( request_iterator, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredStreamRequestStreamResponse'),)) + metadata=(('test', 'ExpiredStreamRequestStreamResponse'),)) next(response_iterator) self.assertIs( @@ -640,7 +640,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable.with_call( request, - metadata=((b'test', b'FailedUnaryRequestBlockingUnaryResponse'),)) + metadata=(('test', 'FailedUnaryRequestBlockingUnaryResponse'),)) self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) @@ -652,7 +652,7 @@ class RPCTest(unittest.TestCase): with self._control.fail(): response_future = multi_callable.future( request, - metadata=((b'test', b'FailedUnaryRequestFutureUnaryResponse'),)) + metadata=(('test', 'FailedUnaryRequestFutureUnaryResponse'),)) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() @@ -672,7 +672,7 @@ class RPCTest(unittest.TestCase): with self._control.fail(): response_iterator = multi_callable( request, - metadata=((b'test', b'FailedUnaryRequestStreamResponse'),)) + metadata=(('test', 'FailedUnaryRequestStreamResponse'),)) next(response_iterator) self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) @@ -686,7 +686,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable( request_iterator, - metadata=((b'test', b'FailedStreamRequestBlockingUnaryResponse'),)) + metadata=(('test', 'FailedStreamRequestBlockingUnaryResponse'),)) self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) @@ -699,7 +699,7 @@ class RPCTest(unittest.TestCase): with self._control.fail(): response_future = multi_callable.future( request_iterator, - metadata=((b'test', b'FailedStreamRequestFutureUnaryResponse'),)) + metadata=(('test', 'FailedStreamRequestFutureUnaryResponse'),)) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() @@ -720,7 +720,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: response_iterator = multi_callable( request_iterator, - metadata=((b'test', b'FailedStreamRequestStreamResponse'),)) + metadata=(('test', 'FailedStreamRequestStreamResponse'),)) tuple(response_iterator) self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) @@ -732,7 +732,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_unary_multi_callable(self._channel) multi_callable.future( request, - metadata=((b'test', b'IgnoredUnaryRequestFutureUnaryResponse'),)) + metadata=(('test', 'IgnoredUnaryRequestFutureUnaryResponse'),)) def testIgnoredUnaryRequestStreamResponse(self): request = b'\x37\x17' @@ -740,7 +740,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_stream_multi_callable(self._channel) multi_callable( request, - metadata=((b'test', b'IgnoredUnaryRequestStreamResponse'),)) + metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),)) def testIgnoredStreamRequestFutureUnaryResponse(self): requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) @@ -749,7 +749,7 @@ class RPCTest(unittest.TestCase): multi_callable = _stream_unary_multi_callable(self._channel) multi_callable.future( request_iterator, - metadata=((b'test', b'IgnoredStreamRequestFutureUnaryResponse'),)) + metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),)) def testIgnoredStreamRequestStreamResponse(self): requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) @@ -758,7 +758,7 @@ class RPCTest(unittest.TestCase): multi_callable = _stream_stream_multi_callable(self._channel) multi_callable( request_iterator, - metadata=((b'test', b'IgnoredStreamRequestStreamResponse'),)) + metadata=(('test', 'IgnoredStreamRequestStreamResponse'),)) if __name__ == '__main__': diff --git a/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py b/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py index 488f7d7141..5d826a269d 100644 --- a/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py +++ b/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py @@ -29,162 +29,9 @@ """Tests of grpc.beta._connectivity_channel.""" -import threading -import time import unittest -from grpc._adapter import _low -from grpc._adapter import _types -from grpc.beta import _connectivity_channel from grpc.beta import interfaces -from tests.unit.framework.common import test_constants - - -def _drive_completion_queue(completion_queue): - while True: - event = completion_queue.next(time.time() + 24 * 60 * 60) - if event.type == _types.EventType.QUEUE_SHUTDOWN: - break - - -class _Callback(object): - - def __init__(self): - self._condition = threading.Condition() - self._connectivities = [] - - def update(self, connectivity): - with self._condition: - self._connectivities.append(connectivity) - self._condition.notify() - - def connectivities(self): - with self._condition: - return tuple(self._connectivities) - - def block_until_connectivities_satisfy(self, predicate): - with self._condition: - while True: - connectivities = tuple(self._connectivities) - if predicate(connectivities): - return connectivities - else: - self._condition.wait() - - -class ChannelConnectivityTest(unittest.TestCase): - - def test_lonely_channel_connectivity(self): - low_channel = _low.Channel('localhost:12345', ()) - callback = _Callback() - - connectivity_channel = _connectivity_channel.ConnectivityChannel( - low_channel) - connectivity_channel.subscribe(callback.update, try_to_connect=False) - first_connectivities = callback.block_until_connectivities_satisfy(bool) - connectivity_channel.subscribe(callback.update, try_to_connect=True) - second_connectivities = callback.block_until_connectivities_satisfy( - lambda connectivities: 2 <= len(connectivities)) - # Wait for a connection that will never happen. - time.sleep(test_constants.SHORT_TIMEOUT) - third_connectivities = callback.connectivities() - connectivity_channel.unsubscribe(callback.update) - fourth_connectivities = callback.connectivities() - connectivity_channel.unsubscribe(callback.update) - fifth_connectivities = callback.connectivities() - - self.assertSequenceEqual( - (interfaces.ChannelConnectivity.IDLE,), first_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.READY, second_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.READY, third_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.READY, fourth_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.READY, fifth_connectivities) - - def test_immediately_connectable_channel_connectivity(self): - server_completion_queue = _low.CompletionQueue() - server = _low.Server(server_completion_queue, []) - port = server.add_http2_port('[::]:0') - server.start() - server_completion_queue_thread = threading.Thread( - target=_drive_completion_queue, args=(server_completion_queue,)) - server_completion_queue_thread.start() - low_channel = _low.Channel('localhost:%d' % port, ()) - first_callback = _Callback() - second_callback = _Callback() - - connectivity_channel = _connectivity_channel.ConnectivityChannel( - low_channel) - connectivity_channel.subscribe(first_callback.update, try_to_connect=False) - first_connectivities = first_callback.block_until_connectivities_satisfy( - bool) - # Wait for a connection that will never happen because try_to_connect=True - # has not yet been passed. - time.sleep(test_constants.SHORT_TIMEOUT) - second_connectivities = first_callback.connectivities() - connectivity_channel.subscribe(second_callback.update, try_to_connect=True) - third_connectivities = first_callback.block_until_connectivities_satisfy( - lambda connectivities: 2 <= len(connectivities)) - fourth_connectivities = second_callback.block_until_connectivities_satisfy( - bool) - # Wait for a connection that will happen (or may already have happened). - first_callback.block_until_connectivities_satisfy( - lambda connectivities: - interfaces.ChannelConnectivity.READY in connectivities) - second_callback.block_until_connectivities_satisfy( - lambda connectivities: - interfaces.ChannelConnectivity.READY in connectivities) - connectivity_channel.unsubscribe(first_callback.update) - connectivity_channel.unsubscribe(second_callback.update) - - server.shutdown() - server_completion_queue.shutdown() - server_completion_queue_thread.join() - - self.assertSequenceEqual( - (interfaces.ChannelConnectivity.IDLE,), first_connectivities) - self.assertSequenceEqual( - (interfaces.ChannelConnectivity.IDLE,), second_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.FATAL_FAILURE, third_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.TRANSIENT_FAILURE, - fourth_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.FATAL_FAILURE, fourth_connectivities) - - def test_reachable_then_unreachable_channel_connectivity(self): - server_completion_queue = _low.CompletionQueue() - server = _low.Server(server_completion_queue, []) - port = server.add_http2_port('[::]:0') - server.start() - server_completion_queue_thread = threading.Thread( - target=_drive_completion_queue, args=(server_completion_queue,)) - server_completion_queue_thread.start() - low_channel = _low.Channel('localhost:%d' % port, ()) - callback = _Callback() - - connectivity_channel = _connectivity_channel.ConnectivityChannel( - low_channel) - connectivity_channel.subscribe(callback.update, try_to_connect=True) - callback.block_until_connectivities_satisfy( - lambda connectivities: - interfaces.ChannelConnectivity.READY in connectivities) - # Now take down the server and confirm that channel readiness is repudiated. - server.shutdown() - callback.block_until_connectivities_satisfy( - lambda connectivities: - connectivities[-1] is not interfaces.ChannelConnectivity.READY) - connectivity_channel.unsubscribe(callback.update) - - server.shutdown() - server_completion_queue.shutdown() - server_completion_queue_thread.join() class ConnectivityStatesTest(unittest.TestCase): diff --git a/src/python/grpcio/tests/unit/beta/_utilities_test.py b/src/python/grpcio/tests/unit/beta/_utilities_test.py index 08ce98e751..90fe10c77c 100644 --- a/src/python/grpcio/tests/unit/beta/_utilities_test.py +++ b/src/python/grpcio/tests/unit/beta/_utilities_test.py @@ -33,21 +33,12 @@ import threading import time import unittest -from grpc._adapter import _low -from grpc._adapter import _types from grpc.beta import implementations from grpc.beta import utilities from grpc.framework.foundation import future from tests.unit.framework.common import test_constants -def _drive_completion_queue(completion_queue): - while True: - event = completion_queue.next(time.time() + 24 * 60 * 60) - if event.type == _types.EventType.QUEUE_SHUTDOWN: - break - - class _Callback(object): def __init__(self): @@ -87,13 +78,9 @@ class ChannelConnectivityTest(unittest.TestCase): self.assertFalse(ready_future.running()) def test_immediately_connectable_channel_connectivity(self): - server_completion_queue = _low.CompletionQueue() - server = _low.Server(server_completion_queue, []) - port = server.add_http2_port('[::]:0') + server = implementations.server({}) + port = server.add_insecure_port('[::]:0') server.start() - server_completion_queue_thread = threading.Thread( - target=_drive_completion_queue, args=(server_completion_queue,)) - server_completion_queue_thread.start() channel = implementations.insecure_channel('localhost', port) callback = _Callback() @@ -114,9 +101,7 @@ class ChannelConnectivityTest(unittest.TestCase): self.assertFalse(ready_future.running()) finally: ready_future.cancel() - server.shutdown() - server_completion_queue.shutdown() - server_completion_queue_thread.join() + server.stop(0) if __name__ == '__main__': diff --git a/src/python/grpcio/tests/unit/beta/test_utilities.py b/src/python/grpcio/tests/unit/beta/test_utilities.py index 8ccad04e05..692da9c97d 100644 --- a/src/python/grpcio/tests/unit/beta/test_utilities.py +++ b/src/python/grpcio/tests/unit/beta/test_utilities.py @@ -51,5 +51,5 @@ def not_really_secure_channel( target = '%s:%d' % (host, port) channel = grpc.secure_channel( target, channel_credentials, - ((b'grpc.ssl_target_name_override', server_host_override,),)) + (('grpc.ssl_target_name_override', server_host_override,),)) return implementations.Channel(channel) diff --git a/src/python/grpcio/tests/unit/test_common.py b/src/python/grpcio/tests/unit/test_common.py index b779f65e7e..c8886bf4ca 100644 --- a/src/python/grpcio/tests/unit/test_common.py +++ b/src/python/grpcio/tests/unit/test_common.py @@ -33,10 +33,10 @@ import collections import six -INVOCATION_INITIAL_METADATA = ((b'0', b'abc'), (b'1', b'def'), (b'2', b'ghi'),) -SERVICE_INITIAL_METADATA = ((b'3', b'jkl'), (b'4', b'mno'), (b'5', b'pqr'),) -SERVICE_TERMINAL_METADATA = ((b'6', b'stu'), (b'7', b'vwx'), (b'8', b'yza'),) -DETAILS = b'test details' +INVOCATION_INITIAL_METADATA = (('0', 'abc'), ('1', 'def'), ('2', 'ghi'),) +SERVICE_INITIAL_METADATA = (('3', 'jkl'), ('4', 'mno'), ('5', 'pqr'),) +SERVICE_TERMINAL_METADATA = (('6', 'stu'), ('7', 'vwx'), ('8', 'yza'),) +DETAILS = 'test details' def metadata_transmitted(original_metadata, transmitted_metadata): @@ -59,16 +59,10 @@ def metadata_transmitted(original_metadata, transmitted_metadata): original_metadata after having been transmitted via gRPC. """ original = collections.defaultdict(list) - for key_value_pair in original_metadata: - key, value = tuple(key_value_pair) - if not isinstance(key, bytes): - key = key.encode() - if not isinstance(value, bytes): - value = value.encode() + for key, value in original_metadata: original[key].append(value) transmitted = collections.defaultdict(list) - for key_value_pair in transmitted_metadata: - key, value = tuple(key_value_pair) + for key, value in transmitted_metadata: transmitted[key].append(value) for key, values in six.iteritems(original): |