aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio_tests
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/grpcio_tests')
-rw-r--r--src/python/grpcio_tests/grpc_version.py2
-rw-r--r--src/python/grpcio_tests/tests/tests.json1
-rw-r--r--src/python/grpcio_tests/tests/unit/_api_test.py13
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py15
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/_common.py15
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py13
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py13
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py33
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py81
-rw-r--r--src/python/grpcio_tests/tests/unit/_interceptor_test.py571
-rw-r--r--src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py237
11 files changed, 811 insertions, 183 deletions
diff --git a/src/python/grpcio_tests/grpc_version.py b/src/python/grpcio_tests/grpc_version.py
index 99ca3fd82d..b1b4d7e0c2 100644
--- a/src/python/grpcio_tests/grpc_version.py
+++ b/src/python/grpcio_tests/grpc_version.py
@@ -14,4 +14,4 @@
# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_tests/grpc_version.py.template`!!!
-VERSION='1.9.0.dev0'
+VERSION = '1.9.0.dev0'
diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json
index 34cbade92c..3bf5308749 100644
--- a/src/python/grpcio_tests/tests/tests.json
+++ b/src/python/grpcio_tests/tests/tests.json
@@ -39,6 +39,7 @@
"unit._cython.cygrpc_test.TypeSmokeTest",
"unit._empty_message_test.EmptyMessageTest",
"unit._exit_test.ExitTest",
+ "unit._interceptor_test.InterceptorTest",
"unit._invalid_metadata_test.InvalidMetadataTest",
"unit._invocation_defects_test.InvocationDefectsTest",
"unit._metadata_code_details_test.MetadataCodeDetailsTest",
diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py
index b14e8d5c75..d6f4447532 100644
--- a/src/python/grpcio_tests/tests/unit/_api_test.py
+++ b/src/python/grpcio_tests/tests/unit/_api_test.py
@@ -33,18 +33,21 @@ class AllTest(unittest.TestCase):
'AuthMetadataPlugin', 'ServerCertificateConfiguration',
'ServerCredentials', 'UnaryUnaryMultiCallable',
'UnaryStreamMultiCallable', 'StreamUnaryMultiCallable',
- 'StreamStreamMultiCallable', 'Channel', 'ServicerContext',
+ 'StreamStreamMultiCallable', 'UnaryUnaryClientInterceptor',
+ 'UnaryStreamClientInterceptor', 'StreamUnaryClientInterceptor',
+ 'StreamStreamClientInterceptor', 'Channel', 'ServicerContext',
'RpcMethodHandler', 'HandlerCallDetails', 'GenericRpcHandler',
- 'ServiceRpcHandler', 'Server', 'unary_unary_rpc_method_handler',
- 'unary_stream_rpc_method_handler',
- 'stream_unary_rpc_method_handler',
+ 'ServiceRpcHandler', 'Server', 'ServerInterceptor',
+ 'unary_unary_rpc_method_handler', 'unary_stream_rpc_method_handler',
+ 'stream_unary_rpc_method_handler', 'ClientCallDetails',
'stream_stream_rpc_method_handler',
'method_handlers_generic_handler', 'ssl_channel_credentials',
'metadata_call_credentials', 'access_token_call_credentials',
'composite_call_credentials', 'composite_channel_credentials',
'ssl_server_credentials', 'ssl_server_certificate_configuration',
'dynamic_ssl_server_credentials', 'channel_ready_future',
- 'insecure_channel', 'secure_channel', 'server',)
+ 'insecure_channel', 'secure_channel', 'intercept_channel',
+ 'server',)
six.assertCountEqual(self, expected_grpc_code_elements,
_from_grpc_import_star.GRPC_ELEMENTS)
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
index 5b97b7b542..a8a7175cc7 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
@@ -22,7 +22,7 @@ from tests.unit.framework.common import test_constants
_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
_EMPTY_FLAGS = 0
-_EMPTY_METADATA = cygrpc.Metadata(())
+_EMPTY_METADATA = ()
_SERVER_SHUTDOWN_TAG = 'server_shutdown'
_REQUEST_CALL_TAG = 'request_call'
@@ -65,12 +65,10 @@ class _Handler(object):
with self._lock:
self._call.start_server_batch(
- cygrpc.Operations(
- (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
+ (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),),
_RECEIVE_CLOSE_ON_SERVER_TAG)
self._call.start_server_batch(
- cygrpc.Operations(
- (cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
+ (cygrpc.operation_receive_message(_EMPTY_FLAGS),),
_RECEIVE_MESSAGE_TAG)
first_event = self._completion_queue.poll()
if _is_cancellation_event(first_event):
@@ -84,8 +82,8 @@ class _Handler(object):
cygrpc.operation_send_status_from_server(
_EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
_EMPTY_FLAGS),)
- self._call.start_server_batch(
- cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
+ self._call.start_server_batch(operations,
+ _SERVER_COMPLETE_CALL_TAG)
self._completion_queue.poll()
self._completion_queue.poll()
@@ -179,8 +177,7 @@ class CancelManyCallsTest(unittest.TestCase):
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
tag = 'client_complete_call_{0:04d}_tag'.format(index)
- client_call.start_client_batch(
- cygrpc.Operations(operations), tag)
+ client_call.start_client_batch(operations, tag)
client_due.add(tag)
client_calls.append(client_call)
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_common.py b/src/python/grpcio_tests/tests/unit/_cython/_common.py
index ac66d1db3d..96f0f1589b 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_common.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_common.py
@@ -23,17 +23,14 @@ RPC_COUNT = 4000
INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
EMPTY_FLAGS = 0
-INVOCATION_METADATA = cygrpc.Metadata(
- (cygrpc.Metadatum(b'client-md-key', b'client-md-key'),
- cygrpc.Metadatum(b'client-md-key-bin', b'\x00\x01' * 3000),))
+INVOCATION_METADATA = (('client-md-key', 'client-md-key'),
+ ('client-md-key-bin', b'\x00\x01' * 3000),)
-INITIAL_METADATA = cygrpc.Metadata(
- (cygrpc.Metadatum(b'server-initial-md-key', b'server-initial-md-value'),
- cygrpc.Metadatum(b'server-initial-md-key-bin', b'\x00\x02' * 3000),))
+INITIAL_METADATA = (('server-initial-md-key', 'server-initial-md-value'),
+ ('server-initial-md-key-bin', b'\x00\x02' * 3000),)
-TRAILING_METADATA = cygrpc.Metadata(
- (cygrpc.Metadatum(b'server-trailing-md-key', b'server-trailing-md-value'),
- cygrpc.Metadatum(b'server-trailing-md-key-bin', b'\x00\x03' * 3000),))
+TRAILING_METADATA = (('server-trailing-md-key', 'server-trailing-md-value'),
+ ('server-trailing-md-key-bin', b'\x00\x03' * 3000),)
class QueueDriver(object):
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py
index 14cc66675c..d08003af44 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py
@@ -48,20 +48,19 @@ class Test(_common.RpcTest, unittest.TestCase):
client_complete_rpc_tag = 'client_complete_rpc_tag'
with self.client_condition:
client_receive_initial_metadata_start_batch_result = (
- client_call.start_client_batch(
- cygrpc.Operations([
- cygrpc.operation_receive_initial_metadata(
- _common.EMPTY_FLAGS),
- ]), client_receive_initial_metadata_tag))
+ client_call.start_client_batch([
+ cygrpc.operation_receive_initial_metadata(
+ _common.EMPTY_FLAGS),
+ ], client_receive_initial_metadata_tag))
client_complete_rpc_start_batch_result = client_call.start_client_batch(
- cygrpc.Operations([
+ [
cygrpc.operation_send_initial_metadata(
_common.INVOCATION_METADATA, _common.EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(
_common.EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(
_common.EMPTY_FLAGS),
- ]), client_complete_rpc_tag)
+ ], client_complete_rpc_tag)
self.client_driver.add_due({
client_receive_initial_metadata_tag,
client_complete_rpc_tag,
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py
index 1e44bcc4dc..d0166a2b29 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py
@@ -43,20 +43,19 @@ class Test(_common.RpcTest, unittest.TestCase):
client_complete_rpc_tag = 'client_complete_rpc_tag'
with self.client_condition:
client_receive_initial_metadata_start_batch_result = (
- client_call.start_client_batch(
- cygrpc.Operations([
- cygrpc.operation_receive_initial_metadata(
- _common.EMPTY_FLAGS),
- ]), client_receive_initial_metadata_tag))
+ client_call.start_client_batch([
+ cygrpc.operation_receive_initial_metadata(
+ _common.EMPTY_FLAGS),
+ ], client_receive_initial_metadata_tag))
client_complete_rpc_start_batch_result = client_call.start_client_batch(
- cygrpc.Operations([
+ [
cygrpc.operation_send_initial_metadata(
_common.INVOCATION_METADATA, _common.EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(
_common.EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(
_common.EMPTY_FLAGS),
- ]), client_complete_rpc_tag)
+ ], client_complete_rpc_tag)
self.client_driver.add_due({
client_receive_initial_metadata_tag,
client_complete_rpc_tag,
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
index 0105612b47..1deb15ba03 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
@@ -20,7 +20,7 @@ from grpc._cython import cygrpc
_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
_EMPTY_FLAGS = 0
-_EMPTY_METADATA = cygrpc.Metadata(())
+_EMPTY_METADATA = ()
class _ServerDriver(object):
@@ -157,19 +157,17 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
client_complete_rpc_tag = 'client_complete_rpc_tag'
with client_condition:
client_receive_initial_metadata_start_batch_result = (
- client_call.start_client_batch(
- cygrpc.Operations([
- cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
- ]), client_receive_initial_metadata_tag))
+ client_call.start_client_batch([
+ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+ ], client_receive_initial_metadata_tag))
client_due.add(client_receive_initial_metadata_tag)
client_complete_rpc_start_batch_result = (
- client_call.start_client_batch(
- cygrpc.Operations([
- cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
- _EMPTY_FLAGS),
- cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
- cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
- ]), client_complete_rpc_tag))
+ client_call.start_client_batch([
+ cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
+ _EMPTY_FLAGS),
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
+ ], client_complete_rpc_tag))
client_due.add(client_complete_rpc_tag)
server_rpc_event = server_driver.first_event()
@@ -197,8 +195,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
- cygrpc.Metadata(()), cygrpc.StatusCode.ok,
- b'test details', _EMPTY_FLAGS),
+ (), cygrpc.StatusCode.ok, b'test details',
+ _EMPTY_FLAGS),
], server_complete_rpc_tag))
server_send_second_message_event = server_call_driver.event_with_tag(
server_send_second_message_tag)
@@ -209,10 +207,9 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
with client_condition:
client_receive_first_message_tag = 'client_receive_first_message_tag'
client_receive_first_message_start_batch_result = (
- client_call.start_client_batch(
- cygrpc.Operations([
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- ]), client_receive_first_message_tag))
+ client_call.start_client_batch([
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ ], client_receive_first_message_tag))
client_due.add(client_receive_first_message_tag)
client_receive_first_message_event = client_driver.event_with_tag(
client_receive_first_message_tag)
diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
index da94cf8028..4eda685486 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
@@ -29,50 +29,12 @@ _EMPTY_FLAGS = 0
def _metadata_plugin(context, callback):
- callback(
- cygrpc.Metadata([
- cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY,
- _CALL_CREDENTIALS_METADATA_VALUE)
- ]), cygrpc.StatusCode.ok, b'')
+ callback(((_CALL_CREDENTIALS_METADATA_KEY,
+ _CALL_CREDENTIALS_METADATA_VALUE,),), cygrpc.StatusCode.ok, b'')
class TypeSmokeTest(unittest.TestCase):
- def testStringsInUtilitiesUpDown(self):
- self.assertEqual(0, cygrpc.StatusCode.ok)
- metadatum = cygrpc.Metadatum(b'a', b'b')
- self.assertEqual(b'a', metadatum.key)
- self.assertEqual(b'b', metadatum.value)
- metadata = cygrpc.Metadata([metadatum])
- self.assertEqual(1, len(metadata))
- self.assertEqual(metadatum.key, metadata[0].key)
-
- def testMetadataIteration(self):
- metadata = cygrpc.Metadata(
- [cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')])
- iterator = iter(metadata)
- metadatum = next(iterator)
- self.assertIsInstance(metadatum, cygrpc.Metadatum)
- self.assertEqual(metadatum.key, b'a')
- self.assertEqual(metadatum.value, b'b')
- metadatum = next(iterator)
- self.assertIsInstance(metadatum, cygrpc.Metadatum)
- self.assertEqual(metadatum.key, b'c')
- self.assertEqual(metadatum.value, b'd')
- with self.assertRaises(StopIteration):
- next(iterator)
-
- def testOperationsIteration(self):
- operations = cygrpc.Operations(
- [cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)])
- iterator = iter(operations)
- operation = next(iterator)
- self.assertIsInstance(operation, cygrpc.Operation)
- # `Operation`s are write-only structures; can't directly debug anything out
- # of them. Just check that we stop iterating.
- with self.assertRaises(StopIteration):
- next(iterator)
-
def testOperationFlags(self):
operation = cygrpc.operation_send_message(b'asdf',
cygrpc.WriteFlag.no_compress)
@@ -182,8 +144,7 @@ class ServerClientMixin(object):
def performer():
tag = object()
try:
- call_result = call.start_client_batch(
- cygrpc.Operations(operations), tag)
+ call_result = call.start_client_batch(operations, tag)
self.assertEqual(cygrpc.CallError.ok, call_result)
event = queue.poll(deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete,
@@ -200,14 +161,14 @@ class ServerClientMixin(object):
def test_echo(self):
DEADLINE = time.time() + 5
DEADLINE_TOLERANCE = 0.25
- CLIENT_METADATA_ASCII_KEY = b'key'
- CLIENT_METADATA_ASCII_VALUE = b'val'
- CLIENT_METADATA_BIN_KEY = b'key-bin'
+ CLIENT_METADATA_ASCII_KEY = 'key'
+ CLIENT_METADATA_ASCII_VALUE = 'val'
+ CLIENT_METADATA_BIN_KEY = 'key-bin'
CLIENT_METADATA_BIN_VALUE = b'\0' * 1000
- SERVER_INITIAL_METADATA_KEY = b'init_me_me_me'
- SERVER_INITIAL_METADATA_VALUE = b'whodawha?'
- SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought'
- SERVER_TRAILING_METADATA_VALUE = b'zomg it is'
+ SERVER_INITIAL_METADATA_KEY = 'init_me_me_me'
+ SERVER_INITIAL_METADATA_VALUE = 'whodawha?'
+ SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought'
+ SERVER_TRAILING_METADATA_VALUE = 'zomg it is'
SERVER_STATUS_CODE = cygrpc.StatusCode.ok
SERVER_STATUS_DETAILS = b'our work is never over'
REQUEST = b'in death a member of project mayhem has a name'
@@ -227,11 +188,9 @@ class ServerClientMixin(object):
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, self.host_argument,
cygrpc_deadline)
- client_initial_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
- CLIENT_METADATA_ASCII_VALUE),
- cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)
- ])
+ client_initial_metadata = (
+ (CLIENT_METADATA_ASCII_KEY, CLIENT_METADATA_ASCII_VALUE,),
+ (CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE,),)
client_start_batch_result = client_call.start_client_batch([
cygrpc.operation_send_initial_metadata(client_initial_metadata,
_EMPTY_FLAGS),
@@ -263,14 +222,10 @@ class ServerClientMixin(object):
server_call_tag = object()
server_call = request_event.operation_call
- server_initial_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
- SERVER_INITIAL_METADATA_VALUE)
- ])
- server_trailing_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
- SERVER_TRAILING_METADATA_VALUE)
- ])
+ server_initial_metadata = (
+ (SERVER_INITIAL_METADATA_KEY, SERVER_INITIAL_METADATA_VALUE,),)
+ server_trailing_metadata = (
+ (SERVER_TRAILING_METADATA_KEY, SERVER_TRAILING_METADATA_VALUE,),)
server_start_batch_result = server_call.start_server_batch([
cygrpc.operation_send_initial_metadata(
server_initial_metadata,
@@ -347,7 +302,7 @@ class ServerClientMixin(object):
METHOD = b'twinkies'
cygrpc_deadline = cygrpc.Timespec(DEADLINE)
- empty_metadata = cygrpc.Metadata([])
+ empty_metadata = ()
server_request_tag = object()
self.server.request_call(self.server_completion_queue,
diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py
new file mode 100644
index 0000000000..cf875ed7da
--- /dev/null
+++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py
@@ -0,0 +1,571 @@
+# Copyright 2017 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test of gRPC Python interceptors."""
+
+import collections
+import itertools
+import threading
+import unittest
+from concurrent import futures
+
+import grpc
+from grpc.framework.foundation import logging_pool
+
+from tests.unit.framework.common import test_constants
+from tests.unit.framework.common import test_control
+
+_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2
+_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:]
+_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
+_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]
+
+_UNARY_UNARY = '/test/UnaryUnary'
+_UNARY_STREAM = '/test/UnaryStream'
+_STREAM_UNARY = '/test/StreamUnary'
+_STREAM_STREAM = '/test/StreamStream'
+
+
+class _Callback(object):
+
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._value = None
+ self._called = False
+
+ def __call__(self, value):
+ with self._condition:
+ self._value = value
+ self._called = True
+ self._condition.notify_all()
+
+ def value(self):
+ with self._condition:
+ while not self._called:
+ self._condition.wait()
+ return self._value
+
+
+class _Handler(object):
+
+ def __init__(self, control):
+ self._control = control
+
+ def handle_unary_unary(self, request, servicer_context):
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
+ return request
+
+ def handle_unary_stream(self, request, servicer_context):
+ for _ in range(test_constants.STREAM_LENGTH):
+ self._control.control()
+ yield request
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
+
+ def handle_stream_unary(self, request_iterator, servicer_context):
+ if servicer_context is not None:
+ servicer_context.invocation_metadata()
+ self._control.control()
+ response_elements = []
+ for request in request_iterator:
+ self._control.control()
+ response_elements.append(request)
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
+ return b''.join(response_elements)
+
+ def handle_stream_stream(self, request_iterator, servicer_context):
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
+ for request in request_iterator:
+ self._control.control()
+ yield request
+ self._control.control()
+
+
+class _MethodHandler(grpc.RpcMethodHandler):
+
+ def __init__(self, request_streaming, response_streaming,
+ request_deserializer, response_serializer, unary_unary,
+ unary_stream, stream_unary, stream_stream):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = request_deserializer
+ self.response_serializer = response_serializer
+ self.unary_unary = unary_unary
+ self.unary_stream = unary_stream
+ self.stream_unary = stream_unary
+ self.stream_stream = stream_stream
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def __init__(self, handler):
+ self._handler = handler
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(False, False, None, None,
+ self._handler.handle_unary_unary, None, None,
+ None)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(False, True, _DESERIALIZE_REQUEST,
+ _SERIALIZE_RESPONSE, None,
+ self._handler.handle_unary_stream, None, None)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(True, False, _DESERIALIZE_REQUEST,
+ _SERIALIZE_RESPONSE, None, None,
+ self._handler.handle_stream_unary, None)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(True, True, None, None, None, None, None,
+ self._handler.handle_stream_stream)
+ else:
+ return None
+
+
+def _unary_unary_multi_callable(channel):
+ return channel.unary_unary(_UNARY_UNARY)
+
+
+def _unary_stream_multi_callable(channel):
+ return channel.unary_stream(
+ _UNARY_STREAM,
+ request_serializer=_SERIALIZE_REQUEST,
+ response_deserializer=_DESERIALIZE_RESPONSE)
+
+
+def _stream_unary_multi_callable(channel):
+ return channel.stream_unary(
+ _STREAM_UNARY,
+ request_serializer=_SERIALIZE_REQUEST,
+ response_deserializer=_DESERIALIZE_RESPONSE)
+
+
+def _stream_stream_multi_callable(channel):
+ return channel.stream_stream(_STREAM_STREAM)
+
+
+class _ClientCallDetails(
+ collections.namedtuple('_ClientCallDetails',
+ ('method', 'timeout', 'metadata',
+ 'credentials')), grpc.ClientCallDetails):
+ pass
+
+
+class _GenericClientInterceptor(
+ grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
+ grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
+
+ def __init__(self, interceptor_function):
+ self._fn = interceptor_function
+
+ def intercept_unary_unary(self, continuation, client_call_details, request):
+ new_details, new_request_iterator, postprocess = self._fn(
+ client_call_details, iter((request,)), False, False)
+ response = continuation(new_details, next(new_request_iterator))
+ return postprocess(response) if postprocess else response
+
+ def intercept_unary_stream(self, continuation, client_call_details,
+ request):
+ new_details, new_request_iterator, postprocess = self._fn(
+ client_call_details, iter((request,)), False, True)
+ response_it = continuation(new_details, new_request_iterator)
+ return postprocess(response_it) if postprocess else response_it
+
+ def intercept_stream_unary(self, continuation, client_call_details,
+ request_iterator):
+ new_details, new_request_iterator, postprocess = self._fn(
+ client_call_details, request_iterator, True, False)
+ response = continuation(new_details, next(new_request_iterator))
+ return postprocess(response) if postprocess else response
+
+ def intercept_stream_stream(self, continuation, client_call_details,
+ request_iterator):
+ new_details, new_request_iterator, postprocess = self._fn(
+ client_call_details, request_iterator, True, True)
+ response_it = continuation(new_details, new_request_iterator)
+ return postprocess(response_it) if postprocess else response_it
+
+
+class _LoggingInterceptor(
+ grpc.ServerInterceptor, grpc.UnaryUnaryClientInterceptor,
+ grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor,
+ grpc.StreamStreamClientInterceptor):
+
+ def __init__(self, tag, record):
+ self.tag = tag
+ self.record = record
+
+ def intercept_service(self, continuation, handler_call_details):
+ self.record.append(self.tag + ':intercept_service')
+ return continuation(handler_call_details)
+
+ def intercept_unary_unary(self, continuation, client_call_details, request):
+ self.record.append(self.tag + ':intercept_unary_unary')
+ return continuation(client_call_details, request)
+
+ def intercept_unary_stream(self, continuation, client_call_details,
+ request):
+ self.record.append(self.tag + ':intercept_unary_stream')
+ return continuation(client_call_details, request)
+
+ def intercept_stream_unary(self, continuation, client_call_details,
+ request_iterator):
+ self.record.append(self.tag + ':intercept_stream_unary')
+ return continuation(client_call_details, request_iterator)
+
+ def intercept_stream_stream(self, continuation, client_call_details,
+ request_iterator):
+ self.record.append(self.tag + ':intercept_stream_stream')
+ return continuation(client_call_details, request_iterator)
+
+
+class _DefectiveClientInterceptor(grpc.UnaryUnaryClientInterceptor):
+
+ def intercept_unary_unary(self, ignored_continuation,
+ ignored_client_call_details, ignored_request):
+ raise test_control.Defect()
+
+
+def _wrap_request_iterator_stream_interceptor(wrapper):
+
+ def intercept_call(client_call_details, request_iterator, request_streaming,
+ ignored_response_streaming):
+ if request_streaming:
+ return client_call_details, wrapper(request_iterator), None
+ else:
+ return client_call_details, request_iterator, None
+
+ return _GenericClientInterceptor(intercept_call)
+
+
+def _append_request_header_interceptor(header, value):
+
+ def intercept_call(client_call_details, request_iterator,
+ ignored_request_streaming, ignored_response_streaming):
+ metadata = []
+ if client_call_details.metadata:
+ metadata = list(client_call_details.metadata)
+ metadata.append((header, value,))
+ client_call_details = _ClientCallDetails(
+ client_call_details.method, client_call_details.timeout, metadata,
+ client_call_details.credentials)
+ return client_call_details, request_iterator, None
+
+ return _GenericClientInterceptor(intercept_call)
+
+
+class _GenericServerInterceptor(grpc.ServerInterceptor):
+
+ def __init__(self, fn):
+ self._fn = fn
+
+ def intercept_service(self, continuation, handler_call_details):
+ return self._fn(continuation, handler_call_details)
+
+
+def _filter_server_interceptor(condition, interceptor):
+
+ def intercept_service(continuation, handler_call_details):
+ if condition(handler_call_details):
+ return interceptor.intercept_service(continuation,
+ handler_call_details)
+ return continuation(handler_call_details)
+
+ return _GenericServerInterceptor(intercept_service)
+
+
+class InterceptorTest(unittest.TestCase):
+
+ def setUp(self):
+ self._control = test_control.PauseFailControl()
+ self._handler = _Handler(self._control)
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+
+ self._record = []
+ conditional_interceptor = _filter_server_interceptor(
+ lambda x: ('secret', '42') in x.invocation_metadata,
+ _LoggingInterceptor('s3', self._record))
+
+ self._server = grpc.server(
+ self._server_pool,
+ interceptors=(_LoggingInterceptor('s1', self._record),
+ conditional_interceptor,
+ _LoggingInterceptor('s2', self._record),))
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
+ self._server.start()
+
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+
+ def tearDown(self):
+ self._server.stop(None)
+ self._server_pool.shutdown(wait=True)
+
+ def testTripleRequestMessagesClientInterceptor(self):
+
+ def triple(request_iterator):
+ while True:
+ try:
+ item = next(request_iterator)
+ yield item
+ yield item
+ yield item
+ except StopIteration:
+ break
+
+ interceptor = _wrap_request_iterator_stream_interceptor(triple)
+ channel = grpc.intercept_channel(self._channel, interceptor)
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+
+ multi_callable = _stream_stream_multi_callable(channel)
+ response_iterator = multi_callable(
+ iter(requests),
+ metadata=(
+ ('test',
+ 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))
+
+ responses = tuple(response_iterator)
+ self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ iter(requests),
+ metadata=(
+ ('test',
+ 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))
+
+ responses = tuple(response_iterator)
+ self.assertEqual(len(responses), test_constants.STREAM_LENGTH)
+
+ def testDefectiveClientInterceptor(self):
+ interceptor = _DefectiveClientInterceptor()
+ defective_channel = grpc.intercept_channel(self._channel, interceptor)
+
+ request = b'\x07\x08'
+
+ multi_callable = _unary_unary_multi_callable(defective_channel)
+ call_future = multi_callable.future(
+ request,
+ metadata=(
+ ('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),))
+
+ self.assertIsNotNone(call_future.exception())
+ self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL)
+
+ def testInterceptedHeaderManipulationWithServerSideVerification(self):
+ request = b'\x07\x08'
+
+ channel = grpc.intercept_channel(
+ self._channel, _append_request_header_interceptor('secret', '42'))
+ channel = grpc.intercept_channel(
+ channel,
+ _LoggingInterceptor('c1', self._record),
+ _LoggingInterceptor('c2', self._record))
+
+ self._record[:] = []
+
+ multi_callable = _unary_unary_multi_callable(channel)
+ multi_callable.with_call(
+ request,
+ metadata=(
+ ('test',
+ 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),))
+
+ self.assertSequenceEqual(self._record, [
+ 'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
+ 's1:intercept_service', 's3:intercept_service',
+ 's2:intercept_service'
+ ])
+
+ def testInterceptedUnaryRequestBlockingUnaryResponse(self):
+ request = b'\x07\x08'
+
+ self._record[:] = []
+
+ channel = grpc.intercept_channel(
+ self._channel,
+ _LoggingInterceptor('c1', self._record),
+ _LoggingInterceptor('c2', self._record))
+
+ multi_callable = _unary_unary_multi_callable(channel)
+ multi_callable(
+ request,
+ metadata=(
+ ('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),))
+
+ self.assertSequenceEqual(self._record, [
+ 'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
+ 's1:intercept_service', 's2:intercept_service'
+ ])
+
+ def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self):
+ request = b'\x07\x08'
+
+ channel = grpc.intercept_channel(
+ self._channel,
+ _LoggingInterceptor('c1', self._record),
+ _LoggingInterceptor('c2', self._record))
+
+ self._record[:] = []
+
+ multi_callable = _unary_unary_multi_callable(channel)
+ multi_callable.with_call(
+ request,
+ metadata=(
+ ('test',
+ 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),))
+
+ self.assertSequenceEqual(self._record, [
+ 'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
+ 's1:intercept_service', 's2:intercept_service'
+ ])
+
+ def testInterceptedUnaryRequestFutureUnaryResponse(self):
+ request = b'\x07\x08'
+
+ self._record[:] = []
+ channel = grpc.intercept_channel(
+ self._channel,
+ _LoggingInterceptor('c1', self._record),
+ _LoggingInterceptor('c2', self._record))
+
+ multi_callable = _unary_unary_multi_callable(channel)
+ response_future = multi_callable.future(
+ request,
+ metadata=(('test', 'InterceptedUnaryRequestFutureUnaryResponse'),))
+ response_future.result()
+
+ self.assertSequenceEqual(self._record, [
+ 'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
+ 's1:intercept_service', 's2:intercept_service'
+ ])
+
+ def testInterceptedUnaryRequestStreamResponse(self):
+ request = b'\x37\x58'
+
+ self._record[:] = []
+ channel = grpc.intercept_channel(
+ self._channel,
+ _LoggingInterceptor('c1', self._record),
+ _LoggingInterceptor('c2', self._record))
+
+ multi_callable = _unary_stream_multi_callable(channel)
+ response_iterator = multi_callable(
+ request,
+ metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),))
+ tuple(response_iterator)
+
+ self.assertSequenceEqual(self._record, [
+ 'c1:intercept_unary_stream', 'c2:intercept_unary_stream',
+ 's1:intercept_service', 's2:intercept_service'
+ ])
+
+ def testInterceptedStreamRequestBlockingUnaryResponse(self):
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ self._record[:] = []
+ channel = grpc.intercept_channel(
+ self._channel,
+ _LoggingInterceptor('c1', self._record),
+ _LoggingInterceptor('c2', self._record))
+
+ multi_callable = _stream_unary_multi_callable(channel)
+ multi_callable(
+ request_iterator,
+ metadata=(
+ ('test', 'InterceptedStreamRequestBlockingUnaryResponse'),))
+
+ self.assertSequenceEqual(self._record, [
+ 'c1:intercept_stream_unary', 'c2:intercept_stream_unary',
+ 's1:intercept_service', 's2:intercept_service'
+ ])
+
+ def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self):
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ self._record[:] = []
+ channel = grpc.intercept_channel(
+ self._channel,
+ _LoggingInterceptor('c1', self._record),
+ _LoggingInterceptor('c2', self._record))
+
+ multi_callable = _stream_unary_multi_callable(channel)
+ multi_callable.with_call(
+ request_iterator,
+ metadata=(
+ ('test',
+ 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))
+
+ self.assertSequenceEqual(self._record, [
+ 'c1:intercept_stream_unary', 'c2:intercept_stream_unary',
+ 's1:intercept_service', 's2:intercept_service'
+ ])
+
+ def testInterceptedStreamRequestFutureUnaryResponse(self):
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ self._record[:] = []
+ channel = grpc.intercept_channel(
+ self._channel,
+ _LoggingInterceptor('c1', self._record),
+ _LoggingInterceptor('c2', self._record))
+
+ multi_callable = _stream_unary_multi_callable(channel)
+ response_future = multi_callable.future(
+ request_iterator,
+ metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),))
+ response_future.result()
+
+ self.assertSequenceEqual(self._record, [
+ 'c1:intercept_stream_unary', 'c2:intercept_stream_unary',
+ 's1:intercept_service', 's2:intercept_service'
+ ])
+
+ def testInterceptedStreamRequestStreamResponse(self):
+ requests = tuple(b'\x77\x58'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ self._record[:] = []
+ channel = grpc.intercept_channel(
+ self._channel,
+ _LoggingInterceptor('c1', self._record),
+ _LoggingInterceptor('c2', self._record))
+
+ multi_callable = _stream_stream_multi_callable(channel)
+ response_iterator = multi_callable(
+ request_iterator,
+ metadata=(('test', 'InterceptedStreamRequestStreamResponse'),))
+ tuple(response_iterator)
+
+ self.assertSequenceEqual(self._record, [
+ 'c1:intercept_stream_stream', 'c2:intercept_stream_stream',
+ 's1:intercept_service', 's2:intercept_service'
+ ])
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
index 6faab94be6..cb59cd3769 100644
--- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
+++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
@@ -56,6 +56,7 @@ class _Servicer(object):
def __init__(self):
self._lock = threading.Lock()
+ self._abort_call = False
self._code = None
self._details = None
self._exception = False
@@ -67,10 +68,13 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
+ if self._abort_call:
+ context.abort(self._code, self._details)
+ else:
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
if self._exception:
raise test_control.Defect()
else:
@@ -81,10 +85,13 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
+ if self._abort_call:
+ context.abort(self._code, self._details)
+ else:
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
for _ in range(test_constants.STREAM_LENGTH // 2):
yield _SERIALIZED_RESPONSE
if self._exception:
@@ -95,14 +102,16 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
# TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
# request iterator.
- for ignored_request in request_iterator:
- pass
+ list(request_iterator)
+ if self._abort_call:
+ context.abort(self._code, self._details)
+ else:
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
if self._exception:
raise test_control.Defect()
else:
@@ -113,19 +122,25 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
# TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
# request iterator.
- for ignored_request in request_iterator:
- pass
+ list(request_iterator)
+ if self._abort_call:
+ context.abort(self._code, self._details)
+ else:
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
for _ in range(test_constants.STREAM_LENGTH // 3):
yield object()
if self._exception:
raise test_control.Defect()
+ def set_abort_call(self):
+ with self._lock:
+ self._abort_call = True
+
def set_code(self, code):
with self._lock:
self._code = code
@@ -212,11 +227,10 @@ class MetadataCodeDetailsTest(unittest.TestCase):
def testSuccessfulUnaryStream(self):
self._servicer.set_details(_DETAILS)
- call = self._unary_stream(
+ response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
- for _ in call:
- pass
+ received_initial_metadata = response_iterator_call.initial_metadata()
+ list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -225,10 +239,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
- test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
- call.trailing_metadata()))
- self.assertIs(grpc.StatusCode.OK, call.code())
- self.assertEqual(_DETAILS, call.details())
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ response_iterator_call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
+ self.assertEqual(_DETAILS, response_iterator_call.details())
def testSuccessfulStreamUnary(self):
self._servicer.set_details(_DETAILS)
@@ -252,12 +267,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
def testSuccessfulStreamStream(self):
self._servicer.set_details(_DETAILS)
- call = self._stream_stream(
+ response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
- for _ in call:
- pass
+ received_initial_metadata = response_iterator_call.initial_metadata()
+ list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -266,10 +280,106 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
- test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
- call.trailing_metadata()))
- self.assertIs(grpc.StatusCode.OK, call.code())
- self.assertEqual(_DETAILS, call.details())
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ response_iterator_call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
+ self.assertEqual(_DETAILS, response_iterator_call.details())
+
+ def testAbortedUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_abort_call()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testAbortedUnaryStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_abort_call()
+
+ response_iterator_call = self._unary_stream(
+ _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = response_iterator_call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ self.assertEqual(len(list(response_iterator_call)), 0)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ response_iterator_call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, response_iterator_call.code())
+ self.assertEqual(_DETAILS, response_iterator_call.details())
+
+ def testAbortedStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_abort_call()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testAbortedStreamStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_abort_call()
+
+ response_iterator_call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = response_iterator_call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ self.assertEqual(len(list(response_iterator_call)), 0)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ response_iterator_call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, response_iterator_call.code())
+ self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)
@@ -296,12 +406,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
- call = self._unary_stream(
+ response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
+ received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
- for _ in call:
- pass
+ list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -310,10 +419,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
- test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
- call.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, call.code())
- self.assertEqual(_DETAILS, call.details())
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ response_iterator_call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, response_iterator_call.code())
+ self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE)
@@ -342,13 +452,12 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
- call = self._stream_stream(
+ response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
+ received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError) as exception_context:
- for _ in call:
- pass
+ list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -390,12 +499,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_details(_DETAILS)
self._servicer.set_exception()
- call = self._unary_stream(
+ response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
+ received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
- for _ in call:
- pass
+ list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -404,10 +512,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
- test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
- call.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, call.code())
- self.assertEqual(_DETAILS, call.details())
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ response_iterator_call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, response_iterator_call.code())
+ self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeExceptionStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE)
@@ -438,13 +547,12 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_details(_DETAILS)
self._servicer.set_exception()
- call = self._stream_stream(
+ response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
+ received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
- for _ in call:
- pass
+ list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -453,10 +561,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
- test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
- call.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, call.code())
- self.assertEqual(_DETAILS, call.details())
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ response_iterator_call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, response_iterator_call.code())
+ self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeReturnNoneUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)