diff options
Diffstat (limited to 'src/python/grpcio')
-rw-r--r-- | src/python/grpcio/grpc/__init__.py | 100 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_common.py | 42 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi | 3 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi | 59 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi | 6 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi | 2 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_cython/imports.generated.h | 4 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_server.py | 36 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_utilities.py | 26 | ||||
-rw-r--r-- | src/python/grpcio/grpc_core_dependencies.py | 1 | ||||
-rw-r--r-- | src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py | 583 | ||||
-rw-r--r-- | src/python/grpcio/tests/tests.json | 2 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/_rpc_test.py | 9 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/_thread_cleanup_test.py | 117 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/beta/test_utilities.py | 4 |
15 files changed, 921 insertions, 73 deletions
diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 5ba5a4e1fd..2fbb9f45e6 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -900,6 +900,102 @@ class Server(six.with_metaclass(abc.ABCMeta)): ################################# Functions ################################ +def unary_unary_rpc_method_handler( + behavior, request_deserializer=None, response_serializer=None): + """Creates an RpcMethodHandler for a unary-unary RPC method. + + Args: + behavior: The implementation of an RPC method as a callable behavior taking + a single request value and returning a single response value. + request_deserializer: An optional request deserialization behavior. + response_serializer: An optional response serialization behavior. + + Returns: + An RpcMethodHandler for a unary-unary RPC method constructed from the given + parameters. + """ + from grpc import _utilities + return _utilities.RpcMethodHandler( + False, False, request_deserializer, response_serializer, behavior, None, + None, None) + + +def unary_stream_rpc_method_handler( + behavior, request_deserializer=None, response_serializer=None): + """Creates an RpcMethodHandler for a unary-stream RPC method. + + Args: + behavior: The implementation of an RPC method as a callable behavior taking + a single request value and returning an iterator of response values. + request_deserializer: An optional request deserialization behavior. + response_serializer: An optional response serialization behavior. + + Returns: + An RpcMethodHandler for a unary-stream RPC method constructed from the + given parameters. + """ + from grpc import _utilities + return _utilities.RpcMethodHandler( + False, True, request_deserializer, response_serializer, None, behavior, + None, None) + + +def stream_unary_rpc_method_handler( + behavior, request_deserializer=None, response_serializer=None): + """Creates an RpcMethodHandler for a stream-unary RPC method. + + Args: + behavior: The implementation of an RPC method as a callable behavior taking + an iterator of request values and returning a single response value. + request_deserializer: An optional request deserialization behavior. + response_serializer: An optional response serialization behavior. + + Returns: + An RpcMethodHandler for a stream-unary RPC method constructed from the + given parameters. + """ + from grpc import _utilities + return _utilities.RpcMethodHandler( + True, False, request_deserializer, response_serializer, None, None, + behavior, None) + + +def stream_stream_rpc_method_handler( + behavior, request_deserializer=None, response_serializer=None): + """Creates an RpcMethodHandler for a stream-stream RPC method. + + Args: + behavior: The implementation of an RPC method as a callable behavior taking + an iterator of request values and returning an iterator of response + values. + request_deserializer: An optional request deserialization behavior. + response_serializer: An optional response serialization behavior. + + Returns: + An RpcMethodHandler for a stream-stream RPC method constructed from the + given parameters. + """ + from grpc import _utilities + return _utilities.RpcMethodHandler( + True, True, request_deserializer, response_serializer, None, None, None, + behavior) + + +def method_handlers_generic_handler(service, method_handlers): + """Creates a grpc.GenericRpcHandler from RpcMethodHandlers. + + Args: + service: A service name to be used for the given method handlers. + method_handlers: A dictionary from method name to RpcMethodHandler + implementing the named method. + + Returns: + A GenericRpcHandler constructed from the given parameters. + """ + from grpc import _utilities + return _utilities.DictionaryGenericHandler(service, method_handlers) + + def ssl_channel_credentials( root_certificates=None, private_key=None, certificate_chain=None): """Creates a ChannelCredentials for use with an SSL-enabled Channel. @@ -1059,7 +1155,7 @@ def insecure_channel(target, options=None): A Channel to the target through which RPCs may be conducted. """ from grpc import _channel - return _channel.Channel(target, None, options) + return _channel.Channel(target, options, None) def secure_channel(target, credentials, options=None): @@ -1075,7 +1171,7 @@ def secure_channel(target, credentials, options=None): A Channel to the target through which RPCs may be conducted. """ from grpc import _channel - return _channel.Channel(target, credentials, options) + return _channel.Channel(target, options, credentials) def server(generic_rpc_handlers, thread_pool, options=None): diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index b8688a0149..1fd1704f18 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -30,6 +30,8 @@ """Shared implementation.""" import logging +import threading +import time import six @@ -110,3 +112,43 @@ def fully_qualified_method(group, method): group = _encode(group) method = _encode(method) return b'/' + group + b'/' + method + + +class CleanupThread(threading.Thread): + """A threading.Thread subclass supporting custom behavior on join(). + + On Python Interpreter exit, Python will attempt to join outstanding threads + prior to garbage collection. We may need to do additional cleanup, and + we accomplish this by overriding the join() method. + """ + + def __init__(self, behavior, group=None, target=None, name=None, + args=(), kwargs={}): + """Constructor. + + Args: + behavior (function): Function called on join() with a single + argument, timeout, indicating the maximum duration of + `behavior`, or None indicating `behavior` has no deadline. + `behavior` must be idempotent. + group (None): should be None. Reseved for future extensions + when ThreadGroup is implemented. + target (function): The function to invoke when this thread is + run. Defaults to None. + name (str): The name of this thread. Defaults to None. + args (tuple[object]): A tuple of arguments to pass to `target`. + kwargs (dict[str,object]): A dictionary of keyword arguments to + pass to `target`. + """ + super(CleanupThread, self).__init__(group=group, target=target, + name=name, args=args, kwargs=kwargs) + self._behavior = behavior + + def join(self, timeout=None): + start_time = time.time() + self._behavior(timeout) + end_time = time.time() + if timeout is not None: + timeout -= end_time - start_time + timeout = max(timeout, 0) + super(CleanupThread, self).join(timeout) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi index a67c963684..01089c3dc0 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi @@ -31,9 +31,6 @@ cdef class CompletionQueue: cdef grpc_completion_queue *c_completion_queue - cdef object pluck_condition - cdef int num_plucking - cdef int num_polling cdef bint is_shutting_down cdef bint is_shutdown diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi index cdae39d519..90266516fe 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi @@ -32,6 +32,8 @@ cimport cpython import threading import time +cdef int _INTERRUPT_CHECK_PERIOD_MS = 200 + cdef class CompletionQueue: @@ -40,9 +42,6 @@ cdef class CompletionQueue: self.c_completion_queue = grpc_completion_queue_create(NULL) self.is_shutting_down = False self.is_shutdown = False - self.pluck_condition = threading.Condition() - self.num_plucking = 0 - self.num_polling = 0 cdef _interpret_event(self, grpc_event event): cdef OperationTag tag = None @@ -83,45 +82,27 @@ cdef class CompletionQueue: def poll(self, Timespec deadline=None): # We name this 'poll' to avoid problems with CPython's expectations for # 'special' methods (like next and __next__). + cdef gpr_timespec c_increment + cdef gpr_timespec c_timeout cdef gpr_timespec c_deadline with nogil: + c_increment = gpr_time_from_millis(_INTERRUPT_CHECK_PERIOD_MS, GPR_TIMESPAN) c_deadline = gpr_inf_future(GPR_CLOCK_REALTIME) - if deadline is not None: - c_deadline = deadline.c_time - cdef grpc_event event - - # Poll within a critical section to detect contention - with self.pluck_condition: - assert self.num_plucking == 0, 'cannot simultaneously pluck and poll' - self.num_polling += 1 - with nogil: - event = grpc_completion_queue_next( - self.c_completion_queue, c_deadline, NULL) - with self.pluck_condition: - self.num_polling -= 1 - return self._interpret_event(event) - - def pluck(self, OperationTag tag, Timespec deadline=None): - # Plucking a 'None' tag is equivalent to passing control to GRPC core until - # the deadline. - cdef gpr_timespec c_deadline = gpr_inf_future( - GPR_CLOCK_REALTIME) - if deadline is not None: - c_deadline = deadline.c_time - cdef grpc_event event - - # Pluck within a critical section to detect contention - with self.pluck_condition: - assert self.num_polling == 0, 'cannot simultaneously pluck and poll' - assert self.num_plucking < GRPC_MAX_COMPLETION_QUEUE_PLUCKERS, ( - 'cannot pluck more than {} times simultaneously'.format( - GRPC_MAX_COMPLETION_QUEUE_PLUCKERS)) - self.num_plucking += 1 - with nogil: - event = grpc_completion_queue_pluck( - self.c_completion_queue, <cpython.PyObject *>tag, c_deadline, NULL) - with self.pluck_condition: - self.num_plucking -= 1 + if deadline is not None: + c_deadline = deadline.c_time + + while True: + c_timeout = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c_increment) + if gpr_time_cmp(c_timeout, c_deadline) > 0: + c_timeout = c_deadline + event = grpc_completion_queue_next( + self.c_completion_queue, c_timeout, NULL) + if event.type != GRPC_QUEUE_TIMEOUT or gpr_time_cmp(c_timeout, c_deadline) == 0: + break; + + # Handle any signals + with gil: + cpython.PyErr_CheckSignals() return self._interpret_event(event) def shutdown(self): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index 05b8886df7..168b9751aa 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -80,6 +80,12 @@ cdef extern from "grpc/_cython/loader.h": gpr_timespec gpr_convert_clock_type(gpr_timespec t, gpr_clock_type target_clock) nogil + gpr_timespec gpr_time_from_millis(int64_t ms, gpr_clock_type type) nogil + + gpr_timespec gpr_time_add(gpr_timespec a, gpr_timespec b) nogil + + int gpr_time_cmp(gpr_timespec a, gpr_timespec b) nogil + ctypedef enum grpc_status_code: GRPC_STATUS_OK GRPC_STATUS_CANCELLED diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index c8a73e65d6..42afeb8498 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -99,7 +99,7 @@ cdef class Server: with nogil: grpc_server_start(self.c_server) # Ensure the core has gotten a chance to do the start-up work - self.backup_shutdown_queue.pluck(None, Timespec(None)) + self.backup_shutdown_queue.poll(Timespec(None)) def add_http2_port(self, address, ServerCredentials server_credentials=None): diff --git a/src/python/grpcio/grpc/_cython/imports.generated.h b/src/python/grpcio/grpc/_cython/imports.generated.h index a14f4f5183..d52e8591b3 100644 --- a/src/python/grpcio/grpc/_cython/imports.generated.h +++ b/src/python/grpcio/grpc/_cython/imports.generated.h @@ -482,7 +482,7 @@ extern grpc_byte_buffer_reader_readall_type grpc_byte_buffer_reader_readall_impo typedef grpc_byte_buffer *(*grpc_raw_byte_buffer_from_reader_type)(grpc_byte_buffer_reader *reader); extern grpc_raw_byte_buffer_from_reader_type grpc_raw_byte_buffer_from_reader_import; #define grpc_raw_byte_buffer_from_reader grpc_raw_byte_buffer_from_reader_import -typedef void(*gpr_log_type)(const char *file, int line, gpr_log_severity severity, const char *format, ...); +typedef void(*gpr_log_type)(const char *file, int line, gpr_log_severity severity, const char *format, ...) GPRC_PRINT_FORMAT_CHECK(4, 5); extern gpr_log_type gpr_log_import; #define gpr_log gpr_log_import typedef void(*gpr_log_message_type)(const char *file, int line, gpr_log_severity severity, const char *message); @@ -827,7 +827,7 @@ extern gpr_format_message_type gpr_format_message_import; typedef char *(*gpr_strdup_type)(const char *src); extern gpr_strdup_type gpr_strdup_import; #define gpr_strdup gpr_strdup_import -typedef int(*gpr_asprintf_type)(char **strp, const char *format, ...); +typedef int(*gpr_asprintf_type)(char **strp, const char *format, ...) GPRC_PRINT_FORMAT_CHECK(2, 3); extern gpr_asprintf_type gpr_asprintf_import; #define gpr_asprintf gpr_asprintf_import typedef const char *(*gpr_subprocess_binary_extension_type)(); diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index f4f6720497..bf20f15f72 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -60,6 +60,8 @@ _CANCELLED = 'cancelled' _EMPTY_FLAGS = 0 _EMPTY_METADATA = cygrpc.Metadata(()) +_UNEXPECTED_EXIT_SERVER_GRACE = 1.0 + def _serialized_request(request_event): return request_event.batch_operations[0].received_message.bytes() @@ -254,7 +256,7 @@ class _Context(grpc.ServicerContext): else: if self._state.initial_metadata_allowed: operation = cygrpc.operation_send_initial_metadata( - cygrpc.Metadata(initial_metadata), _EMPTY_FLAGS) + _common.metadata(initial_metadata), _EMPTY_FLAGS) self._rpc_event.operation_call.start_batch( cygrpc.Operations((operation,)), _send_initial_metadata(self._state)) @@ -342,10 +344,9 @@ def _unary_request(rpc_event, state, request_deserializer): if state.client is _CLOSED: details = '"{}" requires exactly one request message.'.format( rpc_event.request_call_details.method) - # TODO(5992#issuecomment-220761992): really, what status code? _abort( state, rpc_event.operation_call, - cygrpc.StatusCode.unavailable, details) + cygrpc.StatusCode.unimplemented, details) return None elif state.client is _CANCELLED: return None @@ -670,17 +671,6 @@ def _serve(state): return -def _start(state): - with state.lock: - if state.stage is not _ServerStage.STOPPED: - raise ValueError('Cannot start already-started server!') - state.server.start() - state.stage = _ServerStage.STARTED - _request_call(state) - thread = threading.Thread(target=_serve, args=(state,)) - thread.start() - - def _stop(state, grace): with state.lock: if state.stage is _ServerStage.STOPPED: @@ -719,6 +709,24 @@ def _stop(state, grace): return shutdown_event +def _start(state): + with state.lock: + if state.stage is not _ServerStage.STOPPED: + raise ValueError('Cannot start already-started server!') + state.server.start() + state.stage = _ServerStage.STARTED + _request_call(state) + def cleanup_server(timeout): + if timeout is None: + _stop(state, _UNEXPECTED_EXIT_SERVER_GRACE).wait() + else: + _stop(state, timeout).wait() + + thread = _common.CleanupThread( + cleanup_server, target=_serve, args=(state,)) + thread.start() + + class Server(grpc.Server): def __init__(self, generic_handlers, thread_pool): diff --git a/src/python/grpcio/grpc/_utilities.py b/src/python/grpcio/grpc/_utilities.py index a4ca9b7282..4850967fbc 100644 --- a/src/python/grpcio/grpc/_utilities.py +++ b/src/python/grpcio/grpc/_utilities.py @@ -29,16 +29,41 @@ """Internal utilities for gRPC Python.""" +import collections import threading import time +import six + import grpc +from grpc import _common from grpc.framework.foundation import callable_util _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = ( 'Exception calling connectivity future "done" callback!') +class RpcMethodHandler( + collections.namedtuple( + '_RpcMethodHandler', + ('request_streaming', 'response_streaming', 'request_deserializer', + 'response_serializer', 'unary_unary', 'unary_stream', 'stream_unary', + 'stream_stream',)), + grpc.RpcMethodHandler): + pass + + +class DictionaryGenericHandler(grpc.GenericRpcHandler): + + def __init__(self, service, method_handlers): + self._method_handlers = { + _common.fully_qualified_method(service, method): method_handler + for method, method_handler in six.iteritems(method_handlers)} + + def service(self, handler_call_details): + return self._method_handlers.get(handler_call_details.method) + + class _ChannelReadyFuture(grpc.Future): def __init__(self, channel): @@ -144,4 +169,3 @@ def channel_ready_future(channel): ready_future = _ChannelReadyFuture(channel) ready_future.start() return ready_future - diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index 0f9212f6a0..839c555f05 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -162,6 +162,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/transport/transport.c', 'src/core/lib/transport/transport_op_string.c', 'src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c', + 'src/core/ext/transport/chttp2/transport/bin_decoder.c', 'src/core/ext/transport/chttp2/transport/bin_encoder.c', 'src/core/ext/transport/chttp2/transport/chttp2_plugin.c', 'src/core/ext/transport/chttp2/transport/chttp2_transport.c', diff --git a/src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py b/src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py new file mode 100644 index 0000000000..1c9cbb0d0c --- /dev/null +++ b/src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py @@ -0,0 +1,583 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import collections +from concurrent import futures +import contextlib +import distutils.spawn +import errno +import os +import shutil +import subprocess +import sys +import tempfile +import threading +import unittest + +from six import moves + +import grpc +from tests.unit.framework.common import test_constants + +# Identifiers of entities we expect to find in the generated module. +STUB_IDENTIFIER = 'TestServiceStub' +SERVICER_IDENTIFIER = 'TestServiceServicer' +ADD_SERVICER_TO_SERVER_IDENTIFIER = 'add_TestServiceServicer_to_server' + + +class _ServicerMethods(object): + + def __init__(self, response_pb2, payload_pb2): + self._condition = threading.Condition() + self._paused = False + self._fail = False + self._response_pb2 = response_pb2 + self._payload_pb2 = payload_pb2 + + @contextlib.contextmanager + def pause(self): # pylint: disable=invalid-name + with self._condition: + self._paused = True + yield + with self._condition: + self._paused = False + self._condition.notify_all() + + @contextlib.contextmanager + def fail(self): # pylint: disable=invalid-name + with self._condition: + self._fail = True + yield + with self._condition: + self._fail = False + + def _control(self): # pylint: disable=invalid-name + with self._condition: + if self._fail: + raise ValueError() + while self._paused: + self._condition.wait() + + def UnaryCall(self, request, unused_rpc_context): + response = self._response_pb2.SimpleResponse() + response.payload.payload_type = self._payload_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * request.response_size + self._control() + return response + + def StreamingOutputCall(self, request, unused_rpc_context): + for parameter in request.response_parameters: + response = self._response_pb2.StreamingOutputCallResponse() + response.payload.payload_type = self._payload_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * parameter.size + self._control() + yield response + + def StreamingInputCall(self, request_iter, unused_rpc_context): + response = self._response_pb2.StreamingInputCallResponse() + aggregated_payload_size = 0 + for request in request_iter: + aggregated_payload_size += len(request.payload.payload_compressable) + response.aggregated_payload_size = aggregated_payload_size + self._control() + return response + + def FullDuplexCall(self, request_iter, unused_rpc_context): + for request in request_iter: + for parameter in request.response_parameters: + response = self._response_pb2.StreamingOutputCallResponse() + response.payload.payload_type = self._payload_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * parameter.size + self._control() + yield response + + def HalfDuplexCall(self, request_iter, unused_rpc_context): + responses = [] + for request in request_iter: + for parameter in request.response_parameters: + response = self._response_pb2.StreamingOutputCallResponse() + response.payload.payload_type = self._payload_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * parameter.size + self._control() + responses.append(response) + for response in responses: + yield response + + +class _Service( + collections.namedtuple( + '_Service', ('servicer_methods', 'server', 'stub',))): + """A live and running service. + + Attributes: + servicer_methods: The _ServicerMethods servicing RPCs. + server: The grpc.Server servicing RPCs. + stub: A stub on which to invoke RPCs. + """ + + +def _CreateService(service_pb2, response_pb2, payload_pb2): + """Provides a servicer backend and a stub. + + Args: + service_pb2: The service_pb2 module generated by this test. + response_pb2: The response_pb2 module generated by this test. + payload_pb2: The payload_pb2 module generated by this test. + + Returns: + A _Service with which to test RPCs. + """ + servicer_methods = _ServicerMethods(response_pb2, payload_pb2) + + class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)): + + def UnaryCall(self, request, context): + return servicer_methods.UnaryCall(request, context) + + def StreamingOutputCall(self, request, context): + return servicer_methods.StreamingOutputCall(request, context) + + def StreamingInputCall(self, request_iter, context): + return servicer_methods.StreamingInputCall(request_iter, context) + + def FullDuplexCall(self, request_iter, context): + return servicer_methods.FullDuplexCall(request_iter, context) + + def HalfDuplexCall(self, request_iter, context): + return servicer_methods.HalfDuplexCall(request_iter, context) + + server = grpc.server( + (), futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE)) + getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), server) + port = server.add_insecure_port('[::]:0') + server.start() + channel = grpc.insecure_channel('localhost:{}'.format(port)) + stub = getattr(service_pb2, STUB_IDENTIFIER)(channel) + return _Service(servicer_methods, server, stub) + + +def _CreateIncompleteService(service_pb2): + """Provides a servicer backend that fails to implement methods and its stub. + + Args: + service_pb2: The service_pb2 module generated by this test. + + Returns: + A _Service with which to test RPCs. The returned _Service's + servicer_methods implements none of the methods required of it. + """ + + class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)): + pass + + server = grpc.server( + (), futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE)) + getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), server) + port = server.add_insecure_port('[::]:0') + server.start() + channel = grpc.insecure_channel('localhost:{}'.format(port)) + stub = getattr(service_pb2, STUB_IDENTIFIER)(channel) + return _Service(None, server, stub) + + +def _streaming_input_request_iterator(request_pb2, payload_pb2): + for _ in range(3): + request = request_pb2.StreamingInputCallRequest() + request.payload.payload_type = payload_pb2.COMPRESSABLE + request.payload.payload_compressable = 'a' + yield request + + +def _streaming_output_request(request_pb2): + request = request_pb2.StreamingOutputCallRequest() + sizes = [1, 2, 3] + request.response_parameters.add(size=sizes[0], interval_us=0) + request.response_parameters.add(size=sizes[1], interval_us=0) + request.response_parameters.add(size=sizes[2], interval_us=0) + return request + + +def _full_duplex_request_iterator(request_pb2): + request = request_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=1, interval_us=0) + yield request + request = request_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=2, interval_us=0) + request.response_parameters.add(size=3, interval_us=0) + yield request + + +class PythonPluginTest(unittest.TestCase): + """Test case for the gRPC Python protoc-plugin. + + While reading these tests, remember that the futures API + (`stub.method.future()`) only gives futures for the *response-unary* + methods and does not exist for response-streaming methods. + """ + + def setUp(self): + # Assume that the appropriate protoc and grpc_python_plugins are on the + # path. + protoc_command = 'protoc' + protoc_plugin_filename = distutils.spawn.find_executable( + 'grpc_python_plugin') + if not os.path.isfile(protoc_command): + # Assume that if we haven't built protoc that it's on the system. + protoc_command = 'protoc' + + # Ensure that the output directory exists. + self.outdir = tempfile.mkdtemp() + + # Find all proto files + paths = [] + root_dir = os.path.dirname(os.path.realpath(__file__)) + proto_dir = os.path.join(root_dir, 'protos') + for walk_root, _, filenames in os.walk(proto_dir): + for filename in filenames: + if filename.endswith('.proto'): + path = os.path.join(walk_root, filename) + paths.append(path) + + # Invoke protoc with the plugin. + cmd = [ + protoc_command, + '--plugin=protoc-gen-python-grpc=%s' % protoc_plugin_filename, + '-I %s' % root_dir, + '--python_out=%s' % self.outdir, + '--python-grpc_out=%s' % self.outdir + ] + paths + subprocess.check_call(' '.join(cmd), shell=True, env=os.environ, + cwd=os.path.dirname(os.path.realpath(__file__))) + + # Generated proto directories dont include __init__.py, but + # these are needed for python package resolution + for walk_root, _, _ in os.walk(os.path.join(self.outdir, 'protos')): + path = os.path.join(walk_root, '__init__.py') + open(path, 'a').close() + + sys.path.insert(0, self.outdir) + + import protos.payload.test_payload_pb2 as payload_pb2 + import protos.requests.r.test_requests_pb2 as request_pb2 + import protos.responses.test_responses_pb2 as response_pb2 + import protos.service.test_service_pb2 as service_pb2 + self._payload_pb2 = payload_pb2 + self._request_pb2 = request_pb2 + self._response_pb2 = response_pb2 + self._service_pb2 = service_pb2 + + def tearDown(self): + try: + shutil.rmtree(self.outdir) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise + sys.path.remove(self.outdir) + + def testImportAttributes(self): + # check that we can access the generated module and its members. + self.assertIsNotNone( + getattr(self._service_pb2, STUB_IDENTIFIER, None)) + self.assertIsNotNone( + getattr(self._service_pb2, SERVICER_IDENTIFIER, None)) + self.assertIsNotNone( + getattr(self._service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER, None)) + + def testUpDown(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + self.assertIsNotNone(service.servicer_methods) + self.assertIsNotNone(service.server) + self.assertIsNotNone(service.stub) + + def testIncompleteServicer(self): + service = _CreateIncompleteService(self._service_pb2) + request = self._request_pb2.SimpleRequest(response_size=13) + with self.assertRaises(grpc.RpcError) as exception_context: + service.stub.UnaryCall(request) + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.UNIMPLEMENTED) + + def testUnaryCall(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + request = self._request_pb2.SimpleRequest(response_size=13) + response = service.stub.UnaryCall(request) + expected_response = service.servicer_methods.UnaryCall( + request, 'not a real context!') + self.assertEqual(expected_response, response) + + def testUnaryCallFuture(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + request = self._request_pb2.SimpleRequest(response_size=13) + # Check that the call does not block waiting for the server to respond. + with service.servicer_methods.pause(): + response_future = service.stub.UnaryCall.future(request) + response = response_future.result() + expected_response = service.servicer_methods.UnaryCall( + request, 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testUnaryCallFutureExpired(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + request = self._request_pb2.SimpleRequest(response_size=13) + with service.servicer_methods.pause(): + response_future = service.stub.UnaryCall.future( + request, timeout=test_constants.SHORT_TIMEOUT) + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertIs(response_future.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + + def testUnaryCallFutureCancelled(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + request = self._request_pb2.SimpleRequest(response_size=13) + with service.servicer_methods.pause(): + response_future = service.stub.UnaryCall.future(request) + response_future.cancel() + self.assertTrue(response_future.cancelled()) + self.assertIs(response_future.code(), grpc.StatusCode.CANCELLED) + + def testUnaryCallFutureFailed(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + request = self._request_pb2.SimpleRequest(response_size=13) + with service.servicer_methods.fail(): + response_future = service.stub.UnaryCall.future(request) + self.assertIsNotNone(response_future.exception()) + self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN) + + def testStreamingOutputCall(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + request = _streaming_output_request(self._request_pb2) + responses = service.stub.StreamingOutputCall(request) + expected_responses = service.servicer_methods.StreamingOutputCall( + request, 'not a real RpcContext!') + for expected_response, response in moves.zip_longest( + expected_responses, responses): + self.assertEqual(expected_response, response) + + def testStreamingOutputCallExpired(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + request = _streaming_output_request(self._request_pb2) + with service.servicer_methods.pause(): + responses = service.stub.StreamingOutputCall( + request, timeout=test_constants.SHORT_TIMEOUT) + with self.assertRaises(grpc.RpcError) as exception_context: + list(responses) + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + + def testStreamingOutputCallCancelled(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + request = _streaming_output_request(self._request_pb2) + responses = service.stub.StreamingOutputCall(request) + next(responses) + responses.cancel() + with self.assertRaises(grpc.RpcError) as exception_context: + next(responses) + self.assertIs(responses.code(), grpc.StatusCode.CANCELLED) + + def testStreamingOutputCallFailed(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + request = _streaming_output_request(self._request_pb2) + with service.servicer_methods.fail(): + responses = service.stub.StreamingOutputCall(request) + self.assertIsNotNone(responses) + with self.assertRaises(grpc.RpcError) as exception_context: + next(responses) + self.assertIs(exception_context.exception.code(), grpc.StatusCode.UNKNOWN) + + def testStreamingInputCall(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + response = service.stub.StreamingInputCall( + _streaming_input_request_iterator( + self._request_pb2, self._payload_pb2)) + expected_response = service.servicer_methods.StreamingInputCall( + _streaming_input_request_iterator(self._request_pb2, self._payload_pb2), + 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testStreamingInputCallFuture(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + with service.servicer_methods.pause(): + response_future = service.stub.StreamingInputCall.future( + _streaming_input_request_iterator( + self._request_pb2, self._payload_pb2)) + response = response_future.result() + expected_response = service.servicer_methods.StreamingInputCall( + _streaming_input_request_iterator(self._request_pb2, self._payload_pb2), + 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testStreamingInputCallFutureExpired(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + with service.servicer_methods.pause(): + response_future = service.stub.StreamingInputCall.future( + _streaming_input_request_iterator( + self._request_pb2, self._payload_pb2), + timeout=test_constants.SHORT_TIMEOUT) + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIs( + response_future.exception().code(), grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + + def testStreamingInputCallFutureCancelled(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + with service.servicer_methods.pause(): + response_future = service.stub.StreamingInputCall.future( + _streaming_input_request_iterator( + self._request_pb2, self._payload_pb2)) + response_future.cancel() + self.assertTrue(response_future.cancelled()) + with self.assertRaises(grpc.FutureCancelledError): + response_future.result() + + def testStreamingInputCallFutureFailed(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + with service.servicer_methods.fail(): + response_future = service.stub.StreamingInputCall.future( + _streaming_input_request_iterator( + self._request_pb2, self._payload_pb2)) + self.assertIsNotNone(response_future.exception()) + self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN) + + def testFullDuplexCall(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + responses = service.stub.FullDuplexCall( + _full_duplex_request_iterator(self._request_pb2)) + expected_responses = service.servicer_methods.FullDuplexCall( + _full_duplex_request_iterator(self._request_pb2), + 'not a real RpcContext!') + for expected_response, response in moves.zip_longest( + expected_responses, responses): + self.assertEqual(expected_response, response) + + def testFullDuplexCallExpired(self): + request_iterator = _full_duplex_request_iterator(self._request_pb2) + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + with service.servicer_methods.pause(): + responses = service.stub.FullDuplexCall( + request_iterator, timeout=test_constants.SHORT_TIMEOUT) + with self.assertRaises(grpc.RpcError) as exception_context: + list(responses) + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + + def testFullDuplexCallCancelled(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + request_iterator = _full_duplex_request_iterator(self._request_pb2) + responses = service.stub.FullDuplexCall(request_iterator) + next(responses) + responses.cancel() + with self.assertRaises(grpc.RpcError) as exception_context: + next(responses) + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.CANCELLED) + + def testFullDuplexCallFailed(self): + request_iterator = _full_duplex_request_iterator(self._request_pb2) + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + with service.servicer_methods.fail(): + responses = service.stub.FullDuplexCall(request_iterator) + with self.assertRaises(grpc.RpcError) as exception_context: + next(responses) + self.assertIs(exception_context.exception.code(), grpc.StatusCode.UNKNOWN) + + def testHalfDuplexCall(self): + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + def half_duplex_request_iterator(): + request = self._request_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=1, interval_us=0) + yield request + request = self._request_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=2, interval_us=0) + request.response_parameters.add(size=3, interval_us=0) + yield request + responses = service.stub.HalfDuplexCall(half_duplex_request_iterator()) + expected_responses = service.servicer_methods.HalfDuplexCall( + half_duplex_request_iterator(), 'not a real RpcContext!') + for expected_response, response in moves.zip_longest( + expected_responses, responses): + self.assertEqual(expected_response, response) + + def testHalfDuplexCallWedged(self): + condition = threading.Condition() + wait_cell = [False] + @contextlib.contextmanager + def wait(): # pylint: disable=invalid-name + # Where's Python 3's 'nonlocal' statement when you need it? + with condition: + wait_cell[0] = True + yield + with condition: + wait_cell[0] = False + condition.notify_all() + def half_duplex_request_iterator(): + request = self._request_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=1, interval_us=0) + yield request + with condition: + while wait_cell[0]: + condition.wait() + service = _CreateService( + self._service_pb2, self._response_pb2, self._payload_pb2) + with wait(): + responses = service.stub.HalfDuplexCall( + half_duplex_request_iterator(), timeout=test_constants.SHORT_TIMEOUT) + # half-duplex waits for the client to send all info + with self.assertRaises(grpc.RpcError) as exception_context: + next(responses) + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json index 53b2998b78..e64d2d04a7 100644 --- a/src/python/grpcio/tests/tests.json +++ b/src/python/grpcio/tests/tests.json @@ -49,10 +49,12 @@ "_low_test.HangingServerShutdown", "_low_test.InsecureServerInsecureClient", "_not_found_test.NotFoundTest", + "_python_plugin_test.PythonPluginTest", "_read_some_but_not_all_responses_test.ReadSomeButNotAllResponsesTest", "_rpc_test.RPCTest", "_sanity_test.Sanity", "_secure_interop_test.SecureInteropTest", + "_thread_cleanup_test.CleanupThreadTest", "_transmission_test.RoundTripTest", "_transmission_test.TransmissionTest", "_utilities_test.ChannelConnectivityTest", diff --git a/src/python/grpcio/tests/unit/_rpc_test.py b/src/python/grpcio/tests/unit/_rpc_test.py index b33bff490c..8407593c86 100644 --- a/src/python/grpcio/tests/unit/_rpc_test.py +++ b/src/python/grpcio/tests/unit/_rpc_test.py @@ -29,8 +29,6 @@ """Test of gRPC Python's application-layer API.""" -from __future__ import division - import itertools import threading import unittest @@ -193,13 +191,6 @@ class RPCTest(unittest.TestCase): self._channel = grpc.insecure_channel('localhost:%d' % port) - # TODO(nathaniel): Why is this necessary, and only in some development - # environments? - def tearDown(self): - del self._channel - del self._server - del self._server_pool - def testUnrecognizedMethod(self): request = b'abc' diff --git a/src/python/grpcio/tests/unit/_thread_cleanup_test.py b/src/python/grpcio/tests/unit/_thread_cleanup_test.py new file mode 100644 index 0000000000..3e4f317edc --- /dev/null +++ b/src/python/grpcio/tests/unit/_thread_cleanup_test.py @@ -0,0 +1,117 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Tests for CleanupThread.""" + +import threading +import time +import unittest + +from grpc import _common + +_SHORT_TIME = 0.5 +_LONG_TIME = 2.0 +_EPSILON = 0.1 + + +def cleanup(timeout): + if timeout is not None: + time.sleep(timeout) + else: + time.sleep(_LONG_TIME) + + +def slow_cleanup(timeout): + # Don't respect timeout + time.sleep(_LONG_TIME) + + +class CleanupThreadTest(unittest.TestCase): + + def testTargetInvocation(self): + event = threading.Event() + def target(arg1, arg2, arg3=None): + self.assertEqual('arg1', arg1) + self.assertEqual('arg2', arg2) + self.assertEqual('arg3', arg3) + event.set() + + cleanup_thread = _common.CleanupThread(behavior=lambda x: None, + target=target, name='test-name', + args=('arg1', 'arg2'), kwargs={'arg3': 'arg3'}) + cleanup_thread.start() + cleanup_thread.join() + self.assertEqual(cleanup_thread.name, 'test-name') + self.assertTrue(event.is_set()) + + def testJoinNoTimeout(self): + cleanup_thread = _common.CleanupThread(behavior=cleanup) + cleanup_thread.start() + start_time = time.time() + cleanup_thread.join() + end_time = time.time() + self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON) + + def testJoinTimeout(self): + cleanup_thread = _common.CleanupThread(behavior=cleanup) + cleanup_thread.start() + start_time = time.time() + cleanup_thread.join(_SHORT_TIME) + end_time = time.time() + self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON) + + def testJoinTimeoutSlowBehavior(self): + cleanup_thread = _common.CleanupThread(behavior=slow_cleanup) + cleanup_thread.start() + start_time = time.time() + cleanup_thread.join(_SHORT_TIME) + end_time = time.time() + self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON) + + def testJoinTimeoutSlowTarget(self): + event = threading.Event() + def target(): + event.wait(_LONG_TIME) + cleanup_thread = _common.CleanupThread(behavior=cleanup, target=target) + cleanup_thread.start() + start_time = time.time() + cleanup_thread.join(_SHORT_TIME) + end_time = time.time() + self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON) + event.set() + + def testJoinZeroTimeout(self): + cleanup_thread = _common.CleanupThread(behavior=cleanup) + cleanup_thread.start() + start_time = time.time() + cleanup_thread.join(0) + end_time = time.time() + self.assertAlmostEqual(0, end_time - start_time, delta=_EPSILON) + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/beta/test_utilities.py b/src/python/grpcio/tests/unit/beta/test_utilities.py index 66b5f72897..e374b203ce 100644 --- a/src/python/grpcio/tests/unit/beta/test_utilities.py +++ b/src/python/grpcio/tests/unit/beta/test_utilities.py @@ -50,6 +50,6 @@ def not_really_secure_channel( """ target = '%s:%d' % (host, port) channel = grpc.secure_channel( - target, ((b'grpc.ssl_target_name_override', server_host_override,),), - channel_credentials._credentials) + target, channel_credentials._credentials, + ((b'grpc.ssl_target_name_override', server_host_override,),)) return implementations.Channel(channel) |