aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio_tests/tests
diff options
context:
space:
mode:
authorGravatar Mehrdad Afshari <mehrdada@users.noreply.github.com>2017-12-10 18:27:21 -0800
committerGravatar Mehrdad Afshari <mehrdada@users.noreply.github.com>2017-12-10 20:45:18 -0800
commit90ab995cb0f3ef29c1f284a4c361a9c2750ef2dd (patch)
treed317a8a1ed8249c70e08f597af1ee93cb81198ad /src/python/grpcio_tests/tests
parent9bc44e38296fe5c8001929b61851d0a1ba326a15 (diff)
Tests for ServicerContext.abort
Diffstat (limited to 'src/python/grpcio_tests/tests')
-rw-r--r--src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py166
1 files changed, 134 insertions, 32 deletions
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
index 6faab94be6..68df6643e7 100644
--- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
+++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
@@ -56,6 +56,7 @@ class _Servicer(object):
def __init__(self):
self._lock = threading.Lock()
+ self._abort_call = False
self._code = None
self._details = None
self._exception = False
@@ -67,10 +68,13 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
+ if self._abort_call:
+ context.abort(self._code, self._details)
+ else:
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
if self._exception:
raise test_control.Defect()
else:
@@ -81,10 +85,13 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
+ if self._abort_call:
+ context.abort(self._code, self._details)
+ else:
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
for _ in range(test_constants.STREAM_LENGTH // 2):
yield _SERIALIZED_RESPONSE
if self._exception:
@@ -95,14 +102,16 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
# TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
# request iterator.
- for ignored_request in request_iterator:
- pass
+ list(request_iterator)
+ if self._abort_call:
+ context.abort(self._code, self._details)
+ else:
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
if self._exception:
raise test_control.Defect()
else:
@@ -113,19 +122,25 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
# TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
# request iterator.
- for ignored_request in request_iterator:
- pass
+ list(request_iterator)
+ if self._abort_call:
+ context.abort(self._code, self._details)
+ else:
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
for _ in range(test_constants.STREAM_LENGTH // 3):
yield object()
if self._exception:
raise test_control.Defect()
+ def set_abort_call(self):
+ with self._lock:
+ self._abort_call = True
+
def set_code(self, code):
with self._lock:
self._code = code
@@ -215,8 +230,7 @@ class MetadataCodeDetailsTest(unittest.TestCase):
call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
- for _ in call:
- pass
+ list(call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -256,8 +270,7 @@ class MetadataCodeDetailsTest(unittest.TestCase):
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
- for _ in call:
- pass
+ list(call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -271,6 +284,99 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self.assertIs(grpc.StatusCode.OK, call.code())
self.assertEqual(_DETAILS, call.details())
+ def testAbortedUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_abort_call()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testAbortedUnaryStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_abort_call()
+
+ call = self._unary_stream(
+ _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ self.assertEqual(len(list(call)), 0)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testAbortedStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_abort_call()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testAbortedStreamStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_abort_call()
+
+ call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ self.assertEqual(len(list(call)), 0)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
def testCustomCodeUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
@@ -300,8 +406,7 @@ class MetadataCodeDetailsTest(unittest.TestCase):
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
with self.assertRaises(grpc.RpcError):
- for _ in call:
- pass
+ list(call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -347,8 +452,7 @@ class MetadataCodeDetailsTest(unittest.TestCase):
metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
with self.assertRaises(grpc.RpcError) as exception_context:
- for _ in call:
- pass
+ list(call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -394,8 +498,7 @@ class MetadataCodeDetailsTest(unittest.TestCase):
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
with self.assertRaises(grpc.RpcError):
- for _ in call:
- pass
+ list(call)
self.assertTrue(
test_common.metadata_transmitted(
@@ -443,8 +546,7 @@ class MetadataCodeDetailsTest(unittest.TestCase):
metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
with self.assertRaises(grpc.RpcError):
- for _ in call:
- pass
+ list(call)
self.assertTrue(
test_common.metadata_transmitted(