aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio
diff options
context:
space:
mode:
authorGravatar Mark D. Roth <roth@google.com>2016-07-01 14:17:55 -0700
committerGravatar Mark D. Roth <roth@google.com>2016-07-01 14:17:55 -0700
commit98aa0c13f0f31584658dd7ad565f7a9a7a866f5e (patch)
tree4b5d3cbfd668fea862882335fde156b358e39de4 /src/python/grpcio
parent97b173dfb8d10bc68dcffa135762d2d152723bc6 (diff)
parent3c945ee2b3c3bf0fd01cc995332e252d1e10e51e (diff)
Merge branch 'filter_call_init_failure' into filter_api
Diffstat (limited to 'src/python/grpcio')
-rw-r--r--src/python/grpcio/README.rst2
-rw-r--r--src/python/grpcio/grpc/_channel.py111
-rw-r--r--src/python/grpcio/grpc/_common.py43
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi5
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi8
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi1
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi36
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi2
-rw-r--r--src/python/grpcio/grpc/_cython/cygrpc.pyx34
-rw-r--r--src/python/grpcio/grpc/_cython/imports.generated.c2
-rw-r--r--src/python/grpcio/grpc/_cython/imports.generated.h3
-rw-r--r--src/python/grpcio/grpc/_cython/loader.c7
-rw-r--r--src/python/grpcio/grpc/_cython/loader.h5
-rw-r--r--src/python/grpcio/grpc/_plugin_wrapping.py14
-rw-r--r--src/python/grpcio/grpc/_server.py42
-rw-r--r--src/python/grpcio/grpc/beta/_server_adaptations.py41
-rw-r--r--src/python/grpcio/grpc_core_dependencies.py2
-rw-r--r--src/python/grpcio/grpc_version.py2
-rw-r--r--src/python/grpcio/tests/interop/methods.py9
-rw-r--r--src/python/grpcio/tests/tests.json4
-rw-r--r--src/python/grpcio/tests/unit/_compression_test.py133
-rw-r--r--src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py4
-rw-r--r--src/python/grpcio/tests/unit/_cython/_channel_test.py2
-rw-r--r--src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py4
-rw-r--r--src/python/grpcio/tests/unit/_cython/cygrpc_test.py40
-rw-r--r--src/python/grpcio/tests/unit/_empty_message_test.py8
-rw-r--r--src/python/grpcio/tests/unit/_exit_scenarios.py249
-rw-r--r--src/python/grpcio/tests/unit/_exit_test.py185
-rw-r--r--src/python/grpcio/tests/unit/_metadata_code_details_test.py523
-rw-r--r--src/python/grpcio/tests/unit/_metadata_test.py24
-rw-r--r--src/python/grpcio/tests/unit/_rpc_test.py96
-rw-r--r--src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py153
-rw-r--r--src/python/grpcio/tests/unit/beta/_utilities_test.py21
-rw-r--r--src/python/grpcio/tests/unit/beta/test_utilities.py2
-rw-r--r--src/python/grpcio/tests/unit/test_common.py18
35 files changed, 1406 insertions, 429 deletions
diff --git a/src/python/grpcio/README.rst b/src/python/grpcio/README.rst
index afc4fe6a37..3fc318539e 100644
--- a/src/python/grpcio/README.rst
+++ b/src/python/grpcio/README.rst
@@ -46,7 +46,7 @@ package named :code:`python-dev`).
::
$ export REPO_ROOT=grpc # REPO_ROOT can be any directory of your choice
- $ git clone https://github.com/grpc/grpc.git $REPO_ROOT
+ $ git clone -b $(curl -L http://grpc.io/release) https://github.com/grpc/grpc $REPO_ROOT
$ cd $REPO_ROOT
$ git submodule update --init
diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py
index 1f34beeb2c..a89b501303 100644
--- a/src/python/grpcio/grpc/_channel.py
+++ b/src/python/grpcio/grpc/_channel.py
@@ -179,6 +179,7 @@ def _event_handler(state, call, response_deserializer):
def _consume_request_iterator(
request_iterator, state, call, request_serializer):
event_handler = _event_handler(state, call, None)
+
def consume_request_iterator():
for request in request_iterator:
serialized_request = _common.serialize(request, request_serializer)
@@ -212,8 +213,18 @@ def _consume_request_iterator(
)
call.start_batch(cygrpc.Operations(operations), event_handler)
state.due.add(cygrpc.OperationType.send_close_from_client)
- thread = threading.Thread(target=consume_request_iterator)
- thread.start()
+
+ def stop_consumption_thread(timeout):
+ with state.condition:
+ if state.code is None:
+ call.cancel()
+ state.cancelled = True
+ _abort(state, grpc.StatusCode.CANCELLED, 'Cancelled!')
+ state.condition.notify_all()
+
+ consumption_thread = _common.CleanupThread(
+ stop_consumption_thread, target=consume_request_iterator)
+ consumption_thread.start()
class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
@@ -353,13 +364,13 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
with self._state.condition:
while self._state.initial_metadata is None:
self._state.condition.wait()
- return self._state.initial_metadata
+ return _common.application_metadata(self._state.initial_metadata)
def trailing_metadata(self):
with self._state.condition:
while self._state.trailing_metadata is None:
self._state.condition.wait()
- return self._state.trailing_metadata
+ return _common.application_metadata(self._state.trailing_metadata)
def code(self):
with self._state.condition:
@@ -371,7 +382,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
with self._state.condition:
while self._state.details is None:
self._state.condition.wait()
- return self._state.details
+ return _common.decode(self._state.details)
def _repr(self):
with self._state.condition:
@@ -379,7 +390,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
return '<_Rendezvous object of in-flight RPC>'
else:
return '<_Rendezvous of RPC that terminated with ({}, {})>'.format(
- self._state.code, self._state.details)
+ self._state.code, _common.decode(self._state.details))
def __repr__(self):
return self._repr()
@@ -440,7 +451,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None)
operations = (
cygrpc.operation_send_initial_metadata(
- _common.metadata(metadata), _EMPTY_FLAGS),
+ _common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
@@ -518,7 +529,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
- _common.metadata(metadata), _EMPTY_FLAGS),
+ _common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
@@ -553,7 +564,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
None)
operations = (
cygrpc.operation_send_initial_metadata(
- _common.metadata(metadata), _EMPTY_FLAGS),
+ _common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
@@ -597,7 +608,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
- _common.metadata(metadata), _EMPTY_FLAGS),
+ _common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
@@ -634,7 +645,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
- _common.metadata(metadata), _EMPTY_FLAGS),
+ _common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call.start_batch(cygrpc.Operations(operations), event_handler)
@@ -652,16 +663,27 @@ class _ChannelCallState(object):
self.managed_calls = None
-def _call_spin(state):
- while True:
- event = state.completion_queue.poll()
- completed_call = event.tag(event)
- if completed_call is not None:
- with state.lock:
- state.managed_calls.remove(completed_call)
- if not state.managed_calls:
- state.managed_calls = None
- return
+def _run_channel_spin_thread(state):
+ def channel_spin():
+ while True:
+ event = state.completion_queue.poll()
+ completed_call = event.tag(event)
+ if completed_call is not None:
+ with state.lock:
+ state.managed_calls.remove(completed_call)
+ if not state.managed_calls:
+ state.managed_calls = None
+ return
+
+ def stop_channel_spin(timeout):
+ with state.lock:
+ if state.managed_calls is not None:
+ for call in state.managed_calls:
+ call.cancel()
+
+ channel_spin_thread = _common.CleanupThread(
+ stop_channel_spin, target=channel_spin)
+ channel_spin_thread.start()
def _create_channel_managed_call(state):
@@ -690,8 +712,7 @@ def _create_channel_managed_call(state):
parent, flags, state.completion_queue, method, host, deadline)
if state.managed_calls is None:
state.managed_calls = set((call,))
- spin_thread = threading.Thread(target=_call_spin, args=(state,))
- spin_thread.start()
+ _run_channel_spin_thread(state)
else:
state.managed_calls.add(call)
return call
@@ -784,11 +805,18 @@ def _poll_connectivity(state, channel, initial_try_to_connect):
_spawn_delivery(state, callbacks)
+def _moot(state):
+ with state.lock:
+ del state.callbacks_and_connectivities[:]
+
+
def _subscribe(state, callback, try_to_connect):
with state.lock:
if not state.callbacks_and_connectivities and not state.polling:
- polling_thread = threading.Thread(
- target=_poll_connectivity,
+ def cancel_all_subscriptions(timeout):
+ _moot(state)
+ polling_thread = _common.CleanupThread(
+ cancel_all_subscriptions, target=_poll_connectivity,
args=(state, state.channel, bool(try_to_connect)))
polling_thread.start()
state.polling = True
@@ -812,19 +840,19 @@ def _unsubscribe(state, callback):
break
-def _moot(state):
- with state.lock:
- del state.callbacks_and_connectivities[:]
-
-
def _options(options):
if options is None:
pairs = ((cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT),)
else:
pairs = list(options) + [
(cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)]
- return cygrpc.ChannelArgs(
- cygrpc.ChannelArg(arg_name, arg_value) for arg_name, arg_value in pairs)
+ encoded_pairs = [
+ (_common.encode(arg_name), arg_value) if isinstance(arg_value, int)
+ else (_common.encode(arg_name), _common.encode(arg_value))
+ for arg_name, arg_value in pairs]
+ return cygrpc.ChannelArgs([
+ cygrpc.ChannelArg(arg_name, arg_value)
+ for arg_name, arg_value in encoded_pairs])
class Channel(grpc.Channel):
@@ -837,7 +865,8 @@ class Channel(grpc.Channel):
options: Configuration options for the channel.
credentials: A cygrpc.ChannelCredentials or None.
"""
- self._channel = cygrpc.Channel(target, _options(options), credentials)
+ self._channel = cygrpc.Channel(
+ _common.encode(target), _options(options), credentials)
self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel)
@@ -850,26 +879,26 @@ class Channel(grpc.Channel):
def unary_unary(
self, method, request_serializer=None, response_deserializer=None):
return _UnaryUnaryMultiCallable(
- self._channel, _create_channel_managed_call(self._call_state), method,
- request_serializer, response_deserializer)
+ self._channel, _create_channel_managed_call(self._call_state),
+ _common.encode(method), request_serializer, response_deserializer)
def unary_stream(
self, method, request_serializer=None, response_deserializer=None):
return _UnaryStreamMultiCallable(
- self._channel, _create_channel_managed_call(self._call_state), method,
- request_serializer, response_deserializer)
+ self._channel, _create_channel_managed_call(self._call_state),
+ _common.encode(method), request_serializer, response_deserializer)
def stream_unary(
self, method, request_serializer=None, response_deserializer=None):
return _StreamUnaryMultiCallable(
- self._channel, _create_channel_managed_call(self._call_state), method,
- request_serializer, response_deserializer)
+ self._channel, _create_channel_managed_call(self._call_state),
+ _common.encode(method), request_serializer, response_deserializer)
def stream_stream(
self, method, request_serializer=None, response_deserializer=None):
return _StreamStreamMultiCallable(
- self._channel, _create_channel_managed_call(self._call_state), method,
- request_serializer, response_deserializer)
+ self._channel, _create_channel_managed_call(self._call_state),
+ _common.encode(method), request_serializer, response_deserializer)
def __del__(self):
_moot(self._connectivity_state)
diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py
index f351bea9e3..4d7d521419 100644
--- a/src/python/grpcio/grpc/_common.py
+++ b/src/python/grpcio/grpc/_common.py
@@ -76,9 +76,37 @@ STATUS_CODE_TO_CYGRPC_STATUS_CODE = {
}
-def metadata(application_metadata):
+def encode(s):
+ if isinstance(s, bytes):
+ return s
+ else:
+ return s.encode('ascii')
+
+
+def decode(b):
+ if isinstance(b, str):
+ return b
+ else:
+ try:
+ return b.decode('utf8')
+ except UnicodeDecodeError:
+ logging.exception('Invalid encoding on {}'.format(b))
+ return b.decode('latin1')
+
+
+def cygrpc_metadata(application_metadata):
return _EMPTY_METADATA if application_metadata is None else cygrpc.Metadata(
- cygrpc.Metadatum(key, value) for key, value in application_metadata)
+ cygrpc.Metadatum(encode(key), encode(value))
+ for key, value in application_metadata)
+
+
+def application_metadata(cygrpc_metadata):
+ if cygrpc_metadata is None:
+ return ()
+ else:
+ return tuple(
+ (decode(key), value if key[-4:] == b'-bin' else decode(value))
+ for key, value in cygrpc_metadata)
def _transform(message, transformer, exception_message):
@@ -101,17 +129,8 @@ def deserialize(serialized_message, deserializer):
'Exception deserializing message!')
-def _encode(s):
- if isinstance(s, bytes):
- return s
- else:
- return s.encode('ascii')
-
-
def fully_qualified_method(group, method):
- group = _encode(group)
- method = _encode(method)
- return b'/' + group + b'/' + method
+ return '/{}/{}'.format(group, method)
class CleanupThread(threading.Thread):
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
index 866cff0d01..1406696510 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
@@ -32,9 +32,8 @@ cimport cpython
cdef class Channel:
- def __cinit__(self, target, ChannelArgs arguments=None,
+ def __cinit__(self, bytes target, ChannelArgs arguments=None,
ChannelCredentials channel_credentials=None):
- target = str_to_bytes(target)
cdef grpc_channel_args *c_arguments = NULL
cdef char *c_target = NULL
self.c_channel = NULL
@@ -57,8 +56,6 @@ cdef class Channel:
def create_call(self, Call parent, int flags,
CompletionQueue queue not None,
method, host, Timespec deadline not None):
- method = str_to_bytes(method)
- host = str_to_bytes(host)
if queue.is_shutting_down:
raise ValueError("queue must not be shutting down or shutdown")
cdef char *method_c_string = method
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
index 470382d609..b24e69243e 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
@@ -82,7 +82,7 @@ cdef class ServerCredentials:
cdef class CredentialsMetadataPlugin:
- def __cinit__(self, object plugin_callback, name):
+ def __cinit__(self, object plugin_callback, bytes name):
"""
Args:
plugin_callback (callable): Callback accepting a service URL (str/bytes)
@@ -91,9 +91,8 @@ cdef class CredentialsMetadataPlugin:
when called should be non-blocking and eventually call the callback
object with the appropriate status code/details and metadata (if
successful).
- name (str): Plugin name.
+ name (bytes): Plugin name.
"""
- name = str_to_bytes(name)
if not callable(plugin_callback):
raise ValueError('expected callable plugin_callback')
self.plugin_callback = plugin_callback
@@ -130,8 +129,7 @@ cdef void plugin_get_metadata(
grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil:
def python_callback(
Metadata metadata, grpc_status_code status,
- error_details):
- error_details = str_to_bytes(error_details)
+ bytes error_details):
cb(user_data, metadata.c_metadata_array.metadata,
metadata.c_metadata_array.count, status, error_details)
cdef CredentialsMetadataPlugin self = <CredentialsMetadataPlugin>state
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
index 168b9751aa..f3b3d61273 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
@@ -37,6 +37,7 @@ cdef extern from "grpc/_cython/loader.h":
ctypedef long int64_t
int pygrpc_load_core(char*)
+ int pygrpc_initialize_core()
void *gpr_malloc(size_t size) nogil
void gpr_free(void *ptr) nogil
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
index 0055d0d3a2..8e651e880f 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
@@ -231,17 +231,10 @@ cdef class Event:
cdef class ByteBuffer:
- def __cinit__(self, data):
+ def __cinit__(self, bytes data):
if data is None:
self.c_byte_buffer = NULL
return
- if isinstance(data, ByteBuffer):
- data = (<ByteBuffer>data).bytes()
- if data is None:
- self.c_byte_buffer = NULL
- return
- else:
- data = str_to_bytes(data)
cdef char *c_data = data
cdef gpr_slice data_slice
@@ -296,26 +289,28 @@ cdef class ByteBuffer:
cdef class SslPemKeyCertPair:
- def __cinit__(self, private_key, certificate_chain):
- self.private_key = str_to_bytes(private_key)
- self.certificate_chain = str_to_bytes(certificate_chain)
+ def __cinit__(self, bytes private_key, bytes certificate_chain):
+ self.private_key = private_key
+ self.certificate_chain = certificate_chain
self.c_pair.private_key = self.private_key
self.c_pair.certificate_chain = self.certificate_chain
cdef class ChannelArg:
- def __cinit__(self, key, value):
- self.key = str_to_bytes(key)
+ def __cinit__(self, bytes key, value):
+ self.key = key
self.c_arg.key = self.key
if isinstance(value, int):
- self.value = int(value)
+ self.value = value
self.c_arg.type = GRPC_ARG_INTEGER
self.c_arg.value.integer = self.value
- else:
- self.value = str_to_bytes(value)
+ elif isinstance(value, bytes):
+ self.value = value
self.c_arg.type = GRPC_ARG_STRING
self.c_arg.value.string = self.value
+ else:
+ raise TypeError('Expected int or bytes, got {}'.format(type(value)))
cdef class ChannelArgs:
@@ -347,9 +342,9 @@ cdef class ChannelArgs:
cdef class Metadatum:
- def __cinit__(self, key, value):
- self._key = str_to_bytes(key)
- self._value = str_to_bytes(value)
+ def __cinit__(self, bytes key, bytes value):
+ self._key = key
+ self._value = value
self.c_metadata.key = self._key
self.c_metadata.value = self._value
self.c_metadata.value_length = len(self._value)
@@ -563,8 +558,7 @@ def operation_send_close_from_client(int flags):
return op
def operation_send_status_from_server(
- Metadata metadata, grpc_status_code code, details, int flags):
- details = str_to_bytes(details)
+ Metadata metadata, grpc_status_code code, bytes details, int flags):
cdef Operation op = Operation()
op.c_op.type = GRPC_OP_SEND_STATUS_FROM_SERVER
op.c_op.flags = flags
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
index 42afeb8498..3e03b6efe1 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
@@ -101,7 +101,7 @@ cdef class Server:
# Ensure the core has gotten a chance to do the start-up work
self.backup_shutdown_queue.poll(Timespec(None))
- def add_http2_port(self, address,
+ def add_http2_port(self, bytes address,
ServerCredentials server_credentials=None):
address = str_to_bytes(address)
self.references.append(address)
diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx
index cf146f5a04..7a8d0dd8a1 100644
--- a/src/python/grpcio/grpc/_cython/cygrpc.pyx
+++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx
@@ -45,30 +45,22 @@ include "grpc/_cython/_cygrpc/security.pyx.pxi"
include "grpc/_cython/_cygrpc/server.pyx.pxi"
#
-# Global state
+# initialize gRPC
#
-cdef class _ModuleState:
- cdef bint is_loaded
+def _initialize():
+ if 'win32' in sys.platform:
+ filename = pkg_resources.resource_filename(
+ 'grpc._cython', '_windows/grpc_c.64.python')
+ if not isinstance(filename, bytes):
+ filename = filename.encode()
+ if not pygrpc_load_core(filename):
+ raise ImportError('failed to load core gRPC library')
+ if not pygrpc_initialize_core():
+ raise ImportError('failed to initialize core gRPC library')
- def __cinit__(self):
- if 'win32' in sys.platform:
- filename = pkg_resources.resource_filename(
- 'grpc._cython', '_windows/grpc_c.64.python')
- if not pygrpc_load_core(filename):
- raise ImportError('failed to load core gRPC library')
- with nogil:
- grpc_init()
- self.is_loaded = True
- with nogil:
- grpc_set_ssl_roots_override_callback(
+ grpc_set_ssl_roots_override_callback(
<grpc_ssl_roots_override_callback>ssl_roots_override_callback)
- def __dealloc__(self):
- if self.is_loaded:
- with nogil:
- grpc_shutdown()
-
-_module_state = _ModuleState()
-
+_initialize()
diff --git a/src/python/grpcio/grpc/_cython/imports.generated.c b/src/python/grpcio/grpc/_cython/imports.generated.c
index 8437e74ba0..d78ec2f66e 100644
--- a/src/python/grpcio/grpc/_cython/imports.generated.c
+++ b/src/python/grpcio/grpc/_cython/imports.generated.c
@@ -128,6 +128,7 @@ grpc_is_binary_header_type grpc_is_binary_header_import;
grpc_call_error_to_string_type grpc_call_error_to_string_import;
grpc_insecure_channel_create_from_fd_type grpc_insecure_channel_create_from_fd_import;
grpc_server_add_insecure_channel_from_fd_type grpc_server_add_insecure_channel_from_fd_import;
+grpc_use_signal_type grpc_use_signal_import;
grpc_auth_property_iterator_next_type grpc_auth_property_iterator_next_import;
grpc_auth_context_property_iterator_type grpc_auth_context_property_iterator_import;
grpc_auth_context_peer_identity_type grpc_auth_context_peer_identity_import;
@@ -403,6 +404,7 @@ void pygrpc_load_imports(HMODULE library) {
grpc_call_error_to_string_import = (grpc_call_error_to_string_type) GetProcAddress(library, "grpc_call_error_to_string");
grpc_insecure_channel_create_from_fd_import = (grpc_insecure_channel_create_from_fd_type) GetProcAddress(library, "grpc_insecure_channel_create_from_fd");
grpc_server_add_insecure_channel_from_fd_import = (grpc_server_add_insecure_channel_from_fd_type) GetProcAddress(library, "grpc_server_add_insecure_channel_from_fd");
+ grpc_use_signal_import = (grpc_use_signal_type) GetProcAddress(library, "grpc_use_signal");
grpc_auth_property_iterator_next_import = (grpc_auth_property_iterator_next_type) GetProcAddress(library, "grpc_auth_property_iterator_next");
grpc_auth_context_property_iterator_import = (grpc_auth_context_property_iterator_type) GetProcAddress(library, "grpc_auth_context_property_iterator");
grpc_auth_context_peer_identity_import = (grpc_auth_context_peer_identity_type) GetProcAddress(library, "grpc_auth_context_peer_identity");
diff --git a/src/python/grpcio/grpc/_cython/imports.generated.h b/src/python/grpcio/grpc/_cython/imports.generated.h
index d52e8591b3..b3e341fe25 100644
--- a/src/python/grpcio/grpc/_cython/imports.generated.h
+++ b/src/python/grpcio/grpc/_cython/imports.generated.h
@@ -335,6 +335,9 @@ extern grpc_insecure_channel_create_from_fd_type grpc_insecure_channel_create_fr
typedef void(*grpc_server_add_insecure_channel_from_fd_type)(grpc_server *server, grpc_completion_queue *cq, int fd);
extern grpc_server_add_insecure_channel_from_fd_type grpc_server_add_insecure_channel_from_fd_import;
#define grpc_server_add_insecure_channel_from_fd grpc_server_add_insecure_channel_from_fd_import
+typedef void(*grpc_use_signal_type)(int signum);
+extern grpc_use_signal_type grpc_use_signal_import;
+#define grpc_use_signal grpc_use_signal_import
typedef const grpc_auth_property *(*grpc_auth_property_iterator_next_type)(grpc_auth_property_iterator *it);
extern grpc_auth_property_iterator_next_type grpc_auth_property_iterator_next_import;
#define grpc_auth_property_iterator_next grpc_auth_property_iterator_next_import
diff --git a/src/python/grpcio/grpc/_cython/loader.c b/src/python/grpcio/grpc/_cython/loader.c
index b909ad594e..86b70dbb02 100644
--- a/src/python/grpcio/grpc/_cython/loader.c
+++ b/src/python/grpcio/grpc/_cython/loader.c
@@ -31,6 +31,7 @@
*
*/
+#include <Python.h>
#include "loader.h"
#ifdef __cplusplus
@@ -62,6 +63,12 @@ int pygrpc_load_core(char *path) { return 1; }
#endif /* !GPR_WINDOWS */
+// Cython doesn't have Py_AtExit bindings, so we call the C_API directly
+int pygrpc_initialize_core(void) {
+ grpc_init();
+ return Py_AtExit(grpc_shutdown) < 0 ? 0 : 1;
+}
+
#ifdef __cplusplus
}
#endif /* __cpluslus */
diff --git a/src/python/grpcio/grpc/_cython/loader.h b/src/python/grpcio/grpc/_cython/loader.h
index 3b8796d39f..eb4b1a1b01 100644
--- a/src/python/grpcio/grpc/_cython/loader.h
+++ b/src/python/grpcio/grpc/_cython/loader.h
@@ -46,6 +46,11 @@ extern "C" {
/* Attempts to load the core if necessary, and return non-zero upon succes. */
int pygrpc_load_core(char *path);
+/* Initializes grpc and registers grpc_shutdown() to be called right before
+ * interpreter exit. Returns non-zero upon success.
+ */
+int pygrpc_initialize_core(void);
+
#ifdef __cplusplus
}
#endif /* __cpluslus */
diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py
index 4e9cfe710c..7cb5218c22 100644
--- a/src/python/grpcio/grpc/_plugin_wrapping.py
+++ b/src/python/grpcio/grpc/_plugin_wrapping.py
@@ -31,6 +31,7 @@ import collections
import threading
import grpc
+from grpc import _common
from grpc._cython import cygrpc
@@ -62,17 +63,16 @@ class _WrappedCygrpcCallback(object):
# TODO(atash) translate different Exception superclasses into different
# status codes.
self.cygrpc_callback(
- cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message)
+ _common.EMPTY_METADATA, cygrpc.StatusCode.internal,
+ _common.encode(str(error)))
def _invoke_success(self, metadata):
try:
- cygrpc_metadata = cygrpc.Metadata(
- cygrpc.Metadatum(key, value)
- for key, value in metadata)
+ cygrpc_metadata = _common.cygrpc_metadata(metadata)
except Exception as error:
self._invoke_failure(error)
return
- self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '')
+ self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, b'')
def __call__(self, metadata, error):
with self.is_called_lock:
@@ -101,7 +101,7 @@ class _WrappedPlugin(object):
def __call__(self, context, cygrpc_callback):
wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback)
wrapped_context = AuthMetadataContext(
- context.service_url, context.method_name)
+ _common.decode(context.service_url), _common.decode(context.method_name))
try:
self.plugin(
wrapped_context, AuthMetadataPluginCallback(wrapped_cygrpc_callback))
@@ -120,4 +120,4 @@ def call_credentials_metadata_plugin(plugin, name):
plugin's invocation must be non-blocking.
"""
return cygrpc.call_credentials_metadata_plugin(
- cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name))
+ cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), _common.encode(name)))
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py
index bf20f15f72..f4c114056f 100644
--- a/src/python/grpcio/grpc/_server.py
+++ b/src/python/grpcio/grpc/_server.py
@@ -87,7 +87,7 @@ def _abortion_code(state, code):
def _details(state):
- return '' if state.details is None else state.details
+ return b'' if state.details is None else state.details
class _HandlerCallDetails(
@@ -146,14 +146,14 @@ def _abort(state, call, code, details):
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
- _common.metadata(state.trailing_metadata), effective_code,
+ _common.cygrpc_metadata(state.trailing_metadata), effective_code,
effective_details, _EMPTY_FLAGS),
)
token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
else:
operations = (
cygrpc.operation_send_status_from_server(
- _common.metadata(state.trailing_metadata), effective_code,
+ _common.cygrpc_metadata(state.trailing_metadata), effective_code,
effective_details, _EMPTY_FLAGS),
)
token = _SEND_STATUS_FROM_SERVER_TOKEN
@@ -191,7 +191,7 @@ def _receive_message(state, call, request_deserializer):
if request is None:
_abort(
state, call, cygrpc.StatusCode.internal,
- 'Exception deserializing request!')
+ b'Exception deserializing request!')
else:
state.request = request
state.condition.notify_all()
@@ -244,10 +244,10 @@ class _Context(grpc.ServicerContext):
self._state.disable_next_compression = True
def invocation_metadata(self):
- return self._rpc_event.request_metadata
+ return _common.application_metadata(self._rpc_event.request_metadata)
def peer(self):
- return self._rpc_event.operation_call.peer()
+ return _common.decode(self._rpc_event.operation_call.peer())
def send_initial_metadata(self, initial_metadata):
with self._state.condition:
@@ -256,7 +256,7 @@ class _Context(grpc.ServicerContext):
else:
if self._state.initial_metadata_allowed:
operation = cygrpc.operation_send_initial_metadata(
- _common.metadata(initial_metadata), _EMPTY_FLAGS)
+ _common.cygrpc_metadata(initial_metadata), _EMPTY_FLAGS)
self._rpc_event.operation_call.start_batch(
cygrpc.Operations((operation,)),
_send_initial_metadata(self._state))
@@ -267,7 +267,8 @@ class _Context(grpc.ServicerContext):
def set_trailing_metadata(self, trailing_metadata):
with self._state.condition:
- self._state.trailing_metadata = trailing_metadata
+ self._state.trailing_metadata = _common.cygrpc_metadata(
+ trailing_metadata)
def set_code(self, code):
with self._state.condition:
@@ -275,7 +276,7 @@ class _Context(grpc.ServicerContext):
def set_details(self, details):
with self._state.condition:
- self._state.details = details
+ self._state.details = _common.encode(details)
class _RequestIterator(object):
@@ -346,7 +347,7 @@ def _unary_request(rpc_event, state, request_deserializer):
rpc_event.request_call_details.method)
_abort(
state, rpc_event.operation_call,
- cygrpc.StatusCode.unimplemented, details)
+ cygrpc.StatusCode.unimplemented, _common.encode(details))
return None
elif state.client is _CANCELLED:
return None
@@ -366,8 +367,8 @@ def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
if e not in state.rpc_errors:
details = 'Exception calling application: {}'.format(e)
logging.exception(details)
- _abort(
- state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details)
+ _abort(state, rpc_event.operation_call,
+ cygrpc.StatusCode.unknown, _common.encode(details))
return None, False
@@ -381,8 +382,8 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator):
if e not in state.rpc_errors:
details = 'Exception iterating responses: {}'.format(e)
logging.exception(details)
- _abort(
- state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details)
+ _abort(state, rpc_event.operation_call,
+ cygrpc.StatusCode.unknown, _common.encode(details))
return None, False
@@ -392,7 +393,7 @@ def _serialize_response(rpc_event, state, response, response_serializer):
with state.condition:
_abort(
state, rpc_event.operation_call, cygrpc.StatusCode.internal,
- 'Failed to serialize response!')
+ b'Failed to serialize response!')
return None
else:
return serialized_response
@@ -428,7 +429,7 @@ def _send_response(rpc_event, state, serialized_response):
def _status(rpc_event, state, serialized_response):
with state.condition:
if state.client is not _CANCELLED:
- trailing_metadata = _common.metadata(state.trailing_metadata)
+ trailing_metadata = _common.cygrpc_metadata(state.trailing_metadata)
code = _completion_code(state)
details = _details(state)
operations = [
@@ -532,7 +533,8 @@ def _find_method_handler(rpc_event, generic_handlers):
for generic_handler in generic_handlers:
method_handler = generic_handler.service(
_HandlerCallDetails(
- rpc_event.request_call_details.method, rpc_event.request_metadata))
+ _common.decode(rpc_event.request_call_details.method),
+ rpc_event.request_metadata))
if method_handler is not None:
return method_handler
else:
@@ -545,7 +547,7 @@ def _handle_unrecognized_method(rpc_event):
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
_EMPTY_METADATA, cygrpc.StatusCode.unimplemented,
- 'Method not found!', _EMPTY_FLAGS),
+ b'Method not found!', _EMPTY_FLAGS),
)
rpc_state = _RPCState()
rpc_event.operation_call.start_batch(
@@ -740,10 +742,10 @@ class Server(grpc.Server):
_add_generic_handlers(self._state, generic_rpc_handlers)
def add_insecure_port(self, address):
- return _add_insecure_port(self._state, address)
+ return _add_insecure_port(self._state, _common.encode(address))
def add_secure_port(self, address, server_credentials):
- return _add_secure_port(self._state, address, server_credentials)
+ return _add_secure_port(self._state, _common.encode(address), server_credentials)
def start(self):
_start(self._state)
diff --git a/src/python/grpcio/grpc/beta/_server_adaptations.py b/src/python/grpcio/grpc/beta/_server_adaptations.py
index 79e6ca87eb..1e1f80156a 100644
--- a/src/python/grpcio/grpc/beta/_server_adaptations.py
+++ b/src/python/grpcio/grpc/beta/_server_adaptations.py
@@ -79,7 +79,8 @@ class _FaceServicerContext(face.ServicerContext):
return _ServerProtocolContext(self._servicer_context)
def invocation_metadata(self):
- return self._servicer_context.invocation_metadata()
+ return _common.cygrpc_metadata(
+ self._servicer_context.invocation_metadata())
def initial_metadata(self, initial_metadata):
self._servicer_context.send_initial_metadata(initial_metadata)
@@ -161,14 +162,24 @@ class _Callback(stream.Consumer):
self._condition.wait()
-def _pipe_requests(request_iterator, request_consumer, servicer_context):
- for request in request_iterator:
- if not servicer_context.is_active():
- return
- request_consumer.consume(request)
- if not servicer_context.is_active():
- return
- request_consumer.terminate()
+def _run_request_pipe_thread(request_iterator, request_consumer,
+ servicer_context):
+ thread_joined = threading.Event()
+ def pipe_requests():
+ for request in request_iterator:
+ if not servicer_context.is_active() or thread_joined.is_set():
+ return
+ request_consumer.consume(request)
+ if not servicer_context.is_active() or thread_joined.is_set():
+ return
+ request_consumer.terminate()
+
+ def stop_request_pipe(timeout):
+ thread_joined.set()
+
+ request_pipe_thread = _common.CleanupThread(
+ stop_request_pipe, target=pipe_requests)
+ request_pipe_thread.start()
def _adapt_unary_unary_event(unary_unary_event):
@@ -206,10 +217,8 @@ def _adapt_stream_unary_event(stream_unary_event):
raise abandonment.Abandoned()
request_consumer = stream_unary_event(
callback.consume_and_terminate, _FaceServicerContext(servicer_context))
- request_pipe_thread = threading.Thread(
- target=_pipe_requests,
- args=(request_iterator, request_consumer, servicer_context,))
- request_pipe_thread.start()
+ _run_request_pipe_thread(
+ request_iterator, request_consumer, servicer_context)
return callback.draw_all_values()[0]
return adaptation
@@ -221,10 +230,8 @@ def _adapt_stream_stream_event(stream_stream_event):
raise abandonment.Abandoned()
request_consumer = stream_stream_event(
callback, _FaceServicerContext(servicer_context))
- request_pipe_thread = threading.Thread(
- target=_pipe_requests,
- args=(request_iterator, request_consumer, servicer_context,))
- request_pipe_thread.start()
+ _run_request_pipe_thread(
+ request_iterator, request_consumer, servicer_context)
while True:
response = callback.draw_one_value()
if response is None:
diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py
index 839c555f05..b37e27c27e 100644
--- a/src/python/grpcio/grpc_core_dependencies.py
+++ b/src/python/grpcio/grpc_core_dependencies.py
@@ -94,6 +94,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/iomgr/endpoint_pair_posix.c',
'src/core/lib/iomgr/endpoint_pair_windows.c',
'src/core/lib/iomgr/error.c',
+ 'src/core/lib/iomgr/ev_epoll_linux.c',
'src/core/lib/iomgr/ev_poll_and_epoll_posix.c',
'src/core/lib/iomgr/ev_poll_posix.c',
'src/core/lib/iomgr/ev_posix.c',
@@ -104,6 +105,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/iomgr/iomgr_posix.c',
'src/core/lib/iomgr/iomgr_windows.c',
'src/core/lib/iomgr/load_file.c',
+ 'src/core/lib/iomgr/network_status_tracker.c',
'src/core/lib/iomgr/polling_entity.c',
'src/core/lib/iomgr/pollset_set_windows.c',
'src/core/lib/iomgr/pollset_windows.c',
diff --git a/src/python/grpcio/grpc_version.py b/src/python/grpcio/grpc_version.py
index 0c13104d9d..0f4db9d972 100644
--- a/src/python/grpcio/grpc_version.py
+++ b/src/python/grpcio/grpc_version.py
@@ -29,4 +29,4 @@
# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc_version.py.template`!!!
-VERSION='0.15.0.dev0'
+VERSION='0.16.0.dev0'
diff --git a/src/python/grpcio/tests/interop/methods.py b/src/python/grpcio/tests/interop/methods.py
index 7eac511525..86aa0495a2 100644
--- a/src/python/grpcio/tests/interop/methods.py
+++ b/src/python/grpcio/tests/interop/methods.py
@@ -79,10 +79,11 @@ class TestService(test_pb2.BetaTestServiceServicer):
def FullDuplexCall(self, request_iterator, context):
for request in request_iterator:
- yield messages_pb2.StreamingOutputCallResponse(
- payload=messages_pb2.Payload(
- type=request.payload.type,
- body=b'\x00' * request.response_parameters[0].size))
+ for response_parameters in request.response_parameters:
+ yield messages_pb2.StreamingOutputCallResponse(
+ payload=messages_pb2.Payload(
+ type=request.payload.type,
+ body=b'\x00' * response_parameters.size))
# NOTE(nathaniel): Apparently this is the same as the full-duplex call?
# NOTE(atash): It isn't even called in the interop spec (Oct 22 2015)...
diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json
index 8e509621a8..45eb75b242 100644
--- a/src/python/grpcio/tests/tests.json
+++ b/src/python/grpcio/tests/tests.json
@@ -10,9 +10,10 @@
"_channel_connectivity_test.ChannelConnectivityTest",
"_channel_ready_future_test.ChannelReadyFutureTest",
"_channel_test.ChannelTest",
- "_connectivity_channel_test.ChannelConnectivityTest",
+ "_compression_test.CompressionTest",
"_connectivity_channel_test.ConnectivityStatesTest",
"_empty_message_test.EmptyMessageTest",
+ "_exit_test.ExitTest",
"_face_interface_test.DynamicInvokerBlockingInvocationInlineServiceTest",
"_face_interface_test.DynamicInvokerFutureInvocationAsynchronousEventServiceTest",
"_face_interface_test.GenericInvokerBlockingInvocationInlineServiceTest",
@@ -24,6 +25,7 @@
"_implementations_test.ChannelCredentialsTest",
"_insecure_interop_test.InsecureInteropTest",
"_logging_pool_test.LoggingPoolTest",
+ "_metadata_code_details_test.MetadataCodeDetailsTest",
"_metadata_test.MetadataTest",
"_not_found_test.NotFoundTest",
"_python_plugin_test.PythonPluginTest",
diff --git a/src/python/grpcio/tests/unit/_compression_test.py b/src/python/grpcio/tests/unit/_compression_test.py
new file mode 100644
index 0000000000..9e8b8578c1
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_compression_test.py
@@ -0,0 +1,133 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Tests server and client side compression."""
+
+import unittest
+
+import grpc
+from grpc import _grpcio_metadata
+from grpc.framework.foundation import logging_pool
+
+from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+
+_UNARY_UNARY = '/test/UnaryUnary'
+_STREAM_STREAM = '/test/StreamStream'
+
+
+def handle_unary(request, servicer_context):
+ servicer_context.send_initial_metadata([
+ ('grpc-internal-encoding-request', 'gzip')])
+ return request
+
+
+def handle_stream(request_iterator, servicer_context):
+ # TODO(issue:#6891) We should be able to remove this loop,
+ # and replace with return; yield
+ servicer_context.send_initial_metadata([
+ ('grpc-internal-encoding-request', 'gzip')])
+ for request in request_iterator:
+ yield request
+
+
+class _MethodHandler(grpc.RpcMethodHandler):
+
+ def __init__(self, request_streaming, response_streaming):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ self.stream_stream = lambda x, y: handle_stream(x, y)
+ elif not self.request_streaming and not self.response_streaming:
+ self.unary_unary = lambda x, y: handle_unary(x, y)
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(False, False)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(True, True)
+ else:
+ return None
+
+
+class CompressionTest(unittest.TestCase):
+
+ def setUp(self):
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server((_GenericHandler(),), self._server_pool)
+ self._port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+
+ def testUnary(self):
+ request = b'\x00' * 100
+
+ # Client -> server compressed through default client channel compression
+ # settings. Server -> client compressed via server-side metadata setting.
+ # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer
+ # literal with proper use of the public API.
+ compressed_channel = grpc.insecure_channel('localhost:%d' % self._port,
+ options=[('grpc.default_compression_algorithm', 1)])
+ multi_callable = compressed_channel.unary_unary(_UNARY_UNARY)
+ response = multi_callable(request)
+ self.assertEqual(request, response)
+
+ # Client -> server compressed through client metadata setting. Server ->
+ # client compressed via server-side metadata setting.
+ # TODO(https://github.com/grpc/grpc/issues/4078): replace the "0" integer
+ # literal with proper use of the public API.
+ uncompressed_channel = grpc.insecure_channel('localhost:%d' % self._port,
+ options=[('grpc.default_compression_algorithm', 0)])
+ multi_callable = compressed_channel.unary_unary(_UNARY_UNARY)
+ response = multi_callable(request, metadata=[
+ ('grpc-internal-encoding-request', 'gzip')])
+ self.assertEqual(request, response)
+
+ def testStreaming(self):
+ request = b'\x00' * 100
+
+ # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer
+ # literal with proper use of the public API.
+ compressed_channel = grpc.insecure_channel('localhost:%d' % self._port,
+ options=[('grpc.default_compression_algorithm', 1)])
+ multi_callable = compressed_channel.stream_stream(_STREAM_STREAM)
+ call = multi_callable([request] * test_constants.STREAM_LENGTH)
+ for response in call:
+ self.assertEqual(request, response)
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py
index c1de779014..cac0c8b3b9 100644
--- a/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py
+++ b/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py
@@ -159,9 +159,9 @@ class CancelManyCallsTest(unittest.TestCase):
server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server()
server.register_completion_queue(server_completion_queue)
- port = server.add_http2_port('[::]:0')
+ port = server.add_http2_port(b'[::]:0')
server.start()
- channel = cygrpc.Channel('localhost:{}'.format(port))
+ channel = cygrpc.Channel('localhost:{}'.format(port).encode())
state = _State()
diff --git a/src/python/grpcio/tests/unit/_cython/_channel_test.py b/src/python/grpcio/tests/unit/_cython/_channel_test.py
index 3dc7a246ae..f9c8a3ac62 100644
--- a/src/python/grpcio/tests/unit/_cython/_channel_test.py
+++ b/src/python/grpcio/tests/unit/_cython/_channel_test.py
@@ -37,7 +37,7 @@ from tests.unit.framework.common import test_constants
def _channel_and_completion_queue():
- channel = cygrpc.Channel('localhost:54321', cygrpc.ChannelArgs(()))
+ channel = cygrpc.Channel(b'localhost:54321', cygrpc.ChannelArgs(()))
completion_queue = cygrpc.CompletionQueue()
return channel, completion_queue
diff --git a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py
index 6ae7a90fbe..27fcee0d6f 100644
--- a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py
+++ b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py
@@ -126,9 +126,9 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server()
server.register_completion_queue(server_completion_queue)
- port = server.add_http2_port('[::]:0')
+ port = server.add_http2_port(b'[::]:0')
server.start()
- channel = cygrpc.Channel('localhost:{}'.format(port))
+ channel = cygrpc.Channel('localhost:{}'.format(port).encode())
server_shutdown_tag = 'server_shutdown_tag'
server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag)
diff --git a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
index a006a20ce3..b740695e35 100644
--- a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
+++ b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
@@ -46,38 +46,38 @@ def _metadata_plugin_callback(context, callback):
callback(cygrpc.Metadata(
[cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY,
_CALL_CREDENTIALS_METADATA_VALUE)]),
- cygrpc.StatusCode.ok, '')
+ cygrpc.StatusCode.ok, b'')
class TypeSmokeTest(unittest.TestCase):
def testStringsInUtilitiesUpDown(self):
self.assertEqual(0, cygrpc.StatusCode.ok)
- metadatum = cygrpc.Metadatum('a', 'b')
- self.assertEqual('a'.encode(), metadatum.key)
- self.assertEqual('b'.encode(), metadatum.value)
+ metadatum = cygrpc.Metadatum(b'a', b'b')
+ self.assertEqual(b'a', metadatum.key)
+ self.assertEqual(b'b', metadatum.value)
metadata = cygrpc.Metadata([metadatum])
self.assertEqual(1, len(metadata))
self.assertEqual(metadatum.key, metadata[0].key)
def testMetadataIteration(self):
metadata = cygrpc.Metadata([
- cygrpc.Metadatum('a', 'b'), cygrpc.Metadatum('c', 'd')])
+ cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')])
iterator = iter(metadata)
metadatum = next(iterator)
self.assertIsInstance(metadatum, cygrpc.Metadatum)
- self.assertEqual(metadatum.key, 'a'.encode())
- self.assertEqual(metadatum.value, 'b'.encode())
+ self.assertEqual(metadatum.key, b'a')
+ self.assertEqual(metadatum.value, b'b')
metadatum = next(iterator)
self.assertIsInstance(metadatum, cygrpc.Metadatum)
- self.assertEqual(metadatum.key, 'c'.encode())
- self.assertEqual(metadatum.value, 'd'.encode())
+ self.assertEqual(metadatum.key, b'c')
+ self.assertEqual(metadatum.value, b'd')
with self.assertRaises(StopIteration):
next(iterator)
def testOperationsIteration(self):
operations = cygrpc.Operations([
- cygrpc.operation_send_message('asdf', _EMPTY_FLAGS)])
+ cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)])
iterator = iter(operations)
operation = next(iterator)
self.assertIsInstance(operation, cygrpc.Operation)
@@ -87,7 +87,7 @@ class TypeSmokeTest(unittest.TestCase):
next(iterator)
def testOperationFlags(self):
- operation = cygrpc.operation_send_message('asdf',
+ operation = cygrpc.operation_send_message(b'asdf',
cygrpc.WriteFlag.no_compress)
self.assertEqual(cygrpc.WriteFlag.no_compress, operation.flags)
@@ -105,16 +105,16 @@ class TypeSmokeTest(unittest.TestCase):
del server
def testChannelUpDown(self):
- channel = cygrpc.Channel('[::]:0', cygrpc.ChannelArgs([]))
+ channel = cygrpc.Channel(b'[::]:0', cygrpc.ChannelArgs([]))
del channel
def testCredentialsMetadataPluginUpDown(self):
plugin = cygrpc.CredentialsMetadataPlugin(
- lambda ignored_a, ignored_b: None, '')
+ lambda ignored_a, ignored_b: None, b'')
del plugin
def testCallCredentialsFromPluginUpDown(self):
- plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '')
+ plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, b'')
call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
del plugin
del call_credentials
@@ -123,7 +123,7 @@ class TypeSmokeTest(unittest.TestCase):
server = cygrpc.Server()
completion_queue = cygrpc.CompletionQueue()
server.register_completion_queue(completion_queue)
- port = server.add_http2_port('[::]:0')
+ port = server.add_http2_port(b'[::]:0')
self.assertIsInstance(port, int)
server.start()
del server
@@ -131,7 +131,7 @@ class TypeSmokeTest(unittest.TestCase):
def testServerStartShutdown(self):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server()
- server.add_http2_port('[::]:0')
+ server.add_http2_port(b'[::]:0')
server.register_completion_queue(completion_queue)
server.start()
shutdown_tag = object()
@@ -150,9 +150,9 @@ class ServerClientMixin(object):
self.server = cygrpc.Server()
self.server.register_completion_queue(self.server_completion_queue)
if server_credentials:
- self.port = self.server.add_http2_port('[::]:0', server_credentials)
+ self.port = self.server.add_http2_port(b'[::]:0', server_credentials)
else:
- self.port = self.server.add_http2_port('[::]:0')
+ self.port = self.server.add_http2_port(b'[::]:0')
self.server.start()
self.client_completion_queue = cygrpc.CompletionQueue()
if client_credentials:
@@ -160,10 +160,10 @@ class ServerClientMixin(object):
cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
host_override)])
self.client_channel = cygrpc.Channel(
- 'localhost:{}'.format(self.port), client_channel_arguments,
+ 'localhost:{}'.format(self.port).encode(), client_channel_arguments,
client_credentials)
else:
- self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port))
+ self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port).encode())
if host_override:
self.host_argument = None # default host
self.expected_host = host_override
diff --git a/src/python/grpcio/tests/unit/_empty_message_test.py b/src/python/grpcio/tests/unit/_empty_message_test.py
index f324f6216b..8c7d697728 100644
--- a/src/python/grpcio/tests/unit/_empty_message_test.py
+++ b/src/python/grpcio/tests/unit/_empty_message_test.py
@@ -37,10 +37,10 @@ from tests.unit.framework.common import test_constants
_REQUEST = b''
_RESPONSE = b''
-_UNARY_UNARY = b'/test/UnaryUnary'
-_UNARY_STREAM = b'/test/UnaryStream'
-_STREAM_UNARY = b'/test/StreamUnary'
-_STREAM_STREAM = b'/test/StreamStream'
+_UNARY_UNARY = '/test/UnaryUnary'
+_UNARY_STREAM = '/test/UnaryStream'
+_STREAM_UNARY = '/test/StreamUnary'
+_STREAM_STREAM = '/test/StreamStream'
def handle_unary_unary(request, servicer_context):
diff --git a/src/python/grpcio/tests/unit/_exit_scenarios.py b/src/python/grpcio/tests/unit/_exit_scenarios.py
new file mode 100644
index 0000000000..24a2faef85
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_exit_scenarios.py
@@ -0,0 +1,249 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Defines a number of module-scope gRPC scenarios to test clean exit."""
+
+import argparse
+import threading
+import time
+
+import grpc
+
+from tests.unit.framework.common import test_constants
+
+WAIT_TIME = 1000
+
+REQUEST = b'request'
+
+UNSTARTED_SERVER = 'unstarted_server'
+RUNNING_SERVER = 'running_server'
+POLL_CONNECTIVITY_NO_SERVER = 'poll_connectivity_no_server'
+POLL_CONNECTIVITY = 'poll_connectivity'
+IN_FLIGHT_UNARY_UNARY_CALL = 'in_flight_unary_unary_call'
+IN_FLIGHT_UNARY_STREAM_CALL = 'in_flight_unary_stream_call'
+IN_FLIGHT_STREAM_UNARY_CALL = 'in_flight_stream_unary_call'
+IN_FLIGHT_STREAM_STREAM_CALL = 'in_flight_stream_stream_call'
+IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL = 'in_flight_partial_unary_stream_call'
+IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL = 'in_flight_partial_stream_unary_call'
+IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL = 'in_flight_partial_stream_stream_call'
+
+UNARY_UNARY = b'/test/UnaryUnary'
+UNARY_STREAM = b'/test/UnaryStream'
+STREAM_UNARY = b'/test/StreamUnary'
+STREAM_STREAM = b'/test/StreamStream'
+PARTIAL_UNARY_STREAM = b'/test/PartialUnaryStream'
+PARTIAL_STREAM_UNARY = b'/test/PartialStreamUnary'
+PARTIAL_STREAM_STREAM = b'/test/PartialStreamStream'
+
+TEST_TO_METHOD = {
+ IN_FLIGHT_UNARY_UNARY_CALL: UNARY_UNARY,
+ IN_FLIGHT_UNARY_STREAM_CALL: UNARY_STREAM,
+ IN_FLIGHT_STREAM_UNARY_CALL: STREAM_UNARY,
+ IN_FLIGHT_STREAM_STREAM_CALL: STREAM_STREAM,
+ IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL: PARTIAL_UNARY_STREAM,
+ IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL: PARTIAL_STREAM_UNARY,
+ IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL: PARTIAL_STREAM_STREAM,
+}
+
+
+def hang_unary_unary(request, servicer_context):
+ time.sleep(WAIT_TIME)
+
+
+def hang_unary_stream(request, servicer_context):
+ time.sleep(WAIT_TIME)
+
+
+def hang_partial_unary_stream(request, servicer_context):
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ yield request
+ time.sleep(WAIT_TIME)
+
+
+def hang_stream_unary(request_iterator, servicer_context):
+ time.sleep(WAIT_TIME)
+
+
+def hang_partial_stream_unary(request_iterator, servicer_context):
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ next(request_iterator)
+ time.sleep(WAIT_TIME)
+
+
+def hang_stream_stream(request_iterator, servicer_context):
+ time.sleep(WAIT_TIME)
+
+
+def hang_partial_stream_stream(request_iterator, servicer_context):
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ yield next(request_iterator)
+ time.sleep(WAIT_TIME)
+
+
+class MethodHandler(grpc.RpcMethodHandler):
+
+ def __init__(self, request_streaming, response_streaming, partial_hang):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ if partial_hang:
+ self.stream_stream = hang_partial_stream_stream
+ else:
+ self.stream_stream = hang_stream_stream
+ elif self.request_streaming:
+ if partial_hang:
+ self.stream_unary = hang_partial_stream_unary
+ else:
+ self.stream_unary = hang_stream_unary
+ elif self.response_streaming:
+ if partial_hang:
+ self.unary_stream = hang_partial_unary_stream
+ else:
+ self.unary_stream = hang_unary_stream
+ else:
+ self.unary_unary = hang_unary_unary
+
+
+class GenericHandler(grpc.GenericRpcHandler):
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == UNARY_UNARY:
+ return MethodHandler(False, False, False)
+ elif handler_call_details.method == UNARY_STREAM:
+ return MethodHandler(False, True, False)
+ elif handler_call_details.method == STREAM_UNARY:
+ return MethodHandler(True, False, False)
+ elif handler_call_details.method == STREAM_STREAM:
+ return MethodHandler(True, True, False)
+ elif handler_call_details.method == PARTIAL_UNARY_STREAM:
+ return MethodHandler(False, True, True)
+ elif handler_call_details.method == PARTIAL_STREAM_UNARY:
+ return MethodHandler(True, False, True)
+ elif handler_call_details.method == PARTIAL_STREAM_STREAM:
+ return MethodHandler(True, True, True)
+ else:
+ return None
+
+
+# Traditional executors will not exit until all their
+# current jobs complete. Because we submit jobs that will
+# never finish, we don't want to block exit on these jobs.
+class DaemonPool(object):
+
+ def submit(self, fn, *args, **kwargs):
+ thread = threading.Thread(target=fn, args=args, kwargs=kwargs)
+ thread.daemon = True
+ thread.start()
+
+ def shutdown(self, wait=True):
+ pass
+
+
+def infinite_request_iterator():
+ while True:
+ yield REQUEST
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('scenario', type=str)
+ parser.add_argument(
+ '--wait_for_interrupt', dest='wait_for_interrupt', action='store_true')
+ args = parser.parse_args()
+
+ if args.scenario == UNSTARTED_SERVER:
+ server = grpc.server((), DaemonPool())
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+ elif args.scenario == RUNNING_SERVER:
+ server = grpc.server((), DaemonPool())
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+ elif args.scenario == POLL_CONNECTIVITY_NO_SERVER:
+ channel = grpc.insecure_channel('localhost:12345')
+
+ def connectivity_callback(connectivity):
+ pass
+
+ channel.subscribe(connectivity_callback, try_to_connect=True)
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+ elif args.scenario == POLL_CONNECTIVITY:
+ server = grpc.server((), DaemonPool())
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:%d' % port)
+
+ def connectivity_callback(connectivity):
+ pass
+
+ channel.subscribe(connectivity_callback, try_to_connect=True)
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+
+ else:
+ handler = GenericHandler()
+ server = grpc.server((), DaemonPool())
+ port = server.add_insecure_port('[::]:0')
+ server.add_generic_rpc_handlers((handler,))
+ server.start()
+ channel = grpc.insecure_channel('localhost:%d' % port)
+
+ method = TEST_TO_METHOD[args.scenario]
+
+ if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL:
+ multi_callable = channel.unary_unary(method)
+ future = multi_callable.future(REQUEST)
+ result, call = multi_callable.with_call(REQUEST)
+ elif (args.scenario == IN_FLIGHT_UNARY_STREAM_CALL or
+ args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL):
+ multi_callable = channel.unary_stream(method)
+ response_iterator = multi_callable(REQUEST)
+ for response in response_iterator:
+ pass
+ elif (args.scenario == IN_FLIGHT_STREAM_UNARY_CALL or
+ args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL):
+ multi_callable = channel.stream_unary(method)
+ future = multi_callable.future(infinite_request_iterator())
+ result, call = multi_callable.with_call(
+ [REQUEST] * test_constants.STREAM_LENGTH)
+ elif (args.scenario == IN_FLIGHT_STREAM_STREAM_CALL or
+ args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL):
+ multi_callable = channel.stream_stream(method)
+ response_iterator = multi_callable(infinite_request_iterator())
+ for response in response_iterator:
+ pass
diff --git a/src/python/grpcio/tests/unit/_exit_test.py b/src/python/grpcio/tests/unit/_exit_test.py
new file mode 100644
index 0000000000..b0d6af73e5
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_exit_test.py
@@ -0,0 +1,185 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests clean exit of server/client on Python Interpreter exit/sigint.
+
+The tests in this module spawn a subprocess for each test case, the
+test is considered successful if it doesn't hang/timeout.
+"""
+
+import atexit
+import os
+import signal
+import six
+import subprocess
+import sys
+import threading
+import time
+import unittest
+
+from tests.unit import _exit_scenarios
+
+SCENARIO_FILE = os.path.abspath(os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), '_exit_scenarios.py'))
+INTERPRETER = sys.executable
+BASE_COMMAND = [INTERPRETER, SCENARIO_FILE]
+BASE_SIGTERM_COMMAND = BASE_COMMAND + ['--wait_for_interrupt']
+
+INIT_TIME = 1.0
+
+
+processes = []
+process_lock = threading.Lock()
+
+
+# Make sure we attempt to clean up any
+# processes we may have left running
+def cleanup_processes():
+ with process_lock:
+ for process in processes:
+ try:
+ process.kill()
+ except Exception:
+ pass
+atexit.register(cleanup_processes)
+
+
+def interrupt_and_wait(process):
+ with process_lock:
+ processes.append(process)
+ time.sleep(INIT_TIME)
+ os.kill(process.pid, signal.SIGINT)
+ process.wait()
+
+
+def wait(process):
+ with process_lock:
+ processes.append(process)
+ process.wait()
+
+
+class ExitTest(unittest.TestCase):
+
+ def test_unstarted_server(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.UNSTARTED_SERVER],
+ stdout=sys.stdout, stderr=sys.stderr)
+ wait(process)
+
+ def test_unstarted_server_terminate(self):
+ process = subprocess.Popen(
+ BASE_SIGTERM_COMMAND + [_exit_scenarios.UNSTARTED_SERVER],
+ stdout=sys.stdout)
+ interrupt_and_wait(process)
+
+ def test_running_server(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.RUNNING_SERVER],
+ stdout=sys.stdout, stderr=sys.stderr)
+ wait(process)
+
+ def test_running_server_terminate(self):
+ process = subprocess.Popen(
+ BASE_SIGTERM_COMMAND + [_exit_scenarios.RUNNING_SERVER],
+ stdout=sys.stdout, stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ def test_poll_connectivity_no_server(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
+ stdout=sys.stdout, stderr=sys.stderr)
+ wait(process)
+
+ def test_poll_connectivity_no_server_terminate(self):
+ process = subprocess.Popen(
+ BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
+ stdout=sys.stdout, stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ def test_poll_connectivity(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY],
+ stdout=sys.stdout, stderr=sys.stderr)
+ wait(process)
+
+ def test_poll_connectivity_terminate(self):
+ process = subprocess.Popen(
+ BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY],
+ stdout=sys.stdout, stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ def test_in_flight_unary_unary_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL],
+ stdout=sys.stdout, stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ def test_in_flight_unary_stream_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL],
+ stdout=sys.stdout, stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ def test_in_flight_stream_unary_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL],
+ stdout=sys.stdout, stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ def test_in_flight_stream_stream_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL],
+ stdout=sys.stdout, stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ def test_in_flight_partial_unary_stream_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL],
+ stdout=sys.stdout, stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ def test_in_flight_partial_stream_unary_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL],
+ stdout=sys.stdout, stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ def test_in_flight_partial_stream_stream_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL],
+ stdout=sys.stdout, stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_metadata_code_details_test.py b/src/python/grpcio/tests/unit/_metadata_code_details_test.py
new file mode 100644
index 0000000000..0fd02d2a22
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_metadata_code_details_test.py
@@ -0,0 +1,523 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests application-provided metadata, status code, and details."""
+
+import threading
+import unittest
+
+import grpc
+from grpc.framework.foundation import logging_pool
+
+from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+from tests.unit.framework.common import test_control
+
+_SERIALIZED_REQUEST = b'\x46\x47\x48'
+_SERIALIZED_RESPONSE = b'\x49\x50\x51'
+
+_REQUEST_SERIALIZER = lambda unused_request: _SERIALIZED_REQUEST
+_REQUEST_DESERIALIZER = lambda unused_serialized_request: object()
+_RESPONSE_SERIALIZER = lambda unused_response: _SERIALIZED_RESPONSE
+_RESPONSE_DESERIALIZER = lambda unused_serialized_resopnse: object()
+
+_SERVICE = 'test.TestService'
+_UNARY_UNARY = 'UnaryUnary'
+_UNARY_STREAM = 'UnaryStream'
+_STREAM_UNARY = 'StreamUnary'
+_STREAM_STREAM = 'StreamStream'
+
+_CLIENT_METADATA = (
+ ('client-md-key', 'client-md-key'),
+ ('client-md-key-bin', b'\x00\x01')
+)
+
+_SERVER_INITIAL_METADATA = (
+ ('server-initial-md-key', 'server-initial-md-value'),
+ ('server-initial-md-key-bin', b'\x00\x02')
+)
+
+_SERVER_TRAILING_METADATA = (
+ ('server-trailing-md-key', 'server-trailing-md-value'),
+ ('server-trailing-md-key-bin', b'\x00\x03')
+)
+
+_NON_OK_CODE = grpc.StatusCode.NOT_FOUND
+_DETAILS = 'Test details!'
+
+
+class _Servicer(object):
+
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._code = None
+ self._details = None
+ self._exception = False
+ self._return_none = False
+ self._received_client_metadata = None
+
+ def unary_unary(self, request, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ if self._exception:
+ raise test_control.Defect()
+ else:
+ return None if self._return_none else object()
+
+ def unary_stream(self, request, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ yield _SERIALIZED_RESPONSE
+ if self._exception:
+ raise test_control.Defect()
+
+ def stream_unary(self, request_iterator, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
+ # request iterator.
+ for ignored_request in request_iterator:
+ pass
+ if self._exception:
+ raise test_control.Defect()
+ else:
+ return None if self._return_none else _SERIALIZED_RESPONSE
+
+ def stream_stream(self, request_iterator, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
+ # request iterator.
+ for ignored_request in request_iterator:
+ pass
+ for _ in range(test_constants.STREAM_LENGTH // 3):
+ yield object()
+ if self._exception:
+ raise test_control.Defect()
+
+ def set_code(self, code):
+ with self._lock:
+ self._code = code
+
+ def set_details(self, details):
+ with self._lock:
+ self._details = details
+
+ def set_exception(self):
+ with self._lock:
+ self._exception = True
+
+ def set_return_none(self):
+ with self._lock:
+ self._return_none = True
+
+ def received_client_metadata(self):
+ with self._lock:
+ return self._received_client_metadata
+
+
+def _generic_handler(servicer):
+ method_handlers = {
+ _UNARY_UNARY: grpc.unary_unary_rpc_method_handler(
+ servicer.unary_unary, request_deserializer=_REQUEST_DESERIALIZER,
+ response_serializer=_RESPONSE_SERIALIZER),
+ _UNARY_STREAM: grpc.unary_stream_rpc_method_handler(
+ servicer.unary_stream),
+ _STREAM_UNARY: grpc.stream_unary_rpc_method_handler(
+ servicer.stream_unary),
+ _STREAM_STREAM: grpc.stream_stream_rpc_method_handler(
+ servicer.stream_stream, request_deserializer=_REQUEST_DESERIALIZER,
+ response_serializer=_RESPONSE_SERIALIZER),
+ }
+ return grpc.method_handlers_generic_handler(_SERVICE, method_handlers)
+
+
+class MetadataCodeDetailsTest(unittest.TestCase):
+
+ def setUp(self):
+ self._servicer = _Servicer()
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server(
+ (_generic_handler(self._servicer),), self._server_pool)
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ self._unary_unary = channel.unary_unary(
+ '/'.join(('', _SERVICE, _UNARY_UNARY,)),
+ request_serializer=_REQUEST_SERIALIZER,
+ response_deserializer=_RESPONSE_DESERIALIZER,)
+ self._unary_stream = channel.unary_stream(
+ '/'.join(('', _SERVICE, _UNARY_STREAM,)),)
+ self._stream_unary = channel.stream_unary(
+ '/'.join(('', _SERVICE, _STREAM_UNARY,)),)
+ self._stream_stream = channel.stream_stream(
+ '/'.join(('', _SERVICE, _STREAM_STREAM,)),
+ request_serializer=_REQUEST_SERIALIZER,
+ response_deserializer=_RESPONSE_DESERIALIZER,)
+
+
+ def testSuccessfulUnaryUnary(self):
+ self._servicer.set_details(_DETAILS)
+
+ unused_response, call = self._unary_unary.with_call(
+ object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testSuccessfulUnaryStream(self):
+ self._servicer.set_details(_DETAILS)
+
+ call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testSuccessfulStreamUnary(self):
+ self._servicer.set_details(_DETAILS)
+
+ unused_response, call = self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testSuccessfulStreamStream(self):
+ self._servicer.set_details(_DETAILS)
+
+ call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeUnaryStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeStreamStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeExceptionUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeExceptionUnaryStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeExceptionStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeExceptionStreamStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeReturnNoneUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_return_none()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeReturnNoneStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_return_none()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_metadata_test.py b/src/python/grpcio/tests/unit/_metadata_test.py
index 2cb13f236b..c637a28039 100644
--- a/src/python/grpcio/tests/unit/_metadata_test.py
+++ b/src/python/grpcio/tests/unit/_metadata_test.py
@@ -44,33 +44,33 @@ _CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'),
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x00\x00\x00'
-_UNARY_UNARY = b'/test/UnaryUnary'
-_UNARY_STREAM = b'/test/UnaryStream'
-_STREAM_UNARY = b'/test/StreamUnary'
-_STREAM_STREAM = b'/test/StreamStream'
+_UNARY_UNARY = '/test/UnaryUnary'
+_UNARY_STREAM = '/test/UnaryStream'
+_STREAM_UNARY = '/test/StreamUnary'
+_STREAM_STREAM = '/test/StreamStream'
_USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
_CLIENT_METADATA = (
- (b'client-md-key', b'client-md-key'),
- (b'client-md-key-bin', b'\x00\x01')
+ ('client-md-key', 'client-md-key'),
+ ('client-md-key-bin', b'\x00\x01')
)
_SERVER_INITIAL_METADATA = (
- (b'server-initial-md-key', b'server-initial-md-value'),
- (b'server-initial-md-key-bin', b'\x00\x02')
+ ('server-initial-md-key', 'server-initial-md-value'),
+ ('server-initial-md-key-bin', b'\x00\x02')
)
_SERVER_TRAILING_METADATA = (
- (b'server-trailing-md-key', b'server-trailing-md-value'),
- (b'server-trailing-md-key-bin', b'\x00\x03')
+ ('server-trailing-md-key', 'server-trailing-md-value'),
+ ('server-trailing-md-key-bin', b'\x00\x03')
)
def user_agent(metadata):
for key, val in metadata:
- if key == b'user-agent':
- return val.decode('ascii')
+ if key == 'user-agent':
+ return val
raise KeyError('No user agent!')
diff --git a/src/python/grpcio/tests/unit/_rpc_test.py b/src/python/grpcio/tests/unit/_rpc_test.py
index 9814504edf..c70d65a6df 100644
--- a/src/python/grpcio/tests/unit/_rpc_test.py
+++ b/src/python/grpcio/tests/unit/_rpc_test.py
@@ -45,10 +45,10 @@ _DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:]
_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]
-_UNARY_UNARY = b'/test/UnaryUnary'
-_UNARY_STREAM = b'/test/UnaryStream'
-_STREAM_UNARY = b'/test/StreamUnary'
-_STREAM_STREAM = b'/test/StreamStream'
+_UNARY_UNARY = '/test/UnaryUnary'
+_UNARY_STREAM = '/test/UnaryStream'
+_STREAM_UNARY = '/test/StreamUnary'
+_STREAM_STREAM = '/test/StreamStream'
class _Callback(object):
@@ -79,7 +79,7 @@ class _Handler(object):
def handle_unary_unary(self, request, servicer_context):
self._control.control()
if servicer_context is not None:
- servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
+ servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
return request
def handle_unary_stream(self, request, servicer_context):
@@ -88,7 +88,7 @@ class _Handler(object):
yield request
self._control.control()
if servicer_context is not None:
- servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
+ servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
def handle_stream_unary(self, request_iterator, servicer_context):
if servicer_context is not None:
@@ -100,13 +100,13 @@ class _Handler(object):
response_elements.append(request)
self._control.control()
if servicer_context is not None:
- servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
+ servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
return b''.join(response_elements)
def handle_stream_stream(self, request_iterator, servicer_context):
self._control.control()
if servicer_context is not None:
- servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
+ servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
for request in request_iterator:
self._control.control()
yield request
@@ -185,7 +185,7 @@ class RPCTest(unittest.TestCase):
self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
self._server = grpc.server((), self._server_pool)
- port = self._server.add_insecure_port(b'[::]:0')
+ port = self._server.add_insecure_port('[::]:0')
self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
self._server.start()
@@ -195,7 +195,7 @@ class RPCTest(unittest.TestCase):
request = b'abc'
with self.assertRaises(grpc.RpcError) as exception_context:
- self._channel.unary_unary(b'NoSuchMethod')(request)
+ self._channel.unary_unary('NoSuchMethod')(request)
self.assertEqual(
grpc.StatusCode.UNIMPLEMENTED, exception_context.exception.code())
@@ -207,7 +207,7 @@ class RPCTest(unittest.TestCase):
multi_callable = _unary_unary_multi_callable(self._channel)
response = multi_callable(
request, metadata=(
- (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponse'),))
+ ('test', 'SuccessfulUnaryRequestBlockingUnaryResponse'),))
self.assertEqual(expected_response, response)
@@ -218,7 +218,7 @@ class RPCTest(unittest.TestCase):
multi_callable = _unary_unary_multi_callable(self._channel)
response, call = multi_callable.with_call(
request, metadata=(
- (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),))
+ ('test', 'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),))
self.assertEqual(expected_response, response)
self.assertIs(grpc.StatusCode.OK, call.code())
@@ -230,7 +230,7 @@ class RPCTest(unittest.TestCase):
multi_callable = _unary_unary_multi_callable(self._channel)
response_future = multi_callable.future(
request, metadata=(
- (b'test', b'SuccessfulUnaryRequestFutureUnaryResponse'),))
+ ('test', 'SuccessfulUnaryRequestFutureUnaryResponse'),))
response = response_future.result()
self.assertEqual(expected_response, response)
@@ -242,7 +242,7 @@ class RPCTest(unittest.TestCase):
multi_callable = _unary_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request,
- metadata=((b'test', b'SuccessfulUnaryRequestStreamResponse'),))
+ metadata=(('test', 'SuccessfulUnaryRequestStreamResponse'),))
responses = tuple(response_iterator)
self.assertSequenceEqual(expected_responses, responses)
@@ -255,7 +255,7 @@ class RPCTest(unittest.TestCase):
multi_callable = _stream_unary_multi_callable(self._channel)
response = multi_callable(
request_iterator,
- metadata=((b'test', b'SuccessfulStreamRequestBlockingUnaryResponse'),))
+ metadata=(('test', 'SuccessfulStreamRequestBlockingUnaryResponse'),))
self.assertEqual(expected_response, response)
@@ -268,7 +268,7 @@ class RPCTest(unittest.TestCase):
response, call = multi_callable.with_call(
request_iterator,
metadata=(
- (b'test', b'SuccessfulStreamRequestBlockingUnaryResponseWithCall'),
+ ('test', 'SuccessfulStreamRequestBlockingUnaryResponseWithCall'),
))
self.assertEqual(expected_response, response)
@@ -283,7 +283,7 @@ class RPCTest(unittest.TestCase):
response_future = multi_callable.future(
request_iterator,
metadata=(
- (b'test', b'SuccessfulStreamRequestFutureUnaryResponse'),))
+ ('test', 'SuccessfulStreamRequestFutureUnaryResponse'),))
response = response_future.result()
self.assertEqual(expected_response, response)
@@ -297,7 +297,7 @@ class RPCTest(unittest.TestCase):
multi_callable = _stream_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request_iterator,
- metadata=((b'test', b'SuccessfulStreamRequestStreamResponse'),))
+ metadata=(('test', 'SuccessfulStreamRequestStreamResponse'),))
responses = tuple(response_iterator)
self.assertSequenceEqual(expected_responses, responses)
@@ -312,9 +312,9 @@ class RPCTest(unittest.TestCase):
multi_callable = _unary_unary_multi_callable(self._channel)
first_response = multi_callable(
- first_request, metadata=((b'test', b'SequentialInvocations'),))
+ first_request, metadata=(('test', 'SequentialInvocations'),))
second_response = multi_callable(
- second_request, metadata=((b'test', b'SequentialInvocations'),))
+ second_request, metadata=(('test', 'SequentialInvocations'),))
self.assertEqual(expected_first_response, first_response)
self.assertEqual(expected_second_response, second_response)
@@ -331,7 +331,7 @@ class RPCTest(unittest.TestCase):
request_iterator = iter(requests)
response_future = pool.submit(
multi_callable, request_iterator,
- metadata=((b'test', b'ConcurrentBlockingInvocations'),))
+ metadata=(('test', 'ConcurrentBlockingInvocations'),))
response_futures[index] = response_future
responses = tuple(
response_future.result() for response_future in response_futures)
@@ -350,7 +350,7 @@ class RPCTest(unittest.TestCase):
request_iterator = iter(requests)
response_future = multi_callable.future(
request_iterator,
- metadata=((b'test', b'ConcurrentFutureInvocations'),))
+ metadata=(('test', 'ConcurrentFutureInvocations'),))
response_futures[index] = response_future
responses = tuple(
response_future.result() for response_future in response_futures)
@@ -380,8 +380,8 @@ class RPCTest(unittest.TestCase):
inner_response_future = multi_callable.future(
request,
metadata=(
- (b'test',
- b'WaitingForSomeButNotAllConcurrentFutureInvocations'),))
+ ('test',
+ 'WaitingForSomeButNotAllConcurrentFutureInvocations'),))
outer_response_future = pool.submit(wrap_future(inner_response_future))
response_futures[index] = outer_response_future
@@ -400,7 +400,7 @@ class RPCTest(unittest.TestCase):
response_iterator = multi_callable(
request,
metadata=(
- (b'test', b'ConsumingOneStreamResponseUnaryRequest'),))
+ ('test', 'ConsumingOneStreamResponseUnaryRequest'),))
next(response_iterator)
def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self):
@@ -410,7 +410,7 @@ class RPCTest(unittest.TestCase):
response_iterator = multi_callable(
request,
metadata=(
- (b'test', b'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),))
+ ('test', 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),))
for _ in range(test_constants.STREAM_LENGTH // 2):
next(response_iterator)
@@ -422,7 +422,7 @@ class RPCTest(unittest.TestCase):
response_iterator = multi_callable(
request_iterator,
metadata=(
- (b'test', b'ConsumingSomeButNotAllStreamResponsesStreamRequest'),))
+ ('test', 'ConsumingSomeButNotAllStreamResponsesStreamRequest'),))
for _ in range(test_constants.STREAM_LENGTH // 2):
next(response_iterator)
@@ -434,7 +434,7 @@ class RPCTest(unittest.TestCase):
response_iterator = multi_callable(
request_iterator,
metadata=(
- (b'test', b'ConsumingTooManyStreamResponsesStreamRequest'),))
+ ('test', 'ConsumingTooManyStreamResponsesStreamRequest'),))
for _ in range(test_constants.STREAM_LENGTH):
next(response_iterator)
for _ in range(test_constants.STREAM_LENGTH):
@@ -453,7 +453,7 @@ class RPCTest(unittest.TestCase):
with self._control.pause():
response_future = multi_callable.future(
request,
- metadata=((b'test', b'CancelledUnaryRequestUnaryResponse'),))
+ metadata=(('test', 'CancelledUnaryRequestUnaryResponse'),))
response_future.cancel()
self.assertTrue(response_future.cancelled())
@@ -468,7 +468,7 @@ class RPCTest(unittest.TestCase):
with self._control.pause():
response_iterator = multi_callable(
request,
- metadata=((b'test', b'CancelledUnaryRequestStreamResponse'),))
+ metadata=(('test', 'CancelledUnaryRequestStreamResponse'),))
self._control.block_until_paused()
response_iterator.cancel()
@@ -488,7 +488,7 @@ class RPCTest(unittest.TestCase):
with self._control.pause():
response_future = multi_callable.future(
request_iterator,
- metadata=((b'test', b'CancelledStreamRequestUnaryResponse'),))
+ metadata=(('test', 'CancelledStreamRequestUnaryResponse'),))
self._control.block_until_paused()
response_future.cancel()
@@ -508,7 +508,7 @@ class RPCTest(unittest.TestCase):
with self._control.pause():
response_iterator = multi_callable(
request_iterator,
- metadata=((b'test', b'CancelledStreamRequestStreamResponse'),))
+ metadata=(('test', 'CancelledStreamRequestStreamResponse'),))
response_iterator.cancel()
with self.assertRaises(grpc.RpcError):
@@ -526,7 +526,7 @@ class RPCTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable.with_call(
request, timeout=test_constants.SHORT_TIMEOUT,
- metadata=((b'test', b'ExpiredUnaryRequestBlockingUnaryResponse'),))
+ metadata=(('test', 'ExpiredUnaryRequestBlockingUnaryResponse'),))
self.assertIsNotNone(exception_context.exception.initial_metadata())
self.assertIs(
@@ -542,7 +542,7 @@ class RPCTest(unittest.TestCase):
with self._control.pause():
response_future = multi_callable.future(
request, timeout=test_constants.SHORT_TIMEOUT,
- metadata=((b'test', b'ExpiredUnaryRequestFutureUnaryResponse'),))
+ metadata=(('test', 'ExpiredUnaryRequestFutureUnaryResponse'),))
response_future.add_done_callback(callback)
value_passed_to_callback = callback.value()
@@ -567,7 +567,7 @@ class RPCTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request, timeout=test_constants.SHORT_TIMEOUT,
- metadata=((b'test', b'ExpiredUnaryRequestStreamResponse'),))
+ metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),))
next(response_iterator)
self.assertIs(
@@ -583,7 +583,7 @@ class RPCTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable(
request_iterator, timeout=test_constants.SHORT_TIMEOUT,
- metadata=((b'test', b'ExpiredStreamRequestBlockingUnaryResponse'),))
+ metadata=(('test', 'ExpiredStreamRequestBlockingUnaryResponse'),))
self.assertIsNotNone(exception_context.exception.initial_metadata())
self.assertIs(
@@ -600,7 +600,7 @@ class RPCTest(unittest.TestCase):
with self._control.pause():
response_future = multi_callable.future(
request_iterator, timeout=test_constants.SHORT_TIMEOUT,
- metadata=((b'test', b'ExpiredStreamRequestFutureUnaryResponse'),))
+ metadata=(('test', 'ExpiredStreamRequestFutureUnaryResponse'),))
response_future.add_done_callback(callback)
value_passed_to_callback = callback.value()
@@ -625,7 +625,7 @@ class RPCTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request_iterator, timeout=test_constants.SHORT_TIMEOUT,
- metadata=((b'test', b'ExpiredStreamRequestStreamResponse'),))
+ metadata=(('test', 'ExpiredStreamRequestStreamResponse'),))
next(response_iterator)
self.assertIs(
@@ -640,7 +640,7 @@ class RPCTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable.with_call(
request,
- metadata=((b'test', b'FailedUnaryRequestBlockingUnaryResponse'),))
+ metadata=(('test', 'FailedUnaryRequestBlockingUnaryResponse'),))
self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
@@ -652,7 +652,7 @@ class RPCTest(unittest.TestCase):
with self._control.fail():
response_future = multi_callable.future(
request,
- metadata=((b'test', b'FailedUnaryRequestFutureUnaryResponse'),))
+ metadata=(('test', 'FailedUnaryRequestFutureUnaryResponse'),))
response_future.add_done_callback(callback)
value_passed_to_callback = callback.value()
@@ -672,7 +672,7 @@ class RPCTest(unittest.TestCase):
with self._control.fail():
response_iterator = multi_callable(
request,
- metadata=((b'test', b'FailedUnaryRequestStreamResponse'),))
+ metadata=(('test', 'FailedUnaryRequestStreamResponse'),))
next(response_iterator)
self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
@@ -686,7 +686,7 @@ class RPCTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable(
request_iterator,
- metadata=((b'test', b'FailedStreamRequestBlockingUnaryResponse'),))
+ metadata=(('test', 'FailedStreamRequestBlockingUnaryResponse'),))
self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
@@ -699,7 +699,7 @@ class RPCTest(unittest.TestCase):
with self._control.fail():
response_future = multi_callable.future(
request_iterator,
- metadata=((b'test', b'FailedStreamRequestFutureUnaryResponse'),))
+ metadata=(('test', 'FailedStreamRequestFutureUnaryResponse'),))
response_future.add_done_callback(callback)
value_passed_to_callback = callback.value()
@@ -720,7 +720,7 @@ class RPCTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError) as exception_context:
response_iterator = multi_callable(
request_iterator,
- metadata=((b'test', b'FailedStreamRequestStreamResponse'),))
+ metadata=(('test', 'FailedStreamRequestStreamResponse'),))
tuple(response_iterator)
self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
@@ -732,7 +732,7 @@ class RPCTest(unittest.TestCase):
multi_callable = _unary_unary_multi_callable(self._channel)
multi_callable.future(
request,
- metadata=((b'test', b'IgnoredUnaryRequestFutureUnaryResponse'),))
+ metadata=(('test', 'IgnoredUnaryRequestFutureUnaryResponse'),))
def testIgnoredUnaryRequestStreamResponse(self):
request = b'\x37\x17'
@@ -740,7 +740,7 @@ class RPCTest(unittest.TestCase):
multi_callable = _unary_stream_multi_callable(self._channel)
multi_callable(
request,
- metadata=((b'test', b'IgnoredUnaryRequestStreamResponse'),))
+ metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),))
def testIgnoredStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
@@ -749,7 +749,7 @@ class RPCTest(unittest.TestCase):
multi_callable = _stream_unary_multi_callable(self._channel)
multi_callable.future(
request_iterator,
- metadata=((b'test', b'IgnoredStreamRequestFutureUnaryResponse'),))
+ metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),))
def testIgnoredStreamRequestStreamResponse(self):
requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
@@ -758,7 +758,7 @@ class RPCTest(unittest.TestCase):
multi_callable = _stream_stream_multi_callable(self._channel)
multi_callable(
request_iterator,
- metadata=((b'test', b'IgnoredStreamRequestStreamResponse'),))
+ metadata=(('test', 'IgnoredStreamRequestStreamResponse'),))
if __name__ == '__main__':
diff --git a/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py b/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py
index 488f7d7141..5d826a269d 100644
--- a/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py
+++ b/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py
@@ -29,162 +29,9 @@
"""Tests of grpc.beta._connectivity_channel."""
-import threading
-import time
import unittest
-from grpc._adapter import _low
-from grpc._adapter import _types
-from grpc.beta import _connectivity_channel
from grpc.beta import interfaces
-from tests.unit.framework.common import test_constants
-
-
-def _drive_completion_queue(completion_queue):
- while True:
- event = completion_queue.next(time.time() + 24 * 60 * 60)
- if event.type == _types.EventType.QUEUE_SHUTDOWN:
- break
-
-
-class _Callback(object):
-
- def __init__(self):
- self._condition = threading.Condition()
- self._connectivities = []
-
- def update(self, connectivity):
- with self._condition:
- self._connectivities.append(connectivity)
- self._condition.notify()
-
- def connectivities(self):
- with self._condition:
- return tuple(self._connectivities)
-
- def block_until_connectivities_satisfy(self, predicate):
- with self._condition:
- while True:
- connectivities = tuple(self._connectivities)
- if predicate(connectivities):
- return connectivities
- else:
- self._condition.wait()
-
-
-class ChannelConnectivityTest(unittest.TestCase):
-
- def test_lonely_channel_connectivity(self):
- low_channel = _low.Channel('localhost:12345', ())
- callback = _Callback()
-
- connectivity_channel = _connectivity_channel.ConnectivityChannel(
- low_channel)
- connectivity_channel.subscribe(callback.update, try_to_connect=False)
- first_connectivities = callback.block_until_connectivities_satisfy(bool)
- connectivity_channel.subscribe(callback.update, try_to_connect=True)
- second_connectivities = callback.block_until_connectivities_satisfy(
- lambda connectivities: 2 <= len(connectivities))
- # Wait for a connection that will never happen.
- time.sleep(test_constants.SHORT_TIMEOUT)
- third_connectivities = callback.connectivities()
- connectivity_channel.unsubscribe(callback.update)
- fourth_connectivities = callback.connectivities()
- connectivity_channel.unsubscribe(callback.update)
- fifth_connectivities = callback.connectivities()
-
- self.assertSequenceEqual(
- (interfaces.ChannelConnectivity.IDLE,), first_connectivities)
- self.assertNotIn(
- interfaces.ChannelConnectivity.READY, second_connectivities)
- self.assertNotIn(
- interfaces.ChannelConnectivity.READY, third_connectivities)
- self.assertNotIn(
- interfaces.ChannelConnectivity.READY, fourth_connectivities)
- self.assertNotIn(
- interfaces.ChannelConnectivity.READY, fifth_connectivities)
-
- def test_immediately_connectable_channel_connectivity(self):
- server_completion_queue = _low.CompletionQueue()
- server = _low.Server(server_completion_queue, [])
- port = server.add_http2_port('[::]:0')
- server.start()
- server_completion_queue_thread = threading.Thread(
- target=_drive_completion_queue, args=(server_completion_queue,))
- server_completion_queue_thread.start()
- low_channel = _low.Channel('localhost:%d' % port, ())
- first_callback = _Callback()
- second_callback = _Callback()
-
- connectivity_channel = _connectivity_channel.ConnectivityChannel(
- low_channel)
- connectivity_channel.subscribe(first_callback.update, try_to_connect=False)
- first_connectivities = first_callback.block_until_connectivities_satisfy(
- bool)
- # Wait for a connection that will never happen because try_to_connect=True
- # has not yet been passed.
- time.sleep(test_constants.SHORT_TIMEOUT)
- second_connectivities = first_callback.connectivities()
- connectivity_channel.subscribe(second_callback.update, try_to_connect=True)
- third_connectivities = first_callback.block_until_connectivities_satisfy(
- lambda connectivities: 2 <= len(connectivities))
- fourth_connectivities = second_callback.block_until_connectivities_satisfy(
- bool)
- # Wait for a connection that will happen (or may already have happened).
- first_callback.block_until_connectivities_satisfy(
- lambda connectivities:
- interfaces.ChannelConnectivity.READY in connectivities)
- second_callback.block_until_connectivities_satisfy(
- lambda connectivities:
- interfaces.ChannelConnectivity.READY in connectivities)
- connectivity_channel.unsubscribe(first_callback.update)
- connectivity_channel.unsubscribe(second_callback.update)
-
- server.shutdown()
- server_completion_queue.shutdown()
- server_completion_queue_thread.join()
-
- self.assertSequenceEqual(
- (interfaces.ChannelConnectivity.IDLE,), first_connectivities)
- self.assertSequenceEqual(
- (interfaces.ChannelConnectivity.IDLE,), second_connectivities)
- self.assertNotIn(
- interfaces.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities)
- self.assertNotIn(
- interfaces.ChannelConnectivity.FATAL_FAILURE, third_connectivities)
- self.assertNotIn(
- interfaces.ChannelConnectivity.TRANSIENT_FAILURE,
- fourth_connectivities)
- self.assertNotIn(
- interfaces.ChannelConnectivity.FATAL_FAILURE, fourth_connectivities)
-
- def test_reachable_then_unreachable_channel_connectivity(self):
- server_completion_queue = _low.CompletionQueue()
- server = _low.Server(server_completion_queue, [])
- port = server.add_http2_port('[::]:0')
- server.start()
- server_completion_queue_thread = threading.Thread(
- target=_drive_completion_queue, args=(server_completion_queue,))
- server_completion_queue_thread.start()
- low_channel = _low.Channel('localhost:%d' % port, ())
- callback = _Callback()
-
- connectivity_channel = _connectivity_channel.ConnectivityChannel(
- low_channel)
- connectivity_channel.subscribe(callback.update, try_to_connect=True)
- callback.block_until_connectivities_satisfy(
- lambda connectivities:
- interfaces.ChannelConnectivity.READY in connectivities)
- # Now take down the server and confirm that channel readiness is repudiated.
- server.shutdown()
- callback.block_until_connectivities_satisfy(
- lambda connectivities:
- connectivities[-1] is not interfaces.ChannelConnectivity.READY)
- connectivity_channel.unsubscribe(callback.update)
-
- server.shutdown()
- server_completion_queue.shutdown()
- server_completion_queue_thread.join()
class ConnectivityStatesTest(unittest.TestCase):
diff --git a/src/python/grpcio/tests/unit/beta/_utilities_test.py b/src/python/grpcio/tests/unit/beta/_utilities_test.py
index 08ce98e751..90fe10c77c 100644
--- a/src/python/grpcio/tests/unit/beta/_utilities_test.py
+++ b/src/python/grpcio/tests/unit/beta/_utilities_test.py
@@ -33,21 +33,12 @@ import threading
import time
import unittest
-from grpc._adapter import _low
-from grpc._adapter import _types
from grpc.beta import implementations
from grpc.beta import utilities
from grpc.framework.foundation import future
from tests.unit.framework.common import test_constants
-def _drive_completion_queue(completion_queue):
- while True:
- event = completion_queue.next(time.time() + 24 * 60 * 60)
- if event.type == _types.EventType.QUEUE_SHUTDOWN:
- break
-
-
class _Callback(object):
def __init__(self):
@@ -87,13 +78,9 @@ class ChannelConnectivityTest(unittest.TestCase):
self.assertFalse(ready_future.running())
def test_immediately_connectable_channel_connectivity(self):
- server_completion_queue = _low.CompletionQueue()
- server = _low.Server(server_completion_queue, [])
- port = server.add_http2_port('[::]:0')
+ server = implementations.server({})
+ port = server.add_insecure_port('[::]:0')
server.start()
- server_completion_queue_thread = threading.Thread(
- target=_drive_completion_queue, args=(server_completion_queue,))
- server_completion_queue_thread.start()
channel = implementations.insecure_channel('localhost', port)
callback = _Callback()
@@ -114,9 +101,7 @@ class ChannelConnectivityTest(unittest.TestCase):
self.assertFalse(ready_future.running())
finally:
ready_future.cancel()
- server.shutdown()
- server_completion_queue.shutdown()
- server_completion_queue_thread.join()
+ server.stop(0)
if __name__ == '__main__':
diff --git a/src/python/grpcio/tests/unit/beta/test_utilities.py b/src/python/grpcio/tests/unit/beta/test_utilities.py
index 8ccad04e05..692da9c97d 100644
--- a/src/python/grpcio/tests/unit/beta/test_utilities.py
+++ b/src/python/grpcio/tests/unit/beta/test_utilities.py
@@ -51,5 +51,5 @@ def not_really_secure_channel(
target = '%s:%d' % (host, port)
channel = grpc.secure_channel(
target, channel_credentials,
- ((b'grpc.ssl_target_name_override', server_host_override,),))
+ (('grpc.ssl_target_name_override', server_host_override,),))
return implementations.Channel(channel)
diff --git a/src/python/grpcio/tests/unit/test_common.py b/src/python/grpcio/tests/unit/test_common.py
index b779f65e7e..c8886bf4ca 100644
--- a/src/python/grpcio/tests/unit/test_common.py
+++ b/src/python/grpcio/tests/unit/test_common.py
@@ -33,10 +33,10 @@ import collections
import six
-INVOCATION_INITIAL_METADATA = ((b'0', b'abc'), (b'1', b'def'), (b'2', b'ghi'),)
-SERVICE_INITIAL_METADATA = ((b'3', b'jkl'), (b'4', b'mno'), (b'5', b'pqr'),)
-SERVICE_TERMINAL_METADATA = ((b'6', b'stu'), (b'7', b'vwx'), (b'8', b'yza'),)
-DETAILS = b'test details'
+INVOCATION_INITIAL_METADATA = (('0', 'abc'), ('1', 'def'), ('2', 'ghi'),)
+SERVICE_INITIAL_METADATA = (('3', 'jkl'), ('4', 'mno'), ('5', 'pqr'),)
+SERVICE_TERMINAL_METADATA = (('6', 'stu'), ('7', 'vwx'), ('8', 'yza'),)
+DETAILS = 'test details'
def metadata_transmitted(original_metadata, transmitted_metadata):
@@ -59,16 +59,10 @@ def metadata_transmitted(original_metadata, transmitted_metadata):
original_metadata after having been transmitted via gRPC.
"""
original = collections.defaultdict(list)
- for key_value_pair in original_metadata:
- key, value = tuple(key_value_pair)
- if not isinstance(key, bytes):
- key = key.encode()
- if not isinstance(value, bytes):
- value = value.encode()
+ for key, value in original_metadata:
original[key].append(value)
transmitted = collections.defaultdict(list)
- for key_value_pair in transmitted_metadata:
- key, value = tuple(key_value_pair)
+ for key, value in transmitted_metadata:
transmitted[key].append(value)
for key, values in six.iteritems(original):