diff options
Diffstat (limited to 'src/python/grpcio/tests/unit/_adapter/_low_test.py')
-rw-r--r-- | src/python/grpcio/tests/unit/_adapter/_low_test.py | 315 |
1 files changed, 315 insertions, 0 deletions
diff --git a/src/python/grpcio/tests/unit/_adapter/_low_test.py b/src/python/grpcio/tests/unit/_adapter/_low_test.py new file mode 100644 index 0000000000..39b6f247b4 --- /dev/null +++ b/src/python/grpcio/tests/unit/_adapter/_low_test.py @@ -0,0 +1,315 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import threading +import time +import unittest + +from grpc import _grpcio_metadata +from grpc._adapter import _types +from grpc._adapter import _low +from tests.unit import test_common + + +def wait_for_events(completion_queues, deadline): + """ + Args: + completion_queues: list of completion queues to wait for events on + deadline: absolute deadline to wait until + + Returns: + a sequence of events of length len(completion_queues). + """ + + results = [None] * len(completion_queues) + lock = threading.Lock() + threads = [] + def set_ith_result(i, completion_queue): + result = completion_queue.next(deadline) + with lock: + results[i] = result + for i, completion_queue in enumerate(completion_queues): + thread = threading.Thread(target=set_ith_result, + args=[i, completion_queue]) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + return results + + +class InsecureServerInsecureClient(unittest.TestCase): + + def setUp(self): + self.server_completion_queue = _low.CompletionQueue() + self.server = _low.Server(self.server_completion_queue, []) + self.port = self.server.add_http2_port('[::]:0') + self.client_completion_queue = _low.CompletionQueue() + self.client_channel = _low.Channel('localhost:%d'%self.port, []) + + self.server.start() + + def tearDown(self): + self.server.shutdown() + del self.client_channel + + self.client_completion_queue.shutdown() + while (self.client_completion_queue.next().type != + _types.EventType.QUEUE_SHUTDOWN): + pass + self.server_completion_queue.shutdown() + while (self.server_completion_queue.next().type != + _types.EventType.QUEUE_SHUTDOWN): + pass + + del self.client_completion_queue + del self.server_completion_queue + del self.server + + def testEcho(self): + deadline = time.time() + 5 + event_time_tolerance = 2 + deadline_tolerance = 0.25 + client_metadata_ascii_key = 'key' + client_metadata_ascii_value = 'val' + client_metadata_bin_key = 'key-bin' + client_metadata_bin_value = b'\0'*1000 + server_initial_metadata_key = 'init_me_me_me' + server_initial_metadata_value = 'whodawha?' + server_trailing_metadata_key = 'california_is_in_a_drought' + server_trailing_metadata_value = 'zomg it is' + server_status_code = _types.StatusCode.OK + server_status_details = 'our work is never over' + request = 'blarghaflargh' + response = 'his name is robert paulson' + method = 'twinkies' + host = 'hostess' + server_request_tag = object() + request_call_result = self.server.request_call(self.server_completion_queue, + server_request_tag) + + self.assertEqual(_types.CallError.OK, request_call_result) + + client_call_tag = object() + client_call = self.client_channel.create_call( + self.client_completion_queue, method, host, deadline) + client_initial_metadata = [ + (client_metadata_ascii_key, client_metadata_ascii_value), + (client_metadata_bin_key, client_metadata_bin_value) + ] + client_start_batch_result = client_call.start_batch([ + _types.OpArgs.send_initial_metadata(client_initial_metadata), + _types.OpArgs.send_message(request, 0), + _types.OpArgs.send_close_from_client(), + _types.OpArgs.recv_initial_metadata(), + _types.OpArgs.recv_message(), + _types.OpArgs.recv_status_on_client() + ], client_call_tag) + self.assertEqual(_types.CallError.OK, client_start_batch_result) + + client_no_event, request_event, = wait_for_events( + [self.client_completion_queue, self.server_completion_queue], + time.time() + event_time_tolerance) + self.assertEqual(client_no_event, None) + self.assertEqual(_types.EventType.OP_COMPLETE, request_event.type) + self.assertIsInstance(request_event.call, _low.Call) + self.assertIs(server_request_tag, request_event.tag) + self.assertEqual(1, len(request_event.results)) + received_initial_metadata = request_event.results[0].initial_metadata + # Check that our metadata were transmitted + self.assertTrue(test_common.metadata_transmitted(client_initial_metadata, + received_initial_metadata)) + # Check that Python's user agent string is a part of the full user agent + # string + received_initial_metadata_dict = dict(received_initial_metadata) + self.assertIn('user-agent', received_initial_metadata_dict) + self.assertIn('Python-gRPC-{}'.format(_grpcio_metadata.__version__), + received_initial_metadata_dict['user-agent']) + self.assertEqual(method, request_event.call_details.method) + self.assertEqual(host, request_event.call_details.host) + self.assertLess(abs(deadline - request_event.call_details.deadline), + deadline_tolerance) + + # Check that the channel is connected, and that both it and the call have + # the proper target and peer; do this after the first flurry of messages to + # avoid the possibility that connection was delayed by the core until the + # first message was sent. + self.assertEqual(_types.ConnectivityState.READY, + self.client_channel.check_connectivity_state(False)) + self.assertIsNotNone(self.client_channel.target()) + self.assertIsNotNone(client_call.peer()) + + server_call_tag = object() + server_call = request_event.call + server_initial_metadata = [ + (server_initial_metadata_key, server_initial_metadata_value) + ] + server_trailing_metadata = [ + (server_trailing_metadata_key, server_trailing_metadata_value) + ] + server_start_batch_result = server_call.start_batch([ + _types.OpArgs.send_initial_metadata(server_initial_metadata), + _types.OpArgs.recv_message(), + _types.OpArgs.send_message(response, 0), + _types.OpArgs.recv_close_on_server(), + _types.OpArgs.send_status_from_server( + server_trailing_metadata, server_status_code, server_status_details) + ], server_call_tag) + self.assertEqual(_types.CallError.OK, server_start_batch_result) + + client_event, server_event, = wait_for_events( + [self.client_completion_queue, self.server_completion_queue], + time.time() + event_time_tolerance) + + self.assertEqual(6, len(client_event.results)) + found_client_op_types = set() + for client_result in client_event.results: + # we expect each op type to be unique + self.assertNotIn(client_result.type, found_client_op_types) + found_client_op_types.add(client_result.type) + if client_result.type == _types.OpType.RECV_INITIAL_METADATA: + self.assertTrue( + test_common.metadata_transmitted(server_initial_metadata, + client_result.initial_metadata)) + elif client_result.type == _types.OpType.RECV_MESSAGE: + self.assertEqual(response, client_result.message) + elif client_result.type == _types.OpType.RECV_STATUS_ON_CLIENT: + self.assertTrue( + test_common.metadata_transmitted(server_trailing_metadata, + client_result.trailing_metadata)) + self.assertEqual(server_status_details, client_result.status.details) + self.assertEqual(server_status_code, client_result.status.code) + self.assertEqual(set([ + _types.OpType.SEND_INITIAL_METADATA, + _types.OpType.SEND_MESSAGE, + _types.OpType.SEND_CLOSE_FROM_CLIENT, + _types.OpType.RECV_INITIAL_METADATA, + _types.OpType.RECV_MESSAGE, + _types.OpType.RECV_STATUS_ON_CLIENT + ]), found_client_op_types) + + self.assertEqual(5, len(server_event.results)) + found_server_op_types = set() + for server_result in server_event.results: + self.assertNotIn(client_result.type, found_server_op_types) + found_server_op_types.add(server_result.type) + if server_result.type == _types.OpType.RECV_MESSAGE: + self.assertEqual(request, server_result.message) + elif server_result.type == _types.OpType.RECV_CLOSE_ON_SERVER: + self.assertFalse(server_result.cancelled) + self.assertEqual(set([ + _types.OpType.SEND_INITIAL_METADATA, + _types.OpType.RECV_MESSAGE, + _types.OpType.SEND_MESSAGE, + _types.OpType.RECV_CLOSE_ON_SERVER, + _types.OpType.SEND_STATUS_FROM_SERVER + ]), found_server_op_types) + + del client_call + del server_call + + +class HangingServerShutdown(unittest.TestCase): + + def setUp(self): + self.server_completion_queue = _low.CompletionQueue() + self.server = _low.Server(self.server_completion_queue, []) + self.port = self.server.add_http2_port('[::]:0') + self.client_completion_queue = _low.CompletionQueue() + self.client_channel = _low.Channel('localhost:%d'%self.port, []) + + self.server.start() + + def tearDown(self): + self.server.shutdown() + del self.client_channel + + self.client_completion_queue.shutdown() + self.server_completion_queue.shutdown() + while True: + client_event, server_event = wait_for_events( + [self.client_completion_queue, self.server_completion_queue], + float("+inf")) + if (client_event.type == _types.EventType.QUEUE_SHUTDOWN and + server_event.type == _types.EventType.QUEUE_SHUTDOWN): + break + + del self.client_completion_queue + del self.server_completion_queue + del self.server + + def testHangingServerCall(self): + deadline = time.time() + 5 + deadline_tolerance = 0.25 + event_time_tolerance = 2 + cancel_all_calls_time_tolerance = 0.5 + request = 'blarghaflargh' + method = 'twinkies' + host = 'hostess' + server_request_tag = object() + request_call_result = self.server.request_call(self.server_completion_queue, + server_request_tag) + + client_call_tag = object() + client_call = self.client_channel.create_call(self.client_completion_queue, + method, host, deadline) + client_start_batch_result = client_call.start_batch([ + _types.OpArgs.send_initial_metadata([]), + _types.OpArgs.send_message(request, 0), + _types.OpArgs.send_close_from_client(), + _types.OpArgs.recv_initial_metadata(), + _types.OpArgs.recv_message(), + _types.OpArgs.recv_status_on_client() + ], client_call_tag) + + client_no_event, request_event, = wait_for_events( + [self.client_completion_queue, self.server_completion_queue], + time.time() + event_time_tolerance) + + # Now try to shutdown the server and expect that we see server shutdown + # almost immediately after calling cancel_all_calls. + with self.assertRaises(RuntimeError): + self.server.cancel_all_calls() + shutdown_tag = object() + self.server.shutdown(shutdown_tag) + pre_cancel_timestamp = time.time() + self.server.cancel_all_calls() + finish_shutdown_timestamp = None + client_call_event, server_shutdown_event = wait_for_events( + [self.client_completion_queue, self.server_completion_queue], + time.time() + event_time_tolerance) + self.assertIs(shutdown_tag, server_shutdown_event.tag) + self.assertGreater(pre_cancel_timestamp + cancel_all_calls_time_tolerance, + time.time()) + + del client_call + + +if __name__ == '__main__': + unittest.main(verbosity=2) |