diff options
author | Sree Kuchibhotla <sreek@google.com> | 2016-06-06 10:31:45 -0700 |
---|---|---|
committer | Sree Kuchibhotla <sreek@google.com> | 2016-06-06 10:31:45 -0700 |
commit | 1fd6235791370e511093941e23432b446328f3e1 (patch) | |
tree | 27cae5ec3d74d36f5b0c8e8298fc97849a347579 /src/python/grpcio/tests | |
parent | 9bc3d2d67f32f4dad8cd1319dd4f3fce48c1abee (diff) | |
parent | 37e26d922fc443b67df50abd8b11288eba2920cd (diff) |
Merge branch 'master' into epoll_changes
Diffstat (limited to 'src/python/grpcio/tests')
-rw-r--r-- | src/python/grpcio/tests/interop/client.py | 39 | ||||
-rw-r--r-- | src/python/grpcio/tests/interop/methods.py | 30 | ||||
-rw-r--r-- | src/python/grpcio/tests/tests.json | 7 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/_channel_connectivity_test.py | 161 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/_channel_ready_future_test.py | 103 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py | 251 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/_rpc_test.py | 775 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/beta/_auth_test.py | 96 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/beta/_implementations_test.py | 17 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/framework/common/test_control.py | 26 |
10 files changed, 1476 insertions, 29 deletions
diff --git a/src/python/grpcio/tests/interop/client.py b/src/python/grpcio/tests/interop/client.py index db29eb4aa7..e3d5545a02 100644 --- a/src/python/grpcio/tests/interop/client.py +++ b/src/python/grpcio/tests/interop/client.py @@ -65,39 +65,34 @@ def _args(): help='email address of the default service account', type=str) return parser.parse_args() -def _oauth_access_token(args): - credentials = oauth2client_client.GoogleCredentials.get_application_default() - scoped_credentials = credentials.create_scoped([args.oauth_scope]) - return scoped_credentials.get_access_token().access_token def _stub(args): - if args.oauth_scope: - if args.test_case == 'oauth2_auth_token': - # TODO(jtattermusch): This testcase sets the auth metadata key-value - # manually, which also means that the user would need to do the same - # thing every time he/she would like to use and out of band oauth token. - # The transformer function that produces the metadata key-value from - # the access token should be provided by gRPC auth library. - access_token = _oauth_access_token(args) - metadata_transformer = lambda x: [ - ('authorization', 'Bearer %s' % access_token)] - else: - metadata_transformer = lambda x: [ - ('authorization', 'Bearer %s' % _oauth_access_token(args))] + if args.test_case == 'oauth2_auth_token': + creds = oauth2client_client.GoogleCredentials.get_application_default() + scoped_creds = creds.create_scoped([args.oauth_scope]) + access_token = scoped_creds.get_access_token().access_token + call_creds = implementations.access_token_call_credentials(access_token) + elif args.test_case == 'compute_engine_creds': + creds = oauth2client_client.GoogleCredentials.get_application_default() + scoped_creds = creds.create_scoped([args.oauth_scope]) + call_creds = implementations.google_call_credentials(scoped_creds) else: - metadata_transformer = lambda x: [] + call_creds = None if args.use_tls: if args.use_test_ca: root_certificates = resources.test_root_certificates() else: root_certificates = None # will load default roots. + channel_creds = implementations.ssl_channel_credentials(root_certificates) + if call_creds is not None: + channel_creds = implementations.composite_channel_credentials( + channel_creds, call_creds) + channel = test_utilities.not_really_secure_channel( - args.server_host, args.server_port, - implementations.ssl_channel_credentials(root_certificates), + args.server_host, args.server_port, channel_creds, args.server_host_override) - stub = test_pb2.beta_create_TestService_stub( - channel, metadata_transformer=metadata_transformer) + stub = test_pb2.beta_create_TestService_stub(channel) else: channel = implementations.insecure_channel( args.server_host, args.server_port) diff --git a/src/python/grpcio/tests/interop/methods.py b/src/python/grpcio/tests/interop/methods.py index 67862ed7d3..d5ef0c68bb 100644 --- a/src/python/grpcio/tests/interop/methods.py +++ b/src/python/grpcio/tests/interop/methods.py @@ -39,6 +39,8 @@ import time from oauth2client import client as oauth2client_client +from grpc.beta import implementations +from grpc.beta import interfaces from grpc.framework.common import cardinality from grpc.framework.interfaces.face import face @@ -88,13 +90,15 @@ class TestService(test_pb2.BetaTestServiceServicer): return self.FullDuplexCall(request_iterator, context) -def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope): +def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope, + protocol_options=None): with stub: request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=314159, payload=messages_pb2.Payload(body=b'\x00' * 271828), fill_username=fill_username, fill_oauth_scope=fill_oauth_scope) - response_future = stub.UnaryCall.future(request, _TIMEOUT) + response_future = stub.UnaryCall.future(request, _TIMEOUT, + protocol_options=protocol_options) response = response_future.result() if response.payload.type is not messages_pb2.COMPRESSABLE: raise ValueError( @@ -303,7 +307,24 @@ def _oauth2_auth_token(stub, args): if args.oauth_scope.find(response.oauth_scope) == -1: raise ValueError( 'expected to find oauth scope "%s" in received "%s"' % - (response.oauth_scope, args.oauth_scope)) + (response.oauth_scope, args.oauth_scope)) + + +def _per_rpc_creds(stub, args): + json_key_filename = os.environ[ + oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS] + wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] + credentials = oauth2client_client.GoogleCredentials.get_application_default() + scoped_credentials = credentials.create_scoped([args.oauth_scope]) + call_creds = implementations.google_call_credentials(scoped_credentials) + options = interfaces.grpc_call_options(disable_compression=False, + credentials=call_creds) + response = _large_unary_common_behavior(stub, True, False, + protocol_options=options) + if wanted_email != response.username: + raise ValueError( + 'expected username %s, got %s' % (wanted_email, response.username)) + @enum.unique class TestCase(enum.Enum): @@ -317,6 +338,7 @@ class TestCase(enum.Enum): EMPTY_STREAM = 'empty_stream' COMPUTE_ENGINE_CREDS = 'compute_engine_creds' OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' + PER_RPC_CREDS = 'per_rpc_creds' TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' def test_interoperability(self, stub, args): @@ -342,5 +364,7 @@ class TestCase(enum.Enum): _compute_engine_creds(stub, args) elif self is TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token(stub, args) + elif self is TestCase.PER_RPC_CREDS: + _per_rpc_creds(stub, args) else: raise NotImplementedError('Test case "%s" not implemented!' % self.name) diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json index 1beb619f87..8dc47bf69d 100644 --- a/src/python/grpcio/tests/tests.json +++ b/src/python/grpcio/tests/tests.json @@ -1,4 +1,6 @@ [ + "_auth_test.AccessTokenCallCredentialsTest", + "_auth_test.GoogleCallCredentialsTest", "_base_interface_test.AsyncEasyTest", "_base_interface_test.AsyncPeasyTest", "_base_interface_test.SyncEasyTest", @@ -6,6 +8,8 @@ "_beta_features_test.BetaFeaturesTest", "_beta_features_test.ContextManagementAndLifecycleTest", "_cancel_many_calls_test.CancelManyCallsTest", + "_channel_connectivity_test.ChannelConnectivityTest", + "_channel_ready_future_test.ChannelReadyFutureTest", "_channel_test.ChannelTest", "_connectivity_channel_test.ChannelConnectivityTest", "_core_over_links_base_interface_test.AsyncEasyTest", @@ -31,6 +35,7 @@ "_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest", "_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest", "_health_servicer_test.HealthServicerTest", + "_implementations_test.CallCredentialsTest", "_implementations_test.ChannelCredentialsTest", "_insecure_interop_test.InsecureInteropTest", "_intermediary_low_test.CancellationTest", @@ -43,6 +48,8 @@ "_low_test.HangingServerShutdown", "_low_test.InsecureServerInsecureClient", "_not_found_test.NotFoundTest", + "_read_some_but_not_all_responses_test.ReadSomeButNotAllResponsesTest", + "_rpc_test.RPCTest", "_sanity_test.Sanity", "_secure_interop_test.SecureInteropTest", "_transmission_test.RoundTripTest", diff --git a/src/python/grpcio/tests/unit/_channel_connectivity_test.py b/src/python/grpcio/tests/unit/_channel_connectivity_test.py new file mode 100644 index 0000000000..a1575efada --- /dev/null +++ b/src/python/grpcio/tests/unit/_channel_connectivity_test.py @@ -0,0 +1,161 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests of grpc._channel.Channel connectivity.""" + +import threading +import time +import unittest +from concurrent import futures + +import grpc +from grpc import _channel +from grpc import _server +from tests.unit.framework.common import test_constants + + +def _ready_in_connectivities(connectivities): + return grpc.ChannelConnectivity.READY in connectivities + + +def _last_connectivity_is_not_ready(connectivities): + return connectivities[-1] is not grpc.ChannelConnectivity.READY + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._connectivities = [] + + def update(self, connectivity): + with self._condition: + self._connectivities.append(connectivity) + self._condition.notify() + + def connectivities(self): + with self._condition: + return tuple(self._connectivities) + + def block_until_connectivities_satisfy(self, predicate): + with self._condition: + while True: + connectivities = tuple(self._connectivities) + if predicate(connectivities): + return connectivities + else: + self._condition.wait() + + +class ChannelConnectivityTest(unittest.TestCase): + + def test_lonely_channel_connectivity(self): + callback = _Callback() + + channel = _channel.Channel('localhost:12345', None, None) + channel.subscribe(callback.update, try_to_connect=False) + first_connectivities = callback.block_until_connectivities_satisfy(bool) + channel.subscribe(callback.update, try_to_connect=True) + second_connectivities = callback.block_until_connectivities_satisfy( + lambda connectivities: 2 <= len(connectivities)) + # Wait for a connection that will never happen. + time.sleep(test_constants.SHORT_TIMEOUT) + third_connectivities = callback.connectivities() + channel.unsubscribe(callback.update) + fourth_connectivities = callback.connectivities() + channel.unsubscribe(callback.update) + fifth_connectivities = callback.connectivities() + + self.assertSequenceEqual( + (grpc.ChannelConnectivity.IDLE,), first_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.READY, second_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.READY, third_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.READY, fourth_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.READY, fifth_connectivities) + + def test_immediately_connectable_channel_connectivity(self): + server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0)) + port = server.add_insecure_port('[::]:0') + server.start() + first_callback = _Callback() + second_callback = _Callback() + + channel = _channel.Channel('localhost:{}'.format(port), None, None) + channel.subscribe(first_callback.update, try_to_connect=False) + first_connectivities = first_callback.block_until_connectivities_satisfy( + bool) + # Wait for a connection that will never happen because try_to_connect=True + # has not yet been passed. + time.sleep(test_constants.SHORT_TIMEOUT) + second_connectivities = first_callback.connectivities() + channel.subscribe(second_callback.update, try_to_connect=True) + third_connectivities = first_callback.block_until_connectivities_satisfy( + lambda connectivities: 2 <= len(connectivities)) + fourth_connectivities = second_callback.block_until_connectivities_satisfy( + bool) + # Wait for a connection that will happen (or may already have happened). + first_callback.block_until_connectivities_satisfy(_ready_in_connectivities) + second_callback.block_until_connectivities_satisfy(_ready_in_connectivities) + del channel + + self.assertSequenceEqual( + (grpc.ChannelConnectivity.IDLE,), first_connectivities) + self.assertSequenceEqual( + (grpc.ChannelConnectivity.IDLE,), second_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.FATAL_FAILURE, third_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.TRANSIENT_FAILURE, + fourth_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.FATAL_FAILURE, fourth_connectivities) + + def test_reachable_then_unreachable_channel_connectivity(self): + server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0)) + port = server.add_insecure_port('[::]:0') + server.start() + callback = _Callback() + + channel = _channel.Channel('localhost:{}'.format(port), None, None) + channel.subscribe(callback.update, try_to_connect=True) + callback.block_until_connectivities_satisfy(_ready_in_connectivities) + # Now take down the server and confirm that channel readiness is repudiated. + server.stop(None) + callback.block_until_connectivities_satisfy(_last_connectivity_is_not_ready) + channel.unsubscribe(callback.update) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_channel_ready_future_test.py b/src/python/grpcio/tests/unit/_channel_ready_future_test.py new file mode 100644 index 0000000000..b84bc0197a --- /dev/null +++ b/src/python/grpcio/tests/unit/_channel_ready_future_test.py @@ -0,0 +1,103 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests of grpc.channel_ready_future.""" + +import threading +import unittest +from concurrent import futures + +import grpc +from grpc import _channel +from grpc import _server +from tests.unit.framework.common import test_constants + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + + def accept_value(self, value): + with self._condition: + self._value = value + self._condition.notify_all() + + def block_until_called(self): + with self._condition: + while self._value is None: + self._condition.wait() + return self._value + + +class ChannelReadyFutureTest(unittest.TestCase): + + def test_lonely_channel_connectivity(self): + channel = grpc.insecure_channel('localhost:12345') + callback = _Callback() + + ready_future = grpc.channel_ready_future(channel) + ready_future.add_done_callback(callback.accept_value) + with self.assertRaises(grpc.FutureTimeoutError): + ready_future.result(test_constants.SHORT_TIMEOUT) + self.assertFalse(ready_future.cancelled()) + self.assertFalse(ready_future.done()) + self.assertTrue(ready_future.running()) + ready_future.cancel() + value_passed_to_callback = callback.block_until_called() + self.assertIs(ready_future, value_passed_to_callback) + self.assertTrue(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + + def test_immediately_connectable_channel_connectivity(self): + server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0)) + port = server.add_insecure_port('[::]:0') + server.start() + channel = grpc.insecure_channel('localhost:{}'.format(port)) + callback = _Callback() + + ready_future = grpc.channel_ready_future(channel) + ready_future.add_done_callback(callback.accept_value) + self.assertIsNone(ready_future.result(test_constants.SHORT_TIMEOUT)) + value_passed_to_callback = callback.block_until_called() + self.assertIs(ready_future, value_passed_to_callback) + self.assertFalse(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + # Cancellation after maturity has no effect. + ready_future.cancel() + self.assertFalse(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py new file mode 100644 index 0000000000..6ae7a90fbe --- /dev/null +++ b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -0,0 +1,251 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test a corner-case at the level of the Cython API.""" + +import threading +import unittest + +from grpc._cython import cygrpc + +_INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) +_EMPTY_FLAGS = 0 +_EMPTY_METADATA = cygrpc.Metadata(()) + + +class _ServerDriver(object): + + def __init__(self, completion_queue, shutdown_tag): + self._condition = threading.Condition() + self._completion_queue = completion_queue + self._shutdown_tag = shutdown_tag + self._events = [] + self._saw_shutdown_tag = False + + def start(self): + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._condition.notify() + if event.tag is self._shutdown_tag: + self._saw_shutdown_tag = True + break + thread = threading.Thread(target=in_thread) + thread.start() + + def done(self): + with self._condition: + return self._saw_shutdown_tag + + def first_event(self): + with self._condition: + while not self._events: + self._condition.wait() + return self._events[0] + + def events(self): + with self._condition: + while not self._saw_shutdown_tag: + self._condition.wait() + return tuple(self._events) + + +class _QueueDriver(object): + + def __init__(self, condition, completion_queue, due): + self._condition = condition + self._completion_queue = completion_queue + self._due = due + self._events = [] + self._returned = False + + def start(self): + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._due.remove(event.tag) + self._condition.notify_all() + if not self._due: + self._returned = True + return + thread = threading.Thread(target=in_thread) + thread.start() + + def done(self): + with self._condition: + return self._returned + + def event_with_tag(self, tag): + with self._condition: + while True: + for event in self._events: + if event.tag is tag: + return event + self._condition.wait() + + def events(self): + with self._condition: + while not self._returned: + self._condition.wait() + return tuple(self._events) + + +class ReadSomeButNotAllResponsesTest(unittest.TestCase): + + def testReadSomeButNotAllResponses(self): + server_completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server() + server.register_completion_queue(server_completion_queue) + port = server.add_http2_port('[::]:0') + server.start() + channel = cygrpc.Channel('localhost:{}'.format(port)) + + server_shutdown_tag = 'server_shutdown_tag' + server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag) + server_driver.start() + + client_condition = threading.Condition() + client_due = set() + client_completion_queue = cygrpc.CompletionQueue() + client_driver = _QueueDriver( + client_condition, client_completion_queue, client_due) + client_driver.start() + + server_call_condition = threading.Condition() + server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' + server_send_first_message_tag = 'server_send_first_message_tag' + server_send_second_message_tag = 'server_send_second_message_tag' + server_complete_rpc_tag = 'server_complete_rpc_tag' + server_call_due = set(( + server_send_initial_metadata_tag, + server_send_first_message_tag, + server_send_second_message_tag, + server_complete_rpc_tag, + )) + server_call_completion_queue = cygrpc.CompletionQueue() + server_call_driver = _QueueDriver( + server_call_condition, server_call_completion_queue, server_call_due) + server_call_driver.start() + + server_rpc_tag = 'server_rpc_tag' + request_call_result = server.request_call( + server_call_completion_queue, server_completion_queue, server_rpc_tag) + + client_call = channel.create_call( + None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None, + _INFINITE_FUTURE) + client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' + client_complete_rpc_tag = 'client_complete_rpc_tag' + with client_condition: + client_receive_initial_metadata_start_batch_result = ( + client_call.start_batch(cygrpc.Operations([ + cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), + ]), client_receive_initial_metadata_tag)) + client_due.add(client_receive_initial_metadata_tag) + client_complete_rpc_start_batch_result = ( + client_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ]), client_complete_rpc_tag)) + client_due.add(client_complete_rpc_tag) + + server_rpc_event = server_driver.first_event() + + with server_call_condition: + server_send_initial_metadata_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS), + ]), server_send_initial_metadata_tag)) + server_send_first_message_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS), + ]), server_send_first_message_tag)) + server_send_initial_metadata_event = server_call_driver.event_with_tag( + server_send_initial_metadata_tag) + server_send_first_message_event = server_call_driver.event_with_tag( + server_send_first_message_tag) + with server_call_condition: + server_send_second_message_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS), + ]), server_send_second_message_tag)) + server_complete_rpc_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), + cygrpc.operation_send_status_from_server( + cygrpc.Metadata(()), cygrpc.StatusCode.ok, b'test details', + _EMPTY_FLAGS), + ]), server_complete_rpc_tag)) + server_send_second_message_event = server_call_driver.event_with_tag( + server_send_second_message_tag) + server_complete_rpc_event = server_call_driver.event_with_tag( + server_complete_rpc_tag) + server_call_driver.events() + + with client_condition: + client_receive_first_message_tag = 'client_receive_first_message_tag' + client_receive_first_message_start_batch_result = ( + client_call.start_batch(cygrpc.Operations([ + cygrpc.operation_receive_message(_EMPTY_FLAGS), + ]), client_receive_first_message_tag)) + client_due.add(client_receive_first_message_tag) + client_receive_first_message_event = client_driver.event_with_tag( + client_receive_first_message_tag) + + client_call_cancel_result = client_call.cancel() + client_driver.events() + + server.shutdown(server_completion_queue, server_shutdown_tag) + server.cancel_all_calls() + server_driver.events() + + self.assertEqual(cygrpc.CallError.ok, request_call_result) + self.assertEqual( + cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result) + self.assertEqual( + cygrpc.CallError.ok, client_receive_initial_metadata_start_batch_result) + self.assertEqual( + cygrpc.CallError.ok, client_complete_rpc_start_batch_result) + self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result) + self.assertIs(server_rpc_tag, server_rpc_event.tag) + self.assertEqual( + cygrpc.CompletionType.operation_complete, server_rpc_event.type) + self.assertIsInstance(server_rpc_event.operation_call, cygrpc.Call) + self.assertEqual(0, len(server_rpc_event.batch_operations)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_rpc_test.py b/src/python/grpcio/tests/unit/_rpc_test.py new file mode 100644 index 0000000000..1c7a14c5d0 --- /dev/null +++ b/src/python/grpcio/tests/unit/_rpc_test.py @@ -0,0 +1,775 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test of gRPC Python's application-layer API.""" + +import itertools +import threading +import unittest +from concurrent import futures + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) / 2:] +_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) / 3] + +_UNARY_UNARY = b'/test/UnaryUnary' +_UNARY_STREAM = b'/test/UnaryStream' +_STREAM_UNARY = b'/test/StreamUnary' +_STREAM_STREAM = b'/test/StreamStream' + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + self._called = False + + def __call__(self, value): + with self._condition: + self._value = value + self._called = True + self._condition.notify_all() + + def value(self): + with self._condition: + while not self._called: + self._condition.wait() + return self._value + + +class _Handler(object): + + def __init__(self, control): + self._control = control + + def handle_unary_unary(self, request, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + return request + + def handle_unary_stream(self, request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH): + self._control.control() + yield request + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + + def handle_stream_unary(self, request_iterator, servicer_context): + if servicer_context is not None: + servicer_context.invocation_metadata() + self._control.control() + response_elements = [] + for request in request_iterator: + self._control.control() + response_elements.append(request) + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + return b''.join(response_elements) + + def handle_stream_stream(self, request_iterator, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + for request in request_iterator: + self._control.control() + yield request + self._control.control() + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__( + self, request_streaming, response_streaming, request_deserializer, + response_serializer, unary_unary, unary_stream, stream_unary, + stream_stream): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = request_deserializer + self.response_serializer = response_serializer + self.unary_unary = unary_unary + self.unary_stream = unary_stream + self.stream_unary = stream_unary + self.stream_stream = stream_stream + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, handler): + self._handler = handler + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler( + False, False, None, None, self._handler.handle_unary_unary, None, + None, None) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler( + False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, + self._handler.handle_unary_stream, None, None) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler( + True, False, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, None, + self._handler.handle_stream_unary, None) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler( + True, True, None, None, None, None, None, + self._handler.handle_stream_stream) + else: + return None + + +def _unary_unary_multi_callable(channel): + return channel.unary_unary(_UNARY_UNARY) + + +def _unary_stream_multi_callable(channel): + return channel.unary_stream( + _UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_unary_multi_callable(channel): + return channel.stream_unary( + _STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_stream_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM) + + +class RPCTest(unittest.TestCase): + + def setUp(self): + self._control = test_control.PauseFailControl() + self._handler = _Handler(self._control) + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + + self._server = grpc.server((), self._server_pool) + port = self._server.add_insecure_port(b'[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) + self._server.start() + + self._channel = grpc.insecure_channel(b'localhost:%d' % port) + + # TODO(nathaniel): Why is this necessary, and only in some development + # environments? + def tearDown(self): + del self._channel + del self._server + del self._server_pool + + def testUnrecognizedMethod(self): + request = b'abc' + + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(b'NoSuchMethod')(request) + + self.assertEqual( + grpc.StatusCode.UNIMPLEMENTED, exception_context.exception.code()) + + def testSuccessfulUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x08' + expected_response = self._handler.handle_unary_unary(request, None) + + multi_callable = _unary_unary_multi_callable(self._channel) + response = multi_callable( + request, metadata=( + (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponse'),)) + + self.assertEqual(expected_response, response) + + def testSuccessfulUnaryRequestBlockingUnaryResponseWithCall(self): + request = b'\x07\x08' + expected_response = self._handler.handle_unary_unary(request, None) + + multi_callable = _unary_unary_multi_callable(self._channel) + response, call = multi_callable( + request, metadata=( + (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),), + with_call=True) + + self.assertEqual(expected_response, response) + self.assertIs(grpc.StatusCode.OK, call.code()) + + def testSuccessfulUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x08' + expected_response = self._handler.handle_unary_unary(request, None) + + multi_callable = _unary_unary_multi_callable(self._channel) + response_future = multi_callable.future( + request, metadata=( + (b'test', b'SuccessfulUnaryRequestFutureUnaryResponse'),)) + response = response_future.result() + + self.assertEqual(expected_response, response) + + def testSuccessfulUnaryRequestStreamResponse(self): + request = b'\x37\x58' + expected_responses = tuple(self._handler.handle_unary_stream(request, None)) + + multi_callable = _unary_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request, + metadata=((b'test', b'SuccessfulUnaryRequestStreamResponse'),)) + responses = tuple(response_iterator) + + self.assertSequenceEqual(expected_responses, responses) + + def testSuccessfulStreamRequestBlockingUnaryResponse(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary(iter(requests), None) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + response = multi_callable( + request_iterator, + metadata=((b'test', b'SuccessfulStreamRequestBlockingUnaryResponse'),)) + + self.assertEqual(expected_response, response) + + def testSuccessfulStreamRequestBlockingUnaryResponseWithCall(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary(iter(requests), None) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + response, call = multi_callable( + request_iterator, + metadata=( + (b'test', b'SuccessfulStreamRequestBlockingUnaryResponseWithCall'), + ), with_call=True) + + self.assertEqual(expected_response, response) + self.assertIs(grpc.StatusCode.OK, call.code()) + + def testSuccessfulStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary(iter(requests), None) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + response_future = multi_callable.future( + request_iterator, + metadata=( + (b'test', b'SuccessfulStreamRequestFutureUnaryResponse'),)) + response = response_future.result() + + self.assertEqual(expected_response, response) + + def testSuccessfulStreamRequestStreamResponse(self): + requests = tuple(b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)) + expected_responses = tuple( + self._handler.handle_stream_stream(iter(requests), None)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request_iterator, + metadata=((b'test', b'SuccessfulStreamRequestStreamResponse'),)) + responses = tuple(response_iterator) + + self.assertSequenceEqual(expected_responses, responses) + + def testSequentialInvocations(self): + first_request = b'\x07\x08' + second_request = b'\x0809' + expected_first_response = self._handler.handle_unary_unary( + first_request, None) + expected_second_response = self._handler.handle_unary_unary( + second_request, None) + + multi_callable = _unary_unary_multi_callable(self._channel) + first_response = multi_callable( + first_request, metadata=((b'test', b'SequentialInvocations'),)) + second_response = multi_callable( + second_request, metadata=((b'test', b'SequentialInvocations'),)) + + self.assertEqual(expected_first_response, first_response) + self.assertEqual(expected_second_response, second_response) + + def testConcurrentBlockingInvocations(self): + pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary(iter(requests), None) + expected_responses = [expected_response] * test_constants.THREAD_CONCURRENCY + response_futures = [None] * test_constants.THREAD_CONCURRENCY + + multi_callable = _stream_unary_multi_callable(self._channel) + for index in range(test_constants.THREAD_CONCURRENCY): + request_iterator = iter(requests) + response_future = pool.submit( + multi_callable, request_iterator, + metadata=((b'test', b'ConcurrentBlockingInvocations'),)) + response_futures[index] = response_future + responses = tuple( + response_future.result() for response_future in response_futures) + + pool.shutdown(wait=True) + self.assertSequenceEqual(expected_responses, responses) + + def testConcurrentFutureInvocations(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary(iter(requests), None) + expected_responses = [expected_response] * test_constants.THREAD_CONCURRENCY + response_futures = [None] * test_constants.THREAD_CONCURRENCY + + multi_callable = _stream_unary_multi_callable(self._channel) + for index in range(test_constants.THREAD_CONCURRENCY): + request_iterator = iter(requests) + response_future = multi_callable.future( + request_iterator, + metadata=((b'test', b'ConcurrentFutureInvocations'),)) + response_futures[index] = response_future + responses = tuple( + response_future.result() for response_future in response_futures) + + self.assertSequenceEqual(expected_responses, responses) + + def testWaitingForSomeButNotAllConcurrentFutureInvocations(self): + pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + request = b'\x67\x68' + expected_response = self._handler.handle_unary_unary(request, None) + response_futures = [None] * test_constants.THREAD_CONCURRENCY + lock = threading.Lock() + test_is_running_cell = [True] + def wrap_future(future): + def wrap(): + try: + return future.result() + except grpc.RpcError: + with lock: + if test_is_running_cell[0]: + raise + return None + return wrap + + multi_callable = _unary_unary_multi_callable(self._channel) + for index in range(test_constants.THREAD_CONCURRENCY): + inner_response_future = multi_callable.future( + request, + metadata=( + (b'test', + b'WaitingForSomeButNotAllConcurrentFutureInvocations'),)) + outer_response_future = pool.submit(wrap_future(inner_response_future)) + response_futures[index] = outer_response_future + + some_completed_response_futures_iterator = itertools.islice( + futures.as_completed(response_futures), + test_constants.THREAD_CONCURRENCY // 2) + for response_future in some_completed_response_futures_iterator: + self.assertEqual(expected_response, response_future.result()) + with lock: + test_is_running_cell[0] = False + + def testConsumingOneStreamResponseUnaryRequest(self): + request = b'\x57\x38' + + multi_callable = _unary_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request, + metadata=( + (b'test', b'ConsumingOneStreamResponseUnaryRequest'),)) + next(response_iterator) + + def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self): + request = b'\x57\x38' + + multi_callable = _unary_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request, + metadata=( + (b'test', b'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) + for _ in range(test_constants.STREAM_LENGTH // 2): + next(response_iterator) + + def testConsumingSomeButNotAllStreamResponsesStreamRequest(self): + requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request_iterator, + metadata=( + (b'test', b'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) + for _ in range(test_constants.STREAM_LENGTH // 2): + next(response_iterator) + + def testConsumingTooManyStreamResponsesStreamRequest(self): + requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request_iterator, + metadata=( + (b'test', b'ConsumingTooManyStreamResponsesStreamRequest'),)) + for _ in range(test_constants.STREAM_LENGTH): + next(response_iterator) + for _ in range(test_constants.STREAM_LENGTH): + with self.assertRaises(StopIteration): + next(response_iterator) + + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.OK, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def testCancelledUnaryRequestUnaryResponse(self): + request = b'\x07\x17' + + multi_callable = _unary_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request, + metadata=((b'test', b'CancelledUnaryRequestUnaryResponse'),)) + response_future.cancel() + + self.assertTrue(response_future.cancelled()) + with self.assertRaises(grpc.FutureCancelledError): + response_future.result() + self.assertIs(grpc.StatusCode.CANCELLED, response_future.code()) + + def testCancelledUnaryRequestStreamResponse(self): + request = b'\x07\x19' + + multi_callable = _unary_stream_multi_callable(self._channel) + with self._control.pause(): + response_iterator = multi_callable( + request, + metadata=((b'test', b'CancelledUnaryRequestStreamResponse'),)) + self._control.block_until_paused() + response_iterator.cancel() + + with self.assertRaises(grpc.RpcError) as exception_context: + next(response_iterator) + self.assertIs(grpc.StatusCode.CANCELLED, exception_context.exception.code()) + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def testCancelledStreamRequestUnaryResponse(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request_iterator, + metadata=((b'test', b'CancelledStreamRequestUnaryResponse'),)) + self._control.block_until_paused() + response_future.cancel() + + self.assertTrue(response_future.cancelled()) + with self.assertRaises(grpc.FutureCancelledError): + response_future.result() + self.assertIsNotNone(response_future.initial_metadata()) + self.assertIs(grpc.StatusCode.CANCELLED, response_future.code()) + self.assertIsNotNone(response_future.details()) + self.assertIsNotNone(response_future.trailing_metadata()) + + def testCancelledStreamRequestStreamResponse(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + with self._control.pause(): + response_iterator = multi_callable( + request_iterator, + metadata=((b'test', b'CancelledStreamRequestStreamResponse'),)) + response_iterator.cancel() + + with self.assertRaises(grpc.RpcError): + next(response_iterator) + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def testExpiredUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x17' + + multi_callable = _unary_unary_multi_callable(self._channel) + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredUnaryRequestBlockingUnaryResponse'),), + with_call=True) + + self.assertIsNotNone(exception_context.exception.initial_metadata()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIsNotNone(exception_context.exception.details()) + self.assertIsNotNone(exception_context.exception.trailing_metadata()) + + def testExpiredUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x17' + callback = _Callback() + + multi_callable = _unary_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredUnaryRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + self.assertIs(response_future, value_passed_to_callback) + self.assertIsNotNone(response_future.initial_metadata()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code()) + self.assertIsNotNone(response_future.details()) + self.assertIsNotNone(response_future.trailing_metadata()) + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, response_future.exception().code()) + + def testExpiredUnaryRequestStreamResponse(self): + request = b'\x07\x19' + + multi_callable = _unary_stream_multi_callable(self._channel) + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredUnaryRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code()) + + def testExpiredStreamRequestBlockingUnaryResponse(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request_iterator, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredStreamRequestBlockingUnaryResponse'),)) + + self.assertIsNotNone(exception_context.exception.initial_metadata()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIsNotNone(exception_context.exception.details()) + self.assertIsNotNone(exception_context.exception.trailing_metadata()) + + def testExpiredStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + callback = _Callback() + + multi_callable = _stream_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request_iterator, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredStreamRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIs(response_future, value_passed_to_callback) + self.assertIsNotNone(response_future.initial_metadata()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code()) + self.assertIsNotNone(response_future.details()) + self.assertIsNotNone(response_future.trailing_metadata()) + + def testExpiredStreamRequestStreamResponse(self): + requests = tuple(b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request_iterator, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredStreamRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code()) + + def testFailedUnaryRequestBlockingUnaryResponse(self): + request = b'\x37\x17' + + multi_callable = _unary_unary_multi_callable(self._channel) + with self._control.fail(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request, + metadata=((b'test', b'FailedUnaryRequestBlockingUnaryResponse'),), + with_call=True) + + self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + + def testFailedUnaryRequestFutureUnaryResponse(self): + request = b'\x37\x17' + callback = _Callback() + + multi_callable = _unary_unary_multi_callable(self._channel) + with self._control.fail(): + response_future = multi_callable.future( + request, + metadata=((b'test', b'FailedUnaryRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIs(grpc.StatusCode.UNKNOWN, response_future.exception().code()) + self.assertIs(response_future, value_passed_to_callback) + + def testFailedUnaryRequestStreamResponse(self): + request = b'\x37\x17' + + multi_callable = _unary_stream_multi_callable(self._channel) + with self.assertRaises(grpc.RpcError) as exception_context: + with self._control.fail(): + response_iterator = multi_callable( + request, + metadata=((b'test', b'FailedUnaryRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + + def testFailedStreamRequestBlockingUnaryResponse(self): + requests = tuple(b'\x47\x58' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + with self._control.fail(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request_iterator, + metadata=((b'test', b'FailedStreamRequestBlockingUnaryResponse'),)) + + self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + + def testFailedStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + callback = _Callback() + + multi_callable = _stream_unary_multi_callable(self._channel) + with self._control.fail(): + response_future = multi_callable.future( + request_iterator, + metadata=((b'test', b'FailedStreamRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs(grpc.StatusCode.UNKNOWN, response_future.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIs(response_future, value_passed_to_callback) + + def testFailedStreamRequestStreamResponse(self): + requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + with self._control.fail(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request_iterator, + metadata=((b'test', b'FailedStreamRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code()) + + def testIgnoredUnaryRequestFutureUnaryResponse(self): + request = b'\x37\x17' + + multi_callable = _unary_unary_multi_callable(self._channel) + multi_callable.future( + request, + metadata=((b'test', b'IgnoredUnaryRequestFutureUnaryResponse'),)) + + def testIgnoredUnaryRequestStreamResponse(self): + request = b'\x37\x17' + + multi_callable = _unary_stream_multi_callable(self._channel) + multi_callable( + request, + metadata=((b'test', b'IgnoredUnaryRequestStreamResponse'),)) + + def testIgnoredStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + multi_callable.future( + request_iterator, + metadata=((b'test', b'IgnoredStreamRequestFutureUnaryResponse'),)) + + def testIgnoredStreamRequestStreamResponse(self): + requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + multi_callable( + request_iterator, + metadata=((b'test', b'IgnoredStreamRequestStreamResponse'),)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/beta/_auth_test.py b/src/python/grpcio/tests/unit/beta/_auth_test.py new file mode 100644 index 0000000000..694928a91b --- /dev/null +++ b/src/python/grpcio/tests/unit/beta/_auth_test.py @@ -0,0 +1,96 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests of standard AuthMetadataPlugins.""" + +import collections +import threading +import unittest + +from grpc.beta import _auth + + +class MockGoogleCreds(object): + + def get_access_token(self): + token = collections.namedtuple('MockAccessTokenInfo', + ('access_token', 'expires_in')) + token.access_token = 'token' + return token + + +class MockExceptionGoogleCreds(object): + + def get_access_token(self): + raise Exception() + + +class GoogleCallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials_success(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertEqual(metadata, (('authorization', 'Bearer token'),)) + self.assertIsNone(error) + callback_event.set() + + call_creds = _auth.GoogleCallCredentials(MockGoogleCreds()) + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + def test_google_call_credentials_error(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertIsNotNone(error) + callback_event.set() + + call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds()) + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + +class AccessTokenCallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials_success(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertEqual(metadata, (('authorization', 'Bearer token'),)) + self.assertIsNone(error) + callback_event.set() + + call_creds = _auth.AccessTokenCallCredentials('token') + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/beta/_implementations_test.py b/src/python/grpcio/tests/unit/beta/_implementations_test.py index 26be670c45..127f93e9bb 100644 --- a/src/python/grpcio/tests/unit/beta/_implementations_test.py +++ b/src/python/grpcio/tests/unit/beta/_implementations_test.py @@ -29,8 +29,11 @@ """Tests the implementations module of the gRPC Python Beta API.""" +import datetime import unittest +from oauth2client import client as oauth2client_client + from grpc.beta import implementations from tests.unit import resources @@ -49,5 +52,19 @@ class ChannelCredentialsTest(unittest.TestCase): channel_credentials, implementations.ChannelCredentials) +class CallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials(self): + creds = oauth2client_client.GoogleCredentials( + 'token', 'client_id', 'secret', 'refresh_token', + datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/', + 'user_agent') + call_creds = implementations.google_call_credentials(creds) + self.assertIsInstance(call_creds, implementations.CallCredentials) + + def test_access_token_call_credentials(self): + call_creds = implementations.access_token_call_credentials('token') + self.assertIsInstance(call_creds, implementations.CallCredentials) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/framework/common/test_control.py b/src/python/grpcio/tests/unit/framework/common/test_control.py index ca5ba3a854..088e2f8b88 100644 --- a/src/python/grpcio/tests/unit/framework/common/test_control.py +++ b/src/python/grpcio/tests/unit/framework/common/test_control.py @@ -60,10 +60,16 @@ class Control(six.with_metaclass(abc.ABCMeta)): class PauseFailControl(Control): - """A Control that can be used to pause or fail code under control.""" + """A Control that can be used to pause or fail code under control. + + This object is only safe for use from two threads: one of the system under + test calling control and the other from the test system calling pause, + block_until_paused, and fail. + """ def __init__(self): self._condition = threading.Condition() + self._pause = False self._paused = False self._fail = False @@ -72,19 +78,31 @@ class PauseFailControl(Control): if self._fail: raise Defect() - while self._paused: + while self._pause: + self._paused = True + self._condition.notify_all() self._condition.wait() + self._paused = False @contextlib.contextmanager def pause(self): """Pauses code under control while controlling code is in context.""" with self._condition: - self._paused = True + self._pause = True yield with self._condition: - self._paused = False + self._pause = False self._condition.notify_all() + def block_until_paused(self): + """Blocks controlling code until code under control is paused. + + May only be called within the context of a pause call. + """ + with self._condition: + while not self._paused: + self._condition.wait() + @contextlib.contextmanager def fail(self): """Fails code under control while controlling code is in context.""" |