diff options
Diffstat (limited to 'src/python/grpcio_tests/tests')
10 files changed, 393 insertions, 217 deletions
diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index 724427adb6..0d94426413 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -25,6 +25,7 @@ "unit._auth_test.AccessTokenAuthMetadataPluginTest", "unit._auth_test.GoogleCallCredentialsTest", "unit._channel_args_test.ChannelArgsTest", + "unit._channel_close_test.ChannelCloseTest", "unit._channel_connectivity_test.ChannelConnectivityTest", "unit._channel_ready_future_test.ChannelReadyFutureTest", "unit._compression_test.CompressionTest", diff --git a/src/python/grpcio_tests/tests/unit/_channel_close_test.py b/src/python/grpcio_tests/tests/unit/_channel_close_test.py new file mode 100644 index 0000000000..af3a9ee1ee --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_channel_close_test.py @@ -0,0 +1,185 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests server and client side compression.""" + +import threading +import time +import unittest + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_BEAT = 0.5 +_SOME_TIME = 5 +_MORE_TIME = 10 + + +class _MethodHandler(grpc.RpcMethodHandler): + + request_streaming = True + response_streaming = True + request_deserializer = None + response_serializer = None + + def stream_stream(self, request_iterator, servicer_context): + for request in request_iterator: + yield request * 2 + + +_METHOD_HANDLER = _MethodHandler() + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + return _METHOD_HANDLER + + +_GENERIC_HANDLER = _GenericHandler() + + +class _Pipe(object): + + def __init__(self, values): + self._condition = threading.Condition() + self._values = list(values) + self._open = True + + def __iter__(self): + return self + + def _next(self): + with self._condition: + while not self._values and self._open: + self._condition.wait() + if self._values: + return self._values.pop(0) + else: + raise StopIteration() + + def next(self): + return self._next() + + def __next__(self): + return self._next() + + def add(self, value): + with self._condition: + self._values.append(value) + self._condition.notify() + + def close(self): + with self._condition: + self._open = False + self._condition.notify() + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + +class ChannelCloseTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server( + max_workers=test_constants.THREAD_CONCURRENCY) + self._server.add_generic_rpc_handlers((_GENERIC_HANDLER,)) + self._port = self._server.add_insecure_port('[::]:0') + self._server.start() + + def tearDown(self): + self._server.stop(None) + + def test_close_immediately_after_call_invocation(self): + channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + multi_callable = channel.stream_stream('Meffod') + request_iterator = _Pipe(()) + response_iterator = multi_callable(request_iterator) + channel.close() + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_close_while_call_active(self): + channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + multi_callable = channel.stream_stream('Meffod') + request_iterator = _Pipe((b'abc',)) + response_iterator = multi_callable(request_iterator) + next(response_iterator) + channel.close() + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_context_manager_close_while_call_active(self): + with grpc.insecure_channel('localhost:{}'.format( + self._port)) as channel: # pylint: disable=bad-continuation + multi_callable = channel.stream_stream('Meffod') + request_iterator = _Pipe((b'abc',)) + response_iterator = multi_callable(request_iterator) + next(response_iterator) + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_context_manager_close_while_many_calls_active(self): + with grpc.insecure_channel('localhost:{}'.format( + self._port)) as channel: # pylint: disable=bad-continuation + multi_callable = channel.stream_stream('Meffod') + request_iterators = tuple( + _Pipe((b'abc',)) + for _ in range(test_constants.THREAD_CONCURRENCY)) + response_iterators = [] + for request_iterator in request_iterators: + response_iterator = multi_callable(request_iterator) + next(response_iterator) + response_iterators.append(response_iterator) + for request_iterator in request_iterators: + request_iterator.close() + + for response_iterator in response_iterators: + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + def test_many_concurrent_closes(self): + channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + multi_callable = channel.stream_stream('Meffod') + request_iterator = _Pipe((b'abc',)) + response_iterator = multi_callable(request_iterator) + next(response_iterator) + start = time.time() + end = start + _MORE_TIME + + def sleep_some_time_then_close(): + time.sleep(_SOME_TIME) + channel.close() + + for _ in range(test_constants.THREAD_CONCURRENCY): + close_thread = threading.Thread(target=sleep_some_time_then_close) + close_thread.start() + while True: + request_iterator.add(b'def') + time.sleep(_BEAT) + if end < time.time(): + break + request_iterator.close() + + self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py index 4f8868d346..578a3d79ad 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py @@ -19,6 +19,7 @@ import unittest from grpc._cython import cygrpc from grpc.framework.foundation import logging_pool from tests.unit.framework.common import test_constants +from tests.unit._cython import test_utilities _EMPTY_FLAGS = 0 _EMPTY_METADATA = () @@ -30,6 +31,8 @@ _RECEIVE_MESSAGE_TAG = 'receive_message' _SERVER_COMPLETE_CALL_TAG = 'server_complete_call' _SUCCESS_CALL_FRACTION = 1.0 / 8.0 +_SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) +_UNSUCCESSFUL_CALLS = test_constants.RPC_CONCURRENCY - _SUCCESSFUL_CALLS class _State(object): @@ -150,7 +153,8 @@ class CancelManyCallsTest(unittest.TestCase): server.register_completion_queue(server_completion_queue) port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None) + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None, + None) state = _State() @@ -165,31 +169,33 @@ class CancelManyCallsTest(unittest.TestCase): client_condition = threading.Condition() client_due = set() - client_completion_queue = cygrpc.CompletionQueue() - client_driver = _QueueDriver(client_condition, client_completion_queue, - client_due) - client_driver.start() with client_condition: client_calls = [] for index in range(test_constants.RPC_CONCURRENCY): - client_call = channel.create_call(None, _EMPTY_FLAGS, - client_completion_queue, - b'/twinkies', None, None) - operations = ( - cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, - _EMPTY_FLAGS), - cygrpc.SendMessageOperation(b'\x45\x56', _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ) tag = 'client_complete_call_{0:04d}_tag'.format(index) - client_call.start_client_batch(operations, tag) + client_call = channel.integrated_call( + _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, + None, (( + ( + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.SendMessageOperation(b'\x45\x56', + _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation( + _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ), + tag, + ),)) client_due.add(tag) client_calls.append(client_call) + client_events_future = test_utilities.SimpleFuture( + lambda: tuple(channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS))) + with state.condition: while True: if state.parked_handlers < test_constants.THREAD_CONCURRENCY: @@ -201,12 +207,14 @@ class CancelManyCallsTest(unittest.TestCase): state.condition.notify_all() break - client_driver.events( - test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) + client_events_future.result() with client_condition: for client_call in client_calls: - client_call.cancel() + client_call.cancel(cygrpc.StatusCode.cancelled, 'Cancelled!') + for _ in range(_UNSUCCESSFUL_CALLS): + channel.next_call_event() + channel.close(cygrpc.StatusCode.unknown, 'Cancelled on channel close!') with state.condition: server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py index 7305d0fa3f..d95286071d 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py @@ -21,25 +21,20 @@ from grpc._cython import cygrpc from tests.unit.framework.common import test_constants -def _channel_and_completion_queue(): - channel = cygrpc.Channel(b'localhost:54321', ()) - completion_queue = cygrpc.CompletionQueue() - return channel, completion_queue +def _channel(): + return cygrpc.Channel(b'localhost:54321', (), None) -def _connectivity_loop(channel, completion_queue): +def _connectivity_loop(channel): for _ in range(100): connectivity = channel.check_connectivity_state(True) - channel.watch_connectivity_state(connectivity, - time.time() + 0.2, completion_queue, - None) - completion_queue.poll() + channel.watch_connectivity_state(connectivity, time.time() + 0.2) def _create_loop_destroy(): - channel, completion_queue = _channel_and_completion_queue() - _connectivity_loop(channel, completion_queue) - completion_queue.shutdown() + channel = _channel() + _connectivity_loop(channel) + channel.close(cygrpc.StatusCode.ok, 'Channel close!') def _in_parallel(behavior, arguments): @@ -55,12 +50,9 @@ def _in_parallel(behavior, arguments): class ChannelTest(unittest.TestCase): def test_single_channel_lonely_connectivity(self): - channel, completion_queue = _channel_and_completion_queue() - _in_parallel(_connectivity_loop, ( - channel, - completion_queue, - )) - completion_queue.shutdown() + channel = _channel() + _connectivity_loop(channel) + channel.close(cygrpc.StatusCode.ok, 'Channel close!') def test_multiple_channels_lonely_connectivity(self): _in_parallel(_create_loop_destroy, ()) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_common.py b/src/python/grpcio_tests/tests/unit/_cython/_common.py index 7fd3d19b4e..d8210f36f8 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_common.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_common.py @@ -100,7 +100,8 @@ class RpcTest(object): self.server.register_completion_queue(self.server_completion_queue) port = self.server.add_http2_port(b'[::]:0') self.server.start() - self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), []) + self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [], + None) self._server_shutdown_tag = 'server_shutdown_tag' self.server_condition = threading.Condition() diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py index 7caa98f72d..8a721788f4 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py @@ -19,6 +19,7 @@ import unittest from grpc._cython import cygrpc from tests.unit._cython import _common +from tests.unit._cython import test_utilities class Test(_common.RpcTest, unittest.TestCase): @@ -41,31 +42,27 @@ class Test(_common.RpcTest, unittest.TestCase): server_request_call_tag, }) - client_call = self.channel.create_call(None, _common.EMPTY_FLAGS, - self.client_completion_queue, - b'/twinkies', None, None) client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' client_complete_rpc_tag = 'client_complete_rpc_tag' - with self.client_condition: - client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), - ], client_receive_initial_metadata_tag)) - self.assertEqual(cygrpc.CallError.ok, - client_receive_initial_metadata_start_batch_result) - client_complete_rpc_start_batch_result = client_call.start_client_batch( + client_call = self.channel.integrated_call( + _common.EMPTY_FLAGS, b'/twinkies', None, None, + _common.INVOCATION_METADATA, None, [( [ - cygrpc.SendInitialMetadataOperation( - _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), - ], client_complete_rpc_tag) - self.assertEqual(cygrpc.CallError.ok, - client_complete_rpc_start_batch_result) - self.client_driver.add_due({ + cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag, - client_complete_rpc_tag, - }) + )]) + client_call.operate([ + cygrpc.SendInitialMetadataOperation(_common.INVOCATION_METADATA, + _common.EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), + ], client_complete_rpc_tag) + + client_events_future = test_utilities.SimpleFuture( + lambda: [ + self.channel.next_call_event(), + self.channel.next_call_event(),]) server_request_call_event = self.server_driver.event_with_tag( server_request_call_tag) @@ -96,20 +93,23 @@ class Test(_common.RpcTest, unittest.TestCase): server_complete_rpc_event = server_call_driver.event_with_tag( server_complete_rpc_tag) - client_receive_initial_metadata_event = self.client_driver.event_with_tag( - client_receive_initial_metadata_tag) - client_complete_rpc_event = self.client_driver.event_with_tag( - client_complete_rpc_tag) + client_events = client_events_future.result() + if client_events[0].tag is client_receive_initial_metadata_tag: + client_receive_initial_metadata_event = client_events[0] + client_complete_rpc_event = client_events[1] + else: + client_complete_rpc_event = client_events[0] + client_receive_initial_metadata_event = client_events[1] return ( _common.OperationResult(server_request_call_start_batch_result, server_request_call_event.completion_type, server_request_call_event.success), _common.OperationResult( - client_receive_initial_metadata_start_batch_result, + cygrpc.CallError.ok, client_receive_initial_metadata_event.completion_type, client_receive_initial_metadata_event.success), - _common.OperationResult(client_complete_rpc_start_batch_result, + _common.OperationResult(cygrpc.CallError.ok, client_complete_rpc_event.completion_type, client_complete_rpc_event.success), _common.OperationResult( diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py index 8582a39c01..47f39ebce2 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py @@ -19,6 +19,7 @@ import unittest from grpc._cython import cygrpc from tests.unit._cython import _common +from tests.unit._cython import test_utilities class Test(_common.RpcTest, unittest.TestCase): @@ -36,28 +37,31 @@ class Test(_common.RpcTest, unittest.TestCase): server_request_call_tag, }) - client_call = self.channel.create_call(None, _common.EMPTY_FLAGS, - self.client_completion_queue, - b'/twinkies', None, None) client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' client_complete_rpc_tag = 'client_complete_rpc_tag' - with self.client_condition: - client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), - ], client_receive_initial_metadata_tag)) - client_complete_rpc_start_batch_result = client_call.start_client_batch( - [ - cygrpc.SendInitialMetadataOperation( - _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), - ], client_complete_rpc_tag) - self.client_driver.add_due({ - client_receive_initial_metadata_tag, - client_complete_rpc_tag, - }) - + client_call = self.channel.integrated_call( + _common.EMPTY_FLAGS, b'/twinkies', None, None, + _common.INVOCATION_METADATA, None, [ + ( + [ + cygrpc.SendInitialMetadataOperation( + _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation( + _common.EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation( + _common.EMPTY_FLAGS), + ], + client_complete_rpc_tag, + ), + ]) + client_call.operate([ + cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), + ], client_receive_initial_metadata_tag) + + client_events_future = test_utilities.SimpleFuture( + lambda: [ + self.channel.next_call_event(), + self.channel.next_call_event(),]) server_request_call_event = self.server_driver.event_with_tag( server_request_call_tag) @@ -87,20 +91,19 @@ class Test(_common.RpcTest, unittest.TestCase): server_complete_rpc_event = self.server_driver.event_with_tag( server_complete_rpc_tag) - client_receive_initial_metadata_event = self.client_driver.event_with_tag( - client_receive_initial_metadata_tag) - client_complete_rpc_event = self.client_driver.event_with_tag( - client_complete_rpc_tag) + client_events = client_events_future.result() + client_receive_initial_metadata_event = client_events[0] + client_complete_rpc_event = client_events[1] return ( _common.OperationResult(server_request_call_start_batch_result, server_request_call_event.completion_type, server_request_call_event.success), _common.OperationResult( - client_receive_initial_metadata_start_batch_result, + cygrpc.CallError.ok, client_receive_initial_metadata_event.completion_type, client_receive_initial_metadata_event.success), - _common.OperationResult(client_complete_rpc_start_batch_result, + _common.OperationResult(cygrpc.CallError.ok, client_complete_rpc_event.completion_type, client_complete_rpc_event.success), _common.OperationResult( diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py index bc63b54879..8a903bfaf9 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -17,6 +17,7 @@ import threading import unittest from grpc._cython import cygrpc +from tests.unit._cython import test_utilities _EMPTY_FLAGS = 0 _EMPTY_METADATA = () @@ -118,7 +119,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server.register_completion_queue(server_completion_queue) port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set()) + channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set(), + None) server_shutdown_tag = 'server_shutdown_tag' server_driver = _ServerDriver(server_completion_queue, @@ -127,10 +129,6 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): client_condition = threading.Condition() client_due = set() - client_completion_queue = cygrpc.CompletionQueue() - client_driver = _QueueDriver(client_condition, client_completion_queue, - client_due) - client_driver.start() server_call_condition = threading.Condition() server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' @@ -154,25 +152,28 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server_completion_queue, server_rpc_tag) - client_call = channel.create_call(None, _EMPTY_FLAGS, - client_completion_queue, b'/twinkies', - None, None) client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' client_complete_rpc_tag = 'client_complete_rpc_tag' - with client_condition: - client_receive_initial_metadata_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - ], client_receive_initial_metadata_tag)) - client_due.add(client_receive_initial_metadata_tag) - client_complete_rpc_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, - _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ], client_complete_rpc_tag)) - client_due.add(client_complete_rpc_tag) + client_call = channel.segregated_call( + _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, None, ( + ( + [ + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + ], + client_receive_initial_metadata_tag, + ), + ( + [ + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + client_complete_rpc_tag, + ), + )) + client_receive_initial_metadata_event_future = test_utilities.SimpleFuture( + client_call.next_event) server_rpc_event = server_driver.first_event() @@ -208,19 +209,20 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server_complete_rpc_tag) server_call_driver.events() - with client_condition: - client_receive_first_message_tag = 'client_receive_first_message_tag' - client_receive_first_message_start_batch_result = ( - client_call.start_client_batch([ - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - ], client_receive_first_message_tag)) - client_due.add(client_receive_first_message_tag) - client_receive_first_message_event = client_driver.event_with_tag( - client_receive_first_message_tag) + client_recieve_initial_metadata_event = client_receive_initial_metadata_event_future.result( + ) + + client_receive_first_message_tag = 'client_receive_first_message_tag' + client_call.operate([ + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + ], client_receive_first_message_tag) + client_receive_first_message_event = client_call.next_event() - client_call_cancel_result = client_call.cancel() - client_driver.events() + client_call_cancel_result = client_call.cancel( + cygrpc.StatusCode.cancelled, 'Cancelled during test!') + client_complete_rpc_event = client_call.next_event() + channel.close(cygrpc.StatusCode.unknown, 'Channel closed!') server.shutdown(server_completion_queue, server_shutdown_tag) server.cancel_all_calls() server_driver.events() @@ -228,11 +230,6 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): self.assertEqual(cygrpc.CallError.ok, request_call_result) self.assertEqual(cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result) - self.assertEqual(cygrpc.CallError.ok, - client_receive_initial_metadata_start_batch_result) - self.assertEqual(cygrpc.CallError.ok, - client_complete_rpc_start_batch_result) - self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result) self.assertIs(server_rpc_tag, server_rpc_event.tag) self.assertEqual(cygrpc.CompletionType.operation_complete, server_rpc_event.completion_type) diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py index 23f5ef605d..724a690746 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py @@ -51,8 +51,8 @@ class TypeSmokeTest(unittest.TestCase): del server def testChannelUpDown(self): - channel = cygrpc.Channel(b'[::]:0', None) - del channel + channel = cygrpc.Channel(b'[::]:0', None, None) + channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!') def test_metadata_plugin_call_credentials_up_down(self): cygrpc.MetadataPluginCallCredentials(_metadata_plugin, @@ -121,7 +121,7 @@ class ServerClientMixin(object): client_credentials) else: self.client_channel = cygrpc.Channel('localhost:{}'.format( - self.port).encode(), set()) + self.port).encode(), set(), None) if host_override: self.host_argument = None # default host self.expected_host = host_override @@ -131,17 +131,20 @@ class ServerClientMixin(object): self.expected_host = self.host_argument def tearDownMixin(self): + self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!') + del self.client_channel del self.server del self.client_completion_queue del self.server_completion_queue - def _perform_operations(self, operations, call, queue, deadline, - description): - """Perform the list of operations with given call, queue, and deadline. + def _perform_queue_operations(self, operations, call, queue, deadline, + description): + """Perform the operations with given call, queue, and deadline. - Invocation errors are reported with as an exception with `description` in - the message. Performs the operations asynchronously, returning a future. - """ + Invocation errors are reported with as an exception with `description` + in the message. Performs the operations asynchronously, returning a + future. + """ def performer(): tag = object() @@ -185,9 +188,6 @@ class ServerClientMixin(object): self.assertEqual(cygrpc.CallError.ok, request_call_result) client_call_tag = object() - client_call = self.client_channel.create_call( - None, 0, self.client_completion_queue, METHOD, self.host_argument, - DEADLINE) client_initial_metadata = ( ( CLIENT_METADATA_ASCII_KEY, @@ -198,18 +198,24 @@ class ServerClientMixin(object): CLIENT_METADATA_BIN_VALUE, ), ) - client_start_batch_result = client_call.start_client_batch([ - cygrpc.SendInitialMetadataOperation(client_initial_metadata, - _EMPTY_FLAGS), - cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ], client_call_tag) - self.assertEqual(cygrpc.CallError.ok, client_start_batch_result) - client_event_future = test_utilities.CompletionQueuePollFuture( - self.client_completion_queue, DEADLINE) + client_call = self.client_channel.integrated_call( + 0, METHOD, self.host_argument, DEADLINE, client_initial_metadata, + None, [ + ( + [ + cygrpc.SendInitialMetadataOperation( + client_initial_metadata, _EMPTY_FLAGS), + cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + client_call_tag, + ), + ]) + client_event_future = test_utilities.SimpleFuture( + self.client_channel.next_call_event) request_event = self.server_completion_queue.poll(deadline=DEADLINE) self.assertEqual(cygrpc.CompletionType.operation_complete, @@ -304,66 +310,76 @@ class ServerClientMixin(object): del client_call del server_call - def test6522(self): + def test_6522(self): DEADLINE = time.time() + 5 DEADLINE_TOLERANCE = 0.25 METHOD = b'twinkies' empty_metadata = () + # Prologue server_request_tag = object() self.server.request_call(self.server_completion_queue, self.server_completion_queue, server_request_tag) - client_call = self.client_channel.create_call( - None, 0, self.client_completion_queue, METHOD, self.host_argument, - DEADLINE) - - # Prologue - def perform_client_operations(operations, description): - return self._perform_operations(operations, client_call, - self.client_completion_queue, - DEADLINE, description) - - client_event_future = perform_client_operations([ - cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - ], "Client prologue") + client_call = self.client_channel.segregated_call( + 0, METHOD, self.host_argument, DEADLINE, None, None, ([( + [ + cygrpc.SendInitialMetadataOperation(empty_metadata, + _EMPTY_FLAGS), + cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), + ], + object(), + ), ( + [ + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + object(), + )])) + + client_initial_metadata_event_future = test_utilities.SimpleFuture( + client_call.next_event) request_event = self.server_completion_queue.poll(deadline=DEADLINE) server_call = request_event.call def perform_server_operations(operations, description): - return self._perform_operations(operations, server_call, - self.server_completion_queue, - DEADLINE, description) + return self._perform_queue_operations(operations, server_call, + self.server_completion_queue, + DEADLINE, description) server_event_future = perform_server_operations([ cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), ], "Server prologue") - client_event_future.result() # force completion + client_initial_metadata_event_future.result() # force completion server_event_future.result() # Messaging for _ in range(10): - client_event_future = perform_client_operations([ + client_call.operate([ cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), ], "Client message") + client_message_event_future = test_utilities.SimpleFuture( + client_call.next_event) server_event_future = perform_server_operations([ cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), ], "Server receive") - client_event_future.result() # force completion + client_message_event_future.result() # force completion server_event_future.result() # Epilogue - client_event_future = perform_client_operations([ + client_call.operate([ cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS) ], "Client epilogue") + # One for ReceiveStatusOnClient, one for SendCloseFromClient. + client_events_future = test_utilities.SimpleFuture( + lambda: { + client_call.next_event(), + client_call.next_event(),}) server_event_future = perform_server_operations([ cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), @@ -371,7 +387,7 @@ class ServerClientMixin(object): empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS) ], "Server epilogue") - client_event_future.result() # force completion + client_events_future.result() # force completion server_event_future.result() diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py index 4edf0fc4ad..f153089a24 100644 --- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -81,29 +81,16 @@ class InvalidMetadataTest(unittest.TestCase): request = b'\x07\x08' metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_future = self._unary_unary.future(request, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - response_future.result() - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_future.details(), expected_error_details) - self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._unary_unary.future(request, metadata=metadata) def testUnaryRequestStreamResponse(self): request = b'\x37\x58' metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_iterator = self._unary_stream(request, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - next(response_iterator) - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_iterator.details(), expected_error_details) - self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._unary_stream(request, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) def testStreamRequestBlockingUnaryResponse(self): request_iterator = ( @@ -129,32 +116,18 @@ class InvalidMetadataTest(unittest.TestCase): b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_future = self._stream_unary.future( - request_iterator, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - response_future.result() - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_future.details(), expected_error_details) - self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._stream_unary.future(request_iterator, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) def testStreamRequestStreamResponse(self): request_iterator = ( b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) metadata = (('InVaLiD', 'StreamRequestStreamResponse'),) expected_error_details = "metadata was invalid: %s" % metadata - response_iterator = self._stream_stream( - request_iterator, metadata=metadata) - with self.assertRaises(grpc.RpcError) as exception_context: - next(response_iterator) - self.assertEqual(exception_context.exception.details(), - expected_error_details) - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.INTERNAL) - self.assertEqual(response_iterator.details(), expected_error_details) - self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL) + with self.assertRaises(ValueError) as exception_context: + self._stream_stream(request_iterator, metadata=metadata) + self.assertIn(expected_error_details, str(exception_context.exception)) if __name__ == '__main__': |