diff options
Diffstat (limited to 'src/python/grpcio')
35 files changed, 1406 insertions, 429 deletions
diff --git a/src/python/grpcio/README.rst b/src/python/grpcio/README.rst index afc4fe6a37..3fc318539e 100644 --- a/src/python/grpcio/README.rst +++ b/src/python/grpcio/README.rst @@ -46,7 +46,7 @@ package named :code:`python-dev`). :: $ export REPO_ROOT=grpc # REPO_ROOT can be any directory of your choice - $ git clone https://github.com/grpc/grpc.git $REPO_ROOT + $ git clone -b $(curl -L http://grpc.io/release) https://github.com/grpc/grpc $REPO_ROOT $ cd $REPO_ROOT $ git submodule update --init diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 1f34beeb2c..a89b501303 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -179,6 +179,7 @@ def _event_handler(state, call, response_deserializer): def _consume_request_iterator( request_iterator, state, call, request_serializer): event_handler = _event_handler(state, call, None) + def consume_request_iterator(): for request in request_iterator: serialized_request = _common.serialize(request, request_serializer) @@ -212,8 +213,18 @@ def _consume_request_iterator( ) call.start_batch(cygrpc.Operations(operations), event_handler) state.due.add(cygrpc.OperationType.send_close_from_client) - thread = threading.Thread(target=consume_request_iterator) - thread.start() + + def stop_consumption_thread(timeout): + with state.condition: + if state.code is None: + call.cancel() + state.cancelled = True + _abort(state, grpc.StatusCode.CANCELLED, 'Cancelled!') + state.condition.notify_all() + + consumption_thread = _common.CleanupThread( + stop_consumption_thread, target=consume_request_iterator) + consumption_thread.start() class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): @@ -353,13 +364,13 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): with self._state.condition: while self._state.initial_metadata is None: self._state.condition.wait() - return self._state.initial_metadata + return _common.application_metadata(self._state.initial_metadata) def trailing_metadata(self): with self._state.condition: while self._state.trailing_metadata is None: self._state.condition.wait() - return self._state.trailing_metadata + return _common.application_metadata(self._state.trailing_metadata) def code(self): with self._state.condition: @@ -371,7 +382,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): with self._state.condition: while self._state.details is None: self._state.condition.wait() - return self._state.details + return _common.decode(self._state.details) def _repr(self): with self._state.condition: @@ -379,7 +390,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): return '<_Rendezvous object of in-flight RPC>' else: return '<_Rendezvous of RPC that terminated with ({}, {})>'.format( - self._state.code, self._state.details) + self._state.code, _common.decode(self._state.details)) def __repr__(self): return self._repr() @@ -440,7 +451,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None) operations = ( cygrpc.operation_send_initial_metadata( - _common.metadata(metadata), _EMPTY_FLAGS), + _common.cygrpc_metadata(metadata), _EMPTY_FLAGS), cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS), cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), @@ -518,7 +529,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): event_handler) operations = ( cygrpc.operation_send_initial_metadata( - _common.metadata(metadata), _EMPTY_FLAGS), + _common.cygrpc_metadata(metadata), _EMPTY_FLAGS), cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS), cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), @@ -553,7 +564,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): None) operations = ( cygrpc.operation_send_initial_metadata( - _common.metadata(metadata), _EMPTY_FLAGS), + _common.cygrpc_metadata(metadata), _EMPTY_FLAGS), cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) @@ -597,7 +608,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): event_handler) operations = ( cygrpc.operation_send_initial_metadata( - _common.metadata(metadata), _EMPTY_FLAGS), + _common.cygrpc_metadata(metadata), _EMPTY_FLAGS), cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) @@ -634,7 +645,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): event_handler) operations = ( cygrpc.operation_send_initial_metadata( - _common.metadata(metadata), _EMPTY_FLAGS), + _common.cygrpc_metadata(metadata), _EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) call.start_batch(cygrpc.Operations(operations), event_handler) @@ -652,16 +663,27 @@ class _ChannelCallState(object): self.managed_calls = None -def _call_spin(state): - while True: - event = state.completion_queue.poll() - completed_call = event.tag(event) - if completed_call is not None: - with state.lock: - state.managed_calls.remove(completed_call) - if not state.managed_calls: - state.managed_calls = None - return +def _run_channel_spin_thread(state): + def channel_spin(): + while True: + event = state.completion_queue.poll() + completed_call = event.tag(event) + if completed_call is not None: + with state.lock: + state.managed_calls.remove(completed_call) + if not state.managed_calls: + state.managed_calls = None + return + + def stop_channel_spin(timeout): + with state.lock: + if state.managed_calls is not None: + for call in state.managed_calls: + call.cancel() + + channel_spin_thread = _common.CleanupThread( + stop_channel_spin, target=channel_spin) + channel_spin_thread.start() def _create_channel_managed_call(state): @@ -690,8 +712,7 @@ def _create_channel_managed_call(state): parent, flags, state.completion_queue, method, host, deadline) if state.managed_calls is None: state.managed_calls = set((call,)) - spin_thread = threading.Thread(target=_call_spin, args=(state,)) - spin_thread.start() + _run_channel_spin_thread(state) else: state.managed_calls.add(call) return call @@ -784,11 +805,18 @@ def _poll_connectivity(state, channel, initial_try_to_connect): _spawn_delivery(state, callbacks) +def _moot(state): + with state.lock: + del state.callbacks_and_connectivities[:] + + def _subscribe(state, callback, try_to_connect): with state.lock: if not state.callbacks_and_connectivities and not state.polling: - polling_thread = threading.Thread( - target=_poll_connectivity, + def cancel_all_subscriptions(timeout): + _moot(state) + polling_thread = _common.CleanupThread( + cancel_all_subscriptions, target=_poll_connectivity, args=(state, state.channel, bool(try_to_connect))) polling_thread.start() state.polling = True @@ -812,19 +840,19 @@ def _unsubscribe(state, callback): break -def _moot(state): - with state.lock: - del state.callbacks_and_connectivities[:] - - def _options(options): if options is None: pairs = ((cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT),) else: pairs = list(options) + [ (cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)] - return cygrpc.ChannelArgs( - cygrpc.ChannelArg(arg_name, arg_value) for arg_name, arg_value in pairs) + encoded_pairs = [ + (_common.encode(arg_name), arg_value) if isinstance(arg_value, int) + else (_common.encode(arg_name), _common.encode(arg_value)) + for arg_name, arg_value in pairs] + return cygrpc.ChannelArgs([ + cygrpc.ChannelArg(arg_name, arg_value) + for arg_name, arg_value in encoded_pairs]) class Channel(grpc.Channel): @@ -837,7 +865,8 @@ class Channel(grpc.Channel): options: Configuration options for the channel. credentials: A cygrpc.ChannelCredentials or None. """ - self._channel = cygrpc.Channel(target, _options(options), credentials) + self._channel = cygrpc.Channel( + _common.encode(target), _options(options), credentials) self._call_state = _ChannelCallState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel) @@ -850,26 +879,26 @@ class Channel(grpc.Channel): def unary_unary( self, method, request_serializer=None, response_deserializer=None): return _UnaryUnaryMultiCallable( - self._channel, _create_channel_managed_call(self._call_state), method, - request_serializer, response_deserializer) + self._channel, _create_channel_managed_call(self._call_state), + _common.encode(method), request_serializer, response_deserializer) def unary_stream( self, method, request_serializer=None, response_deserializer=None): return _UnaryStreamMultiCallable( - self._channel, _create_channel_managed_call(self._call_state), method, - request_serializer, response_deserializer) + self._channel, _create_channel_managed_call(self._call_state), + _common.encode(method), request_serializer, response_deserializer) def stream_unary( self, method, request_serializer=None, response_deserializer=None): return _StreamUnaryMultiCallable( - self._channel, _create_channel_managed_call(self._call_state), method, - request_serializer, response_deserializer) + self._channel, _create_channel_managed_call(self._call_state), + _common.encode(method), request_serializer, response_deserializer) def stream_stream( self, method, request_serializer=None, response_deserializer=None): return _StreamStreamMultiCallable( - self._channel, _create_channel_managed_call(self._call_state), method, - request_serializer, response_deserializer) + self._channel, _create_channel_managed_call(self._call_state), + _common.encode(method), request_serializer, response_deserializer) def __del__(self): _moot(self._connectivity_state) diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index f351bea9e3..4d7d521419 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -76,9 +76,37 @@ STATUS_CODE_TO_CYGRPC_STATUS_CODE = { } -def metadata(application_metadata): +def encode(s): + if isinstance(s, bytes): + return s + else: + return s.encode('ascii') + + +def decode(b): + if isinstance(b, str): + return b + else: + try: + return b.decode('utf8') + except UnicodeDecodeError: + logging.exception('Invalid encoding on {}'.format(b)) + return b.decode('latin1') + + +def cygrpc_metadata(application_metadata): return _EMPTY_METADATA if application_metadata is None else cygrpc.Metadata( - cygrpc.Metadatum(key, value) for key, value in application_metadata) + cygrpc.Metadatum(encode(key), encode(value)) + for key, value in application_metadata) + + +def application_metadata(cygrpc_metadata): + if cygrpc_metadata is None: + return () + else: + return tuple( + (decode(key), value if key[-4:] == b'-bin' else decode(value)) + for key, value in cygrpc_metadata) def _transform(message, transformer, exception_message): @@ -101,17 +129,8 @@ def deserialize(serialized_message, deserializer): 'Exception deserializing message!') -def _encode(s): - if isinstance(s, bytes): - return s - else: - return s.encode('ascii') - - def fully_qualified_method(group, method): - group = _encode(group) - method = _encode(method) - return b'/' + group + b'/' + method + return '/{}/{}'.format(group, method) class CleanupThread(threading.Thread): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index 866cff0d01..1406696510 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -32,9 +32,8 @@ cimport cpython cdef class Channel: - def __cinit__(self, target, ChannelArgs arguments=None, + def __cinit__(self, bytes target, ChannelArgs arguments=None, ChannelCredentials channel_credentials=None): - target = str_to_bytes(target) cdef grpc_channel_args *c_arguments = NULL cdef char *c_target = NULL self.c_channel = NULL @@ -57,8 +56,6 @@ cdef class Channel: def create_call(self, Call parent, int flags, CompletionQueue queue not None, method, host, Timespec deadline not None): - method = str_to_bytes(method) - host = str_to_bytes(host) if queue.is_shutting_down: raise ValueError("queue must not be shutting down or shutdown") cdef char *method_c_string = method diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi index 470382d609..b24e69243e 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi @@ -82,7 +82,7 @@ cdef class ServerCredentials: cdef class CredentialsMetadataPlugin: - def __cinit__(self, object plugin_callback, name): + def __cinit__(self, object plugin_callback, bytes name): """ Args: plugin_callback (callable): Callback accepting a service URL (str/bytes) @@ -91,9 +91,8 @@ cdef class CredentialsMetadataPlugin: when called should be non-blocking and eventually call the callback object with the appropriate status code/details and metadata (if successful). - name (str): Plugin name. + name (bytes): Plugin name. """ - name = str_to_bytes(name) if not callable(plugin_callback): raise ValueError('expected callable plugin_callback') self.plugin_callback = plugin_callback @@ -130,8 +129,7 @@ cdef void plugin_get_metadata( grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil: def python_callback( Metadata metadata, grpc_status_code status, - error_details): - error_details = str_to_bytes(error_details) + bytes error_details): cb(user_data, metadata.c_metadata_array.metadata, metadata.c_metadata_array.count, status, error_details) cdef CredentialsMetadataPlugin self = <CredentialsMetadataPlugin>state diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index 168b9751aa..f3b3d61273 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -37,6 +37,7 @@ cdef extern from "grpc/_cython/loader.h": ctypedef long int64_t int pygrpc_load_core(char*) + int pygrpc_initialize_core() void *gpr_malloc(size_t size) nogil void gpr_free(void *ptr) nogil diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi index 0055d0d3a2..8e651e880f 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi @@ -231,17 +231,10 @@ cdef class Event: cdef class ByteBuffer: - def __cinit__(self, data): + def __cinit__(self, bytes data): if data is None: self.c_byte_buffer = NULL return - if isinstance(data, ByteBuffer): - data = (<ByteBuffer>data).bytes() - if data is None: - self.c_byte_buffer = NULL - return - else: - data = str_to_bytes(data) cdef char *c_data = data cdef gpr_slice data_slice @@ -296,26 +289,28 @@ cdef class ByteBuffer: cdef class SslPemKeyCertPair: - def __cinit__(self, private_key, certificate_chain): - self.private_key = str_to_bytes(private_key) - self.certificate_chain = str_to_bytes(certificate_chain) + def __cinit__(self, bytes private_key, bytes certificate_chain): + self.private_key = private_key + self.certificate_chain = certificate_chain self.c_pair.private_key = self.private_key self.c_pair.certificate_chain = self.certificate_chain cdef class ChannelArg: - def __cinit__(self, key, value): - self.key = str_to_bytes(key) + def __cinit__(self, bytes key, value): + self.key = key self.c_arg.key = self.key if isinstance(value, int): - self.value = int(value) + self.value = value self.c_arg.type = GRPC_ARG_INTEGER self.c_arg.value.integer = self.value - else: - self.value = str_to_bytes(value) + elif isinstance(value, bytes): + self.value = value self.c_arg.type = GRPC_ARG_STRING self.c_arg.value.string = self.value + else: + raise TypeError('Expected int or bytes, got {}'.format(type(value))) cdef class ChannelArgs: @@ -347,9 +342,9 @@ cdef class ChannelArgs: cdef class Metadatum: - def __cinit__(self, key, value): - self._key = str_to_bytes(key) - self._value = str_to_bytes(value) + def __cinit__(self, bytes key, bytes value): + self._key = key + self._value = value self.c_metadata.key = self._key self.c_metadata.value = self._value self.c_metadata.value_length = len(self._value) @@ -563,8 +558,7 @@ def operation_send_close_from_client(int flags): return op def operation_send_status_from_server( - Metadata metadata, grpc_status_code code, details, int flags): - details = str_to_bytes(details) + Metadata metadata, grpc_status_code code, bytes details, int flags): cdef Operation op = Operation() op.c_op.type = GRPC_OP_SEND_STATUS_FROM_SERVER op.c_op.flags = flags diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index 42afeb8498..3e03b6efe1 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -101,7 +101,7 @@ cdef class Server: # Ensure the core has gotten a chance to do the start-up work self.backup_shutdown_queue.poll(Timespec(None)) - def add_http2_port(self, address, + def add_http2_port(self, bytes address, ServerCredentials server_credentials=None): address = str_to_bytes(address) self.references.append(address) diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx index cf146f5a04..7a8d0dd8a1 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pyx +++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx @@ -45,30 +45,22 @@ include "grpc/_cython/_cygrpc/security.pyx.pxi" include "grpc/_cython/_cygrpc/server.pyx.pxi" # -# Global state +# initialize gRPC # -cdef class _ModuleState: - cdef bint is_loaded +def _initialize(): + if 'win32' in sys.platform: + filename = pkg_resources.resource_filename( + 'grpc._cython', '_windows/grpc_c.64.python') + if not isinstance(filename, bytes): + filename = filename.encode() + if not pygrpc_load_core(filename): + raise ImportError('failed to load core gRPC library') + if not pygrpc_initialize_core(): + raise ImportError('failed to initialize core gRPC library') - def __cinit__(self): - if 'win32' in sys.platform: - filename = pkg_resources.resource_filename( - 'grpc._cython', '_windows/grpc_c.64.python') - if not pygrpc_load_core(filename): - raise ImportError('failed to load core gRPC library') - with nogil: - grpc_init() - self.is_loaded = True - with nogil: - grpc_set_ssl_roots_override_callback( + grpc_set_ssl_roots_override_callback( <grpc_ssl_roots_override_callback>ssl_roots_override_callback) - def __dealloc__(self): - if self.is_loaded: - with nogil: - grpc_shutdown() - -_module_state = _ModuleState() - +_initialize() diff --git a/src/python/grpcio/grpc/_cython/imports.generated.c b/src/python/grpcio/grpc/_cython/imports.generated.c index 8437e74ba0..d78ec2f66e 100644 --- a/src/python/grpcio/grpc/_cython/imports.generated.c +++ b/src/python/grpcio/grpc/_cython/imports.generated.c @@ -128,6 +128,7 @@ grpc_is_binary_header_type grpc_is_binary_header_import; grpc_call_error_to_string_type grpc_call_error_to_string_import; grpc_insecure_channel_create_from_fd_type grpc_insecure_channel_create_from_fd_import; grpc_server_add_insecure_channel_from_fd_type grpc_server_add_insecure_channel_from_fd_import; +grpc_use_signal_type grpc_use_signal_import; grpc_auth_property_iterator_next_type grpc_auth_property_iterator_next_import; grpc_auth_context_property_iterator_type grpc_auth_context_property_iterator_import; grpc_auth_context_peer_identity_type grpc_auth_context_peer_identity_import; @@ -403,6 +404,7 @@ void pygrpc_load_imports(HMODULE library) { grpc_call_error_to_string_import = (grpc_call_error_to_string_type) GetProcAddress(library, "grpc_call_error_to_string"); grpc_insecure_channel_create_from_fd_import = (grpc_insecure_channel_create_from_fd_type) GetProcAddress(library, "grpc_insecure_channel_create_from_fd"); grpc_server_add_insecure_channel_from_fd_import = (grpc_server_add_insecure_channel_from_fd_type) GetProcAddress(library, "grpc_server_add_insecure_channel_from_fd"); + grpc_use_signal_import = (grpc_use_signal_type) GetProcAddress(library, "grpc_use_signal"); grpc_auth_property_iterator_next_import = (grpc_auth_property_iterator_next_type) GetProcAddress(library, "grpc_auth_property_iterator_next"); grpc_auth_context_property_iterator_import = (grpc_auth_context_property_iterator_type) GetProcAddress(library, "grpc_auth_context_property_iterator"); grpc_auth_context_peer_identity_import = (grpc_auth_context_peer_identity_type) GetProcAddress(library, "grpc_auth_context_peer_identity"); diff --git a/src/python/grpcio/grpc/_cython/imports.generated.h b/src/python/grpcio/grpc/_cython/imports.generated.h index d52e8591b3..b3e341fe25 100644 --- a/src/python/grpcio/grpc/_cython/imports.generated.h +++ b/src/python/grpcio/grpc/_cython/imports.generated.h @@ -335,6 +335,9 @@ extern grpc_insecure_channel_create_from_fd_type grpc_insecure_channel_create_fr typedef void(*grpc_server_add_insecure_channel_from_fd_type)(grpc_server *server, grpc_completion_queue *cq, int fd); extern grpc_server_add_insecure_channel_from_fd_type grpc_server_add_insecure_channel_from_fd_import; #define grpc_server_add_insecure_channel_from_fd grpc_server_add_insecure_channel_from_fd_import +typedef void(*grpc_use_signal_type)(int signum); +extern grpc_use_signal_type grpc_use_signal_import; +#define grpc_use_signal grpc_use_signal_import typedef const grpc_auth_property *(*grpc_auth_property_iterator_next_type)(grpc_auth_property_iterator *it); extern grpc_auth_property_iterator_next_type grpc_auth_property_iterator_next_import; #define grpc_auth_property_iterator_next grpc_auth_property_iterator_next_import diff --git a/src/python/grpcio/grpc/_cython/loader.c b/src/python/grpcio/grpc/_cython/loader.c index b909ad594e..86b70dbb02 100644 --- a/src/python/grpcio/grpc/_cython/loader.c +++ b/src/python/grpcio/grpc/_cython/loader.c @@ -31,6 +31,7 @@ * */ +#include <Python.h> #include "loader.h" #ifdef __cplusplus @@ -62,6 +63,12 @@ int pygrpc_load_core(char *path) { return 1; } #endif /* !GPR_WINDOWS */ +// Cython doesn't have Py_AtExit bindings, so we call the C_API directly +int pygrpc_initialize_core(void) { + grpc_init(); + return Py_AtExit(grpc_shutdown) < 0 ? 0 : 1; +} + #ifdef __cplusplus } #endif /* __cpluslus */ diff --git a/src/python/grpcio/grpc/_cython/loader.h b/src/python/grpcio/grpc/_cython/loader.h index 3b8796d39f..eb4b1a1b01 100644 --- a/src/python/grpcio/grpc/_cython/loader.h +++ b/src/python/grpcio/grpc/_cython/loader.h @@ -46,6 +46,11 @@ extern "C" { /* Attempts to load the core if necessary, and return non-zero upon succes. */ int pygrpc_load_core(char *path); +/* Initializes grpc and registers grpc_shutdown() to be called right before + * interpreter exit. Returns non-zero upon success. + */ +int pygrpc_initialize_core(void); + #ifdef __cplusplus } #endif /* __cpluslus */ diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py index 4e9cfe710c..7cb5218c22 100644 --- a/src/python/grpcio/grpc/_plugin_wrapping.py +++ b/src/python/grpcio/grpc/_plugin_wrapping.py @@ -31,6 +31,7 @@ import collections import threading import grpc +from grpc import _common from grpc._cython import cygrpc @@ -62,17 +63,16 @@ class _WrappedCygrpcCallback(object): # TODO(atash) translate different Exception superclasses into different # status codes. self.cygrpc_callback( - cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message) + _common.EMPTY_METADATA, cygrpc.StatusCode.internal, + _common.encode(str(error))) def _invoke_success(self, metadata): try: - cygrpc_metadata = cygrpc.Metadata( - cygrpc.Metadatum(key, value) - for key, value in metadata) + cygrpc_metadata = _common.cygrpc_metadata(metadata) except Exception as error: self._invoke_failure(error) return - self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '') + self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, b'') def __call__(self, metadata, error): with self.is_called_lock: @@ -101,7 +101,7 @@ class _WrappedPlugin(object): def __call__(self, context, cygrpc_callback): wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback) wrapped_context = AuthMetadataContext( - context.service_url, context.method_name) + _common.decode(context.service_url), _common.decode(context.method_name)) try: self.plugin( wrapped_context, AuthMetadataPluginCallback(wrapped_cygrpc_callback)) @@ -120,4 +120,4 @@ def call_credentials_metadata_plugin(plugin, name): plugin's invocation must be non-blocking. """ return cygrpc.call_credentials_metadata_plugin( - cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name)) + cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), _common.encode(name))) diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index bf20f15f72..f4c114056f 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -87,7 +87,7 @@ def _abortion_code(state, code): def _details(state): - return '' if state.details is None else state.details + return b'' if state.details is None else state.details class _HandlerCallDetails( @@ -146,14 +146,14 @@ def _abort(state, call, code, details): cygrpc.operation_send_initial_metadata( _EMPTY_METADATA, _EMPTY_FLAGS), cygrpc.operation_send_status_from_server( - _common.metadata(state.trailing_metadata), effective_code, + _common.cygrpc_metadata(state.trailing_metadata), effective_code, effective_details, _EMPTY_FLAGS), ) token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN else: operations = ( cygrpc.operation_send_status_from_server( - _common.metadata(state.trailing_metadata), effective_code, + _common.cygrpc_metadata(state.trailing_metadata), effective_code, effective_details, _EMPTY_FLAGS), ) token = _SEND_STATUS_FROM_SERVER_TOKEN @@ -191,7 +191,7 @@ def _receive_message(state, call, request_deserializer): if request is None: _abort( state, call, cygrpc.StatusCode.internal, - 'Exception deserializing request!') + b'Exception deserializing request!') else: state.request = request state.condition.notify_all() @@ -244,10 +244,10 @@ class _Context(grpc.ServicerContext): self._state.disable_next_compression = True def invocation_metadata(self): - return self._rpc_event.request_metadata + return _common.application_metadata(self._rpc_event.request_metadata) def peer(self): - return self._rpc_event.operation_call.peer() + return _common.decode(self._rpc_event.operation_call.peer()) def send_initial_metadata(self, initial_metadata): with self._state.condition: @@ -256,7 +256,7 @@ class _Context(grpc.ServicerContext): else: if self._state.initial_metadata_allowed: operation = cygrpc.operation_send_initial_metadata( - _common.metadata(initial_metadata), _EMPTY_FLAGS) + _common.cygrpc_metadata(initial_metadata), _EMPTY_FLAGS) self._rpc_event.operation_call.start_batch( cygrpc.Operations((operation,)), _send_initial_metadata(self._state)) @@ -267,7 +267,8 @@ class _Context(grpc.ServicerContext): def set_trailing_metadata(self, trailing_metadata): with self._state.condition: - self._state.trailing_metadata = trailing_metadata + self._state.trailing_metadata = _common.cygrpc_metadata( + trailing_metadata) def set_code(self, code): with self._state.condition: @@ -275,7 +276,7 @@ class _Context(grpc.ServicerContext): def set_details(self, details): with self._state.condition: - self._state.details = details + self._state.details = _common.encode(details) class _RequestIterator(object): @@ -346,7 +347,7 @@ def _unary_request(rpc_event, state, request_deserializer): rpc_event.request_call_details.method) _abort( state, rpc_event.operation_call, - cygrpc.StatusCode.unimplemented, details) + cygrpc.StatusCode.unimplemented, _common.encode(details)) return None elif state.client is _CANCELLED: return None @@ -366,8 +367,8 @@ def _call_behavior(rpc_event, state, behavior, argument, request_deserializer): if e not in state.rpc_errors: details = 'Exception calling application: {}'.format(e) logging.exception(details) - _abort( - state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details) + _abort(state, rpc_event.operation_call, + cygrpc.StatusCode.unknown, _common.encode(details)) return None, False @@ -381,8 +382,8 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator): if e not in state.rpc_errors: details = 'Exception iterating responses: {}'.format(e) logging.exception(details) - _abort( - state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details) + _abort(state, rpc_event.operation_call, + cygrpc.StatusCode.unknown, _common.encode(details)) return None, False @@ -392,7 +393,7 @@ def _serialize_response(rpc_event, state, response, response_serializer): with state.condition: _abort( state, rpc_event.operation_call, cygrpc.StatusCode.internal, - 'Failed to serialize response!') + b'Failed to serialize response!') return None else: return serialized_response @@ -428,7 +429,7 @@ def _send_response(rpc_event, state, serialized_response): def _status(rpc_event, state, serialized_response): with state.condition: if state.client is not _CANCELLED: - trailing_metadata = _common.metadata(state.trailing_metadata) + trailing_metadata = _common.cygrpc_metadata(state.trailing_metadata) code = _completion_code(state) details = _details(state) operations = [ @@ -532,7 +533,8 @@ def _find_method_handler(rpc_event, generic_handlers): for generic_handler in generic_handlers: method_handler = generic_handler.service( _HandlerCallDetails( - rpc_event.request_call_details.method, rpc_event.request_metadata)) + _common.decode(rpc_event.request_call_details.method), + rpc_event.request_metadata)) if method_handler is not None: return method_handler else: @@ -545,7 +547,7 @@ def _handle_unrecognized_method(rpc_event): cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), cygrpc.operation_send_status_from_server( _EMPTY_METADATA, cygrpc.StatusCode.unimplemented, - 'Method not found!', _EMPTY_FLAGS), + b'Method not found!', _EMPTY_FLAGS), ) rpc_state = _RPCState() rpc_event.operation_call.start_batch( @@ -740,10 +742,10 @@ class Server(grpc.Server): _add_generic_handlers(self._state, generic_rpc_handlers) def add_insecure_port(self, address): - return _add_insecure_port(self._state, address) + return _add_insecure_port(self._state, _common.encode(address)) def add_secure_port(self, address, server_credentials): - return _add_secure_port(self._state, address, server_credentials) + return _add_secure_port(self._state, _common.encode(address), server_credentials) def start(self): _start(self._state) diff --git a/src/python/grpcio/grpc/beta/_server_adaptations.py b/src/python/grpcio/grpc/beta/_server_adaptations.py index 79e6ca87eb..1e1f80156a 100644 --- a/src/python/grpcio/grpc/beta/_server_adaptations.py +++ b/src/python/grpcio/grpc/beta/_server_adaptations.py @@ -79,7 +79,8 @@ class _FaceServicerContext(face.ServicerContext): return _ServerProtocolContext(self._servicer_context) def invocation_metadata(self): - return self._servicer_context.invocation_metadata() + return _common.cygrpc_metadata( + self._servicer_context.invocation_metadata()) def initial_metadata(self, initial_metadata): self._servicer_context.send_initial_metadata(initial_metadata) @@ -161,14 +162,24 @@ class _Callback(stream.Consumer): self._condition.wait() -def _pipe_requests(request_iterator, request_consumer, servicer_context): - for request in request_iterator: - if not servicer_context.is_active(): - return - request_consumer.consume(request) - if not servicer_context.is_active(): - return - request_consumer.terminate() +def _run_request_pipe_thread(request_iterator, request_consumer, + servicer_context): + thread_joined = threading.Event() + def pipe_requests(): + for request in request_iterator: + if not servicer_context.is_active() or thread_joined.is_set(): + return + request_consumer.consume(request) + if not servicer_context.is_active() or thread_joined.is_set(): + return + request_consumer.terminate() + + def stop_request_pipe(timeout): + thread_joined.set() + + request_pipe_thread = _common.CleanupThread( + stop_request_pipe, target=pipe_requests) + request_pipe_thread.start() def _adapt_unary_unary_event(unary_unary_event): @@ -206,10 +217,8 @@ def _adapt_stream_unary_event(stream_unary_event): raise abandonment.Abandoned() request_consumer = stream_unary_event( callback.consume_and_terminate, _FaceServicerContext(servicer_context)) - request_pipe_thread = threading.Thread( - target=_pipe_requests, - args=(request_iterator, request_consumer, servicer_context,)) - request_pipe_thread.start() + _run_request_pipe_thread( + request_iterator, request_consumer, servicer_context) return callback.draw_all_values()[0] return adaptation @@ -221,10 +230,8 @@ def _adapt_stream_stream_event(stream_stream_event): raise abandonment.Abandoned() request_consumer = stream_stream_event( callback, _FaceServicerContext(servicer_context)) - request_pipe_thread = threading.Thread( - target=_pipe_requests, - args=(request_iterator, request_consumer, servicer_context,)) - request_pipe_thread.start() + _run_request_pipe_thread( + request_iterator, request_consumer, servicer_context) while True: response = callback.draw_one_value() if response is None: diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index 839c555f05..b37e27c27e 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -94,6 +94,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/iomgr/endpoint_pair_posix.c', 'src/core/lib/iomgr/endpoint_pair_windows.c', 'src/core/lib/iomgr/error.c', + 'src/core/lib/iomgr/ev_epoll_linux.c', 'src/core/lib/iomgr/ev_poll_and_epoll_posix.c', 'src/core/lib/iomgr/ev_poll_posix.c', 'src/core/lib/iomgr/ev_posix.c', @@ -104,6 +105,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/iomgr/iomgr_posix.c', 'src/core/lib/iomgr/iomgr_windows.c', 'src/core/lib/iomgr/load_file.c', + 'src/core/lib/iomgr/network_status_tracker.c', 'src/core/lib/iomgr/polling_entity.c', 'src/core/lib/iomgr/pollset_set_windows.c', 'src/core/lib/iomgr/pollset_windows.c', diff --git a/src/python/grpcio/grpc_version.py b/src/python/grpcio/grpc_version.py index 0c13104d9d..0f4db9d972 100644 --- a/src/python/grpcio/grpc_version.py +++ b/src/python/grpcio/grpc_version.py @@ -29,4 +29,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc_version.py.template`!!! -VERSION='0.15.0.dev0' +VERSION='0.16.0.dev0' diff --git a/src/python/grpcio/tests/interop/methods.py b/src/python/grpcio/tests/interop/methods.py index 7eac511525..86aa0495a2 100644 --- a/src/python/grpcio/tests/interop/methods.py +++ b/src/python/grpcio/tests/interop/methods.py @@ -79,10 +79,11 @@ class TestService(test_pb2.BetaTestServiceServicer): def FullDuplexCall(self, request_iterator, context): for request in request_iterator: - yield messages_pb2.StreamingOutputCallResponse( - payload=messages_pb2.Payload( - type=request.payload.type, - body=b'\x00' * request.response_parameters[0].size)) + for response_parameters in request.response_parameters: + yield messages_pb2.StreamingOutputCallResponse( + payload=messages_pb2.Payload( + type=request.payload.type, + body=b'\x00' * response_parameters.size)) # NOTE(nathaniel): Apparently this is the same as the full-duplex call? # NOTE(atash): It isn't even called in the interop spec (Oct 22 2015)... diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json index 8e509621a8..45eb75b242 100644 --- a/src/python/grpcio/tests/tests.json +++ b/src/python/grpcio/tests/tests.json @@ -10,9 +10,10 @@ "_channel_connectivity_test.ChannelConnectivityTest", "_channel_ready_future_test.ChannelReadyFutureTest", "_channel_test.ChannelTest", - "_connectivity_channel_test.ChannelConnectivityTest", + "_compression_test.CompressionTest", "_connectivity_channel_test.ConnectivityStatesTest", "_empty_message_test.EmptyMessageTest", + "_exit_test.ExitTest", "_face_interface_test.DynamicInvokerBlockingInvocationInlineServiceTest", "_face_interface_test.DynamicInvokerFutureInvocationAsynchronousEventServiceTest", "_face_interface_test.GenericInvokerBlockingInvocationInlineServiceTest", @@ -24,6 +25,7 @@ "_implementations_test.ChannelCredentialsTest", "_insecure_interop_test.InsecureInteropTest", "_logging_pool_test.LoggingPoolTest", + "_metadata_code_details_test.MetadataCodeDetailsTest", "_metadata_test.MetadataTest", "_not_found_test.NotFoundTest", "_python_plugin_test.PythonPluginTest", diff --git a/src/python/grpcio/tests/unit/_compression_test.py b/src/python/grpcio/tests/unit/_compression_test.py new file mode 100644 index 0000000000..9e8b8578c1 --- /dev/null +++ b/src/python/grpcio/tests/unit/_compression_test.py @@ -0,0 +1,133 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Tests server and client side compression.""" + +import unittest + +import grpc +from grpc import _grpcio_metadata +from grpc.framework.foundation import logging_pool + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_UNARY_UNARY = '/test/UnaryUnary' +_STREAM_STREAM = '/test/StreamStream' + + +def handle_unary(request, servicer_context): + servicer_context.send_initial_metadata([ + ('grpc-internal-encoding-request', 'gzip')]) + return request + + +def handle_stream(request_iterator, servicer_context): + # TODO(issue:#6891) We should be able to remove this loop, + # and replace with return; yield + servicer_context.send_initial_metadata([ + ('grpc-internal-encoding-request', 'gzip')]) + for request in request_iterator: + yield request + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + self.stream_stream = lambda x, y: handle_stream(x, y) + elif not self.request_streaming and not self.response_streaming: + self.unary_unary = lambda x, y: handle_unary(x, y) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler(False, False) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler(True, True) + else: + return None + + +class CompressionTest(unittest.TestCase): + + def setUp(self): + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + self._server = grpc.server((_GenericHandler(),), self._server_pool) + self._port = self._server.add_insecure_port('[::]:0') + self._server.start() + + def testUnary(self): + request = b'\x00' * 100 + + # Client -> server compressed through default client channel compression + # settings. Server -> client compressed via server-side metadata setting. + # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer + # literal with proper use of the public API. + compressed_channel = grpc.insecure_channel('localhost:%d' % self._port, + options=[('grpc.default_compression_algorithm', 1)]) + multi_callable = compressed_channel.unary_unary(_UNARY_UNARY) + response = multi_callable(request) + self.assertEqual(request, response) + + # Client -> server compressed through client metadata setting. Server -> + # client compressed via server-side metadata setting. + # TODO(https://github.com/grpc/grpc/issues/4078): replace the "0" integer + # literal with proper use of the public API. + uncompressed_channel = grpc.insecure_channel('localhost:%d' % self._port, + options=[('grpc.default_compression_algorithm', 0)]) + multi_callable = compressed_channel.unary_unary(_UNARY_UNARY) + response = multi_callable(request, metadata=[ + ('grpc-internal-encoding-request', 'gzip')]) + self.assertEqual(request, response) + + def testStreaming(self): + request = b'\x00' * 100 + + # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer + # literal with proper use of the public API. + compressed_channel = grpc.insecure_channel('localhost:%d' % self._port, + options=[('grpc.default_compression_algorithm', 1)]) + multi_callable = compressed_channel.stream_stream(_STREAM_STREAM) + call = multi_callable([request] * test_constants.STREAM_LENGTH) + for response in call: + self.assertEqual(request, response) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py index c1de779014..cac0c8b3b9 100644 --- a/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py @@ -159,9 +159,9 @@ class CancelManyCallsTest(unittest.TestCase): server_completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server() server.register_completion_queue(server_completion_queue) - port = server.add_http2_port('[::]:0') + port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port)) + channel = cygrpc.Channel('localhost:{}'.format(port).encode()) state = _State() diff --git a/src/python/grpcio/tests/unit/_cython/_channel_test.py b/src/python/grpcio/tests/unit/_cython/_channel_test.py index 3dc7a246ae..f9c8a3ac62 100644 --- a/src/python/grpcio/tests/unit/_cython/_channel_test.py +++ b/src/python/grpcio/tests/unit/_cython/_channel_test.py @@ -37,7 +37,7 @@ from tests.unit.framework.common import test_constants def _channel_and_completion_queue(): - channel = cygrpc.Channel('localhost:54321', cygrpc.ChannelArgs(())) + channel = cygrpc.Channel(b'localhost:54321', cygrpc.ChannelArgs(())) completion_queue = cygrpc.CompletionQueue() return channel, completion_queue diff --git a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py index 6ae7a90fbe..27fcee0d6f 100644 --- a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py +++ b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -126,9 +126,9 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): server_completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server() server.register_completion_queue(server_completion_queue) - port = server.add_http2_port('[::]:0') + port = server.add_http2_port(b'[::]:0') server.start() - channel = cygrpc.Channel('localhost:{}'.format(port)) + channel = cygrpc.Channel('localhost:{}'.format(port).encode()) server_shutdown_tag = 'server_shutdown_tag' server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag) diff --git a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py index a006a20ce3..b740695e35 100644 --- a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py @@ -46,38 +46,38 @@ def _metadata_plugin_callback(context, callback): callback(cygrpc.Metadata( [cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY, _CALL_CREDENTIALS_METADATA_VALUE)]), - cygrpc.StatusCode.ok, '') + cygrpc.StatusCode.ok, b'') class TypeSmokeTest(unittest.TestCase): def testStringsInUtilitiesUpDown(self): self.assertEqual(0, cygrpc.StatusCode.ok) - metadatum = cygrpc.Metadatum('a', 'b') - self.assertEqual('a'.encode(), metadatum.key) - self.assertEqual('b'.encode(), metadatum.value) + metadatum = cygrpc.Metadatum(b'a', b'b') + self.assertEqual(b'a', metadatum.key) + self.assertEqual(b'b', metadatum.value) metadata = cygrpc.Metadata([metadatum]) self.assertEqual(1, len(metadata)) self.assertEqual(metadatum.key, metadata[0].key) def testMetadataIteration(self): metadata = cygrpc.Metadata([ - cygrpc.Metadatum('a', 'b'), cygrpc.Metadatum('c', 'd')]) + cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')]) iterator = iter(metadata) metadatum = next(iterator) self.assertIsInstance(metadatum, cygrpc.Metadatum) - self.assertEqual(metadatum.key, 'a'.encode()) - self.assertEqual(metadatum.value, 'b'.encode()) + self.assertEqual(metadatum.key, b'a') + self.assertEqual(metadatum.value, b'b') metadatum = next(iterator) self.assertIsInstance(metadatum, cygrpc.Metadatum) - self.assertEqual(metadatum.key, 'c'.encode()) - self.assertEqual(metadatum.value, 'd'.encode()) + self.assertEqual(metadatum.key, b'c') + self.assertEqual(metadatum.value, b'd') with self.assertRaises(StopIteration): next(iterator) def testOperationsIteration(self): operations = cygrpc.Operations([ - cygrpc.operation_send_message('asdf', _EMPTY_FLAGS)]) + cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)]) iterator = iter(operations) operation = next(iterator) self.assertIsInstance(operation, cygrpc.Operation) @@ -87,7 +87,7 @@ class TypeSmokeTest(unittest.TestCase): next(iterator) def testOperationFlags(self): - operation = cygrpc.operation_send_message('asdf', + operation = cygrpc.operation_send_message(b'asdf', cygrpc.WriteFlag.no_compress) self.assertEqual(cygrpc.WriteFlag.no_compress, operation.flags) @@ -105,16 +105,16 @@ class TypeSmokeTest(unittest.TestCase): del server def testChannelUpDown(self): - channel = cygrpc.Channel('[::]:0', cygrpc.ChannelArgs([])) + channel = cygrpc.Channel(b'[::]:0', cygrpc.ChannelArgs([])) del channel def testCredentialsMetadataPluginUpDown(self): plugin = cygrpc.CredentialsMetadataPlugin( - lambda ignored_a, ignored_b: None, '') + lambda ignored_a, ignored_b: None, b'') del plugin def testCallCredentialsFromPluginUpDown(self): - plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '') + plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, b'') call_credentials = cygrpc.call_credentials_metadata_plugin(plugin) del plugin del call_credentials @@ -123,7 +123,7 @@ class TypeSmokeTest(unittest.TestCase): server = cygrpc.Server() completion_queue = cygrpc.CompletionQueue() server.register_completion_queue(completion_queue) - port = server.add_http2_port('[::]:0') + port = server.add_http2_port(b'[::]:0') self.assertIsInstance(port, int) server.start() del server @@ -131,7 +131,7 @@ class TypeSmokeTest(unittest.TestCase): def testServerStartShutdown(self): completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server() - server.add_http2_port('[::]:0') + server.add_http2_port(b'[::]:0') server.register_completion_queue(completion_queue) server.start() shutdown_tag = object() @@ -150,9 +150,9 @@ class ServerClientMixin(object): self.server = cygrpc.Server() self.server.register_completion_queue(self.server_completion_queue) if server_credentials: - self.port = self.server.add_http2_port('[::]:0', server_credentials) + self.port = self.server.add_http2_port(b'[::]:0', server_credentials) else: - self.port = self.server.add_http2_port('[::]:0') + self.port = self.server.add_http2_port(b'[::]:0') self.server.start() self.client_completion_queue = cygrpc.CompletionQueue() if client_credentials: @@ -160,10 +160,10 @@ class ServerClientMixin(object): cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override, host_override)]) self.client_channel = cygrpc.Channel( - 'localhost:{}'.format(self.port), client_channel_arguments, + 'localhost:{}'.format(self.port).encode(), client_channel_arguments, client_credentials) else: - self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port)) + self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port).encode()) if host_override: self.host_argument = None # default host self.expected_host = host_override diff --git a/src/python/grpcio/tests/unit/_empty_message_test.py b/src/python/grpcio/tests/unit/_empty_message_test.py index f324f6216b..8c7d697728 100644 --- a/src/python/grpcio/tests/unit/_empty_message_test.py +++ b/src/python/grpcio/tests/unit/_empty_message_test.py @@ -37,10 +37,10 @@ from tests.unit.framework.common import test_constants _REQUEST = b'' _RESPONSE = b'' -_UNARY_UNARY = b'/test/UnaryUnary' -_UNARY_STREAM = b'/test/UnaryStream' -_STREAM_UNARY = b'/test/StreamUnary' -_STREAM_STREAM = b'/test/StreamStream' +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' def handle_unary_unary(request, servicer_context): diff --git a/src/python/grpcio/tests/unit/_exit_scenarios.py b/src/python/grpcio/tests/unit/_exit_scenarios.py new file mode 100644 index 0000000000..24a2faef85 --- /dev/null +++ b/src/python/grpcio/tests/unit/_exit_scenarios.py @@ -0,0 +1,249 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Defines a number of module-scope gRPC scenarios to test clean exit.""" + +import argparse +import threading +import time + +import grpc + +from tests.unit.framework.common import test_constants + +WAIT_TIME = 1000 + +REQUEST = b'request' + +UNSTARTED_SERVER = 'unstarted_server' +RUNNING_SERVER = 'running_server' +POLL_CONNECTIVITY_NO_SERVER = 'poll_connectivity_no_server' +POLL_CONNECTIVITY = 'poll_connectivity' +IN_FLIGHT_UNARY_UNARY_CALL = 'in_flight_unary_unary_call' +IN_FLIGHT_UNARY_STREAM_CALL = 'in_flight_unary_stream_call' +IN_FLIGHT_STREAM_UNARY_CALL = 'in_flight_stream_unary_call' +IN_FLIGHT_STREAM_STREAM_CALL = 'in_flight_stream_stream_call' +IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL = 'in_flight_partial_unary_stream_call' +IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL = 'in_flight_partial_stream_unary_call' +IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL = 'in_flight_partial_stream_stream_call' + +UNARY_UNARY = b'/test/UnaryUnary' +UNARY_STREAM = b'/test/UnaryStream' +STREAM_UNARY = b'/test/StreamUnary' +STREAM_STREAM = b'/test/StreamStream' +PARTIAL_UNARY_STREAM = b'/test/PartialUnaryStream' +PARTIAL_STREAM_UNARY = b'/test/PartialStreamUnary' +PARTIAL_STREAM_STREAM = b'/test/PartialStreamStream' + +TEST_TO_METHOD = { + IN_FLIGHT_UNARY_UNARY_CALL: UNARY_UNARY, + IN_FLIGHT_UNARY_STREAM_CALL: UNARY_STREAM, + IN_FLIGHT_STREAM_UNARY_CALL: STREAM_UNARY, + IN_FLIGHT_STREAM_STREAM_CALL: STREAM_STREAM, + IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL: PARTIAL_UNARY_STREAM, + IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL: PARTIAL_STREAM_UNARY, + IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL: PARTIAL_STREAM_STREAM, +} + + +def hang_unary_unary(request, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_unary_stream(request, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_partial_unary_stream(request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH // 2): + yield request + time.sleep(WAIT_TIME) + + +def hang_stream_unary(request_iterator, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_partial_stream_unary(request_iterator, servicer_context): + for _ in range(test_constants.STREAM_LENGTH // 2): + next(request_iterator) + time.sleep(WAIT_TIME) + + +def hang_stream_stream(request_iterator, servicer_context): + time.sleep(WAIT_TIME) + + +def hang_partial_stream_stream(request_iterator, servicer_context): + for _ in range(test_constants.STREAM_LENGTH // 2): + yield next(request_iterator) + time.sleep(WAIT_TIME) + + +class MethodHandler(grpc.RpcMethodHandler): + + def __init__(self, request_streaming, response_streaming, partial_hang): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + if partial_hang: + self.stream_stream = hang_partial_stream_stream + else: + self.stream_stream = hang_stream_stream + elif self.request_streaming: + if partial_hang: + self.stream_unary = hang_partial_stream_unary + else: + self.stream_unary = hang_stream_unary + elif self.response_streaming: + if partial_hang: + self.unary_stream = hang_partial_unary_stream + else: + self.unary_stream = hang_unary_stream + else: + self.unary_unary = hang_unary_unary + + +class GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == UNARY_UNARY: + return MethodHandler(False, False, False) + elif handler_call_details.method == UNARY_STREAM: + return MethodHandler(False, True, False) + elif handler_call_details.method == STREAM_UNARY: + return MethodHandler(True, False, False) + elif handler_call_details.method == STREAM_STREAM: + return MethodHandler(True, True, False) + elif handler_call_details.method == PARTIAL_UNARY_STREAM: + return MethodHandler(False, True, True) + elif handler_call_details.method == PARTIAL_STREAM_UNARY: + return MethodHandler(True, False, True) + elif handler_call_details.method == PARTIAL_STREAM_STREAM: + return MethodHandler(True, True, True) + else: + return None + + +# Traditional executors will not exit until all their +# current jobs complete. Because we submit jobs that will +# never finish, we don't want to block exit on these jobs. +class DaemonPool(object): + + def submit(self, fn, *args, **kwargs): + thread = threading.Thread(target=fn, args=args, kwargs=kwargs) + thread.daemon = True + thread.start() + + def shutdown(self, wait=True): + pass + + +def infinite_request_iterator(): + while True: + yield REQUEST + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('scenario', type=str) + parser.add_argument( + '--wait_for_interrupt', dest='wait_for_interrupt', action='store_true') + args = parser.parse_args() + + if args.scenario == UNSTARTED_SERVER: + server = grpc.server((), DaemonPool()) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == RUNNING_SERVER: + server = grpc.server((), DaemonPool()) + port = server.add_insecure_port('[::]:0') + server.start() + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == POLL_CONNECTIVITY_NO_SERVER: + channel = grpc.insecure_channel('localhost:12345') + + def connectivity_callback(connectivity): + pass + + channel.subscribe(connectivity_callback, try_to_connect=True) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == POLL_CONNECTIVITY: + server = grpc.server((), DaemonPool()) + port = server.add_insecure_port('[::]:0') + server.start() + channel = grpc.insecure_channel('localhost:%d' % port) + + def connectivity_callback(connectivity): + pass + + channel.subscribe(connectivity_callback, try_to_connect=True) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + + else: + handler = GenericHandler() + server = grpc.server((), DaemonPool()) + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers((handler,)) + server.start() + channel = grpc.insecure_channel('localhost:%d' % port) + + method = TEST_TO_METHOD[args.scenario] + + if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL: + multi_callable = channel.unary_unary(method) + future = multi_callable.future(REQUEST) + result, call = multi_callable.with_call(REQUEST) + elif (args.scenario == IN_FLIGHT_UNARY_STREAM_CALL or + args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL): + multi_callable = channel.unary_stream(method) + response_iterator = multi_callable(REQUEST) + for response in response_iterator: + pass + elif (args.scenario == IN_FLIGHT_STREAM_UNARY_CALL or + args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL): + multi_callable = channel.stream_unary(method) + future = multi_callable.future(infinite_request_iterator()) + result, call = multi_callable.with_call( + [REQUEST] * test_constants.STREAM_LENGTH) + elif (args.scenario == IN_FLIGHT_STREAM_STREAM_CALL or + args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL): + multi_callable = channel.stream_stream(method) + response_iterator = multi_callable(infinite_request_iterator()) + for response in response_iterator: + pass diff --git a/src/python/grpcio/tests/unit/_exit_test.py b/src/python/grpcio/tests/unit/_exit_test.py new file mode 100644 index 0000000000..b0d6af73e5 --- /dev/null +++ b/src/python/grpcio/tests/unit/_exit_test.py @@ -0,0 +1,185 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests clean exit of server/client on Python Interpreter exit/sigint. + +The tests in this module spawn a subprocess for each test case, the +test is considered successful if it doesn't hang/timeout. +""" + +import atexit +import os +import signal +import six +import subprocess +import sys +import threading +import time +import unittest + +from tests.unit import _exit_scenarios + +SCENARIO_FILE = os.path.abspath(os.path.join( + os.path.dirname(os.path.realpath(__file__)), '_exit_scenarios.py')) +INTERPRETER = sys.executable +BASE_COMMAND = [INTERPRETER, SCENARIO_FILE] +BASE_SIGTERM_COMMAND = BASE_COMMAND + ['--wait_for_interrupt'] + +INIT_TIME = 1.0 + + +processes = [] +process_lock = threading.Lock() + + +# Make sure we attempt to clean up any +# processes we may have left running +def cleanup_processes(): + with process_lock: + for process in processes: + try: + process.kill() + except Exception: + pass +atexit.register(cleanup_processes) + + +def interrupt_and_wait(process): + with process_lock: + processes.append(process) + time.sleep(INIT_TIME) + os.kill(process.pid, signal.SIGINT) + process.wait() + + +def wait(process): + with process_lock: + processes.append(process) + process.wait() + + +class ExitTest(unittest.TestCase): + + def test_unstarted_server(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.UNSTARTED_SERVER], + stdout=sys.stdout, stderr=sys.stderr) + wait(process) + + def test_unstarted_server_terminate(self): + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.UNSTARTED_SERVER], + stdout=sys.stdout) + interrupt_and_wait(process) + + def test_running_server(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.RUNNING_SERVER], + stdout=sys.stdout, stderr=sys.stderr) + wait(process) + + def test_running_server_terminate(self): + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.RUNNING_SERVER], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + def test_poll_connectivity_no_server(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER], + stdout=sys.stdout, stderr=sys.stderr) + wait(process) + + def test_poll_connectivity_no_server_terminate(self): + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + def test_poll_connectivity(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY], + stdout=sys.stdout, stderr=sys.stderr) + wait(process) + + def test_poll_connectivity_terminate(self): + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + def test_in_flight_unary_unary_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + def test_in_flight_unary_stream_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + def test_in_flight_stream_unary_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + def test_in_flight_stream_stream_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + def test_in_flight_partial_unary_stream_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + def test_in_flight_partial_stream_unary_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + def test_in_flight_partial_stream_stream_call(self): + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL], + stdout=sys.stdout, stderr=sys.stderr) + interrupt_and_wait(process) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_metadata_code_details_test.py b/src/python/grpcio/tests/unit/_metadata_code_details_test.py new file mode 100644 index 0000000000..0fd02d2a22 --- /dev/null +++ b/src/python/grpcio/tests/unit/_metadata_code_details_test.py @@ -0,0 +1,523 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests application-provided metadata, status code, and details.""" + +import threading +import unittest + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit import test_common +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZED_REQUEST = b'\x46\x47\x48' +_SERIALIZED_RESPONSE = b'\x49\x50\x51' + +_REQUEST_SERIALIZER = lambda unused_request: _SERIALIZED_REQUEST +_REQUEST_DESERIALIZER = lambda unused_serialized_request: object() +_RESPONSE_SERIALIZER = lambda unused_response: _SERIALIZED_RESPONSE +_RESPONSE_DESERIALIZER = lambda unused_serialized_resopnse: object() + +_SERVICE = 'test.TestService' +_UNARY_UNARY = 'UnaryUnary' +_UNARY_STREAM = 'UnaryStream' +_STREAM_UNARY = 'StreamUnary' +_STREAM_STREAM = 'StreamStream' + +_CLIENT_METADATA = ( + ('client-md-key', 'client-md-key'), + ('client-md-key-bin', b'\x00\x01') +) + +_SERVER_INITIAL_METADATA = ( + ('server-initial-md-key', 'server-initial-md-value'), + ('server-initial-md-key-bin', b'\x00\x02') +) + +_SERVER_TRAILING_METADATA = ( + ('server-trailing-md-key', 'server-trailing-md-value'), + ('server-trailing-md-key-bin', b'\x00\x03') +) + +_NON_OK_CODE = grpc.StatusCode.NOT_FOUND +_DETAILS = 'Test details!' + + +class _Servicer(object): + + def __init__(self): + self._lock = threading.Lock() + self._code = None + self._details = None + self._exception = False + self._return_none = False + self._received_client_metadata = None + + def unary_unary(self, request, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + if self._exception: + raise test_control.Defect() + else: + return None if self._return_none else object() + + def unary_stream(self, request, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + for _ in range(test_constants.STREAM_LENGTH // 2): + yield _SERIALIZED_RESPONSE + if self._exception: + raise test_control.Defect() + + def stream_unary(self, request_iterator, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the + # request iterator. + for ignored_request in request_iterator: + pass + if self._exception: + raise test_control.Defect() + else: + return None if self._return_none else _SERIALIZED_RESPONSE + + def stream_stream(self, request_iterator, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + context.send_initial_metadata(_SERVER_INITIAL_METADATA) + context.set_trailing_metadata(_SERVER_TRAILING_METADATA) + if self._code is not None: + context.set_code(self._code) + if self._details is not None: + context.set_details(self._details) + # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the + # request iterator. + for ignored_request in request_iterator: + pass + for _ in range(test_constants.STREAM_LENGTH // 3): + yield object() + if self._exception: + raise test_control.Defect() + + def set_code(self, code): + with self._lock: + self._code = code + + def set_details(self, details): + with self._lock: + self._details = details + + def set_exception(self): + with self._lock: + self._exception = True + + def set_return_none(self): + with self._lock: + self._return_none = True + + def received_client_metadata(self): + with self._lock: + return self._received_client_metadata + + +def _generic_handler(servicer): + method_handlers = { + _UNARY_UNARY: grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, request_deserializer=_REQUEST_DESERIALIZER, + response_serializer=_RESPONSE_SERIALIZER), + _UNARY_STREAM: grpc.unary_stream_rpc_method_handler( + servicer.unary_stream), + _STREAM_UNARY: grpc.stream_unary_rpc_method_handler( + servicer.stream_unary), + _STREAM_STREAM: grpc.stream_stream_rpc_method_handler( + servicer.stream_stream, request_deserializer=_REQUEST_DESERIALIZER, + response_serializer=_RESPONSE_SERIALIZER), + } + return grpc.method_handlers_generic_handler(_SERVICE, method_handlers) + + +class MetadataCodeDetailsTest(unittest.TestCase): + + def setUp(self): + self._servicer = _Servicer() + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + self._server = grpc.server( + (_generic_handler(self._servicer),), self._server_pool) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + + channel = grpc.insecure_channel('localhost:{}'.format(port)) + self._unary_unary = channel.unary_unary( + '/'.join(('', _SERVICE, _UNARY_UNARY,)), + request_serializer=_REQUEST_SERIALIZER, + response_deserializer=_RESPONSE_DESERIALIZER,) + self._unary_stream = channel.unary_stream( + '/'.join(('', _SERVICE, _UNARY_STREAM,)),) + self._stream_unary = channel.stream_unary( + '/'.join(('', _SERVICE, _STREAM_UNARY,)),) + self._stream_stream = channel.stream_stream( + '/'.join(('', _SERVICE, _STREAM_STREAM,)), + request_serializer=_REQUEST_SERIALIZER, + response_deserializer=_RESPONSE_DESERIALIZER,) + + + def testSuccessfulUnaryUnary(self): + self._servicer.set_details(_DETAILS) + + unused_response, call = self._unary_unary.with_call( + object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, call.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testSuccessfulUnaryStream(self): + self._servicer.set_details(_DETAILS) + + call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testSuccessfulStreamUnary(self): + self._servicer.set_details(_DETAILS) + + unused_response, call = self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, call.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testSuccessfulStreamStream(self): + self._servicer.set_details(_DETAILS) + + call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(grpc.StatusCode.OK, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testCustomCodeUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeUnaryStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + with self.assertRaises(grpc.RpcError): + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testCustomCodeStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeStreamStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + with self.assertRaises(grpc.RpcError) as exception_context: + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeExceptionUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeExceptionUnaryStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + with self.assertRaises(grpc.RpcError): + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testCustomCodeExceptionStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeExceptionStreamStream(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_exception() + + call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = call.initial_metadata() + with self.assertRaises(grpc.RpcError): + for _ in call: + pass + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata())) + self.assertIs(_NON_OK_CODE, call.code()) + self.assertEqual(_DETAILS, call.details()) + + def testCustomCodeReturnNoneUnaryUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_return_none() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + def testCustomCodeReturnNoneStreamUnary(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + self._servicer.set_return_none() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(_NON_OK_CODE, exception_context.exception.code()) + self.assertEqual(_DETAILS, exception_context.exception.details()) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_metadata_test.py b/src/python/grpcio/tests/unit/_metadata_test.py index 2cb13f236b..c637a28039 100644 --- a/src/python/grpcio/tests/unit/_metadata_test.py +++ b/src/python/grpcio/tests/unit/_metadata_test.py @@ -44,33 +44,33 @@ _CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'), _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x00\x00\x00' -_UNARY_UNARY = b'/test/UnaryUnary' -_UNARY_STREAM = b'/test/UnaryStream' -_STREAM_UNARY = b'/test/StreamUnary' -_STREAM_STREAM = b'/test/StreamStream' +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' _USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__) _CLIENT_METADATA = ( - (b'client-md-key', b'client-md-key'), - (b'client-md-key-bin', b'\x00\x01') + ('client-md-key', 'client-md-key'), + ('client-md-key-bin', b'\x00\x01') ) _SERVER_INITIAL_METADATA = ( - (b'server-initial-md-key', b'server-initial-md-value'), - (b'server-initial-md-key-bin', b'\x00\x02') + ('server-initial-md-key', 'server-initial-md-value'), + ('server-initial-md-key-bin', b'\x00\x02') ) _SERVER_TRAILING_METADATA = ( - (b'server-trailing-md-key', b'server-trailing-md-value'), - (b'server-trailing-md-key-bin', b'\x00\x03') + ('server-trailing-md-key', 'server-trailing-md-value'), + ('server-trailing-md-key-bin', b'\x00\x03') ) def user_agent(metadata): for key, val in metadata: - if key == b'user-agent': - return val.decode('ascii') + if key == 'user-agent': + return val raise KeyError('No user agent!') diff --git a/src/python/grpcio/tests/unit/_rpc_test.py b/src/python/grpcio/tests/unit/_rpc_test.py index 9814504edf..c70d65a6df 100644 --- a/src/python/grpcio/tests/unit/_rpc_test.py +++ b/src/python/grpcio/tests/unit/_rpc_test.py @@ -45,10 +45,10 @@ _DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] _SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 _DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] -_UNARY_UNARY = b'/test/UnaryUnary' -_UNARY_STREAM = b'/test/UnaryStream' -_STREAM_UNARY = b'/test/StreamUnary' -_STREAM_STREAM = b'/test/StreamStream' +_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_STREAM = '/test/UnaryStream' +_STREAM_UNARY = '/test/StreamUnary' +_STREAM_STREAM = '/test/StreamStream' class _Callback(object): @@ -79,7 +79,7 @@ class _Handler(object): def handle_unary_unary(self, request, servicer_context): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) return request def handle_unary_stream(self, request, servicer_context): @@ -88,7 +88,7 @@ class _Handler(object): yield request self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) def handle_stream_unary(self, request_iterator, servicer_context): if servicer_context is not None: @@ -100,13 +100,13 @@ class _Handler(object): response_elements.append(request) self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) return b''.join(response_elements) def handle_stream_stream(self, request_iterator, servicer_context): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) for request in request_iterator: self._control.control() yield request @@ -185,7 +185,7 @@ class RPCTest(unittest.TestCase): self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) self._server = grpc.server((), self._server_pool) - port = self._server.add_insecure_port(b'[::]:0') + port = self._server.add_insecure_port('[::]:0') self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) self._server.start() @@ -195,7 +195,7 @@ class RPCTest(unittest.TestCase): request = b'abc' with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(b'NoSuchMethod')(request) + self._channel.unary_unary('NoSuchMethod')(request) self.assertEqual( grpc.StatusCode.UNIMPLEMENTED, exception_context.exception.code()) @@ -207,7 +207,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_unary_multi_callable(self._channel) response = multi_callable( request, metadata=( - (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponse'),)) + ('test', 'SuccessfulUnaryRequestBlockingUnaryResponse'),)) self.assertEqual(expected_response, response) @@ -218,7 +218,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_unary_multi_callable(self._channel) response, call = multi_callable.with_call( request, metadata=( - (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),)) + ('test', 'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),)) self.assertEqual(expected_response, response) self.assertIs(grpc.StatusCode.OK, call.code()) @@ -230,7 +230,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_unary_multi_callable(self._channel) response_future = multi_callable.future( request, metadata=( - (b'test', b'SuccessfulUnaryRequestFutureUnaryResponse'),)) + ('test', 'SuccessfulUnaryRequestFutureUnaryResponse'),)) response = response_future.result() self.assertEqual(expected_response, response) @@ -242,7 +242,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_stream_multi_callable(self._channel) response_iterator = multi_callable( request, - metadata=((b'test', b'SuccessfulUnaryRequestStreamResponse'),)) + metadata=(('test', 'SuccessfulUnaryRequestStreamResponse'),)) responses = tuple(response_iterator) self.assertSequenceEqual(expected_responses, responses) @@ -255,7 +255,7 @@ class RPCTest(unittest.TestCase): multi_callable = _stream_unary_multi_callable(self._channel) response = multi_callable( request_iterator, - metadata=((b'test', b'SuccessfulStreamRequestBlockingUnaryResponse'),)) + metadata=(('test', 'SuccessfulStreamRequestBlockingUnaryResponse'),)) self.assertEqual(expected_response, response) @@ -268,7 +268,7 @@ class RPCTest(unittest.TestCase): response, call = multi_callable.with_call( request_iterator, metadata=( - (b'test', b'SuccessfulStreamRequestBlockingUnaryResponseWithCall'), + ('test', 'SuccessfulStreamRequestBlockingUnaryResponseWithCall'), )) self.assertEqual(expected_response, response) @@ -283,7 +283,7 @@ class RPCTest(unittest.TestCase): response_future = multi_callable.future( request_iterator, metadata=( - (b'test', b'SuccessfulStreamRequestFutureUnaryResponse'),)) + ('test', 'SuccessfulStreamRequestFutureUnaryResponse'),)) response = response_future.result() self.assertEqual(expected_response, response) @@ -297,7 +297,7 @@ class RPCTest(unittest.TestCase): multi_callable = _stream_stream_multi_callable(self._channel) response_iterator = multi_callable( request_iterator, - metadata=((b'test', b'SuccessfulStreamRequestStreamResponse'),)) + metadata=(('test', 'SuccessfulStreamRequestStreamResponse'),)) responses = tuple(response_iterator) self.assertSequenceEqual(expected_responses, responses) @@ -312,9 +312,9 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_unary_multi_callable(self._channel) first_response = multi_callable( - first_request, metadata=((b'test', b'SequentialInvocations'),)) + first_request, metadata=(('test', 'SequentialInvocations'),)) second_response = multi_callable( - second_request, metadata=((b'test', b'SequentialInvocations'),)) + second_request, metadata=(('test', 'SequentialInvocations'),)) self.assertEqual(expected_first_response, first_response) self.assertEqual(expected_second_response, second_response) @@ -331,7 +331,7 @@ class RPCTest(unittest.TestCase): request_iterator = iter(requests) response_future = pool.submit( multi_callable, request_iterator, - metadata=((b'test', b'ConcurrentBlockingInvocations'),)) + metadata=(('test', 'ConcurrentBlockingInvocations'),)) response_futures[index] = response_future responses = tuple( response_future.result() for response_future in response_futures) @@ -350,7 +350,7 @@ class RPCTest(unittest.TestCase): request_iterator = iter(requests) response_future = multi_callable.future( request_iterator, - metadata=((b'test', b'ConcurrentFutureInvocations'),)) + metadata=(('test', 'ConcurrentFutureInvocations'),)) response_futures[index] = response_future responses = tuple( response_future.result() for response_future in response_futures) @@ -380,8 +380,8 @@ class RPCTest(unittest.TestCase): inner_response_future = multi_callable.future( request, metadata=( - (b'test', - b'WaitingForSomeButNotAllConcurrentFutureInvocations'),)) + ('test', + 'WaitingForSomeButNotAllConcurrentFutureInvocations'),)) outer_response_future = pool.submit(wrap_future(inner_response_future)) response_futures[index] = outer_response_future @@ -400,7 +400,7 @@ class RPCTest(unittest.TestCase): response_iterator = multi_callable( request, metadata=( - (b'test', b'ConsumingOneStreamResponseUnaryRequest'),)) + ('test', 'ConsumingOneStreamResponseUnaryRequest'),)) next(response_iterator) def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self): @@ -410,7 +410,7 @@ class RPCTest(unittest.TestCase): response_iterator = multi_callable( request, metadata=( - (b'test', b'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) + ('test', 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) for _ in range(test_constants.STREAM_LENGTH // 2): next(response_iterator) @@ -422,7 +422,7 @@ class RPCTest(unittest.TestCase): response_iterator = multi_callable( request_iterator, metadata=( - (b'test', b'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) + ('test', 'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) for _ in range(test_constants.STREAM_LENGTH // 2): next(response_iterator) @@ -434,7 +434,7 @@ class RPCTest(unittest.TestCase): response_iterator = multi_callable( request_iterator, metadata=( - (b'test', b'ConsumingTooManyStreamResponsesStreamRequest'),)) + ('test', 'ConsumingTooManyStreamResponsesStreamRequest'),)) for _ in range(test_constants.STREAM_LENGTH): next(response_iterator) for _ in range(test_constants.STREAM_LENGTH): @@ -453,7 +453,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_future = multi_callable.future( request, - metadata=((b'test', b'CancelledUnaryRequestUnaryResponse'),)) + metadata=(('test', 'CancelledUnaryRequestUnaryResponse'),)) response_future.cancel() self.assertTrue(response_future.cancelled()) @@ -468,7 +468,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_iterator = multi_callable( request, - metadata=((b'test', b'CancelledUnaryRequestStreamResponse'),)) + metadata=(('test', 'CancelledUnaryRequestStreamResponse'),)) self._control.block_until_paused() response_iterator.cancel() @@ -488,7 +488,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_future = multi_callable.future( request_iterator, - metadata=((b'test', b'CancelledStreamRequestUnaryResponse'),)) + metadata=(('test', 'CancelledStreamRequestUnaryResponse'),)) self._control.block_until_paused() response_future.cancel() @@ -508,7 +508,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_iterator = multi_callable( request_iterator, - metadata=((b'test', b'CancelledStreamRequestStreamResponse'),)) + metadata=(('test', 'CancelledStreamRequestStreamResponse'),)) response_iterator.cancel() with self.assertRaises(grpc.RpcError): @@ -526,7 +526,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable.with_call( request, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredUnaryRequestBlockingUnaryResponse'),)) + metadata=(('test', 'ExpiredUnaryRequestBlockingUnaryResponse'),)) self.assertIsNotNone(exception_context.exception.initial_metadata()) self.assertIs( @@ -542,7 +542,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_future = multi_callable.future( request, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredUnaryRequestFutureUnaryResponse'),)) + metadata=(('test', 'ExpiredUnaryRequestFutureUnaryResponse'),)) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() @@ -567,7 +567,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: response_iterator = multi_callable( request, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredUnaryRequestStreamResponse'),)) + metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),)) next(response_iterator) self.assertIs( @@ -583,7 +583,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable( request_iterator, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredStreamRequestBlockingUnaryResponse'),)) + metadata=(('test', 'ExpiredStreamRequestBlockingUnaryResponse'),)) self.assertIsNotNone(exception_context.exception.initial_metadata()) self.assertIs( @@ -600,7 +600,7 @@ class RPCTest(unittest.TestCase): with self._control.pause(): response_future = multi_callable.future( request_iterator, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredStreamRequestFutureUnaryResponse'),)) + metadata=(('test', 'ExpiredStreamRequestFutureUnaryResponse'),)) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() @@ -625,7 +625,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: response_iterator = multi_callable( request_iterator, timeout=test_constants.SHORT_TIMEOUT, - metadata=((b'test', b'ExpiredStreamRequestStreamResponse'),)) + metadata=(('test', 'ExpiredStreamRequestStreamResponse'),)) next(response_iterator) self.assertIs( @@ -640,7 +640,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable.with_call( request, - metadata=((b'test', b'FailedUnaryRequestBlockingUnaryResponse'),)) + metadata=(('test', 'FailedUnaryRequestBlockingUnaryResponse'),)) self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) @@ -652,7 +652,7 @@ class RPCTest(unittest.TestCase): with self._control.fail(): response_future = multi_callable.future( request, - metadata=((b'test', b'FailedUnaryRequestFutureUnaryResponse'),)) + metadata=(('test', 'FailedUnaryRequestFutureUnaryResponse'),)) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() @@ -672,7 +672,7 @@ class RPCTest(unittest.TestCase): with self._control.fail(): response_iterator = multi_callable( request, - metadata=((b'test', b'FailedUnaryRequestStreamResponse'),)) + metadata=(('test', 'FailedUnaryRequestStreamResponse'),)) next(response_iterator) self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) @@ -686,7 +686,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable( request_iterator, - metadata=((b'test', b'FailedStreamRequestBlockingUnaryResponse'),)) + metadata=(('test', 'FailedStreamRequestBlockingUnaryResponse'),)) self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) @@ -699,7 +699,7 @@ class RPCTest(unittest.TestCase): with self._control.fail(): response_future = multi_callable.future( request_iterator, - metadata=((b'test', b'FailedStreamRequestFutureUnaryResponse'),)) + metadata=(('test', 'FailedStreamRequestFutureUnaryResponse'),)) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() @@ -720,7 +720,7 @@ class RPCTest(unittest.TestCase): with self.assertRaises(grpc.RpcError) as exception_context: response_iterator = multi_callable( request_iterator, - metadata=((b'test', b'FailedStreamRequestStreamResponse'),)) + metadata=(('test', 'FailedStreamRequestStreamResponse'),)) tuple(response_iterator) self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) @@ -732,7 +732,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_unary_multi_callable(self._channel) multi_callable.future( request, - metadata=((b'test', b'IgnoredUnaryRequestFutureUnaryResponse'),)) + metadata=(('test', 'IgnoredUnaryRequestFutureUnaryResponse'),)) def testIgnoredUnaryRequestStreamResponse(self): request = b'\x37\x17' @@ -740,7 +740,7 @@ class RPCTest(unittest.TestCase): multi_callable = _unary_stream_multi_callable(self._channel) multi_callable( request, - metadata=((b'test', b'IgnoredUnaryRequestStreamResponse'),)) + metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),)) def testIgnoredStreamRequestFutureUnaryResponse(self): requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) @@ -749,7 +749,7 @@ class RPCTest(unittest.TestCase): multi_callable = _stream_unary_multi_callable(self._channel) multi_callable.future( request_iterator, - metadata=((b'test', b'IgnoredStreamRequestFutureUnaryResponse'),)) + metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),)) def testIgnoredStreamRequestStreamResponse(self): requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) @@ -758,7 +758,7 @@ class RPCTest(unittest.TestCase): multi_callable = _stream_stream_multi_callable(self._channel) multi_callable( request_iterator, - metadata=((b'test', b'IgnoredStreamRequestStreamResponse'),)) + metadata=(('test', 'IgnoredStreamRequestStreamResponse'),)) if __name__ == '__main__': diff --git a/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py b/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py index 488f7d7141..5d826a269d 100644 --- a/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py +++ b/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py @@ -29,162 +29,9 @@ """Tests of grpc.beta._connectivity_channel.""" -import threading -import time import unittest -from grpc._adapter import _low -from grpc._adapter import _types -from grpc.beta import _connectivity_channel from grpc.beta import interfaces -from tests.unit.framework.common import test_constants - - -def _drive_completion_queue(completion_queue): - while True: - event = completion_queue.next(time.time() + 24 * 60 * 60) - if event.type == _types.EventType.QUEUE_SHUTDOWN: - break - - -class _Callback(object): - - def __init__(self): - self._condition = threading.Condition() - self._connectivities = [] - - def update(self, connectivity): - with self._condition: - self._connectivities.append(connectivity) - self._condition.notify() - - def connectivities(self): - with self._condition: - return tuple(self._connectivities) - - def block_until_connectivities_satisfy(self, predicate): - with self._condition: - while True: - connectivities = tuple(self._connectivities) - if predicate(connectivities): - return connectivities - else: - self._condition.wait() - - -class ChannelConnectivityTest(unittest.TestCase): - - def test_lonely_channel_connectivity(self): - low_channel = _low.Channel('localhost:12345', ()) - callback = _Callback() - - connectivity_channel = _connectivity_channel.ConnectivityChannel( - low_channel) - connectivity_channel.subscribe(callback.update, try_to_connect=False) - first_connectivities = callback.block_until_connectivities_satisfy(bool) - connectivity_channel.subscribe(callback.update, try_to_connect=True) - second_connectivities = callback.block_until_connectivities_satisfy( - lambda connectivities: 2 <= len(connectivities)) - # Wait for a connection that will never happen. - time.sleep(test_constants.SHORT_TIMEOUT) - third_connectivities = callback.connectivities() - connectivity_channel.unsubscribe(callback.update) - fourth_connectivities = callback.connectivities() - connectivity_channel.unsubscribe(callback.update) - fifth_connectivities = callback.connectivities() - - self.assertSequenceEqual( - (interfaces.ChannelConnectivity.IDLE,), first_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.READY, second_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.READY, third_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.READY, fourth_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.READY, fifth_connectivities) - - def test_immediately_connectable_channel_connectivity(self): - server_completion_queue = _low.CompletionQueue() - server = _low.Server(server_completion_queue, []) - port = server.add_http2_port('[::]:0') - server.start() - server_completion_queue_thread = threading.Thread( - target=_drive_completion_queue, args=(server_completion_queue,)) - server_completion_queue_thread.start() - low_channel = _low.Channel('localhost:%d' % port, ()) - first_callback = _Callback() - second_callback = _Callback() - - connectivity_channel = _connectivity_channel.ConnectivityChannel( - low_channel) - connectivity_channel.subscribe(first_callback.update, try_to_connect=False) - first_connectivities = first_callback.block_until_connectivities_satisfy( - bool) - # Wait for a connection that will never happen because try_to_connect=True - # has not yet been passed. - time.sleep(test_constants.SHORT_TIMEOUT) - second_connectivities = first_callback.connectivities() - connectivity_channel.subscribe(second_callback.update, try_to_connect=True) - third_connectivities = first_callback.block_until_connectivities_satisfy( - lambda connectivities: 2 <= len(connectivities)) - fourth_connectivities = second_callback.block_until_connectivities_satisfy( - bool) - # Wait for a connection that will happen (or may already have happened). - first_callback.block_until_connectivities_satisfy( - lambda connectivities: - interfaces.ChannelConnectivity.READY in connectivities) - second_callback.block_until_connectivities_satisfy( - lambda connectivities: - interfaces.ChannelConnectivity.READY in connectivities) - connectivity_channel.unsubscribe(first_callback.update) - connectivity_channel.unsubscribe(second_callback.update) - - server.shutdown() - server_completion_queue.shutdown() - server_completion_queue_thread.join() - - self.assertSequenceEqual( - (interfaces.ChannelConnectivity.IDLE,), first_connectivities) - self.assertSequenceEqual( - (interfaces.ChannelConnectivity.IDLE,), second_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.FATAL_FAILURE, third_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.TRANSIENT_FAILURE, - fourth_connectivities) - self.assertNotIn( - interfaces.ChannelConnectivity.FATAL_FAILURE, fourth_connectivities) - - def test_reachable_then_unreachable_channel_connectivity(self): - server_completion_queue = _low.CompletionQueue() - server = _low.Server(server_completion_queue, []) - port = server.add_http2_port('[::]:0') - server.start() - server_completion_queue_thread = threading.Thread( - target=_drive_completion_queue, args=(server_completion_queue,)) - server_completion_queue_thread.start() - low_channel = _low.Channel('localhost:%d' % port, ()) - callback = _Callback() - - connectivity_channel = _connectivity_channel.ConnectivityChannel( - low_channel) - connectivity_channel.subscribe(callback.update, try_to_connect=True) - callback.block_until_connectivities_satisfy( - lambda connectivities: - interfaces.ChannelConnectivity.READY in connectivities) - # Now take down the server and confirm that channel readiness is repudiated. - server.shutdown() - callback.block_until_connectivities_satisfy( - lambda connectivities: - connectivities[-1] is not interfaces.ChannelConnectivity.READY) - connectivity_channel.unsubscribe(callback.update) - - server.shutdown() - server_completion_queue.shutdown() - server_completion_queue_thread.join() class ConnectivityStatesTest(unittest.TestCase): diff --git a/src/python/grpcio/tests/unit/beta/_utilities_test.py b/src/python/grpcio/tests/unit/beta/_utilities_test.py index 08ce98e751..90fe10c77c 100644 --- a/src/python/grpcio/tests/unit/beta/_utilities_test.py +++ b/src/python/grpcio/tests/unit/beta/_utilities_test.py @@ -33,21 +33,12 @@ import threading import time import unittest -from grpc._adapter import _low -from grpc._adapter import _types from grpc.beta import implementations from grpc.beta import utilities from grpc.framework.foundation import future from tests.unit.framework.common import test_constants -def _drive_completion_queue(completion_queue): - while True: - event = completion_queue.next(time.time() + 24 * 60 * 60) - if event.type == _types.EventType.QUEUE_SHUTDOWN: - break - - class _Callback(object): def __init__(self): @@ -87,13 +78,9 @@ class ChannelConnectivityTest(unittest.TestCase): self.assertFalse(ready_future.running()) def test_immediately_connectable_channel_connectivity(self): - server_completion_queue = _low.CompletionQueue() - server = _low.Server(server_completion_queue, []) - port = server.add_http2_port('[::]:0') + server = implementations.server({}) + port = server.add_insecure_port('[::]:0') server.start() - server_completion_queue_thread = threading.Thread( - target=_drive_completion_queue, args=(server_completion_queue,)) - server_completion_queue_thread.start() channel = implementations.insecure_channel('localhost', port) callback = _Callback() @@ -114,9 +101,7 @@ class ChannelConnectivityTest(unittest.TestCase): self.assertFalse(ready_future.running()) finally: ready_future.cancel() - server.shutdown() - server_completion_queue.shutdown() - server_completion_queue_thread.join() + server.stop(0) if __name__ == '__main__': diff --git a/src/python/grpcio/tests/unit/beta/test_utilities.py b/src/python/grpcio/tests/unit/beta/test_utilities.py index 8ccad04e05..692da9c97d 100644 --- a/src/python/grpcio/tests/unit/beta/test_utilities.py +++ b/src/python/grpcio/tests/unit/beta/test_utilities.py @@ -51,5 +51,5 @@ def not_really_secure_channel( target = '%s:%d' % (host, port) channel = grpc.secure_channel( target, channel_credentials, - ((b'grpc.ssl_target_name_override', server_host_override,),)) + (('grpc.ssl_target_name_override', server_host_override,),)) return implementations.Channel(channel) diff --git a/src/python/grpcio/tests/unit/test_common.py b/src/python/grpcio/tests/unit/test_common.py index b779f65e7e..c8886bf4ca 100644 --- a/src/python/grpcio/tests/unit/test_common.py +++ b/src/python/grpcio/tests/unit/test_common.py @@ -33,10 +33,10 @@ import collections import six -INVOCATION_INITIAL_METADATA = ((b'0', b'abc'), (b'1', b'def'), (b'2', b'ghi'),) -SERVICE_INITIAL_METADATA = ((b'3', b'jkl'), (b'4', b'mno'), (b'5', b'pqr'),) -SERVICE_TERMINAL_METADATA = ((b'6', b'stu'), (b'7', b'vwx'), (b'8', b'yza'),) -DETAILS = b'test details' +INVOCATION_INITIAL_METADATA = (('0', 'abc'), ('1', 'def'), ('2', 'ghi'),) +SERVICE_INITIAL_METADATA = (('3', 'jkl'), ('4', 'mno'), ('5', 'pqr'),) +SERVICE_TERMINAL_METADATA = (('6', 'stu'), ('7', 'vwx'), ('8', 'yza'),) +DETAILS = 'test details' def metadata_transmitted(original_metadata, transmitted_metadata): @@ -59,16 +59,10 @@ def metadata_transmitted(original_metadata, transmitted_metadata): original_metadata after having been transmitted via gRPC. """ original = collections.defaultdict(list) - for key_value_pair in original_metadata: - key, value = tuple(key_value_pair) - if not isinstance(key, bytes): - key = key.encode() - if not isinstance(value, bytes): - value = value.encode() + for key, value in original_metadata: original[key].append(value) transmitted = collections.defaultdict(list) - for key_value_pair in transmitted_metadata: - key, value = tuple(key_value_pair) + for key, value in transmitted_metadata: transmitted[key].append(value) for key, values in six.iteritems(original): |