aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/grpcio')
-rw-r--r--src/python/grpcio/commands.py4
-rw-r--r--src/python/grpcio/grpc/__init__.py224
-rw-r--r--src/python/grpcio/grpc/_adapter/_low.py76
-rw-r--r--src/python/grpcio/grpc/_adapter/_types.py2
-rw-r--r--src/python/grpcio/grpc/_auth.py (renamed from src/python/grpcio/grpc/beta/_auth.py)21
-rw-r--r--src/python/grpcio/grpc/_channel.py55
-rw-r--r--src/python/grpcio/grpc/_common.py59
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi7
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi24
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi3
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi59
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi2
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi53
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi6
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi39
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi70
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi9
-rw-r--r--src/python/grpcio/grpc/_cython/cygrpc.pyx1
-rw-r--r--src/python/grpcio/grpc/_cython/imports.generated.c10
-rw-r--r--src/python/grpcio/grpc/_cython/imports.generated.h21
-rw-r--r--src/python/grpcio/grpc/_links/service.py6
-rw-r--r--src/python/grpcio/grpc/_server.py91
-rw-r--r--src/python/grpcio/grpc/_utilities.py26
-rw-r--r--src/python/grpcio/grpc/beta/_client_adaptations.py563
-rw-r--r--src/python/grpcio/grpc/beta/_server_adaptations.py367
-rw-r--r--src/python/grpcio/grpc/beta/implementations.py220
-rw-r--r--src/python/grpcio/grpc/beta/interfaces.py90
-rw-r--r--src/python/grpcio/grpc_core_dependencies.py13
-rw-r--r--src/python/grpcio/grpc_version.py2
-rw-r--r--src/python/grpcio/tests/interop/client.py3
-rw-r--r--src/python/grpcio/tests/interop/methods.py13
-rw-r--r--src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py583
-rw-r--r--src/python/grpcio/tests/qps/benchmark_client.py102
-rw-r--r--src/python/grpcio/tests/qps/client_runner.py14
-rw-r--r--src/python/grpcio/tests/qps/qps_worker.py4
-rw-r--r--src/python/grpcio/tests/qps/worker_server.py5
-rw-r--r--src/python/grpcio/tests/stress/client.py6
-rw-r--r--src/python/grpcio/tests/tests.json40
-rw-r--r--src/python/grpcio/tests/unit/_adapter/_intermediary_low_test.py428
-rw-r--r--src/python/grpcio/tests/unit/_adapter/_low_test.py319
-rw-r--r--src/python/grpcio/tests/unit/_api_test.py111
-rw-r--r--src/python/grpcio/tests/unit/_auth_test.py (renamed from src/python/grpcio/tests/unit/beta/_auth_test.py)2
-rw-r--r--src/python/grpcio/tests/unit/_channel_connectivity_test.py4
-rw-r--r--src/python/grpcio/tests/unit/_core_over_links_base_interface_test.py157
-rw-r--r--src/python/grpcio/tests/unit/_crust_over_core_over_links_face_interface_test.py163
-rw-r--r--src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py251
-rw-r--r--src/python/grpcio/tests/unit/_cython/cygrpc_test.py286
-rw-r--r--src/python/grpcio/tests/unit/_cython/test_utilities.py34
-rw-r--r--src/python/grpcio/tests/unit/_empty_message_test.py137
-rw-r--r--src/python/grpcio/tests/unit/_from_grpc_import_star.py (renamed from src/python/grpcio/grpc/_adapter/_implementations.py)24
-rw-r--r--src/python/grpcio/tests/unit/_links/_lonely_invocation_link_test.py88
-rw-r--r--src/python/grpcio/tests/unit/_links/_transmission_test.py239
-rw-r--r--src/python/grpcio/tests/unit/_metadata_code_details_test.py523
-rw-r--r--src/python/grpcio/tests/unit/_metadata_test.py216
-rw-r--r--src/python/grpcio/tests/unit/_rpc_test.py34
-rw-r--r--src/python/grpcio/tests/unit/_thread_cleanup_test.py117
-rw-r--r--src/python/grpcio/tests/unit/beta/_beta_features_test.py4
-rw-r--r--src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py10
-rw-r--r--src/python/grpcio/tests/unit/beta/_not_found_test.py2
-rw-r--r--src/python/grpcio/tests/unit/beta/test_utilities.py13
-rw-r--r--src/python/grpcio/tests/unit/framework/_crust_over_core_face_interface_test.py113
-rw-r--r--src/python/grpcio/tests/unit/framework/core/_base_interface_test.py96
-rw-r--r--src/python/grpcio/tests/unit/framework/foundation/_later_test.py151
-rw-r--r--src/python/grpcio/tests/unit/framework/interfaces/base/test_cases.py6
-rw-r--r--src/python/grpcio/tests/unit/test_common.py4
65 files changed, 3740 insertions, 2685 deletions
diff --git a/src/python/grpcio/commands.py b/src/python/grpcio/commands.py
index 295dab2d27..f498ed4190 100644
--- a/src/python/grpcio/commands.py
+++ b/src/python/grpcio/commands.py
@@ -182,6 +182,8 @@ class BuildProtoModules(setuptools.Command):
'--plugin=protoc-gen-python-grpc={}'.format(
self.grpc_python_plugin_command),
'-I {}'.format(GRPC_STEM),
+ '-I .',
+ '-I {}/third_party/protobuf/src'.format(GRPC_STEM),
'--python_out={}'.format(PROTO_GEN_STEM),
'--python-grpc_out={}'.format(PROTO_GEN_STEM),
] + [path]
@@ -191,7 +193,7 @@ class BuildProtoModules(setuptools.Command):
except subprocess.CalledProcessError as e:
sys.stderr.write(
'warning: Command:\n{}\nMessage:\n{}\nOutput:\n{}'.format(
- command, e.message, e.output))
+ command, str(e), e.output))
# Generated proto directories dont include __init__.py, but
# these are needed for python package resolution
diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py
index bbf04ad03e..b3eeaad1f7 100644
--- a/src/python/grpcio/grpc/__init__.py
+++ b/src/python/grpcio/grpc/__init__.py
@@ -29,8 +29,6 @@
"""gRPC's Python API."""
-__import__('pkg_resources').declare_namespace(__name__)
-
import abc
import enum
@@ -212,14 +210,14 @@ class ChannelConnectivity(enum.Enum):
READY: The channel is ready to conduct RPCs.
TRANSIENT_FAILURE: The channel has seen a failure from which it expects to
recover.
- FATAL_FAILURE: The channel has seen a failure from which it cannot recover.
+ SHUTDOWN: The channel has seen a failure from which it cannot recover.
"""
IDLE = (_cygrpc.ConnectivityState.idle, 'idle')
CONNECTING = (_cygrpc.ConnectivityState.connecting, 'connecting')
READY = (_cygrpc.ConnectivityState.ready, 'ready')
TRANSIENT_FAILURE = (
_cygrpc.ConnectivityState.transient_failure, 'transient failure')
- FATAL_FAILURE = (_cygrpc.ConnectivityState.fatal_failure, 'fatal failure')
+ SHUTDOWN = (_cygrpc.ConnectivityState.shutdown, 'shutdown')
@enum.unique
@@ -438,9 +436,7 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
"""Affords invoking a unary-unary RPC."""
@abc.abstractmethod
- def __call__(
- self, request, timeout=None, metadata=None, credentials=None,
- with_call=False):
+ def __call__(self, request, timeout=None, metadata=None, credentials=None):
"""Synchronously invokes the underlying RPC.
Args:
@@ -449,12 +445,30 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
metadata: An optional sequence of pairs of bytes to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC.
- with_call: Whether or not to include return a Call for the RPC in addition
- to the response.
Returns:
- The response value for the RPC, and a Call for the RPC if with_call was
- set to True at invocation.
+ The response value for the RPC.
+
+ Raises:
+ RpcError: Indicating that the RPC terminated with non-OK status. The
+ raised RpcError will also be a Call for the RPC affording the RPC's
+ metadata, status code, and details.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def with_call(self, request, timeout=None, metadata=None, credentials=None):
+ """Synchronously invokes the underlying RPC.
+
+ Args:
+ request: The request value for the RPC.
+ timeout: An optional durating of time in seconds to allow for the RPC.
+ metadata: An optional sequence of pairs of bytes to be transmitted to the
+ service-side of the RPC.
+ credentials: An optional CallCredentials for the RPC.
+
+ Returns:
+ The response value for the RPC and a Call value for the RPC.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
@@ -510,8 +524,7 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
@abc.abstractmethod
def __call__(
- self, request_iterator, timeout=None, metadata=None, credentials=None,
- with_call=False):
+ self, request_iterator, timeout=None, metadata=None, credentials=None):
"""Synchronously invokes the underlying RPC.
Args:
@@ -520,8 +533,6 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
metadata: An optional sequence of pairs of bytes to be transmitted to the
service-side of the RPC.
credentials: An optional CallCredentials for the RPC.
- with_call: Whether or not to include return a Call for the RPC in addition
- to the response.
Returns:
The response value for the RPC, and a Call for the RPC if with_call was
@@ -535,6 +546,28 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
+ def with_call(
+ self, request_iterator, timeout=None, metadata=None, credentials=None):
+ """Synchronously invokes the underlying RPC.
+
+ Args:
+ request_iterator: An iterator that yields request values for the RPC.
+ timeout: An optional duration of time in seconds to allow for the RPC.
+ metadata: An optional sequence of pairs of bytes to be transmitted to the
+ service-side of the RPC.
+ credentials: An optional CallCredentials for the RPC.
+
+ Returns:
+ The response value for the RPC and a Call for the RPC.
+
+ Raises:
+ RpcError: Indicating that the RPC terminated with non-OK status. The
+ raised RpcError will also be a Call for the RPC affording the RPC's
+ metadata, status code, and details.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
def future(
self, request_iterator, timeout=None, metadata=None, credentials=None):
"""Asynchronously invokes the underlying RPC.
@@ -900,6 +933,102 @@ class Server(six.with_metaclass(abc.ABCMeta)):
################################# Functions ################################
+def unary_unary_rpc_method_handler(
+ behavior, request_deserializer=None, response_serializer=None):
+ """Creates an RpcMethodHandler for a unary-unary RPC method.
+
+ Args:
+ behavior: The implementation of an RPC method as a callable behavior taking
+ a single request value and returning a single response value.
+ request_deserializer: An optional request deserialization behavior.
+ response_serializer: An optional response serialization behavior.
+
+ Returns:
+ An RpcMethodHandler for a unary-unary RPC method constructed from the given
+ parameters.
+ """
+ from grpc import _utilities
+ return _utilities.RpcMethodHandler(
+ False, False, request_deserializer, response_serializer, behavior, None,
+ None, None)
+
+
+def unary_stream_rpc_method_handler(
+ behavior, request_deserializer=None, response_serializer=None):
+ """Creates an RpcMethodHandler for a unary-stream RPC method.
+
+ Args:
+ behavior: The implementation of an RPC method as a callable behavior taking
+ a single request value and returning an iterator of response values.
+ request_deserializer: An optional request deserialization behavior.
+ response_serializer: An optional response serialization behavior.
+
+ Returns:
+ An RpcMethodHandler for a unary-stream RPC method constructed from the
+ given parameters.
+ """
+ from grpc import _utilities
+ return _utilities.RpcMethodHandler(
+ False, True, request_deserializer, response_serializer, None, behavior,
+ None, None)
+
+
+def stream_unary_rpc_method_handler(
+ behavior, request_deserializer=None, response_serializer=None):
+ """Creates an RpcMethodHandler for a stream-unary RPC method.
+
+ Args:
+ behavior: The implementation of an RPC method as a callable behavior taking
+ an iterator of request values and returning a single response value.
+ request_deserializer: An optional request deserialization behavior.
+ response_serializer: An optional response serialization behavior.
+
+ Returns:
+ An RpcMethodHandler for a stream-unary RPC method constructed from the
+ given parameters.
+ """
+ from grpc import _utilities
+ return _utilities.RpcMethodHandler(
+ True, False, request_deserializer, response_serializer, None, None,
+ behavior, None)
+
+
+def stream_stream_rpc_method_handler(
+ behavior, request_deserializer=None, response_serializer=None):
+ """Creates an RpcMethodHandler for a stream-stream RPC method.
+
+ Args:
+ behavior: The implementation of an RPC method as a callable behavior taking
+ an iterator of request values and returning an iterator of response
+ values.
+ request_deserializer: An optional request deserialization behavior.
+ response_serializer: An optional response serialization behavior.
+
+ Returns:
+ An RpcMethodHandler for a stream-stream RPC method constructed from the
+ given parameters.
+ """
+ from grpc import _utilities
+ return _utilities.RpcMethodHandler(
+ True, True, request_deserializer, response_serializer, None, None, None,
+ behavior)
+
+
+def method_handlers_generic_handler(service, method_handlers):
+ """Creates a grpc.GenericRpcHandler from RpcMethodHandlers.
+
+ Args:
+ service: A service name to be used for the given method handlers.
+ method_handlers: A dictionary from method name to RpcMethodHandler
+ implementing the named method.
+
+ Returns:
+ A GenericRpcHandler constructed from the given parameters.
+ """
+ from grpc import _utilities
+ return _utilities.DictionaryGenericHandler(service, method_handlers)
+
+
def ssl_channel_credentials(
root_certificates=None, private_key=None, certificate_chain=None):
"""Creates a ChannelCredentials for use with an SSL-enabled Channel.
@@ -947,6 +1076,21 @@ def metadata_call_credentials(metadata_plugin, name=None):
metadata_plugin, effective_name))
+def access_token_call_credentials(access_token):
+ """Construct CallCredentials from an access token.
+
+ Args:
+ access_token: A string to place directly in the http request
+ authorization header, ie "Authorization: Bearer <access_token>".
+
+ Returns:
+ A CallCredentials.
+ """
+ from grpc import _auth
+ return metadata_call_credentials(
+ _auth.AccessTokenCallCredentials(access_token))
+
+
def composite_call_credentials(call_credentials, additional_call_credentials):
"""Compose two CallCredentials to make a new one.
@@ -1044,7 +1188,7 @@ def insecure_channel(target, options=None):
A Channel to the target through which RPCs may be conducted.
"""
from grpc import _channel
- return _channel.Channel(target, None, options)
+ return _channel.Channel(target, options, None)
def secure_channel(target, credentials, options=None):
@@ -1060,7 +1204,7 @@ def secure_channel(target, credentials, options=None):
A Channel to the target through which RPCs may be conducted.
"""
from grpc import _channel
- return _channel.Channel(target, credentials, options)
+ return _channel.Channel(target, options, credentials._credentials)
def server(generic_rpc_handlers, thread_pool, options=None):
@@ -1082,3 +1226,49 @@ def server(generic_rpc_handlers, thread_pool, options=None):
"""
from grpc import _server
return _server.Server(generic_rpc_handlers, thread_pool)
+
+
+################################### __all__ #################################
+
+
+__all__ = (
+ 'FutureTimeoutError',
+ 'FutureCancelledError',
+ 'Future',
+ 'ChannelConnectivity',
+ 'StatusCode',
+ 'RpcError',
+ 'RpcContext',
+ 'Call',
+ 'ChannelCredentials',
+ 'CallCredentials',
+ 'AuthMetadataContext',
+ 'AuthMetadataPluginCallback',
+ 'AuthMetadataPlugin',
+ 'ServerCredentials',
+ 'UnaryUnaryMultiCallable',
+ 'UnaryStreamMultiCallable',
+ 'StreamUnaryMultiCallable',
+ 'StreamStreamMultiCallable',
+ 'Channel',
+ 'ServicerContext',
+ 'RpcMethodHandler',
+ 'HandlerCallDetails',
+ 'GenericRpcHandler',
+ 'Server',
+ 'unary_unary_rpc_method_handler',
+ 'unary_stream_rpc_method_handler',
+ 'stream_unary_rpc_method_handler',
+ 'stream_stream_rpc_method_handler',
+ 'method_handlers_generic_handler',
+ 'ssl_channel_credentials',
+ 'metadata_call_credentials',
+ 'access_token_call_credentials',
+ 'composite_call_credentials',
+ 'composite_channel_credentials',
+ 'ssl_server_credentials',
+ 'channel_ready_future',
+ 'insecure_channel',
+ 'secure_channel',
+ 'server',
+)
diff --git a/src/python/grpcio/grpc/_adapter/_low.py b/src/python/grpcio/grpc/_adapter/_low.py
index 00788bd4cf..48410167a0 100644
--- a/src/python/grpcio/grpc/_adapter/_low.py
+++ b/src/python/grpcio/grpc/_adapter/_low.py
@@ -30,8 +30,8 @@
import threading
from grpc import _grpcio_metadata
+from grpc import _plugin_wrapping
from grpc._cython import cygrpc
-from grpc._adapter import _implementations
from grpc._adapter import _types
_USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
@@ -57,78 +57,8 @@ def channel_credentials_ssl(
return cygrpc.channel_credentials_ssl(root_certificates, pair)
-class _WrappedCygrpcCallback(object):
-
- def __init__(self, cygrpc_callback):
- self.is_called = False
- self.error = None
- self.is_called_lock = threading.Lock()
- self.cygrpc_callback = cygrpc_callback
-
- def _invoke_failure(self, error):
- # TODO(atash) translate different Exception superclasses into different
- # status codes.
- self.cygrpc_callback(
- cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message)
-
- def _invoke_success(self, metadata):
- try:
- cygrpc_metadata = cygrpc.Metadata(
- cygrpc.Metadatum(key, value)
- for key, value in metadata)
- except Exception as error:
- self._invoke_failure(error)
- return
- self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '')
-
- def __call__(self, metadata, error):
- with self.is_called_lock:
- if self.is_called:
- raise RuntimeError('callback should only ever be invoked once')
- if self.error:
- self._invoke_failure(self.error)
- return
- self.is_called = True
- if error is None:
- self._invoke_success(metadata)
- else:
- self._invoke_failure(error)
-
- def notify_failure(self, error):
- with self.is_called_lock:
- if not self.is_called:
- self.error = error
-
-
-class _WrappedPlugin(object):
-
- def __init__(self, plugin):
- self.plugin = plugin
-
- def __call__(self, context, cygrpc_callback):
- wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback)
- wrapped_context = _implementations.AuthMetadataContext(context.service_url,
- context.method_name)
- try:
- self.plugin(
- wrapped_context,
- _implementations.AuthMetadataPluginCallback(wrapped_cygrpc_callback))
- except Exception as error:
- wrapped_cygrpc_callback.notify_failure(error)
- raise
-
-
-def call_credentials_metadata_plugin(plugin, name):
- """
- Args:
- plugin: A callable accepting a _types.AuthMetadataContext
- object and a callback (itself accepting a list of metadata key/value
- 2-tuples and a None-able exception value). The callback must be eventually
- called, but need not be called in plugin's invocation.
- plugin's invocation must be non-blocking.
- """
- return cygrpc.call_credentials_metadata_plugin(
- cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name))
+call_credentials_metadata_plugin = (
+ _plugin_wrapping.call_credentials_metadata_plugin)
class CompletionQueue(_types.CompletionQueue):
diff --git a/src/python/grpcio/grpc/_adapter/_types.py b/src/python/grpcio/grpc/_adapter/_types.py
index f8405949d4..b7cc6fbbb5 100644
--- a/src/python/grpcio/grpc/_adapter/_types.py
+++ b/src/python/grpcio/grpc/_adapter/_types.py
@@ -114,7 +114,7 @@ class ConnectivityState(enum.IntEnum):
CONNECTING = cygrpc.ConnectivityState.connecting
READY = cygrpc.ConnectivityState.ready
TRANSIENT_FAILURE = cygrpc.ConnectivityState.transient_failure
- FATAL_FAILURE = cygrpc.ConnectivityState.fatal_failure
+ FATAL_FAILURE = cygrpc.ConnectivityState.shutdown
class Status(collections.namedtuple(
diff --git a/src/python/grpcio/grpc/beta/_auth.py b/src/python/grpcio/grpc/_auth.py
index 553d4b9991..dea3221c9d 100644
--- a/src/python/grpcio/grpc/beta/_auth.py
+++ b/src/python/grpcio/grpc/_auth.py
@@ -29,9 +29,10 @@
"""GRPCAuthMetadataPlugins for standard authentication."""
+import inspect
from concurrent import futures
-from grpc.beta import interfaces
+import grpc
def _sign_request(callback, token, error):
@@ -39,16 +40,28 @@ def _sign_request(callback, token, error):
callback(metadata, error)
-class GoogleCallCredentials(interfaces.GRPCAuthMetadataPlugin):
+class GoogleCallCredentials(grpc.AuthMetadataPlugin):
"""Metadata wrapper for GoogleCredentials from the oauth2client library."""
def __init__(self, credentials):
self._credentials = credentials
self._pool = futures.ThreadPoolExecutor(max_workers=1)
+ # Hack to determine if these are JWT creds and we need to pass
+ # additional_claims when getting a token
+ if 'additional_claims' in inspect.getargspec(
+ credentials.get_access_token).args:
+ self._is_jwt = True
+ else:
+ self._is_jwt = False
+
def __call__(self, context, callback):
# MetadataPlugins cannot block (see grpc.beta.interfaces.py)
- future = self._pool.submit(self._credentials.get_access_token)
+ if self._is_jwt:
+ future = self._pool.submit(self._credentials.get_access_token,
+ additional_claims={'aud': context.service_url})
+ else:
+ future = self._pool.submit(self._credentials.get_access_token)
future.add_done_callback(lambda x: self._get_token_callback(callback, x))
def _get_token_callback(self, callback, future):
@@ -63,7 +76,7 @@ class GoogleCallCredentials(interfaces.GRPCAuthMetadataPlugin):
self._pool.shutdown(wait=False)
-class AccessTokenCallCredentials(interfaces.GRPCAuthMetadataPlugin):
+class AccessTokenCallCredentials(grpc.AuthMetadataPlugin):
"""Metadata wrapper for raw access token credentials."""
def __init__(self, access_token):
diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py
index d9eb5a4b77..1f34beeb2c 100644
--- a/src/python/grpcio/grpc/_channel.py
+++ b/src/python/grpcio/grpc/_channel.py
@@ -85,7 +85,7 @@ def _deadline(timeout):
def _unknown_code_details(unknown_cygrpc_code, details):
- return b'Server sent unknown code {} and details "{}"'.format(
+ return 'Server sent unknown code {} and details "{}"'.format(
unknown_cygrpc_code, details)
@@ -134,19 +134,19 @@ def _handle_event(event, state, response_deserializer):
for batch_operation in event.batch_operations:
operation_type = batch_operation.type
state.due.remove(operation_type)
- if operation_type is cygrpc.OperationType.receive_initial_metadata:
+ if operation_type == cygrpc.OperationType.receive_initial_metadata:
state.initial_metadata = batch_operation.received_metadata
- elif operation_type is cygrpc.OperationType.receive_message:
+ elif operation_type == cygrpc.OperationType.receive_message:
serialized_response = batch_operation.received_message.bytes()
if serialized_response is not None:
response = _common.deserialize(
serialized_response, response_deserializer)
if response is None:
- details = b'Exception deserializing response!'
+ details = 'Exception deserializing response!'
_abort(state, grpc.StatusCode.INTERNAL, details)
else:
state.response = response
- elif operation_type is cygrpc.OperationType.receive_status_on_client:
+ elif operation_type == cygrpc.OperationType.receive_status_on_client:
state.trailing_metadata = batch_operation.received_metadata
if state.code is None:
code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE.get(
@@ -186,7 +186,7 @@ def _consume_request_iterator(
if state.code is None and not state.cancelled:
if serialized_request is None:
call.cancel()
- details = b'Exception serializing request!'
+ details = 'Exception serializing request!'
_abort(state, grpc.StatusCode.INTERNAL, details)
return
else:
@@ -230,7 +230,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
if self._state.code is None:
self._call.cancel()
self._state.cancelled = True
- _abort(self._state, grpc.StatusCode.CANCELLED, b'Cancelled!')
+ _abort(self._state, grpc.StatusCode.CANCELLED, 'Cancelled!')
self._state.condition.notify_all()
return False
@@ -402,7 +402,7 @@ def _start_unary_request(request, timeout, request_serializer):
if serialized_request is None:
state = _RPCState(
(), _EMPTY_METADATA, _EMPTY_METADATA, grpc.StatusCode.INTERNAL,
- b'Exception serializing request!')
+ 'Exception serializing request!')
rendezvous = _Rendezvous(state, None, None, deadline)
return deadline, deadline_timespec, None, rendezvous
else:
@@ -449,9 +449,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
)
return state, operations, deadline, deadline_timespec, None
- def __call__(
- self, request, timeout=None, metadata=None, credentials=None,
- with_call=False):
+ def _blocking(self, request, timeout, metadata, credentials):
state, operations, deadline, deadline_timespec, rendezvous = self._prepare(
request, timeout, metadata)
if rendezvous:
@@ -464,7 +462,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
call.set_credentials(credentials._credentials)
call.start_batch(cygrpc.Operations(operations), None)
_handle_event(completion_queue.poll(), state, self._response_deserializer)
- return _end_unary_response_blocking(state, with_call, deadline)
+ return state, deadline
+
+ def __call__(self, request, timeout=None, metadata=None, credentials=None):
+ state, deadline, = self._blocking(request, timeout, metadata, credentials)
+ return _end_unary_response_blocking(state, False, deadline)
+
+ def with_call(self, request, timeout=None, metadata=None, credentials=None):
+ state, deadline, = self._blocking(request, timeout, metadata, credentials)
+ return _end_unary_response_blocking(state, True, deadline)
def future(self, request, timeout=None, metadata=None, credentials=None):
state, operations, deadline, deadline_timespec, rendezvous = self._prepare(
@@ -532,9 +538,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
- def __call__(
- self, request_iterator, timeout=None, metadata=None, credentials=None,
- with_call=False):
+ def _blocking(self, request_iterator, timeout, metadata, credentials):
deadline, deadline_timespec = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
completion_queue = cygrpc.CompletionQueue()
@@ -563,7 +567,19 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
state.condition.notify_all()
if not state.due:
break
- return _end_unary_response_blocking(state, with_call, deadline)
+ return state, deadline
+
+ def __call__(
+ self, request_iterator, timeout=None, metadata=None, credentials=None):
+ state, deadline, = self._blocking(
+ request_iterator, timeout, metadata, credentials)
+ return _end_unary_response_blocking(state, False, deadline)
+
+ def with_call(
+ self, request_iterator, timeout=None, metadata=None, credentials=None):
+ state, deadline, = self._blocking(
+ request_iterator, timeout, metadata, credentials)
+ return _end_unary_response_blocking(state, True, deadline)
def future(
self, request_iterator, timeout=None, metadata=None, credentials=None):
@@ -814,6 +830,13 @@ def _options(options):
class Channel(grpc.Channel):
def __init__(self, target, options, credentials):
+ """Constructor.
+
+ Args:
+ target: The target to which to connect.
+ options: Configuration options for the channel.
+ credentials: A cygrpc.ChannelCredentials or None.
+ """
self._channel = cygrpc.Channel(target, _options(options), credentials)
self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel)
diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py
index a3fb66cd07..f351bea9e3 100644
--- a/src/python/grpcio/grpc/_common.py
+++ b/src/python/grpcio/grpc/_common.py
@@ -30,6 +30,8 @@
"""Shared implementation."""
import logging
+import threading
+import time
import six
@@ -44,8 +46,8 @@ CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = {
cygrpc.ConnectivityState.ready: grpc.ChannelConnectivity.READY,
cygrpc.ConnectivityState.transient_failure:
grpc.ChannelConnectivity.TRANSIENT_FAILURE,
- cygrpc.ConnectivityState.fatal_failure:
- grpc.ChannelConnectivity.FATAL_FAILURE,
+ cygrpc.ConnectivityState.shutdown:
+ grpc.ChannelConnectivity.SHUTDOWN,
}
CYGRPC_STATUS_CODE_TO_STATUS_CODE = {
@@ -97,3 +99,56 @@ def serialize(message, serializer):
def deserialize(serialized_message, deserializer):
return _transform(serialized_message, deserializer,
'Exception deserializing message!')
+
+
+def _encode(s):
+ if isinstance(s, bytes):
+ return s
+ else:
+ return s.encode('ascii')
+
+
+def fully_qualified_method(group, method):
+ group = _encode(group)
+ method = _encode(method)
+ return b'/' + group + b'/' + method
+
+
+class CleanupThread(threading.Thread):
+ """A threading.Thread subclass supporting custom behavior on join().
+
+ On Python Interpreter exit, Python will attempt to join outstanding threads
+ prior to garbage collection. We may need to do additional cleanup, and
+ we accomplish this by overriding the join() method.
+ """
+
+ def __init__(self, behavior, group=None, target=None, name=None,
+ args=(), kwargs={}):
+ """Constructor.
+
+ Args:
+ behavior (function): Function called on join() with a single
+ argument, timeout, indicating the maximum duration of
+ `behavior`, or None indicating `behavior` has no deadline.
+ `behavior` must be idempotent.
+ group (None): should be None. Reseved for future extensions
+ when ThreadGroup is implemented.
+ target (function): The function to invoke when this thread is
+ run. Defaults to None.
+ name (str): The name of this thread. Defaults to None.
+ args (tuple[object]): A tuple of arguments to pass to `target`.
+ kwargs (dict[str,object]): A dictionary of keyword arguments to
+ pass to `target`.
+ """
+ super(CleanupThread, self).__init__(group=group, target=target,
+ name=name, args=args, kwargs=kwargs)
+ self._behavior = behavior
+
+ def join(self, timeout=None):
+ start_time = time.time()
+ self._behavior(timeout)
+ end_time = time.time()
+ if timeout is not None:
+ timeout -= end_time - start_time
+ timeout = max(timeout, 0)
+ super(CleanupThread, self).join(timeout)
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
index 1bfe6344e0..a09bbc75fe 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
@@ -55,6 +55,7 @@ cdef class Call:
def cancel(
self, grpc_status_code error_code=GRPC_STATUS__DO_NOT_USE,
details=None):
+ details = str_to_bytes(details)
if not self.is_valid:
raise ValueError("invalid call object cannot be used from Python")
if (details is None) != (error_code == GRPC_STATUS__DO_NOT_USE):
@@ -63,12 +64,6 @@ cdef class Call:
cdef grpc_call_error result
cdef char *c_details = NULL
if error_code != GRPC_STATUS__DO_NOT_USE:
- if isinstance(details, bytes):
- pass
- elif isinstance(details, basestring):
- details = details.encode()
- else:
- raise TypeError("expected details to be str or bytes")
self.references.append(details)
c_details = details
with nogil:
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
index c26bc083cf..866cff0d01 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
@@ -34,18 +34,13 @@ cdef class Channel:
def __cinit__(self, target, ChannelArgs arguments=None,
ChannelCredentials channel_credentials=None):
+ target = str_to_bytes(target)
cdef grpc_channel_args *c_arguments = NULL
cdef char *c_target = NULL
self.c_channel = NULL
self.references = []
if arguments is not None:
c_arguments = &arguments.c_args
- if isinstance(target, bytes):
- pass
- elif isinstance(target, basestring):
- target = target.encode()
- else:
- raise TypeError("expected target to be str or bytes")
c_target = target
if channel_credentials is None:
with nogil:
@@ -62,25 +57,14 @@ cdef class Channel:
def create_call(self, Call parent, int flags,
CompletionQueue queue not None,
method, host, Timespec deadline not None):
+ method = str_to_bytes(method)
+ host = str_to_bytes(host)
if queue.is_shutting_down:
raise ValueError("queue must not be shutting down or shutdown")
- if isinstance(method, bytes):
- pass
- elif isinstance(method, basestring):
- method = method.encode()
- else:
- raise TypeError("expected method to be str or bytes")
cdef char *method_c_string = method
cdef char *host_c_string = NULL
- if host is None:
- pass
- elif isinstance(host, bytes):
+ if host is not None:
host_c_string = host
- elif isinstance(host, basestring):
- host = host.encode()
- host_c_string = host
- else:
- raise TypeError("expected host to be str, bytes, or None")
cdef Call operation_call = Call()
operation_call.references = [self, method, host, queue]
cdef grpc_call *parent_call = NULL
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi
index a67c963684..01089c3dc0 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi
@@ -31,9 +31,6 @@
cdef class CompletionQueue:
cdef grpc_completion_queue *c_completion_queue
- cdef object pluck_condition
- cdef int num_plucking
- cdef int num_polling
cdef bint is_shutting_down
cdef bint is_shutdown
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
index cdae39d519..90266516fe 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
@@ -32,6 +32,8 @@ cimport cpython
import threading
import time
+cdef int _INTERRUPT_CHECK_PERIOD_MS = 200
+
cdef class CompletionQueue:
@@ -40,9 +42,6 @@ cdef class CompletionQueue:
self.c_completion_queue = grpc_completion_queue_create(NULL)
self.is_shutting_down = False
self.is_shutdown = False
- self.pluck_condition = threading.Condition()
- self.num_plucking = 0
- self.num_polling = 0
cdef _interpret_event(self, grpc_event event):
cdef OperationTag tag = None
@@ -83,45 +82,27 @@ cdef class CompletionQueue:
def poll(self, Timespec deadline=None):
# We name this 'poll' to avoid problems with CPython's expectations for
# 'special' methods (like next and __next__).
+ cdef gpr_timespec c_increment
+ cdef gpr_timespec c_timeout
cdef gpr_timespec c_deadline
with nogil:
+ c_increment = gpr_time_from_millis(_INTERRUPT_CHECK_PERIOD_MS, GPR_TIMESPAN)
c_deadline = gpr_inf_future(GPR_CLOCK_REALTIME)
- if deadline is not None:
- c_deadline = deadline.c_time
- cdef grpc_event event
-
- # Poll within a critical section to detect contention
- with self.pluck_condition:
- assert self.num_plucking == 0, 'cannot simultaneously pluck and poll'
- self.num_polling += 1
- with nogil:
- event = grpc_completion_queue_next(
- self.c_completion_queue, c_deadline, NULL)
- with self.pluck_condition:
- self.num_polling -= 1
- return self._interpret_event(event)
-
- def pluck(self, OperationTag tag, Timespec deadline=None):
- # Plucking a 'None' tag is equivalent to passing control to GRPC core until
- # the deadline.
- cdef gpr_timespec c_deadline = gpr_inf_future(
- GPR_CLOCK_REALTIME)
- if deadline is not None:
- c_deadline = deadline.c_time
- cdef grpc_event event
-
- # Pluck within a critical section to detect contention
- with self.pluck_condition:
- assert self.num_polling == 0, 'cannot simultaneously pluck and poll'
- assert self.num_plucking < GRPC_MAX_COMPLETION_QUEUE_PLUCKERS, (
- 'cannot pluck more than {} times simultaneously'.format(
- GRPC_MAX_COMPLETION_QUEUE_PLUCKERS))
- self.num_plucking += 1
- with nogil:
- event = grpc_completion_queue_pluck(
- self.c_completion_queue, <cpython.PyObject *>tag, c_deadline, NULL)
- with self.pluck_condition:
- self.num_plucking -= 1
+ if deadline is not None:
+ c_deadline = deadline.c_time
+
+ while True:
+ c_timeout = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c_increment)
+ if gpr_time_cmp(c_timeout, c_deadline) > 0:
+ c_timeout = c_deadline
+ event = grpc_completion_queue_next(
+ self.c_completion_queue, c_timeout, NULL)
+ if event.type != GRPC_QUEUE_TIMEOUT or gpr_time_cmp(c_timeout, c_deadline) == 0:
+ break;
+
+ # Handle any signals
+ with gil:
+ cpython.PyErr_CheckSignals()
return self._interpret_event(event)
def shutdown(self):
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
index 19a59e08f3..d377a67520 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
@@ -54,7 +54,7 @@ cdef class ServerCredentials:
cdef class CredentialsMetadataPlugin:
cdef object plugin_callback
- cdef str plugin_name
+ cdef bytes plugin_name
cdef grpc_metadata_credentials_plugin make_c_plugin(self)
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
index 1ba86457af..470382d609 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
@@ -82,7 +82,7 @@ cdef class ServerCredentials:
cdef class CredentialsMetadataPlugin:
- def __cinit__(self, object plugin_callback, str name):
+ def __cinit__(self, object plugin_callback, name):
"""
Args:
plugin_callback (callable): Callback accepting a service URL (str/bytes)
@@ -93,6 +93,7 @@ cdef class CredentialsMetadataPlugin:
successful).
name (str): Plugin name.
"""
+ name = str_to_bytes(name)
if not callable(plugin_callback):
raise ValueError('expected callable plugin_callback')
self.plugin_callback = plugin_callback
@@ -129,7 +130,8 @@ cdef void plugin_get_metadata(
grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil:
def python_callback(
Metadata metadata, grpc_status_code status,
- const char *error_details):
+ error_details):
+ error_details = str_to_bytes(error_details)
cb(user_data, metadata.c_metadata_array.metadata,
metadata.c_metadata_array.count, status, error_details)
cdef CredentialsMetadataPlugin self = <CredentialsMetadataPlugin>state
@@ -148,14 +150,7 @@ def channel_credentials_google_default():
def channel_credentials_ssl(pem_root_certificates,
SslPemKeyCertPair ssl_pem_key_cert_pair):
- if pem_root_certificates is None:
- pass
- elif isinstance(pem_root_certificates, bytes):
- pass
- elif isinstance(pem_root_certificates, basestring):
- pem_root_certificates = pem_root_certificates.encode()
- else:
- raise TypeError("expected str or bytes for pem_root_certificates")
+ pem_root_certificates = str_to_bytes(pem_root_certificates)
cdef ChannelCredentials credentials = ChannelCredentials()
cdef const char *c_pem_root_certificates = NULL
if pem_root_certificates is not None:
@@ -207,12 +202,7 @@ def call_credentials_google_compute_engine():
def call_credentials_service_account_jwt_access(
json_key, Timespec token_lifetime not None):
- if isinstance(json_key, bytes):
- pass
- elif isinstance(json_key, basestring):
- json_key = json_key.encode()
- else:
- raise TypeError("expected json_key to be str or bytes")
+ json_key = str_to_bytes(json_key)
cdef CallCredentials credentials = CallCredentials()
cdef char *json_key_c_string = json_key
with nogil:
@@ -223,12 +213,7 @@ def call_credentials_service_account_jwt_access(
return credentials
def call_credentials_google_refresh_token(json_refresh_token):
- if isinstance(json_refresh_token, bytes):
- pass
- elif isinstance(json_refresh_token, basestring):
- json_refresh_token = json_refresh_token.encode()
- else:
- raise TypeError("expected json_refresh_token to be str or bytes")
+ json_refresh_token = str_to_bytes(json_refresh_token)
cdef CallCredentials credentials = CallCredentials()
cdef char *json_refresh_token_c_string = json_refresh_token
with nogil:
@@ -238,18 +223,8 @@ def call_credentials_google_refresh_token(json_refresh_token):
return credentials
def call_credentials_google_iam(authorization_token, authority_selector):
- if isinstance(authorization_token, bytes):
- pass
- elif isinstance(authorization_token, basestring):
- authorization_token = authorization_token.encode()
- else:
- raise TypeError("expected authorization_token to be str or bytes")
- if isinstance(authority_selector, bytes):
- pass
- elif isinstance(authority_selector, basestring):
- authority_selector = authority_selector.encode()
- else:
- raise TypeError("expected authority_selector to be str or bytes")
+ authorization_token = str_to_bytes(authorization_token)
+ authority_selector = str_to_bytes(authority_selector)
cdef CallCredentials credentials = CallCredentials()
cdef char *authorization_token_c_string = authorization_token
cdef char *authority_selector_c_string = authority_selector
@@ -272,16 +247,10 @@ def call_credentials_metadata_plugin(CredentialsMetadataPlugin plugin):
def server_credentials_ssl(pem_root_certs, pem_key_cert_pairs,
bint force_client_auth):
+ pem_root_certs = str_to_bytes(pem_root_certs)
cdef char *c_pem_root_certs = NULL
- if pem_root_certs is None:
- pass
- elif isinstance(pem_root_certs, bytes):
- c_pem_root_certs = pem_root_certs
- elif isinstance(pem_root_certs, basestring):
- pem_root_certs = pem_root_certs.encode()
+ if pem_root_certs is not None:
c_pem_root_certs = pem_root_certs
- else:
- raise TypeError("expected pem_root_certs to be str or bytes")
pem_key_cert_pairs = list(pem_key_cert_pairs)
for pair in pem_key_cert_pairs:
if not isinstance(pair, SslPemKeyCertPair):
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
index 05b8886df7..168b9751aa 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
@@ -80,6 +80,12 @@ cdef extern from "grpc/_cython/loader.h":
gpr_timespec gpr_convert_clock_type(gpr_timespec t,
gpr_clock_type target_clock) nogil
+ gpr_timespec gpr_time_from_millis(int64_t ms, gpr_clock_type type) nogil
+
+ gpr_timespec gpr_time_add(gpr_timespec a, gpr_timespec b) nogil
+
+ int gpr_time_cmp(gpr_timespec a, gpr_timespec b) nogil
+
ctypedef enum grpc_status_code:
GRPC_STATUS_OK
GRPC_STATUS_CANCELLED
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi
new file mode 100644
index 0000000000..274f038e0a
--- /dev/null
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi
@@ -0,0 +1,39 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+# This function will ascii encode unicode string inputs if neccesary.
+# In Python3, unicode strings are the default str type.
+cdef bytes str_to_bytes(object s):
+ if s is None or isinstance(s, bytes):
+ return s
+ elif isinstance(s, unicode):
+ return s.encode('ascii')
+ else:
+ raise TypeError('Expected bytes, str, or unicode, not {}'.format(type(s)))
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
index e0219b0086..0055d0d3a2 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
@@ -33,7 +33,7 @@ class ConnectivityState:
connecting = GRPC_CHANNEL_CONNECTING
ready = GRPC_CHANNEL_READY
transient_failure = GRPC_CHANNEL_TRANSIENT_FAILURE
- fatal_failure = GRPC_CHANNEL_SHUTDOWN
+ shutdown = GRPC_CHANNEL_SHUTDOWN
class ChannelArgKey:
@@ -235,18 +235,13 @@ cdef class ByteBuffer:
if data is None:
self.c_byte_buffer = NULL
return
- if isinstance(data, bytes):
- pass
- elif isinstance(data, basestring):
- data = data.encode()
- elif isinstance(data, ByteBuffer):
+ if isinstance(data, ByteBuffer):
data = (<ByteBuffer>data).bytes()
if data is None:
self.c_byte_buffer = NULL
return
else:
- raise TypeError("expected value to be of type str, bytes, or "
- "ByteBuffer, not {}".format(type(data)))
+ data = str_to_bytes(data)
cdef char *c_data = data
cdef gpr_slice data_slice
@@ -302,19 +297,8 @@ cdef class ByteBuffer:
cdef class SslPemKeyCertPair:
def __cinit__(self, private_key, certificate_chain):
- if isinstance(private_key, bytes):
- self.private_key = private_key
- elif isinstance(private_key, basestring):
- self.private_key = private_key.encode()
- else:
- raise TypeError("expected private_key to be of type str or bytes")
- if isinstance(certificate_chain, bytes):
- self.certificate_chain = certificate_chain
- elif isinstance(certificate_chain, basestring):
- self.certificate_chain = certificate_chain.encode()
- else:
- raise TypeError("expected certificate_chain to be of type str or bytes "
- "or int")
+ self.private_key = str_to_bytes(private_key)
+ self.certificate_chain = str_to_bytes(certificate_chain)
self.c_pair.private_key = self.private_key
self.c_pair.certificate_chain = self.certificate_chain
@@ -322,27 +306,16 @@ cdef class SslPemKeyCertPair:
cdef class ChannelArg:
def __cinit__(self, key, value):
- if isinstance(key, bytes):
- self.key = key
- elif isinstance(key, basestring):
- self.key = key.encode()
- else:
- raise TypeError("expected key to be of type str or bytes")
- if isinstance(value, bytes):
- self.value = value
- self.c_arg.type = GRPC_ARG_STRING
- self.c_arg.value.string = self.value
- elif isinstance(value, basestring):
- self.value = value.encode()
- self.c_arg.type = GRPC_ARG_STRING
- self.c_arg.value.string = self.value
- elif isinstance(value, int):
+ self.key = str_to_bytes(key)
+ self.c_arg.key = self.key
+ if isinstance(value, int):
self.value = int(value)
self.c_arg.type = GRPC_ARG_INTEGER
self.c_arg.value.integer = self.value
else:
- raise TypeError("expected value to be of type str or bytes or int")
- self.c_arg.key = self.key
+ self.value = str_to_bytes(value)
+ self.c_arg.type = GRPC_ARG_STRING
+ self.c_arg.value.string = self.value
cdef class ChannelArgs:
@@ -375,18 +348,8 @@ cdef class ChannelArgs:
cdef class Metadatum:
def __cinit__(self, key, value):
- if isinstance(key, bytes):
- self._key = key
- elif isinstance(key, basestring):
- self._key = key.encode()
- else:
- raise TypeError("expected key to be of type str or bytes")
- if isinstance(value, bytes):
- self._value = value
- elif isinstance(value, basestring):
- self._value = value.encode()
- else:
- raise TypeError("expected value to be of type str or bytes")
+ self._key = str_to_bytes(key)
+ self._value = str_to_bytes(value)
self.c_metadata.key = self._key
self.c_metadata.value = self._value
self.c_metadata.value_length = len(self._value)
@@ -601,12 +564,7 @@ def operation_send_close_from_client(int flags):
def operation_send_status_from_server(
Metadata metadata, grpc_status_code code, details, int flags):
- if isinstance(details, bytes):
- pass
- elif isinstance(details, basestring):
- details = details.encode()
- else:
- raise TypeError("expected a str or bytes object for details")
+ details = str_to_bytes(details)
cdef Operation op = Operation()
op.c_op.type = GRPC_OP_SEND_STATUS_FROM_SERVER
op.c_op.flags = flags
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
index 55948755b2..42afeb8498 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
@@ -99,16 +99,11 @@ cdef class Server:
with nogil:
grpc_server_start(self.c_server)
# Ensure the core has gotten a chance to do the start-up work
- self.backup_shutdown_queue.pluck(None, Timespec(None))
+ self.backup_shutdown_queue.poll(Timespec(None))
def add_http2_port(self, address,
ServerCredentials server_credentials=None):
- if isinstance(address, bytes):
- pass
- elif isinstance(address, basestring):
- address = address.encode()
- else:
- raise TypeError("expected address to be a str or bytes")
+ address = str_to_bytes(address)
self.references.append(address)
cdef int result
cdef char *address_c_string = address
diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx
index 8823ea5ef5..cf146f5a04 100644
--- a/src/python/grpcio/grpc/_cython/cygrpc.pyx
+++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx
@@ -35,6 +35,7 @@ import sys
# TODO(atash): figure out why the coverage tool gets confused about the Cython
# coverage plugin when the following files don't have a '.pxi' suffix.
+include "grpc/_cython/_cygrpc/grpc_string.pyx.pxi"
include "grpc/_cython/_cygrpc/call.pyx.pxi"
include "grpc/_cython/_cygrpc/channel.pyx.pxi"
include "grpc/_cython/_cygrpc/credentials.pyx.pxi"
diff --git a/src/python/grpcio/grpc/_cython/imports.generated.c b/src/python/grpcio/grpc/_cython/imports.generated.c
index 5d09329680..8437e74ba0 100644
--- a/src/python/grpcio/grpc/_cython/imports.generated.c
+++ b/src/python/grpcio/grpc/_cython/imports.generated.c
@@ -126,7 +126,8 @@ grpc_header_key_is_legal_type grpc_header_key_is_legal_import;
grpc_header_nonbin_value_is_legal_type grpc_header_nonbin_value_is_legal_import;
grpc_is_binary_header_type grpc_is_binary_header_import;
grpc_call_error_to_string_type grpc_call_error_to_string_import;
-grpc_cronet_secure_channel_create_type grpc_cronet_secure_channel_create_import;
+grpc_insecure_channel_create_from_fd_type grpc_insecure_channel_create_from_fd_import;
+grpc_server_add_insecure_channel_from_fd_type grpc_server_add_insecure_channel_from_fd_import;
grpc_auth_property_iterator_next_type grpc_auth_property_iterator_next_import;
grpc_auth_context_property_iterator_type grpc_auth_context_property_iterator_import;
grpc_auth_context_peer_identity_type grpc_auth_context_peer_identity_import;
@@ -259,6 +260,8 @@ gpr_avl_unref_type gpr_avl_unref_import;
gpr_avl_add_type gpr_avl_add_import;
gpr_avl_remove_type gpr_avl_remove_import;
gpr_avl_get_type gpr_avl_get_import;
+gpr_avl_maybe_get_type gpr_avl_maybe_get_import;
+gpr_avl_is_empty_type gpr_avl_is_empty_import;
gpr_cmdline_create_type gpr_cmdline_create_import;
gpr_cmdline_add_int_type gpr_cmdline_add_int_import;
gpr_cmdline_add_flag_type gpr_cmdline_add_flag_import;
@@ -398,7 +401,8 @@ void pygrpc_load_imports(HMODULE library) {
grpc_header_nonbin_value_is_legal_import = (grpc_header_nonbin_value_is_legal_type) GetProcAddress(library, "grpc_header_nonbin_value_is_legal");
grpc_is_binary_header_import = (grpc_is_binary_header_type) GetProcAddress(library, "grpc_is_binary_header");
grpc_call_error_to_string_import = (grpc_call_error_to_string_type) GetProcAddress(library, "grpc_call_error_to_string");
- grpc_cronet_secure_channel_create_import = (grpc_cronet_secure_channel_create_type) GetProcAddress(library, "grpc_cronet_secure_channel_create");
+ grpc_insecure_channel_create_from_fd_import = (grpc_insecure_channel_create_from_fd_type) GetProcAddress(library, "grpc_insecure_channel_create_from_fd");
+ grpc_server_add_insecure_channel_from_fd_import = (grpc_server_add_insecure_channel_from_fd_type) GetProcAddress(library, "grpc_server_add_insecure_channel_from_fd");
grpc_auth_property_iterator_next_import = (grpc_auth_property_iterator_next_type) GetProcAddress(library, "grpc_auth_property_iterator_next");
grpc_auth_context_property_iterator_import = (grpc_auth_context_property_iterator_type) GetProcAddress(library, "grpc_auth_context_property_iterator");
grpc_auth_context_peer_identity_import = (grpc_auth_context_peer_identity_type) GetProcAddress(library, "grpc_auth_context_peer_identity");
@@ -531,6 +535,8 @@ void pygrpc_load_imports(HMODULE library) {
gpr_avl_add_import = (gpr_avl_add_type) GetProcAddress(library, "gpr_avl_add");
gpr_avl_remove_import = (gpr_avl_remove_type) GetProcAddress(library, "gpr_avl_remove");
gpr_avl_get_import = (gpr_avl_get_type) GetProcAddress(library, "gpr_avl_get");
+ gpr_avl_maybe_get_import = (gpr_avl_maybe_get_type) GetProcAddress(library, "gpr_avl_maybe_get");
+ gpr_avl_is_empty_import = (gpr_avl_is_empty_type) GetProcAddress(library, "gpr_avl_is_empty");
gpr_cmdline_create_import = (gpr_cmdline_create_type) GetProcAddress(library, "gpr_cmdline_create");
gpr_cmdline_add_int_import = (gpr_cmdline_add_int_type) GetProcAddress(library, "gpr_cmdline_add_int");
gpr_cmdline_add_flag_import = (gpr_cmdline_add_flag_type) GetProcAddress(library, "gpr_cmdline_add_flag");
diff --git a/src/python/grpcio/grpc/_cython/imports.generated.h b/src/python/grpcio/grpc/_cython/imports.generated.h
index cfadad0527..d52e8591b3 100644
--- a/src/python/grpcio/grpc/_cython/imports.generated.h
+++ b/src/python/grpcio/grpc/_cython/imports.generated.h
@@ -43,7 +43,7 @@
#include <grpc/census.h>
#include <grpc/compression.h>
#include <grpc/grpc.h>
-#include <grpc/grpc_cronet.h>
+#include <grpc/grpc_posix.h>
#include <grpc/grpc_security.h>
#include <grpc/impl/codegen/alloc.h>
#include <grpc/impl/codegen/byte_buffer.h>
@@ -329,9 +329,12 @@ extern grpc_is_binary_header_type grpc_is_binary_header_import;
typedef const char *(*grpc_call_error_to_string_type)(grpc_call_error error);
extern grpc_call_error_to_string_type grpc_call_error_to_string_import;
#define grpc_call_error_to_string grpc_call_error_to_string_import
-typedef grpc_channel *(*grpc_cronet_secure_channel_create_type)(void *engine, const char *target, const grpc_channel_args *args, void *reserved);
-extern grpc_cronet_secure_channel_create_type grpc_cronet_secure_channel_create_import;
-#define grpc_cronet_secure_channel_create grpc_cronet_secure_channel_create_import
+typedef grpc_channel *(*grpc_insecure_channel_create_from_fd_type)(const char *target, int fd, const grpc_channel_args *args);
+extern grpc_insecure_channel_create_from_fd_type grpc_insecure_channel_create_from_fd_import;
+#define grpc_insecure_channel_create_from_fd grpc_insecure_channel_create_from_fd_import
+typedef void(*grpc_server_add_insecure_channel_from_fd_type)(grpc_server *server, grpc_completion_queue *cq, int fd);
+extern grpc_server_add_insecure_channel_from_fd_type grpc_server_add_insecure_channel_from_fd_import;
+#define grpc_server_add_insecure_channel_from_fd grpc_server_add_insecure_channel_from_fd_import
typedef const grpc_auth_property *(*grpc_auth_property_iterator_next_type)(grpc_auth_property_iterator *it);
extern grpc_auth_property_iterator_next_type grpc_auth_property_iterator_next_import;
#define grpc_auth_property_iterator_next grpc_auth_property_iterator_next_import
@@ -479,7 +482,7 @@ extern grpc_byte_buffer_reader_readall_type grpc_byte_buffer_reader_readall_impo
typedef grpc_byte_buffer *(*grpc_raw_byte_buffer_from_reader_type)(grpc_byte_buffer_reader *reader);
extern grpc_raw_byte_buffer_from_reader_type grpc_raw_byte_buffer_from_reader_import;
#define grpc_raw_byte_buffer_from_reader grpc_raw_byte_buffer_from_reader_import
-typedef void(*gpr_log_type)(const char *file, int line, gpr_log_severity severity, const char *format, ...);
+typedef void(*gpr_log_type)(const char *file, int line, gpr_log_severity severity, const char *format, ...) GPRC_PRINT_FORMAT_CHECK(4, 5);
extern gpr_log_type gpr_log_import;
#define gpr_log gpr_log_import
typedef void(*gpr_log_message_type)(const char *file, int line, gpr_log_severity severity, const char *message);
@@ -728,6 +731,12 @@ extern gpr_avl_remove_type gpr_avl_remove_import;
typedef void *(*gpr_avl_get_type)(gpr_avl avl, void *key);
extern gpr_avl_get_type gpr_avl_get_import;
#define gpr_avl_get gpr_avl_get_import
+typedef int(*gpr_avl_maybe_get_type)(gpr_avl avl, void *key, void **value);
+extern gpr_avl_maybe_get_type gpr_avl_maybe_get_import;
+#define gpr_avl_maybe_get gpr_avl_maybe_get_import
+typedef int(*gpr_avl_is_empty_type)(gpr_avl avl);
+extern gpr_avl_is_empty_type gpr_avl_is_empty_import;
+#define gpr_avl_is_empty gpr_avl_is_empty_import
typedef gpr_cmdline *(*gpr_cmdline_create_type)(const char *description);
extern gpr_cmdline_create_type gpr_cmdline_create_import;
#define gpr_cmdline_create gpr_cmdline_create_import
@@ -818,7 +827,7 @@ extern gpr_format_message_type gpr_format_message_import;
typedef char *(*gpr_strdup_type)(const char *src);
extern gpr_strdup_type gpr_strdup_import;
#define gpr_strdup gpr_strdup_import
-typedef int(*gpr_asprintf_type)(char **strp, const char *format, ...);
+typedef int(*gpr_asprintf_type)(char **strp, const char *format, ...) GPRC_PRINT_FORMAT_CHECK(2, 3);
extern gpr_asprintf_type gpr_asprintf_import;
#define gpr_asprintf gpr_asprintf_import
typedef const char *(*gpr_subprocess_binary_extension_type)();
diff --git a/src/python/grpcio/grpc/_links/service.py b/src/python/grpcio/grpc/_links/service.py
index 11310e2240..5fc4994ca0 100644
--- a/src/python/grpcio/grpc/_links/service.py
+++ b/src/python/grpcio/grpc/_links/service.py
@@ -33,6 +33,7 @@ import abc
import enum
import logging
import threading
+import six
import time
from grpc._adapter import _intermediary_low
@@ -177,7 +178,10 @@ class _Kernel(object):
call = service_acceptance.call
call.accept(self._completion_queue, call)
try:
- group, method = service_acceptance.method.split(b'/')[1:3]
+ service_method = service_acceptance.method
+ if six.PY3:
+ service_method = service_method.decode('latin1')
+ group, method = service_method.split('/')[1:3]
except ValueError:
logging.info('Illegal path "%s"!', service_acceptance.method)
return
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py
index c65070f1b3..bf20f15f72 100644
--- a/src/python/grpcio/grpc/_server.py
+++ b/src/python/grpcio/grpc/_server.py
@@ -60,21 +60,34 @@ _CANCELLED = 'cancelled'
_EMPTY_FLAGS = 0
_EMPTY_METADATA = cygrpc.Metadata(())
+_UNEXPECTED_EXIT_SERVER_GRACE = 1.0
+
def _serialized_request(request_event):
return request_event.batch_operations[0].received_message.bytes()
-def _code(state):
+def _application_code(code):
+ cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
+ return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
+
+
+def _completion_code(state):
if state.code is None:
return cygrpc.StatusCode.ok
else:
- code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(state.code)
- return cygrpc.StatusCode.unknown if code is None else code
+ return _application_code(state.code)
+
+
+def _abortion_code(state, code):
+ if state.code is None:
+ return code
+ else:
+ return _application_code(state.code)
def _details(state):
- return b'' if state.details is None else state.details
+ return '' if state.details is None else state.details
class _HandlerCallDetails(
@@ -126,20 +139,22 @@ def _send_status_from_server(state, token):
def _abort(state, call, code, details):
if state.client is not _CANCELLED:
+ effective_code = _abortion_code(state, code)
+ effective_details = details if state.details is None else state.details
if state.initial_metadata_allowed:
operations = (
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
- _common.metadata(state.trailing_metadata), code, details,
- _EMPTY_FLAGS),
+ _common.metadata(state.trailing_metadata), effective_code,
+ effective_details, _EMPTY_FLAGS),
)
token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
else:
operations = (
cygrpc.operation_send_status_from_server(
- _common.metadata(state.trailing_metadata), code, details,
- _EMPTY_FLAGS),
+ _common.metadata(state.trailing_metadata), effective_code,
+ effective_details, _EMPTY_FLAGS),
)
token = _SEND_STATUS_FROM_SERVER_TOKEN
call.start_batch(
@@ -176,7 +191,7 @@ def _receive_message(state, call, request_deserializer):
if request is None:
_abort(
state, call, cygrpc.StatusCode.internal,
- b'Exception deserializing request!')
+ 'Exception deserializing request!')
else:
state.request = request
state.condition.notify_all()
@@ -241,7 +256,7 @@ class _Context(grpc.ServicerContext):
else:
if self._state.initial_metadata_allowed:
operation = cygrpc.operation_send_initial_metadata(
- cygrpc.Metadata(initial_metadata), _EMPTY_FLAGS)
+ _common.metadata(initial_metadata), _EMPTY_FLAGS)
self._rpc_event.operation_call.start_batch(
cygrpc.Operations((operation,)),
_send_initial_metadata(self._state))
@@ -327,12 +342,11 @@ def _unary_request(rpc_event, state, request_deserializer):
state.condition.wait()
if state.request is None:
if state.client is _CLOSED:
- details = b'"{}" requires exactly one request message.'.format(
+ details = '"{}" requires exactly one request message.'.format(
rpc_event.request_call_details.method)
- # TODO(5992#issuecomment-220761992): really, what status code?
_abort(
state, rpc_event.operation_call,
- cygrpc.StatusCode.unavailable, details)
+ cygrpc.StatusCode.unimplemented, details)
return None
elif state.client is _CANCELLED:
return None
@@ -346,15 +360,15 @@ def _unary_request(rpc_event, state, request_deserializer):
def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
context = _Context(rpc_event, state, request_deserializer)
try:
- return behavior(argument, context)
+ return behavior(argument, context), True
except Exception as e: # pylint: disable=broad-except
with state.condition:
if e not in state.rpc_errors:
- details = b'Exception calling application: {}'.format(e)
+ details = 'Exception calling application: {}'.format(e)
logging.exception(details)
_abort(
state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details)
- return None
+ return None, False
def _take_response_from_response_iterator(rpc_event, state, response_iterator):
@@ -365,7 +379,7 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator):
except Exception as e: # pylint: disable=broad-except
with state.condition:
if e not in state.rpc_errors:
- details = b'Exception iterating responses: {}'.format(e)
+ details = 'Exception iterating responses: {}'.format(e)
logging.exception(details)
_abort(
state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details)
@@ -378,7 +392,7 @@ def _serialize_response(rpc_event, state, response, response_serializer):
with state.condition:
_abort(
state, rpc_event.operation_call, cygrpc.StatusCode.internal,
- b'Failed to serialize response!')
+ 'Failed to serialize response!')
return None
else:
return serialized_response
@@ -415,7 +429,7 @@ def _status(rpc_event, state, serialized_response):
with state.condition:
if state.client is not _CANCELLED:
trailing_metadata = _common.metadata(state.trailing_metadata)
- code = _code(state)
+ code = _completion_code(state)
details = _details(state)
operations = [
cygrpc.operation_send_status_from_server(
@@ -440,9 +454,9 @@ def _unary_response_in_pool(
response_serializer):
argument = argument_thunk()
if argument is not None:
- response = _call_behavior(
+ response, proceed = _call_behavior(
rpc_event, state, behavior, argument, request_deserializer)
- if response is not None:
+ if proceed:
serialized_response = _serialize_response(
rpc_event, state, response, response_serializer)
if serialized_response is not None:
@@ -455,9 +469,9 @@ def _stream_response_in_pool(
response_serializer):
argument = argument_thunk()
if argument is not None:
- response_iterator = _call_behavior(
+ response_iterator, proceed = _call_behavior(
rpc_event, state, behavior, argument, request_deserializer)
- if response_iterator is not None:
+ if proceed:
while True:
response, proceed = _take_response_from_response_iterator(
rpc_event, state, response_iterator)
@@ -531,7 +545,7 @@ def _handle_unrecognized_method(rpc_event):
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
_EMPTY_METADATA, cygrpc.StatusCode.unimplemented,
- b'Method not found!', _EMPTY_FLAGS),
+ 'Method not found!', _EMPTY_FLAGS),
)
rpc_state = _RPCState()
rpc_event.operation_call.start_batch(
@@ -657,17 +671,6 @@ def _serve(state):
return
-def _start(state):
- with state.lock:
- if state.stage is not _ServerStage.STOPPED:
- raise ValueError('Cannot start already-started server!')
- state.server.start()
- state.stage = _ServerStage.STARTED
- _request_call(state)
- thread = threading.Thread(target=_serve, args=(state,))
- thread.start()
-
-
def _stop(state, grace):
with state.lock:
if state.stage is _ServerStage.STOPPED:
@@ -706,6 +709,24 @@ def _stop(state, grace):
return shutdown_event
+def _start(state):
+ with state.lock:
+ if state.stage is not _ServerStage.STOPPED:
+ raise ValueError('Cannot start already-started server!')
+ state.server.start()
+ state.stage = _ServerStage.STARTED
+ _request_call(state)
+ def cleanup_server(timeout):
+ if timeout is None:
+ _stop(state, _UNEXPECTED_EXIT_SERVER_GRACE).wait()
+ else:
+ _stop(state, timeout).wait()
+
+ thread = _common.CleanupThread(
+ cleanup_server, target=_serve, args=(state,))
+ thread.start()
+
+
class Server(grpc.Server):
def __init__(self, generic_handlers, thread_pool):
diff --git a/src/python/grpcio/grpc/_utilities.py b/src/python/grpcio/grpc/_utilities.py
index a4ca9b7282..4850967fbc 100644
--- a/src/python/grpcio/grpc/_utilities.py
+++ b/src/python/grpcio/grpc/_utilities.py
@@ -29,16 +29,41 @@
"""Internal utilities for gRPC Python."""
+import collections
import threading
import time
+import six
+
import grpc
+from grpc import _common
from grpc.framework.foundation import callable_util
_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = (
'Exception calling connectivity future "done" callback!')
+class RpcMethodHandler(
+ collections.namedtuple(
+ '_RpcMethodHandler',
+ ('request_streaming', 'response_streaming', 'request_deserializer',
+ 'response_serializer', 'unary_unary', 'unary_stream', 'stream_unary',
+ 'stream_stream',)),
+ grpc.RpcMethodHandler):
+ pass
+
+
+class DictionaryGenericHandler(grpc.GenericRpcHandler):
+
+ def __init__(self, service, method_handlers):
+ self._method_handlers = {
+ _common.fully_qualified_method(service, method): method_handler
+ for method, method_handler in six.iteritems(method_handlers)}
+
+ def service(self, handler_call_details):
+ return self._method_handlers.get(handler_call_details.method)
+
+
class _ChannelReadyFuture(grpc.Future):
def __init__(self, channel):
@@ -144,4 +169,3 @@ def channel_ready_future(channel):
ready_future = _ChannelReadyFuture(channel)
ready_future.start()
return ready_future
-
diff --git a/src/python/grpcio/grpc/beta/_client_adaptations.py b/src/python/grpcio/grpc/beta/_client_adaptations.py
new file mode 100644
index 0000000000..56456cc117
--- /dev/null
+++ b/src/python/grpcio/grpc/beta/_client_adaptations.py
@@ -0,0 +1,563 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Translates gRPC's client-side API into gRPC's client-side Beta API."""
+
+import grpc
+from grpc import _common
+from grpc._cython import cygrpc
+from grpc.beta import interfaces
+from grpc.framework.common import cardinality
+from grpc.framework.foundation import future
+from grpc.framework.interfaces.face import face
+
+_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = {
+ grpc.StatusCode.CANCELLED: (
+ face.Abortion.Kind.CANCELLED, face.CancellationError),
+ grpc.StatusCode.UNKNOWN: (
+ face.Abortion.Kind.REMOTE_FAILURE, face.RemoteError),
+ grpc.StatusCode.DEADLINE_EXCEEDED: (
+ face.Abortion.Kind.EXPIRED, face.ExpirationError),
+ grpc.StatusCode.UNIMPLEMENTED: (
+ face.Abortion.Kind.LOCAL_FAILURE, face.LocalError),
+}
+
+
+def _effective_metadata(metadata, metadata_transformer):
+ non_none_metadata = () if metadata is None else metadata
+ if metadata_transformer is None:
+ return non_none_metadata
+ else:
+ return metadata_transformer(non_none_metadata)
+
+
+def _credentials(grpc_call_options):
+ return None if grpc_call_options is None else grpc_call_options.credentials
+
+
+def _abortion(rpc_error_call):
+ code = rpc_error_call.code()
+ pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
+ error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0]
+ return face.Abortion(
+ error_kind, rpc_error_call.initial_metadata(),
+ rpc_error_call.trailing_metadata(), code, rpc_error_code.details())
+
+
+def _abortion_error(rpc_error_call):
+ code = rpc_error_call.code()
+ pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
+ exception_class = face.AbortionError if pair is None else pair[1]
+ return exception_class(
+ rpc_error_call.initial_metadata(), rpc_error_call.trailing_metadata(),
+ code, rpc_error_call.details())
+
+
+class _InvocationProtocolContext(interfaces.GRPCInvocationContext):
+
+ def disable_next_request_compression(self):
+ pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
+
+
+class _Rendezvous(future.Future, face.Call):
+
+ def __init__(self, response_future, response_iterator, call):
+ self._future = response_future
+ self._iterator = response_iterator
+ self._call = call
+
+ def cancel(self):
+ return self._call.cancel()
+
+ def cancelled(self):
+ return self._future.cancelled()
+
+ def running(self):
+ return self._future.running()
+
+ def done(self):
+ return self._future.done()
+
+ def result(self, timeout=None):
+ try:
+ return self._future.result(timeout=timeout)
+ except grpc.RpcError as rpc_error_call:
+ raise _abortion_error(rpc_error_call)
+ except grpc.FutureTimeoutError:
+ raise future.TimeoutError()
+ except grpc.FutureCancelledError:
+ raise future.CancelledError()
+
+ def exception(self, timeout=None):
+ try:
+ rpc_error_call = self._future.exception(timeout=timeout)
+ return _abortion_error(rpc_error_call)
+ except grpc.FutureTimeoutError:
+ raise future.TimeoutError()
+ except grpc.FutureCancelledError:
+ raise future.CancelledError()
+
+ def traceback(self, timeout=None):
+ try:
+ return self._future.traceback(timeout=timeout)
+ except grpc.FutureTimeoutError:
+ raise future.TimeoutError()
+ except grpc.FutureCancelledError:
+ raise future.CancelledError()
+
+ def add_done_callback(self, fn):
+ self._future.add_done_callback(lambda ignored_callback: fn(self))
+
+ def __iter__(self):
+ return self
+
+ def _next(self):
+ try:
+ return next(self._iterator)
+ except grpc.RpcError as rpc_error_call:
+ raise _abortion_error(rpc_error_call)
+
+ def __next__(self):
+ return self._next()
+
+ def next(self):
+ return self._next()
+
+ def is_active(self):
+ return self._call.is_active()
+
+ def time_remaining(self):
+ return self._call.time_remaining()
+
+ def add_abortion_callback(self, abortion_callback):
+ registered = self._call.add_callback(
+ lambda: abortion_callback(_abortion(self._call)))
+ return None if registered else _abortion(self._call)
+
+ def protocol_context(self):
+ return _InvocationProtocolContext()
+
+ def initial_metadata(self):
+ return self._call.initial_metadata()
+
+ def terminal_metadata(self):
+ return self._call.terminal_metadata()
+
+ def code(self):
+ return self._call.code()
+
+ def details(self):
+ return self._call.details()
+
+
+def _blocking_unary_unary(
+ channel, group, method, timeout, with_call, protocol_options, metadata,
+ metadata_transformer, request, request_serializer, response_deserializer):
+ try:
+ multi_callable = channel.unary_unary(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ if with_call:
+ response, call = multi_callable.with_call(
+ request, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ return response, _Rendezvous(None, None, call)
+ else:
+ return multi_callable(
+ request, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ except grpc.RpcError as rpc_error_call:
+ raise _abortion_error(rpc_error_call)
+
+
+def _future_unary_unary(
+ channel, group, method, timeout, protocol_options, metadata,
+ metadata_transformer, request, request_serializer, response_deserializer):
+ multi_callable = channel.unary_unary(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ response_future = multi_callable.future(
+ request, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ return _Rendezvous(response_future, None, response_future)
+
+
+def _unary_stream(
+ channel, group, method, timeout, protocol_options, metadata,
+ metadata_transformer, request, request_serializer, response_deserializer):
+ multi_callable = channel.unary_stream(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ response_iterator = multi_callable(
+ request, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ return _Rendezvous(None, response_iterator, response_iterator)
+
+
+def _blocking_stream_unary(
+ channel, group, method, timeout, with_call, protocol_options, metadata,
+ metadata_transformer, request_iterator, request_serializer,
+ response_deserializer):
+ try:
+ multi_callable = channel.stream_unary(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ if with_call:
+ response, call = multi_callable.with_call(
+ request_iterator, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ return response, _Rendezvous(None, None, call)
+ else:
+ return multi_callable(
+ request_iterator, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ except grpc.RpcError as rpc_error_call:
+ raise _abortion_error(rpc_error_call)
+
+
+def _future_stream_unary(
+ channel, group, method, timeout, protocol_options, metadata,
+ metadata_transformer, request_iterator, request_serializer,
+ response_deserializer):
+ multi_callable = channel.stream_unary(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ response_future = multi_callable.future(
+ request_iterator, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ return _Rendezvous(response_future, None, response_future)
+
+
+def _stream_stream(
+ channel, group, method, timeout, protocol_options, metadata,
+ metadata_transformer, request_iterator, request_serializer,
+ response_deserializer):
+ multi_callable = channel.stream_stream(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ response_iterator = multi_callable(
+ request_iterator, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ return _Rendezvous(None, response_iterator, response_iterator)
+
+
+class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
+
+ def __init__(
+ self, channel, group, method, metadata_transformer, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._group = group
+ self._method = method
+ self._metadata_transformer = metadata_transformer
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(
+ self, request, timeout, metadata=None, with_call=False,
+ protocol_options=None):
+ return _blocking_unary_unary(
+ self._channel, self._group, self._method, timeout, with_call,
+ protocol_options, metadata, self._metadata_transformer, request,
+ self._request_serializer, self._response_deserializer)
+
+ def future(self, request, timeout, metadata=None, protocol_options=None):
+ return _future_unary_unary(
+ self._channel, self._group, self._method, timeout, protocol_options,
+ metadata, self._metadata_transformer, request, self._request_serializer,
+ self._response_deserializer)
+
+ def event(
+ self, request, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+
+class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable):
+
+ def __init__(
+ self, channel, group, method, metadata_transformer, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._group = group
+ self._method = method
+ self._metadata_transformer = metadata_transformer
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(self, request, timeout, metadata=None, protocol_options=None):
+ return _unary_stream(
+ self._channel, self._group, self._method, timeout, protocol_options,
+ metadata, self._metadata_transformer, request, self._request_serializer,
+ self._response_deserializer)
+
+ def event(
+ self, request, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+
+class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable):
+
+ def __init__(
+ self, channel, group, method, metadata_transformer, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._group = group
+ self._method = method
+ self._metadata_transformer = metadata_transformer
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(
+ self, request_iterator, timeout, metadata=None, with_call=False,
+ protocol_options=None):
+ return _blocking_stream_unary(
+ self._channel, self._group, self._method, timeout, with_call,
+ protocol_options, metadata, self._metadata_transformer,
+ request_iterator, self._request_serializer, self._response_deserializer)
+
+ def future(
+ self, request_iterator, timeout, metadata=None, protocol_options=None):
+ return _future_stream_unary(
+ self._channel, self._group, self._method, timeout, protocol_options,
+ metadata, self._metadata_transformer, request_iterator,
+ self._request_serializer, self._response_deserializer)
+
+ def event(
+ self, receiver, abortion_callback, timeout, metadata=None,
+ protocol_options=None):
+ raise NotImplementedError()
+
+
+class _StreamStreamMultiCallable(face.StreamStreamMultiCallable):
+
+ def __init__(
+ self, channel, group, method, metadata_transformer, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._group = group
+ self._method = method
+ self._metadata_transformer = metadata_transformer
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(
+ self, request_iterator, timeout, metadata=None, protocol_options=None):
+ return _stream_stream(
+ self._channel, self._group, self._method, timeout, protocol_options,
+ metadata, self._metadata_transformer, request_iterator,
+ self._request_serializer, self._response_deserializer)
+
+ def event(
+ self, receiver, abortion_callback, timeout, metadata=None,
+ protocol_options=None):
+ raise NotImplementedError()
+
+
+class _GenericStub(face.GenericStub):
+
+ def __init__(
+ self, channel, metadata_transformer, request_serializers,
+ response_deserializers):
+ self._channel = channel
+ self._metadata_transformer = metadata_transformer
+ self._request_serializers = request_serializers or {}
+ self._response_deserializers = response_deserializers or {}
+
+ def blocking_unary_unary(
+ self, group, method, request, timeout, metadata=None,
+ with_call=None, protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _blocking_unary_unary(
+ self._channel, group, method, timeout, with_call, protocol_options,
+ metadata, self._metadata_transformer, request, request_serializer,
+ response_deserializer)
+
+ def future_unary_unary(
+ self, group, method, request, timeout, metadata=None,
+ protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _future_unary_unary(
+ self._channel, group, method, timeout, protocol_options, metadata,
+ self._metadata_transformer, request, request_serializer,
+ response_deserializer)
+
+ def inline_unary_stream(
+ self, group, method, request, timeout, metadata=None,
+ protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _unary_stream(
+ self._channel, group, method, timeout, protocol_options, metadata,
+ self._metadata_transformer, request, request_serializer,
+ response_deserializer)
+
+ def blocking_stream_unary(
+ self, group, method, request_iterator, timeout, metadata=None,
+ with_call=None, protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _blocking_stream_unary(
+ self._channel, group, method, timeout, with_call, protocol_options,
+ metadata, self._metadata_transformer, request_iterator,
+ request_serializer, response_deserializer)
+
+ def future_stream_unary(
+ self, group, method, request_iterator, timeout, metadata=None,
+ protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _future_stream_unary(
+ self._channel, group, method, timeout, protocol_options, metadata,
+ self._metadata_transformer, request_iterator, request_serializer,
+ response_deserializer)
+
+ def inline_stream_stream(
+ self, group, method, request_iterator, timeout, metadata=None,
+ protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _stream_stream(
+ self._channel, group, method, timeout, protocol_options, metadata,
+ self._metadata_transformer, request_iterator, request_serializer,
+ response_deserializer)
+
+ def event_unary_unary(
+ self, group, method, request, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+ def event_unary_stream(
+ self, group, method, request, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+ def event_stream_unary(
+ self, group, method, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+ def event_stream_stream(
+ self, group, method, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+ def unary_unary(self, group, method):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _UnaryUnaryMultiCallable(
+ self._channel, group, method, self._metadata_transformer,
+ request_serializer, response_deserializer)
+
+ def unary_stream(self, group, method):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _UnaryStreamMultiCallable(
+ self._channel, group, method, self._metadata_transformer,
+ request_serializer, response_deserializer)
+
+ def stream_unary(self, group, method):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _StreamUnaryMultiCallable(
+ self._channel, group, method, self._metadata_transformer,
+ request_serializer, response_deserializer)
+
+ def stream_stream(self, group, method):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _StreamStreamMultiCallable(
+ self._channel, group, method, self._metadata_transformer,
+ request_serializer, response_deserializer)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ return False
+
+
+class _DynamicStub(face.DynamicStub):
+
+ def __init__(self, generic_stub, group, cardinalities):
+ self._generic_stub = generic_stub
+ self._group = group
+ self._cardinalities = cardinalities
+
+ def __getattr__(self, attr):
+ method_cardinality = self._cardinalities.get(attr)
+ if method_cardinality is cardinality.Cardinality.UNARY_UNARY:
+ return self._generic_stub.unary_unary(self._group, attr)
+ elif method_cardinality is cardinality.Cardinality.UNARY_STREAM:
+ return self._generic_stub.unary_stream(self._group, attr)
+ elif method_cardinality is cardinality.Cardinality.STREAM_UNARY:
+ return self._generic_stub.stream_unary(self._group, attr)
+ elif method_cardinality is cardinality.Cardinality.STREAM_STREAM:
+ return self._generic_stub.stream_stream(self._group, attr)
+ else:
+ raise AttributeError('_DynamicStub object has no attribute "%s"!' % attr)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ return False
+
+
+def generic_stub(
+ channel, host, metadata_transformer, request_serializers,
+ response_deserializers):
+ return _GenericStub(
+ channel, metadata_transformer, request_serializers,
+ response_deserializers)
+
+
+def dynamic_stub(
+ channel, service, cardinalities, host, metadata_transformer,
+ request_serializers, response_deserializers):
+ return _DynamicStub(
+ _GenericStub(
+ channel, metadata_transformer, request_serializers,
+ response_deserializers),
+ service, cardinalities)
diff --git a/src/python/grpcio/grpc/beta/_server_adaptations.py b/src/python/grpcio/grpc/beta/_server_adaptations.py
new file mode 100644
index 0000000000..79e6ca87eb
--- /dev/null
+++ b/src/python/grpcio/grpc/beta/_server_adaptations.py
@@ -0,0 +1,367 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Translates gRPC's server-side API into gRPC's server-side Beta API."""
+
+import collections
+import threading
+
+import grpc
+from grpc import _common
+from grpc.beta import interfaces
+from grpc.framework.common import cardinality
+from grpc.framework.common import style
+from grpc.framework.foundation import abandonment
+from grpc.framework.foundation import logging_pool
+from grpc.framework.foundation import stream
+from grpc.framework.interfaces.face import face
+
+_DEFAULT_POOL_SIZE = 8
+
+
+class _ServerProtocolContext(interfaces.GRPCServicerContext):
+
+ def __init__(self, servicer_context):
+ self._servicer_context = servicer_context
+
+ def peer(self):
+ return self._servicer_context.peer()
+
+ def disable_next_response_compression(self):
+ pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
+
+
+class _FaceServicerContext(face.ServicerContext):
+
+ def __init__(self, servicer_context):
+ self._servicer_context = servicer_context
+
+ def is_active(self):
+ return self._servicer_context.is_active()
+
+ def time_remaining(self):
+ return self._servicer_context.time_remaining()
+
+ def add_abortion_callback(self, abortion_callback):
+ raise NotImplementedError(
+ 'add_abortion_callback no longer supported server-side!')
+
+ def cancel(self):
+ self._servicer_context.cancel()
+
+ def protocol_context(self):
+ return _ServerProtocolContext(self._servicer_context)
+
+ def invocation_metadata(self):
+ return self._servicer_context.invocation_metadata()
+
+ def initial_metadata(self, initial_metadata):
+ self._servicer_context.send_initial_metadata(initial_metadata)
+
+ def terminal_metadata(self, terminal_metadata):
+ self._servicer_context.set_terminal_metadata(terminal_metadata)
+
+ def code(self, code):
+ self._servicer_context.set_code(code)
+
+ def details(self, details):
+ self._servicer_context.set_details(details)
+
+
+def _adapt_unary_request_inline(unary_request_inline):
+ def adaptation(request, servicer_context):
+ return unary_request_inline(request, _FaceServicerContext(servicer_context))
+ return adaptation
+
+
+def _adapt_stream_request_inline(stream_request_inline):
+ def adaptation(request_iterator, servicer_context):
+ return stream_request_inline(
+ request_iterator, _FaceServicerContext(servicer_context))
+ return adaptation
+
+
+class _Callback(stream.Consumer):
+
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._values = []
+ self._terminated = False
+ self._cancelled = False
+
+ def consume(self, value):
+ with self._condition:
+ self._values.append(value)
+ self._condition.notify_all()
+
+ def terminate(self):
+ with self._condition:
+ self._terminated = True
+ self._condition.notify_all()
+
+ def consume_and_terminate(self, value):
+ with self._condition:
+ self._values.append(value)
+ self._terminated = True
+ self._condition.notify_all()
+
+ def cancel(self):
+ with self._condition:
+ self._cancelled = True
+ self._condition.notify_all()
+
+ def draw_one_value(self):
+ with self._condition:
+ while True:
+ if self._cancelled:
+ raise abandonment.Abandoned()
+ elif self._values:
+ return self._values.pop(0)
+ elif self._terminated:
+ return None
+ else:
+ self._condition.wait()
+
+ def draw_all_values(self):
+ with self._condition:
+ while True:
+ if self._cancelled:
+ raise abandonment.Abandoned()
+ elif self._terminated:
+ all_values = tuple(self._values)
+ self._values = None
+ return all_values
+ else:
+ self._condition.wait()
+
+
+def _pipe_requests(request_iterator, request_consumer, servicer_context):
+ for request in request_iterator:
+ if not servicer_context.is_active():
+ return
+ request_consumer.consume(request)
+ if not servicer_context.is_active():
+ return
+ request_consumer.terminate()
+
+
+def _adapt_unary_unary_event(unary_unary_event):
+ def adaptation(request, servicer_context):
+ callback = _Callback()
+ if not servicer_context.add_callback(callback.cancel):
+ raise abandonment.Abandoned()
+ unary_unary_event(
+ request, callback.consume_and_terminate,
+ _FaceServicerContext(servicer_context))
+ return callback.draw_all_values()[0]
+ return adaptation
+
+
+def _adapt_unary_stream_event(unary_stream_event):
+ def adaptation(request, servicer_context):
+ callback = _Callback()
+ if not servicer_context.add_callback(callback.cancel):
+ raise abandonment.Abandoned()
+ unary_stream_event(
+ request, callback, _FaceServicerContext(servicer_context))
+ while True:
+ response = callback.draw_one_value()
+ if response is None:
+ return
+ else:
+ yield response
+ return adaptation
+
+
+def _adapt_stream_unary_event(stream_unary_event):
+ def adaptation(request_iterator, servicer_context):
+ callback = _Callback()
+ if not servicer_context.add_callback(callback.cancel):
+ raise abandonment.Abandoned()
+ request_consumer = stream_unary_event(
+ callback.consume_and_terminate, _FaceServicerContext(servicer_context))
+ request_pipe_thread = threading.Thread(
+ target=_pipe_requests,
+ args=(request_iterator, request_consumer, servicer_context,))
+ request_pipe_thread.start()
+ return callback.draw_all_values()[0]
+ return adaptation
+
+
+def _adapt_stream_stream_event(stream_stream_event):
+ def adaptation(request_iterator, servicer_context):
+ callback = _Callback()
+ if not servicer_context.add_callback(callback.cancel):
+ raise abandonment.Abandoned()
+ request_consumer = stream_stream_event(
+ callback, _FaceServicerContext(servicer_context))
+ request_pipe_thread = threading.Thread(
+ target=_pipe_requests,
+ args=(request_iterator, request_consumer, servicer_context,))
+ request_pipe_thread.start()
+ while True:
+ response = callback.draw_one_value()
+ if response is None:
+ return
+ else:
+ yield response
+ return adaptation
+
+
+class _SimpleMethodHandler(
+ collections.namedtuple(
+ '_MethodHandler',
+ ('request_streaming', 'response_streaming', 'request_deserializer',
+ 'response_serializer', 'unary_unary', 'unary_stream', 'stream_unary',
+ 'stream_stream',)),
+ grpc.RpcMethodHandler):
+ pass
+
+
+def _simple_method_handler(
+ implementation, request_deserializer, response_serializer):
+ if implementation.style is style.Service.INLINE:
+ if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
+ return _SimpleMethodHandler(
+ False, False, request_deserializer, response_serializer,
+ _adapt_unary_request_inline(implementation.unary_unary_inline), None,
+ None, None)
+ elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
+ return _SimpleMethodHandler(
+ False, True, request_deserializer, response_serializer, None,
+ _adapt_unary_request_inline(implementation.unary_stream_inline), None,
+ None)
+ elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
+ return _SimpleMethodHandler(
+ True, False, request_deserializer, response_serializer, None, None,
+ _adapt_stream_request_inline(implementation.stream_unary_inline),
+ None)
+ elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
+ return _SimpleMethodHandler(
+ True, True, request_deserializer, response_serializer, None, None,
+ None,
+ _adapt_stream_request_inline(implementation.stream_stream_inline))
+ elif implementation.style is style.Service.EVENT:
+ if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
+ return _SimpleMethodHandler(
+ False, False, request_deserializer, response_serializer,
+ _adapt_unary_unary_event(implementation.unary_unary_event), None,
+ None, None)
+ elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
+ return _SimpleMethodHandler(
+ False, True, request_deserializer, response_serializer, None,
+ _adapt_unary_stream_event(implementation.unary_stream_event), None,
+ None)
+ elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
+ return _SimpleMethodHandler(
+ True, False, request_deserializer, response_serializer, None, None,
+ _adapt_stream_unary_event(implementation.stream_unary_event), None)
+ elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
+ return _SimpleMethodHandler(
+ True, True, request_deserializer, response_serializer, None, None,
+ None, _adapt_stream_stream_event(implementation.stream_stream_event))
+
+
+def _flatten_method_pair_map(method_pair_map):
+ method_pair_map = method_pair_map or {}
+ flat_map = {}
+ for method_pair in method_pair_map:
+ method = _common.fully_qualified_method(method_pair[0], method_pair[1])
+ flat_map[method] = method_pair_map[method_pair]
+ return flat_map
+
+
+class _GenericRpcHandler(grpc.GenericRpcHandler):
+
+ def __init__(
+ self, method_implementations, multi_method_implementation,
+ request_deserializers, response_serializers):
+ self._method_implementations = _flatten_method_pair_map(
+ method_implementations)
+ self._request_deserializers = _flatten_method_pair_map(
+ request_deserializers)
+ self._response_serializers = _flatten_method_pair_map(
+ response_serializers)
+ self._multi_method_implementation = multi_method_implementation
+
+ def service(self, handler_call_details):
+ method_implementation = self._method_implementations.get(
+ handler_call_details.method)
+ if method_implementation is not None:
+ return _simple_method_handler(
+ method_implementation,
+ self._request_deserializers.get(handler_call_details.method),
+ self._response_serializers.get(handler_call_details.method))
+ elif self._multi_method_implementation is None:
+ return None
+ else:
+ try:
+ return None #TODO(nathaniel): call the multimethod.
+ except face.NoSuchMethodError:
+ return None
+
+
+class _Server(interfaces.Server):
+
+ def __init__(self, server):
+ self._server = server
+
+ def add_insecure_port(self, address):
+ return self._server.add_insecure_port(address)
+
+ def add_secure_port(self, address, server_credentials):
+ return self._server.add_secure_port(address, server_credentials)
+
+ def start(self):
+ self._server.start()
+
+ def stop(self, grace):
+ return self._server.stop(grace)
+
+ def __enter__(self):
+ self._server.start()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._server.stop(None)
+ return False
+
+
+def server(
+ service_implementations, multi_method_implementation, request_deserializers,
+ response_serializers, thread_pool, thread_pool_size):
+ generic_rpc_handler = _GenericRpcHandler(
+ service_implementations, multi_method_implementation,
+ request_deserializers, response_serializers)
+ if thread_pool is None:
+ effective_thread_pool = logging_pool.pool(
+ _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size)
+ else:
+ effective_thread_pool = thread_pool
+ return _Server(grpc.server((generic_rpc_handler,), effective_thread_pool))
diff --git a/src/python/grpcio/grpc/beta/implementations.py b/src/python/grpcio/grpc/beta/implementations.py
index d8c32dd2f5..4ae6e7d675 100644
--- a/src/python/grpcio/grpc/beta/implementations.py
+++ b/src/python/grpcio/grpc/beta/implementations.py
@@ -35,83 +35,20 @@ import enum
import threading # pylint: disable=unused-import
# cardinality and face are referenced from specification in this module.
-from grpc._adapter import _intermediary_low
-from grpc._adapter import _low
+import grpc
+from grpc import _auth
from grpc._adapter import _types
-from grpc.beta import _auth
-from grpc.beta import _connectivity_channel
-from grpc.beta import _server
-from grpc.beta import _stub
+from grpc.beta import _client_adaptations
+from grpc.beta import _server_adaptations
from grpc.beta import interfaces
from grpc.framework.common import cardinality # pylint: disable=unused-import
from grpc.framework.interfaces.face import face # pylint: disable=unused-import
-_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
- 'Exception calling channel subscription callback!')
-
-class ChannelCredentials(object):
- """A value encapsulating the data required to create a secure Channel.
-
- This class and its instances have no supported interface - it exists to define
- the type of its instances and its instances exist to be passed to other
- functions.
- """
-
- def __init__(self, low_credentials):
- self._low_credentials = low_credentials
-
-
-def ssl_channel_credentials(root_certificates=None, private_key=None,
- certificate_chain=None):
- """Creates a ChannelCredentials for use with an SSL-enabled Channel.
-
- Args:
- root_certificates: The PEM-encoded root certificates or unset to ask for
- them to be retrieved from a default location.
- private_key: The PEM-encoded private key to use or unset if no private key
- should be used.
- certificate_chain: The PEM-encoded certificate chain to use or unset if no
- certificate chain should be used.
-
- Returns:
- A ChannelCredentials for use with an SSL-enabled Channel.
- """
- return ChannelCredentials(_low.channel_credentials_ssl(
- root_certificates, private_key, certificate_chain))
-
-
-class CallCredentials(object):
- """A value encapsulating data asserting an identity over an *established*
- channel. May be composed with ChannelCredentials to always assert identity for
- every call over that channel.
-
- This class and its instances have no supported interface - it exists to define
- the type of its instances and its instances exist to be passed to other
- functions.
- """
-
- def __init__(self, low_credentials):
- self._low_credentials = low_credentials
-
-
-def metadata_call_credentials(metadata_plugin, name=None):
- """Construct CallCredentials from an interfaces.GRPCAuthMetadataPlugin.
-
- Args:
- metadata_plugin: An interfaces.GRPCAuthMetadataPlugin to use in constructing
- the CallCredentials object.
-
- Returns:
- A CallCredentials object for use in a GRPCCallOptions object.
- """
- if name is None:
- try:
- name = metadata_plugin.__name__
- except AttributeError:
- name = metadata_plugin.__class__.__name__
- return CallCredentials(
- _low.call_credentials_metadata_plugin(metadata_plugin, name))
+ChannelCredentials = grpc.ChannelCredentials
+ssl_channel_credentials = grpc.ssl_channel_credentials
+CallCredentials = grpc.CallCredentials
+metadata_call_credentials = grpc.metadata_call_credentials
def google_call_credentials(credentials):
@@ -125,53 +62,9 @@ def google_call_credentials(credentials):
"""
return metadata_call_credentials(_auth.GoogleCallCredentials(credentials))
-
-def access_token_call_credentials(access_token):
- """Construct CallCredentials from an access token.
-
- Args:
- access_token: A string to place directly in the http request
- authorization header, ie "Authorization: Bearer <access_token>".
-
- Returns:
- A CallCredentials object for use in a GRPCCallOptions object.
- """
- return metadata_call_credentials(
- _auth.AccessTokenCallCredentials(access_token))
-
-
-def composite_call_credentials(call_credentials, additional_call_credentials):
- """Compose two CallCredentials to make a new one.
-
- Args:
- call_credentials: A CallCredentials object.
- additional_call_credentials: Another CallCredentials object to compose on
- top of call_credentials.
-
- Returns:
- A CallCredentials object for use in a GRPCCallOptions object.
- """
- return CallCredentials(
- _low.call_credentials_composite(
- call_credentials._low_credentials,
- additional_call_credentials._low_credentials))
-
-def composite_channel_credentials(channel_credentials,
- additional_call_credentials):
- """Compose ChannelCredentials on top of client credentials to make a new one.
-
- Args:
- channel_credentials: A ChannelCredentials object.
- additional_call_credentials: A CallCredentials object to compose on
- top of channel_credentials.
-
- Returns:
- A ChannelCredentials object for use in a GRPCCallOptions object.
- """
- return ChannelCredentials(
- _low.channel_credentials_composite(
- channel_credentials._low_credentials,
- additional_call_credentials._low_credentials))
+access_token_call_credentials = grpc.access_token_call_credentials
+composite_call_credentials = grpc.composite_call_credentials
+composite_channel_credentials = grpc.composite_channel_credentials
class Channel(object):
@@ -182,11 +75,8 @@ class Channel(object):
unsupported.
"""
- def __init__(self, low_channel, intermediary_low_channel):
- self._low_channel = low_channel
- self._intermediary_low_channel = intermediary_low_channel
- self._connectivity_channel = _connectivity_channel.ConnectivityChannel(
- low_channel)
+ def __init__(self, channel):
+ self._channel = channel
def subscribe(self, callback, try_to_connect=None):
"""Subscribes to this Channel's connectivity.
@@ -201,7 +91,7 @@ class Channel(object):
attempt to connect if it is not already connected and ready to conduct
RPCs.
"""
- self._connectivity_channel.subscribe(callback, try_to_connect)
+ self._channel.subscribe(callback, try_to_connect=try_to_connect)
def unsubscribe(self, callback):
"""Unsubscribes a callback from this Channel's connectivity.
@@ -210,7 +100,7 @@ class Channel(object):
callback: A callable previously registered with this Channel from having
been passed to its "subscribe" method.
"""
- self._connectivity_channel.unsubscribe(callback)
+ self._channel.unsubscribe(callback)
def insecure_channel(host, port):
@@ -224,9 +114,9 @@ def insecure_channel(host, port):
Returns:
A Channel to the remote host through which RPCs may be conducted.
"""
- intermediary_low_channel = _intermediary_low.Channel(
- '%s:%d' % (host, port) if port else host, None)
- return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access
+ channel = grpc.insecure_channel(
+ host if port is None else '%s:%d' % (host, port))
+ return Channel(channel)
def secure_channel(host, port, channel_credentials):
@@ -241,10 +131,9 @@ def secure_channel(host, port, channel_credentials):
Returns:
A secure Channel to the remote host through which RPCs may be conducted.
"""
- intermediary_low_channel = _intermediary_low.Channel(
- '%s:%d' % (host, port) if port else host,
- channel_credentials._low_credentials)
- return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access
+ channel = grpc.secure_channel(
+ host if port is None else '%s:%d' % (host, port), channel_credentials)
+ return Channel(channel)
class StubOptions(object):
@@ -308,12 +197,11 @@ def generic_stub(channel, options=None):
A face.GenericStub on which RPCs can be made.
"""
effective_options = _EMPTY_STUB_OPTIONS if options is None else options
- return _stub.generic_stub(
- channel._intermediary_low_channel, effective_options.host, # pylint: disable=protected-access
- effective_options.metadata_transformer,
+ return _client_adaptations.generic_stub(
+ channel._channel, # pylint: disable=protected-access
+ effective_options.host, effective_options.metadata_transformer,
effective_options.request_serializers,
- effective_options.response_deserializers, effective_options.thread_pool,
- effective_options.thread_pool_size)
+ effective_options.response_deserializers)
def dynamic_stub(channel, service, cardinalities, options=None):
@@ -331,55 +219,16 @@ def dynamic_stub(channel, service, cardinalities, options=None):
A face.DynamicStub with which RPCs can be invoked.
"""
effective_options = StubOptions() if options is None else options
- return _stub.dynamic_stub(
- channel._intermediary_low_channel, effective_options.host, service, # pylint: disable=protected-access
- cardinalities, effective_options.metadata_transformer,
+ return _client_adaptations.dynamic_stub(
+ channel._channel, # pylint: disable=protected-access
+ service, cardinalities, effective_options.host,
+ effective_options.metadata_transformer,
effective_options.request_serializers,
- effective_options.response_deserializers, effective_options.thread_pool,
- effective_options.thread_pool_size)
-
-
-class ServerCredentials(object):
- """A value encapsulating the data required to open a secure port on a Server.
+ effective_options.response_deserializers)
- This class and its instances have no supported interface - it exists to define
- the type of its instances and its instances exist to be passed to other
- functions.
- """
- def __init__(self, low_credentials):
- self._low_credentials = low_credentials
-
-
-def ssl_server_credentials(
- private_key_certificate_chain_pairs, root_certificates=None,
- require_client_auth=False):
- """Creates a ServerCredentials for use with an SSL-enabled Server.
-
- Args:
- private_key_certificate_chain_pairs: A nonempty sequence each element of
- which is a pair the first element of which is a PEM-encoded private key
- and the second element of which is the corresponding PEM-encoded
- certificate chain.
- root_certificates: PEM-encoded client root certificates to be used for
- verifying authenticated clients. If omitted, require_client_auth must also
- be omitted or be False.
- require_client_auth: A boolean indicating whether or not to require clients
- to be authenticated. May only be True if root_certificates is not None.
-
- Returns:
- A ServerCredentials for use with an SSL-enabled Server.
- """
- if len(private_key_certificate_chain_pairs) == 0:
- raise ValueError(
- 'At least one private key-certificate chain pairis required!')
- elif require_client_auth and root_certificates is None:
- raise ValueError(
- 'Illegal to require client auth without providing root certificates!')
- else:
- return ServerCredentials(_low.server_credentials_ssl(
- root_certificates, private_key_certificate_chain_pairs,
- require_client_auth))
+ServerCredentials = grpc.ServerCredentials
+ssl_server_credentials = grpc.ssl_server_credentials
class ServerOptions(object):
@@ -452,9 +301,8 @@ def server(service_implementations, options=None):
An interfaces.Server with which RPCs can be serviced.
"""
effective_options = _EMPTY_SERVER_OPTIONS if options is None else options
- return _server.server(
+ return _server_adaptations.server(
service_implementations, effective_options.multi_method_implementation,
effective_options.request_deserializers,
effective_options.response_serializers, effective_options.thread_pool,
- effective_options.thread_pool_size, effective_options.default_timeout,
- effective_options.maximum_timeout)
+ effective_options.thread_pool_size)
diff --git a/src/python/grpcio/grpc/beta/interfaces.py b/src/python/grpcio/grpc/beta/interfaces.py
index 24de9ad1a8..90f6bbbfcc 100644
--- a/src/python/grpcio/grpc/beta/interfaces.py
+++ b/src/python/grpcio/grpc/beta/interfaces.py
@@ -30,53 +30,16 @@
"""Constants and interfaces of the Beta API of gRPC Python."""
import abc
-import enum
import six
-from grpc._adapter import _types
+import grpc
+ChannelConnectivity = grpc.ChannelConnectivity
+# FATAL_FAILURE was a Beta-API name for SHUTDOWN
+ChannelConnectivity.FATAL_FAILURE = ChannelConnectivity.SHUTDOWN
-@enum.unique
-class ChannelConnectivity(enum.Enum):
- """Mirrors grpc_connectivity_state in the gRPC Core.
-
- Attributes:
- IDLE: The channel is idle.
- CONNECTING: The channel is connecting.
- READY: The channel is ready to conduct RPCs.
- TRANSIENT_FAILURE: The channel has seen a failure from which it expects to
- recover.
- FATAL_FAILURE: The channel has seen a failure from which it cannot recover.
- """
- IDLE = (_types.ConnectivityState.IDLE, 'idle',)
- CONNECTING = (_types.ConnectivityState.CONNECTING, 'connecting',)
- READY = (_types.ConnectivityState.READY, 'ready',)
- TRANSIENT_FAILURE = (
- _types.ConnectivityState.TRANSIENT_FAILURE, 'transient failure',)
- FATAL_FAILURE = (_types.ConnectivityState.FATAL_FAILURE, 'fatal failure',)
-
-
-@enum.unique
-class StatusCode(enum.Enum):
- """Mirrors grpc_status_code in the C core."""
- OK = 0
- CANCELLED = 1
- UNKNOWN = 2
- INVALID_ARGUMENT = 3
- DEADLINE_EXCEEDED = 4
- NOT_FOUND = 5
- ALREADY_EXISTS = 6
- PERMISSION_DENIED = 7
- RESOURCE_EXHAUSTED = 8
- FAILED_PRECONDITION = 9
- ABORTED = 10
- OUT_OF_RANGE = 11
- UNIMPLEMENTED = 12
- INTERNAL = 13
- UNAVAILABLE = 14
- DATA_LOSS = 15
- UNAUTHENTICATED = 16
+StatusCode = grpc.StatusCode
class GRPCCallOptions(object):
@@ -106,46 +69,9 @@ def grpc_call_options(disable_compression=False, credentials=None):
"""
return GRPCCallOptions(disable_compression, None, credentials)
-
-class GRPCAuthMetadataContext(six.with_metaclass(abc.ABCMeta)):
- """Provides information to call credentials metadata plugins.
-
- Attributes:
- service_url: A string URL of the service being called into.
- method_name: A string of the fully qualified method name being called.
- """
-
-
-class GRPCAuthMetadataPluginCallback(six.with_metaclass(abc.ABCMeta)):
- """Callback object received by a metadata plugin."""
-
- def __call__(self, metadata, error):
- """Inform the gRPC runtime of the metadata to construct a CallCredentials.
-
- Args:
- metadata: An iterable of 2-sequences (e.g. tuples) of metadata key/value
- pairs.
- error: An Exception to indicate error or None to indicate success.
- """
- raise NotImplementedError()
-
-
-class GRPCAuthMetadataPlugin(six.with_metaclass(abc.ABCMeta)):
- """
- """
-
- def __call__(self, context, callback):
- """Invoke the plugin.
-
- Must not block. Need only be called by the gRPC runtime.
-
- Args:
- context: A GRPCAuthMetadataContext providing information on what the
- plugin is being used for.
- callback: A GRPCAuthMetadataPluginCallback to be invoked either
- synchronously or asynchronously.
- """
- raise NotImplementedError()
+GRPCAuthMetadataContext = grpc.AuthMetadataContext
+GRPCAuthMetadataPluginCallback = grpc.AuthMetadataPluginCallback
+GRPCAuthMetadataPlugin = grpc.AuthMetadataPlugin
class GRPCServicerContext(six.with_metaclass(abc.ABCMeta)):
diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py
index 84c83af8a8..839c555f05 100644
--- a/src/python/grpcio/grpc_core_dependencies.py
+++ b/src/python/grpcio/grpc_core_dependencies.py
@@ -45,7 +45,6 @@ CORE_SOURCE_FILES = [
'src/core/lib/support/env_windows.c',
'src/core/lib/support/histogram.c',
'src/core/lib/support/host_port.c',
- 'src/core/lib/support/load_file.c',
'src/core/lib/support/log.c',
'src/core/lib/support/log_android.c',
'src/core/lib/support/log_linux.c',
@@ -84,7 +83,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/channel/connected_channel.c',
'src/core/lib/channel/http_client_filter.c',
'src/core/lib/channel/http_server_filter.c',
- 'src/core/lib/compression/compression_algorithm.c',
+ 'src/core/lib/compression/compression.c',
'src/core/lib/compression/message_compress.c',
'src/core/lib/debug/trace.c',
'src/core/lib/http/format_request.c',
@@ -94,6 +93,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/iomgr/endpoint.c',
'src/core/lib/iomgr/endpoint_pair_posix.c',
'src/core/lib/iomgr/endpoint_pair_windows.c',
+ 'src/core/lib/iomgr/error.c',
'src/core/lib/iomgr/ev_poll_and_epoll_posix.c',
'src/core/lib/iomgr/ev_poll_posix.c',
'src/core/lib/iomgr/ev_posix.c',
@@ -103,6 +103,8 @@ CORE_SOURCE_FILES = [
'src/core/lib/iomgr/iomgr.c',
'src/core/lib/iomgr/iomgr_posix.c',
'src/core/lib/iomgr/iomgr_windows.c',
+ 'src/core/lib/iomgr/load_file.c',
+ 'src/core/lib/iomgr/polling_entity.c',
'src/core/lib/iomgr/pollset_set_windows.c',
'src/core/lib/iomgr/pollset_windows.c',
'src/core/lib/iomgr/resolve_address_posix.c',
@@ -160,6 +162,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/transport/transport.c',
'src/core/lib/transport/transport_op_string.c',
'src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c',
+ 'src/core/ext/transport/chttp2/transport/bin_decoder.c',
'src/core/ext/transport/chttp2/transport/bin_encoder.c',
'src/core/ext/transport/chttp2/transport/chttp2_plugin.c',
'src/core/ext/transport/chttp2/transport/chttp2_transport.c',
@@ -203,6 +206,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/security/transport/secure_endpoint.c',
'src/core/lib/security/transport/security_connector.c',
'src/core/lib/security/transport/server_auth_filter.c',
+ 'src/core/lib/security/transport/tsi_error.c',
'src/core/lib/security/util/b64.c',
'src/core/lib/security/util/json_util.c',
'src/core/lib/surface/init_secure.c',
@@ -230,10 +234,9 @@ CORE_SOURCE_FILES = [
'src/core/ext/client_config/subchannel_index.c',
'src/core/ext/client_config/uri_parser.c',
'src/core/ext/transport/chttp2/server/insecure/server_chttp2.c',
+ 'src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c',
'src/core/ext/transport/chttp2/client/insecure/channel_create.c',
- 'src/core/ext/transport/cronet/client/secure/cronet_channel_create.c',
- 'src/core/ext/transport/cronet/transport/cronet_api_dummy.c',
- 'src/core/ext/transport/cronet/transport/cronet_transport.c',
+ 'src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c',
'src/core/ext/lb_policy/grpclb/load_balancer_api.c',
'src/core/ext/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c',
'third_party/nanopb/pb_common.c',
diff --git a/src/python/grpcio/grpc_version.py b/src/python/grpcio/grpc_version.py
index 0c13104d9d..0f4db9d972 100644
--- a/src/python/grpcio/grpc_version.py
+++ b/src/python/grpcio/grpc_version.py
@@ -29,4 +29,4 @@
# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc_version.py.template`!!!
-VERSION='0.15.0.dev0'
+VERSION='0.16.0.dev0'
diff --git a/src/python/grpcio/tests/interop/client.py b/src/python/grpcio/tests/interop/client.py
index e3d5545a02..8aa1ce30c1 100644
--- a/src/python/grpcio/tests/interop/client.py
+++ b/src/python/grpcio/tests/interop/client.py
@@ -76,6 +76,9 @@ def _stub(args):
creds = oauth2client_client.GoogleCredentials.get_application_default()
scoped_creds = creds.create_scoped([args.oauth_scope])
call_creds = implementations.google_call_credentials(scoped_creds)
+ elif args.test_case == 'jwt_token_creds':
+ creds = oauth2client_client.GoogleCredentials.get_application_default()
+ call_creds = implementations.google_call_credentials(creds)
else:
call_creds = None
if args.use_tls:
diff --git a/src/python/grpcio/tests/interop/methods.py b/src/python/grpcio/tests/interop/methods.py
index d5ef0c68bb..7eac511525 100644
--- a/src/python/grpcio/tests/interop/methods.py
+++ b/src/python/grpcio/tests/interop/methods.py
@@ -310,6 +310,16 @@ def _oauth2_auth_token(stub, args):
(response.oauth_scope, args.oauth_scope))
+def _jwt_token_creds(stub, args):
+ json_key_filename = os.environ[
+ oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
+ wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
+ response = _large_unary_common_behavior(stub, True, False)
+ if wanted_email != response.username:
+ raise ValueError(
+ 'expected username %s, got %s' % (wanted_email, response.username))
+
+
def _per_rpc_creds(stub, args):
json_key_filename = os.environ[
oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
@@ -338,6 +348,7 @@ class TestCase(enum.Enum):
EMPTY_STREAM = 'empty_stream'
COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
+ JWT_TOKEN_CREDS = 'jwt_token_creds'
PER_RPC_CREDS = 'per_rpc_creds'
TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
@@ -364,6 +375,8 @@ class TestCase(enum.Enum):
_compute_engine_creds(stub, args)
elif self is TestCase.OAUTH2_AUTH_TOKEN:
_oauth2_auth_token(stub, args)
+ elif self is TestCase.JWT_TOKEN_CREDS:
+ _jwt_token_creds(stub, args)
elif self is TestCase.PER_RPC_CREDS:
_per_rpc_creds(stub, args)
else:
diff --git a/src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py b/src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py
new file mode 100644
index 0000000000..1c9cbb0d0c
--- /dev/null
+++ b/src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py
@@ -0,0 +1,583 @@
+# Copyright 2015, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import collections
+from concurrent import futures
+import contextlib
+import distutils.spawn
+import errno
+import os
+import shutil
+import subprocess
+import sys
+import tempfile
+import threading
+import unittest
+
+from six import moves
+
+import grpc
+from tests.unit.framework.common import test_constants
+
+# Identifiers of entities we expect to find in the generated module.
+STUB_IDENTIFIER = 'TestServiceStub'
+SERVICER_IDENTIFIER = 'TestServiceServicer'
+ADD_SERVICER_TO_SERVER_IDENTIFIER = 'add_TestServiceServicer_to_server'
+
+
+class _ServicerMethods(object):
+
+ def __init__(self, response_pb2, payload_pb2):
+ self._condition = threading.Condition()
+ self._paused = False
+ self._fail = False
+ self._response_pb2 = response_pb2
+ self._payload_pb2 = payload_pb2
+
+ @contextlib.contextmanager
+ def pause(self): # pylint: disable=invalid-name
+ with self._condition:
+ self._paused = True
+ yield
+ with self._condition:
+ self._paused = False
+ self._condition.notify_all()
+
+ @contextlib.contextmanager
+ def fail(self): # pylint: disable=invalid-name
+ with self._condition:
+ self._fail = True
+ yield
+ with self._condition:
+ self._fail = False
+
+ def _control(self): # pylint: disable=invalid-name
+ with self._condition:
+ if self._fail:
+ raise ValueError()
+ while self._paused:
+ self._condition.wait()
+
+ def UnaryCall(self, request, unused_rpc_context):
+ response = self._response_pb2.SimpleResponse()
+ response.payload.payload_type = self._payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * request.response_size
+ self._control()
+ return response
+
+ def StreamingOutputCall(self, request, unused_rpc_context):
+ for parameter in request.response_parameters:
+ response = self._response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = self._payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ yield response
+
+ def StreamingInputCall(self, request_iter, unused_rpc_context):
+ response = self._response_pb2.StreamingInputCallResponse()
+ aggregated_payload_size = 0
+ for request in request_iter:
+ aggregated_payload_size += len(request.payload.payload_compressable)
+ response.aggregated_payload_size = aggregated_payload_size
+ self._control()
+ return response
+
+ def FullDuplexCall(self, request_iter, unused_rpc_context):
+ for request in request_iter:
+ for parameter in request.response_parameters:
+ response = self._response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = self._payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ yield response
+
+ def HalfDuplexCall(self, request_iter, unused_rpc_context):
+ responses = []
+ for request in request_iter:
+ for parameter in request.response_parameters:
+ response = self._response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = self._payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ responses.append(response)
+ for response in responses:
+ yield response
+
+
+class _Service(
+ collections.namedtuple(
+ '_Service', ('servicer_methods', 'server', 'stub',))):
+ """A live and running service.
+
+ Attributes:
+ servicer_methods: The _ServicerMethods servicing RPCs.
+ server: The grpc.Server servicing RPCs.
+ stub: A stub on which to invoke RPCs.
+ """
+
+
+def _CreateService(service_pb2, response_pb2, payload_pb2):
+ """Provides a servicer backend and a stub.
+
+ Args:
+ service_pb2: The service_pb2 module generated by this test.
+ response_pb2: The response_pb2 module generated by this test.
+ payload_pb2: The payload_pb2 module generated by this test.
+
+ Returns:
+ A _Service with which to test RPCs.
+ """
+ servicer_methods = _ServicerMethods(response_pb2, payload_pb2)
+
+ class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
+
+ def UnaryCall(self, request, context):
+ return servicer_methods.UnaryCall(request, context)
+
+ def StreamingOutputCall(self, request, context):
+ return servicer_methods.StreamingOutputCall(request, context)
+
+ def StreamingInputCall(self, request_iter, context):
+ return servicer_methods.StreamingInputCall(request_iter, context)
+
+ def FullDuplexCall(self, request_iter, context):
+ return servicer_methods.FullDuplexCall(request_iter, context)
+
+ def HalfDuplexCall(self, request_iter, context):
+ return servicer_methods.HalfDuplexCall(request_iter, context)
+
+ server = grpc.server(
+ (), futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
+ getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), server)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ stub = getattr(service_pb2, STUB_IDENTIFIER)(channel)
+ return _Service(servicer_methods, server, stub)
+
+
+def _CreateIncompleteService(service_pb2):
+ """Provides a servicer backend that fails to implement methods and its stub.
+
+ Args:
+ service_pb2: The service_pb2 module generated by this test.
+
+ Returns:
+ A _Service with which to test RPCs. The returned _Service's
+ servicer_methods implements none of the methods required of it.
+ """
+
+ class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
+ pass
+
+ server = grpc.server(
+ (), futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
+ getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), server)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ stub = getattr(service_pb2, STUB_IDENTIFIER)(channel)
+ return _Service(None, server, stub)
+
+
+def _streaming_input_request_iterator(request_pb2, payload_pb2):
+ for _ in range(3):
+ request = request_pb2.StreamingInputCallRequest()
+ request.payload.payload_type = payload_pb2.COMPRESSABLE
+ request.payload.payload_compressable = 'a'
+ yield request
+
+
+def _streaming_output_request(request_pb2):
+ request = request_pb2.StreamingOutputCallRequest()
+ sizes = [1, 2, 3]
+ request.response_parameters.add(size=sizes[0], interval_us=0)
+ request.response_parameters.add(size=sizes[1], interval_us=0)
+ request.response_parameters.add(size=sizes[2], interval_us=0)
+ return request
+
+
+def _full_duplex_request_iterator(request_pb2):
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=2, interval_us=0)
+ request.response_parameters.add(size=3, interval_us=0)
+ yield request
+
+
+class PythonPluginTest(unittest.TestCase):
+ """Test case for the gRPC Python protoc-plugin.
+
+ While reading these tests, remember that the futures API
+ (`stub.method.future()`) only gives futures for the *response-unary*
+ methods and does not exist for response-streaming methods.
+ """
+
+ def setUp(self):
+ # Assume that the appropriate protoc and grpc_python_plugins are on the
+ # path.
+ protoc_command = 'protoc'
+ protoc_plugin_filename = distutils.spawn.find_executable(
+ 'grpc_python_plugin')
+ if not os.path.isfile(protoc_command):
+ # Assume that if we haven't built protoc that it's on the system.
+ protoc_command = 'protoc'
+
+ # Ensure that the output directory exists.
+ self.outdir = tempfile.mkdtemp()
+
+ # Find all proto files
+ paths = []
+ root_dir = os.path.dirname(os.path.realpath(__file__))
+ proto_dir = os.path.join(root_dir, 'protos')
+ for walk_root, _, filenames in os.walk(proto_dir):
+ for filename in filenames:
+ if filename.endswith('.proto'):
+ path = os.path.join(walk_root, filename)
+ paths.append(path)
+
+ # Invoke protoc with the plugin.
+ cmd = [
+ protoc_command,
+ '--plugin=protoc-gen-python-grpc=%s' % protoc_plugin_filename,
+ '-I %s' % root_dir,
+ '--python_out=%s' % self.outdir,
+ '--python-grpc_out=%s' % self.outdir
+ ] + paths
+ subprocess.check_call(' '.join(cmd), shell=True, env=os.environ,
+ cwd=os.path.dirname(os.path.realpath(__file__)))
+
+ # Generated proto directories dont include __init__.py, but
+ # these are needed for python package resolution
+ for walk_root, _, _ in os.walk(os.path.join(self.outdir, 'protos')):
+ path = os.path.join(walk_root, '__init__.py')
+ open(path, 'a').close()
+
+ sys.path.insert(0, self.outdir)
+
+ import protos.payload.test_payload_pb2 as payload_pb2
+ import protos.requests.r.test_requests_pb2 as request_pb2
+ import protos.responses.test_responses_pb2 as response_pb2
+ import protos.service.test_service_pb2 as service_pb2
+ self._payload_pb2 = payload_pb2
+ self._request_pb2 = request_pb2
+ self._response_pb2 = response_pb2
+ self._service_pb2 = service_pb2
+
+ def tearDown(self):
+ try:
+ shutil.rmtree(self.outdir)
+ except OSError as exc:
+ if exc.errno != errno.ENOENT:
+ raise
+ sys.path.remove(self.outdir)
+
+ def testImportAttributes(self):
+ # check that we can access the generated module and its members.
+ self.assertIsNotNone(
+ getattr(self._service_pb2, STUB_IDENTIFIER, None))
+ self.assertIsNotNone(
+ getattr(self._service_pb2, SERVICER_IDENTIFIER, None))
+ self.assertIsNotNone(
+ getattr(self._service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER, None))
+
+ def testUpDown(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ self.assertIsNotNone(service.servicer_methods)
+ self.assertIsNotNone(service.server)
+ self.assertIsNotNone(service.stub)
+
+ def testIncompleteServicer(self):
+ service = _CreateIncompleteService(self._service_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ service.stub.UnaryCall(request)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.UNIMPLEMENTED)
+
+ def testUnaryCall(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ response = service.stub.UnaryCall(request)
+ expected_response = service.servicer_methods.UnaryCall(
+ request, 'not a real context!')
+ self.assertEqual(expected_response, response)
+
+ def testUnaryCallFuture(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ # Check that the call does not block waiting for the server to respond.
+ with service.servicer_methods.pause():
+ response_future = service.stub.UnaryCall.future(request)
+ response = response_future.result()
+ expected_response = service.servicer_methods.UnaryCall(
+ request, 'not a real RpcContext!')
+ self.assertEqual(expected_response, response)
+
+ def testUnaryCallFutureExpired(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ with service.servicer_methods.pause():
+ response_future = service.stub.UnaryCall.future(
+ request, timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+ self.assertIs(response_future.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testUnaryCallFutureCancelled(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ with service.servicer_methods.pause():
+ response_future = service.stub.UnaryCall.future(request)
+ response_future.cancel()
+ self.assertTrue(response_future.cancelled())
+ self.assertIs(response_future.code(), grpc.StatusCode.CANCELLED)
+
+ def testUnaryCallFutureFailed(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ with service.servicer_methods.fail():
+ response_future = service.stub.UnaryCall.future(request)
+ self.assertIsNotNone(response_future.exception())
+ self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN)
+
+ def testStreamingOutputCall(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = _streaming_output_request(self._request_pb2)
+ responses = service.stub.StreamingOutputCall(request)
+ expected_responses = service.servicer_methods.StreamingOutputCall(
+ request, 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(
+ expected_responses, responses):
+ self.assertEqual(expected_response, response)
+
+ def testStreamingOutputCallExpired(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = _streaming_output_request(self._request_pb2)
+ with service.servicer_methods.pause():
+ responses = service.stub.StreamingOutputCall(
+ request, timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ list(responses)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testStreamingOutputCallCancelled(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = _streaming_output_request(self._request_pb2)
+ responses = service.stub.StreamingOutputCall(request)
+ next(responses)
+ responses.cancel()
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(responses.code(), grpc.StatusCode.CANCELLED)
+
+ def testStreamingOutputCallFailed(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = _streaming_output_request(self._request_pb2)
+ with service.servicer_methods.fail():
+ responses = service.stub.StreamingOutputCall(request)
+ self.assertIsNotNone(responses)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(exception_context.exception.code(), grpc.StatusCode.UNKNOWN)
+
+ def testStreamingInputCall(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ response = service.stub.StreamingInputCall(
+ _streaming_input_request_iterator(
+ self._request_pb2, self._payload_pb2))
+ expected_response = service.servicer_methods.StreamingInputCall(
+ _streaming_input_request_iterator(self._request_pb2, self._payload_pb2),
+ 'not a real RpcContext!')
+ self.assertEqual(expected_response, response)
+
+ def testStreamingInputCallFuture(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.pause():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(
+ self._request_pb2, self._payload_pb2))
+ response = response_future.result()
+ expected_response = service.servicer_methods.StreamingInputCall(
+ _streaming_input_request_iterator(self._request_pb2, self._payload_pb2),
+ 'not a real RpcContext!')
+ self.assertEqual(expected_response, response)
+
+ def testStreamingInputCallFutureExpired(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.pause():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(
+ self._request_pb2, self._payload_pb2),
+ timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIs(
+ response_future.exception().code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testStreamingInputCallFutureCancelled(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.pause():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(
+ self._request_pb2, self._payload_pb2))
+ response_future.cancel()
+ self.assertTrue(response_future.cancelled())
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.result()
+
+ def testStreamingInputCallFutureFailed(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.fail():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(
+ self._request_pb2, self._payload_pb2))
+ self.assertIsNotNone(response_future.exception())
+ self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN)
+
+ def testFullDuplexCall(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ responses = service.stub.FullDuplexCall(
+ _full_duplex_request_iterator(self._request_pb2))
+ expected_responses = service.servicer_methods.FullDuplexCall(
+ _full_duplex_request_iterator(self._request_pb2),
+ 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(
+ expected_responses, responses):
+ self.assertEqual(expected_response, response)
+
+ def testFullDuplexCallExpired(self):
+ request_iterator = _full_duplex_request_iterator(self._request_pb2)
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.pause():
+ responses = service.stub.FullDuplexCall(
+ request_iterator, timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ list(responses)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testFullDuplexCallCancelled(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request_iterator = _full_duplex_request_iterator(self._request_pb2)
+ responses = service.stub.FullDuplexCall(request_iterator)
+ next(responses)
+ responses.cancel()
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.CANCELLED)
+
+ def testFullDuplexCallFailed(self):
+ request_iterator = _full_duplex_request_iterator(self._request_pb2)
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.fail():
+ responses = service.stub.FullDuplexCall(request_iterator)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(exception_context.exception.code(), grpc.StatusCode.UNKNOWN)
+
+ def testHalfDuplexCall(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ def half_duplex_request_iterator():
+ request = self._request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ request = self._request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=2, interval_us=0)
+ request.response_parameters.add(size=3, interval_us=0)
+ yield request
+ responses = service.stub.HalfDuplexCall(half_duplex_request_iterator())
+ expected_responses = service.servicer_methods.HalfDuplexCall(
+ half_duplex_request_iterator(), 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(
+ expected_responses, responses):
+ self.assertEqual(expected_response, response)
+
+ def testHalfDuplexCallWedged(self):
+ condition = threading.Condition()
+ wait_cell = [False]
+ @contextlib.contextmanager
+ def wait(): # pylint: disable=invalid-name
+ # Where's Python 3's 'nonlocal' statement when you need it?
+ with condition:
+ wait_cell[0] = True
+ yield
+ with condition:
+ wait_cell[0] = False
+ condition.notify_all()
+ def half_duplex_request_iterator():
+ request = self._request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ with condition:
+ while wait_cell[0]:
+ condition.wait()
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with wait():
+ responses = service.stub.HalfDuplexCall(
+ half_duplex_request_iterator(), timeout=test_constants.SHORT_TIMEOUT)
+ # half-duplex waits for the client to send all info
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/qps/benchmark_client.py b/src/python/grpcio/tests/qps/benchmark_client.py
index b372ea01ad..080281415d 100644
--- a/src/python/grpcio/tests/qps/benchmark_client.py
+++ b/src/python/grpcio/tests/qps/benchmark_client.py
@@ -30,14 +30,13 @@
"""Defines test client behaviors (UNARY/STREAMING) (SYNC/ASYNC)."""
import abc
+import threading
import time
-try:
- import Queue as queue # Python 2.x
-except ImportError:
- import queue # Python 3
from concurrent import futures
+from six.moves import queue
+import grpc
from grpc.beta import implementations
from grpc.framework.interfaces.face import face
from src.proto.grpc.testing import messages_pb2
@@ -65,6 +64,13 @@ class BenchmarkClient:
else:
channel = implementations.insecure_channel(host, port)
+ connected_event = threading.Event()
+ def wait_for_ready(connectivity):
+ if connectivity == grpc.ChannelConnectivity.READY:
+ connected_event.set()
+ channel.subscribe(wait_for_ready, try_to_connect=True)
+ connected_event.wait()
+
if config.payload_config.WhichOneof('payload') == 'simple_params':
self._generic = False
self._stub = services_pb2.beta_create_BenchmarkService_stub(channel)
@@ -82,6 +88,7 @@ class BenchmarkClient:
self._response_callbacks = []
def add_response_callback(self, callback):
+ """callback will be invoked as callback(client, query_time)"""
self._response_callbacks.append(callback)
@abc.abstractmethod
@@ -95,10 +102,10 @@ class BenchmarkClient:
def stop(self):
pass
- def _handle_response(self, query_time):
+ def _handle_response(self, client, query_time):
self._hist.add(query_time * 1e9) # Report times in nanoseconds
for callback in self._response_callbacks:
- callback(query_time)
+ callback(client, query_time)
class UnarySyncBenchmarkClient(BenchmarkClient):
@@ -121,7 +128,7 @@ class UnarySyncBenchmarkClient(BenchmarkClient):
start_time = time.time()
self._stub.UnaryCall(self._request, _TIMEOUT)
end_time = time.time()
- self._handle_response(end_time - start_time)
+ self._handle_response(self, end_time - start_time)
class UnaryAsyncBenchmarkClient(BenchmarkClient):
@@ -136,19 +143,20 @@ class UnaryAsyncBenchmarkClient(BenchmarkClient):
def _response_received(self, start_time, resp):
resp.result()
end_time = time.time()
- self._handle_response(end_time - start_time)
+ self._handle_response(self, end_time - start_time)
def stop(self):
self._stub = None
-class StreamingSyncBenchmarkClient(BenchmarkClient):
+class _SyncStream(object):
- def __init__(self, server, config, hist):
- super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
+ def __init__(self, stub, generic, request, handle_response):
+ self._stub = stub
+ self._generic = generic
+ self._request = request
+ self._handle_response = handle_response
self._is_streaming = False
- self._pool = futures.ThreadPoolExecutor(max_workers=1)
- # Use a thread-safe queue to put requests on the stream
self._request_queue = queue.Queue()
self._send_time_queue = queue.Queue()
@@ -158,15 +166,6 @@ class StreamingSyncBenchmarkClient(BenchmarkClient):
def start(self):
self._is_streaming = True
- self._pool.submit(self._request_stream)
-
- def stop(self):
- self._is_streaming = False
- self._pool.shutdown(wait=True)
- self._stub = None
-
- def _request_stream(self):
- self._is_streaming = True
if self._generic:
stream_callable = self._stub.stream_stream(
'grpc.testing.BenchmarkService', 'StreamingCall')
@@ -175,8 +174,11 @@ class StreamingSyncBenchmarkClient(BenchmarkClient):
response_stream = stream_callable(self._request_generator(), _TIMEOUT)
for _ in response_stream:
- end_time = time.time()
- self._handle_response(end_time - self._send_time_queue.get_nowait())
+ self._handle_response(
+ self, time.time() - self._send_time_queue.get_nowait())
+
+ def stop(self):
+ self._is_streaming = False
def _request_generator(self):
while self._is_streaming:
@@ -187,46 +189,28 @@ class StreamingSyncBenchmarkClient(BenchmarkClient):
pass
-class AsyncReceiver(face.ResponseReceiver):
- """Receiver for async stream responses."""
-
- def __init__(self, send_time_queue, response_handler):
- self._send_time_queue = send_time_queue
- self._response_handler = response_handler
-
- def initial_metadata(self, initial_mdetadata):
- pass
-
- def response(self, response):
- end_time = time.time()
- self._response_handler(end_time - self._send_time_queue.get_nowait())
-
- def complete(self, terminal_metadata, code, details):
- pass
-
-
-class StreamingAsyncBenchmarkClient(BenchmarkClient):
+class StreamingSyncBenchmarkClient(BenchmarkClient):
def __init__(self, server, config, hist):
- super(StreamingAsyncBenchmarkClient, self).__init__(server, config, hist)
- self._send_time_queue = queue.Queue()
- self._receiver = AsyncReceiver(self._send_time_queue, self._handle_response)
- self._rendezvous = None
+ super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
+ self._pool = futures.ThreadPoolExecutor(
+ max_workers=config.outstanding_rpcs_per_channel)
+ self._streams = [_SyncStream(self._stub, self._generic,
+ self._request, self._handle_response)
+ for _ in xrange(config.outstanding_rpcs_per_channel)]
+ self._curr_stream = 0
def send_request(self):
- if self._rendezvous is not None:
- self._send_time_queue.put(time.time())
- self._rendezvous.consume(self._request)
+ # Use a round_robin scheduler to determine what stream to send on
+ self._streams[self._curr_stream].send_request()
+ self._curr_stream = (self._curr_stream + 1) % len(self._streams)
def start(self):
- if self._generic:
- stream_callable = self._stub.stream_stream(
- 'grpc.testing.BenchmarkService', 'StreamingCall')
- else:
- stream_callable = self._stub.StreamingCall
- self._rendezvous = stream_callable.event(
- self._receiver, lambda *args: None, _TIMEOUT)
+ for stream in self._streams:
+ self._pool.submit(stream.start)
def stop(self):
- self._rendezvous.terminate()
- self._rendezvous = None
+ for stream in self._streams:
+ stream.stop()
+ self._pool.shutdown(wait=True)
+ self._stub = None
diff --git a/src/python/grpcio/tests/qps/client_runner.py b/src/python/grpcio/tests/qps/client_runner.py
index 1ede7d2af1..1fd58687ad 100644
--- a/src/python/grpcio/tests/qps/client_runner.py
+++ b/src/python/grpcio/tests/qps/client_runner.py
@@ -34,7 +34,7 @@ ClientRunner invokes either periodically or in response to some event.
"""
import abc
-import thread
+import threading
import time
@@ -61,15 +61,18 @@ class OpenLoopClientRunner(ClientRunner):
super(OpenLoopClientRunner, self).__init__(client)
self._is_running = False
self._interval_generator = interval_generator
+ self._dispatch_thread = threading.Thread(
+ target=self._dispatch_requests, args=())
def start(self):
self._is_running = True
self._client.start()
- thread.start_new_thread(self._dispatch_requests, ())
-
+ self._dispatch_thread.start()
+
def stop(self):
self._is_running = False
self._client.stop()
+ self._dispatch_thread.join()
self._client = None
def _dispatch_requests(self):
@@ -98,7 +101,6 @@ class ClosedLoopClientRunner(ClientRunner):
self._client.stop()
self._client = None
- def _send_request(self, response_time):
+ def _send_request(self, client, response_time):
if self._is_running:
- self._client.send_request()
-
+ client.send_request()
diff --git a/src/python/grpcio/tests/qps/qps_worker.py b/src/python/grpcio/tests/qps/qps_worker.py
index 3dda718638..16926379a5 100644
--- a/src/python/grpcio/tests/qps/qps_worker.py
+++ b/src/python/grpcio/tests/qps/qps_worker.py
@@ -43,9 +43,7 @@ def run_worker_server(port):
server.add_insecure_port('[::]:{}'.format(port))
server.start()
servicer.wait_for_quit()
- # Drain outstanding requests for clean exit
- time.sleep(2)
- server.stop(0)
+ server.stop(2)
if __name__ == '__main__':
diff --git a/src/python/grpcio/tests/qps/worker_server.py b/src/python/grpcio/tests/qps/worker_server.py
index 1f9af5482c..d41f8377c2 100644
--- a/src/python/grpcio/tests/qps/worker_server.py
+++ b/src/python/grpcio/tests/qps/worker_server.py
@@ -153,9 +153,8 @@ class WorkerServer(services_pb2.BetaWorkerServiceServicer):
if config.rpc_type == control_pb2.UNARY:
client = benchmark_client.UnaryAsyncBenchmarkClient(
server, config, qps_data)
- elif config.rpc_type == control_pb2.STREAMING:
- client = benchmark_client.StreamingAsyncBenchmarkClient(
- server, config, qps_data)
+ else:
+ raise Exception('Async streaming client not supported')
else:
raise Exception('Unsupported client type {}'.format(config.client_type))
diff --git a/src/python/grpcio/tests/stress/client.py b/src/python/grpcio/tests/stress/client.py
index e2e016760c..0de2532cd8 100644
--- a/src/python/grpcio/tests/stress/client.py
+++ b/src/python/grpcio/tests/stress/client.py
@@ -30,10 +30,10 @@
"""Entry point for running stress tests."""
import argparse
-import Queue
import threading
from grpc.beta import implementations
+from six.moves import queue
from src.proto.grpc.testing import metrics_pb2
from src.proto.grpc.testing import test_pb2
@@ -94,7 +94,7 @@ def run_test(args):
test_cases = _parse_weighted_test_cases(args.test_cases)
test_servers = args.server_addresses.split(',')
# Propagate any client exceptions with a queue
- exception_queue = Queue.Queue()
+ exception_queue = queue.Queue()
stop_event = threading.Event()
hist = histogram.Histogram(1, 1)
runners = []
@@ -121,7 +121,7 @@ def run_test(args):
if timeout_secs < 0:
timeout_secs = None
raise exception_queue.get(block=True, timeout=timeout_secs)
- except Queue.Empty:
+ except queue.Empty:
# No exceptions thrown, success
pass
finally:
diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json
index 81458b11da..e384a2fc13 100644
--- a/src/python/grpcio/tests/tests.json
+++ b/src/python/grpcio/tests/tests.json
@@ -1,10 +1,9 @@
[
+ "_api_test.AllTest",
+ "_api_test.ChannelConnectivityTest",
+ "_api_test.ChannelTest",
"_auth_test.AccessTokenCallCredentialsTest",
"_auth_test.GoogleCallCredentialsTest",
- "_base_interface_test.AsyncEasyTest",
- "_base_interface_test.AsyncPeasyTest",
- "_base_interface_test.SyncEasyTest",
- "_base_interface_test.SyncPeasyTest",
"_beta_features_test.BetaFeaturesTest",
"_beta_features_test.ContextManagementAndLifecycleTest",
"_cancel_many_calls_test.CancelManyCallsTest",
@@ -12,22 +11,8 @@
"_channel_ready_future_test.ChannelReadyFutureTest",
"_channel_test.ChannelTest",
"_connectivity_channel_test.ChannelConnectivityTest",
- "_core_over_links_base_interface_test.AsyncEasyTest",
- "_core_over_links_base_interface_test.AsyncPeasyTest",
- "_core_over_links_base_interface_test.SyncEasyTest",
- "_core_over_links_base_interface_test.SyncPeasyTest",
- "_crust_over_core_face_interface_test.DynamicInvokerBlockingInvocationInlineServiceTest",
- "_crust_over_core_face_interface_test.DynamicInvokerFutureInvocationAsynchronousEventServiceTest",
- "_crust_over_core_face_interface_test.GenericInvokerBlockingInvocationInlineServiceTest",
- "_crust_over_core_face_interface_test.GenericInvokerFutureInvocationAsynchronousEventServiceTest",
- "_crust_over_core_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest",
- "_crust_over_core_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest",
- "_crust_over_core_over_links_face_interface_test.DynamicInvokerBlockingInvocationInlineServiceTest",
- "_crust_over_core_over_links_face_interface_test.DynamicInvokerFutureInvocationAsynchronousEventServiceTest",
- "_crust_over_core_over_links_face_interface_test.GenericInvokerBlockingInvocationInlineServiceTest",
- "_crust_over_core_over_links_face_interface_test.GenericInvokerFutureInvocationAsynchronousEventServiceTest",
- "_crust_over_core_over_links_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest",
- "_crust_over_core_over_links_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest",
+ "_connectivity_channel_test.ConnectivityStatesTest",
+ "_empty_message_test.EmptyMessageTest",
"_face_interface_test.DynamicInvokerBlockingInvocationInlineServiceTest",
"_face_interface_test.DynamicInvokerFutureInvocationAsynchronousEventServiceTest",
"_face_interface_test.GenericInvokerBlockingInvocationInlineServiceTest",
@@ -38,21 +23,16 @@
"_implementations_test.CallCredentialsTest",
"_implementations_test.ChannelCredentialsTest",
"_insecure_interop_test.InsecureInteropTest",
- "_intermediary_low_test.CancellationTest",
- "_intermediary_low_test.EchoTest",
- "_intermediary_low_test.ExpirationTest",
- "_intermediary_low_test.LonelyClientTest",
- "_later_test.LaterTest",
"_logging_pool_test.LoggingPoolTest",
- "_lonely_invocation_link_test.LonelyInvocationLinkTest",
- "_low_test.HangingServerShutdown",
- "_low_test.InsecureServerInsecureClient",
+ "_metadata_code_details_test.MetadataCodeDetailsTest",
+ "_metadata_test.MetadataTest",
"_not_found_test.NotFoundTest",
+ "_python_plugin_test.PythonPluginTest",
+ "_read_some_but_not_all_responses_test.ReadSomeButNotAllResponsesTest",
"_rpc_test.RPCTest",
"_sanity_test.Sanity",
"_secure_interop_test.SecureInteropTest",
- "_transmission_test.RoundTripTest",
- "_transmission_test.TransmissionTest",
+ "_thread_cleanup_test.CleanupThreadTest",
"_utilities_test.ChannelConnectivityTest",
"beta_python_plugin_test.PythonPluginTest",
"cygrpc_test.InsecureServerInsecureClient",
diff --git a/src/python/grpcio/tests/unit/_adapter/_intermediary_low_test.py b/src/python/grpcio/tests/unit/_adapter/_intermediary_low_test.py
deleted file mode 100644
index 22d4b019c7..0000000000
--- a/src/python/grpcio/tests/unit/_adapter/_intermediary_low_test.py
+++ /dev/null
@@ -1,428 +0,0 @@
-# Copyright 2015, Google Inc.
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests for the old '_low'."""
-
-import threading
-import time
-import unittest
-
-import six
-from six.moves import queue
-
-from grpc._adapter import _intermediary_low as _low
-
-_STREAM_LENGTH = 300
-_TIMEOUT = 5
-_AFTER_DELAY = 2
-_FUTURE = time.time() + 60 * 60 * 24
-_BYTE_SEQUENCE = b'\abcdefghijklmnopqrstuvwxyz0123456789' * 200
-_BYTE_SEQUENCE_SEQUENCE = tuple(
- bytes(bytearray((row + column) % 256 for column in range(row)))
- for row in range(_STREAM_LENGTH))
-
-
-class LonelyClientTest(unittest.TestCase):
-
- def testLonelyClient(self):
- host = 'nosuchhostexists'
- port = 54321
- method = 'test method'
- deadline = time.time() + _TIMEOUT
- after_deadline = deadline + _AFTER_DELAY
- metadata_tag = object()
- finish_tag = object()
-
- completion_queue = _low.CompletionQueue()
- channel = _low.Channel('%s:%d' % (host, port), None)
- client_call = _low.Call(channel, completion_queue, method, host, deadline)
-
- client_call.invoke(completion_queue, metadata_tag, finish_tag)
- first_event = completion_queue.get(after_deadline)
- self.assertIsNotNone(first_event)
- second_event = completion_queue.get(after_deadline)
- self.assertIsNotNone(second_event)
- kinds = [event.kind for event in (first_event, second_event)]
- six.assertCountEqual(self,
- (_low.Event.Kind.METADATA_ACCEPTED, _low.Event.Kind.FINISH),
- kinds)
-
- self.assertIsNone(completion_queue.get(after_deadline))
-
- completion_queue.stop()
- stop_event = completion_queue.get(_FUTURE)
- self.assertEqual(_low.Event.Kind.STOP, stop_event.kind)
-
- del client_call
- del channel
- del completion_queue
-
-
-def _drive_completion_queue(completion_queue, event_queue):
- while True:
- event = completion_queue.get(_FUTURE)
- if event.kind is _low.Event.Kind.STOP:
- break
- event_queue.put(event)
-
-
-class EchoTest(unittest.TestCase):
-
- def setUp(self):
- self.host = 'localhost'
-
- self.server_completion_queue = _low.CompletionQueue()
- self.server = _low.Server(self.server_completion_queue)
- port = self.server.add_http2_addr('[::]:0')
- self.server.start()
- self.server_events = queue.Queue()
- self.server_completion_queue_thread = threading.Thread(
- target=_drive_completion_queue,
- args=(self.server_completion_queue, self.server_events))
- self.server_completion_queue_thread.start()
-
- self.client_completion_queue = _low.CompletionQueue()
- self.channel = _low.Channel('%s:%d' % (self.host, port), None)
- self.client_events = queue.Queue()
- self.client_completion_queue_thread = threading.Thread(
- target=_drive_completion_queue,
- args=(self.client_completion_queue, self.client_events))
- self.client_completion_queue_thread.start()
-
- def tearDown(self):
- self.server.stop()
- self.server.cancel_all_calls()
- self.server_completion_queue.stop()
- self.client_completion_queue.stop()
- self.server_completion_queue_thread.join()
- self.client_completion_queue_thread.join()
- del self.server
-
- def _perform_echo_test(self, test_data):
- method = 'test method'
- details = 'test details'
- server_leading_metadata_key = 'my_server_leading_key'
- server_leading_metadata_value = 'my_server_leading_value'
- server_trailing_metadata_key = 'my_server_trailing_key'
- server_trailing_metadata_value = 'my_server_trailing_value'
- client_metadata_key = 'my_client_key'
- client_metadata_value = 'my_client_value'
- server_leading_binary_metadata_key = 'my_server_leading_key-bin'
- server_leading_binary_metadata_value = b'\0'*2047
- server_trailing_binary_metadata_key = 'my_server_trailing_key-bin'
- server_trailing_binary_metadata_value = b'\0'*2047
- client_binary_metadata_key = 'my_client_key-bin'
- client_binary_metadata_value = b'\0'*2047
- deadline = _FUTURE
- metadata_tag = object()
- finish_tag = object()
- write_tag = object()
- complete_tag = object()
- service_tag = object()
- read_tag = object()
- status_tag = object()
-
- server_data = []
- client_data = []
-
- client_call = _low.Call(self.channel, self.client_completion_queue,
- method, self.host, deadline)
- client_call.add_metadata(client_metadata_key, client_metadata_value)
- client_call.add_metadata(client_binary_metadata_key,
- client_binary_metadata_value)
-
- client_call.invoke(self.client_completion_queue, metadata_tag, finish_tag)
-
- self.server.service(service_tag)
- service_accepted = self.server_events.get()
- self.assertIsNotNone(service_accepted)
- self.assertIs(service_accepted.kind, _low.Event.Kind.SERVICE_ACCEPTED)
- self.assertIs(service_accepted.tag, service_tag)
- self.assertEqual(method, service_accepted.service_acceptance.method)
- self.assertEqual(self.host, service_accepted.service_acceptance.host)
- self.assertIsNotNone(service_accepted.service_acceptance.call)
- metadata = dict(service_accepted.metadata)
- self.assertIn(client_metadata_key, metadata)
- self.assertEqual(client_metadata_value, metadata[client_metadata_key])
- self.assertIn(client_binary_metadata_key, metadata)
- self.assertEqual(client_binary_metadata_value,
- metadata[client_binary_metadata_key])
- server_call = service_accepted.service_acceptance.call
- server_call.accept(self.server_completion_queue, finish_tag)
- server_call.add_metadata(server_leading_metadata_key,
- server_leading_metadata_value)
- server_call.add_metadata(server_leading_binary_metadata_key,
- server_leading_binary_metadata_value)
- server_call.premetadata()
-
- metadata_accepted = self.client_events.get()
- self.assertIsNotNone(metadata_accepted)
- self.assertEqual(_low.Event.Kind.METADATA_ACCEPTED, metadata_accepted.kind)
- self.assertEqual(metadata_tag, metadata_accepted.tag)
- metadata = dict(metadata_accepted.metadata)
- self.assertIn(server_leading_metadata_key, metadata)
- self.assertEqual(server_leading_metadata_value,
- metadata[server_leading_metadata_key])
- self.assertIn(server_leading_binary_metadata_key, metadata)
- self.assertEqual(server_leading_binary_metadata_value,
- metadata[server_leading_binary_metadata_key])
-
- for datum in test_data:
- client_call.write(datum, write_tag, _low.WriteFlags.WRITE_NO_COMPRESS)
- write_accepted = self.client_events.get()
- self.assertIsNotNone(write_accepted)
- self.assertIs(write_accepted.kind, _low.Event.Kind.WRITE_ACCEPTED)
- self.assertIs(write_accepted.tag, write_tag)
- self.assertIs(write_accepted.write_accepted, True)
-
- server_call.read(read_tag)
- read_accepted = self.server_events.get()
- self.assertIsNotNone(read_accepted)
- self.assertEqual(_low.Event.Kind.READ_ACCEPTED, read_accepted.kind)
- self.assertEqual(read_tag, read_accepted.tag)
- self.assertIsNotNone(read_accepted.bytes)
- server_data.append(read_accepted.bytes)
-
- server_call.write(read_accepted.bytes, write_tag, 0)
- write_accepted = self.server_events.get()
- self.assertIsNotNone(write_accepted)
- self.assertEqual(_low.Event.Kind.WRITE_ACCEPTED, write_accepted.kind)
- self.assertEqual(write_tag, write_accepted.tag)
- self.assertTrue(write_accepted.write_accepted)
-
- client_call.read(read_tag)
- read_accepted = self.client_events.get()
- self.assertIsNotNone(read_accepted)
- self.assertEqual(_low.Event.Kind.READ_ACCEPTED, read_accepted.kind)
- self.assertEqual(read_tag, read_accepted.tag)
- self.assertIsNotNone(read_accepted.bytes)
- client_data.append(read_accepted.bytes)
-
- client_call.complete(complete_tag)
- complete_accepted = self.client_events.get()
- self.assertIsNotNone(complete_accepted)
- self.assertIs(complete_accepted.kind, _low.Event.Kind.COMPLETE_ACCEPTED)
- self.assertIs(complete_accepted.tag, complete_tag)
- self.assertIs(complete_accepted.complete_accepted, True)
-
- server_call.read(read_tag)
- read_accepted = self.server_events.get()
- self.assertIsNotNone(read_accepted)
- self.assertEqual(_low.Event.Kind.READ_ACCEPTED, read_accepted.kind)
- self.assertEqual(read_tag, read_accepted.tag)
- self.assertIsNone(read_accepted.bytes)
-
- server_call.add_metadata(server_trailing_metadata_key,
- server_trailing_metadata_value)
- server_call.add_metadata(server_trailing_binary_metadata_key,
- server_trailing_binary_metadata_value)
-
- server_call.status(_low.Status(_low.Code.OK, details), status_tag)
- server_terminal_event_one = self.server_events.get()
- server_terminal_event_two = self.server_events.get()
- if server_terminal_event_one.kind == _low.Event.Kind.COMPLETE_ACCEPTED:
- status_accepted = server_terminal_event_one
- rpc_accepted = server_terminal_event_two
- else:
- status_accepted = server_terminal_event_two
- rpc_accepted = server_terminal_event_one
- self.assertIsNotNone(status_accepted)
- self.assertIsNotNone(rpc_accepted)
- self.assertEqual(_low.Event.Kind.COMPLETE_ACCEPTED, status_accepted.kind)
- self.assertEqual(status_tag, status_accepted.tag)
- self.assertTrue(status_accepted.complete_accepted)
- self.assertEqual(_low.Event.Kind.FINISH, rpc_accepted.kind)
- self.assertEqual(finish_tag, rpc_accepted.tag)
- self.assertEqual(_low.Status(_low.Code.OK, ''), rpc_accepted.status)
-
- client_call.read(read_tag)
- client_terminal_event_one = self.client_events.get()
- client_terminal_event_two = self.client_events.get()
- if client_terminal_event_one.kind == _low.Event.Kind.READ_ACCEPTED:
- read_accepted = client_terminal_event_one
- finish_accepted = client_terminal_event_two
- else:
- read_accepted = client_terminal_event_two
- finish_accepted = client_terminal_event_one
- self.assertIsNotNone(read_accepted)
- self.assertIsNotNone(finish_accepted)
- self.assertEqual(_low.Event.Kind.READ_ACCEPTED, read_accepted.kind)
- self.assertEqual(read_tag, read_accepted.tag)
- self.assertIsNone(read_accepted.bytes)
- self.assertEqual(_low.Event.Kind.FINISH, finish_accepted.kind)
- self.assertEqual(finish_tag, finish_accepted.tag)
- self.assertEqual(_low.Status(_low.Code.OK, details), finish_accepted.status)
- metadata = dict(finish_accepted.metadata)
- self.assertIn(server_trailing_metadata_key, metadata)
- self.assertEqual(server_trailing_metadata_value,
- metadata[server_trailing_metadata_key])
- self.assertIn(server_trailing_binary_metadata_key, metadata)
- self.assertEqual(server_trailing_binary_metadata_value,
- metadata[server_trailing_binary_metadata_key])
- self.assertSetEqual(set(key for key, _ in finish_accepted.metadata),
- set((server_trailing_metadata_key,
- server_trailing_binary_metadata_key,)))
-
- self.assertSequenceEqual(test_data, server_data)
- self.assertSequenceEqual(test_data, client_data)
-
- def testNoEcho(self):
- self._perform_echo_test(())
-
- def testOneByteEcho(self):
- self._perform_echo_test([b'\x07'])
-
- def testOneManyByteEcho(self):
- self._perform_echo_test([_BYTE_SEQUENCE])
-
- def testManyOneByteEchoes(self):
- self._perform_echo_test(_BYTE_SEQUENCE)
-
- def testManyManyByteEchoes(self):
- self._perform_echo_test(_BYTE_SEQUENCE_SEQUENCE)
-
-
-class CancellationTest(unittest.TestCase):
-
- def setUp(self):
- self.host = 'localhost'
-
- self.server_completion_queue = _low.CompletionQueue()
- self.server = _low.Server(self.server_completion_queue)
- port = self.server.add_http2_addr('[::]:0')
- self.server.start()
- self.server_events = queue.Queue()
- self.server_completion_queue_thread = threading.Thread(
- target=_drive_completion_queue,
- args=(self.server_completion_queue, self.server_events))
- self.server_completion_queue_thread.start()
-
- self.client_completion_queue = _low.CompletionQueue()
- self.channel = _low.Channel('%s:%d' % (self.host, port), None)
- self.client_events = queue.Queue()
- self.client_completion_queue_thread = threading.Thread(
- target=_drive_completion_queue,
- args=(self.client_completion_queue, self.client_events))
- self.client_completion_queue_thread.start()
-
- def tearDown(self):
- self.server.stop()
- self.server.cancel_all_calls()
- self.server_completion_queue.stop()
- self.client_completion_queue.stop()
- self.server_completion_queue_thread.join()
- self.client_completion_queue_thread.join()
- del self.server
-
- def testCancellation(self):
- method = 'test method'
- deadline = _FUTURE
- metadata_tag = object()
- finish_tag = object()
- write_tag = object()
- service_tag = object()
- read_tag = object()
- test_data = _BYTE_SEQUENCE_SEQUENCE
-
- server_data = []
- client_data = []
-
- client_call = _low.Call(self.channel, self.client_completion_queue,
- method, self.host, deadline)
-
- client_call.invoke(self.client_completion_queue, metadata_tag, finish_tag)
-
- self.server.service(service_tag)
- service_accepted = self.server_events.get()
- server_call = service_accepted.service_acceptance.call
-
- server_call.accept(self.server_completion_queue, finish_tag)
- server_call.premetadata()
-
- metadata_accepted = self.client_events.get()
- self.assertIsNotNone(metadata_accepted)
-
- for datum in test_data:
- client_call.write(datum, write_tag, 0)
- write_accepted = self.client_events.get()
-
- server_call.read(read_tag)
- read_accepted = self.server_events.get()
- server_data.append(read_accepted.bytes)
-
- server_call.write(read_accepted.bytes, write_tag, 0)
- write_accepted = self.server_events.get()
- self.assertIsNotNone(write_accepted)
-
- client_call.read(read_tag)
- read_accepted = self.client_events.get()
- client_data.append(read_accepted.bytes)
-
- client_call.cancel()
- # cancel() is idempotent.
- client_call.cancel()
- client_call.cancel()
- client_call.cancel()
-
- server_call.read(read_tag)
-
- server_terminal_event_one = self.server_events.get()
- server_terminal_event_two = self.server_events.get()
- if server_terminal_event_one.kind == _low.Event.Kind.READ_ACCEPTED:
- read_accepted = server_terminal_event_one
- rpc_accepted = server_terminal_event_two
- else:
- read_accepted = server_terminal_event_two
- rpc_accepted = server_terminal_event_one
- self.assertIsNotNone(read_accepted)
- self.assertIsNotNone(rpc_accepted)
- self.assertEqual(_low.Event.Kind.READ_ACCEPTED, read_accepted.kind)
- self.assertIsNone(read_accepted.bytes)
- self.assertEqual(_low.Event.Kind.FINISH, rpc_accepted.kind)
- self.assertEqual(_low.Status(_low.Code.CANCELLED, ''), rpc_accepted.status)
-
- finish_event = self.client_events.get()
- self.assertEqual(_low.Event.Kind.FINISH, finish_event.kind)
- self.assertEqual(_low.Status(_low.Code.CANCELLED, 'Cancelled'),
- finish_event.status)
-
- self.assertSequenceEqual(test_data, server_data)
- self.assertSequenceEqual(test_data, client_data)
-
-
-class ExpirationTest(unittest.TestCase):
-
- @unittest.skip('TODO(nathaniel): Expiration test!')
- def testExpiration(self):
- pass
-
-
-if __name__ == '__main__':
- unittest.main(verbosity=2)
-
diff --git a/src/python/grpcio/tests/unit/_adapter/_low_test.py b/src/python/grpcio/tests/unit/_adapter/_low_test.py
deleted file mode 100644
index ec46617996..0000000000
--- a/src/python/grpcio/tests/unit/_adapter/_low_test.py
+++ /dev/null
@@ -1,319 +0,0 @@
-# Copyright 2015, Google Inc.
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-import threading
-import time
-import unittest
-
-from grpc import _grpcio_metadata
-from grpc._adapter import _types
-from grpc._adapter import _low
-from tests.unit import test_common
-
-
-def wait_for_events(completion_queues, deadline):
- """
- Args:
- completion_queues: list of completion queues to wait for events on
- deadline: absolute deadline to wait until
-
- Returns:
- a sequence of events of length len(completion_queues).
- """
-
- results = [None] * len(completion_queues)
- lock = threading.Lock()
- threads = []
- def set_ith_result(i, completion_queue):
- result = completion_queue.next(deadline)
- with lock:
- results[i] = result
- for i, completion_queue in enumerate(completion_queues):
- thread = threading.Thread(target=set_ith_result,
- args=[i, completion_queue])
- thread.start()
- threads.append(thread)
- for thread in threads:
- thread.join()
- return results
-
-
-class InsecureServerInsecureClient(unittest.TestCase):
-
- def setUp(self):
- self.server_completion_queue = _low.CompletionQueue()
- self.server = _low.Server(self.server_completion_queue, [])
- self.port = self.server.add_http2_port('[::]:0')
- self.client_completion_queue = _low.CompletionQueue()
- self.client_channel = _low.Channel('localhost:%d'%self.port, [])
-
- self.server.start()
-
- def tearDown(self):
- self.server.shutdown()
- del self.client_channel
-
- self.client_completion_queue.shutdown()
- while (self.client_completion_queue.next(float('+inf')).type !=
- _types.EventType.QUEUE_SHUTDOWN):
- pass
- self.server_completion_queue.shutdown()
- while (self.server_completion_queue.next(float('+inf')).type !=
- _types.EventType.QUEUE_SHUTDOWN):
- pass
-
- del self.client_completion_queue
- del self.server_completion_queue
- del self.server
-
- def testEcho(self):
- deadline = time.time() + 5
- event_time_tolerance = 2
- deadline_tolerance = 0.25
- client_metadata_ascii_key = 'key'
- client_metadata_ascii_value = 'val'
- client_metadata_bin_key = 'key-bin'
- client_metadata_bin_value = b'\0'*1000
- server_initial_metadata_key = 'init_me_me_me'
- server_initial_metadata_value = 'whodawha?'
- server_trailing_metadata_key = 'california_is_in_a_drought'
- server_trailing_metadata_value = 'zomg it is'
- server_status_code = _types.StatusCode.OK
- server_status_details = 'our work is never over'
- request = 'blarghaflargh'
- response = 'his name is robert paulson'
- method = 'twinkies'
- host = 'hostess'
- server_request_tag = object()
- request_call_result = self.server.request_call(self.server_completion_queue,
- server_request_tag)
-
- self.assertEqual(_types.CallError.OK, request_call_result)
-
- client_call_tag = object()
- client_call = self.client_channel.create_call(
- self.client_completion_queue, method, host, deadline)
- client_initial_metadata = [
- (client_metadata_ascii_key, client_metadata_ascii_value),
- (client_metadata_bin_key, client_metadata_bin_value)
- ]
- client_start_batch_result = client_call.start_batch([
- _types.OpArgs.send_initial_metadata(client_initial_metadata),
- _types.OpArgs.send_message(request, 0),
- _types.OpArgs.send_close_from_client(),
- _types.OpArgs.recv_initial_metadata(),
- _types.OpArgs.recv_message(),
- _types.OpArgs.recv_status_on_client()
- ], client_call_tag)
- self.assertEqual(_types.CallError.OK, client_start_batch_result)
-
- client_no_event, request_event, = wait_for_events(
- [self.client_completion_queue, self.server_completion_queue],
- time.time() + event_time_tolerance)
- self.assertEqual(client_no_event, None)
- self.assertEqual(_types.EventType.OP_COMPLETE, request_event.type)
- self.assertIsInstance(request_event.call, _low.Call)
- self.assertIs(server_request_tag, request_event.tag)
- self.assertEqual(1, len(request_event.results))
- received_initial_metadata = request_event.results[0].initial_metadata
- # Check that our metadata were transmitted
- self.assertTrue(test_common.metadata_transmitted(client_initial_metadata,
- received_initial_metadata))
- # Check that Python's user agent string is a part of the full user agent
- # string
- received_initial_metadata_dict = dict(received_initial_metadata)
- self.assertIn('user-agent', received_initial_metadata_dict)
- self.assertIn('Python-gRPC-{}'.format(_grpcio_metadata.__version__),
- received_initial_metadata_dict['user-agent'])
- self.assertEqual(method, request_event.call_details.method)
- self.assertEqual(host, request_event.call_details.host)
- self.assertLess(abs(deadline - request_event.call_details.deadline),
- deadline_tolerance)
-
- # Check that the channel is connected, and that both it and the call have
- # the proper target and peer; do this after the first flurry of messages to
- # avoid the possibility that connection was delayed by the core until the
- # first message was sent.
- self.assertEqual(_types.ConnectivityState.READY,
- self.client_channel.check_connectivity_state(False))
- self.assertIsNotNone(self.client_channel.target())
- self.assertIsNotNone(client_call.peer())
-
- server_call_tag = object()
- server_call = request_event.call
- server_initial_metadata = [
- (server_initial_metadata_key, server_initial_metadata_value)
- ]
- server_trailing_metadata = [
- (server_trailing_metadata_key, server_trailing_metadata_value)
- ]
- server_start_batch_result = server_call.start_batch([
- _types.OpArgs.send_initial_metadata(server_initial_metadata),
- _types.OpArgs.recv_message(),
- _types.OpArgs.send_message(response, 0),
- _types.OpArgs.recv_close_on_server(),
- _types.OpArgs.send_status_from_server(
- server_trailing_metadata, server_status_code, server_status_details)
- ], server_call_tag)
- self.assertEqual(_types.CallError.OK, server_start_batch_result)
-
- client_event, server_event, = wait_for_events(
- [self.client_completion_queue, self.server_completion_queue],
- time.time() + event_time_tolerance)
-
- self.assertEqual(6, len(client_event.results))
- found_client_op_types = set()
- for client_result in client_event.results:
- # we expect each op type to be unique
- self.assertNotIn(client_result.type, found_client_op_types)
- found_client_op_types.add(client_result.type)
- if client_result.type == _types.OpType.RECV_INITIAL_METADATA:
- self.assertTrue(
- test_common.metadata_transmitted(server_initial_metadata,
- client_result.initial_metadata))
- elif client_result.type == _types.OpType.RECV_MESSAGE:
- self.assertEqual(response, client_result.message)
- elif client_result.type == _types.OpType.RECV_STATUS_ON_CLIENT:
- self.assertTrue(
- test_common.metadata_transmitted(server_trailing_metadata,
- client_result.trailing_metadata))
- self.assertEqual(server_status_details, client_result.status.details)
- self.assertEqual(server_status_code, client_result.status.code)
- self.assertEqual(set([
- _types.OpType.SEND_INITIAL_METADATA,
- _types.OpType.SEND_MESSAGE,
- _types.OpType.SEND_CLOSE_FROM_CLIENT,
- _types.OpType.RECV_INITIAL_METADATA,
- _types.OpType.RECV_MESSAGE,
- _types.OpType.RECV_STATUS_ON_CLIENT
- ]), found_client_op_types)
-
- self.assertEqual(5, len(server_event.results))
- found_server_op_types = set()
- for server_result in server_event.results:
- self.assertNotIn(client_result.type, found_server_op_types)
- found_server_op_types.add(server_result.type)
- if server_result.type == _types.OpType.RECV_MESSAGE:
- self.assertEqual(request, server_result.message)
- elif server_result.type == _types.OpType.RECV_CLOSE_ON_SERVER:
- self.assertFalse(server_result.cancelled)
- self.assertEqual(set([
- _types.OpType.SEND_INITIAL_METADATA,
- _types.OpType.RECV_MESSAGE,
- _types.OpType.SEND_MESSAGE,
- _types.OpType.RECV_CLOSE_ON_SERVER,
- _types.OpType.SEND_STATUS_FROM_SERVER
- ]), found_server_op_types)
-
- del client_call
- del server_call
-
-
-class HangingServerShutdown(unittest.TestCase):
-
- def setUp(self):
- self.server_completion_queue = _low.CompletionQueue()
- self.server = _low.Server(self.server_completion_queue, [])
- self.port = self.server.add_http2_port('[::]:0')
- self.client_completion_queue = _low.CompletionQueue()
- self.client_channel = _low.Channel('localhost:%d'%self.port, [])
-
- self.server.start()
-
- def tearDown(self):
- self.server.shutdown()
- del self.client_channel
-
- self.client_completion_queue.shutdown()
- self.server_completion_queue.shutdown()
- while True:
- client_event, server_event = wait_for_events(
- [self.client_completion_queue, self.server_completion_queue],
- float("+inf"))
- if (client_event.type == _types.EventType.QUEUE_SHUTDOWN and
- server_event.type == _types.EventType.QUEUE_SHUTDOWN):
- break
-
- del self.client_completion_queue
- del self.server_completion_queue
- del self.server
-
- def testHangingServerCall(self):
- deadline = time.time() + 5
- deadline_tolerance = 0.25
- event_time_tolerance = 2
- cancel_all_calls_time_tolerance = 0.5
- request = 'blarghaflargh'
- method = 'twinkies'
- host = 'hostess'
- server_request_tag = object()
- request_call_result = self.server.request_call(self.server_completion_queue,
- server_request_tag)
-
- client_call_tag = object()
- client_call = self.client_channel.create_call(self.client_completion_queue,
- method, host, deadline)
- client_start_batch_result = client_call.start_batch([
- _types.OpArgs.send_initial_metadata([]),
- _types.OpArgs.send_message(request, 0),
- _types.OpArgs.send_close_from_client(),
- _types.OpArgs.recv_initial_metadata(),
- _types.OpArgs.recv_message(),
- _types.OpArgs.recv_status_on_client()
- ], client_call_tag)
-
- client_no_event, request_event, = wait_for_events(
- [self.client_completion_queue, self.server_completion_queue],
- time.time() + event_time_tolerance)
-
- # Now try to shutdown the server and expect that we see server shutdown
- # almost immediately after calling cancel_all_calls.
-
- # First attempt to cancel all calls before shutting down, and expect
- # our state machine to catch the erroneous API use.
- with self.assertRaises(RuntimeError):
- self.server.cancel_all_calls()
-
- shutdown_tag = object()
- self.server.shutdown(shutdown_tag)
- pre_cancel_timestamp = time.time()
- self.server.cancel_all_calls()
- finish_shutdown_timestamp = None
- client_call_event, server_shutdown_event = wait_for_events(
- [self.client_completion_queue, self.server_completion_queue],
- time.time() + event_time_tolerance)
- self.assertIs(shutdown_tag, server_shutdown_event.tag)
- self.assertGreater(pre_cancel_timestamp + cancel_all_calls_time_tolerance,
- time.time())
-
- del client_call
-
-
-if __name__ == '__main__':
- unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_api_test.py b/src/python/grpcio/tests/unit/_api_test.py
new file mode 100644
index 0000000000..2fe89499f5
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_api_test.py
@@ -0,0 +1,111 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test of gRPC Python's application-layer API."""
+
+import unittest
+
+import six
+
+import grpc
+
+from tests.unit import _from_grpc_import_star
+
+
+class AllTest(unittest.TestCase):
+
+ def testAll(self):
+ expected_grpc_code_elements = (
+ 'FutureTimeoutError',
+ 'FutureCancelledError',
+ 'Future',
+ 'ChannelConnectivity',
+ 'StatusCode',
+ 'RpcError',
+ 'RpcContext',
+ 'Call',
+ 'ChannelCredentials',
+ 'CallCredentials',
+ 'AuthMetadataContext',
+ 'AuthMetadataPluginCallback',
+ 'AuthMetadataPlugin',
+ 'ServerCredentials',
+ 'UnaryUnaryMultiCallable',
+ 'UnaryStreamMultiCallable',
+ 'StreamUnaryMultiCallable',
+ 'StreamStreamMultiCallable',
+ 'Channel',
+ 'ServicerContext',
+ 'RpcMethodHandler',
+ 'HandlerCallDetails',
+ 'GenericRpcHandler',
+ 'Server',
+ 'unary_unary_rpc_method_handler',
+ 'unary_stream_rpc_method_handler',
+ 'stream_unary_rpc_method_handler',
+ 'stream_stream_rpc_method_handler',
+ 'method_handlers_generic_handler',
+ 'ssl_channel_credentials',
+ 'metadata_call_credentials',
+ 'access_token_call_credentials',
+ 'composite_call_credentials',
+ 'composite_channel_credentials',
+ 'ssl_server_credentials',
+ 'channel_ready_future',
+ 'insecure_channel',
+ 'secure_channel',
+ 'server',
+ )
+
+ six.assertCountEqual(
+ self, expected_grpc_code_elements,
+ _from_grpc_import_star.GRPC_ELEMENTS)
+
+
+class ChannelConnectivityTest(unittest.TestCase):
+
+ def testChannelConnectivity(self):
+ self.assertSequenceEqual(
+ (grpc.ChannelConnectivity.IDLE,
+ grpc.ChannelConnectivity.CONNECTING,
+ grpc.ChannelConnectivity.READY,
+ grpc.ChannelConnectivity.TRANSIENT_FAILURE,
+ grpc.ChannelConnectivity.SHUTDOWN,),
+ tuple(grpc.ChannelConnectivity))
+
+
+class ChannelTest(unittest.TestCase):
+
+ def test_secure_channel(self):
+ channel_credentials = grpc.ssl_channel_credentials()
+ channel = grpc.secure_channel('google.com:443', channel_credentials)
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/beta/_auth_test.py b/src/python/grpcio/tests/unit/_auth_test.py
index 694928a91b..c31f7b06f7 100644
--- a/src/python/grpcio/tests/unit/beta/_auth_test.py
+++ b/src/python/grpcio/tests/unit/_auth_test.py
@@ -33,7 +33,7 @@ import collections
import threading
import unittest
-from grpc.beta import _auth
+from grpc import _auth
class MockGoogleCreds(object):
diff --git a/src/python/grpcio/tests/unit/_channel_connectivity_test.py b/src/python/grpcio/tests/unit/_channel_connectivity_test.py
index a1575efada..ae8de523ec 100644
--- a/src/python/grpcio/tests/unit/_channel_connectivity_test.py
+++ b/src/python/grpcio/tests/unit/_channel_connectivity_test.py
@@ -135,12 +135,12 @@ class ChannelConnectivityTest(unittest.TestCase):
self.assertNotIn(
grpc.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities)
self.assertNotIn(
- grpc.ChannelConnectivity.FATAL_FAILURE, third_connectivities)
+ grpc.ChannelConnectivity.SHUTDOWN, third_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.TRANSIENT_FAILURE,
fourth_connectivities)
self.assertNotIn(
- grpc.ChannelConnectivity.FATAL_FAILURE, fourth_connectivities)
+ grpc.ChannelConnectivity.SHUTDOWN, fourth_connectivities)
def test_reachable_then_unreachable_channel_connectivity(self):
server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0))
diff --git a/src/python/grpcio/tests/unit/_core_over_links_base_interface_test.py b/src/python/grpcio/tests/unit/_core_over_links_base_interface_test.py
deleted file mode 100644
index 2b8981c752..0000000000
--- a/src/python/grpcio/tests/unit/_core_over_links_base_interface_test.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# Copyright 2015, Google Inc.
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests Base interface compliance of the core-over-gRPC-links stack."""
-
-import collections
-import logging
-import random
-import time
-import unittest
-
-import six
-
-from grpc._adapter import _intermediary_low
-from grpc._links import invocation
-from grpc._links import service
-from grpc.beta import interfaces as beta_interfaces
-from grpc.framework.core import implementations
-from grpc.framework.interfaces.base import utilities
-from tests.unit import test_common as grpc_test_common
-from tests.unit.framework.common import test_constants
-from tests.unit.framework.interfaces.base import test_cases
-from tests.unit.framework.interfaces.base import test_interfaces
-
-
-class _SerializationBehaviors(
- collections.namedtuple(
- '_SerializationBehaviors',
- ('request_serializers', 'request_deserializers', 'response_serializers',
- 'response_deserializers',))):
- pass
-
-
-class _Links(
- collections.namedtuple(
- '_Links',
- ('invocation_end_link', 'invocation_grpc_link', 'service_grpc_link',
- 'service_end_link'))):
- pass
-
-
-def _serialization_behaviors_from_serializations(serializations):
- request_serializers = {}
- request_deserializers = {}
- response_serializers = {}
- response_deserializers = {}
- for (group, method), serialization in six.iteritems(serializations):
- request_serializers[group, method] = serialization.serialize_request
- request_deserializers[group, method] = serialization.deserialize_request
- response_serializers[group, method] = serialization.serialize_response
- response_deserializers[group, method] = serialization.deserialize_response
- return _SerializationBehaviors(
- request_serializers, request_deserializers, response_serializers,
- response_deserializers)
-
-
-class _Implementation(test_interfaces.Implementation):
-
- def instantiate(self, serializations, servicer):
- serialization_behaviors = _serialization_behaviors_from_serializations(
- serializations)
- invocation_end_link = implementations.invocation_end_link()
- service_end_link = implementations.service_end_link(
- servicer, test_constants.DEFAULT_TIMEOUT,
- test_constants.MAXIMUM_TIMEOUT)
- service_grpc_link = service.service_link(
- serialization_behaviors.request_deserializers,
- serialization_behaviors.response_serializers)
- port = service_grpc_link.add_port('[::]:0', None)
- channel = _intermediary_low.Channel('localhost:%d' % port, None)
- invocation_grpc_link = invocation.invocation_link(
- channel, b'localhost', None,
- serialization_behaviors.request_serializers,
- serialization_behaviors.response_deserializers)
-
- invocation_end_link.join_link(invocation_grpc_link)
- invocation_grpc_link.join_link(invocation_end_link)
- service_end_link.join_link(service_grpc_link)
- service_grpc_link.join_link(service_end_link)
- invocation_grpc_link.start()
- service_grpc_link.start()
- return invocation_end_link, service_end_link, (
- invocation_grpc_link, service_grpc_link)
-
- def destantiate(self, memo):
- invocation_grpc_link, service_grpc_link = memo
- invocation_grpc_link.stop()
- service_grpc_link.begin_stop()
- service_grpc_link.end_stop()
-
- def invocation_initial_metadata(self):
- return grpc_test_common.INVOCATION_INITIAL_METADATA
-
- def service_initial_metadata(self):
- return grpc_test_common.SERVICE_INITIAL_METADATA
-
- def invocation_completion(self):
- return utilities.completion(None, None, None)
-
- def service_completion(self):
- return utilities.completion(
- grpc_test_common.SERVICE_TERMINAL_METADATA,
- beta_interfaces.StatusCode.OK, grpc_test_common.DETAILS)
-
- def metadata_transmitted(self, original_metadata, transmitted_metadata):
- return original_metadata is None or grpc_test_common.metadata_transmitted(
- original_metadata, transmitted_metadata)
-
- def completion_transmitted(self, original_completion, transmitted_completion):
- if (original_completion.terminal_metadata is not None and
- not grpc_test_common.metadata_transmitted(
- original_completion.terminal_metadata,
- transmitted_completion.terminal_metadata)):
- return False
- elif original_completion.code is not transmitted_completion.code:
- return False
- elif original_completion.message != transmitted_completion.message:
- return False
- else:
- return True
-
-
-def load_tests(loader, tests, pattern):
- return unittest.TestSuite(
- tests=tuple(
- loader.loadTestsFromTestCase(test_case_class)
- for test_case_class in test_cases.test_cases(_Implementation())))
-
-
-if __name__ == '__main__':
- unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_crust_over_core_over_links_face_interface_test.py b/src/python/grpcio/tests/unit/_crust_over_core_over_links_face_interface_test.py
deleted file mode 100644
index 50b9a5a824..0000000000
--- a/src/python/grpcio/tests/unit/_crust_over_core_over_links_face_interface_test.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# Copyright 2015, Google Inc.
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests Face compliance of the crust-over-core-over-gRPC-links stack."""
-
-import collections
-import unittest
-
-import six
-
-from grpc._adapter import _intermediary_low
-from grpc._links import invocation
-from grpc._links import service
-from grpc.beta import interfaces as beta_interfaces
-from grpc.framework.core import implementations as core_implementations
-from grpc.framework.crust import implementations as crust_implementations
-from grpc.framework.foundation import logging_pool
-from grpc.framework.interfaces.links import utilities
-from tests.unit import test_common as grpc_test_common
-from tests.unit.framework.common import test_constants
-from tests.unit.framework.interfaces.face import test_cases
-from tests.unit.framework.interfaces.face import test_interfaces
-
-
-class _SerializationBehaviors(
- collections.namedtuple(
- '_SerializationBehaviors',
- ('request_serializers', 'request_deserializers', 'response_serializers',
- 'response_deserializers',))):
- pass
-
-
-def _serialization_behaviors_from_test_methods(test_methods):
- request_serializers = {}
- request_deserializers = {}
- response_serializers = {}
- response_deserializers = {}
- for (group, method), test_method in six.iteritems(test_methods):
- request_serializers[group, method] = test_method.serialize_request
- request_deserializers[group, method] = test_method.deserialize_request
- response_serializers[group, method] = test_method.serialize_response
- response_deserializers[group, method] = test_method.deserialize_response
- return _SerializationBehaviors(
- request_serializers, request_deserializers, response_serializers,
- response_deserializers)
-
-
-class _Implementation(test_interfaces.Implementation):
-
- def instantiate(
- self, methods, method_implementations, multi_method_implementation):
- pool = logging_pool.pool(test_constants.POOL_SIZE)
- servicer = crust_implementations.servicer(
- method_implementations, multi_method_implementation, pool)
- serialization_behaviors = _serialization_behaviors_from_test_methods(
- methods)
- invocation_end_link = core_implementations.invocation_end_link()
- service_end_link = core_implementations.service_end_link(
- servicer, test_constants.DEFAULT_TIMEOUT,
- test_constants.MAXIMUM_TIMEOUT)
- service_grpc_link = service.service_link(
- serialization_behaviors.request_deserializers,
- serialization_behaviors.response_serializers)
- port = service_grpc_link.add_port('[::]:0', None)
- channel = _intermediary_low.Channel('localhost:%d' % port, None)
- invocation_grpc_link = invocation.invocation_link(
- channel, b'localhost', None,
- serialization_behaviors.request_serializers,
- serialization_behaviors.response_deserializers)
-
- invocation_end_link.join_link(invocation_grpc_link)
- invocation_grpc_link.join_link(invocation_end_link)
- service_grpc_link.join_link(service_end_link)
- service_end_link.join_link(service_grpc_link)
- service_end_link.start()
- invocation_end_link.start()
- invocation_grpc_link.start()
- service_grpc_link.start()
-
- generic_stub = crust_implementations.generic_stub(invocation_end_link, pool)
- # TODO(nathaniel): Add a "groups" attribute to _digest.TestServiceDigest.
- group = next(iter(methods))[0]
- # TODO(nathaniel): Add a "cardinalities_by_group" attribute to
- # _digest.TestServiceDigest.
- cardinalities = {
- method: method_object.cardinality()
- for (group, method), method_object in six.iteritems(methods)}
- dynamic_stub = crust_implementations.dynamic_stub(
- invocation_end_link, group, cardinalities, pool)
-
- return generic_stub, {group: dynamic_stub}, (
- invocation_end_link, invocation_grpc_link, service_grpc_link,
- service_end_link, pool)
-
- def destantiate(self, memo):
- (invocation_end_link, invocation_grpc_link, service_grpc_link,
- service_end_link, pool) = memo
- invocation_end_link.stop(0).wait()
- invocation_grpc_link.stop()
- service_grpc_link.begin_stop()
- service_end_link.stop(0).wait()
- service_grpc_link.end_stop()
- invocation_end_link.join_link(utilities.NULL_LINK)
- invocation_grpc_link.join_link(utilities.NULL_LINK)
- service_grpc_link.join_link(utilities.NULL_LINK)
- service_end_link.join_link(utilities.NULL_LINK)
- pool.shutdown(wait=True)
-
- def invocation_metadata(self):
- return grpc_test_common.INVOCATION_INITIAL_METADATA
-
- def initial_metadata(self):
- return grpc_test_common.SERVICE_INITIAL_METADATA
-
- def terminal_metadata(self):
- return grpc_test_common.SERVICE_TERMINAL_METADATA
-
- def code(self):
- return beta_interfaces.StatusCode.OK
-
- def details(self):
- return grpc_test_common.DETAILS
-
- def metadata_transmitted(self, original_metadata, transmitted_metadata):
- return original_metadata is None or grpc_test_common.metadata_transmitted(
- original_metadata, transmitted_metadata)
-
-
-def load_tests(loader, tests, pattern):
- return unittest.TestSuite(
- tests=tuple(
- loader.loadTestsFromTestCase(test_case_class)
- for test_case_class in test_cases.test_cases(_Implementation())))
-
-
-if __name__ == '__main__':
- unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py
new file mode 100644
index 0000000000..6ae7a90fbe
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py
@@ -0,0 +1,251 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test a corner-case at the level of the Cython API."""
+
+import threading
+import unittest
+
+from grpc._cython import cygrpc
+
+_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
+_EMPTY_FLAGS = 0
+_EMPTY_METADATA = cygrpc.Metadata(())
+
+
+class _ServerDriver(object):
+
+ def __init__(self, completion_queue, shutdown_tag):
+ self._condition = threading.Condition()
+ self._completion_queue = completion_queue
+ self._shutdown_tag = shutdown_tag
+ self._events = []
+ self._saw_shutdown_tag = False
+
+ def start(self):
+ def in_thread():
+ while True:
+ event = self._completion_queue.poll()
+ with self._condition:
+ self._events.append(event)
+ self._condition.notify()
+ if event.tag is self._shutdown_tag:
+ self._saw_shutdown_tag = True
+ break
+ thread = threading.Thread(target=in_thread)
+ thread.start()
+
+ def done(self):
+ with self._condition:
+ return self._saw_shutdown_tag
+
+ def first_event(self):
+ with self._condition:
+ while not self._events:
+ self._condition.wait()
+ return self._events[0]
+
+ def events(self):
+ with self._condition:
+ while not self._saw_shutdown_tag:
+ self._condition.wait()
+ return tuple(self._events)
+
+
+class _QueueDriver(object):
+
+ def __init__(self, condition, completion_queue, due):
+ self._condition = condition
+ self._completion_queue = completion_queue
+ self._due = due
+ self._events = []
+ self._returned = False
+
+ def start(self):
+ def in_thread():
+ while True:
+ event = self._completion_queue.poll()
+ with self._condition:
+ self._events.append(event)
+ self._due.remove(event.tag)
+ self._condition.notify_all()
+ if not self._due:
+ self._returned = True
+ return
+ thread = threading.Thread(target=in_thread)
+ thread.start()
+
+ def done(self):
+ with self._condition:
+ return self._returned
+
+ def event_with_tag(self, tag):
+ with self._condition:
+ while True:
+ for event in self._events:
+ if event.tag is tag:
+ return event
+ self._condition.wait()
+
+ def events(self):
+ with self._condition:
+ while not self._returned:
+ self._condition.wait()
+ return tuple(self._events)
+
+
+class ReadSomeButNotAllResponsesTest(unittest.TestCase):
+
+ def testReadSomeButNotAllResponses(self):
+ server_completion_queue = cygrpc.CompletionQueue()
+ server = cygrpc.Server()
+ server.register_completion_queue(server_completion_queue)
+ port = server.add_http2_port('[::]:0')
+ server.start()
+ channel = cygrpc.Channel('localhost:{}'.format(port))
+
+ server_shutdown_tag = 'server_shutdown_tag'
+ server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag)
+ server_driver.start()
+
+ client_condition = threading.Condition()
+ client_due = set()
+ client_completion_queue = cygrpc.CompletionQueue()
+ client_driver = _QueueDriver(
+ client_condition, client_completion_queue, client_due)
+ client_driver.start()
+
+ server_call_condition = threading.Condition()
+ server_send_initial_metadata_tag = 'server_send_initial_metadata_tag'
+ server_send_first_message_tag = 'server_send_first_message_tag'
+ server_send_second_message_tag = 'server_send_second_message_tag'
+ server_complete_rpc_tag = 'server_complete_rpc_tag'
+ server_call_due = set((
+ server_send_initial_metadata_tag,
+ server_send_first_message_tag,
+ server_send_second_message_tag,
+ server_complete_rpc_tag,
+ ))
+ server_call_completion_queue = cygrpc.CompletionQueue()
+ server_call_driver = _QueueDriver(
+ server_call_condition, server_call_completion_queue, server_call_due)
+ server_call_driver.start()
+
+ server_rpc_tag = 'server_rpc_tag'
+ request_call_result = server.request_call(
+ server_call_completion_queue, server_completion_queue, server_rpc_tag)
+
+ client_call = channel.create_call(
+ None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None,
+ _INFINITE_FUTURE)
+ client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
+ client_complete_rpc_tag = 'client_complete_rpc_tag'
+ with client_condition:
+ client_receive_initial_metadata_start_batch_result = (
+ client_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+ ]), client_receive_initial_metadata_tag))
+ client_due.add(client_receive_initial_metadata_tag)
+ client_complete_rpc_start_batch_result = (
+ client_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_send_initial_metadata(
+ _EMPTY_METADATA, _EMPTY_FLAGS),
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
+ ]), client_complete_rpc_tag))
+ client_due.add(client_complete_rpc_tag)
+
+ server_rpc_event = server_driver.first_event()
+
+ with server_call_condition:
+ server_send_initial_metadata_start_batch_result = (
+ server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_send_initial_metadata(
+ _EMPTY_METADATA, _EMPTY_FLAGS),
+ ]), server_send_initial_metadata_tag))
+ server_send_first_message_start_batch_result = (
+ server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
+ ]), server_send_first_message_tag))
+ server_send_initial_metadata_event = server_call_driver.event_with_tag(
+ server_send_initial_metadata_tag)
+ server_send_first_message_event = server_call_driver.event_with_tag(
+ server_send_first_message_tag)
+ with server_call_condition:
+ server_send_second_message_start_batch_result = (
+ server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
+ ]), server_send_second_message_tag))
+ server_complete_rpc_start_batch_result = (
+ server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
+ cygrpc.operation_send_status_from_server(
+ cygrpc.Metadata(()), cygrpc.StatusCode.ok, b'test details',
+ _EMPTY_FLAGS),
+ ]), server_complete_rpc_tag))
+ server_send_second_message_event = server_call_driver.event_with_tag(
+ server_send_second_message_tag)
+ server_complete_rpc_event = server_call_driver.event_with_tag(
+ server_complete_rpc_tag)
+ server_call_driver.events()
+
+ with client_condition:
+ client_receive_first_message_tag = 'client_receive_first_message_tag'
+ client_receive_first_message_start_batch_result = (
+ client_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ ]), client_receive_first_message_tag))
+ client_due.add(client_receive_first_message_tag)
+ client_receive_first_message_event = client_driver.event_with_tag(
+ client_receive_first_message_tag)
+
+ client_call_cancel_result = client_call.cancel()
+ client_driver.events()
+
+ server.shutdown(server_completion_queue, server_shutdown_tag)
+ server.cancel_all_calls()
+ server_driver.events()
+
+ self.assertEqual(cygrpc.CallError.ok, request_call_result)
+ self.assertEqual(
+ cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result)
+ self.assertEqual(
+ cygrpc.CallError.ok, client_receive_initial_metadata_start_batch_result)
+ self.assertEqual(
+ cygrpc.CallError.ok, client_complete_rpc_start_batch_result)
+ self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result)
+ self.assertIs(server_rpc_tag, server_rpc_event.tag)
+ self.assertEqual(
+ cygrpc.CompletionType.operation_complete, server_rpc_event.type)
+ self.assertIsInstance(server_rpc_event.operation_call, cygrpc.Call)
+ self.assertEqual(0, len(server_rpc_event.batch_operations))
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
index 0a511101f0..a006a20ce3 100644
--- a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
+++ b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
@@ -37,7 +37,7 @@ from tests.unit import test_common
from tests.unit import resources
-_SSL_HOST_OVERRIDE = 'foo.test.google.fr'
+_SSL_HOST_OVERRIDE = b'foo.test.google.fr'
_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
_EMPTY_FLAGS = 0
@@ -143,22 +143,60 @@ class TypeSmokeTest(unittest.TestCase):
del completion_queue
-class InsecureServerInsecureClient(unittest.TestCase):
+class ServerClientMixin(object):
- def setUp(self):
+ def setUpMixin(self, server_credentials, client_credentials, host_override):
self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server()
self.server.register_completion_queue(self.server_completion_queue)
- self.port = self.server.add_http2_port('[::]:0')
+ if server_credentials:
+ self.port = self.server.add_http2_port('[::]:0', server_credentials)
+ else:
+ self.port = self.server.add_http2_port('[::]:0')
self.server.start()
self.client_completion_queue = cygrpc.CompletionQueue()
- self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port))
-
- def tearDown(self):
+ if client_credentials:
+ client_channel_arguments = cygrpc.ChannelArgs([
+ cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
+ host_override)])
+ self.client_channel = cygrpc.Channel(
+ 'localhost:{}'.format(self.port), client_channel_arguments,
+ client_credentials)
+ else:
+ self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port))
+ if host_override:
+ self.host_argument = None # default host
+ self.expected_host = host_override
+ else:
+ # arbitrary host name necessitating no further identification
+ self.host_argument = b'hostess'
+ self.expected_host = self.host_argument
+
+ def tearDownMixin(self):
del self.server
del self.client_completion_queue
del self.server_completion_queue
+ def _perform_operations(self, operations, call, queue, deadline, description):
+ """Perform the list of operations with given call, queue, and deadline.
+
+ Invocation errors are reported with as an exception with `description` in
+ the message. Performs the operations asynchronously, returning a future.
+ """
+ def performer():
+ tag = object()
+ try:
+ call_result = call.start_batch(cygrpc.Operations(operations), tag)
+ self.assertEqual(cygrpc.CallError.ok, call_result)
+ event = queue.poll(deadline)
+ self.assertEqual(cygrpc.CompletionType.operation_complete, event.type)
+ self.assertTrue(event.success)
+ self.assertIs(tag, event.tag)
+ except Exception as error:
+ raise Exception("Error in '{}': {}".format(description, error.message))
+ return event
+ return test_utilities.SimpleFuture(performer)
+
def testEcho(self):
DEADLINE = time.time()+5
DEADLINE_TOLERANCE = 0.25
@@ -175,7 +213,6 @@ class InsecureServerInsecureClient(unittest.TestCase):
REQUEST = b'in death a member of project mayhem has a name'
RESPONSE = b'his name is robert paulson'
METHOD = b'twinkies'
- HOST = b'hostess'
cygrpc_deadline = cygrpc.Timespec(DEADLINE)
@@ -188,7 +225,8 @@ class InsecureServerInsecureClient(unittest.TestCase):
client_call_tag = object()
client_call = self.client_channel.create_call(
- None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline)
+ None, 0, self.client_completion_queue, METHOD, self.host_argument,
+ cygrpc_deadline)
client_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE),
@@ -216,7 +254,8 @@ class InsecureServerInsecureClient(unittest.TestCase):
test_common.metadata_transmitted(client_initial_metadata,
request_event.request_metadata))
self.assertEqual(METHOD, request_event.request_call_details.method)
- self.assertEqual(HOST, request_event.request_call_details.host)
+ self.assertEqual(self.expected_host,
+ request_event.request_call_details.host)
self.assertLess(
abs(DEADLINE - float(request_event.request_call_details.deadline)),
DEADLINE_TOLERANCE)
@@ -292,172 +331,101 @@ class InsecureServerInsecureClient(unittest.TestCase):
del client_call
del server_call
-
-class SecureServerSecureClient(unittest.TestCase):
-
- def setUp(self):
- server_credentials = cygrpc.server_credentials_ssl(
- None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
- resources.certificate_chain())], False)
- channel_credentials = cygrpc.channel_credentials_ssl(
- resources.test_root_certificates(), None)
- self.server_completion_queue = cygrpc.CompletionQueue()
- self.server = cygrpc.Server()
- self.server.register_completion_queue(self.server_completion_queue)
- self.port = self.server.add_http2_port('[::]:0', server_credentials)
- self.server.start()
- self.client_completion_queue = cygrpc.CompletionQueue()
- client_channel_arguments = cygrpc.ChannelArgs([
- cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
- _SSL_HOST_OVERRIDE)])
- self.client_channel = cygrpc.Channel(
- 'localhost:{}'.format(self.port), client_channel_arguments,
- channel_credentials)
-
- def tearDown(self):
- del self.server
- del self.client_completion_queue
- del self.server_completion_queue
-
- def testEcho(self):
+ def test6522(self):
DEADLINE = time.time()+5
DEADLINE_TOLERANCE = 0.25
- CLIENT_METADATA_ASCII_KEY = b'key'
- CLIENT_METADATA_ASCII_VALUE = b'val'
- CLIENT_METADATA_BIN_KEY = b'key-bin'
- CLIENT_METADATA_BIN_VALUE = b'\0'*1000
- SERVER_INITIAL_METADATA_KEY = b'init_me_me_me'
- SERVER_INITIAL_METADATA_VALUE = b'whodawha?'
- SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought'
- SERVER_TRAILING_METADATA_VALUE = b'zomg it is'
- SERVER_STATUS_CODE = cygrpc.StatusCode.ok
- SERVER_STATUS_DETAILS = b'our work is never over'
- REQUEST = b'in death a member of project mayhem has a name'
- RESPONSE = b'his name is robert paulson'
- METHOD = b'/twinkies'
- HOST = None # Default host
+ METHOD = b'twinkies'
cygrpc_deadline = cygrpc.Timespec(DEADLINE)
+ empty_metadata = cygrpc.Metadata([])
server_request_tag = object()
- request_call_result = self.server.request_call(
+ self.server.request_call(
self.server_completion_queue, self.server_completion_queue,
server_request_tag)
+ client_call = self.client_channel.create_call(
+ None, 0, self.client_completion_queue, METHOD, self.host_argument,
+ cygrpc_deadline)
- self.assertEqual(cygrpc.CallError.ok, request_call_result)
-
- plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '')
- call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
+ # Prologue
+ def perform_client_operations(operations, description):
+ return self._perform_operations(
+ operations, client_call,
+ self.client_completion_queue, cygrpc_deadline, description)
- client_call_tag = object()
- client_call = self.client_channel.create_call(
- None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline)
- client_call.set_credentials(call_credentials)
- client_initial_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
- CLIENT_METADATA_ASCII_VALUE),
- cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
- client_start_batch_result = client_call.start_batch(cygrpc.Operations([
- cygrpc.operation_send_initial_metadata(client_initial_metadata,
- _EMPTY_FLAGS),
- cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS),
- cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
- cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
- ]), client_call_tag)
- self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
- client_event_future = test_utilities.CompletionQueuePollFuture(
- self.client_completion_queue, cygrpc_deadline)
+ client_event_future = perform_client_operations([
+ cygrpc.operation_send_initial_metadata(empty_metadata,
+ _EMPTY_FLAGS),
+ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+ ], "Client prologue")
request_event = self.server_completion_queue.poll(cygrpc_deadline)
- self.assertEqual(cygrpc.CompletionType.operation_complete,
- request_event.type)
- self.assertIsInstance(request_event.operation_call, cygrpc.Call)
- self.assertIs(server_request_tag, request_event.tag)
- self.assertEqual(0, len(request_event.batch_operations))
- client_metadata_with_credentials = list(client_initial_metadata) + [
- (_CALL_CREDENTIALS_METADATA_KEY, _CALL_CREDENTIALS_METADATA_VALUE)]
- self.assertTrue(
- test_common.metadata_transmitted(client_metadata_with_credentials,
- request_event.request_metadata))
- self.assertEqual(METHOD, request_event.request_call_details.method)
- self.assertEqual(_SSL_HOST_OVERRIDE,
- request_event.request_call_details.host)
- self.assertLess(
- abs(DEADLINE - float(request_event.request_call_details.deadline)),
- DEADLINE_TOLERANCE)
-
- server_call_tag = object()
server_call = request_event.operation_call
- server_initial_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
- SERVER_INITIAL_METADATA_VALUE)])
- server_trailing_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
- SERVER_TRAILING_METADATA_VALUE)])
- server_start_batch_result = server_call.start_batch([
- cygrpc.operation_send_initial_metadata(server_initial_metadata,
- _EMPTY_FLAGS),
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- cygrpc.operation_send_message(RESPONSE, _EMPTY_FLAGS),
- cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
- cygrpc.operation_send_status_from_server(
- server_trailing_metadata, SERVER_STATUS_CODE,
- SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
- ], server_call_tag)
- self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
- client_event = client_event_future.result()
- server_event = self.server_completion_queue.poll(cygrpc_deadline)
+ def perform_server_operations(operations, description):
+ return self._perform_operations(
+ operations, server_call,
+ self.server_completion_queue, cygrpc_deadline, description)
- self.assertEqual(6, len(client_event.batch_operations))
- found_client_op_types = set()
- for client_result in client_event.batch_operations:
- # we expect each op type to be unique
- self.assertNotIn(client_result.type, found_client_op_types)
- found_client_op_types.add(client_result.type)
- if client_result.type == cygrpc.OperationType.receive_initial_metadata:
- self.assertTrue(
- test_common.metadata_transmitted(server_initial_metadata,
- client_result.received_metadata))
- elif client_result.type == cygrpc.OperationType.receive_message:
- self.assertEqual(RESPONSE, client_result.received_message.bytes())
- elif client_result.type == cygrpc.OperationType.receive_status_on_client:
- self.assertTrue(
- test_common.metadata_transmitted(server_trailing_metadata,
- client_result.received_metadata))
- self.assertEqual(SERVER_STATUS_DETAILS,
- client_result.received_status_details)
- self.assertEqual(SERVER_STATUS_CODE, client_result.received_status_code)
- self.assertEqual(set([
- cygrpc.OperationType.send_initial_metadata,
- cygrpc.OperationType.send_message,
- cygrpc.OperationType.send_close_from_client,
- cygrpc.OperationType.receive_initial_metadata,
- cygrpc.OperationType.receive_message,
- cygrpc.OperationType.receive_status_on_client
- ]), found_client_op_types)
+ server_event_future = perform_server_operations([
+ cygrpc.operation_send_initial_metadata(empty_metadata,
+ _EMPTY_FLAGS),
+ ], "Server prologue")
- self.assertEqual(5, len(server_event.batch_operations))
- found_server_op_types = set()
- for server_result in server_event.batch_operations:
- self.assertNotIn(client_result.type, found_server_op_types)
- found_server_op_types.add(server_result.type)
- if server_result.type == cygrpc.OperationType.receive_message:
- self.assertEqual(REQUEST, server_result.received_message.bytes())
- elif server_result.type == cygrpc.OperationType.receive_close_on_server:
- self.assertFalse(server_result.received_cancelled)
- self.assertEqual(set([
- cygrpc.OperationType.send_initial_metadata,
- cygrpc.OperationType.receive_message,
- cygrpc.OperationType.send_message,
- cygrpc.OperationType.receive_close_on_server,
- cygrpc.OperationType.send_status_from_server
- ]), found_server_op_types)
+ client_event_future.result() # force completion
+ server_event_future.result()
- del client_call
- del server_call
+ # Messaging
+ for _ in range(10):
+ client_event_future = perform_client_operations([
+ cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ ], "Client message")
+ server_event_future = perform_server_operations([
+ cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ ], "Server receive")
+
+ client_event_future.result() # force completion
+ server_event_future.result()
+
+ # Epilogue
+ client_event_future = perform_client_operations([
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
+ ], "Client epilogue")
+
+ server_event_future = perform_server_operations([
+ cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
+ cygrpc.operation_send_status_from_server(
+ empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
+ ], "Server epilogue")
+
+ client_event_future.result() # force completion
+ server_event_future.result()
+
+
+class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
+
+ def setUp(self):
+ self.setUpMixin(None, None, None)
+
+ def tearDown(self):
+ self.tearDownMixin()
+
+
+class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
+
+ def setUp(self):
+ server_credentials = cygrpc.server_credentials_ssl(
+ None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
+ resources.certificate_chain())], False)
+ client_credentials = cygrpc.channel_credentials_ssl(
+ resources.test_root_certificates(), None)
+ self.setUpMixin(server_credentials, client_credentials, _SSL_HOST_OVERRIDE)
+
+ def tearDown(self):
+ self.tearDownMixin()
if __name__ == '__main__':
diff --git a/src/python/grpcio/tests/unit/_cython/test_utilities.py b/src/python/grpcio/tests/unit/_cython/test_utilities.py
index 6a09739643..6280ce74c4 100644
--- a/src/python/grpcio/tests/unit/_cython/test_utilities.py
+++ b/src/python/grpcio/tests/unit/_cython/test_utilities.py
@@ -32,15 +32,35 @@ import threading
from grpc._cython import cygrpc
-class CompletionQueuePollFuture:
+class SimpleFuture(object):
+ """A simple future mechanism."""
- def __init__(self, completion_queue, deadline):
- def poller_function():
- self._event_result = completion_queue.poll(deadline)
- self._event_result = None
- self._thread = threading.Thread(target=poller_function)
+ def __init__(self, function, *args, **kwargs):
+ def wrapped_function():
+ try:
+ self._result = function(*args, **kwargs)
+ except Exception as error:
+ self._error = error
+ self._result = None
+ self._error = None
+ self._thread = threading.Thread(target=wrapped_function)
self._thread.start()
def result(self):
+ """The resulting value of this future.
+
+ Re-raises any exceptions.
+ """
self._thread.join()
- return self._event_result
+ if self._error:
+ # TODO(atash): re-raise exceptions in a way that preserves tracebacks
+ raise self._error
+ return self._result
+
+
+class CompletionQueuePollFuture(SimpleFuture):
+
+ def __init__(self, completion_queue, deadline):
+ super(CompletionQueuePollFuture, self).__init__(
+ lambda: completion_queue.poll(deadline))
+
diff --git a/src/python/grpcio/tests/unit/_empty_message_test.py b/src/python/grpcio/tests/unit/_empty_message_test.py
new file mode 100644
index 0000000000..f324f6216b
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_empty_message_test.py
@@ -0,0 +1,137 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import unittest
+
+import grpc
+from grpc.framework.foundation import logging_pool
+
+from tests.unit.framework.common import test_constants
+
+_REQUEST = b''
+_RESPONSE = b''
+
+_UNARY_UNARY = b'/test/UnaryUnary'
+_UNARY_STREAM = b'/test/UnaryStream'
+_STREAM_UNARY = b'/test/StreamUnary'
+_STREAM_STREAM = b'/test/StreamStream'
+
+
+def handle_unary_unary(request, servicer_context):
+ return _RESPONSE
+
+
+def handle_unary_stream(request, servicer_context):
+ for _ in range(test_constants.STREAM_LENGTH):
+ yield _RESPONSE
+
+
+def handle_stream_unary(request_iterator, servicer_context):
+ for request in request_iterator:
+ pass
+ return _RESPONSE
+
+
+def handle_stream_stream(request_iterator, servicer_context):
+ for request in request_iterator:
+ yield _RESPONSE
+
+
+class _MethodHandler(grpc.RpcMethodHandler):
+
+ def __init__(self, request_streaming, response_streaming):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ self.stream_stream = handle_stream_stream
+ elif self.request_streaming:
+ self.stream_unary = handle_stream_unary
+ elif self.response_streaming:
+ self.unary_stream = handle_unary_stream
+ else:
+ self.unary_unary = handle_unary_unary
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(False, False)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(False, True)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(True, False)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(True, True)
+ else:
+ return None
+
+
+class EmptyMessageTest(unittest.TestCase):
+
+ def setUp(self):
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server((_GenericHandler(),), self._server_pool)
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+
+ def tearDown(self):
+ self._server.stop(0)
+
+ def testUnaryUnary(self):
+ response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST)
+ self.assertEqual(_RESPONSE, response)
+
+ def testUnaryStream(self):
+ response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST)
+ self.assertSequenceEqual(
+ [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator))
+
+ def testStreamUnary(self):
+ response = self._channel.stream_unary(_STREAM_UNARY)(
+ [_REQUEST] * test_constants.STREAM_LENGTH)
+ self.assertEqual(_RESPONSE, response)
+
+ def testStreamStream(self):
+ response_iterator = self._channel.stream_stream(_STREAM_STREAM)(
+ [_REQUEST] * test_constants.STREAM_LENGTH)
+ self.assertSequenceEqual(
+ [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator))
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
+
diff --git a/src/python/grpcio/grpc/_adapter/_implementations.py b/src/python/grpcio/tests/unit/_from_grpc_import_star.py
index b85f228bf6..78d2fb7dc5 100644
--- a/src/python/grpcio/grpc/_adapter/_implementations.py
+++ b/src/python/grpcio/tests/unit/_from_grpc_import_star.py
@@ -1,4 +1,4 @@
-# Copyright 2015, Google Inc.
+# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
@@ -27,22 +27,12 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-import collections
+_BEFORE_IMPORT = tuple(globals())
-from grpc.beta import interfaces
+from grpc import *
-class AuthMetadataContext(collections.namedtuple(
- 'AuthMetadataContext', [
- 'service_url',
- 'method_name'
- ]), interfaces.GRPCAuthMetadataContext):
- pass
+_AFTER_IMPORT = tuple(globals())
-
-class AuthMetadataPluginCallback(interfaces.GRPCAuthMetadataContext):
-
- def __init__(self, callback):
- self._callback = callback
-
- def __call__(self, metadata, error):
- self._callback(metadata, error)
+GRPC_ELEMENTS = tuple(
+ element for element in _AFTER_IMPORT
+ if element not in _BEFORE_IMPORT and element != '_BEFORE_IMPORT')
diff --git a/src/python/grpcio/tests/unit/_links/_lonely_invocation_link_test.py b/src/python/grpcio/tests/unit/_links/_lonely_invocation_link_test.py
deleted file mode 100644
index 890755f81c..0000000000
--- a/src/python/grpcio/tests/unit/_links/_lonely_invocation_link_test.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# Copyright 2015, Google Inc.
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""A test of invocation-side code unconnected to an RPC server."""
-
-import unittest
-
-from grpc._adapter import _intermediary_low
-from grpc._links import invocation
-from grpc.framework.interfaces.links import links
-from tests.unit.framework.common import test_constants
-from tests.unit.framework.interfaces.links import test_cases
-from tests.unit.framework.interfaces.links import test_utilities
-
-_NULL_BEHAVIOR = lambda unused_argument: None
-
-
-class LonelyInvocationLinkTest(unittest.TestCase):
-
- def testUpAndDown(self):
- channel = _intermediary_low.Channel('nonexistent:54321', None)
- invocation_link = invocation.invocation_link(
- channel, 'nonexistent', None, {}, {})
-
- invocation_link.start()
- invocation_link.stop()
-
- def _test_lonely_invocation_with_termination(self, termination):
- test_operation_id = object()
- test_group = 'test package.Test Service'
- test_method = 'test method'
- invocation_link_mate = test_utilities.RecordingLink()
-
- channel = _intermediary_low.Channel('nonexistent:54321', None)
- invocation_link = invocation.invocation_link(
- channel, 'nonexistent', None, {}, {})
- invocation_link.join_link(invocation_link_mate)
- invocation_link.start()
-
- ticket = links.Ticket(
- test_operation_id, 0, test_group, test_method,
- links.Ticket.Subscription.FULL, test_constants.SHORT_TIMEOUT, 1, None,
- None, None, None, None, termination, None)
- invocation_link.accept_ticket(ticket)
- invocation_link_mate.block_until_tickets_satisfy(test_cases.terminated)
-
- invocation_link.stop()
-
- self.assertIsNot(
- invocation_link_mate.tickets()[-1].termination,
- links.Ticket.Termination.COMPLETION)
-
- def testLonelyInvocationLinkWithCommencementTicket(self):
- self._test_lonely_invocation_with_termination(None)
-
- def testLonelyInvocationLinkWithEntireTicket(self):
- self._test_lonely_invocation_with_termination(
- links.Ticket.Termination.COMPLETION)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/src/python/grpcio/tests/unit/_links/_transmission_test.py b/src/python/grpcio/tests/unit/_links/_transmission_test.py
deleted file mode 100644
index 888684d197..0000000000
--- a/src/python/grpcio/tests/unit/_links/_transmission_test.py
+++ /dev/null
@@ -1,239 +0,0 @@
-# Copyright 2015, Google Inc.
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests transmission of tickets across gRPC-on-the-wire."""
-
-import unittest
-
-from grpc._adapter import _intermediary_low
-from grpc._links import invocation
-from grpc._links import service
-from grpc.beta import interfaces as beta_interfaces
-from grpc.framework.interfaces.links import links
-from tests.unit import test_common
-from tests.unit._links import _proto_scenarios
-from tests.unit.framework.common import test_constants
-from tests.unit.framework.interfaces.links import test_cases
-from tests.unit.framework.interfaces.links import test_utilities
-
-_IDENTITY = lambda x: x
-
-
-class TransmissionTest(test_cases.TransmissionTest, unittest.TestCase):
-
- def create_transmitting_links(self):
- service_link = service.service_link(
- {self.group_and_method(): self.deserialize_request},
- {self.group_and_method(): self.serialize_response})
- port = service_link.add_port('[::]:0', None)
- service_link.start()
- channel = _intermediary_low.Channel('localhost:%d' % port, None)
- invocation_link = invocation.invocation_link(
- channel, 'localhost', None,
- {self.group_and_method(): self.serialize_request},
- {self.group_and_method(): self.deserialize_response})
- invocation_link.start()
- return invocation_link, service_link
-
- def destroy_transmitting_links(self, invocation_side_link, service_side_link):
- invocation_side_link.stop()
- service_side_link.begin_stop()
- service_side_link.end_stop()
-
- def create_invocation_initial_metadata(self):
- return (
- ('first_invocation_initial_metadata_key', 'just a string value'),
- ('second_invocation_initial_metadata_key', '0123456789'),
- ('third_invocation_initial_metadata_key-bin', '\x00\x57' * 100),
- )
-
- def create_invocation_terminal_metadata(self):
- return None
-
- def create_service_initial_metadata(self):
- return (
- ('first_service_initial_metadata_key', 'just another string value'),
- ('second_service_initial_metadata_key', '9876543210'),
- ('third_service_initial_metadata_key-bin', '\x00\x59\x02' * 100),
- )
-
- def create_service_terminal_metadata(self):
- return (
- ('first_service_terminal_metadata_key', 'yet another string value'),
- ('second_service_terminal_metadata_key', 'abcdefghij'),
- ('third_service_terminal_metadata_key-bin', '\x00\x37' * 100),
- )
-
- def create_invocation_completion(self):
- return None, None
-
- def create_service_completion(self):
- return (
- beta_interfaces.StatusCode.OK, b'An exuberant test "details" message!')
-
- def assertMetadataTransmitted(self, original_metadata, transmitted_metadata):
- self.assertTrue(
- test_common.metadata_transmitted(
- original_metadata, transmitted_metadata),
- '%s erroneously transmitted as %s' % (
- original_metadata, transmitted_metadata))
-
-
-class RoundTripTest(unittest.TestCase):
-
- def testZeroMessageRoundTrip(self):
- test_operation_id = object()
- test_group = 'test package.Test Group'
- test_method = 'test method'
- identity_transformation = {(test_group, test_method): _IDENTITY}
- test_code = beta_interfaces.StatusCode.OK
- test_message = 'a test message'
-
- service_link = service.service_link(
- identity_transformation, identity_transformation)
- service_mate = test_utilities.RecordingLink()
- service_link.join_link(service_mate)
- port = service_link.add_port('[::]:0', None)
- service_link.start()
- channel = _intermediary_low.Channel('localhost:%d' % port, None)
- invocation_link = invocation.invocation_link(
- channel, None, None, identity_transformation, identity_transformation)
- invocation_mate = test_utilities.RecordingLink()
- invocation_link.join_link(invocation_mate)
- invocation_link.start()
-
- invocation_ticket = links.Ticket(
- test_operation_id, 0, test_group, test_method,
- links.Ticket.Subscription.FULL, test_constants.LONG_TIMEOUT, None, None,
- None, None, None, None, links.Ticket.Termination.COMPLETION, None)
- invocation_link.accept_ticket(invocation_ticket)
- service_mate.block_until_tickets_satisfy(test_cases.terminated)
-
- service_ticket = links.Ticket(
- service_mate.tickets()[-1].operation_id, 0, None, None, None, None,
- None, None, None, None, test_code, test_message,
- links.Ticket.Termination.COMPLETION, None)
- service_link.accept_ticket(service_ticket)
- invocation_mate.block_until_tickets_satisfy(test_cases.terminated)
-
- invocation_link.stop()
- service_link.begin_stop()
- service_link.end_stop()
-
- self.assertIs(
- service_mate.tickets()[-1].termination,
- links.Ticket.Termination.COMPLETION)
- self.assertIs(
- invocation_mate.tickets()[-1].termination,
- links.Ticket.Termination.COMPLETION)
- self.assertIs(invocation_mate.tickets()[-1].code, test_code)
- self.assertEqual(invocation_mate.tickets()[-1].message, test_message)
-
- def _perform_scenario_test(self, scenario):
- test_operation_id = object()
- test_group, test_method = scenario.group_and_method()
- test_code = beta_interfaces.StatusCode.OK
- test_message = 'a scenario test message'
-
- service_link = service.service_link(
- {(test_group, test_method): scenario.deserialize_request},
- {(test_group, test_method): scenario.serialize_response})
- service_mate = test_utilities.RecordingLink()
- service_link.join_link(service_mate)
- port = service_link.add_port('[::]:0', None)
- service_link.start()
- channel = _intermediary_low.Channel('localhost:%d' % port, None)
- invocation_link = invocation.invocation_link(
- channel, 'localhost', None,
- {(test_group, test_method): scenario.serialize_request},
- {(test_group, test_method): scenario.deserialize_response})
- invocation_mate = test_utilities.RecordingLink()
- invocation_link.join_link(invocation_mate)
- invocation_link.start()
-
- invocation_ticket = links.Ticket(
- test_operation_id, 0, test_group, test_method,
- links.Ticket.Subscription.FULL, test_constants.LONG_TIMEOUT, None, None,
- None, None, None, None, None, None)
- invocation_link.accept_ticket(invocation_ticket)
- requests = scenario.requests()
- for request_index, request in enumerate(requests):
- request_ticket = links.Ticket(
- test_operation_id, 1 + request_index, None, None, None, None, 1, None,
- request, None, None, None, None, None)
- invocation_link.accept_ticket(request_ticket)
- service_mate.block_until_tickets_satisfy(
- test_cases.at_least_n_payloads_received_predicate(1 + request_index))
- response_ticket = links.Ticket(
- service_mate.tickets()[0].operation_id, request_index, None, None,
- None, None, 1, None, scenario.response_for_request(request), None,
- None, None, None, None)
- service_link.accept_ticket(response_ticket)
- invocation_mate.block_until_tickets_satisfy(
- test_cases.at_least_n_payloads_received_predicate(1 + request_index))
- request_count = len(requests)
- invocation_completion_ticket = links.Ticket(
- test_operation_id, request_count + 1, None, None, None, None, None,
- None, None, None, None, None, links.Ticket.Termination.COMPLETION,
- None)
- invocation_link.accept_ticket(invocation_completion_ticket)
- service_mate.block_until_tickets_satisfy(test_cases.terminated)
- service_completion_ticket = links.Ticket(
- service_mate.tickets()[0].operation_id, request_count, None, None, None,
- None, None, None, None, None, test_code, test_message,
- links.Ticket.Termination.COMPLETION, None)
- service_link.accept_ticket(service_completion_ticket)
- invocation_mate.block_until_tickets_satisfy(test_cases.terminated)
-
- invocation_link.stop()
- service_link.begin_stop()
- service_link.end_stop()
-
- observed_requests = tuple(
- ticket.payload for ticket in service_mate.tickets()
- if ticket.payload is not None)
- observed_responses = tuple(
- ticket.payload for ticket in invocation_mate.tickets()
- if ticket.payload is not None)
- self.assertTrue(scenario.verify_requests(observed_requests))
- self.assertTrue(scenario.verify_responses(observed_responses))
-
- def testEmptyScenario(self):
- self._perform_scenario_test(_proto_scenarios.EmptyScenario())
-
- def testBidirectionallyUnaryScenario(self):
- self._perform_scenario_test(_proto_scenarios.BidirectionallyUnaryScenario())
-
- def testBidirectionallyStreamingScenario(self):
- self._perform_scenario_test(
- _proto_scenarios.BidirectionallyStreamingScenario())
-
-
-if __name__ == '__main__':
- unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_metadata_code_details_test.py b/src/python/grpcio/tests/unit/_metadata_code_details_test.py
new file mode 100644
index 0000000000..dd74268cbf
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_metadata_code_details_test.py
@@ -0,0 +1,523 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests application-provided metadata, status code, and details."""
+
+import threading
+import unittest
+
+import grpc
+from grpc.framework.foundation import logging_pool
+
+from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+from tests.unit.framework.common import test_control
+
+_SERIALIZED_REQUEST = b'\x46\x47\x48'
+_SERIALIZED_RESPONSE = b'\x49\x50\x51'
+
+_REQUEST_SERIALIZER = lambda unused_request: _SERIALIZED_REQUEST
+_REQUEST_DESERIALIZER = lambda unused_serialized_request: object()
+_RESPONSE_SERIALIZER = lambda unused_response: _SERIALIZED_RESPONSE
+_RESPONSE_DESERIALIZER = lambda unused_serialized_resopnse: object()
+
+_SERVICE = b'test.TestService'
+_UNARY_UNARY = b'UnaryUnary'
+_UNARY_STREAM = b'UnaryStream'
+_STREAM_UNARY = b'StreamUnary'
+_STREAM_STREAM = b'StreamStream'
+
+_CLIENT_METADATA = (
+ (b'client-md-key', b'client-md-key'),
+ (b'client-md-key-bin', b'\x00\x01')
+)
+
+_SERVER_INITIAL_METADATA = (
+ (b'server-initial-md-key', b'server-initial-md-value'),
+ (b'server-initial-md-key-bin', b'\x00\x02')
+)
+
+_SERVER_TRAILING_METADATA = (
+ (b'server-trailing-md-key', b'server-trailing-md-value'),
+ (b'server-trailing-md-key-bin', b'\x00\x03')
+)
+
+_NON_OK_CODE = grpc.StatusCode.NOT_FOUND
+_DETAILS = b'Test details!'
+
+
+class _Servicer(object):
+
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._code = None
+ self._details = None
+ self._exception = False
+ self._return_none = False
+ self._received_client_metadata = None
+
+ def unary_unary(self, request, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ if self._exception:
+ raise test_control.Defect()
+ else:
+ return None if self._return_none else object()
+
+ def unary_stream(self, request, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ yield _SERIALIZED_RESPONSE
+ if self._exception:
+ raise test_control.Defect()
+
+ def stream_unary(self, request_iterator, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
+ # request iterator.
+ for ignored_request in request_iterator:
+ pass
+ if self._exception:
+ raise test_control.Defect()
+ else:
+ return None if self._return_none else _SERIALIZED_RESPONSE
+
+ def stream_stream(self, request_iterator, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
+ # request iterator.
+ for ignored_request in request_iterator:
+ pass
+ for _ in range(test_constants.STREAM_LENGTH // 3):
+ yield object()
+ if self._exception:
+ raise test_control.Defect()
+
+ def set_code(self, code):
+ with self._lock:
+ self._code = code
+
+ def set_details(self, details):
+ with self._lock:
+ self._details = details
+
+ def set_exception(self):
+ with self._lock:
+ self._exception = True
+
+ def set_return_none(self):
+ with self._lock:
+ self._return_none = True
+
+ def received_client_metadata(self):
+ with self._lock:
+ return self._received_client_metadata
+
+
+def _generic_handler(servicer):
+ method_handlers = {
+ _UNARY_UNARY: grpc.unary_unary_rpc_method_handler(
+ servicer.unary_unary, request_deserializer=_REQUEST_DESERIALIZER,
+ response_serializer=_RESPONSE_SERIALIZER),
+ _UNARY_STREAM: grpc.unary_stream_rpc_method_handler(
+ servicer.unary_stream),
+ _STREAM_UNARY: grpc.stream_unary_rpc_method_handler(
+ servicer.stream_unary),
+ _STREAM_STREAM: grpc.stream_stream_rpc_method_handler(
+ servicer.stream_stream, request_deserializer=_REQUEST_DESERIALIZER,
+ response_serializer=_RESPONSE_SERIALIZER),
+ }
+ return grpc.method_handlers_generic_handler(_SERVICE, method_handlers)
+
+
+class MetadataCodeDetailsTest(unittest.TestCase):
+
+ def setUp(self):
+ self._servicer = _Servicer()
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server(
+ (_generic_handler(self._servicer),), self._server_pool)
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ self._unary_unary = channel.unary_unary(
+ b'/'.join((b'', _SERVICE, _UNARY_UNARY,)),
+ request_serializer=_REQUEST_SERIALIZER,
+ response_deserializer=_RESPONSE_DESERIALIZER,)
+ self._unary_stream = channel.unary_stream(
+ b'/'.join((b'', _SERVICE, _UNARY_STREAM,)),)
+ self._stream_unary = channel.stream_unary(
+ b'/'.join((b'', _SERVICE, _STREAM_UNARY,)),)
+ self._stream_stream = channel.stream_stream(
+ b'/'.join((b'', _SERVICE, _STREAM_STREAM,)),
+ request_serializer=_REQUEST_SERIALIZER,
+ response_deserializer=_RESPONSE_DESERIALIZER,)
+
+
+ def testSuccessfulUnaryUnary(self):
+ self._servicer.set_details(_DETAILS)
+
+ unused_response, call = self._unary_unary.with_call(
+ object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testSuccessfulUnaryStream(self):
+ self._servicer.set_details(_DETAILS)
+
+ call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testSuccessfulStreamUnary(self):
+ self._servicer.set_details(_DETAILS)
+
+ unused_response, call = self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testSuccessfulStreamStream(self):
+ self._servicer.set_details(_DETAILS)
+
+ call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeUnaryStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeStreamStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeExceptionUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeExceptionUnaryStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeExceptionStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeExceptionStreamStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeReturnNoneUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_return_none()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeReturnNoneStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_return_none()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_metadata_test.py b/src/python/grpcio/tests/unit/_metadata_test.py
new file mode 100644
index 0000000000..2cb13f236b
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_metadata_test.py
@@ -0,0 +1,216 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Tests server and client side metadata API."""
+
+import unittest
+import weakref
+
+import grpc
+from grpc import _grpcio_metadata
+from grpc.framework.foundation import logging_pool
+
+from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+
+_CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'),
+ ('grpc.secondary_user_agent', 'secondary-agent'))
+
+_REQUEST = b'\x00\x00\x00'
+_RESPONSE = b'\x00\x00\x00'
+
+_UNARY_UNARY = b'/test/UnaryUnary'
+_UNARY_STREAM = b'/test/UnaryStream'
+_STREAM_UNARY = b'/test/StreamUnary'
+_STREAM_STREAM = b'/test/StreamStream'
+
+_USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
+
+_CLIENT_METADATA = (
+ (b'client-md-key', b'client-md-key'),
+ (b'client-md-key-bin', b'\x00\x01')
+)
+
+_SERVER_INITIAL_METADATA = (
+ (b'server-initial-md-key', b'server-initial-md-value'),
+ (b'server-initial-md-key-bin', b'\x00\x02')
+)
+
+_SERVER_TRAILING_METADATA = (
+ (b'server-trailing-md-key', b'server-trailing-md-value'),
+ (b'server-trailing-md-key-bin', b'\x00\x03')
+)
+
+
+def user_agent(metadata):
+ for key, val in metadata:
+ if key == b'user-agent':
+ return val.decode('ascii')
+ raise KeyError('No user agent!')
+
+
+def validate_client_metadata(test, servicer_context):
+ test.assertTrue(test_common.metadata_transmitted(
+ _CLIENT_METADATA, servicer_context.invocation_metadata()))
+ test.assertTrue(user_agent(servicer_context.invocation_metadata())
+ .startswith('primary-agent ' + _USER_AGENT))
+ test.assertTrue(user_agent(servicer_context.invocation_metadata())
+ .endswith('secondary-agent'))
+
+
+def handle_unary_unary(test, request, servicer_context):
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ return _RESPONSE
+
+
+def handle_unary_stream(test, request, servicer_context):
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ for _ in range(test_constants.STREAM_LENGTH):
+ yield _RESPONSE
+
+
+def handle_stream_unary(test, request_iterator, servicer_context):
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ # TODO(issue:#6891) We should be able to remove this loop
+ for request in request_iterator:
+ pass
+ return _RESPONSE
+
+
+def handle_stream_stream(test, request_iterator, servicer_context):
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ # TODO(issue:#6891) We should be able to remove this loop,
+ # and replace with return; yield
+ for request in request_iterator:
+ yield _RESPONSE
+
+
+class _MethodHandler(grpc.RpcMethodHandler):
+
+ def __init__(self, test, request_streaming, response_streaming):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
+ elif self.request_streaming:
+ self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
+ elif self.response_streaming:
+ self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
+ else:
+ self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def __init__(self, test):
+ self._test = test
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(self._test, False, False)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(self._test, False, True)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(self._test, True, False)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(self._test, True, True)
+ else:
+ return None
+
+
+class MetadataTest(unittest.TestCase):
+
+ def setUp(self):
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server((_GenericHandler(weakref.proxy(self)),),
+ self._server_pool)
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+ self._channel = grpc.insecure_channel('localhost:%d' % port,
+ options=_CHANNEL_ARGS)
+
+ def tearDown(self):
+ self._server.stop(0)
+
+ def testUnaryUnary(self):
+ multi_callable = self._channel.unary_unary(_UNARY_UNARY)
+ unused_response, call = multi_callable.with_call(
+ _REQUEST, metadata=_CLIENT_METADATA)
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+
+ def testUnaryStream(self):
+ multi_callable = self._channel.unary_stream(_UNARY_STREAM)
+ call = multi_callable(_REQUEST, metadata=_CLIENT_METADATA)
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ for _ in call:
+ pass
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+
+ def testStreamUnary(self):
+ multi_callable = self._channel.stream_unary(_STREAM_UNARY)
+ unused_response, call = multi_callable.with_call(
+ [_REQUEST] * test_constants.STREAM_LENGTH,
+ metadata=_CLIENT_METADATA)
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+
+ def testStreamStream(self):
+ multi_callable = self._channel.stream_stream(_STREAM_STREAM)
+ call = multi_callable([_REQUEST] * test_constants.STREAM_LENGTH,
+ metadata=_CLIENT_METADATA)
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ for _ in call:
+ pass
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_rpc_test.py b/src/python/grpcio/tests/unit/_rpc_test.py
index 1c7a14c5d0..9814504edf 100644
--- a/src/python/grpcio/tests/unit/_rpc_test.py
+++ b/src/python/grpcio/tests/unit/_rpc_test.py
@@ -27,7 +27,7 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""Test of gRPC Python's application-layer API."""
+"""Test of RPCs made against gRPC Python's application-layer API."""
import itertools
import threading
@@ -41,9 +41,9 @@ from tests.unit.framework.common import test_constants
from tests.unit.framework.common import test_control
_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2
-_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) / 2:]
+_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:]
_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
-_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) / 3]
+_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]
_UNARY_UNARY = b'/test/UnaryUnary'
_UNARY_STREAM = b'/test/UnaryStream'
@@ -189,14 +189,7 @@ class RPCTest(unittest.TestCase):
self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
self._server.start()
- self._channel = grpc.insecure_channel(b'localhost:%d' % port)
-
- # TODO(nathaniel): Why is this necessary, and only in some development
- # environments?
- def tearDown(self):
- del self._channel
- del self._server
- del self._server_pool
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
def testUnrecognizedMethod(self):
request = b'abc'
@@ -223,10 +216,9 @@ class RPCTest(unittest.TestCase):
expected_response = self._handler.handle_unary_unary(request, None)
multi_callable = _unary_unary_multi_callable(self._channel)
- response, call = multi_callable(
+ response, call = multi_callable.with_call(
request, metadata=(
- (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),),
- with_call=True)
+ (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),))
self.assertEqual(expected_response, response)
self.assertIs(grpc.StatusCode.OK, call.code())
@@ -273,11 +265,11 @@ class RPCTest(unittest.TestCase):
request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel)
- response, call = multi_callable(
+ response, call = multi_callable.with_call(
request_iterator,
metadata=(
(b'test', b'SuccessfulStreamRequestBlockingUnaryResponseWithCall'),
- ), with_call=True)
+ ))
self.assertEqual(expected_response, response)
self.assertIs(grpc.StatusCode.OK, call.code())
@@ -532,10 +524,9 @@ class RPCTest(unittest.TestCase):
multi_callable = _unary_unary_multi_callable(self._channel)
with self._control.pause():
with self.assertRaises(grpc.RpcError) as exception_context:
- multi_callable(
+ multi_callable.with_call(
request, timeout=test_constants.SHORT_TIMEOUT,
- metadata=((b'test', b'ExpiredUnaryRequestBlockingUnaryResponse'),),
- with_call=True)
+ metadata=((b'test', b'ExpiredUnaryRequestBlockingUnaryResponse'),))
self.assertIsNotNone(exception_context.exception.initial_metadata())
self.assertIs(
@@ -647,10 +638,9 @@ class RPCTest(unittest.TestCase):
multi_callable = _unary_unary_multi_callable(self._channel)
with self._control.fail():
with self.assertRaises(grpc.RpcError) as exception_context:
- multi_callable(
+ multi_callable.with_call(
request,
- metadata=((b'test', b'FailedUnaryRequestBlockingUnaryResponse'),),
- with_call=True)
+ metadata=((b'test', b'FailedUnaryRequestBlockingUnaryResponse'),))
self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
diff --git a/src/python/grpcio/tests/unit/_thread_cleanup_test.py b/src/python/grpcio/tests/unit/_thread_cleanup_test.py
new file mode 100644
index 0000000000..3e4f317edc
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_thread_cleanup_test.py
@@ -0,0 +1,117 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Tests for CleanupThread."""
+
+import threading
+import time
+import unittest
+
+from grpc import _common
+
+_SHORT_TIME = 0.5
+_LONG_TIME = 2.0
+_EPSILON = 0.1
+
+
+def cleanup(timeout):
+ if timeout is not None:
+ time.sleep(timeout)
+ else:
+ time.sleep(_LONG_TIME)
+
+
+def slow_cleanup(timeout):
+ # Don't respect timeout
+ time.sleep(_LONG_TIME)
+
+
+class CleanupThreadTest(unittest.TestCase):
+
+ def testTargetInvocation(self):
+ event = threading.Event()
+ def target(arg1, arg2, arg3=None):
+ self.assertEqual('arg1', arg1)
+ self.assertEqual('arg2', arg2)
+ self.assertEqual('arg3', arg3)
+ event.set()
+
+ cleanup_thread = _common.CleanupThread(behavior=lambda x: None,
+ target=target, name='test-name',
+ args=('arg1', 'arg2'), kwargs={'arg3': 'arg3'})
+ cleanup_thread.start()
+ cleanup_thread.join()
+ self.assertEqual(cleanup_thread.name, 'test-name')
+ self.assertTrue(event.is_set())
+
+ def testJoinNoTimeout(self):
+ cleanup_thread = _common.CleanupThread(behavior=cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join()
+ end_time = time.time()
+ self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON)
+
+ def testJoinTimeout(self):
+ cleanup_thread = _common.CleanupThread(behavior=cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(_SHORT_TIME)
+ end_time = time.time()
+ self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON)
+
+ def testJoinTimeoutSlowBehavior(self):
+ cleanup_thread = _common.CleanupThread(behavior=slow_cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(_SHORT_TIME)
+ end_time = time.time()
+ self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON)
+
+ def testJoinTimeoutSlowTarget(self):
+ event = threading.Event()
+ def target():
+ event.wait(_LONG_TIME)
+ cleanup_thread = _common.CleanupThread(behavior=cleanup, target=target)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(_SHORT_TIME)
+ end_time = time.time()
+ self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON)
+ event.set()
+
+ def testJoinZeroTimeout(self):
+ cleanup_thread = _common.CleanupThread(behavior=cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(0)
+ end_time = time.time()
+ self.assertAlmostEqual(0, end_time - start_time, delta=_EPSILON)
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/beta/_beta_features_test.py b/src/python/grpcio/tests/unit/beta/_beta_features_test.py
index bb2893a21b..3a9701b8eb 100644
--- a/src/python/grpcio/tests/unit/beta/_beta_features_test.py
+++ b/src/python/grpcio/tests/unit/beta/_beta_features_test.py
@@ -42,8 +42,8 @@ from tests.unit.framework.common import test_constants
_SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
-_PER_RPC_CREDENTIALS_METADATA_KEY = 'my-call-credentials-metadata-key'
-_PER_RPC_CREDENTIALS_METADATA_VALUE = 'my-call-credentials-metadata-value'
+_PER_RPC_CREDENTIALS_METADATA_KEY = b'my-call-credentials-metadata-key'
+_PER_RPC_CREDENTIALS_METADATA_VALUE = b'my-call-credentials-metadata-value'
_GROUP = 'group'
_UNARY_UNARY = 'unary-unary'
diff --git a/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py b/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py
index 5dc8720639..488f7d7141 100644
--- a/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py
+++ b/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py
@@ -187,5 +187,15 @@ class ChannelConnectivityTest(unittest.TestCase):
server_completion_queue_thread.join()
+class ConnectivityStatesTest(unittest.TestCase):
+
+ def testBetaConnectivityStates(self):
+ self.assertIsNotNone(interfaces.ChannelConnectivity.IDLE)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.CONNECTING)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.READY)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.TRANSIENT_FAILURE)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.FATAL_FAILURE)
+
+
if __name__ == '__main__':
unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/beta/_not_found_test.py b/src/python/grpcio/tests/unit/beta/_not_found_test.py
index 44fcd1e13c..37b8c49120 100644
--- a/src/python/grpcio/tests/unit/beta/_not_found_test.py
+++ b/src/python/grpcio/tests/unit/beta/_not_found_test.py
@@ -61,7 +61,7 @@ class NotFoundTest(unittest.TestCase):
def test_future_stream_unary_not_found(self):
rpc_future = self._generic_stub.future_stream_unary(
- 'grupe', 'mevvod', b'def', test_constants.LONG_TIMEOUT)
+ 'grupe', 'mevvod', [b'def'], test_constants.LONG_TIMEOUT)
with self.assertRaises(face.LocalError) as exception_assertion_context:
rpc_future.result()
self.assertIs(
diff --git a/src/python/grpcio/tests/unit/beta/test_utilities.py b/src/python/grpcio/tests/unit/beta/test_utilities.py
index 0313e06a93..8ccad04e05 100644
--- a/src/python/grpcio/tests/unit/beta/test_utilities.py
+++ b/src/python/grpcio/tests/unit/beta/test_utilities.py
@@ -29,7 +29,7 @@
"""Test-appropriate entry points into the gRPC Python Beta API."""
-from grpc._adapter import _intermediary_low
+import grpc
from grpc.beta import implementations
@@ -48,9 +48,8 @@ def not_really_secure_channel(
An implementations.Channel to the remote host through which RPCs may be
conducted.
"""
- hostport = '%s:%d' % (host, port)
- intermediary_low_channel = _intermediary_low.Channel(
- hostport, channel_credentials._low_credentials,
- server_host_override=server_host_override)
- return implementations.Channel(
- intermediary_low_channel._internal, intermediary_low_channel)
+ target = '%s:%d' % (host, port)
+ channel = grpc.secure_channel(
+ target, channel_credentials,
+ ((b'grpc.ssl_target_name_override', server_host_override,),))
+ return implementations.Channel(channel)
diff --git a/src/python/grpcio/tests/unit/framework/_crust_over_core_face_interface_test.py b/src/python/grpcio/tests/unit/framework/_crust_over_core_face_interface_test.py
deleted file mode 100644
index 43457be362..0000000000
--- a/src/python/grpcio/tests/unit/framework/_crust_over_core_face_interface_test.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# Copyright 2015, Google Inc.
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests Face interface compliance of the crust-over-core stack."""
-
-import collections
-import unittest
-
-import six
-
-from grpc.framework.core import implementations as core_implementations
-from grpc.framework.crust import implementations as crust_implementations
-from grpc.framework.foundation import logging_pool
-from grpc.framework.interfaces.links import utilities
-from tests.unit.framework.common import test_constants
-from tests.unit.framework.interfaces.face import test_cases
-from tests.unit.framework.interfaces.face import test_interfaces
-from tests.unit.framework.interfaces.links import test_utilities
-
-
-class _Implementation(test_interfaces.Implementation):
-
- def instantiate(
- self, methods, method_implementations, multi_method_implementation):
- pool = logging_pool.pool(test_constants.POOL_SIZE)
- servicer = crust_implementations.servicer(
- method_implementations, multi_method_implementation, pool)
-
- service_end_link = core_implementations.service_end_link(
- servicer, test_constants.DEFAULT_TIMEOUT,
- test_constants.MAXIMUM_TIMEOUT)
- invocation_end_link = core_implementations.invocation_end_link()
- invocation_end_link.join_link(service_end_link)
- service_end_link.join_link(invocation_end_link)
- service_end_link.start()
- invocation_end_link.start()
-
- generic_stub = crust_implementations.generic_stub(invocation_end_link, pool)
- # TODO(nathaniel): Add a "groups" attribute to _digest.TestServiceDigest.
- group = next(iter(methods))[0]
- # TODO(nathaniel): Add a "cardinalities_by_group" attribute to
- # _digest.TestServiceDigest.
- cardinalities = {
- method: method_object.cardinality()
- for (group, method), method_object in six.iteritems(methods)}
- dynamic_stub = crust_implementations.dynamic_stub(
- invocation_end_link, group, cardinalities, pool)
-
- return generic_stub, {group: dynamic_stub}, (
- invocation_end_link, service_end_link, pool)
-
- def destantiate(self, memo):
- invocation_end_link, service_end_link, pool = memo
- invocation_end_link.stop(0).wait()
- service_end_link.stop(0).wait()
- invocation_end_link.join_link(utilities.NULL_LINK)
- service_end_link.join_link(utilities.NULL_LINK)
- pool.shutdown(wait=True)
-
- def invocation_metadata(self):
- return object()
-
- def initial_metadata(self):
- return object()
-
- def terminal_metadata(self):
- return object()
-
- def code(self):
- return object()
-
- def details(self):
- return object()
-
- def metadata_transmitted(self, original_metadata, transmitted_metadata):
- return original_metadata is transmitted_metadata
-
-
-def load_tests(loader, tests, pattern):
- return unittest.TestSuite(
- tests=tuple(
- loader.loadTestsFromTestCase(test_case_class)
- for test_case_class in test_cases.test_cases(_Implementation())))
-
-
-if __name__ == '__main__':
- unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/framework/core/_base_interface_test.py b/src/python/grpcio/tests/unit/framework/core/_base_interface_test.py
deleted file mode 100644
index 1310292306..0000000000
--- a/src/python/grpcio/tests/unit/framework/core/_base_interface_test.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Copyright 2015, Google Inc.
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests the RPC Framework Core's implementation of the Base interface."""
-
-import logging
-import random
-import time
-import unittest
-
-from grpc.framework.core import implementations
-from grpc.framework.interfaces.base import utilities
-from tests.unit.framework.common import test_constants
-from tests.unit.framework.interfaces.base import test_cases
-from tests.unit.framework.interfaces.base import test_interfaces
-
-
-class _Implementation(test_interfaces.Implementation):
-
- def __init__(self):
- self._invocation_initial_metadata = object()
- self._service_initial_metadata = object()
- self._invocation_terminal_metadata = object()
- self._service_terminal_metadata = object()
-
- def instantiate(self, serializations, servicer):
- invocation = implementations.invocation_end_link()
- service = implementations.service_end_link(
- servicer, test_constants.DEFAULT_TIMEOUT,
- test_constants.MAXIMUM_TIMEOUT)
- invocation.join_link(service)
- service.join_link(invocation)
- return invocation, service, None
-
- def destantiate(self, memo):
- pass
-
- def invocation_initial_metadata(self):
- return self._invocation_initial_metadata
-
- def service_initial_metadata(self):
- return self._service_initial_metadata
-
- def invocation_completion(self):
- return utilities.completion(self._invocation_terminal_metadata, None, None)
-
- def service_completion(self):
- return utilities.completion(self._service_terminal_metadata, None, None)
-
- def metadata_transmitted(self, original_metadata, transmitted_metadata):
- return transmitted_metadata is original_metadata
-
- def completion_transmitted(self, original_completion, transmitted_completion):
- return (
- (original_completion.terminal_metadata is
- transmitted_completion.terminal_metadata) and
- original_completion.code is transmitted_completion.code and
- original_completion.message is transmitted_completion.message
- )
-
-
-def load_tests(loader, tests, pattern):
- return unittest.TestSuite(
- tests=tuple(
- loader.loadTestsFromTestCase(test_case_class)
- for test_case_class in test_cases.test_cases(_Implementation())))
-
-
-if __name__ == '__main__':
- unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/framework/foundation/_later_test.py b/src/python/grpcio/tests/unit/framework/foundation/_later_test.py
deleted file mode 100644
index 6c2459e185..0000000000
--- a/src/python/grpcio/tests/unit/framework/foundation/_later_test.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# Copyright 2015, Google Inc.
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests of the later module."""
-
-import threading
-import time
-import unittest
-
-from grpc.framework.foundation import later
-
-TICK = 0.1
-
-
-class LaterTest(unittest.TestCase):
-
- def test_simple_delay(self):
- lock = threading.Lock()
- cell = [0]
- return_value = object()
-
- def computation():
- with lock:
- cell[0] += 1
- return return_value
- computation_future = later.later(TICK * 2, computation)
-
- self.assertFalse(computation_future.done())
- self.assertFalse(computation_future.cancelled())
- time.sleep(TICK)
- self.assertFalse(computation_future.done())
- self.assertFalse(computation_future.cancelled())
- with lock:
- self.assertEqual(0, cell[0])
- time.sleep(TICK * 2)
- self.assertTrue(computation_future.done())
- self.assertFalse(computation_future.cancelled())
- with lock:
- self.assertEqual(1, cell[0])
- self.assertEqual(return_value, computation_future.result())
-
- def test_callback(self):
- lock = threading.Lock()
- cell = [0]
- callback_called = [False]
- future_passed_to_callback = [None]
- def computation():
- with lock:
- cell[0] += 1
- computation_future = later.later(TICK * 2, computation)
- def callback(outcome):
- with lock:
- callback_called[0] = True
- future_passed_to_callback[0] = outcome
- computation_future.add_done_callback(callback)
- time.sleep(TICK)
- with lock:
- self.assertFalse(callback_called[0])
- time.sleep(TICK * 2)
- with lock:
- self.assertTrue(callback_called[0])
- self.assertTrue(future_passed_to_callback[0].done())
-
- callback_called[0] = False
- future_passed_to_callback[0] = None
-
- computation_future.add_done_callback(callback)
- with lock:
- self.assertTrue(callback_called[0])
- self.assertTrue(future_passed_to_callback[0].done())
-
- def test_cancel(self):
- lock = threading.Lock()
- cell = [0]
- callback_called = [False]
- future_passed_to_callback = [None]
- def computation():
- with lock:
- cell[0] += 1
- computation_future = later.later(TICK * 2, computation)
- def callback(outcome):
- with lock:
- callback_called[0] = True
- future_passed_to_callback[0] = outcome
- computation_future.add_done_callback(callback)
- time.sleep(TICK)
- with lock:
- self.assertFalse(callback_called[0])
- computation_future.cancel()
- self.assertTrue(computation_future.cancelled())
- self.assertFalse(computation_future.running())
- self.assertTrue(computation_future.done())
- with lock:
- self.assertTrue(callback_called[0])
- self.assertTrue(future_passed_to_callback[0].cancelled())
-
- def test_result(self):
- lock = threading.Lock()
- cell = [0]
- callback_called = [False]
- future_passed_to_callback_cell = [None]
- return_value = object()
-
- def computation():
- with lock:
- cell[0] += 1
- return return_value
- computation_future = later.later(TICK * 2, computation)
-
- def callback(future_passed_to_callback):
- with lock:
- callback_called[0] = True
- future_passed_to_callback_cell[0] = future_passed_to_callback
- computation_future.add_done_callback(callback)
- returned_value = computation_future.result()
- self.assertEqual(return_value, returned_value)
-
- # The callback may not yet have been called! Sleep a tick.
- time.sleep(TICK)
- with lock:
- self.assertTrue(callback_called[0])
- self.assertEqual(return_value, future_passed_to_callback_cell[0].result())
-
-if __name__ == '__main__':
- unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/framework/interfaces/base/test_cases.py b/src/python/grpcio/tests/unit/framework/interfaces/base/test_cases.py
index 4f8e26c9a2..5d16bf98be 100644
--- a/src/python/grpcio/tests/unit/framework/interfaces/base/test_cases.py
+++ b/src/python/grpcio/tests/unit/framework/interfaces/base/test_cases.py
@@ -29,6 +29,8 @@
"""Tests of the base interface of RPC Framework."""
+from __future__ import division
+
import logging
import random
import threading
@@ -54,13 +56,13 @@ class _Serialization(test_interfaces.Serialization):
return request + request
def deserialize_request(self, serialized_request):
- return serialized_request[:len(serialized_request) / 2]
+ return serialized_request[:len(serialized_request) // 2]
def serialize_response(self, response):
return response * 3
def deserialize_response(self, serialized_response):
- return serialized_response[2 * len(serialized_response) / 3:]
+ return serialized_response[2 * len(serialized_response) // 3:]
def _advance(quadruples, operator, controller):
diff --git a/src/python/grpcio/tests/unit/test_common.py b/src/python/grpcio/tests/unit/test_common.py
index 7b4d20dccb..b779f65e7e 100644
--- a/src/python/grpcio/tests/unit/test_common.py
+++ b/src/python/grpcio/tests/unit/test_common.py
@@ -61,6 +61,10 @@ def metadata_transmitted(original_metadata, transmitted_metadata):
original = collections.defaultdict(list)
for key_value_pair in original_metadata:
key, value = tuple(key_value_pair)
+ if not isinstance(key, bytes):
+ key = key.encode()
+ if not isinstance(value, bytes):
+ value = value.encode()
original[key].append(value)
transmitted = collections.defaultdict(list)
for key_value_pair in transmitted_metadata: