aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/grpcio')
-rw-r--r--src/python/grpcio/commands.py2
-rw-r--r--src/python/grpcio/grpc/__init__.py1210
-rw-r--r--src/python/grpcio/grpc/_adapter/_low.py76
-rw-r--r--src/python/grpcio/grpc/_adapter/_types.py2
-rw-r--r--src/python/grpcio/grpc/_auth.py86
-rw-r--r--src/python/grpcio/grpc/_channel.py852
-rw-r--r--src/python/grpcio/grpc/_common.py154
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi7
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi24
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi3
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi59
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi4
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi55
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi10
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi39
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi71
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi21
-rw-r--r--src/python/grpcio/grpc/_cython/cygrpc.pyx1
-rw-r--r--src/python/grpcio/grpc/_cython/imports.generated.c10
-rw-r--r--src/python/grpcio/grpc/_cython/imports.generated.h23
-rw-r--r--src/python/grpcio/grpc/_cython/loader.c4
-rw-r--r--src/python/grpcio/grpc/_links/service.py6
-rw-r--r--src/python/grpcio/grpc/_plugin_wrapping.py123
-rw-r--r--src/python/grpcio/grpc/_server.py755
-rw-r--r--src/python/grpcio/grpc/_utilities.py171
-rw-r--r--src/python/grpcio/grpc/beta/_client_adaptations.py563
-rw-r--r--src/python/grpcio/grpc/beta/_server_adaptations.py367
-rw-r--r--src/python/grpcio/grpc/beta/implementations.py197
-rw-r--r--src/python/grpcio/grpc/beta/interfaces.py90
-rw-r--r--src/python/grpcio/grpc_core_dependencies.py31
-rw-r--r--src/python/grpcio/tests/interop/client.py42
-rw-r--r--src/python/grpcio/tests/interop/methods.py43
-rw-r--r--src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py583
-rw-r--r--src/python/grpcio/tests/qps/benchmark_client.py102
-rw-r--r--src/python/grpcio/tests/qps/client_runner.py14
-rw-r--r--src/python/grpcio/tests/qps/qps_worker.py4
-rw-r--r--src/python/grpcio/tests/qps/worker_server.py5
-rw-r--r--src/python/grpcio/tests/stress/client.py6
-rw-r--r--src/python/grpcio/tests/tests.json15
-rw-r--r--src/python/grpcio/tests/unit/_adapter/_intermediary_low_test.py43
-rw-r--r--src/python/grpcio/tests/unit/_adapter/_low_test.py16
-rw-r--r--src/python/grpcio/tests/unit/_api_test.py104
-rw-r--r--src/python/grpcio/tests/unit/_auth_test.py96
-rw-r--r--src/python/grpcio/tests/unit/_channel_connectivity_test.py161
-rw-r--r--src/python/grpcio/tests/unit/_channel_ready_future_test.py103
-rw-r--r--src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py222
-rw-r--r--src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py251
-rw-r--r--src/python/grpcio/tests/unit/_cython/cygrpc_test.py286
-rw-r--r--src/python/grpcio/tests/unit/_cython/test_utilities.py34
-rw-r--r--src/python/grpcio/tests/unit/_empty_message_test.py137
-rw-r--r--src/python/grpcio/tests/unit/_from_grpc_import_star.py (renamed from src/python/grpcio/grpc/_adapter/_implementations.py)24
-rw-r--r--src/python/grpcio/tests/unit/_links/_transmission_test.py2
-rw-r--r--src/python/grpcio/tests/unit/_metadata_test.py216
-rw-r--r--src/python/grpcio/tests/unit/_rpc_test.py768
-rw-r--r--src/python/grpcio/tests/unit/_thread_cleanup_test.py117
-rw-r--r--src/python/grpcio/tests/unit/beta/_beta_features_test.py4
-rw-r--r--src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py10
-rw-r--r--src/python/grpcio/tests/unit/beta/_implementations_test.py17
-rw-r--r--src/python/grpcio/tests/unit/beta/_not_found_test.py2
-rw-r--r--src/python/grpcio/tests/unit/beta/test_utilities.py13
-rw-r--r--src/python/grpcio/tests/unit/framework/common/test_control.py26
-rw-r--r--src/python/grpcio/tests/unit/framework/interfaces/base/test_cases.py6
-rw-r--r--src/python/grpcio/tests/unit/test_common.py4
63 files changed, 7639 insertions, 853 deletions
diff --git a/src/python/grpcio/commands.py b/src/python/grpcio/commands.py
index 295dab2d27..3e974eba0a 100644
--- a/src/python/grpcio/commands.py
+++ b/src/python/grpcio/commands.py
@@ -191,7 +191,7 @@ class BuildProtoModules(setuptools.Command):
except subprocess.CalledProcessError as e:
sys.stderr.write(
'warning: Command:\n{}\nMessage:\n{}\nOutput:\n{}'.format(
- command, e.message, e.output))
+ command, str(e), e.output))
# Generated proto directories dont include __init__.py, but
# these are needed for python package resolution
diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py
index b844a14c48..28adca3772 100644
--- a/src/python/grpcio/grpc/__init__.py
+++ b/src/python/grpcio/grpc/__init__.py
@@ -27,5 +27,1215 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""gRPC's Python API."""
+
__import__('pkg_resources').declare_namespace(__name__)
+import abc
+import enum
+
+import six
+
+from grpc._cython import cygrpc as _cygrpc
+
+
+############################## Future Interface ###############################
+
+
+class FutureTimeoutError(Exception):
+ """Indicates that a method call on a Future timed out."""
+
+
+class FutureCancelledError(Exception):
+ """Indicates that the computation underlying a Future was cancelled."""
+
+
+class Future(six.with_metaclass(abc.ABCMeta)):
+ """A representation of a computation in another control flow.
+
+ Computations represented by a Future may be yet to be begun, may be ongoing,
+ or may have already completed.
+ """
+
+ @abc.abstractmethod
+ def cancel(self):
+ """Attempts to cancel the computation.
+
+ This method does not block.
+
+ Returns:
+ True if the computation has not yet begun, will not be allowed to take
+ place, and determination of both was possible without blocking. False
+ under all other circumstances including but not limited to the
+ computation's already having begun, the computation's already having
+ finished, and the computation's having been scheduled for execution on a
+ remote system for which a determination of whether or not it commenced
+ before being cancelled cannot be made without blocking.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def cancelled(self):
+ """Describes whether the computation was cancelled.
+
+ This method does not block.
+
+ Returns:
+ True if the computation was cancelled any time before its result became
+ immediately available. False under all other circumstances including but
+ not limited to this object's cancel method not having been called and
+ the computation's result having become immediately available.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def running(self):
+ """Describes whether the computation is taking place.
+
+ This method does not block.
+
+ Returns:
+ True if the computation is scheduled to take place in the future or is
+ taking place now, or False if the computation took place in the past or
+ was cancelled.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def done(self):
+ """Describes whether the computation has taken place.
+
+ This method does not block.
+
+ Returns:
+ True if the computation is known to have either completed or have been
+ unscheduled or interrupted. False if the computation may possibly be
+ executing or scheduled to execute later.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def result(self, timeout=None):
+ """Accesses the outcome of the computation or raises its exception.
+
+ This method may return immediately or may block.
+
+ Args:
+ timeout: The length of time in seconds to wait for the computation to
+ finish or be cancelled, or None if this method should block until the
+ computation has finished or is cancelled no matter how long that takes.
+
+ Returns:
+ The return value of the computation.
+
+ Raises:
+ FutureTimeoutError: If a timeout value is passed and the computation does
+ not terminate within the allotted time.
+ FutureCancelledError: If the computation was cancelled.
+ Exception: If the computation raised an exception, this call will raise
+ the same exception.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def exception(self, timeout=None):
+ """Return the exception raised by the computation.
+
+ This method may return immediately or may block.
+
+ Args:
+ timeout: The length of time in seconds to wait for the computation to
+ terminate or be cancelled, or None if this method should block until
+ the computation is terminated or is cancelled no matter how long that
+ takes.
+
+ Returns:
+ The exception raised by the computation, or None if the computation did
+ not raise an exception.
+
+ Raises:
+ FutureTimeoutError: If a timeout value is passed and the computation does
+ not terminate within the allotted time.
+ FutureCancelledError: If the computation was cancelled.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def traceback(self, timeout=None):
+ """Access the traceback of the exception raised by the computation.
+
+ This method may return immediately or may block.
+
+ Args:
+ timeout: The length of time in seconds to wait for the computation to
+ terminate or be cancelled, or None if this method should block until
+ the computation is terminated or is cancelled no matter how long that
+ takes.
+
+ Returns:
+ The traceback of the exception raised by the computation, or None if the
+ computation did not raise an exception.
+
+ Raises:
+ FutureTimeoutError: If a timeout value is passed and the computation does
+ not terminate within the allotted time.
+ FutureCancelledError: If the computation was cancelled.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def add_done_callback(self, fn):
+ """Adds a function to be called at completion of the computation.
+
+ The callback will be passed this Future object describing the outcome of
+ the computation.
+
+ If the computation has already completed, the callback will be called
+ immediately.
+
+ Args:
+ fn: A callable taking this Future object as its single parameter.
+ """
+ raise NotImplementedError()
+
+
+################################ gRPC Enums ##################################
+
+
+@enum.unique
+class ChannelConnectivity(enum.Enum):
+ """Mirrors grpc_connectivity_state in the gRPC Core.
+
+ Attributes:
+ IDLE: The channel is idle.
+ CONNECTING: The channel is connecting.
+ READY: The channel is ready to conduct RPCs.
+ TRANSIENT_FAILURE: The channel has seen a failure from which it expects to
+ recover.
+ SHUTDOWN: The channel has seen a failure from which it cannot recover.
+ """
+ IDLE = (_cygrpc.ConnectivityState.idle, 'idle')
+ CONNECTING = (_cygrpc.ConnectivityState.connecting, 'connecting')
+ READY = (_cygrpc.ConnectivityState.ready, 'ready')
+ TRANSIENT_FAILURE = (
+ _cygrpc.ConnectivityState.transient_failure, 'transient failure')
+ SHUTDOWN = (_cygrpc.ConnectivityState.shutdown, 'shutdown')
+
+
+@enum.unique
+class StatusCode(enum.Enum):
+ """Mirrors grpc_status_code in the gRPC Core."""
+ OK = (_cygrpc.StatusCode.ok, 'ok')
+ CANCELLED = (_cygrpc.StatusCode.cancelled, 'cancelled')
+ UNKNOWN = (_cygrpc.StatusCode.unknown, 'unknown')
+ INVALID_ARGUMENT = (
+ _cygrpc.StatusCode.invalid_argument, 'invalid argument')
+ DEADLINE_EXCEEDED = (
+ _cygrpc.StatusCode.deadline_exceeded, 'deadline exceeded')
+ NOT_FOUND = (_cygrpc.StatusCode.not_found, 'not found')
+ ALREADY_EXISTS = (_cygrpc.StatusCode.already_exists, 'already exists')
+ PERMISSION_DENIED = (
+ _cygrpc.StatusCode.permission_denied, 'permission denied')
+ RESOURCE_EXHAUSTED = (
+ _cygrpc.StatusCode.resource_exhausted, 'resource exhausted')
+ FAILED_PRECONDITION = (
+ _cygrpc.StatusCode.failed_precondition, 'failed precondition')
+ ABORTED = (_cygrpc.StatusCode.aborted, 'aborted')
+ OUT_OF_RANGE = (_cygrpc.StatusCode.out_of_range, 'out of range')
+ UNIMPLEMENTED = (_cygrpc.StatusCode.unimplemented, 'unimplemented')
+ INTERNAL = (_cygrpc.StatusCode.internal, 'internal')
+ UNAVAILABLE = (_cygrpc.StatusCode.unavailable, 'unavailable')
+ DATA_LOSS = (_cygrpc.StatusCode.data_loss, 'data loss')
+ UNAUTHENTICATED = (_cygrpc.StatusCode.unauthenticated, 'unauthenticated')
+
+
+############################# gRPC Exceptions ################################
+
+
+class RpcError(Exception):
+ """Raised by the gRPC library to indicate non-OK-status RPC termination."""
+
+
+############################## Shared Context ################################
+
+
+class RpcContext(six.with_metaclass(abc.ABCMeta)):
+ """Provides RPC-related information and action."""
+
+ @abc.abstractmethod
+ def is_active(self):
+ """Describes whether the RPC is active or has terminated."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def time_remaining(self):
+ """Describes the length of allowed time remaining for the RPC.
+
+ Returns:
+ A nonnegative float indicating the length of allowed time in seconds
+ remaining for the RPC to complete before it is considered to have timed
+ out, or None if no deadline was specified for the RPC.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def cancel(self):
+ """Cancels the RPC.
+
+ Idempotent and has no effect if the RPC has already terminated.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def add_callback(self, callback):
+ """Registers a callback to be called on RPC termination.
+
+ Args:
+ callback: A no-parameter callable to be called on RPC termination.
+
+ Returns:
+ True if the callback was added and will be called later; False if the
+ callback was not added and will not later be called (because the RPC
+ already terminated or some other reason).
+ """
+ raise NotImplementedError()
+
+
+######################### Invocation-Side Context ############################
+
+
+class Call(six.with_metaclass(abc.ABCMeta, RpcContext)):
+ """Invocation-side utility object for an RPC."""
+
+ @abc.abstractmethod
+ def initial_metadata(self):
+ """Accesses the initial metadata from the service-side of the RPC.
+
+ This method blocks until the value is available.
+
+ Returns:
+ The initial metadata as a sequence of pairs of bytes.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def trailing_metadata(self):
+ """Accesses the trailing metadata from the service-side of the RPC.
+
+ This method blocks until the value is available.
+
+ Returns:
+ The trailing metadata as a sequence of pairs of bytes.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def code(self):
+ """Accesses the status code emitted by the service-side of the RPC.
+
+ This method blocks until the value is available.
+
+ Returns:
+ The StatusCode value for the RPC.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def details(self):
+ """Accesses the details value emitted by the service-side of the RPC.
+
+ This method blocks until the value is available.
+
+ Returns:
+ The bytes of the details of the RPC.
+ """
+ raise NotImplementedError()
+
+
+############ Authentication & Authorization Interfaces & Classes #############
+
+
+class ChannelCredentials(object):
+ """A value encapsulating the data required to create a secure Channel.
+
+ This class has no supported interface - it exists to define the type of its
+ instances and its instances exist to be passed to other functions.
+ """
+
+ def __init__(self, credentials):
+ self._credentials = credentials
+
+
+class CallCredentials(object):
+ """A value encapsulating data asserting an identity over a channel.
+
+ A CallCredentials may be composed with ChannelCredentials to always assert
+ identity for every call over that Channel.
+
+ This class has no supported interface - it exists to define the type of its
+ instances and its instances exist to be passed to other functions.
+ """
+
+ def __init__(self, credentials):
+ self._credentials = credentials
+
+
+class AuthMetadataContext(six.with_metaclass(abc.ABCMeta)):
+ """Provides information to call credentials metadata plugins.
+
+ Attributes:
+ service_url: A string URL of the service being called into.
+ method_name: A string of the fully qualified method name being called.
+ """
+
+
+class AuthMetadataPluginCallback(six.with_metaclass(abc.ABCMeta)):
+ """Callback object received by a metadata plugin."""
+
+ def __call__(self, metadata, error):
+ """Inform the gRPC runtime of the metadata to construct a CallCredentials.
+
+ Args:
+ metadata: An iterable of 2-sequences (e.g. tuples) of metadata key/value
+ pairs.
+ error: An Exception to indicate error or None to indicate success.
+ """
+ raise NotImplementedError()
+
+
+class AuthMetadataPlugin(six.with_metaclass(abc.ABCMeta)):
+ """A specification for custom authentication."""
+
+ def __call__(self, context, callback):
+ """Implements authentication by passing metadata to a callback.
+
+ Implementations of this method must not block.
+
+ Args:
+ context: An AuthMetadataContext providing information on the RPC that the
+ plugin is being called to authenticate.
+ callback: An AuthMetadataPluginCallback to be invoked either synchronously
+ or asynchronously.
+ """
+ raise NotImplementedError()
+
+
+class ServerCredentials(object):
+ """A value encapsulating the data required to open a secure port on a Server.
+
+ This class has no supported interface - it exists to define the type of its
+ instances and its instances exist to be passed to other functions.
+ """
+
+ def __init__(self, credentials):
+ self._credentials = credentials
+
+
+######################## Multi-Callable Interfaces ###########################
+
+
+class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
+ """Affords invoking a unary-unary RPC."""
+
+ @abc.abstractmethod
+ def __call__(
+ self, request, timeout=None, metadata=None, credentials=None,
+ with_call=False):
+ """Synchronously invokes the underlying RPC.
+
+ Args:
+ request: The request value for the RPC.
+ timeout: An optional duration of time in seconds to allow for the RPC.
+ metadata: An optional sequence of pairs of bytes to be transmitted to the
+ service-side of the RPC.
+ credentials: An optional CallCredentials for the RPC.
+ with_call: Whether or not to include return a Call for the RPC in addition
+ to the response.
+
+ Returns:
+ The response value for the RPC, and a Call for the RPC if with_call was
+ set to True at invocation.
+
+ Raises:
+ RpcError: Indicating that the RPC terminated with non-OK status. The
+ raised RpcError will also be a Call for the RPC affording the RPC's
+ metadata, status code, and details.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def future(self, request, timeout=None, metadata=None, credentials=None):
+ """Asynchronously invokes the underlying RPC.
+
+ Args:
+ request: The request value for the RPC.
+ timeout: An optional duration of time in seconds to allow for the RPC.
+ metadata: An optional sequence of pairs of bytes to be transmitted to the
+ service-side of the RPC.
+ credentials: An optional CallCredentials for the RPC.
+
+ Returns:
+ An object that is both a Call for the RPC and a Future. In the event of
+ RPC completion, the return Future's result value will be the response
+ message of the RPC. Should the event terminate with non-OK status, the
+ returned Future's exception value will be an RpcError.
+ """
+ raise NotImplementedError()
+
+
+class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
+ """Affords invoking a unary-stream RPC."""
+
+ @abc.abstractmethod
+ def __call__(self, request, timeout=None, metadata=None, credentials=None):
+ """Invokes the underlying RPC.
+
+ Args:
+ request: The request value for the RPC.
+ timeout: An optional duration of time in seconds to allow for the RPC.
+ metadata: An optional sequence of pairs of bytes to be transmitted to the
+ service-side of the RPC.
+ credentials: An optional CallCredentials for the RPC.
+
+ Returns:
+ An object that is both a Call for the RPC and an iterator of response
+ values. Drawing response values from the returned iterator may raise
+ RpcError indicating termination of the RPC with non-OK status.
+ """
+ raise NotImplementedError()
+
+
+class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
+ """Affords invoking a stream-unary RPC in any call style."""
+
+ @abc.abstractmethod
+ def __call__(
+ self, request_iterator, timeout=None, metadata=None, credentials=None,
+ with_call=False):
+ """Synchronously invokes the underlying RPC.
+
+ Args:
+ request_iterator: An iterator that yields request values for the RPC.
+ timeout: An optional duration of time in seconds to allow for the RPC.
+ metadata: An optional sequence of pairs of bytes to be transmitted to the
+ service-side of the RPC.
+ credentials: An optional CallCredentials for the RPC.
+ with_call: Whether or not to include return a Call for the RPC in addition
+ to the response.
+
+ Returns:
+ The response value for the RPC, and a Call for the RPC if with_call was
+ set to True at invocation.
+
+ Raises:
+ RpcError: Indicating that the RPC terminated with non-OK status. The
+ raised RpcError will also be a Call for the RPC affording the RPC's
+ metadata, status code, and details.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def future(
+ self, request_iterator, timeout=None, metadata=None, credentials=None):
+ """Asynchronously invokes the underlying RPC.
+
+ Args:
+ request_iterator: An iterator that yields request values for the RPC.
+ timeout: An optional duration of time in seconds to allow for the RPC.
+ metadata: An optional sequence of pairs of bytes to be transmitted to the
+ service-side of the RPC.
+ credentials: An optional CallCredentials for the RPC.
+
+ Returns:
+ An object that is both a Call for the RPC and a Future. In the event of
+ RPC completion, the return Future's result value will be the response
+ message of the RPC. Should the event terminate with non-OK status, the
+ returned Future's exception value will be an RpcError.
+ """
+ raise NotImplementedError()
+
+
+class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
+ """Affords invoking a stream-stream RPC in any call style."""
+
+ @abc.abstractmethod
+ def __call__(
+ self, request_iterator, timeout=None, metadata=None, credentials=None):
+ """Invokes the underlying RPC.
+
+ Args:
+ request_iterator: An iterator that yields request values for the RPC.
+ timeout: An optional duration of time in seconds to allow for the RPC.
+ metadata: An optional sequence of pairs of bytes to be transmitted to the
+ service-side of the RPC.
+ credentials: An optional CallCredentials for the RPC.
+
+ Returns:
+ An object that is both a Call for the RPC and an iterator of response
+ values. Drawing response values from the returned iterator may raise
+ RpcError indicating termination of the RPC with non-OK status.
+ """
+ raise NotImplementedError()
+
+
+############################# Channel Interface ##############################
+
+
+class Channel(six.with_metaclass(abc.ABCMeta)):
+ """Affords RPC invocation via generic methods."""
+
+ @abc.abstractmethod
+ def subscribe(self, callback, try_to_connect=False):
+ """Subscribes to this Channel's connectivity.
+
+ Args:
+ callback: A callable to be invoked and passed a ChannelConnectivity value
+ describing this Channel's connectivity. The callable will be invoked
+ immediately upon subscription and again for every change to this
+ Channel's connectivity thereafter until it is unsubscribed or this
+ Channel object goes out of scope.
+ try_to_connect: A boolean indicating whether or not this Channel should
+ attempt to connect if it is not already connected and ready to conduct
+ RPCs.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def unsubscribe(self, callback):
+ """Unsubscribes a callback from this Channel's connectivity.
+
+ Args:
+ callback: A callable previously registered with this Channel from having
+ been passed to its "subscribe" method.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def unary_unary(
+ self, method, request_serializer=None, response_deserializer=None):
+ """Creates a UnaryUnaryMultiCallable for a unary-unary method.
+
+ Args:
+ method: The name of the RPC method.
+
+ Returns:
+ A UnaryUnaryMultiCallable value for the named unary-unary method.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def unary_stream(
+ self, method, request_serializer=None, response_deserializer=None):
+ """Creates a UnaryStreamMultiCallable for a unary-stream method.
+
+ Args:
+ method: The name of the RPC method.
+
+ Returns:
+ A UnaryStreamMultiCallable value for the name unary-stream method.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def stream_unary(
+ self, method, request_serializer=None, response_deserializer=None):
+ """Creates a StreamUnaryMultiCallable for a stream-unary method.
+
+ Args:
+ method: The name of the RPC method.
+
+ Returns:
+ A StreamUnaryMultiCallable value for the named stream-unary method.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def stream_stream(
+ self, method, request_serializer=None, response_deserializer=None):
+ """Creates a StreamStreamMultiCallable for a stream-stream method.
+
+ Args:
+ method: The name of the RPC method.
+
+ Returns:
+ A StreamStreamMultiCallable value for the named stream-stream method.
+ """
+ raise NotImplementedError()
+
+
+########################## Service-Side Context ##############################
+
+
+class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
+ """A context object passed to method implementations."""
+
+ @abc.abstractmethod
+ def invocation_metadata(self):
+ """Accesses the metadata from the invocation-side of the RPC.
+
+ Returns:
+ The invocation metadata object as a sequence of pairs of bytes.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def peer(self):
+ """Identifies the peer that invoked the RPC being serviced.
+
+ Returns:
+ A string identifying the peer that invoked the RPC being serviced.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def send_initial_metadata(self, initial_metadata):
+ """Sends the initial metadata value to the invocation-side of the RPC.
+
+ This method need not be called by method implementations if they have no
+ service-side initial metadata to transmit.
+
+ Args:
+ initial_metadata: The initial metadata of the RPC as a sequence of pairs
+ of bytes.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def set_trailing_metadata(self, trailing_metadata):
+ """Accepts the trailing metadata value of the RPC.
+
+ This method need not be called by method implementations if they have no
+ service-side trailing metadata to transmit.
+
+ Args:
+ trailing_metadata: The trailing metadata of the RPC as a sequence of pairs
+ of bytes.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def set_code(self, code):
+ """Accepts the status code of the RPC.
+
+ This method need not be called by method implementations if they wish the
+ gRPC runtime to determine the status code of the RPC.
+
+ Args:
+ code: The integer status code of the RPC to be transmitted to the
+ invocation side of the RPC.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def set_details(self, details):
+ """Accepts the service-side details of the RPC.
+
+ This method need not be called by method implementations if they have no
+ details to transmit.
+
+ Args:
+ details: The details bytes of the RPC to be transmitted to
+ the invocation side of the RPC.
+ """
+ raise NotImplementedError()
+
+
+##################### Service-Side Handler Interfaces ########################
+
+
+class RpcMethodHandler(six.with_metaclass(abc.ABCMeta)):
+ """An implementation of a single RPC method.
+
+ Attributes:
+ request_streaming: Whether the RPC supports exactly one request message or
+ any arbitrary number of request messages.
+ response_streaming: Whether the RPC supports exactly one response message or
+ any arbitrary number of response messages.
+ request_deserializer: A callable behavior that accepts a byte string and
+ returns an object suitable to be passed to this object's business logic,
+ or None to indicate that this object's business logic should be passed the
+ raw request bytes.
+ response_serializer: A callable behavior that accepts an object produced by
+ this object's business logic and returns a byte string, or None to
+ indicate that the byte strings produced by this object's business logic
+ should be transmitted on the wire as they are.
+ unary_unary: This object's application-specific business logic as a callable
+ value that takes a request value and a ServicerContext object and returns
+ a response value. Only non-None if both request_streaming and
+ response_streaming are False.
+ unary_stream: This object's application-specific business logic as a
+ callable value that takes a request value and a ServicerContext object and
+ returns an iterator of response values. Only non-None if request_streaming
+ is False and response_streaming is True.
+ stream_unary: This object's application-specific business logic as a
+ callable value that takes an iterator of request values and a
+ ServicerContext object and returns a response value. Only non-None if
+ request_streaming is True and response_streaming is False.
+ stream_stream: This object's application-specific business logic as a
+ callable value that takes an iterator of request values and a
+ ServicerContext object and returns an iterator of response values. Only
+ non-None if request_streaming and response_streaming are both True.
+ """
+
+
+class HandlerCallDetails(six.with_metaclass(abc.ABCMeta)):
+ """Describes an RPC that has just arrived for service.
+ Attributes:
+ method: The method name of the RPC.
+ invocation_metadata: The metadata from the invocation side of the RPC.
+ """
+
+
+class GenericRpcHandler(six.with_metaclass(abc.ABCMeta)):
+ """An implementation of arbitrarily many RPC methods."""
+
+ @abc.abstractmethod
+ def service(self, handler_call_details):
+ """Services an RPC (or not).
+
+ Args:
+ handler_call_details: A HandlerCallDetails describing the RPC.
+
+ Returns:
+ An RpcMethodHandler with which the RPC may be serviced, or None to
+ indicate that this object will not be servicing the RPC.
+ """
+ raise NotImplementedError()
+
+
+############################# Server Interface ###############################
+
+
+class Server(six.with_metaclass(abc.ABCMeta)):
+ """Services RPCs."""
+
+ @abc.abstractmethod
+ def add_generic_rpc_handlers(self, generic_rpc_handlers):
+ """Registers GenericRpcHandlers with this Server.
+
+ This method is only safe to call before the server is started.
+
+ Args:
+ generic_rpc_handlers: An iterable of GenericRpcHandlers that will be used
+ to service RPCs after this Server is started.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def add_insecure_port(self, address):
+ """Reserves a port for insecure RPC service once this Server becomes active.
+
+ This method may only be called before calling this Server's start method is
+ called.
+
+ Args:
+ address: The address for which to open a port.
+
+ Returns:
+ An integer port on which RPCs will be serviced after this link has been
+ started. This is typically the same number as the port number contained
+ in the passed address, but will likely be different if the port number
+ contained in the passed address was zero.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def add_secure_port(self, address, server_credentials):
+ """Reserves a port for secure RPC service after this Server becomes active.
+
+ This method may only be called before calling this Server's start method is
+ called.
+
+ Args:
+ address: The address for which to open a port.
+ server_credentials: A ServerCredentials.
+
+ Returns:
+ An integer port on which RPCs will be serviced after this link has been
+ started. This is typically the same number as the port number contained
+ in the passed address, but will likely be different if the port number
+ contained in the passed address was zero.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def start(self):
+ """Starts this Server's service of RPCs.
+
+ This method may only be called while the server is not serving RPCs (i.e. it
+ is not idempotent).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def stop(self, grace):
+ """Stops this Server's service of RPCs.
+
+ All calls to this method immediately stop service of new RPCs. When existing
+ RPCs are aborted is controlled by the grace period parameter passed to this
+ method.
+
+ This method may be called at any time and is idempotent. Passing a smaller
+ grace value than has been passed in a previous call will have the effect of
+ stopping the Server sooner. Passing a larger grace value than has been
+ passed in a previous call will not have the effect of stopping the server
+ later.
+
+ Args:
+ grace: A duration of time in seconds to allow existing RPCs to complete
+ before being aborted by this Server's stopping. If None, this method
+ will block until the server is completely stopped.
+
+ Returns:
+ A threading.Event that will be set when this Server has completely
+ stopped. The returned event may not be set until after the full grace
+ period (if some ongoing RPC continues for the full length of the period)
+ of it may be set much sooner (such as if this Server had no RPCs underway
+ at the time it was stopped or if all RPCs that it had underway completed
+ very early in the grace period).
+ """
+ raise NotImplementedError()
+
+
+################################# Functions ################################
+
+
+def unary_unary_rpc_method_handler(
+ behavior, request_deserializer=None, response_serializer=None):
+ """Creates an RpcMethodHandler for a unary-unary RPC method.
+
+ Args:
+ behavior: The implementation of an RPC method as a callable behavior taking
+ a single request value and returning a single response value.
+ request_deserializer: An optional request deserialization behavior.
+ response_serializer: An optional response serialization behavior.
+
+ Returns:
+ An RpcMethodHandler for a unary-unary RPC method constructed from the given
+ parameters.
+ """
+ from grpc import _utilities
+ return _utilities.RpcMethodHandler(
+ False, False, request_deserializer, response_serializer, behavior, None,
+ None, None)
+
+
+def unary_stream_rpc_method_handler(
+ behavior, request_deserializer=None, response_serializer=None):
+ """Creates an RpcMethodHandler for a unary-stream RPC method.
+
+ Args:
+ behavior: The implementation of an RPC method as a callable behavior taking
+ a single request value and returning an iterator of response values.
+ request_deserializer: An optional request deserialization behavior.
+ response_serializer: An optional response serialization behavior.
+
+ Returns:
+ An RpcMethodHandler for a unary-stream RPC method constructed from the
+ given parameters.
+ """
+ from grpc import _utilities
+ return _utilities.RpcMethodHandler(
+ False, True, request_deserializer, response_serializer, None, behavior,
+ None, None)
+
+
+def stream_unary_rpc_method_handler(
+ behavior, request_deserializer=None, response_serializer=None):
+ """Creates an RpcMethodHandler for a stream-unary RPC method.
+
+ Args:
+ behavior: The implementation of an RPC method as a callable behavior taking
+ an iterator of request values and returning a single response value.
+ request_deserializer: An optional request deserialization behavior.
+ response_serializer: An optional response serialization behavior.
+
+ Returns:
+ An RpcMethodHandler for a stream-unary RPC method constructed from the
+ given parameters.
+ """
+ from grpc import _utilities
+ return _utilities.RpcMethodHandler(
+ True, False, request_deserializer, response_serializer, None, None,
+ behavior, None)
+
+
+def stream_stream_rpc_method_handler(
+ behavior, request_deserializer=None, response_serializer=None):
+ """Creates an RpcMethodHandler for a stream-stream RPC method.
+
+ Args:
+ behavior: The implementation of an RPC method as a callable behavior taking
+ an iterator of request values and returning an iterator of response
+ values.
+ request_deserializer: An optional request deserialization behavior.
+ response_serializer: An optional response serialization behavior.
+
+ Returns:
+ An RpcMethodHandler for a stream-stream RPC method constructed from the
+ given parameters.
+ """
+ from grpc import _utilities
+ return _utilities.RpcMethodHandler(
+ True, True, request_deserializer, response_serializer, None, None, None,
+ behavior)
+
+
+def method_handlers_generic_handler(service, method_handlers):
+ """Creates a grpc.GenericRpcHandler from RpcMethodHandlers.
+
+ Args:
+ service: A service name to be used for the given method handlers.
+ method_handlers: A dictionary from method name to RpcMethodHandler
+ implementing the named method.
+
+ Returns:
+ A GenericRpcHandler constructed from the given parameters.
+ """
+ from grpc import _utilities
+ return _utilities.DictionaryGenericHandler(service, method_handlers)
+
+
+def ssl_channel_credentials(
+ root_certificates=None, private_key=None, certificate_chain=None):
+ """Creates a ChannelCredentials for use with an SSL-enabled Channel.
+
+ Args:
+ root_certificates: The PEM-encoded root certificates or unset to ask for
+ them to be retrieved from a default location.
+ private_key: The PEM-encoded private key to use or unset if no private key
+ should be used.
+ certificate_chain: The PEM-encoded certificate chain to use or unset if no
+ certificate chain should be used.
+
+ Returns:
+ A ChannelCredentials for use with an SSL-enabled Channel.
+ """
+ if private_key is not None or certificate_chain is not None:
+ pair = _cygrpc.SslPemKeyCertPair(private_key, certificate_chain)
+ else:
+ pair = None
+ return ChannelCredentials(
+ _cygrpc.channel_credentials_ssl(root_certificates, pair))
+
+
+def metadata_call_credentials(metadata_plugin, name=None):
+ """Construct CallCredentials from an AuthMetadataPlugin.
+
+ Args:
+ metadata_plugin: An AuthMetadataPlugin to use as the authentication behavior
+ in the created CallCredentials.
+ name: A name for the plugin.
+
+ Returns:
+ A CallCredentials.
+ """
+ from grpc import _plugin_wrapping
+ if name is None:
+ try:
+ effective_name = metadata_plugin.__name__
+ except AttributeError:
+ effective_name = metadata_plugin.__class__.__name__
+ else:
+ effective_name = name
+ return CallCredentials(
+ _plugin_wrapping.call_credentials_metadata_plugin(
+ metadata_plugin, effective_name))
+
+
+def access_token_call_credentials(access_token):
+ """Construct CallCredentials from an access token.
+
+ Args:
+ access_token: A string to place directly in the http request
+ authorization header, ie "Authorization: Bearer <access_token>".
+
+ Returns:
+ A CallCredentials.
+ """
+ from grpc import _auth
+ return metadata_call_credentials(
+ _auth.AccessTokenCallCredentials(access_token))
+
+
+def composite_call_credentials(call_credentials, additional_call_credentials):
+ """Compose two CallCredentials to make a new one.
+
+ Args:
+ call_credentials: A CallCredentials object.
+ additional_call_credentials: Another CallCredentials object to compose on
+ top of call_credentials.
+
+ Returns:
+ A new CallCredentials composed of the two given CallCredentials.
+ """
+ return CallCredentials(
+ _cygrpc.call_credentials_composite(
+ call_credentials._credentials,
+ additional_call_credentials._credentials))
+
+
+def composite_channel_credentials(channel_credentials, call_credentials):
+ """Compose a ChannelCredentials and a CallCredentials.
+
+ Args:
+ channel_credentials: A ChannelCredentials.
+ call_credentials: A CallCredentials.
+
+ Returns:
+ A ChannelCredentials composed of the given ChannelCredentials and
+ CallCredentials.
+ """
+ return ChannelCredentials(
+ _cygrpc.channel_credentials_composite(
+ channel_credentials._credentials, call_credentials._credentials))
+
+
+def ssl_server_credentials(
+ private_key_certificate_chain_pairs, root_certificates=None,
+ require_client_auth=False):
+ """Creates a ServerCredentials for use with an SSL-enabled Server.
+
+ Args:
+ private_key_certificate_chain_pairs: A nonempty sequence each element of
+ which is a pair the first element of which is a PEM-encoded private key
+ and the second element of which is the corresponding PEM-encoded
+ certificate chain.
+ root_certificates: PEM-encoded client root certificates to be used for
+ verifying authenticated clients. If omitted, require_client_auth must also
+ be omitted or be False.
+ require_client_auth: A boolean indicating whether or not to require clients
+ to be authenticated. May only be True if root_certificates is not None.
+
+ Returns:
+ A ServerCredentials for use with an SSL-enabled Server.
+ """
+ if len(private_key_certificate_chain_pairs) == 0:
+ raise ValueError(
+ 'At least one private key-certificate chain pair is required!')
+ elif require_client_auth and root_certificates is None:
+ raise ValueError(
+ 'Illegal to require client auth without providing root certificates!')
+ else:
+ return ServerCredentials(
+ _cygrpc.server_credentials_ssl(
+ root_certificates,
+ [_cygrpc.SslPemKeyCertPair(key, pem)
+ for key, pem in private_key_certificate_chain_pairs],
+ require_client_auth))
+
+
+def channel_ready_future(channel):
+ """Creates a Future tracking when a Channel is ready.
+
+ Cancelling the returned Future does not tell the given Channel to abandon
+ attempts it may have been making to connect; cancelling merely deactivates the
+ returned Future's subscription to the given Channel's connectivity.
+
+ Args:
+ channel: A Channel.
+
+ Returns:
+ A Future that matures when the given Channel has connectivity
+ ChannelConnectivity.READY.
+ """
+ from grpc import _utilities
+ return _utilities.channel_ready_future(channel)
+
+
+def insecure_channel(target, options=None):
+ """Creates an insecure Channel to a server.
+
+ Args:
+ target: The target to which to connect.
+ options: A sequence of string-value pairs according to which to configure
+ the created channel.
+
+ Returns:
+ A Channel to the target through which RPCs may be conducted.
+ """
+ from grpc import _channel
+ return _channel.Channel(target, options, None)
+
+
+def secure_channel(target, credentials, options=None):
+ """Creates an insecure Channel to a server.
+
+ Args:
+ target: The target to which to connect.
+ credentials: A ChannelCredentials instance.
+ options: A sequence of string-value pairs according to which to configure
+ the created channel.
+
+ Returns:
+ A Channel to the target through which RPCs may be conducted.
+ """
+ from grpc import _channel
+ return _channel.Channel(target, options, credentials)
+
+
+def server(generic_rpc_handlers, thread_pool, options=None):
+ """Creates a Server with which RPCs can be serviced.
+
+ The GenericRpcHandlers passed to this function needn't be the only
+ GenericRpcHandlers that will be used to serve RPCs; others may be added later
+ by calling add_generic_rpc_handlers any time before the returned server is
+ started.
+
+ Args:
+ generic_rpc_handlers: Some number of GenericRpcHandlers that will be used
+ to service RPCs after the returned Server is started.
+ thread_pool: A futures.ThreadPoolExecutor to be used by the returned Server
+ to service RPCs.
+
+ Returns:
+ A Server with which RPCs can be serviced.
+ """
+ from grpc import _server
+ return _server.Server(generic_rpc_handlers, thread_pool)
+
+
+################################### __all__ #################################
+
+
+__all__ = (
+ 'FutureTimeoutError',
+ 'FutureCancelledError',
+ 'Future',
+ 'ChannelConnectivity',
+ 'StatusCode',
+ 'RpcError',
+ 'RpcContext',
+ 'Call',
+ 'ChannelCredentials',
+ 'CallCredentials',
+ 'AuthMetadataContext',
+ 'AuthMetadataPluginCallback',
+ 'AuthMetadataPlugin',
+ 'ServerCredentials',
+ 'UnaryUnaryMultiCallable',
+ 'UnaryStreamMultiCallable',
+ 'StreamUnaryMultiCallable',
+ 'StreamStreamMultiCallable',
+ 'Channel',
+ 'ServicerContext',
+ 'RpcMethodHandler',
+ 'HandlerCallDetails',
+ 'GenericRpcHandler',
+ 'Server',
+ 'unary_unary_rpc_method_handler',
+ 'unary_stream_rpc_method_handler',
+ 'stream_unary_rpc_method_handler',
+ 'stream_stream_rpc_method_handler',
+ 'method_handlers_generic_handler',
+ 'ssl_channel_credentials',
+ 'metadata_call_credentials',
+ 'access_token_call_credentials',
+ 'composite_call_credentials',
+ 'composite_channel_credentials',
+ 'ssl_server_credentials',
+ 'channel_ready_future',
+ 'insecure_channel',
+ 'secure_channel',
+ 'server',
+)
diff --git a/src/python/grpcio/grpc/_adapter/_low.py b/src/python/grpcio/grpc/_adapter/_low.py
index 00788bd4cf..48410167a0 100644
--- a/src/python/grpcio/grpc/_adapter/_low.py
+++ b/src/python/grpcio/grpc/_adapter/_low.py
@@ -30,8 +30,8 @@
import threading
from grpc import _grpcio_metadata
+from grpc import _plugin_wrapping
from grpc._cython import cygrpc
-from grpc._adapter import _implementations
from grpc._adapter import _types
_USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
@@ -57,78 +57,8 @@ def channel_credentials_ssl(
return cygrpc.channel_credentials_ssl(root_certificates, pair)
-class _WrappedCygrpcCallback(object):
-
- def __init__(self, cygrpc_callback):
- self.is_called = False
- self.error = None
- self.is_called_lock = threading.Lock()
- self.cygrpc_callback = cygrpc_callback
-
- def _invoke_failure(self, error):
- # TODO(atash) translate different Exception superclasses into different
- # status codes.
- self.cygrpc_callback(
- cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message)
-
- def _invoke_success(self, metadata):
- try:
- cygrpc_metadata = cygrpc.Metadata(
- cygrpc.Metadatum(key, value)
- for key, value in metadata)
- except Exception as error:
- self._invoke_failure(error)
- return
- self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '')
-
- def __call__(self, metadata, error):
- with self.is_called_lock:
- if self.is_called:
- raise RuntimeError('callback should only ever be invoked once')
- if self.error:
- self._invoke_failure(self.error)
- return
- self.is_called = True
- if error is None:
- self._invoke_success(metadata)
- else:
- self._invoke_failure(error)
-
- def notify_failure(self, error):
- with self.is_called_lock:
- if not self.is_called:
- self.error = error
-
-
-class _WrappedPlugin(object):
-
- def __init__(self, plugin):
- self.plugin = plugin
-
- def __call__(self, context, cygrpc_callback):
- wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback)
- wrapped_context = _implementations.AuthMetadataContext(context.service_url,
- context.method_name)
- try:
- self.plugin(
- wrapped_context,
- _implementations.AuthMetadataPluginCallback(wrapped_cygrpc_callback))
- except Exception as error:
- wrapped_cygrpc_callback.notify_failure(error)
- raise
-
-
-def call_credentials_metadata_plugin(plugin, name):
- """
- Args:
- plugin: A callable accepting a _types.AuthMetadataContext
- object and a callback (itself accepting a list of metadata key/value
- 2-tuples and a None-able exception value). The callback must be eventually
- called, but need not be called in plugin's invocation.
- plugin's invocation must be non-blocking.
- """
- return cygrpc.call_credentials_metadata_plugin(
- cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name))
+call_credentials_metadata_plugin = (
+ _plugin_wrapping.call_credentials_metadata_plugin)
class CompletionQueue(_types.CompletionQueue):
diff --git a/src/python/grpcio/grpc/_adapter/_types.py b/src/python/grpcio/grpc/_adapter/_types.py
index f8405949d4..b7cc6fbbb5 100644
--- a/src/python/grpcio/grpc/_adapter/_types.py
+++ b/src/python/grpcio/grpc/_adapter/_types.py
@@ -114,7 +114,7 @@ class ConnectivityState(enum.IntEnum):
CONNECTING = cygrpc.ConnectivityState.connecting
READY = cygrpc.ConnectivityState.ready
TRANSIENT_FAILURE = cygrpc.ConnectivityState.transient_failure
- FATAL_FAILURE = cygrpc.ConnectivityState.fatal_failure
+ FATAL_FAILURE = cygrpc.ConnectivityState.shutdown
class Status(collections.namedtuple(
diff --git a/src/python/grpcio/grpc/_auth.py b/src/python/grpcio/grpc/_auth.py
new file mode 100644
index 0000000000..dea3221c9d
--- /dev/null
+++ b/src/python/grpcio/grpc/_auth.py
@@ -0,0 +1,86 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""GRPCAuthMetadataPlugins for standard authentication."""
+
+import inspect
+from concurrent import futures
+
+import grpc
+
+
+def _sign_request(callback, token, error):
+ metadata = (('authorization', 'Bearer {}'.format(token)),)
+ callback(metadata, error)
+
+
+class GoogleCallCredentials(grpc.AuthMetadataPlugin):
+ """Metadata wrapper for GoogleCredentials from the oauth2client library."""
+
+ def __init__(self, credentials):
+ self._credentials = credentials
+ self._pool = futures.ThreadPoolExecutor(max_workers=1)
+
+ # Hack to determine if these are JWT creds and we need to pass
+ # additional_claims when getting a token
+ if 'additional_claims' in inspect.getargspec(
+ credentials.get_access_token).args:
+ self._is_jwt = True
+ else:
+ self._is_jwt = False
+
+ def __call__(self, context, callback):
+ # MetadataPlugins cannot block (see grpc.beta.interfaces.py)
+ if self._is_jwt:
+ future = self._pool.submit(self._credentials.get_access_token,
+ additional_claims={'aud': context.service_url})
+ else:
+ future = self._pool.submit(self._credentials.get_access_token)
+ future.add_done_callback(lambda x: self._get_token_callback(callback, x))
+
+ def _get_token_callback(self, callback, future):
+ try:
+ access_token = future.result().access_token
+ except Exception as e:
+ _sign_request(callback, None, e)
+ else:
+ _sign_request(callback, access_token, None)
+
+ def __del__(self):
+ self._pool.shutdown(wait=False)
+
+
+class AccessTokenCallCredentials(grpc.AuthMetadataPlugin):
+ """Metadata wrapper for raw access token credentials."""
+
+ def __init__(self, access_token):
+ self._access_token = access_token
+
+ def __call__(self, context, callback):
+ _sign_request(callback, self._access_token, None)
diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py
new file mode 100644
index 0000000000..d9315d2e6c
--- /dev/null
+++ b/src/python/grpcio/grpc/_channel.py
@@ -0,0 +1,852 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Invocation-side implementation of gRPC Python."""
+
+import sys
+import threading
+import time
+
+import grpc
+from grpc import _common
+from grpc import _grpcio_metadata
+from grpc.framework.foundation import callable_util
+from grpc._cython import cygrpc
+
+_USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
+
+_EMPTY_FLAGS = 0
+_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
+_EMPTY_METADATA = cygrpc.Metadata(())
+
+_UNARY_UNARY_INITIAL_DUE = (
+ cygrpc.OperationType.send_initial_metadata,
+ cygrpc.OperationType.send_message,
+ cygrpc.OperationType.send_close_from_client,
+ cygrpc.OperationType.receive_initial_metadata,
+ cygrpc.OperationType.receive_message,
+ cygrpc.OperationType.receive_status_on_client,
+)
+_UNARY_STREAM_INITIAL_DUE = (
+ cygrpc.OperationType.send_initial_metadata,
+ cygrpc.OperationType.send_message,
+ cygrpc.OperationType.send_close_from_client,
+ cygrpc.OperationType.receive_initial_metadata,
+ cygrpc.OperationType.receive_status_on_client,
+)
+_STREAM_UNARY_INITIAL_DUE = (
+ cygrpc.OperationType.send_initial_metadata,
+ cygrpc.OperationType.receive_initial_metadata,
+ cygrpc.OperationType.receive_message,
+ cygrpc.OperationType.receive_status_on_client,
+)
+_STREAM_STREAM_INITIAL_DUE = (
+ cygrpc.OperationType.send_initial_metadata,
+ cygrpc.OperationType.receive_initial_metadata,
+ cygrpc.OperationType.receive_status_on_client,
+)
+
+_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
+ 'Exception calling channel subscription callback!')
+
+
+def _deadline(timeout):
+ if timeout is None:
+ return None, _INFINITE_FUTURE
+ else:
+ deadline = time.time() + timeout
+ return deadline, cygrpc.Timespec(deadline)
+
+
+def _unknown_code_details(unknown_cygrpc_code, details):
+ return 'Server sent unknown code {} and details "{}"'.format(
+ unknown_cygrpc_code, details)
+
+
+def _wait_once_until(condition, until):
+ if until is None:
+ condition.wait()
+ else:
+ remaining = until - time.time()
+ if remaining < 0:
+ raise grpc.FutureTimeoutError()
+ else:
+ condition.wait(timeout=remaining)
+
+
+class _RPCState(object):
+
+ def __init__(self, due, initial_metadata, trailing_metadata, code, details):
+ self.condition = threading.Condition()
+ # The cygrpc.OperationType objects representing events due from the RPC's
+ # completion queue.
+ self.due = set(due)
+ self.initial_metadata = initial_metadata
+ self.response = None
+ self.trailing_metadata = trailing_metadata
+ self.code = code
+ self.details = details
+ # The semantics of grpc.Future.cancel and grpc.Future.cancelled are
+ # slightly wonky, so they have to be tracked separately from the rest of the
+ # result of the RPC. This field tracks whether cancellation was requested
+ # prior to termination of the RPC.
+ self.cancelled = False
+ self.callbacks = []
+
+
+def _abort(state, code, details):
+ if state.code is None:
+ state.code = code
+ state.details = details
+ if state.initial_metadata is None:
+ state.initial_metadata = _EMPTY_METADATA
+ state.trailing_metadata = _EMPTY_METADATA
+
+
+def _handle_event(event, state, response_deserializer):
+ callbacks = []
+ for batch_operation in event.batch_operations:
+ operation_type = batch_operation.type
+ state.due.remove(operation_type)
+ if operation_type == cygrpc.OperationType.receive_initial_metadata:
+ state.initial_metadata = batch_operation.received_metadata
+ elif operation_type == cygrpc.OperationType.receive_message:
+ serialized_response = batch_operation.received_message.bytes()
+ if serialized_response is not None:
+ response = _common.deserialize(
+ serialized_response, response_deserializer)
+ if response is None:
+ details = 'Exception deserializing response!'
+ _abort(state, grpc.StatusCode.INTERNAL, details)
+ else:
+ state.response = response
+ elif operation_type == cygrpc.OperationType.receive_status_on_client:
+ state.trailing_metadata = batch_operation.received_metadata
+ if state.code is None:
+ code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE.get(
+ batch_operation.received_status_code)
+ if code is None:
+ state.code = grpc.StatusCode.UNKNOWN
+ state.details = _unknown_code_details(
+ batch_operation.received_status_code,
+ batch_operation.received_status_details)
+ else:
+ state.code = code
+ state.details = batch_operation.received_status_details
+ callbacks.extend(state.callbacks)
+ state.callbacks = None
+ return callbacks
+
+
+def _event_handler(state, call, response_deserializer):
+ def handle_event(event):
+ with state.condition:
+ callbacks = _handle_event(event, state, response_deserializer)
+ state.condition.notify_all()
+ done = not state.due
+ for callback in callbacks:
+ callback()
+ return call if done else None
+ return handle_event
+
+
+def _consume_request_iterator(
+ request_iterator, state, call, request_serializer):
+ event_handler = _event_handler(state, call, None)
+ def consume_request_iterator():
+ for request in request_iterator:
+ serialized_request = _common.serialize(request, request_serializer)
+ with state.condition:
+ if state.code is None and not state.cancelled:
+ if serialized_request is None:
+ call.cancel()
+ details = 'Exception serializing request!'
+ _abort(state, grpc.StatusCode.INTERNAL, details)
+ return
+ else:
+ operations = (
+ cygrpc.operation_send_message(
+ serialized_request, _EMPTY_FLAGS),
+ )
+ call.start_batch(cygrpc.Operations(operations), event_handler)
+ state.due.add(cygrpc.OperationType.send_message)
+ while True:
+ state.condition.wait()
+ if state.code is None:
+ if cygrpc.OperationType.send_message not in state.due:
+ break
+ else:
+ return
+ else:
+ return
+ with state.condition:
+ if state.code is None:
+ operations = (
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ )
+ call.start_batch(cygrpc.Operations(operations), event_handler)
+ state.due.add(cygrpc.OperationType.send_close_from_client)
+ thread = threading.Thread(target=consume_request_iterator)
+ thread.start()
+
+
+class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
+
+ def __init__(self, state, call, response_deserializer, deadline):
+ super(_Rendezvous, self).__init__()
+ self._state = state
+ self._call = call
+ self._response_deserializer = response_deserializer
+ self._deadline = deadline
+
+ def cancel(self):
+ with self._state.condition:
+ if self._state.code is None:
+ self._call.cancel()
+ self._state.cancelled = True
+ _abort(self._state, grpc.StatusCode.CANCELLED, 'Cancelled!')
+ self._state.condition.notify_all()
+ return False
+
+ def cancelled(self):
+ with self._state.condition:
+ return self._state.cancelled
+
+ def running(self):
+ with self._state.condition:
+ return self._state.code is None
+
+ def done(self):
+ with self._state.condition:
+ return self._state.code is not None
+
+ def result(self, timeout=None):
+ until = None if timeout is None else time.time() + timeout
+ with self._state.condition:
+ while True:
+ if self._state.code is None:
+ _wait_once_until(self._state.condition, until)
+ elif self._state.code is grpc.StatusCode.OK:
+ return self._state.response
+ elif self._state.cancelled:
+ raise grpc.FutureCancelledError()
+ else:
+ raise self
+
+ def exception(self, timeout=None):
+ until = None if timeout is None else time.time() + timeout
+ with self._state.condition:
+ while True:
+ if self._state.code is None:
+ _wait_once_until(self._state.condition, until)
+ elif self._state.code is grpc.StatusCode.OK:
+ return None
+ elif self._state.cancelled:
+ raise grpc.FutureCancelledError()
+ else:
+ return self
+
+ def traceback(self, timeout=None):
+ until = None if timeout is None else time.time() + timeout
+ with self._state.condition:
+ while True:
+ if self._state.code is None:
+ _wait_once_until(self._state.condition, until)
+ elif self._state.code is grpc.StatusCode.OK:
+ return None
+ elif self._state.cancelled:
+ raise grpc.FutureCancelledError()
+ else:
+ try:
+ raise self
+ except grpc.RpcError:
+ return sys.exc_info()[2]
+
+ def add_done_callback(self, fn):
+ with self._state.condition:
+ if self._state.code is None:
+ self._state.callbacks.append(lambda: fn(self))
+ return
+
+ fn(self)
+
+ def _next(self):
+ with self._state.condition:
+ if self._state.code is None:
+ event_handler = _event_handler(
+ self._state, self._call, self._response_deserializer)
+ self._call.start_batch(
+ cygrpc.Operations(
+ (cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
+ event_handler)
+ self._state.due.add(cygrpc.OperationType.receive_message)
+ elif self._state.code is grpc.StatusCode.OK:
+ raise StopIteration()
+ else:
+ raise self
+ while True:
+ self._state.condition.wait()
+ if self._state.response is not None:
+ response = self._state.response
+ self._state.response = None
+ return response
+ elif cygrpc.OperationType.receive_message not in self._state.due:
+ if self._state.code is grpc.StatusCode.OK:
+ raise StopIteration()
+ elif self._state.code is not None:
+ raise self
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return self._next()
+
+ def next(self):
+ return self._next()
+
+ def is_active(self):
+ with self._state.condition:
+ return self._state.code is None
+
+ def time_remaining(self):
+ if self._deadline is None:
+ return None
+ else:
+ return max(self._deadline - time.time(), 0)
+
+ def add_cancellation_callback(self, callback):
+ with self._state.condition:
+ if self._state.callbacks is None:
+ return False
+ else:
+ self._state.callbacks.append(lambda unused_future: callback())
+ return True
+
+ def initial_metadata(self):
+ with self._state.condition:
+ while self._state.initial_metadata is None:
+ self._state.condition.wait()
+ return self._state.initial_metadata
+
+ def trailing_metadata(self):
+ with self._state.condition:
+ while self._state.trailing_metadata is None:
+ self._state.condition.wait()
+ return self._state.trailing_metadata
+
+ def code(self):
+ with self._state.condition:
+ while self._state.code is None:
+ self._state.condition.wait()
+ return self._state.code
+
+ def details(self):
+ with self._state.condition:
+ while self._state.details is None:
+ self._state.condition.wait()
+ return self._state.details
+
+ def _repr(self):
+ with self._state.condition:
+ if self._state.code is None:
+ return '<_Rendezvous object of in-flight RPC>'
+ else:
+ return '<_Rendezvous of RPC that terminated with ({}, {})>'.format(
+ self._state.code, self._state.details)
+
+ def __repr__(self):
+ return self._repr()
+
+ def __str__(self):
+ return self._repr()
+
+ def __del__(self):
+ with self._state.condition:
+ if self._state.code is None:
+ self._call.cancel()
+ self._state.cancelled = True
+ self._state.code = grpc.StatusCode.CANCELLED
+ self._state.condition.notify_all()
+
+
+def _start_unary_request(request, timeout, request_serializer):
+ deadline, deadline_timespec = _deadline(timeout)
+ serialized_request = _common.serialize(request, request_serializer)
+ if serialized_request is None:
+ state = _RPCState(
+ (), _EMPTY_METADATA, _EMPTY_METADATA, grpc.StatusCode.INTERNAL,
+ 'Exception serializing request!')
+ rendezvous = _Rendezvous(state, None, None, deadline)
+ return deadline, deadline_timespec, None, rendezvous
+ else:
+ return deadline, deadline_timespec, serialized_request, None
+
+
+def _end_unary_response_blocking(state, with_call, deadline):
+ if state.code is grpc.StatusCode.OK:
+ if with_call:
+ rendezvous = _Rendezvous(state, None, None, deadline)
+ return state.response, rendezvous
+ else:
+ return state.response
+ else:
+ raise _Rendezvous(state, None, None, deadline)
+
+
+class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
+
+ def __init__(
+ self, channel, create_managed_call, method, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._create_managed_call = create_managed_call
+ self._method = method
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def _prepare(self, request, timeout, metadata):
+ deadline, deadline_timespec, serialized_request, rendezvous = (
+ _start_unary_request(request, timeout, self._request_serializer))
+ if serialized_request is None:
+ return None, None, None, None, rendezvous
+ else:
+ state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None)
+ operations = (
+ cygrpc.operation_send_initial_metadata(
+ _common.metadata(metadata), _EMPTY_FLAGS),
+ cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS),
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
+ )
+ return state, operations, deadline, deadline_timespec, None
+
+ def __call__(
+ self, request, timeout=None, metadata=None, credentials=None,
+ with_call=False):
+ state, operations, deadline, deadline_timespec, rendezvous = self._prepare(
+ request, timeout, metadata)
+ if rendezvous:
+ raise rendezvous
+ else:
+ completion_queue = cygrpc.CompletionQueue()
+ call = self._channel.create_call(
+ None, 0, completion_queue, self._method, None, deadline_timespec)
+ if credentials is not None:
+ call.set_credentials(credentials._credentials)
+ call.start_batch(cygrpc.Operations(operations), None)
+ _handle_event(completion_queue.poll(), state, self._response_deserializer)
+ return _end_unary_response_blocking(state, with_call, deadline)
+
+ def future(self, request, timeout=None, metadata=None, credentials=None):
+ state, operations, deadline, deadline_timespec, rendezvous = self._prepare(
+ request, timeout, metadata)
+ if rendezvous:
+ return rendezvous
+ else:
+ call = self._create_managed_call(
+ None, 0, self._method, None, deadline_timespec)
+ if credentials is not None:
+ call.set_credentials(credentials._credentials)
+ event_handler = _event_handler(state, call, self._response_deserializer)
+ with state.condition:
+ call.start_batch(cygrpc.Operations(operations), event_handler)
+ return _Rendezvous(state, call, self._response_deserializer, deadline)
+
+
+class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
+
+ def __init__(
+ self, channel, create_managed_call, method, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._create_managed_call = create_managed_call
+ self._method = method
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(self, request, timeout=None, metadata=None, credentials=None):
+ deadline, deadline_timespec, serialized_request, rendezvous = (
+ _start_unary_request(request, timeout, self._request_serializer))
+ if serialized_request is None:
+ raise rendezvous
+ else:
+ state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
+ call = self._create_managed_call(
+ None, 0, self._method, None, deadline_timespec)
+ if credentials is not None:
+ call.set_credentials(credentials._credentials)
+ event_handler = _event_handler(state, call, self._response_deserializer)
+ with state.condition:
+ call.start_batch(
+ cygrpc.Operations(
+ (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
+ event_handler)
+ operations = (
+ cygrpc.operation_send_initial_metadata(
+ _common.metadata(metadata), _EMPTY_FLAGS),
+ cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS),
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
+ )
+ call.start_batch(cygrpc.Operations(operations), event_handler)
+ return _Rendezvous(state, call, self._response_deserializer, deadline)
+
+
+class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
+
+ def __init__(
+ self, channel, create_managed_call, method, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._create_managed_call = create_managed_call
+ self._method = method
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(
+ self, request_iterator, timeout=None, metadata=None, credentials=None,
+ with_call=False):
+ deadline, deadline_timespec = _deadline(timeout)
+ state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
+ completion_queue = cygrpc.CompletionQueue()
+ call = self._channel.create_call(
+ None, 0, completion_queue, self._method, None, deadline_timespec)
+ if credentials is not None:
+ call.set_credentials(credentials._credentials)
+ with state.condition:
+ call.start_batch(
+ cygrpc.Operations(
+ (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
+ None)
+ operations = (
+ cygrpc.operation_send_initial_metadata(
+ _common.metadata(metadata), _EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
+ )
+ call.start_batch(cygrpc.Operations(operations), None)
+ _consume_request_iterator(
+ request_iterator, state, call, self._request_serializer)
+ while True:
+ event = completion_queue.poll()
+ with state.condition:
+ _handle_event(event, state, self._response_deserializer)
+ state.condition.notify_all()
+ if not state.due:
+ break
+ return _end_unary_response_blocking(state, with_call, deadline)
+
+ def future(
+ self, request_iterator, timeout=None, metadata=None, credentials=None):
+ deadline, deadline_timespec = _deadline(timeout)
+ state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
+ call = self._create_managed_call(
+ None, 0, self._method, None, deadline_timespec)
+ if credentials is not None:
+ call.set_credentials(credentials._credentials)
+ event_handler = _event_handler(state, call, self._response_deserializer)
+ with state.condition:
+ call.start_batch(
+ cygrpc.Operations(
+ (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
+ event_handler)
+ operations = (
+ cygrpc.operation_send_initial_metadata(
+ _common.metadata(metadata), _EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
+ )
+ call.start_batch(cygrpc.Operations(operations), event_handler)
+ _consume_request_iterator(
+ request_iterator, state, call, self._request_serializer)
+ return _Rendezvous(state, call, self._response_deserializer, deadline)
+
+
+class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
+
+ def __init__(
+ self, channel, create_managed_call, method, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._create_managed_call = create_managed_call
+ self._method = method
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(
+ self, request_iterator, timeout=None, metadata=None, credentials=None):
+ deadline, deadline_timespec = _deadline(timeout)
+ state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None)
+ call = self._create_managed_call(
+ None, 0, self._method, None, deadline_timespec)
+ if credentials is not None:
+ call.set_credentials(credentials._credentials)
+ event_handler = _event_handler(state, call, self._response_deserializer)
+ with state.condition:
+ call.start_batch(
+ cygrpc.Operations(
+ (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
+ event_handler)
+ operations = (
+ cygrpc.operation_send_initial_metadata(
+ _common.metadata(metadata), _EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
+ )
+ call.start_batch(cygrpc.Operations(operations), event_handler)
+ _consume_request_iterator(
+ request_iterator, state, call, self._request_serializer)
+ return _Rendezvous(state, call, self._response_deserializer, deadline)
+
+
+class _ChannelCallState(object):
+
+ def __init__(self, channel):
+ self.lock = threading.Lock()
+ self.channel = channel
+ self.completion_queue = cygrpc.CompletionQueue()
+ self.managed_calls = None
+
+
+def _call_spin(state):
+ while True:
+ event = state.completion_queue.poll()
+ completed_call = event.tag(event)
+ if completed_call is not None:
+ with state.lock:
+ state.managed_calls.remove(completed_call)
+ if not state.managed_calls:
+ state.managed_calls = None
+ return
+
+
+def _create_channel_managed_call(state):
+ def create_channel_managed_call(parent, flags, method, host, deadline):
+ """Creates a managed cygrpc.Call.
+
+ Callers of this function must conduct at least one operation on the returned
+ call. The tags associated with operations conducted on the returned call
+ must be no-argument callables that return None to indicate that this channel
+ should continue polling for events associated with the call and return the
+ call itself to indicate that no more events associated with the call will be
+ generated.
+
+ Args:
+ parent: A cygrpc.Call to be used as the parent of the created call.
+ flags: An integer bitfield of call flags.
+ method: The RPC method.
+ host: A host string for the created call.
+ deadline: A cygrpc.Timespec to be the deadline of the created call.
+
+ Returns:
+ A cygrpc.Call with which to conduct an RPC.
+ """
+ with state.lock:
+ call = state.channel.create_call(
+ parent, flags, state.completion_queue, method, host, deadline)
+ if state.managed_calls is None:
+ state.managed_calls = set((call,))
+ spin_thread = threading.Thread(target=_call_spin, args=(state,))
+ spin_thread.start()
+ else:
+ state.managed_calls.add(call)
+ return call
+ return create_channel_managed_call
+
+
+class _ChannelConnectivityState(object):
+
+ def __init__(self, channel):
+ self.lock = threading.Lock()
+ self.channel = channel
+ self.polling = False
+ self.connectivity = None
+ self.try_to_connect = False
+ self.callbacks_and_connectivities = []
+ self.delivering = False
+
+
+def _deliveries(state):
+ callbacks_needing_update = []
+ for callback_and_connectivity in state.callbacks_and_connectivities:
+ callback, callback_connectivity, = callback_and_connectivity
+ if callback_connectivity is not state.connectivity:
+ callbacks_needing_update.append(callback)
+ callback_and_connectivity[1] = state.connectivity
+ return callbacks_needing_update
+
+
+def _deliver(state, initial_connectivity, initial_callbacks):
+ connectivity = initial_connectivity
+ callbacks = initial_callbacks
+ while True:
+ for callback in callbacks:
+ callable_util.call_logging_exceptions(
+ callback, _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE,
+ connectivity)
+ with state.lock:
+ callbacks = _deliveries(state)
+ if callbacks:
+ connectivity = state.connectivity
+ else:
+ state.delivering = False
+ return
+
+
+def _spawn_delivery(state, callbacks):
+ delivering_thread = threading.Thread(
+ target=_deliver, args=(state, state.connectivity, callbacks,))
+ delivering_thread.start()
+ state.delivering = True
+
+
+# NOTE(https://github.com/grpc/grpc/issues/3064): We'd rather not poll.
+def _poll_connectivity(state, channel, initial_try_to_connect):
+ try_to_connect = initial_try_to_connect
+ connectivity = channel.check_connectivity_state(try_to_connect)
+ with state.lock:
+ state.connectivity = (
+ _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
+ connectivity])
+ callbacks = tuple(
+ callback for callback, unused_but_known_to_be_none_connectivity
+ in state.callbacks_and_connectivities)
+ for callback_and_connectivity in state.callbacks_and_connectivities:
+ callback_and_connectivity[1] = state.connectivity
+ if callbacks:
+ _spawn_delivery(state, callbacks)
+ completion_queue = cygrpc.CompletionQueue()
+ while True:
+ channel.watch_connectivity_state(
+ connectivity, cygrpc.Timespec(time.time() + 0.2),
+ completion_queue, None)
+ event = completion_queue.poll()
+ with state.lock:
+ if not state.callbacks_and_connectivities and not state.try_to_connect:
+ state.polling = False
+ state.connectivity = None
+ break
+ try_to_connect = state.try_to_connect
+ state.try_to_connect = False
+ if event.success or try_to_connect:
+ connectivity = channel.check_connectivity_state(try_to_connect)
+ with state.lock:
+ state.connectivity = (
+ _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
+ connectivity])
+ if not state.delivering:
+ callbacks = _deliveries(state)
+ if callbacks:
+ _spawn_delivery(state, callbacks)
+
+
+def _subscribe(state, callback, try_to_connect):
+ with state.lock:
+ if not state.callbacks_and_connectivities and not state.polling:
+ polling_thread = threading.Thread(
+ target=_poll_connectivity,
+ args=(state, state.channel, bool(try_to_connect)))
+ polling_thread.start()
+ state.polling = True
+ state.callbacks_and_connectivities.append([callback, None])
+ elif not state.delivering and state.connectivity is not None:
+ _spawn_delivery(state, (callback,))
+ state.try_to_connect |= bool(try_to_connect)
+ state.callbacks_and_connectivities.append(
+ [callback, state.connectivity])
+ else:
+ state.try_to_connect |= bool(try_to_connect)
+ state.callbacks_and_connectivities.append([callback, None])
+
+
+def _unsubscribe(state, callback):
+ with state.lock:
+ for index, (subscribed_callback, unused_connectivity) in enumerate(
+ state.callbacks_and_connectivities):
+ if callback == subscribed_callback:
+ state.callbacks_and_connectivities.pop(index)
+ break
+
+
+def _moot(state):
+ with state.lock:
+ del state.callbacks_and_connectivities[:]
+
+
+def _options(options):
+ if options is None:
+ pairs = ((cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT),)
+ else:
+ pairs = list(options) + [
+ (cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)]
+ return cygrpc.ChannelArgs(
+ cygrpc.ChannelArg(arg_name, arg_value) for arg_name, arg_value in pairs)
+
+
+class Channel(grpc.Channel):
+
+ def __init__(self, target, options, credentials):
+ self._channel = cygrpc.Channel(target, _options(options), credentials)
+ self._call_state = _ChannelCallState(self._channel)
+ self._connectivity_state = _ChannelConnectivityState(self._channel)
+
+ def subscribe(self, callback, try_to_connect=None):
+ _subscribe(self._connectivity_state, callback, try_to_connect)
+
+ def unsubscribe(self, callback):
+ _unsubscribe(self._connectivity_state, callback)
+
+ def unary_unary(
+ self, method, request_serializer=None, response_deserializer=None):
+ return _UnaryUnaryMultiCallable(
+ self._channel, _create_channel_managed_call(self._call_state), method,
+ request_serializer, response_deserializer)
+
+ def unary_stream(
+ self, method, request_serializer=None, response_deserializer=None):
+ return _UnaryStreamMultiCallable(
+ self._channel, _create_channel_managed_call(self._call_state), method,
+ request_serializer, response_deserializer)
+
+ def stream_unary(
+ self, method, request_serializer=None, response_deserializer=None):
+ return _StreamUnaryMultiCallable(
+ self._channel, _create_channel_managed_call(self._call_state), method,
+ request_serializer, response_deserializer)
+
+ def stream_stream(
+ self, method, request_serializer=None, response_deserializer=None):
+ return _StreamStreamMultiCallable(
+ self._channel, _create_channel_managed_call(self._call_state), method,
+ request_serializer, response_deserializer)
+
+ def __del__(self):
+ _moot(self._connectivity_state)
diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py
new file mode 100644
index 0000000000..f351bea9e3
--- /dev/null
+++ b/src/python/grpcio/grpc/_common.py
@@ -0,0 +1,154 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Shared implementation."""
+
+import logging
+import threading
+import time
+
+import six
+
+import grpc
+from grpc._cython import cygrpc
+
+_EMPTY_METADATA = cygrpc.Metadata(())
+
+CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = {
+ cygrpc.ConnectivityState.idle: grpc.ChannelConnectivity.IDLE,
+ cygrpc.ConnectivityState.connecting: grpc.ChannelConnectivity.CONNECTING,
+ cygrpc.ConnectivityState.ready: grpc.ChannelConnectivity.READY,
+ cygrpc.ConnectivityState.transient_failure:
+ grpc.ChannelConnectivity.TRANSIENT_FAILURE,
+ cygrpc.ConnectivityState.shutdown:
+ grpc.ChannelConnectivity.SHUTDOWN,
+}
+
+CYGRPC_STATUS_CODE_TO_STATUS_CODE = {
+ cygrpc.StatusCode.ok: grpc.StatusCode.OK,
+ cygrpc.StatusCode.cancelled: grpc.StatusCode.CANCELLED,
+ cygrpc.StatusCode.unknown: grpc.StatusCode.UNKNOWN,
+ cygrpc.StatusCode.invalid_argument: grpc.StatusCode.INVALID_ARGUMENT,
+ cygrpc.StatusCode.deadline_exceeded: grpc.StatusCode.DEADLINE_EXCEEDED,
+ cygrpc.StatusCode.not_found: grpc.StatusCode.NOT_FOUND,
+ cygrpc.StatusCode.already_exists: grpc.StatusCode.ALREADY_EXISTS,
+ cygrpc.StatusCode.permission_denied: grpc.StatusCode.PERMISSION_DENIED,
+ cygrpc.StatusCode.unauthenticated: grpc.StatusCode.UNAUTHENTICATED,
+ cygrpc.StatusCode.resource_exhausted: grpc.StatusCode.RESOURCE_EXHAUSTED,
+ cygrpc.StatusCode.failed_precondition: grpc.StatusCode.FAILED_PRECONDITION,
+ cygrpc.StatusCode.aborted: grpc.StatusCode.ABORTED,
+ cygrpc.StatusCode.out_of_range: grpc.StatusCode.OUT_OF_RANGE,
+ cygrpc.StatusCode.unimplemented: grpc.StatusCode.UNIMPLEMENTED,
+ cygrpc.StatusCode.internal: grpc.StatusCode.INTERNAL,
+ cygrpc.StatusCode.unavailable: grpc.StatusCode.UNAVAILABLE,
+ cygrpc.StatusCode.data_loss: grpc.StatusCode.DATA_LOSS,
+}
+STATUS_CODE_TO_CYGRPC_STATUS_CODE = {
+ grpc_code: cygrpc_code
+ for cygrpc_code, grpc_code in six.iteritems(
+ CYGRPC_STATUS_CODE_TO_STATUS_CODE)
+}
+
+
+def metadata(application_metadata):
+ return _EMPTY_METADATA if application_metadata is None else cygrpc.Metadata(
+ cygrpc.Metadatum(key, value) for key, value in application_metadata)
+
+
+def _transform(message, transformer, exception_message):
+ if transformer is None:
+ return message
+ else:
+ try:
+ return transformer(message)
+ except Exception: # pylint: disable=broad-except
+ logging.exception(exception_message)
+ return None
+
+
+def serialize(message, serializer):
+ return _transform(message, serializer, 'Exception serializing message!')
+
+
+def deserialize(serialized_message, deserializer):
+ return _transform(serialized_message, deserializer,
+ 'Exception deserializing message!')
+
+
+def _encode(s):
+ if isinstance(s, bytes):
+ return s
+ else:
+ return s.encode('ascii')
+
+
+def fully_qualified_method(group, method):
+ group = _encode(group)
+ method = _encode(method)
+ return b'/' + group + b'/' + method
+
+
+class CleanupThread(threading.Thread):
+ """A threading.Thread subclass supporting custom behavior on join().
+
+ On Python Interpreter exit, Python will attempt to join outstanding threads
+ prior to garbage collection. We may need to do additional cleanup, and
+ we accomplish this by overriding the join() method.
+ """
+
+ def __init__(self, behavior, group=None, target=None, name=None,
+ args=(), kwargs={}):
+ """Constructor.
+
+ Args:
+ behavior (function): Function called on join() with a single
+ argument, timeout, indicating the maximum duration of
+ `behavior`, or None indicating `behavior` has no deadline.
+ `behavior` must be idempotent.
+ group (None): should be None. Reseved for future extensions
+ when ThreadGroup is implemented.
+ target (function): The function to invoke when this thread is
+ run. Defaults to None.
+ name (str): The name of this thread. Defaults to None.
+ args (tuple[object]): A tuple of arguments to pass to `target`.
+ kwargs (dict[str,object]): A dictionary of keyword arguments to
+ pass to `target`.
+ """
+ super(CleanupThread, self).__init__(group=group, target=target,
+ name=name, args=args, kwargs=kwargs)
+ self._behavior = behavior
+
+ def join(self, timeout=None):
+ start_time = time.time()
+ self._behavior(timeout)
+ end_time = time.time()
+ if timeout is not None:
+ timeout -= end_time - start_time
+ timeout = max(timeout, 0)
+ super(CleanupThread, self).join(timeout)
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
index 1bfe6344e0..a09bbc75fe 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
@@ -55,6 +55,7 @@ cdef class Call:
def cancel(
self, grpc_status_code error_code=GRPC_STATUS__DO_NOT_USE,
details=None):
+ details = str_to_bytes(details)
if not self.is_valid:
raise ValueError("invalid call object cannot be used from Python")
if (details is None) != (error_code == GRPC_STATUS__DO_NOT_USE):
@@ -63,12 +64,6 @@ cdef class Call:
cdef grpc_call_error result
cdef char *c_details = NULL
if error_code != GRPC_STATUS__DO_NOT_USE:
- if isinstance(details, bytes):
- pass
- elif isinstance(details, basestring):
- details = details.encode()
- else:
- raise TypeError("expected details to be str or bytes")
self.references.append(details)
c_details = details
with nogil:
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
index c26bc083cf..866cff0d01 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
@@ -34,18 +34,13 @@ cdef class Channel:
def __cinit__(self, target, ChannelArgs arguments=None,
ChannelCredentials channel_credentials=None):
+ target = str_to_bytes(target)
cdef grpc_channel_args *c_arguments = NULL
cdef char *c_target = NULL
self.c_channel = NULL
self.references = []
if arguments is not None:
c_arguments = &arguments.c_args
- if isinstance(target, bytes):
- pass
- elif isinstance(target, basestring):
- target = target.encode()
- else:
- raise TypeError("expected target to be str or bytes")
c_target = target
if channel_credentials is None:
with nogil:
@@ -62,25 +57,14 @@ cdef class Channel:
def create_call(self, Call parent, int flags,
CompletionQueue queue not None,
method, host, Timespec deadline not None):
+ method = str_to_bytes(method)
+ host = str_to_bytes(host)
if queue.is_shutting_down:
raise ValueError("queue must not be shutting down or shutdown")
- if isinstance(method, bytes):
- pass
- elif isinstance(method, basestring):
- method = method.encode()
- else:
- raise TypeError("expected method to be str or bytes")
cdef char *method_c_string = method
cdef char *host_c_string = NULL
- if host is None:
- pass
- elif isinstance(host, bytes):
+ if host is not None:
host_c_string = host
- elif isinstance(host, basestring):
- host = host.encode()
- host_c_string = host
- else:
- raise TypeError("expected host to be str, bytes, or None")
cdef Call operation_call = Call()
operation_call.references = [self, method, host, queue]
cdef grpc_call *parent_call = NULL
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi
index a67c963684..01089c3dc0 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pxd.pxi
@@ -31,9 +31,6 @@
cdef class CompletionQueue:
cdef grpc_completion_queue *c_completion_queue
- cdef object pluck_condition
- cdef int num_plucking
- cdef int num_polling
cdef bint is_shutting_down
cdef bint is_shutdown
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
index cdae39d519..90266516fe 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
@@ -32,6 +32,8 @@ cimport cpython
import threading
import time
+cdef int _INTERRUPT_CHECK_PERIOD_MS = 200
+
cdef class CompletionQueue:
@@ -40,9 +42,6 @@ cdef class CompletionQueue:
self.c_completion_queue = grpc_completion_queue_create(NULL)
self.is_shutting_down = False
self.is_shutdown = False
- self.pluck_condition = threading.Condition()
- self.num_plucking = 0
- self.num_polling = 0
cdef _interpret_event(self, grpc_event event):
cdef OperationTag tag = None
@@ -83,45 +82,27 @@ cdef class CompletionQueue:
def poll(self, Timespec deadline=None):
# We name this 'poll' to avoid problems with CPython's expectations for
# 'special' methods (like next and __next__).
+ cdef gpr_timespec c_increment
+ cdef gpr_timespec c_timeout
cdef gpr_timespec c_deadline
with nogil:
+ c_increment = gpr_time_from_millis(_INTERRUPT_CHECK_PERIOD_MS, GPR_TIMESPAN)
c_deadline = gpr_inf_future(GPR_CLOCK_REALTIME)
- if deadline is not None:
- c_deadline = deadline.c_time
- cdef grpc_event event
-
- # Poll within a critical section to detect contention
- with self.pluck_condition:
- assert self.num_plucking == 0, 'cannot simultaneously pluck and poll'
- self.num_polling += 1
- with nogil:
- event = grpc_completion_queue_next(
- self.c_completion_queue, c_deadline, NULL)
- with self.pluck_condition:
- self.num_polling -= 1
- return self._interpret_event(event)
-
- def pluck(self, OperationTag tag, Timespec deadline=None):
- # Plucking a 'None' tag is equivalent to passing control to GRPC core until
- # the deadline.
- cdef gpr_timespec c_deadline = gpr_inf_future(
- GPR_CLOCK_REALTIME)
- if deadline is not None:
- c_deadline = deadline.c_time
- cdef grpc_event event
-
- # Pluck within a critical section to detect contention
- with self.pluck_condition:
- assert self.num_polling == 0, 'cannot simultaneously pluck and poll'
- assert self.num_plucking < GRPC_MAX_COMPLETION_QUEUE_PLUCKERS, (
- 'cannot pluck more than {} times simultaneously'.format(
- GRPC_MAX_COMPLETION_QUEUE_PLUCKERS))
- self.num_plucking += 1
- with nogil:
- event = grpc_completion_queue_pluck(
- self.c_completion_queue, <cpython.PyObject *>tag, c_deadline, NULL)
- with self.pluck_condition:
- self.num_plucking -= 1
+ if deadline is not None:
+ c_deadline = deadline.c_time
+
+ while True:
+ c_timeout = gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c_increment)
+ if gpr_time_cmp(c_timeout, c_deadline) > 0:
+ c_timeout = c_deadline
+ event = grpc_completion_queue_next(
+ self.c_completion_queue, c_timeout, NULL)
+ if event.type != GRPC_QUEUE_TIMEOUT or gpr_time_cmp(c_timeout, c_deadline) == 0:
+ break;
+
+ # Handle any signals
+ with gil:
+ cpython.PyErr_CheckSignals()
return self._interpret_event(event)
def shutdown(self):
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
index c793c8f5e5..d377a67520 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
@@ -54,7 +54,7 @@ cdef class ServerCredentials:
cdef class CredentialsMetadataPlugin:
cdef object plugin_callback
- cdef str plugin_name
+ cdef bytes plugin_name
cdef grpc_metadata_credentials_plugin make_c_plugin(self)
@@ -68,4 +68,4 @@ cdef void plugin_get_metadata(
void *state, grpc_auth_metadata_context context,
grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil
-cdef void plugin_destroy_c_plugin_state(void *state)
+cdef void plugin_destroy_c_plugin_state(void *state) with gil
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
index 94d13b5999..470382d609 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
@@ -82,7 +82,7 @@ cdef class ServerCredentials:
cdef class CredentialsMetadataPlugin:
- def __cinit__(self, object plugin_callback, str name):
+ def __cinit__(self, object plugin_callback, name):
"""
Args:
plugin_callback (callable): Callback accepting a service URL (str/bytes)
@@ -93,6 +93,7 @@ cdef class CredentialsMetadataPlugin:
successful).
name (str): Plugin name.
"""
+ name = str_to_bytes(name)
if not callable(plugin_callback):
raise ValueError('expected callable plugin_callback')
self.plugin_callback = plugin_callback
@@ -129,7 +130,8 @@ cdef void plugin_get_metadata(
grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil:
def python_callback(
Metadata metadata, grpc_status_code status,
- const char *error_details):
+ error_details):
+ error_details = str_to_bytes(error_details)
cb(user_data, metadata.c_metadata_array.metadata,
metadata.c_metadata_array.count, status, error_details)
cdef CredentialsMetadataPlugin self = <CredentialsMetadataPlugin>state
@@ -137,7 +139,7 @@ cdef void plugin_get_metadata(
cy_context.context = context
self.plugin_callback(cy_context, python_callback)
-cdef void plugin_destroy_c_plugin_state(void *state):
+cdef void plugin_destroy_c_plugin_state(void *state) with gil:
cpython.Py_DECREF(<CredentialsMetadataPlugin>state)
def channel_credentials_google_default():
@@ -148,14 +150,7 @@ def channel_credentials_google_default():
def channel_credentials_ssl(pem_root_certificates,
SslPemKeyCertPair ssl_pem_key_cert_pair):
- if pem_root_certificates is None:
- pass
- elif isinstance(pem_root_certificates, bytes):
- pass
- elif isinstance(pem_root_certificates, basestring):
- pem_root_certificates = pem_root_certificates.encode()
- else:
- raise TypeError("expected str or bytes for pem_root_certificates")
+ pem_root_certificates = str_to_bytes(pem_root_certificates)
cdef ChannelCredentials credentials = ChannelCredentials()
cdef const char *c_pem_root_certificates = NULL
if pem_root_certificates is not None:
@@ -207,12 +202,7 @@ def call_credentials_google_compute_engine():
def call_credentials_service_account_jwt_access(
json_key, Timespec token_lifetime not None):
- if isinstance(json_key, bytes):
- pass
- elif isinstance(json_key, basestring):
- json_key = json_key.encode()
- else:
- raise TypeError("expected json_key to be str or bytes")
+ json_key = str_to_bytes(json_key)
cdef CallCredentials credentials = CallCredentials()
cdef char *json_key_c_string = json_key
with nogil:
@@ -223,12 +213,7 @@ def call_credentials_service_account_jwt_access(
return credentials
def call_credentials_google_refresh_token(json_refresh_token):
- if isinstance(json_refresh_token, bytes):
- pass
- elif isinstance(json_refresh_token, basestring):
- json_refresh_token = json_refresh_token.encode()
- else:
- raise TypeError("expected json_refresh_token to be str or bytes")
+ json_refresh_token = str_to_bytes(json_refresh_token)
cdef CallCredentials credentials = CallCredentials()
cdef char *json_refresh_token_c_string = json_refresh_token
with nogil:
@@ -238,18 +223,8 @@ def call_credentials_google_refresh_token(json_refresh_token):
return credentials
def call_credentials_google_iam(authorization_token, authority_selector):
- if isinstance(authorization_token, bytes):
- pass
- elif isinstance(authorization_token, basestring):
- authorization_token = authorization_token.encode()
- else:
- raise TypeError("expected authorization_token to be str or bytes")
- if isinstance(authority_selector, bytes):
- pass
- elif isinstance(authority_selector, basestring):
- authority_selector = authority_selector.encode()
- else:
- raise TypeError("expected authority_selector to be str or bytes")
+ authorization_token = str_to_bytes(authorization_token)
+ authority_selector = str_to_bytes(authority_selector)
cdef CallCredentials credentials = CallCredentials()
cdef char *authorization_token_c_string = authorization_token
cdef char *authority_selector_c_string = authority_selector
@@ -272,16 +247,10 @@ def call_credentials_metadata_plugin(CredentialsMetadataPlugin plugin):
def server_credentials_ssl(pem_root_certs, pem_key_cert_pairs,
bint force_client_auth):
+ pem_root_certs = str_to_bytes(pem_root_certs)
cdef char *c_pem_root_certs = NULL
- if pem_root_certs is None:
- pass
- elif isinstance(pem_root_certs, bytes):
- c_pem_root_certs = pem_root_certs
- elif isinstance(pem_root_certs, basestring):
- pem_root_certs = pem_root_certs.encode()
+ if pem_root_certs is not None:
c_pem_root_certs = pem_root_certs
- else:
- raise TypeError("expected pem_root_certs to be str or bytes")
pem_key_cert_pairs = list(pem_key_cert_pairs)
for pair in pem_key_cert_pairs:
if not isinstance(pair, SslPemKeyCertPair):
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
index 66e6e6b549..168b9751aa 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
@@ -80,6 +80,12 @@ cdef extern from "grpc/_cython/loader.h":
gpr_timespec gpr_convert_clock_type(gpr_timespec t,
gpr_clock_type target_clock) nogil
+ gpr_timespec gpr_time_from_millis(int64_t ms, gpr_clock_type type) nogil
+
+ gpr_timespec gpr_time_add(gpr_timespec a, gpr_timespec b) nogil
+
+ int gpr_time_cmp(gpr_timespec a, gpr_timespec b) nogil
+
ctypedef enum grpc_status_code:
GRPC_STATUS_OK
GRPC_STATUS_CANCELLED
@@ -208,7 +214,7 @@ cdef extern from "grpc/_cython/loader.h":
GRPC_CHANNEL_CONNECTING
GRPC_CHANNEL_READY
GRPC_CHANNEL_TRANSIENT_FAILURE
- GRPC_CHANNEL_FATAL_FAILURE
+ GRPC_CHANNEL_SHUTDOWN
ctypedef struct grpc_metadata:
const char *key
@@ -336,6 +342,8 @@ cdef extern from "grpc/_cython/loader.h":
void grpc_server_register_completion_queue(grpc_server *server,
grpc_completion_queue *cq,
void *reserved) nogil
+ void grpc_server_register_non_listening_completion_queue(
+ grpc_server *server, grpc_completion_queue *cq, void *reserved) nogil
int grpc_server_add_insecure_http2_port(
grpc_server *server, const char *addr) nogil
void grpc_server_start(grpc_server *server) nogil
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi
new file mode 100644
index 0000000000..274f038e0a
--- /dev/null
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi
@@ -0,0 +1,39 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+# This function will ascii encode unicode string inputs if neccesary.
+# In Python3, unicode strings are the default str type.
+cdef bytes str_to_bytes(object s):
+ if s is None or isinstance(s, bytes):
+ return s
+ elif isinstance(s, unicode):
+ return s.encode('ascii')
+ else:
+ raise TypeError('Expected bytes, str, or unicode, not {}'.format(type(s)))
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
index c7539f0d49..0055d0d3a2 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
@@ -33,7 +33,7 @@ class ConnectivityState:
connecting = GRPC_CHANNEL_CONNECTING
ready = GRPC_CHANNEL_READY
transient_failure = GRPC_CHANNEL_TRANSIENT_FAILURE
- fatal_failure = GRPC_CHANNEL_FATAL_FAILURE
+ shutdown = GRPC_CHANNEL_SHUTDOWN
class ChannelArgKey:
@@ -235,18 +235,13 @@ cdef class ByteBuffer:
if data is None:
self.c_byte_buffer = NULL
return
- if isinstance(data, bytes):
- pass
- elif isinstance(data, basestring):
- data = data.encode()
- elif isinstance(data, ByteBuffer):
+ if isinstance(data, ByteBuffer):
data = (<ByteBuffer>data).bytes()
if data is None:
self.c_byte_buffer = NULL
return
else:
- raise TypeError("expected value to be of type str, bytes, or "
- "ByteBuffer, not {}".format(type(data)))
+ data = str_to_bytes(data)
cdef char *c_data = data
cdef gpr_slice data_slice
@@ -274,6 +269,7 @@ cdef class ByteBuffer:
data_slice_length = gpr_slice_length(data_slice)
with gil:
result += (<char *>data_slice_pointer)[:data_slice_length]
+ gpr_slice_unref(data_slice)
with nogil:
grpc_byte_buffer_reader_destroy(&reader)
return bytes(result)
@@ -301,19 +297,8 @@ cdef class ByteBuffer:
cdef class SslPemKeyCertPair:
def __cinit__(self, private_key, certificate_chain):
- if isinstance(private_key, bytes):
- self.private_key = private_key
- elif isinstance(private_key, basestring):
- self.private_key = private_key.encode()
- else:
- raise TypeError("expected private_key to be of type str or bytes")
- if isinstance(certificate_chain, bytes):
- self.certificate_chain = certificate_chain
- elif isinstance(certificate_chain, basestring):
- self.certificate_chain = certificate_chain.encode()
- else:
- raise TypeError("expected certificate_chain to be of type str or bytes "
- "or int")
+ self.private_key = str_to_bytes(private_key)
+ self.certificate_chain = str_to_bytes(certificate_chain)
self.c_pair.private_key = self.private_key
self.c_pair.certificate_chain = self.certificate_chain
@@ -321,27 +306,16 @@ cdef class SslPemKeyCertPair:
cdef class ChannelArg:
def __cinit__(self, key, value):
- if isinstance(key, bytes):
- self.key = key
- elif isinstance(key, basestring):
- self.key = key.encode()
- else:
- raise TypeError("expected key to be of type str or bytes")
- if isinstance(value, bytes):
- self.value = value
- self.c_arg.type = GRPC_ARG_STRING
- self.c_arg.value.string = self.value
- elif isinstance(value, basestring):
- self.value = value.encode()
- self.c_arg.type = GRPC_ARG_STRING
- self.c_arg.value.string = self.value
- elif isinstance(value, int):
+ self.key = str_to_bytes(key)
+ self.c_arg.key = self.key
+ if isinstance(value, int):
self.value = int(value)
self.c_arg.type = GRPC_ARG_INTEGER
self.c_arg.value.integer = self.value
else:
- raise TypeError("expected value to be of type str or bytes or int")
- self.c_arg.key = self.key
+ self.value = str_to_bytes(value)
+ self.c_arg.type = GRPC_ARG_STRING
+ self.c_arg.value.string = self.value
cdef class ChannelArgs:
@@ -374,18 +348,8 @@ cdef class ChannelArgs:
cdef class Metadatum:
def __cinit__(self, key, value):
- if isinstance(key, bytes):
- self._key = key
- elif isinstance(key, basestring):
- self._key = key.encode()
- else:
- raise TypeError("expected key to be of type str or bytes")
- if isinstance(value, bytes):
- self._value = value
- elif isinstance(value, basestring):
- self._value = value.encode()
- else:
- raise TypeError("expected value to be of type str or bytes")
+ self._key = str_to_bytes(key)
+ self._value = str_to_bytes(value)
self.c_metadata.key = self._key
self.c_metadata.value = self._value
self.c_metadata.value_length = len(self._value)
@@ -600,12 +564,7 @@ def operation_send_close_from_client(int flags):
def operation_send_status_from_server(
Metadata metadata, grpc_status_code code, details, int flags):
- if isinstance(details, bytes):
- pass
- elif isinstance(details, basestring):
- details = details.encode()
- else:
- raise TypeError("expected a str or bytes object for details")
+ details = str_to_bytes(details)
cdef Operation op = Operation()
op.c_op.type = GRPC_OP_SEND_STATUS_FROM_SERVER
op.c_op.flags = flags
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
index 8419a59068..42afeb8498 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
@@ -81,25 +81,29 @@ cdef class Server:
self.c_server, queue.c_completion_queue, NULL)
self.registered_completion_queues.append(queue)
+ def register_non_listening_completion_queue(
+ self, CompletionQueue queue not None):
+ if self.is_started:
+ raise ValueError("cannot register completion queues after start")
+ with nogil:
+ grpc_server_register_non_listening_completion_queue(
+ self.c_server, queue.c_completion_queue, NULL)
+ self.registered_completion_queues.append(queue)
+
def start(self):
if self.is_started:
raise ValueError("the server has already started")
self.backup_shutdown_queue = CompletionQueue()
- self.register_completion_queue(self.backup_shutdown_queue)
+ self.register_non_listening_completion_queue(self.backup_shutdown_queue)
self.is_started = True
with nogil:
grpc_server_start(self.c_server)
# Ensure the core has gotten a chance to do the start-up work
- self.backup_shutdown_queue.pluck(None, Timespec(None))
+ self.backup_shutdown_queue.poll(Timespec(None))
def add_http2_port(self, address,
ServerCredentials server_credentials=None):
- if isinstance(address, bytes):
- pass
- elif isinstance(address, basestring):
- address = address.encode()
- else:
- raise TypeError("expected address to be a str or bytes")
+ address = str_to_bytes(address)
self.references.append(address)
cdef int result
cdef char *address_c_string = address
@@ -169,4 +173,3 @@ cdef class Server:
time.sleep(0)
with nogil:
grpc_server_destroy(self.c_server)
-
diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx
index 8823ea5ef5..cf146f5a04 100644
--- a/src/python/grpcio/grpc/_cython/cygrpc.pyx
+++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx
@@ -35,6 +35,7 @@ import sys
# TODO(atash): figure out why the coverage tool gets confused about the Cython
# coverage plugin when the following files don't have a '.pxi' suffix.
+include "grpc/_cython/_cygrpc/grpc_string.pyx.pxi"
include "grpc/_cython/_cygrpc/call.pyx.pxi"
include "grpc/_cython/_cygrpc/channel.pyx.pxi"
include "grpc/_cython/_cygrpc/credentials.pyx.pxi"
diff --git a/src/python/grpcio/grpc/_cython/imports.generated.c b/src/python/grpcio/grpc/_cython/imports.generated.c
index 3cf4cb491a..8437e74ba0 100644
--- a/src/python/grpcio/grpc/_cython/imports.generated.c
+++ b/src/python/grpcio/grpc/_cython/imports.generated.c
@@ -35,7 +35,7 @@
#include "imports.generated.h"
-#ifdef GPR_WIN32
+#ifdef GPR_WINDOWS
census_initialize_type census_initialize_import;
census_shutdown_type census_shutdown_import;
@@ -126,7 +126,8 @@ grpc_header_key_is_legal_type grpc_header_key_is_legal_import;
grpc_header_nonbin_value_is_legal_type grpc_header_nonbin_value_is_legal_import;
grpc_is_binary_header_type grpc_is_binary_header_import;
grpc_call_error_to_string_type grpc_call_error_to_string_import;
-grpc_cronet_secure_channel_create_type grpc_cronet_secure_channel_create_import;
+grpc_insecure_channel_create_from_fd_type grpc_insecure_channel_create_from_fd_import;
+grpc_server_add_insecure_channel_from_fd_type grpc_server_add_insecure_channel_from_fd_import;
grpc_auth_property_iterator_next_type grpc_auth_property_iterator_next_import;
grpc_auth_context_property_iterator_type grpc_auth_context_property_iterator_import;
grpc_auth_context_peer_identity_type grpc_auth_context_peer_identity_import;
@@ -400,7 +401,8 @@ void pygrpc_load_imports(HMODULE library) {
grpc_header_nonbin_value_is_legal_import = (grpc_header_nonbin_value_is_legal_type) GetProcAddress(library, "grpc_header_nonbin_value_is_legal");
grpc_is_binary_header_import = (grpc_is_binary_header_type) GetProcAddress(library, "grpc_is_binary_header");
grpc_call_error_to_string_import = (grpc_call_error_to_string_type) GetProcAddress(library, "grpc_call_error_to_string");
- grpc_cronet_secure_channel_create_import = (grpc_cronet_secure_channel_create_type) GetProcAddress(library, "grpc_cronet_secure_channel_create");
+ grpc_insecure_channel_create_from_fd_import = (grpc_insecure_channel_create_from_fd_type) GetProcAddress(library, "grpc_insecure_channel_create_from_fd");
+ grpc_server_add_insecure_channel_from_fd_import = (grpc_server_add_insecure_channel_from_fd_type) GetProcAddress(library, "grpc_server_add_insecure_channel_from_fd");
grpc_auth_property_iterator_next_import = (grpc_auth_property_iterator_next_type) GetProcAddress(library, "grpc_auth_property_iterator_next");
grpc_auth_context_property_iterator_import = (grpc_auth_context_property_iterator_type) GetProcAddress(library, "grpc_auth_context_property_iterator");
grpc_auth_context_peer_identity_import = (grpc_auth_context_peer_identity_type) GetProcAddress(library, "grpc_auth_context_peer_identity");
@@ -585,4 +587,4 @@ void pygrpc_load_imports(HMODULE library) {
}
#endif /* __cpluslus */
-#endif /* !GPR_WIN32 */
+#endif /* !GPR_WINDOWS */
diff --git a/src/python/grpcio/grpc/_cython/imports.generated.h b/src/python/grpcio/grpc/_cython/imports.generated.h
index 8a3e4bf86e..d52e8591b3 100644
--- a/src/python/grpcio/grpc/_cython/imports.generated.h
+++ b/src/python/grpcio/grpc/_cython/imports.generated.h
@@ -36,14 +36,14 @@
#include <grpc/support/port_platform.h>
-#ifdef GPR_WIN32
+#ifdef GPR_WINDOWS
#include <windows.h>
#include <grpc/census.h>
#include <grpc/compression.h>
#include <grpc/grpc.h>
-#include <grpc/grpc_cronet.h>
+#include <grpc/grpc_posix.h>
#include <grpc/grpc_security.h>
#include <grpc/impl/codegen/alloc.h>
#include <grpc/impl/codegen/byte_buffer.h>
@@ -57,7 +57,7 @@
#include <grpc/support/cpu.h>
#include <grpc/support/histogram.h>
#include <grpc/support/host_port.h>
-#include <grpc/support/log_win32.h>
+#include <grpc/support/log_windows.h>
#include <grpc/support/string_util.h>
#include <grpc/support/subprocess.h>
#include <grpc/support/thd.h>
@@ -329,9 +329,12 @@ extern grpc_is_binary_header_type grpc_is_binary_header_import;
typedef const char *(*grpc_call_error_to_string_type)(grpc_call_error error);
extern grpc_call_error_to_string_type grpc_call_error_to_string_import;
#define grpc_call_error_to_string grpc_call_error_to_string_import
-typedef grpc_channel *(*grpc_cronet_secure_channel_create_type)(void *engine, const char *target, const grpc_channel_args *args, void *reserved);
-extern grpc_cronet_secure_channel_create_type grpc_cronet_secure_channel_create_import;
-#define grpc_cronet_secure_channel_create grpc_cronet_secure_channel_create_import
+typedef grpc_channel *(*grpc_insecure_channel_create_from_fd_type)(const char *target, int fd, const grpc_channel_args *args);
+extern grpc_insecure_channel_create_from_fd_type grpc_insecure_channel_create_from_fd_import;
+#define grpc_insecure_channel_create_from_fd grpc_insecure_channel_create_from_fd_import
+typedef void(*grpc_server_add_insecure_channel_from_fd_type)(grpc_server *server, grpc_completion_queue *cq, int fd);
+extern grpc_server_add_insecure_channel_from_fd_type grpc_server_add_insecure_channel_from_fd_import;
+#define grpc_server_add_insecure_channel_from_fd grpc_server_add_insecure_channel_from_fd_import
typedef const grpc_auth_property *(*grpc_auth_property_iterator_next_type)(grpc_auth_property_iterator *it);
extern grpc_auth_property_iterator_next_type grpc_auth_property_iterator_next_import;
#define grpc_auth_property_iterator_next grpc_auth_property_iterator_next_import
@@ -479,7 +482,7 @@ extern grpc_byte_buffer_reader_readall_type grpc_byte_buffer_reader_readall_impo
typedef grpc_byte_buffer *(*grpc_raw_byte_buffer_from_reader_type)(grpc_byte_buffer_reader *reader);
extern grpc_raw_byte_buffer_from_reader_type grpc_raw_byte_buffer_from_reader_import;
#define grpc_raw_byte_buffer_from_reader grpc_raw_byte_buffer_from_reader_import
-typedef void(*gpr_log_type)(const char *file, int line, gpr_log_severity severity, const char *format, ...);
+typedef void(*gpr_log_type)(const char *file, int line, gpr_log_severity severity, const char *format, ...) GPRC_PRINT_FORMAT_CHECK(4, 5);
extern gpr_log_type gpr_log_import;
#define gpr_log gpr_log_import
typedef void(*gpr_log_message_type)(const char *file, int line, gpr_log_severity severity, const char *message);
@@ -824,7 +827,7 @@ extern gpr_format_message_type gpr_format_message_import;
typedef char *(*gpr_strdup_type)(const char *src);
extern gpr_strdup_type gpr_strdup_import;
#define gpr_strdup gpr_strdup_import
-typedef int(*gpr_asprintf_type)(char **strp, const char *format, ...);
+typedef int(*gpr_asprintf_type)(char **strp, const char *format, ...) GPRC_PRINT_FORMAT_CHECK(2, 3);
extern gpr_asprintf_type gpr_asprintf_import;
#define gpr_asprintf gpr_asprintf_import
typedef const char *(*gpr_subprocess_binary_extension_type)();
@@ -877,7 +880,7 @@ void pygrpc_load_imports(HMODULE library);
}
#endif /* __cpluslus */
-#else /* !GPR_WIN32 */
+#else /* !GPR_WINDOWS */
#include <grpc/byte_buffer.h>
#include <grpc/byte_buffer_reader.h>
@@ -889,6 +892,6 @@ void pygrpc_load_imports(HMODULE library);
#include <grpc/support/time.h>
#include <grpc/status.h>
-#endif /* !GPR_WIN32 */
+#endif /* !GPR_WINDOWS */
#endif
diff --git a/src/python/grpcio/grpc/_cython/loader.c b/src/python/grpcio/grpc/_cython/loader.c
index 3b72806ea1..b909ad594e 100644
--- a/src/python/grpcio/grpc/_cython/loader.c
+++ b/src/python/grpcio/grpc/_cython/loader.c
@@ -37,7 +37,7 @@
extern "C" {
#endif /* __cpluslus */
-#if GPR_WIN32
+#if GPR_WINDOWS
int pygrpc_load_core(char *path) {
HMODULE grpc_c;
@@ -60,7 +60,7 @@ int pygrpc_load_core(char *path) {
int pygrpc_load_core(char *path) { return 1; }
-#endif /* !GPR_WIN32 */
+#endif /* !GPR_WINDOWS */
#ifdef __cplusplus
}
diff --git a/src/python/grpcio/grpc/_links/service.py b/src/python/grpcio/grpc/_links/service.py
index 11310e2240..5fc4994ca0 100644
--- a/src/python/grpcio/grpc/_links/service.py
+++ b/src/python/grpcio/grpc/_links/service.py
@@ -33,6 +33,7 @@ import abc
import enum
import logging
import threading
+import six
import time
from grpc._adapter import _intermediary_low
@@ -177,7 +178,10 @@ class _Kernel(object):
call = service_acceptance.call
call.accept(self._completion_queue, call)
try:
- group, method = service_acceptance.method.split(b'/')[1:3]
+ service_method = service_acceptance.method
+ if six.PY3:
+ service_method = service_method.decode('latin1')
+ group, method = service_method.split('/')[1:3]
except ValueError:
logging.info('Illegal path "%s"!', service_acceptance.method)
return
diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py
new file mode 100644
index 0000000000..4e9cfe710c
--- /dev/null
+++ b/src/python/grpcio/grpc/_plugin_wrapping.py
@@ -0,0 +1,123 @@
+# Copyright 2015, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import collections
+import threading
+
+import grpc
+from grpc._cython import cygrpc
+
+
+class AuthMetadataContext(
+ collections.namedtuple(
+ 'AuthMetadataContext', ('service_url', 'method_name',)),
+ grpc.AuthMetadataContext):
+ pass
+
+
+class AuthMetadataPluginCallback(grpc.AuthMetadataContext):
+
+ def __init__(self, callback):
+ self._callback = callback
+
+ def __call__(self, metadata, error):
+ self._callback(metadata, error)
+
+
+class _WrappedCygrpcCallback(object):
+
+ def __init__(self, cygrpc_callback):
+ self.is_called = False
+ self.error = None
+ self.is_called_lock = threading.Lock()
+ self.cygrpc_callback = cygrpc_callback
+
+ def _invoke_failure(self, error):
+ # TODO(atash) translate different Exception superclasses into different
+ # status codes.
+ self.cygrpc_callback(
+ cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message)
+
+ def _invoke_success(self, metadata):
+ try:
+ cygrpc_metadata = cygrpc.Metadata(
+ cygrpc.Metadatum(key, value)
+ for key, value in metadata)
+ except Exception as error:
+ self._invoke_failure(error)
+ return
+ self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '')
+
+ def __call__(self, metadata, error):
+ with self.is_called_lock:
+ if self.is_called:
+ raise RuntimeError('callback should only ever be invoked once')
+ if self.error:
+ self._invoke_failure(self.error)
+ return
+ self.is_called = True
+ if error is None:
+ self._invoke_success(metadata)
+ else:
+ self._invoke_failure(error)
+
+ def notify_failure(self, error):
+ with self.is_called_lock:
+ if not self.is_called:
+ self.error = error
+
+
+class _WrappedPlugin(object):
+
+ def __init__(self, plugin):
+ self.plugin = plugin
+
+ def __call__(self, context, cygrpc_callback):
+ wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback)
+ wrapped_context = AuthMetadataContext(
+ context.service_url, context.method_name)
+ try:
+ self.plugin(
+ wrapped_context, AuthMetadataPluginCallback(wrapped_cygrpc_callback))
+ except Exception as error:
+ wrapped_cygrpc_callback.notify_failure(error)
+ raise
+
+
+def call_credentials_metadata_plugin(plugin, name):
+ """
+ Args:
+ plugin: A callable accepting a grpc.AuthMetadataContext
+ object and a callback (itself accepting a list of metadata key/value
+ 2-tuples and a None-able exception value). The callback must be eventually
+ called, but need not be called in plugin's invocation.
+ plugin's invocation must be non-blocking.
+ """
+ return cygrpc.call_credentials_metadata_plugin(
+ cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name))
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py
new file mode 100644
index 0000000000..bf20f15f72
--- /dev/null
+++ b/src/python/grpcio/grpc/_server.py
@@ -0,0 +1,755 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Service-side implementation of gRPC Python."""
+
+import collections
+import enum
+import logging
+import threading
+import time
+
+import grpc
+from grpc import _common
+from grpc._cython import cygrpc
+from grpc.framework.foundation import callable_util
+
+_SHUTDOWN_TAG = 'shutdown'
+_REQUEST_CALL_TAG = 'request_call'
+
+_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server'
+_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata'
+_RECEIVE_MESSAGE_TOKEN = 'receive_message'
+_SEND_MESSAGE_TOKEN = 'send_message'
+_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
+ 'send_initial_metadata * send_message')
+_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server'
+_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
+ 'send_initial_metadata * send_status_from_server')
+
+_OPEN = 'open'
+_CLOSED = 'closed'
+_CANCELLED = 'cancelled'
+
+_EMPTY_FLAGS = 0
+_EMPTY_METADATA = cygrpc.Metadata(())
+
+_UNEXPECTED_EXIT_SERVER_GRACE = 1.0
+
+
+def _serialized_request(request_event):
+ return request_event.batch_operations[0].received_message.bytes()
+
+
+def _application_code(code):
+ cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
+ return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
+
+
+def _completion_code(state):
+ if state.code is None:
+ return cygrpc.StatusCode.ok
+ else:
+ return _application_code(state.code)
+
+
+def _abortion_code(state, code):
+ if state.code is None:
+ return code
+ else:
+ return _application_code(state.code)
+
+
+def _details(state):
+ return '' if state.details is None else state.details
+
+
+class _HandlerCallDetails(
+ collections.namedtuple(
+ '_HandlerCallDetails', ('method', 'invocation_metadata',)),
+ grpc.HandlerCallDetails):
+ pass
+
+
+class _RPCState(object):
+
+ def __init__(self):
+ self.condition = threading.Condition()
+ self.due = set()
+ self.request = None
+ self.client = _OPEN
+ self.initial_metadata_allowed = True
+ self.disable_next_compression = False
+ self.trailing_metadata = None
+ self.code = None
+ self.details = None
+ self.statused = False
+ self.rpc_errors = []
+ self.callbacks = []
+
+
+def _raise_rpc_error(state):
+ rpc_error = grpc.RpcError()
+ state.rpc_errors.append(rpc_error)
+ raise rpc_error
+
+
+def _possibly_finish_call(state, token):
+ state.due.remove(token)
+ if (state.client is _CANCELLED or state.statused) and not state.due:
+ callbacks = state.callbacks
+ state.callbacks = None
+ return state, callbacks
+ else:
+ return None, ()
+
+
+def _send_status_from_server(state, token):
+ def send_status_from_server(unused_send_status_from_server_event):
+ with state.condition:
+ return _possibly_finish_call(state, token)
+ return send_status_from_server
+
+
+def _abort(state, call, code, details):
+ if state.client is not _CANCELLED:
+ effective_code = _abortion_code(state, code)
+ effective_details = details if state.details is None else state.details
+ if state.initial_metadata_allowed:
+ operations = (
+ cygrpc.operation_send_initial_metadata(
+ _EMPTY_METADATA, _EMPTY_FLAGS),
+ cygrpc.operation_send_status_from_server(
+ _common.metadata(state.trailing_metadata), effective_code,
+ effective_details, _EMPTY_FLAGS),
+ )
+ token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
+ else:
+ operations = (
+ cygrpc.operation_send_status_from_server(
+ _common.metadata(state.trailing_metadata), effective_code,
+ effective_details, _EMPTY_FLAGS),
+ )
+ token = _SEND_STATUS_FROM_SERVER_TOKEN
+ call.start_batch(
+ cygrpc.Operations(operations),
+ _send_status_from_server(state, token))
+ state.statused = True
+ state.due.add(token)
+
+
+def _receive_close_on_server(state):
+ def receive_close_on_server(receive_close_on_server_event):
+ with state.condition:
+ if receive_close_on_server_event.batch_operations[0].received_cancelled:
+ state.client = _CANCELLED
+ elif state.client is _OPEN:
+ state.client = _CLOSED
+ state.condition.notify_all()
+ return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
+ return receive_close_on_server
+
+
+def _receive_message(state, call, request_deserializer):
+ def receive_message(receive_message_event):
+ serialized_request = _serialized_request(receive_message_event)
+ if serialized_request is None:
+ with state.condition:
+ if state.client is _OPEN:
+ state.client = _CLOSED
+ state.condition.notify_all()
+ return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
+ else:
+ request = _common.deserialize(serialized_request, request_deserializer)
+ with state.condition:
+ if request is None:
+ _abort(
+ state, call, cygrpc.StatusCode.internal,
+ 'Exception deserializing request!')
+ else:
+ state.request = request
+ state.condition.notify_all()
+ return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
+ return receive_message
+
+
+def _send_initial_metadata(state):
+ def send_initial_metadata(unused_send_initial_metadata_event):
+ with state.condition:
+ return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
+ return send_initial_metadata
+
+
+def _send_message(state, token):
+ def send_message(unused_send_message_event):
+ with state.condition:
+ state.condition.notify_all()
+ return _possibly_finish_call(state, token)
+ return send_message
+
+
+class _Context(grpc.ServicerContext):
+
+ def __init__(self, rpc_event, state, request_deserializer):
+ self._rpc_event = rpc_event
+ self._state = state
+ self._request_deserializer = request_deserializer
+
+ def is_active(self):
+ with self._state.condition:
+ return self._state.client is not _CANCELLED and not self._state.statused
+
+ def time_remaining(self):
+ return max(self._rpc_event.request_call_details.deadline - time.time(), 0)
+
+ def cancel(self):
+ self._rpc_event.operation_call.cancel()
+
+ def add_callback(self, callback):
+ with self._state.condition:
+ if self._state.callbacks is None:
+ return False
+ else:
+ self._state.callbacks.append(callback)
+ return True
+
+ def disable_next_message_compression(self):
+ with self._state.condition:
+ self._state.disable_next_compression = True
+
+ def invocation_metadata(self):
+ return self._rpc_event.request_metadata
+
+ def peer(self):
+ return self._rpc_event.operation_call.peer()
+
+ def send_initial_metadata(self, initial_metadata):
+ with self._state.condition:
+ if self._state.client is _CANCELLED:
+ _raise_rpc_error(self._state)
+ else:
+ if self._state.initial_metadata_allowed:
+ operation = cygrpc.operation_send_initial_metadata(
+ _common.metadata(initial_metadata), _EMPTY_FLAGS)
+ self._rpc_event.operation_call.start_batch(
+ cygrpc.Operations((operation,)),
+ _send_initial_metadata(self._state))
+ self._state.initial_metadata_allowed = False
+ self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
+ else:
+ raise ValueError('Initial metadata no longer allowed!')
+
+ def set_trailing_metadata(self, trailing_metadata):
+ with self._state.condition:
+ self._state.trailing_metadata = trailing_metadata
+
+ def set_code(self, code):
+ with self._state.condition:
+ self._state.code = code
+
+ def set_details(self, details):
+ with self._state.condition:
+ self._state.details = details
+
+
+class _RequestIterator(object):
+
+ def __init__(self, state, call, request_deserializer):
+ self._state = state
+ self._call = call
+ self._request_deserializer = request_deserializer
+
+ def _raise_or_start_receive_message(self):
+ if self._state.client is _CANCELLED:
+ _raise_rpc_error(self._state)
+ elif self._state.client is _CLOSED or self._state.statused:
+ raise StopIteration()
+ else:
+ self._call.start_batch(
+ cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
+ _receive_message(self._state, self._call, self._request_deserializer))
+ self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
+
+ def _look_for_request(self):
+ if self._state.client is _CANCELLED:
+ _raise_rpc_error(self._state)
+ elif (self._state.request is None and
+ _RECEIVE_MESSAGE_TOKEN not in self._state.due):
+ raise StopIteration()
+ else:
+ request = self._state.request
+ self._state.request = None
+ return request
+
+ def _next(self):
+ with self._state.condition:
+ self._raise_or_start_receive_message()
+ while True:
+ self._state.condition.wait()
+ request = self._look_for_request()
+ if request is not None:
+ return request
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return self._next()
+
+ def next(self):
+ return self._next()
+
+
+def _unary_request(rpc_event, state, request_deserializer):
+ def unary_request():
+ with state.condition:
+ if state.client is _CANCELLED or state.statused:
+ return None
+ else:
+ start_batch_result = rpc_event.operation_call.start_batch(
+ cygrpc.Operations(
+ (cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
+ _receive_message(
+ state, rpc_event.operation_call, request_deserializer))
+ state.due.add(_RECEIVE_MESSAGE_TOKEN)
+ while True:
+ state.condition.wait()
+ if state.request is None:
+ if state.client is _CLOSED:
+ details = '"{}" requires exactly one request message.'.format(
+ rpc_event.request_call_details.method)
+ _abort(
+ state, rpc_event.operation_call,
+ cygrpc.StatusCode.unimplemented, details)
+ return None
+ elif state.client is _CANCELLED:
+ return None
+ else:
+ request = state.request
+ state.request = None
+ return request
+ return unary_request
+
+
+def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
+ context = _Context(rpc_event, state, request_deserializer)
+ try:
+ return behavior(argument, context), True
+ except Exception as e: # pylint: disable=broad-except
+ with state.condition:
+ if e not in state.rpc_errors:
+ details = 'Exception calling application: {}'.format(e)
+ logging.exception(details)
+ _abort(
+ state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details)
+ return None, False
+
+
+def _take_response_from_response_iterator(rpc_event, state, response_iterator):
+ try:
+ return next(response_iterator), True
+ except StopIteration:
+ return None, True
+ except Exception as e: # pylint: disable=broad-except
+ with state.condition:
+ if e not in state.rpc_errors:
+ details = 'Exception iterating responses: {}'.format(e)
+ logging.exception(details)
+ _abort(
+ state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details)
+ return None, False
+
+
+def _serialize_response(rpc_event, state, response, response_serializer):
+ serialized_response = _common.serialize(response, response_serializer)
+ if serialized_response is None:
+ with state.condition:
+ _abort(
+ state, rpc_event.operation_call, cygrpc.StatusCode.internal,
+ 'Failed to serialize response!')
+ return None
+ else:
+ return serialized_response
+
+
+def _send_response(rpc_event, state, serialized_response):
+ with state.condition:
+ if state.client is _CANCELLED or state.statused:
+ return False
+ else:
+ if state.initial_metadata_allowed:
+ operations = (
+ cygrpc.operation_send_initial_metadata(
+ _EMPTY_METADATA, _EMPTY_FLAGS),
+ cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS),
+ )
+ state.initial_metadata_allowed = False
+ token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
+ else:
+ operations = (
+ cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS),
+ )
+ token = _SEND_MESSAGE_TOKEN
+ rpc_event.operation_call.start_batch(
+ cygrpc.Operations(operations), _send_message(state, token))
+ state.due.add(token)
+ while True:
+ state.condition.wait()
+ if token not in state.due:
+ return state.client is not _CANCELLED and not state.statused
+
+
+def _status(rpc_event, state, serialized_response):
+ with state.condition:
+ if state.client is not _CANCELLED:
+ trailing_metadata = _common.metadata(state.trailing_metadata)
+ code = _completion_code(state)
+ details = _details(state)
+ operations = [
+ cygrpc.operation_send_status_from_server(
+ trailing_metadata, code, details, _EMPTY_FLAGS),
+ ]
+ if state.initial_metadata_allowed:
+ operations.append(
+ cygrpc.operation_send_initial_metadata(
+ _EMPTY_METADATA, _EMPTY_FLAGS))
+ if serialized_response is not None:
+ operations.append(cygrpc.operation_send_message(
+ serialized_response, _EMPTY_FLAGS))
+ rpc_event.operation_call.start_batch(
+ cygrpc.Operations(operations),
+ _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
+ state.statused = True
+ state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
+
+
+def _unary_response_in_pool(
+ rpc_event, state, behavior, argument_thunk, request_deserializer,
+ response_serializer):
+ argument = argument_thunk()
+ if argument is not None:
+ response, proceed = _call_behavior(
+ rpc_event, state, behavior, argument, request_deserializer)
+ if proceed:
+ serialized_response = _serialize_response(
+ rpc_event, state, response, response_serializer)
+ if serialized_response is not None:
+ _status(rpc_event, state, serialized_response)
+ return
+
+
+def _stream_response_in_pool(
+ rpc_event, state, behavior, argument_thunk, request_deserializer,
+ response_serializer):
+ argument = argument_thunk()
+ if argument is not None:
+ response_iterator, proceed = _call_behavior(
+ rpc_event, state, behavior, argument, request_deserializer)
+ if proceed:
+ while True:
+ response, proceed = _take_response_from_response_iterator(
+ rpc_event, state, response_iterator)
+ if proceed:
+ if response is None:
+ _status(rpc_event, state, None)
+ break
+ else:
+ serialized_response = _serialize_response(
+ rpc_event, state, response, response_serializer)
+ if serialized_response is not None:
+ proceed = _send_response(rpc_event, state, serialized_response)
+ if not proceed:
+ break
+ else:
+ break
+ else:
+ break
+
+
+def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
+ unary_request = _unary_request(
+ rpc_event, state, method_handler.request_deserializer)
+ thread_pool.submit(
+ _unary_response_in_pool, rpc_event, state, method_handler.unary_unary,
+ unary_request, method_handler.request_deserializer,
+ method_handler.response_serializer)
+
+
+def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
+ unary_request = _unary_request(
+ rpc_event, state, method_handler.request_deserializer)
+ thread_pool.submit(
+ _stream_response_in_pool, rpc_event, state, method_handler.unary_stream,
+ unary_request, method_handler.request_deserializer,
+ method_handler.response_serializer)
+
+
+def _handle_stream_unary(rpc_event, state, method_handler, thread_pool):
+ request_iterator = _RequestIterator(
+ state, rpc_event.operation_call, method_handler.request_deserializer)
+ thread_pool.submit(
+ _unary_response_in_pool, rpc_event, state, method_handler.stream_unary,
+ lambda: request_iterator, method_handler.request_deserializer,
+ method_handler.response_serializer)
+
+
+def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
+ request_iterator = _RequestIterator(
+ state, rpc_event.operation_call, method_handler.request_deserializer)
+ thread_pool.submit(
+ _stream_response_in_pool, rpc_event, state, method_handler.stream_stream,
+ lambda: request_iterator, method_handler.request_deserializer,
+ method_handler.response_serializer)
+
+
+def _find_method_handler(rpc_event, generic_handlers):
+ for generic_handler in generic_handlers:
+ method_handler = generic_handler.service(
+ _HandlerCallDetails(
+ rpc_event.request_call_details.method, rpc_event.request_metadata))
+ if method_handler is not None:
+ return method_handler
+ else:
+ return None
+
+
+def _handle_unrecognized_method(rpc_event):
+ operations = (
+ cygrpc.operation_send_initial_metadata(_EMPTY_METADATA, _EMPTY_FLAGS),
+ cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
+ cygrpc.operation_send_status_from_server(
+ _EMPTY_METADATA, cygrpc.StatusCode.unimplemented,
+ 'Method not found!', _EMPTY_FLAGS),
+ )
+ rpc_state = _RPCState()
+ rpc_event.operation_call.start_batch(
+ operations, lambda ignored_event: (rpc_state, (),))
+ return rpc_state
+
+
+def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
+ state = _RPCState()
+ with state.condition:
+ rpc_event.operation_call.start_batch(
+ cygrpc.Operations(
+ (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
+ _receive_close_on_server(state))
+ state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
+ if method_handler.request_streaming:
+ if method_handler.response_streaming:
+ _handle_stream_stream(rpc_event, state, method_handler, thread_pool)
+ else:
+ _handle_stream_unary(rpc_event, state, method_handler, thread_pool)
+ else:
+ if method_handler.response_streaming:
+ _handle_unary_stream(rpc_event, state, method_handler, thread_pool)
+ else:
+ _handle_unary_unary(rpc_event, state, method_handler, thread_pool)
+ return state
+
+
+def _handle_call(rpc_event, generic_handlers, thread_pool):
+ if rpc_event.request_call_details.method is not None:
+ method_handler = _find_method_handler(rpc_event, generic_handlers)
+ if method_handler is None:
+ return _handle_unrecognized_method(rpc_event)
+ else:
+ return _handle_with_method_handler(rpc_event, method_handler, thread_pool)
+ else:
+ return None
+
+
+@enum.unique
+class _ServerStage(enum.Enum):
+ STOPPED = 'stopped'
+ STARTED = 'started'
+ GRACE = 'grace'
+
+
+class _ServerState(object):
+
+ def __init__(self, completion_queue, server, generic_handlers, thread_pool):
+ self.lock = threading.Lock()
+ self.completion_queue = completion_queue
+ self.server = server
+ self.generic_handlers = list(generic_handlers)
+ self.thread_pool = thread_pool
+ self.stage = _ServerStage.STOPPED
+ self.shutdown_events = None
+
+ # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
+ self.rpc_states = set()
+ self.due = set()
+
+
+def _add_generic_handlers(state, generic_handlers):
+ with state.lock:
+ state.generic_handlers.extend(generic_handlers)
+
+
+def _add_insecure_port(state, address):
+ with state.lock:
+ return state.server.add_http2_port(address)
+
+
+def _add_secure_port(state, address, server_credentials):
+ with state.lock:
+ return state.server.add_http2_port(address, server_credentials._credentials)
+
+
+def _request_call(state):
+ state.server.request_call(
+ state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG)
+ state.due.add(_REQUEST_CALL_TAG)
+
+
+# TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
+def _stop_serving(state):
+ if not state.rpc_states and not state.due:
+ for shutdown_event in state.shutdown_events:
+ shutdown_event.set()
+ state.stage = _ServerStage.STOPPED
+ return True
+ else:
+ return False
+
+
+def _serve(state):
+ while True:
+ event = state.completion_queue.poll()
+ if event.tag is _SHUTDOWN_TAG:
+ with state.lock:
+ state.due.remove(_SHUTDOWN_TAG)
+ if _stop_serving(state):
+ return
+ elif event.tag is _REQUEST_CALL_TAG:
+ with state.lock:
+ state.due.remove(_REQUEST_CALL_TAG)
+ rpc_state = _handle_call(
+ event, state.generic_handlers, state.thread_pool)
+ if rpc_state is not None:
+ state.rpc_states.add(rpc_state)
+ if state.stage is _ServerStage.STARTED:
+ _request_call(state)
+ elif _stop_serving(state):
+ return
+ else:
+ rpc_state, callbacks = event.tag(event)
+ for callback in callbacks:
+ callable_util.call_logging_exceptions(
+ callback, 'Exception calling callback!')
+ if rpc_state is not None:
+ with state.lock:
+ state.rpc_states.remove(rpc_state)
+ if _stop_serving(state):
+ return
+
+
+def _stop(state, grace):
+ with state.lock:
+ if state.stage is _ServerStage.STOPPED:
+ shutdown_event = threading.Event()
+ shutdown_event.set()
+ return shutdown_event
+ else:
+ if state.stage is _ServerStage.STARTED:
+ state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
+ state.stage = _ServerStage.GRACE
+ state.shutdown_events = []
+ state.due.add(_SHUTDOWN_TAG)
+ shutdown_event = threading.Event()
+ state.shutdown_events.append(shutdown_event)
+ if grace is None:
+ state.server.cancel_all_calls()
+ # TODO(https://github.com/grpc/grpc/issues/6597): delete this loop.
+ for rpc_state in state.rpc_states:
+ with rpc_state.condition:
+ rpc_state.client = _CANCELLED
+ rpc_state.condition.notify_all()
+ else:
+ def cancel_all_calls_after_grace():
+ shutdown_event.wait(timeout=grace)
+ with state.lock:
+ state.server.cancel_all_calls()
+ # TODO(https://github.com/grpc/grpc/issues/6597): delete this loop.
+ for rpc_state in state.rpc_states:
+ with rpc_state.condition:
+ rpc_state.client = _CANCELLED
+ rpc_state.condition.notify_all()
+ thread = threading.Thread(target=cancel_all_calls_after_grace)
+ thread.start()
+ return shutdown_event
+ shutdown_event.wait()
+ return shutdown_event
+
+
+def _start(state):
+ with state.lock:
+ if state.stage is not _ServerStage.STOPPED:
+ raise ValueError('Cannot start already-started server!')
+ state.server.start()
+ state.stage = _ServerStage.STARTED
+ _request_call(state)
+ def cleanup_server(timeout):
+ if timeout is None:
+ _stop(state, _UNEXPECTED_EXIT_SERVER_GRACE).wait()
+ else:
+ _stop(state, timeout).wait()
+
+ thread = _common.CleanupThread(
+ cleanup_server, target=_serve, args=(state,))
+ thread.start()
+
+
+class Server(grpc.Server):
+
+ def __init__(self, generic_handlers, thread_pool):
+ completion_queue = cygrpc.CompletionQueue()
+ server = cygrpc.Server()
+ server.register_completion_queue(completion_queue)
+ self._state = _ServerState(
+ completion_queue, server, generic_handlers, thread_pool)
+
+ def add_generic_rpc_handlers(self, generic_rpc_handlers):
+ _add_generic_handlers(self._state, generic_rpc_handlers)
+
+ def add_insecure_port(self, address):
+ return _add_insecure_port(self._state, address)
+
+ def add_secure_port(self, address, server_credentials):
+ return _add_secure_port(self._state, address, server_credentials)
+
+ def start(self):
+ _start(self._state)
+
+ def stop(self, grace):
+ return _stop(self._state, grace)
+
+ def __del__(self):
+ _stop(self._state, None)
diff --git a/src/python/grpcio/grpc/_utilities.py b/src/python/grpcio/grpc/_utilities.py
new file mode 100644
index 0000000000..4850967fbc
--- /dev/null
+++ b/src/python/grpcio/grpc/_utilities.py
@@ -0,0 +1,171 @@
+# Copyright 2015, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Internal utilities for gRPC Python."""
+
+import collections
+import threading
+import time
+
+import six
+
+import grpc
+from grpc import _common
+from grpc.framework.foundation import callable_util
+
+_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = (
+ 'Exception calling connectivity future "done" callback!')
+
+
+class RpcMethodHandler(
+ collections.namedtuple(
+ '_RpcMethodHandler',
+ ('request_streaming', 'response_streaming', 'request_deserializer',
+ 'response_serializer', 'unary_unary', 'unary_stream', 'stream_unary',
+ 'stream_stream',)),
+ grpc.RpcMethodHandler):
+ pass
+
+
+class DictionaryGenericHandler(grpc.GenericRpcHandler):
+
+ def __init__(self, service, method_handlers):
+ self._method_handlers = {
+ _common.fully_qualified_method(service, method): method_handler
+ for method, method_handler in six.iteritems(method_handlers)}
+
+ def service(self, handler_call_details):
+ return self._method_handlers.get(handler_call_details.method)
+
+
+class _ChannelReadyFuture(grpc.Future):
+
+ def __init__(self, channel):
+ self._condition = threading.Condition()
+ self._channel = channel
+
+ self._matured = False
+ self._cancelled = False
+ self._done_callbacks = []
+
+ def _block(self, timeout):
+ until = None if timeout is None else time.time() + timeout
+ with self._condition:
+ while True:
+ if self._cancelled:
+ raise grpc.FutureCancelledError()
+ elif self._matured:
+ return
+ else:
+ if until is None:
+ self._condition.wait()
+ else:
+ remaining = until - time.time()
+ if remaining < 0:
+ raise grpc.FutureTimeoutError()
+ else:
+ self._condition.wait(timeout=remaining)
+
+ def _update(self, connectivity):
+ with self._condition:
+ if (not self._cancelled and
+ connectivity is grpc.ChannelConnectivity.READY):
+ self._matured = True
+ self._channel.unsubscribe(self._update)
+ self._condition.notify_all()
+ done_callbacks = tuple(self._done_callbacks)
+ self._done_callbacks = None
+ else:
+ return
+
+ for done_callback in done_callbacks:
+ callable_util.call_logging_exceptions(
+ done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self)
+
+ def cancel(self):
+ with self._condition:
+ if not self._matured:
+ self._cancelled = True
+ self._channel.unsubscribe(self._update)
+ self._condition.notify_all()
+ done_callbacks = tuple(self._done_callbacks)
+ self._done_callbacks = None
+ else:
+ return False
+
+ for done_callback in done_callbacks:
+ callable_util.call_logging_exceptions(
+ done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self)
+
+ def cancelled(self):
+ with self._condition:
+ return self._cancelled
+
+ def running(self):
+ with self._condition:
+ return not self._cancelled and not self._matured
+
+ def done(self):
+ with self._condition:
+ return self._cancelled or self._matured
+
+ def result(self, timeout=None):
+ self._block(timeout)
+ return None
+
+ def exception(self, timeout=None):
+ self._block(timeout)
+ return None
+
+ def traceback(self, timeout=None):
+ self._block(timeout)
+ return None
+
+ def add_done_callback(self, fn):
+ with self._condition:
+ if not self._cancelled and not self._matured:
+ self._done_callbacks.append(fn)
+ return
+
+ fn(self)
+
+ def start(self):
+ with self._condition:
+ self._channel.subscribe(self._update, try_to_connect=True)
+
+ def __del__(self):
+ with self._condition:
+ if not self._cancelled and not self._matured:
+ self._channel.unsubscribe(self._update)
+
+
+def channel_ready_future(channel):
+ ready_future = _ChannelReadyFuture(channel)
+ ready_future.start()
+ return ready_future
diff --git a/src/python/grpcio/grpc/beta/_client_adaptations.py b/src/python/grpcio/grpc/beta/_client_adaptations.py
new file mode 100644
index 0000000000..024808c540
--- /dev/null
+++ b/src/python/grpcio/grpc/beta/_client_adaptations.py
@@ -0,0 +1,563 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Translates gRPC's client-side API into gRPC's client-side Beta API."""
+
+import grpc
+from grpc import _common
+from grpc._cython import cygrpc
+from grpc.beta import interfaces
+from grpc.framework.common import cardinality
+from grpc.framework.foundation import future
+from grpc.framework.interfaces.face import face
+
+_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = {
+ grpc.StatusCode.CANCELLED: (
+ face.Abortion.Kind.CANCELLED, face.CancellationError),
+ grpc.StatusCode.UNKNOWN: (
+ face.Abortion.Kind.REMOTE_FAILURE, face.RemoteError),
+ grpc.StatusCode.DEADLINE_EXCEEDED: (
+ face.Abortion.Kind.EXPIRED, face.ExpirationError),
+ grpc.StatusCode.UNIMPLEMENTED: (
+ face.Abortion.Kind.LOCAL_FAILURE, face.LocalError),
+}
+
+
+def _effective_metadata(metadata, metadata_transformer):
+ non_none_metadata = () if metadata is None else metadata
+ if metadata_transformer is None:
+ return non_none_metadata
+ else:
+ return metadata_transformer(non_none_metadata)
+
+
+def _credentials(grpc_call_options):
+ return None if grpc_call_options is None else grpc_call_options.credentials
+
+
+def _abortion(rpc_error_call):
+ code = rpc_error_call.code()
+ pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
+ error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0]
+ return face.Abortion(
+ error_kind, rpc_error_call.initial_metadata(),
+ rpc_error_call.trailing_metadata(), code, rpc_error_code.details())
+
+
+def _abortion_error(rpc_error_call):
+ code = rpc_error_call.code()
+ pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
+ exception_class = face.AbortionError if pair is None else pair[1]
+ return exception_class(
+ rpc_error_call.initial_metadata(), rpc_error_call.trailing_metadata(),
+ code, rpc_error_call.details())
+
+
+class _InvocationProtocolContext(interfaces.GRPCInvocationContext):
+
+ def disable_next_request_compression(self):
+ pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
+
+
+class _Rendezvous(future.Future, face.Call):
+
+ def __init__(self, response_future, response_iterator, call):
+ self._future = response_future
+ self._iterator = response_iterator
+ self._call = call
+
+ def cancel(self):
+ return self._call.cancel()
+
+ def cancelled(self):
+ return self._future.cancelled()
+
+ def running(self):
+ return self._future.running()
+
+ def done(self):
+ return self._future.done()
+
+ def result(self, timeout=None):
+ try:
+ return self._future.result(timeout=timeout)
+ except grpc.RpcError as rpc_error_call:
+ raise _abortion_error(rpc_error_call)
+ except grpc.FutureTimeoutError:
+ raise future.TimeoutError()
+ except grpc.FutureCancelledError:
+ raise future.CancelledError()
+
+ def exception(self, timeout=None):
+ try:
+ rpc_error_call = self._future.exception(timeout=timeout)
+ return _abortion_error(rpc_error_call)
+ except grpc.FutureTimeoutError:
+ raise future.TimeoutError()
+ except grpc.FutureCancelledError:
+ raise future.CancelledError()
+
+ def traceback(self, timeout=None):
+ try:
+ return self._future.traceback(timeout=timeout)
+ except grpc.FutureTimeoutError:
+ raise future.TimeoutError()
+ except grpc.FutureCancelledError:
+ raise future.CancelledError()
+
+ def add_done_callback(self, fn):
+ self._future.add_done_callback(lambda ignored_callback: fn(self))
+
+ def __iter__(self):
+ return self
+
+ def _next(self):
+ try:
+ return next(self._iterator)
+ except grpc.RpcError as rpc_error_call:
+ raise _abortion_error(rpc_error_call)
+
+ def __next__(self):
+ return self._next()
+
+ def next(self):
+ return self._next()
+
+ def is_active(self):
+ return self._call.is_active()
+
+ def time_remaining(self):
+ return self._call.time_remaining()
+
+ def add_abortion_callback(self, abortion_callback):
+ registered = self._call.add_callback(
+ lambda: abortion_callback(_abortion(self._call)))
+ return None if registered else _abortion(self._call)
+
+ def protocol_context(self):
+ return _InvocationProtocolContext()
+
+ def initial_metadata(self):
+ return self._call.initial_metadata()
+
+ def terminal_metadata(self):
+ return self._call.terminal_metadata()
+
+ def code(self):
+ return self._call.code()
+
+ def details(self):
+ return self._call.details()
+
+
+def _blocking_unary_unary(
+ channel, group, method, timeout, with_call, protocol_options, metadata,
+ metadata_transformer, request, request_serializer, response_deserializer):
+ try:
+ multi_callable = channel.unary_unary(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ if with_call:
+ response, call = multi_callable(
+ request, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options), with_call=True)
+ return response, _Rendezvous(None, None, call)
+ else:
+ return multi_callable(
+ request, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ except grpc.RpcError as rpc_error_call:
+ raise _abortion_error(rpc_error_call)
+
+
+def _future_unary_unary(
+ channel, group, method, timeout, protocol_options, metadata,
+ metadata_transformer, request, request_serializer, response_deserializer):
+ multi_callable = channel.unary_unary(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ response_future = multi_callable.future(
+ request, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ return _Rendezvous(response_future, None, response_future)
+
+
+def _unary_stream(
+ channel, group, method, timeout, protocol_options, metadata,
+ metadata_transformer, request, request_serializer, response_deserializer):
+ multi_callable = channel.unary_stream(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ response_iterator = multi_callable(
+ request, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ return _Rendezvous(None, response_iterator, response_iterator)
+
+
+def _blocking_stream_unary(
+ channel, group, method, timeout, with_call, protocol_options, metadata,
+ metadata_transformer, request_iterator, request_serializer,
+ response_deserializer):
+ try:
+ multi_callable = channel.stream_unary(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ if with_call:
+ response, call = multi_callable(
+ request_iterator, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options), with_call=True)
+ return response, _Rendezvous(None, None, call)
+ else:
+ return multi_callable(
+ request_iterator, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ except grpc.RpcError as rpc_error_call:
+ raise _abortion_error(rpc_error_call)
+
+
+def _future_stream_unary(
+ channel, group, method, timeout, protocol_options, metadata,
+ metadata_transformer, request_iterator, request_serializer,
+ response_deserializer):
+ multi_callable = channel.stream_unary(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ response_future = multi_callable.future(
+ request_iterator, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ return _Rendezvous(response_future, None, response_future)
+
+
+def _stream_stream(
+ channel, group, method, timeout, protocol_options, metadata,
+ metadata_transformer, request_iterator, request_serializer,
+ response_deserializer):
+ multi_callable = channel.stream_stream(
+ _common.fully_qualified_method(group, method),
+ request_serializer=request_serializer,
+ response_deserializer=response_deserializer)
+ effective_metadata = _effective_metadata(metadata, metadata_transformer)
+ response_iterator = multi_callable(
+ request_iterator, timeout=timeout, metadata=effective_metadata,
+ credentials=_credentials(protocol_options))
+ return _Rendezvous(None, response_iterator, response_iterator)
+
+
+class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
+
+ def __init__(
+ self, channel, group, method, metadata_transformer, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._group = group
+ self._method = method
+ self._metadata_transformer = metadata_transformer
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(
+ self, request, timeout, metadata=None, with_call=False,
+ protocol_options=None):
+ return _blocking_unary_unary(
+ self._channel, self._group, self._method, timeout, with_call,
+ protocol_options, metadata, self._metadata_transformer, request,
+ self._request_serializer, self._response_deserializer)
+
+ def future(self, request, timeout, metadata=None, protocol_options=None):
+ return _future_unary_unary(
+ self._channel, self._group, self._method, timeout, protocol_options,
+ metadata, self._metadata_transformer, request, self._request_serializer,
+ self._response_deserializer)
+
+ def event(
+ self, request, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+
+class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable):
+
+ def __init__(
+ self, channel, group, method, metadata_transformer, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._group = group
+ self._method = method
+ self._metadata_transformer = metadata_transformer
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(self, request, timeout, metadata=None, protocol_options=None):
+ return _unary_stream(
+ self._channel, self._group, self._method, timeout, protocol_options,
+ metadata, self._metadata_transformer, request, self._request_serializer,
+ self._response_deserializer)
+
+ def event(
+ self, request, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+
+class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable):
+
+ def __init__(
+ self, channel, group, method, metadata_transformer, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._group = group
+ self._method = method
+ self._metadata_transformer = metadata_transformer
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(
+ self, request_iterator, timeout, metadata=None, with_call=False,
+ protocol_options=None):
+ return _blocking_stream_unary(
+ self._channel, self._group, self._method, timeout, with_call,
+ protocol_options, metadata, self._metadata_transformer,
+ request_iterator, self._request_serializer, self._response_deserializer)
+
+ def future(
+ self, request_iterator, timeout, metadata=None, protocol_options=None):
+ return _future_stream_unary(
+ self._channel, self._group, self._method, timeout, protocol_options,
+ metadata, self._metadata_transformer, request_iterator,
+ self._request_serializer, self._response_deserializer)
+
+ def event(
+ self, receiver, abortion_callback, timeout, metadata=None,
+ protocol_options=None):
+ raise NotImplementedError()
+
+
+class _StreamStreamMultiCallable(face.StreamStreamMultiCallable):
+
+ def __init__(
+ self, channel, group, method, metadata_transformer, request_serializer,
+ response_deserializer):
+ self._channel = channel
+ self._group = group
+ self._method = method
+ self._metadata_transformer = metadata_transformer
+ self._request_serializer = request_serializer
+ self._response_deserializer = response_deserializer
+
+ def __call__(
+ self, request_iterator, timeout, metadata=None, protocol_options=None):
+ return _stream_stream(
+ self._channel, self._group, self._method, timeout, protocol_options,
+ metadata, self._metadata_transformer, request_iterator,
+ self._request_serializer, self._response_deserializer)
+
+ def event(
+ self, receiver, abortion_callback, timeout, metadata=None,
+ protocol_options=None):
+ raise NotImplementedError()
+
+
+class _GenericStub(face.GenericStub):
+
+ def __init__(
+ self, channel, metadata_transformer, request_serializers,
+ response_deserializers):
+ self._channel = channel
+ self._metadata_transformer = metadata_transformer
+ self._request_serializers = request_serializers or {}
+ self._response_deserializers = response_deserializers or {}
+
+ def blocking_unary_unary(
+ self, group, method, request, timeout, metadata=None,
+ with_call=None, protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _blocking_unary_unary(
+ self._channel, group, method, timeout, with_call, protocol_options,
+ metadata, self._metadata_transformer, request, request_serializer,
+ response_deserializer)
+
+ def future_unary_unary(
+ self, group, method, request, timeout, metadata=None,
+ protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _future_unary_unary(
+ self._channel, group, method, timeout, protocol_options, metadata,
+ self._metadata_transformer, request, request_serializer,
+ response_deserializer)
+
+ def inline_unary_stream(
+ self, group, method, request, timeout, metadata=None,
+ protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _unary_stream(
+ self._channel, group, method, timeout, protocol_options, metadata,
+ self._metadata_transformer, request, request_serializer,
+ response_deserializer)
+
+ def blocking_stream_unary(
+ self, group, method, request_iterator, timeout, metadata=None,
+ with_call=None, protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _blocking_stream_unary(
+ self._channel, group, method, timeout, with_call, protocol_options,
+ metadata, self._metadata_transformer, request_iterator,
+ request_serializer, response_deserializer)
+
+ def future_stream_unary(
+ self, group, method, request_iterator, timeout, metadata=None,
+ protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _future_stream_unary(
+ self._channel, group, method, timeout, protocol_options, metadata,
+ self._metadata_transformer, request_iterator, request_serializer,
+ response_deserializer)
+
+ def inline_stream_stream(
+ self, group, method, request_iterator, timeout, metadata=None,
+ protocol_options=None):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _stream_stream(
+ self._channel, group, method, timeout, protocol_options, metadata,
+ self._metadata_transformer, request_iterator, request_serializer,
+ response_deserializer)
+
+ def event_unary_unary(
+ self, group, method, request, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+ def event_unary_stream(
+ self, group, method, request, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+ def event_stream_unary(
+ self, group, method, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+ def event_stream_stream(
+ self, group, method, receiver, abortion_callback, timeout,
+ metadata=None, protocol_options=None):
+ raise NotImplementedError()
+
+ def unary_unary(self, group, method):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _UnaryUnaryMultiCallable(
+ self._channel, group, method, self._metadata_transformer,
+ request_serializer, response_deserializer)
+
+ def unary_stream(self, group, method):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _UnaryStreamMultiCallable(
+ self._channel, group, method, self._metadata_transformer,
+ request_serializer, response_deserializer)
+
+ def stream_unary(self, group, method):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _StreamUnaryMultiCallable(
+ self._channel, group, method, self._metadata_transformer,
+ request_serializer, response_deserializer)
+
+ def stream_stream(self, group, method):
+ request_serializer = self._request_serializers.get((group, method,))
+ response_deserializer = self._response_deserializers.get((group, method,))
+ return _StreamStreamMultiCallable(
+ self._channel, group, method, self._metadata_transformer,
+ request_serializer, response_deserializer)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ return False
+
+
+class _DynamicStub(face.DynamicStub):
+
+ def __init__(self, generic_stub, group, cardinalities):
+ self._generic_stub = generic_stub
+ self._group = group
+ self._cardinalities = cardinalities
+
+ def __getattr__(self, attr):
+ method_cardinality = self._cardinalities.get(attr)
+ if method_cardinality is cardinality.Cardinality.UNARY_UNARY:
+ return self._generic_stub.unary_unary(self._group, attr)
+ elif method_cardinality is cardinality.Cardinality.UNARY_STREAM:
+ return self._generic_stub.unary_stream(self._group, attr)
+ elif method_cardinality is cardinality.Cardinality.STREAM_UNARY:
+ return self._generic_stub.stream_unary(self._group, attr)
+ elif method_cardinality is cardinality.Cardinality.STREAM_STREAM:
+ return self._generic_stub.stream_stream(self._group, attr)
+ else:
+ raise AttributeError('_DynamicStub object has no attribute "%s"!' % attr)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ return False
+
+
+def generic_stub(
+ channel, host, metadata_transformer, request_serializers,
+ response_deserializers):
+ return _GenericStub(
+ channel, metadata_transformer, request_serializers,
+ response_deserializers)
+
+
+def dynamic_stub(
+ channel, service, cardinalities, host, metadata_transformer,
+ request_serializers, response_deserializers):
+ return _DynamicStub(
+ _GenericStub(
+ channel, metadata_transformer, request_serializers,
+ response_deserializers),
+ service, cardinalities)
diff --git a/src/python/grpcio/grpc/beta/_server_adaptations.py b/src/python/grpcio/grpc/beta/_server_adaptations.py
new file mode 100644
index 0000000000..79e6ca87eb
--- /dev/null
+++ b/src/python/grpcio/grpc/beta/_server_adaptations.py
@@ -0,0 +1,367 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Translates gRPC's server-side API into gRPC's server-side Beta API."""
+
+import collections
+import threading
+
+import grpc
+from grpc import _common
+from grpc.beta import interfaces
+from grpc.framework.common import cardinality
+from grpc.framework.common import style
+from grpc.framework.foundation import abandonment
+from grpc.framework.foundation import logging_pool
+from grpc.framework.foundation import stream
+from grpc.framework.interfaces.face import face
+
+_DEFAULT_POOL_SIZE = 8
+
+
+class _ServerProtocolContext(interfaces.GRPCServicerContext):
+
+ def __init__(self, servicer_context):
+ self._servicer_context = servicer_context
+
+ def peer(self):
+ return self._servicer_context.peer()
+
+ def disable_next_response_compression(self):
+ pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
+
+
+class _FaceServicerContext(face.ServicerContext):
+
+ def __init__(self, servicer_context):
+ self._servicer_context = servicer_context
+
+ def is_active(self):
+ return self._servicer_context.is_active()
+
+ def time_remaining(self):
+ return self._servicer_context.time_remaining()
+
+ def add_abortion_callback(self, abortion_callback):
+ raise NotImplementedError(
+ 'add_abortion_callback no longer supported server-side!')
+
+ def cancel(self):
+ self._servicer_context.cancel()
+
+ def protocol_context(self):
+ return _ServerProtocolContext(self._servicer_context)
+
+ def invocation_metadata(self):
+ return self._servicer_context.invocation_metadata()
+
+ def initial_metadata(self, initial_metadata):
+ self._servicer_context.send_initial_metadata(initial_metadata)
+
+ def terminal_metadata(self, terminal_metadata):
+ self._servicer_context.set_terminal_metadata(terminal_metadata)
+
+ def code(self, code):
+ self._servicer_context.set_code(code)
+
+ def details(self, details):
+ self._servicer_context.set_details(details)
+
+
+def _adapt_unary_request_inline(unary_request_inline):
+ def adaptation(request, servicer_context):
+ return unary_request_inline(request, _FaceServicerContext(servicer_context))
+ return adaptation
+
+
+def _adapt_stream_request_inline(stream_request_inline):
+ def adaptation(request_iterator, servicer_context):
+ return stream_request_inline(
+ request_iterator, _FaceServicerContext(servicer_context))
+ return adaptation
+
+
+class _Callback(stream.Consumer):
+
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._values = []
+ self._terminated = False
+ self._cancelled = False
+
+ def consume(self, value):
+ with self._condition:
+ self._values.append(value)
+ self._condition.notify_all()
+
+ def terminate(self):
+ with self._condition:
+ self._terminated = True
+ self._condition.notify_all()
+
+ def consume_and_terminate(self, value):
+ with self._condition:
+ self._values.append(value)
+ self._terminated = True
+ self._condition.notify_all()
+
+ def cancel(self):
+ with self._condition:
+ self._cancelled = True
+ self._condition.notify_all()
+
+ def draw_one_value(self):
+ with self._condition:
+ while True:
+ if self._cancelled:
+ raise abandonment.Abandoned()
+ elif self._values:
+ return self._values.pop(0)
+ elif self._terminated:
+ return None
+ else:
+ self._condition.wait()
+
+ def draw_all_values(self):
+ with self._condition:
+ while True:
+ if self._cancelled:
+ raise abandonment.Abandoned()
+ elif self._terminated:
+ all_values = tuple(self._values)
+ self._values = None
+ return all_values
+ else:
+ self._condition.wait()
+
+
+def _pipe_requests(request_iterator, request_consumer, servicer_context):
+ for request in request_iterator:
+ if not servicer_context.is_active():
+ return
+ request_consumer.consume(request)
+ if not servicer_context.is_active():
+ return
+ request_consumer.terminate()
+
+
+def _adapt_unary_unary_event(unary_unary_event):
+ def adaptation(request, servicer_context):
+ callback = _Callback()
+ if not servicer_context.add_callback(callback.cancel):
+ raise abandonment.Abandoned()
+ unary_unary_event(
+ request, callback.consume_and_terminate,
+ _FaceServicerContext(servicer_context))
+ return callback.draw_all_values()[0]
+ return adaptation
+
+
+def _adapt_unary_stream_event(unary_stream_event):
+ def adaptation(request, servicer_context):
+ callback = _Callback()
+ if not servicer_context.add_callback(callback.cancel):
+ raise abandonment.Abandoned()
+ unary_stream_event(
+ request, callback, _FaceServicerContext(servicer_context))
+ while True:
+ response = callback.draw_one_value()
+ if response is None:
+ return
+ else:
+ yield response
+ return adaptation
+
+
+def _adapt_stream_unary_event(stream_unary_event):
+ def adaptation(request_iterator, servicer_context):
+ callback = _Callback()
+ if not servicer_context.add_callback(callback.cancel):
+ raise abandonment.Abandoned()
+ request_consumer = stream_unary_event(
+ callback.consume_and_terminate, _FaceServicerContext(servicer_context))
+ request_pipe_thread = threading.Thread(
+ target=_pipe_requests,
+ args=(request_iterator, request_consumer, servicer_context,))
+ request_pipe_thread.start()
+ return callback.draw_all_values()[0]
+ return adaptation
+
+
+def _adapt_stream_stream_event(stream_stream_event):
+ def adaptation(request_iterator, servicer_context):
+ callback = _Callback()
+ if not servicer_context.add_callback(callback.cancel):
+ raise abandonment.Abandoned()
+ request_consumer = stream_stream_event(
+ callback, _FaceServicerContext(servicer_context))
+ request_pipe_thread = threading.Thread(
+ target=_pipe_requests,
+ args=(request_iterator, request_consumer, servicer_context,))
+ request_pipe_thread.start()
+ while True:
+ response = callback.draw_one_value()
+ if response is None:
+ return
+ else:
+ yield response
+ return adaptation
+
+
+class _SimpleMethodHandler(
+ collections.namedtuple(
+ '_MethodHandler',
+ ('request_streaming', 'response_streaming', 'request_deserializer',
+ 'response_serializer', 'unary_unary', 'unary_stream', 'stream_unary',
+ 'stream_stream',)),
+ grpc.RpcMethodHandler):
+ pass
+
+
+def _simple_method_handler(
+ implementation, request_deserializer, response_serializer):
+ if implementation.style is style.Service.INLINE:
+ if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
+ return _SimpleMethodHandler(
+ False, False, request_deserializer, response_serializer,
+ _adapt_unary_request_inline(implementation.unary_unary_inline), None,
+ None, None)
+ elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
+ return _SimpleMethodHandler(
+ False, True, request_deserializer, response_serializer, None,
+ _adapt_unary_request_inline(implementation.unary_stream_inline), None,
+ None)
+ elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
+ return _SimpleMethodHandler(
+ True, False, request_deserializer, response_serializer, None, None,
+ _adapt_stream_request_inline(implementation.stream_unary_inline),
+ None)
+ elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
+ return _SimpleMethodHandler(
+ True, True, request_deserializer, response_serializer, None, None,
+ None,
+ _adapt_stream_request_inline(implementation.stream_stream_inline))
+ elif implementation.style is style.Service.EVENT:
+ if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
+ return _SimpleMethodHandler(
+ False, False, request_deserializer, response_serializer,
+ _adapt_unary_unary_event(implementation.unary_unary_event), None,
+ None, None)
+ elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
+ return _SimpleMethodHandler(
+ False, True, request_deserializer, response_serializer, None,
+ _adapt_unary_stream_event(implementation.unary_stream_event), None,
+ None)
+ elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
+ return _SimpleMethodHandler(
+ True, False, request_deserializer, response_serializer, None, None,
+ _adapt_stream_unary_event(implementation.stream_unary_event), None)
+ elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
+ return _SimpleMethodHandler(
+ True, True, request_deserializer, response_serializer, None, None,
+ None, _adapt_stream_stream_event(implementation.stream_stream_event))
+
+
+def _flatten_method_pair_map(method_pair_map):
+ method_pair_map = method_pair_map or {}
+ flat_map = {}
+ for method_pair in method_pair_map:
+ method = _common.fully_qualified_method(method_pair[0], method_pair[1])
+ flat_map[method] = method_pair_map[method_pair]
+ return flat_map
+
+
+class _GenericRpcHandler(grpc.GenericRpcHandler):
+
+ def __init__(
+ self, method_implementations, multi_method_implementation,
+ request_deserializers, response_serializers):
+ self._method_implementations = _flatten_method_pair_map(
+ method_implementations)
+ self._request_deserializers = _flatten_method_pair_map(
+ request_deserializers)
+ self._response_serializers = _flatten_method_pair_map(
+ response_serializers)
+ self._multi_method_implementation = multi_method_implementation
+
+ def service(self, handler_call_details):
+ method_implementation = self._method_implementations.get(
+ handler_call_details.method)
+ if method_implementation is not None:
+ return _simple_method_handler(
+ method_implementation,
+ self._request_deserializers.get(handler_call_details.method),
+ self._response_serializers.get(handler_call_details.method))
+ elif self._multi_method_implementation is None:
+ return None
+ else:
+ try:
+ return None #TODO(nathaniel): call the multimethod.
+ except face.NoSuchMethodError:
+ return None
+
+
+class _Server(interfaces.Server):
+
+ def __init__(self, server):
+ self._server = server
+
+ def add_insecure_port(self, address):
+ return self._server.add_insecure_port(address)
+
+ def add_secure_port(self, address, server_credentials):
+ return self._server.add_secure_port(address, server_credentials)
+
+ def start(self):
+ self._server.start()
+
+ def stop(self, grace):
+ return self._server.stop(grace)
+
+ def __enter__(self):
+ self._server.start()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._server.stop(None)
+ return False
+
+
+def server(
+ service_implementations, multi_method_implementation, request_deserializers,
+ response_serializers, thread_pool, thread_pool_size):
+ generic_rpc_handler = _GenericRpcHandler(
+ service_implementations, multi_method_implementation,
+ request_deserializers, response_serializers)
+ if thread_pool is None:
+ effective_thread_pool = logging_pool.pool(
+ _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size)
+ else:
+ effective_thread_pool = thread_pool
+ return _Server(grpc.server((generic_rpc_handler,), effective_thread_pool))
diff --git a/src/python/grpcio/grpc/beta/implementations.py b/src/python/grpcio/grpc/beta/implementations.py
index 822f593323..4ae6e7d675 100644
--- a/src/python/grpcio/grpc/beta/implementations.py
+++ b/src/python/grpcio/grpc/beta/implementations.py
@@ -35,112 +35,36 @@ import enum
import threading # pylint: disable=unused-import
# cardinality and face are referenced from specification in this module.
-from grpc._adapter import _intermediary_low
-from grpc._adapter import _low
+import grpc
+from grpc import _auth
from grpc._adapter import _types
-from grpc.beta import _connectivity_channel
-from grpc.beta import _server
-from grpc.beta import _stub
+from grpc.beta import _client_adaptations
+from grpc.beta import _server_adaptations
from grpc.beta import interfaces
from grpc.framework.common import cardinality # pylint: disable=unused-import
from grpc.framework.interfaces.face import face # pylint: disable=unused-import
-_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
- 'Exception calling channel subscription callback!')
+ChannelCredentials = grpc.ChannelCredentials
+ssl_channel_credentials = grpc.ssl_channel_credentials
+CallCredentials = grpc.CallCredentials
+metadata_call_credentials = grpc.metadata_call_credentials
-class ChannelCredentials(object):
- """A value encapsulating the data required to create a secure Channel.
- This class and its instances have no supported interface - it exists to define
- the type of its instances and its instances exist to be passed to other
- functions.
- """
-
- def __init__(self, low_credentials):
- self._low_credentials = low_credentials
-
-
-def ssl_channel_credentials(root_certificates=None, private_key=None,
- certificate_chain=None):
- """Creates a ChannelCredentials for use with an SSL-enabled Channel.
-
- Args:
- root_certificates: The PEM-encoded root certificates or unset to ask for
- them to be retrieved from a default location.
- private_key: The PEM-encoded private key to use or unset if no private key
- should be used.
- certificate_chain: The PEM-encoded certificate chain to use or unset if no
- certificate chain should be used.
-
- Returns:
- A ChannelCredentials for use with an SSL-enabled Channel.
- """
- return ChannelCredentials(_low.channel_credentials_ssl(
- root_certificates, private_key, certificate_chain))
-
-
-class CallCredentials(object):
- """A value encapsulating data asserting an identity over an *established*
- channel. May be composed with ChannelCredentials to always assert identity for
- every call over that channel.
-
- This class and its instances have no supported interface - it exists to define
- the type of its instances and its instances exist to be passed to other
- functions.
- """
-
- def __init__(self, low_credentials):
- self._low_credentials = low_credentials
-
-
-def metadata_call_credentials(metadata_plugin, name=None):
- """Construct CallCredentials from an interfaces.GRPCAuthMetadataPlugin.
-
- Args:
- metadata_plugin: An interfaces.GRPCAuthMetadataPlugin to use in constructing
- the CallCredentials object.
-
- Returns:
- A CallCredentials object for use in a GRPCCallOptions object.
- """
- if name is None:
- name = metadata_plugin.__name__
- return CallCredentials(
- _low.call_credentials_metadata_plugin(metadata_plugin, name))
-
-def composite_call_credentials(call_credentials, additional_call_credentials):
- """Compose two CallCredentials to make a new one.
+def google_call_credentials(credentials):
+ """Construct CallCredentials from GoogleCredentials.
Args:
- call_credentials: A CallCredentials object.
- additional_call_credentials: Another CallCredentials object to compose on
- top of call_credentials.
+ credentials: A GoogleCredentials object from the oauth2client library.
Returns:
A CallCredentials object for use in a GRPCCallOptions object.
"""
- return CallCredentials(
- _low.call_credentials_composite(
- call_credentials._low_credentials,
- additional_call_credentials._low_credentials))
+ return metadata_call_credentials(_auth.GoogleCallCredentials(credentials))
-def composite_channel_credentials(channel_credentials,
- additional_call_credentials):
- """Compose ChannelCredentials on top of client credentials to make a new one.
-
- Args:
- channel_credentials: A ChannelCredentials object.
- additional_call_credentials: A CallCredentials object to compose on
- top of channel_credentials.
-
- Returns:
- A ChannelCredentials object for use in a GRPCCallOptions object.
- """
- return ChannelCredentials(
- _low.channel_credentials_composite(
- channel_credentials._low_credentials,
- additional_call_credentials._low_credentials))
+access_token_call_credentials = grpc.access_token_call_credentials
+composite_call_credentials = grpc.composite_call_credentials
+composite_channel_credentials = grpc.composite_channel_credentials
class Channel(object):
@@ -151,11 +75,8 @@ class Channel(object):
unsupported.
"""
- def __init__(self, low_channel, intermediary_low_channel):
- self._low_channel = low_channel
- self._intermediary_low_channel = intermediary_low_channel
- self._connectivity_channel = _connectivity_channel.ConnectivityChannel(
- low_channel)
+ def __init__(self, channel):
+ self._channel = channel
def subscribe(self, callback, try_to_connect=None):
"""Subscribes to this Channel's connectivity.
@@ -170,7 +91,7 @@ class Channel(object):
attempt to connect if it is not already connected and ready to conduct
RPCs.
"""
- self._connectivity_channel.subscribe(callback, try_to_connect)
+ self._channel.subscribe(callback, try_to_connect=try_to_connect)
def unsubscribe(self, callback):
"""Unsubscribes a callback from this Channel's connectivity.
@@ -179,7 +100,7 @@ class Channel(object):
callback: A callable previously registered with this Channel from having
been passed to its "subscribe" method.
"""
- self._connectivity_channel.unsubscribe(callback)
+ self._channel.unsubscribe(callback)
def insecure_channel(host, port):
@@ -193,9 +114,9 @@ def insecure_channel(host, port):
Returns:
A Channel to the remote host through which RPCs may be conducted.
"""
- intermediary_low_channel = _intermediary_low.Channel(
- '%s:%d' % (host, port) if port else host, None)
- return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access
+ channel = grpc.insecure_channel(
+ host if port is None else '%s:%d' % (host, port))
+ return Channel(channel)
def secure_channel(host, port, channel_credentials):
@@ -210,10 +131,9 @@ def secure_channel(host, port, channel_credentials):
Returns:
A secure Channel to the remote host through which RPCs may be conducted.
"""
- intermediary_low_channel = _intermediary_low.Channel(
- '%s:%d' % (host, port) if port else host,
- channel_credentials._low_credentials)
- return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access
+ channel = grpc.secure_channel(
+ host if port is None else '%s:%d' % (host, port), channel_credentials)
+ return Channel(channel)
class StubOptions(object):
@@ -277,12 +197,11 @@ def generic_stub(channel, options=None):
A face.GenericStub on which RPCs can be made.
"""
effective_options = _EMPTY_STUB_OPTIONS if options is None else options
- return _stub.generic_stub(
- channel._intermediary_low_channel, effective_options.host, # pylint: disable=protected-access
- effective_options.metadata_transformer,
+ return _client_adaptations.generic_stub(
+ channel._channel, # pylint: disable=protected-access
+ effective_options.host, effective_options.metadata_transformer,
effective_options.request_serializers,
- effective_options.response_deserializers, effective_options.thread_pool,
- effective_options.thread_pool_size)
+ effective_options.response_deserializers)
def dynamic_stub(channel, service, cardinalities, options=None):
@@ -300,55 +219,16 @@ def dynamic_stub(channel, service, cardinalities, options=None):
A face.DynamicStub with which RPCs can be invoked.
"""
effective_options = StubOptions() if options is None else options
- return _stub.dynamic_stub(
- channel._intermediary_low_channel, effective_options.host, service, # pylint: disable=protected-access
- cardinalities, effective_options.metadata_transformer,
+ return _client_adaptations.dynamic_stub(
+ channel._channel, # pylint: disable=protected-access
+ service, cardinalities, effective_options.host,
+ effective_options.metadata_transformer,
effective_options.request_serializers,
- effective_options.response_deserializers, effective_options.thread_pool,
- effective_options.thread_pool_size)
-
+ effective_options.response_deserializers)
-class ServerCredentials(object):
- """A value encapsulating the data required to open a secure port on a Server.
- This class and its instances have no supported interface - it exists to define
- the type of its instances and its instances exist to be passed to other
- functions.
- """
-
- def __init__(self, low_credentials):
- self._low_credentials = low_credentials
-
-
-def ssl_server_credentials(
- private_key_certificate_chain_pairs, root_certificates=None,
- require_client_auth=False):
- """Creates a ServerCredentials for use with an SSL-enabled Server.
-
- Args:
- private_key_certificate_chain_pairs: A nonempty sequence each element of
- which is a pair the first element of which is a PEM-encoded private key
- and the second element of which is the corresponding PEM-encoded
- certificate chain.
- root_certificates: PEM-encoded client root certificates to be used for
- verifying authenticated clients. If omitted, require_client_auth must also
- be omitted or be False.
- require_client_auth: A boolean indicating whether or not to require clients
- to be authenticated. May only be True if root_certificates is not None.
-
- Returns:
- A ServerCredentials for use with an SSL-enabled Server.
- """
- if len(private_key_certificate_chain_pairs) == 0:
- raise ValueError(
- 'At least one private key-certificate chain pairis required!')
- elif require_client_auth and root_certificates is None:
- raise ValueError(
- 'Illegal to require client auth without providing root certificates!')
- else:
- return ServerCredentials(_low.server_credentials_ssl(
- root_certificates, private_key_certificate_chain_pairs,
- require_client_auth))
+ServerCredentials = grpc.ServerCredentials
+ssl_server_credentials = grpc.ssl_server_credentials
class ServerOptions(object):
@@ -421,9 +301,8 @@ def server(service_implementations, options=None):
An interfaces.Server with which RPCs can be serviced.
"""
effective_options = _EMPTY_SERVER_OPTIONS if options is None else options
- return _server.server(
+ return _server_adaptations.server(
service_implementations, effective_options.multi_method_implementation,
effective_options.request_deserializers,
effective_options.response_serializers, effective_options.thread_pool,
- effective_options.thread_pool_size, effective_options.default_timeout,
- effective_options.maximum_timeout)
+ effective_options.thread_pool_size)
diff --git a/src/python/grpcio/grpc/beta/interfaces.py b/src/python/grpcio/grpc/beta/interfaces.py
index 24de9ad1a8..90f6bbbfcc 100644
--- a/src/python/grpcio/grpc/beta/interfaces.py
+++ b/src/python/grpcio/grpc/beta/interfaces.py
@@ -30,53 +30,16 @@
"""Constants and interfaces of the Beta API of gRPC Python."""
import abc
-import enum
import six
-from grpc._adapter import _types
+import grpc
+ChannelConnectivity = grpc.ChannelConnectivity
+# FATAL_FAILURE was a Beta-API name for SHUTDOWN
+ChannelConnectivity.FATAL_FAILURE = ChannelConnectivity.SHUTDOWN
-@enum.unique
-class ChannelConnectivity(enum.Enum):
- """Mirrors grpc_connectivity_state in the gRPC Core.
-
- Attributes:
- IDLE: The channel is idle.
- CONNECTING: The channel is connecting.
- READY: The channel is ready to conduct RPCs.
- TRANSIENT_FAILURE: The channel has seen a failure from which it expects to
- recover.
- FATAL_FAILURE: The channel has seen a failure from which it cannot recover.
- """
- IDLE = (_types.ConnectivityState.IDLE, 'idle',)
- CONNECTING = (_types.ConnectivityState.CONNECTING, 'connecting',)
- READY = (_types.ConnectivityState.READY, 'ready',)
- TRANSIENT_FAILURE = (
- _types.ConnectivityState.TRANSIENT_FAILURE, 'transient failure',)
- FATAL_FAILURE = (_types.ConnectivityState.FATAL_FAILURE, 'fatal failure',)
-
-
-@enum.unique
-class StatusCode(enum.Enum):
- """Mirrors grpc_status_code in the C core."""
- OK = 0
- CANCELLED = 1
- UNKNOWN = 2
- INVALID_ARGUMENT = 3
- DEADLINE_EXCEEDED = 4
- NOT_FOUND = 5
- ALREADY_EXISTS = 6
- PERMISSION_DENIED = 7
- RESOURCE_EXHAUSTED = 8
- FAILED_PRECONDITION = 9
- ABORTED = 10
- OUT_OF_RANGE = 11
- UNIMPLEMENTED = 12
- INTERNAL = 13
- UNAVAILABLE = 14
- DATA_LOSS = 15
- UNAUTHENTICATED = 16
+StatusCode = grpc.StatusCode
class GRPCCallOptions(object):
@@ -106,46 +69,9 @@ def grpc_call_options(disable_compression=False, credentials=None):
"""
return GRPCCallOptions(disable_compression, None, credentials)
-
-class GRPCAuthMetadataContext(six.with_metaclass(abc.ABCMeta)):
- """Provides information to call credentials metadata plugins.
-
- Attributes:
- service_url: A string URL of the service being called into.
- method_name: A string of the fully qualified method name being called.
- """
-
-
-class GRPCAuthMetadataPluginCallback(six.with_metaclass(abc.ABCMeta)):
- """Callback object received by a metadata plugin."""
-
- def __call__(self, metadata, error):
- """Inform the gRPC runtime of the metadata to construct a CallCredentials.
-
- Args:
- metadata: An iterable of 2-sequences (e.g. tuples) of metadata key/value
- pairs.
- error: An Exception to indicate error or None to indicate success.
- """
- raise NotImplementedError()
-
-
-class GRPCAuthMetadataPlugin(six.with_metaclass(abc.ABCMeta)):
- """
- """
-
- def __call__(self, context, callback):
- """Invoke the plugin.
-
- Must not block. Need only be called by the gRPC runtime.
-
- Args:
- context: A GRPCAuthMetadataContext providing information on what the
- plugin is being used for.
- callback: A GRPCAuthMetadataPluginCallback to be invoked either
- synchronously or asynchronously.
- """
- raise NotImplementedError()
+GRPCAuthMetadataContext = grpc.AuthMetadataContext
+GRPCAuthMetadataPluginCallback = grpc.AuthMetadataPluginCallback
+GRPCAuthMetadataPlugin = grpc.AuthMetadataPlugin
class GRPCServicerContext(six.with_metaclass(abc.ABCMeta)):
diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py
index 33636db29e..839c555f05 100644
--- a/src/python/grpcio/grpc_core_dependencies.py
+++ b/src/python/grpcio/grpc_core_dependencies.py
@@ -42,38 +42,38 @@ CORE_SOURCE_FILES = [
'src/core/lib/support/cpu_windows.c',
'src/core/lib/support/env_linux.c',
'src/core/lib/support/env_posix.c',
- 'src/core/lib/support/env_win32.c',
+ 'src/core/lib/support/env_windows.c',
'src/core/lib/support/histogram.c',
'src/core/lib/support/host_port.c',
'src/core/lib/support/log.c',
'src/core/lib/support/log_android.c',
'src/core/lib/support/log_linux.c',
'src/core/lib/support/log_posix.c',
- 'src/core/lib/support/log_win32.c',
+ 'src/core/lib/support/log_windows.c',
'src/core/lib/support/murmur_hash.c',
'src/core/lib/support/slice.c',
'src/core/lib/support/slice_buffer.c',
'src/core/lib/support/stack_lockfree.c',
'src/core/lib/support/string.c',
'src/core/lib/support/string_posix.c',
- 'src/core/lib/support/string_util_win32.c',
- 'src/core/lib/support/string_win32.c',
+ 'src/core/lib/support/string_util_windows.c',
+ 'src/core/lib/support/string_windows.c',
'src/core/lib/support/subprocess_posix.c',
'src/core/lib/support/subprocess_windows.c',
'src/core/lib/support/sync.c',
'src/core/lib/support/sync_posix.c',
- 'src/core/lib/support/sync_win32.c',
+ 'src/core/lib/support/sync_windows.c',
'src/core/lib/support/thd.c',
'src/core/lib/support/thd_posix.c',
- 'src/core/lib/support/thd_win32.c',
+ 'src/core/lib/support/thd_windows.c',
'src/core/lib/support/time.c',
'src/core/lib/support/time_posix.c',
'src/core/lib/support/time_precise.c',
- 'src/core/lib/support/time_win32.c',
+ 'src/core/lib/support/time_windows.c',
'src/core/lib/support/tls_pthread.c',
'src/core/lib/support/tmpfile_msys.c',
'src/core/lib/support/tmpfile_posix.c',
- 'src/core/lib/support/tmpfile_win32.c',
+ 'src/core/lib/support/tmpfile_windows.c',
'src/core/lib/support/wrap_memcpy.c',
'src/core/lib/surface/init.c',
'src/core/lib/channel/channel_args.c',
@@ -83,7 +83,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/channel/connected_channel.c',
'src/core/lib/channel/http_client_filter.c',
'src/core/lib/channel/http_server_filter.c',
- 'src/core/lib/compression/compression_algorithm.c',
+ 'src/core/lib/compression/compression.c',
'src/core/lib/compression/message_compress.c',
'src/core/lib/debug/trace.c',
'src/core/lib/http/format_request.c',
@@ -94,6 +94,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/iomgr/endpoint_pair_posix.c',
'src/core/lib/iomgr/endpoint_pair_windows.c',
'src/core/lib/iomgr/error.c',
+ 'src/core/lib/iomgr/ev_poll_and_epoll_posix.c',
'src/core/lib/iomgr/ev_poll_posix.c',
'src/core/lib/iomgr/ev_posix.c',
'src/core/lib/iomgr/exec_ctx.c',
@@ -103,6 +104,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/iomgr/iomgr_posix.c',
'src/core/lib/iomgr/iomgr_windows.c',
'src/core/lib/iomgr/load_file.c',
+ 'src/core/lib/iomgr/polling_entity.c',
'src/core/lib/iomgr/pollset_set_windows.c',
'src/core/lib/iomgr/pollset_windows.c',
'src/core/lib/iomgr/resolve_address_posix.c',
@@ -160,6 +162,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/transport/transport.c',
'src/core/lib/transport/transport_op_string.c',
'src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c',
+ 'src/core/ext/transport/chttp2/transport/bin_decoder.c',
'src/core/ext/transport/chttp2/transport/bin_encoder.c',
'src/core/ext/transport/chttp2/transport/chttp2_plugin.c',
'src/core/ext/transport/chttp2/transport/chttp2_transport.c',
@@ -189,7 +192,7 @@ CORE_SOURCE_FILES = [
'src/core/lib/security/credentials/credentials_metadata.c',
'src/core/lib/security/credentials/fake/fake_credentials.c',
'src/core/lib/security/credentials/google_default/credentials_posix.c',
- 'src/core/lib/security/credentials/google_default/credentials_win32.c',
+ 'src/core/lib/security/credentials/google_default/credentials_windows.c',
'src/core/lib/security/credentials/google_default/google_default_credentials.c',
'src/core/lib/security/credentials/iam/iam_credentials.c',
'src/core/lib/security/credentials/jwt/json_token.c',
@@ -231,10 +234,9 @@ CORE_SOURCE_FILES = [
'src/core/ext/client_config/subchannel_index.c',
'src/core/ext/client_config/uri_parser.c',
'src/core/ext/transport/chttp2/server/insecure/server_chttp2.c',
+ 'src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c',
'src/core/ext/transport/chttp2/client/insecure/channel_create.c',
- 'src/core/ext/transport/cronet/client/secure/cronet_channel_create.c',
- 'src/core/ext/transport/cronet/transport/cronet_api_dummy.c',
- 'src/core/ext/transport/cronet/transport/cronet_transport.c',
+ 'src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c',
'src/core/ext/lb_policy/grpclb/load_balancer_api.c',
'src/core/ext/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c',
'third_party/nanopb/pb_common.c',
@@ -244,7 +246,10 @@ CORE_SOURCE_FILES = [
'src/core/ext/lb_policy/round_robin/round_robin.c',
'src/core/ext/resolver/dns/native/dns_resolver.c',
'src/core/ext/resolver/sockaddr/sockaddr_resolver.c',
+ 'src/core/ext/load_reporting/load_reporting.c',
+ 'src/core/ext/load_reporting/load_reporting_filter.c',
'src/core/ext/census/context.c',
+ 'src/core/ext/census/gen/census.pb.c',
'src/core/ext/census/grpc_context.c',
'src/core/ext/census/grpc_filter.c',
'src/core/ext/census/grpc_plugin.c',
diff --git a/src/python/grpcio/tests/interop/client.py b/src/python/grpcio/tests/interop/client.py
index db29eb4aa7..8aa1ce30c1 100644
--- a/src/python/grpcio/tests/interop/client.py
+++ b/src/python/grpcio/tests/interop/client.py
@@ -65,39 +65,37 @@ def _args():
help='email address of the default service account', type=str)
return parser.parse_args()
-def _oauth_access_token(args):
- credentials = oauth2client_client.GoogleCredentials.get_application_default()
- scoped_credentials = credentials.create_scoped([args.oauth_scope])
- return scoped_credentials.get_access_token().access_token
def _stub(args):
- if args.oauth_scope:
- if args.test_case == 'oauth2_auth_token':
- # TODO(jtattermusch): This testcase sets the auth metadata key-value
- # manually, which also means that the user would need to do the same
- # thing every time he/she would like to use and out of band oauth token.
- # The transformer function that produces the metadata key-value from
- # the access token should be provided by gRPC auth library.
- access_token = _oauth_access_token(args)
- metadata_transformer = lambda x: [
- ('authorization', 'Bearer %s' % access_token)]
- else:
- metadata_transformer = lambda x: [
- ('authorization', 'Bearer %s' % _oauth_access_token(args))]
+ if args.test_case == 'oauth2_auth_token':
+ creds = oauth2client_client.GoogleCredentials.get_application_default()
+ scoped_creds = creds.create_scoped([args.oauth_scope])
+ access_token = scoped_creds.get_access_token().access_token
+ call_creds = implementations.access_token_call_credentials(access_token)
+ elif args.test_case == 'compute_engine_creds':
+ creds = oauth2client_client.GoogleCredentials.get_application_default()
+ scoped_creds = creds.create_scoped([args.oauth_scope])
+ call_creds = implementations.google_call_credentials(scoped_creds)
+ elif args.test_case == 'jwt_token_creds':
+ creds = oauth2client_client.GoogleCredentials.get_application_default()
+ call_creds = implementations.google_call_credentials(creds)
else:
- metadata_transformer = lambda x: []
+ call_creds = None
if args.use_tls:
if args.use_test_ca:
root_certificates = resources.test_root_certificates()
else:
root_certificates = None # will load default roots.
+ channel_creds = implementations.ssl_channel_credentials(root_certificates)
+ if call_creds is not None:
+ channel_creds = implementations.composite_channel_credentials(
+ channel_creds, call_creds)
+
channel = test_utilities.not_really_secure_channel(
- args.server_host, args.server_port,
- implementations.ssl_channel_credentials(root_certificates),
+ args.server_host, args.server_port, channel_creds,
args.server_host_override)
- stub = test_pb2.beta_create_TestService_stub(
- channel, metadata_transformer=metadata_transformer)
+ stub = test_pb2.beta_create_TestService_stub(channel)
else:
channel = implementations.insecure_channel(
args.server_host, args.server_port)
diff --git a/src/python/grpcio/tests/interop/methods.py b/src/python/grpcio/tests/interop/methods.py
index 67862ed7d3..7eac511525 100644
--- a/src/python/grpcio/tests/interop/methods.py
+++ b/src/python/grpcio/tests/interop/methods.py
@@ -39,6 +39,8 @@ import time
from oauth2client import client as oauth2client_client
+from grpc.beta import implementations
+from grpc.beta import interfaces
from grpc.framework.common import cardinality
from grpc.framework.interfaces.face import face
@@ -88,13 +90,15 @@ class TestService(test_pb2.BetaTestServiceServicer):
return self.FullDuplexCall(request_iterator, context)
-def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope):
+def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope,
+ protocol_options=None):
with stub:
request = messages_pb2.SimpleRequest(
response_type=messages_pb2.COMPRESSABLE, response_size=314159,
payload=messages_pb2.Payload(body=b'\x00' * 271828),
fill_username=fill_username, fill_oauth_scope=fill_oauth_scope)
- response_future = stub.UnaryCall.future(request, _TIMEOUT)
+ response_future = stub.UnaryCall.future(request, _TIMEOUT,
+ protocol_options=protocol_options)
response = response_future.result()
if response.payload.type is not messages_pb2.COMPRESSABLE:
raise ValueError(
@@ -303,7 +307,34 @@ def _oauth2_auth_token(stub, args):
if args.oauth_scope.find(response.oauth_scope) == -1:
raise ValueError(
'expected to find oauth scope "%s" in received "%s"' %
- (response.oauth_scope, args.oauth_scope))
+ (response.oauth_scope, args.oauth_scope))
+
+
+def _jwt_token_creds(stub, args):
+ json_key_filename = os.environ[
+ oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
+ wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
+ response = _large_unary_common_behavior(stub, True, False)
+ if wanted_email != response.username:
+ raise ValueError(
+ 'expected username %s, got %s' % (wanted_email, response.username))
+
+
+def _per_rpc_creds(stub, args):
+ json_key_filename = os.environ[
+ oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
+ wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
+ credentials = oauth2client_client.GoogleCredentials.get_application_default()
+ scoped_credentials = credentials.create_scoped([args.oauth_scope])
+ call_creds = implementations.google_call_credentials(scoped_credentials)
+ options = interfaces.grpc_call_options(disable_compression=False,
+ credentials=call_creds)
+ response = _large_unary_common_behavior(stub, True, False,
+ protocol_options=options)
+ if wanted_email != response.username:
+ raise ValueError(
+ 'expected username %s, got %s' % (wanted_email, response.username))
+
@enum.unique
class TestCase(enum.Enum):
@@ -317,6 +348,8 @@ class TestCase(enum.Enum):
EMPTY_STREAM = 'empty_stream'
COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
+ JWT_TOKEN_CREDS = 'jwt_token_creds'
+ PER_RPC_CREDS = 'per_rpc_creds'
TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
def test_interoperability(self, stub, args):
@@ -342,5 +375,9 @@ class TestCase(enum.Enum):
_compute_engine_creds(stub, args)
elif self is TestCase.OAUTH2_AUTH_TOKEN:
_oauth2_auth_token(stub, args)
+ elif self is TestCase.JWT_TOKEN_CREDS:
+ _jwt_token_creds(stub, args)
+ elif self is TestCase.PER_RPC_CREDS:
+ _per_rpc_creds(stub, args)
else:
raise NotImplementedError('Test case "%s" not implemented!' % self.name)
diff --git a/src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py b/src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py
new file mode 100644
index 0000000000..1c9cbb0d0c
--- /dev/null
+++ b/src/python/grpcio/tests/protoc_plugin/_python_plugin_test.py
@@ -0,0 +1,583 @@
+# Copyright 2015, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import collections
+from concurrent import futures
+import contextlib
+import distutils.spawn
+import errno
+import os
+import shutil
+import subprocess
+import sys
+import tempfile
+import threading
+import unittest
+
+from six import moves
+
+import grpc
+from tests.unit.framework.common import test_constants
+
+# Identifiers of entities we expect to find in the generated module.
+STUB_IDENTIFIER = 'TestServiceStub'
+SERVICER_IDENTIFIER = 'TestServiceServicer'
+ADD_SERVICER_TO_SERVER_IDENTIFIER = 'add_TestServiceServicer_to_server'
+
+
+class _ServicerMethods(object):
+
+ def __init__(self, response_pb2, payload_pb2):
+ self._condition = threading.Condition()
+ self._paused = False
+ self._fail = False
+ self._response_pb2 = response_pb2
+ self._payload_pb2 = payload_pb2
+
+ @contextlib.contextmanager
+ def pause(self): # pylint: disable=invalid-name
+ with self._condition:
+ self._paused = True
+ yield
+ with self._condition:
+ self._paused = False
+ self._condition.notify_all()
+
+ @contextlib.contextmanager
+ def fail(self): # pylint: disable=invalid-name
+ with self._condition:
+ self._fail = True
+ yield
+ with self._condition:
+ self._fail = False
+
+ def _control(self): # pylint: disable=invalid-name
+ with self._condition:
+ if self._fail:
+ raise ValueError()
+ while self._paused:
+ self._condition.wait()
+
+ def UnaryCall(self, request, unused_rpc_context):
+ response = self._response_pb2.SimpleResponse()
+ response.payload.payload_type = self._payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * request.response_size
+ self._control()
+ return response
+
+ def StreamingOutputCall(self, request, unused_rpc_context):
+ for parameter in request.response_parameters:
+ response = self._response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = self._payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ yield response
+
+ def StreamingInputCall(self, request_iter, unused_rpc_context):
+ response = self._response_pb2.StreamingInputCallResponse()
+ aggregated_payload_size = 0
+ for request in request_iter:
+ aggregated_payload_size += len(request.payload.payload_compressable)
+ response.aggregated_payload_size = aggregated_payload_size
+ self._control()
+ return response
+
+ def FullDuplexCall(self, request_iter, unused_rpc_context):
+ for request in request_iter:
+ for parameter in request.response_parameters:
+ response = self._response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = self._payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ yield response
+
+ def HalfDuplexCall(self, request_iter, unused_rpc_context):
+ responses = []
+ for request in request_iter:
+ for parameter in request.response_parameters:
+ response = self._response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = self._payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ responses.append(response)
+ for response in responses:
+ yield response
+
+
+class _Service(
+ collections.namedtuple(
+ '_Service', ('servicer_methods', 'server', 'stub',))):
+ """A live and running service.
+
+ Attributes:
+ servicer_methods: The _ServicerMethods servicing RPCs.
+ server: The grpc.Server servicing RPCs.
+ stub: A stub on which to invoke RPCs.
+ """
+
+
+def _CreateService(service_pb2, response_pb2, payload_pb2):
+ """Provides a servicer backend and a stub.
+
+ Args:
+ service_pb2: The service_pb2 module generated by this test.
+ response_pb2: The response_pb2 module generated by this test.
+ payload_pb2: The payload_pb2 module generated by this test.
+
+ Returns:
+ A _Service with which to test RPCs.
+ """
+ servicer_methods = _ServicerMethods(response_pb2, payload_pb2)
+
+ class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
+
+ def UnaryCall(self, request, context):
+ return servicer_methods.UnaryCall(request, context)
+
+ def StreamingOutputCall(self, request, context):
+ return servicer_methods.StreamingOutputCall(request, context)
+
+ def StreamingInputCall(self, request_iter, context):
+ return servicer_methods.StreamingInputCall(request_iter, context)
+
+ def FullDuplexCall(self, request_iter, context):
+ return servicer_methods.FullDuplexCall(request_iter, context)
+
+ def HalfDuplexCall(self, request_iter, context):
+ return servicer_methods.HalfDuplexCall(request_iter, context)
+
+ server = grpc.server(
+ (), futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
+ getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), server)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ stub = getattr(service_pb2, STUB_IDENTIFIER)(channel)
+ return _Service(servicer_methods, server, stub)
+
+
+def _CreateIncompleteService(service_pb2):
+ """Provides a servicer backend that fails to implement methods and its stub.
+
+ Args:
+ service_pb2: The service_pb2 module generated by this test.
+
+ Returns:
+ A _Service with which to test RPCs. The returned _Service's
+ servicer_methods implements none of the methods required of it.
+ """
+
+ class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
+ pass
+
+ server = grpc.server(
+ (), futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
+ getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), server)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ stub = getattr(service_pb2, STUB_IDENTIFIER)(channel)
+ return _Service(None, server, stub)
+
+
+def _streaming_input_request_iterator(request_pb2, payload_pb2):
+ for _ in range(3):
+ request = request_pb2.StreamingInputCallRequest()
+ request.payload.payload_type = payload_pb2.COMPRESSABLE
+ request.payload.payload_compressable = 'a'
+ yield request
+
+
+def _streaming_output_request(request_pb2):
+ request = request_pb2.StreamingOutputCallRequest()
+ sizes = [1, 2, 3]
+ request.response_parameters.add(size=sizes[0], interval_us=0)
+ request.response_parameters.add(size=sizes[1], interval_us=0)
+ request.response_parameters.add(size=sizes[2], interval_us=0)
+ return request
+
+
+def _full_duplex_request_iterator(request_pb2):
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=2, interval_us=0)
+ request.response_parameters.add(size=3, interval_us=0)
+ yield request
+
+
+class PythonPluginTest(unittest.TestCase):
+ """Test case for the gRPC Python protoc-plugin.
+
+ While reading these tests, remember that the futures API
+ (`stub.method.future()`) only gives futures for the *response-unary*
+ methods and does not exist for response-streaming methods.
+ """
+
+ def setUp(self):
+ # Assume that the appropriate protoc and grpc_python_plugins are on the
+ # path.
+ protoc_command = 'protoc'
+ protoc_plugin_filename = distutils.spawn.find_executable(
+ 'grpc_python_plugin')
+ if not os.path.isfile(protoc_command):
+ # Assume that if we haven't built protoc that it's on the system.
+ protoc_command = 'protoc'
+
+ # Ensure that the output directory exists.
+ self.outdir = tempfile.mkdtemp()
+
+ # Find all proto files
+ paths = []
+ root_dir = os.path.dirname(os.path.realpath(__file__))
+ proto_dir = os.path.join(root_dir, 'protos')
+ for walk_root, _, filenames in os.walk(proto_dir):
+ for filename in filenames:
+ if filename.endswith('.proto'):
+ path = os.path.join(walk_root, filename)
+ paths.append(path)
+
+ # Invoke protoc with the plugin.
+ cmd = [
+ protoc_command,
+ '--plugin=protoc-gen-python-grpc=%s' % protoc_plugin_filename,
+ '-I %s' % root_dir,
+ '--python_out=%s' % self.outdir,
+ '--python-grpc_out=%s' % self.outdir
+ ] + paths
+ subprocess.check_call(' '.join(cmd), shell=True, env=os.environ,
+ cwd=os.path.dirname(os.path.realpath(__file__)))
+
+ # Generated proto directories dont include __init__.py, but
+ # these are needed for python package resolution
+ for walk_root, _, _ in os.walk(os.path.join(self.outdir, 'protos')):
+ path = os.path.join(walk_root, '__init__.py')
+ open(path, 'a').close()
+
+ sys.path.insert(0, self.outdir)
+
+ import protos.payload.test_payload_pb2 as payload_pb2
+ import protos.requests.r.test_requests_pb2 as request_pb2
+ import protos.responses.test_responses_pb2 as response_pb2
+ import protos.service.test_service_pb2 as service_pb2
+ self._payload_pb2 = payload_pb2
+ self._request_pb2 = request_pb2
+ self._response_pb2 = response_pb2
+ self._service_pb2 = service_pb2
+
+ def tearDown(self):
+ try:
+ shutil.rmtree(self.outdir)
+ except OSError as exc:
+ if exc.errno != errno.ENOENT:
+ raise
+ sys.path.remove(self.outdir)
+
+ def testImportAttributes(self):
+ # check that we can access the generated module and its members.
+ self.assertIsNotNone(
+ getattr(self._service_pb2, STUB_IDENTIFIER, None))
+ self.assertIsNotNone(
+ getattr(self._service_pb2, SERVICER_IDENTIFIER, None))
+ self.assertIsNotNone(
+ getattr(self._service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER, None))
+
+ def testUpDown(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ self.assertIsNotNone(service.servicer_methods)
+ self.assertIsNotNone(service.server)
+ self.assertIsNotNone(service.stub)
+
+ def testIncompleteServicer(self):
+ service = _CreateIncompleteService(self._service_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ service.stub.UnaryCall(request)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.UNIMPLEMENTED)
+
+ def testUnaryCall(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ response = service.stub.UnaryCall(request)
+ expected_response = service.servicer_methods.UnaryCall(
+ request, 'not a real context!')
+ self.assertEqual(expected_response, response)
+
+ def testUnaryCallFuture(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ # Check that the call does not block waiting for the server to respond.
+ with service.servicer_methods.pause():
+ response_future = service.stub.UnaryCall.future(request)
+ response = response_future.result()
+ expected_response = service.servicer_methods.UnaryCall(
+ request, 'not a real RpcContext!')
+ self.assertEqual(expected_response, response)
+
+ def testUnaryCallFutureExpired(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ with service.servicer_methods.pause():
+ response_future = service.stub.UnaryCall.future(
+ request, timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+ self.assertIs(response_future.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testUnaryCallFutureCancelled(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ with service.servicer_methods.pause():
+ response_future = service.stub.UnaryCall.future(request)
+ response_future.cancel()
+ self.assertTrue(response_future.cancelled())
+ self.assertIs(response_future.code(), grpc.StatusCode.CANCELLED)
+
+ def testUnaryCallFutureFailed(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = self._request_pb2.SimpleRequest(response_size=13)
+ with service.servicer_methods.fail():
+ response_future = service.stub.UnaryCall.future(request)
+ self.assertIsNotNone(response_future.exception())
+ self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN)
+
+ def testStreamingOutputCall(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = _streaming_output_request(self._request_pb2)
+ responses = service.stub.StreamingOutputCall(request)
+ expected_responses = service.servicer_methods.StreamingOutputCall(
+ request, 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(
+ expected_responses, responses):
+ self.assertEqual(expected_response, response)
+
+ def testStreamingOutputCallExpired(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = _streaming_output_request(self._request_pb2)
+ with service.servicer_methods.pause():
+ responses = service.stub.StreamingOutputCall(
+ request, timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ list(responses)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testStreamingOutputCallCancelled(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = _streaming_output_request(self._request_pb2)
+ responses = service.stub.StreamingOutputCall(request)
+ next(responses)
+ responses.cancel()
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(responses.code(), grpc.StatusCode.CANCELLED)
+
+ def testStreamingOutputCallFailed(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request = _streaming_output_request(self._request_pb2)
+ with service.servicer_methods.fail():
+ responses = service.stub.StreamingOutputCall(request)
+ self.assertIsNotNone(responses)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(exception_context.exception.code(), grpc.StatusCode.UNKNOWN)
+
+ def testStreamingInputCall(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ response = service.stub.StreamingInputCall(
+ _streaming_input_request_iterator(
+ self._request_pb2, self._payload_pb2))
+ expected_response = service.servicer_methods.StreamingInputCall(
+ _streaming_input_request_iterator(self._request_pb2, self._payload_pb2),
+ 'not a real RpcContext!')
+ self.assertEqual(expected_response, response)
+
+ def testStreamingInputCallFuture(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.pause():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(
+ self._request_pb2, self._payload_pb2))
+ response = response_future.result()
+ expected_response = service.servicer_methods.StreamingInputCall(
+ _streaming_input_request_iterator(self._request_pb2, self._payload_pb2),
+ 'not a real RpcContext!')
+ self.assertEqual(expected_response, response)
+
+ def testStreamingInputCallFutureExpired(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.pause():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(
+ self._request_pb2, self._payload_pb2),
+ timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIs(
+ response_future.exception().code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testStreamingInputCallFutureCancelled(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.pause():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(
+ self._request_pb2, self._payload_pb2))
+ response_future.cancel()
+ self.assertTrue(response_future.cancelled())
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.result()
+
+ def testStreamingInputCallFutureFailed(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.fail():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(
+ self._request_pb2, self._payload_pb2))
+ self.assertIsNotNone(response_future.exception())
+ self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN)
+
+ def testFullDuplexCall(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ responses = service.stub.FullDuplexCall(
+ _full_duplex_request_iterator(self._request_pb2))
+ expected_responses = service.servicer_methods.FullDuplexCall(
+ _full_duplex_request_iterator(self._request_pb2),
+ 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(
+ expected_responses, responses):
+ self.assertEqual(expected_response, response)
+
+ def testFullDuplexCallExpired(self):
+ request_iterator = _full_duplex_request_iterator(self._request_pb2)
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.pause():
+ responses = service.stub.FullDuplexCall(
+ request_iterator, timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ list(responses)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testFullDuplexCallCancelled(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ request_iterator = _full_duplex_request_iterator(self._request_pb2)
+ responses = service.stub.FullDuplexCall(request_iterator)
+ next(responses)
+ responses.cancel()
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.CANCELLED)
+
+ def testFullDuplexCallFailed(self):
+ request_iterator = _full_duplex_request_iterator(self._request_pb2)
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with service.servicer_methods.fail():
+ responses = service.stub.FullDuplexCall(request_iterator)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(exception_context.exception.code(), grpc.StatusCode.UNKNOWN)
+
+ def testHalfDuplexCall(self):
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ def half_duplex_request_iterator():
+ request = self._request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ request = self._request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=2, interval_us=0)
+ request.response_parameters.add(size=3, interval_us=0)
+ yield request
+ responses = service.stub.HalfDuplexCall(half_duplex_request_iterator())
+ expected_responses = service.servicer_methods.HalfDuplexCall(
+ half_duplex_request_iterator(), 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(
+ expected_responses, responses):
+ self.assertEqual(expected_response, response)
+
+ def testHalfDuplexCallWedged(self):
+ condition = threading.Condition()
+ wait_cell = [False]
+ @contextlib.contextmanager
+ def wait(): # pylint: disable=invalid-name
+ # Where's Python 3's 'nonlocal' statement when you need it?
+ with condition:
+ wait_cell[0] = True
+ yield
+ with condition:
+ wait_cell[0] = False
+ condition.notify_all()
+ def half_duplex_request_iterator():
+ request = self._request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ with condition:
+ while wait_cell[0]:
+ condition.wait()
+ service = _CreateService(
+ self._service_pb2, self._response_pb2, self._payload_pb2)
+ with wait():
+ responses = service.stub.HalfDuplexCall(
+ half_duplex_request_iterator(), timeout=test_constants.SHORT_TIMEOUT)
+ # half-duplex waits for the client to send all info
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(
+ exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/qps/benchmark_client.py b/src/python/grpcio/tests/qps/benchmark_client.py
index b372ea01ad..080281415d 100644
--- a/src/python/grpcio/tests/qps/benchmark_client.py
+++ b/src/python/grpcio/tests/qps/benchmark_client.py
@@ -30,14 +30,13 @@
"""Defines test client behaviors (UNARY/STREAMING) (SYNC/ASYNC)."""
import abc
+import threading
import time
-try:
- import Queue as queue # Python 2.x
-except ImportError:
- import queue # Python 3
from concurrent import futures
+from six.moves import queue
+import grpc
from grpc.beta import implementations
from grpc.framework.interfaces.face import face
from src.proto.grpc.testing import messages_pb2
@@ -65,6 +64,13 @@ class BenchmarkClient:
else:
channel = implementations.insecure_channel(host, port)
+ connected_event = threading.Event()
+ def wait_for_ready(connectivity):
+ if connectivity == grpc.ChannelConnectivity.READY:
+ connected_event.set()
+ channel.subscribe(wait_for_ready, try_to_connect=True)
+ connected_event.wait()
+
if config.payload_config.WhichOneof('payload') == 'simple_params':
self._generic = False
self._stub = services_pb2.beta_create_BenchmarkService_stub(channel)
@@ -82,6 +88,7 @@ class BenchmarkClient:
self._response_callbacks = []
def add_response_callback(self, callback):
+ """callback will be invoked as callback(client, query_time)"""
self._response_callbacks.append(callback)
@abc.abstractmethod
@@ -95,10 +102,10 @@ class BenchmarkClient:
def stop(self):
pass
- def _handle_response(self, query_time):
+ def _handle_response(self, client, query_time):
self._hist.add(query_time * 1e9) # Report times in nanoseconds
for callback in self._response_callbacks:
- callback(query_time)
+ callback(client, query_time)
class UnarySyncBenchmarkClient(BenchmarkClient):
@@ -121,7 +128,7 @@ class UnarySyncBenchmarkClient(BenchmarkClient):
start_time = time.time()
self._stub.UnaryCall(self._request, _TIMEOUT)
end_time = time.time()
- self._handle_response(end_time - start_time)
+ self._handle_response(self, end_time - start_time)
class UnaryAsyncBenchmarkClient(BenchmarkClient):
@@ -136,19 +143,20 @@ class UnaryAsyncBenchmarkClient(BenchmarkClient):
def _response_received(self, start_time, resp):
resp.result()
end_time = time.time()
- self._handle_response(end_time - start_time)
+ self._handle_response(self, end_time - start_time)
def stop(self):
self._stub = None
-class StreamingSyncBenchmarkClient(BenchmarkClient):
+class _SyncStream(object):
- def __init__(self, server, config, hist):
- super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
+ def __init__(self, stub, generic, request, handle_response):
+ self._stub = stub
+ self._generic = generic
+ self._request = request
+ self._handle_response = handle_response
self._is_streaming = False
- self._pool = futures.ThreadPoolExecutor(max_workers=1)
- # Use a thread-safe queue to put requests on the stream
self._request_queue = queue.Queue()
self._send_time_queue = queue.Queue()
@@ -158,15 +166,6 @@ class StreamingSyncBenchmarkClient(BenchmarkClient):
def start(self):
self._is_streaming = True
- self._pool.submit(self._request_stream)
-
- def stop(self):
- self._is_streaming = False
- self._pool.shutdown(wait=True)
- self._stub = None
-
- def _request_stream(self):
- self._is_streaming = True
if self._generic:
stream_callable = self._stub.stream_stream(
'grpc.testing.BenchmarkService', 'StreamingCall')
@@ -175,8 +174,11 @@ class StreamingSyncBenchmarkClient(BenchmarkClient):
response_stream = stream_callable(self._request_generator(), _TIMEOUT)
for _ in response_stream:
- end_time = time.time()
- self._handle_response(end_time - self._send_time_queue.get_nowait())
+ self._handle_response(
+ self, time.time() - self._send_time_queue.get_nowait())
+
+ def stop(self):
+ self._is_streaming = False
def _request_generator(self):
while self._is_streaming:
@@ -187,46 +189,28 @@ class StreamingSyncBenchmarkClient(BenchmarkClient):
pass
-class AsyncReceiver(face.ResponseReceiver):
- """Receiver for async stream responses."""
-
- def __init__(self, send_time_queue, response_handler):
- self._send_time_queue = send_time_queue
- self._response_handler = response_handler
-
- def initial_metadata(self, initial_mdetadata):
- pass
-
- def response(self, response):
- end_time = time.time()
- self._response_handler(end_time - self._send_time_queue.get_nowait())
-
- def complete(self, terminal_metadata, code, details):
- pass
-
-
-class StreamingAsyncBenchmarkClient(BenchmarkClient):
+class StreamingSyncBenchmarkClient(BenchmarkClient):
def __init__(self, server, config, hist):
- super(StreamingAsyncBenchmarkClient, self).__init__(server, config, hist)
- self._send_time_queue = queue.Queue()
- self._receiver = AsyncReceiver(self._send_time_queue, self._handle_response)
- self._rendezvous = None
+ super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
+ self._pool = futures.ThreadPoolExecutor(
+ max_workers=config.outstanding_rpcs_per_channel)
+ self._streams = [_SyncStream(self._stub, self._generic,
+ self._request, self._handle_response)
+ for _ in xrange(config.outstanding_rpcs_per_channel)]
+ self._curr_stream = 0
def send_request(self):
- if self._rendezvous is not None:
- self._send_time_queue.put(time.time())
- self._rendezvous.consume(self._request)
+ # Use a round_robin scheduler to determine what stream to send on
+ self._streams[self._curr_stream].send_request()
+ self._curr_stream = (self._curr_stream + 1) % len(self._streams)
def start(self):
- if self._generic:
- stream_callable = self._stub.stream_stream(
- 'grpc.testing.BenchmarkService', 'StreamingCall')
- else:
- stream_callable = self._stub.StreamingCall
- self._rendezvous = stream_callable.event(
- self._receiver, lambda *args: None, _TIMEOUT)
+ for stream in self._streams:
+ self._pool.submit(stream.start)
def stop(self):
- self._rendezvous.terminate()
- self._rendezvous = None
+ for stream in self._streams:
+ stream.stop()
+ self._pool.shutdown(wait=True)
+ self._stub = None
diff --git a/src/python/grpcio/tests/qps/client_runner.py b/src/python/grpcio/tests/qps/client_runner.py
index 1ede7d2af1..1fd58687ad 100644
--- a/src/python/grpcio/tests/qps/client_runner.py
+++ b/src/python/grpcio/tests/qps/client_runner.py
@@ -34,7 +34,7 @@ ClientRunner invokes either periodically or in response to some event.
"""
import abc
-import thread
+import threading
import time
@@ -61,15 +61,18 @@ class OpenLoopClientRunner(ClientRunner):
super(OpenLoopClientRunner, self).__init__(client)
self._is_running = False
self._interval_generator = interval_generator
+ self._dispatch_thread = threading.Thread(
+ target=self._dispatch_requests, args=())
def start(self):
self._is_running = True
self._client.start()
- thread.start_new_thread(self._dispatch_requests, ())
-
+ self._dispatch_thread.start()
+
def stop(self):
self._is_running = False
self._client.stop()
+ self._dispatch_thread.join()
self._client = None
def _dispatch_requests(self):
@@ -98,7 +101,6 @@ class ClosedLoopClientRunner(ClientRunner):
self._client.stop()
self._client = None
- def _send_request(self, response_time):
+ def _send_request(self, client, response_time):
if self._is_running:
- self._client.send_request()
-
+ client.send_request()
diff --git a/src/python/grpcio/tests/qps/qps_worker.py b/src/python/grpcio/tests/qps/qps_worker.py
index 3dda718638..16926379a5 100644
--- a/src/python/grpcio/tests/qps/qps_worker.py
+++ b/src/python/grpcio/tests/qps/qps_worker.py
@@ -43,9 +43,7 @@ def run_worker_server(port):
server.add_insecure_port('[::]:{}'.format(port))
server.start()
servicer.wait_for_quit()
- # Drain outstanding requests for clean exit
- time.sleep(2)
- server.stop(0)
+ server.stop(2)
if __name__ == '__main__':
diff --git a/src/python/grpcio/tests/qps/worker_server.py b/src/python/grpcio/tests/qps/worker_server.py
index 1f9af5482c..d41f8377c2 100644
--- a/src/python/grpcio/tests/qps/worker_server.py
+++ b/src/python/grpcio/tests/qps/worker_server.py
@@ -153,9 +153,8 @@ class WorkerServer(services_pb2.BetaWorkerServiceServicer):
if config.rpc_type == control_pb2.UNARY:
client = benchmark_client.UnaryAsyncBenchmarkClient(
server, config, qps_data)
- elif config.rpc_type == control_pb2.STREAMING:
- client = benchmark_client.StreamingAsyncBenchmarkClient(
- server, config, qps_data)
+ else:
+ raise Exception('Async streaming client not supported')
else:
raise Exception('Unsupported client type {}'.format(config.client_type))
diff --git a/src/python/grpcio/tests/stress/client.py b/src/python/grpcio/tests/stress/client.py
index e2e016760c..0de2532cd8 100644
--- a/src/python/grpcio/tests/stress/client.py
+++ b/src/python/grpcio/tests/stress/client.py
@@ -30,10 +30,10 @@
"""Entry point for running stress tests."""
import argparse
-import Queue
import threading
from grpc.beta import implementations
+from six.moves import queue
from src.proto.grpc.testing import metrics_pb2
from src.proto.grpc.testing import test_pb2
@@ -94,7 +94,7 @@ def run_test(args):
test_cases = _parse_weighted_test_cases(args.test_cases)
test_servers = args.server_addresses.split(',')
# Propagate any client exceptions with a queue
- exception_queue = Queue.Queue()
+ exception_queue = queue.Queue()
stop_event = threading.Event()
hist = histogram.Histogram(1, 1)
runners = []
@@ -121,7 +121,7 @@ def run_test(args):
if timeout_secs < 0:
timeout_secs = None
raise exception_queue.get(block=True, timeout=timeout_secs)
- except Queue.Empty:
+ except queue.Empty:
# No exceptions thrown, success
pass
finally:
diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json
index 691062f25a..20d11b20b5 100644
--- a/src/python/grpcio/tests/tests.json
+++ b/src/python/grpcio/tests/tests.json
@@ -1,12 +1,20 @@
[
+ "_api_test.AllTest",
+ "_api_test.ChannelConnectivityTest",
+ "_auth_test.AccessTokenCallCredentialsTest",
+ "_auth_test.GoogleCallCredentialsTest",
"_base_interface_test.AsyncEasyTest",
"_base_interface_test.AsyncPeasyTest",
"_base_interface_test.SyncEasyTest",
"_base_interface_test.SyncPeasyTest",
"_beta_features_test.BetaFeaturesTest",
"_beta_features_test.ContextManagementAndLifecycleTest",
+ "_cancel_many_calls_test.CancelManyCallsTest",
+ "_channel_connectivity_test.ChannelConnectivityTest",
+ "_channel_ready_future_test.ChannelReadyFutureTest",
"_channel_test.ChannelTest",
"_connectivity_channel_test.ChannelConnectivityTest",
+ "_connectivity_channel_test.ConnectivityStatesTest",
"_core_over_links_base_interface_test.AsyncEasyTest",
"_core_over_links_base_interface_test.AsyncPeasyTest",
"_core_over_links_base_interface_test.SyncEasyTest",
@@ -23,6 +31,7 @@
"_crust_over_core_over_links_face_interface_test.GenericInvokerFutureInvocationAsynchronousEventServiceTest",
"_crust_over_core_over_links_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest",
"_crust_over_core_over_links_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest",
+ "_empty_message_test.EmptyMessageTest",
"_face_interface_test.DynamicInvokerBlockingInvocationInlineServiceTest",
"_face_interface_test.DynamicInvokerFutureInvocationAsynchronousEventServiceTest",
"_face_interface_test.GenericInvokerBlockingInvocationInlineServiceTest",
@@ -30,6 +39,7 @@
"_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest",
"_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest",
"_health_servicer_test.HealthServicerTest",
+ "_implementations_test.CallCredentialsTest",
"_implementations_test.ChannelCredentialsTest",
"_insecure_interop_test.InsecureInteropTest",
"_intermediary_low_test.CancellationTest",
@@ -41,9 +51,14 @@
"_lonely_invocation_link_test.LonelyInvocationLinkTest",
"_low_test.HangingServerShutdown",
"_low_test.InsecureServerInsecureClient",
+ "_metadata_test.MetadataTest",
"_not_found_test.NotFoundTest",
+ "_python_plugin_test.PythonPluginTest",
+ "_read_some_but_not_all_responses_test.ReadSomeButNotAllResponsesTest",
+ "_rpc_test.RPCTest",
"_sanity_test.Sanity",
"_secure_interop_test.SecureInteropTest",
+ "_thread_cleanup_test.CleanupThreadTest",
"_transmission_test.RoundTripTest",
"_transmission_test.TransmissionTest",
"_utilities_test.ChannelConnectivityTest",
diff --git a/src/python/grpcio/tests/unit/_adapter/_intermediary_low_test.py b/src/python/grpcio/tests/unit/_adapter/_intermediary_low_test.py
index 22d4b019c7..09ebdeff33 100644
--- a/src/python/grpcio/tests/unit/_adapter/_intermediary_low_test.py
+++ b/src/python/grpcio/tests/unit/_adapter/_intermediary_low_test.py
@@ -164,15 +164,15 @@ class EchoTest(unittest.TestCase):
self.assertIsNotNone(service_accepted)
self.assertIs(service_accepted.kind, _low.Event.Kind.SERVICE_ACCEPTED)
self.assertIs(service_accepted.tag, service_tag)
- self.assertEqual(method, service_accepted.service_acceptance.method)
- self.assertEqual(self.host, service_accepted.service_acceptance.host)
+ self.assertEqual(method.encode(), service_accepted.service_acceptance.method)
+ self.assertEqual(self.host.encode(), service_accepted.service_acceptance.host)
self.assertIsNotNone(service_accepted.service_acceptance.call)
metadata = dict(service_accepted.metadata)
- self.assertIn(client_metadata_key, metadata)
- self.assertEqual(client_metadata_value, metadata[client_metadata_key])
- self.assertIn(client_binary_metadata_key, metadata)
+ self.assertIn(client_metadata_key.encode(), metadata)
+ self.assertEqual(client_metadata_value.encode(), metadata[client_metadata_key.encode()])
+ self.assertIn(client_binary_metadata_key.encode(), metadata)
self.assertEqual(client_binary_metadata_value,
- metadata[client_binary_metadata_key])
+ metadata[client_binary_metadata_key.encode()])
server_call = service_accepted.service_acceptance.call
server_call.accept(self.server_completion_queue, finish_tag)
server_call.add_metadata(server_leading_metadata_key,
@@ -186,12 +186,12 @@ class EchoTest(unittest.TestCase):
self.assertEqual(_low.Event.Kind.METADATA_ACCEPTED, metadata_accepted.kind)
self.assertEqual(metadata_tag, metadata_accepted.tag)
metadata = dict(metadata_accepted.metadata)
- self.assertIn(server_leading_metadata_key, metadata)
- self.assertEqual(server_leading_metadata_value,
- metadata[server_leading_metadata_key])
- self.assertIn(server_leading_binary_metadata_key, metadata)
+ self.assertIn(server_leading_metadata_key.encode(), metadata)
+ self.assertEqual(server_leading_metadata_value.encode(),
+ metadata[server_leading_metadata_key.encode()])
+ self.assertIn(server_leading_binary_metadata_key.encode(), metadata)
self.assertEqual(server_leading_binary_metadata_value,
- metadata[server_leading_binary_metadata_key])
+ metadata[server_leading_binary_metadata_key.encode()])
for datum in test_data:
client_call.write(datum, write_tag, _low.WriteFlags.WRITE_NO_COMPRESS)
@@ -277,17 +277,17 @@ class EchoTest(unittest.TestCase):
self.assertIsNone(read_accepted.bytes)
self.assertEqual(_low.Event.Kind.FINISH, finish_accepted.kind)
self.assertEqual(finish_tag, finish_accepted.tag)
- self.assertEqual(_low.Status(_low.Code.OK, details), finish_accepted.status)
+ self.assertEqual(_low.Status(_low.Code.OK, details.encode()), finish_accepted.status)
metadata = dict(finish_accepted.metadata)
- self.assertIn(server_trailing_metadata_key, metadata)
- self.assertEqual(server_trailing_metadata_value,
- metadata[server_trailing_metadata_key])
- self.assertIn(server_trailing_binary_metadata_key, metadata)
+ self.assertIn(server_trailing_metadata_key.encode(), metadata)
+ self.assertEqual(server_trailing_metadata_value.encode(),
+ metadata[server_trailing_metadata_key.encode()])
+ self.assertIn(server_trailing_binary_metadata_key.encode(), metadata)
self.assertEqual(server_trailing_binary_metadata_value,
- metadata[server_trailing_binary_metadata_key])
+ metadata[server_trailing_binary_metadata_key.encode()])
self.assertSetEqual(set(key for key, _ in finish_accepted.metadata),
- set((server_trailing_metadata_key,
- server_trailing_binary_metadata_key,)))
+ set((server_trailing_metadata_key.encode(),
+ server_trailing_binary_metadata_key.encode(),)))
self.assertSequenceEqual(test_data, server_data)
self.assertSequenceEqual(test_data, client_data)
@@ -302,7 +302,8 @@ class EchoTest(unittest.TestCase):
self._perform_echo_test([_BYTE_SEQUENCE])
def testManyOneByteEchoes(self):
- self._perform_echo_test(_BYTE_SEQUENCE)
+ self._perform_echo_test(
+ [_BYTE_SEQUENCE[i:i+1] for i in range(len(_BYTE_SEQUENCE))])
def testManyManyByteEchoes(self):
self._perform_echo_test(_BYTE_SEQUENCE_SEQUENCE)
@@ -409,7 +410,7 @@ class CancellationTest(unittest.TestCase):
finish_event = self.client_events.get()
self.assertEqual(_low.Event.Kind.FINISH, finish_event.kind)
- self.assertEqual(_low.Status(_low.Code.CANCELLED, 'Cancelled'),
+ self.assertEqual(_low.Status(_low.Code.CANCELLED, b'Cancelled'),
finish_event.status)
self.assertSequenceEqual(test_data, server_data)
diff --git a/src/python/grpcio/tests/unit/_adapter/_low_test.py b/src/python/grpcio/tests/unit/_adapter/_low_test.py
index ec46617996..e09a1f2564 100644
--- a/src/python/grpcio/tests/unit/_adapter/_low_test.py
+++ b/src/python/grpcio/tests/unit/_adapter/_low_test.py
@@ -148,11 +148,11 @@ class InsecureServerInsecureClient(unittest.TestCase):
# Check that Python's user agent string is a part of the full user agent
# string
received_initial_metadata_dict = dict(received_initial_metadata)
- self.assertIn('user-agent', received_initial_metadata_dict)
- self.assertIn('Python-gRPC-{}'.format(_grpcio_metadata.__version__),
- received_initial_metadata_dict['user-agent'])
- self.assertEqual(method, request_event.call_details.method)
- self.assertEqual(host, request_event.call_details.host)
+ self.assertIn(b'user-agent', received_initial_metadata_dict)
+ self.assertIn('Python-gRPC-{}'.format(_grpcio_metadata.__version__).encode(),
+ received_initial_metadata_dict[b'user-agent'])
+ self.assertEqual(method.encode(), request_event.call_details.method)
+ self.assertEqual(host.encode(), request_event.call_details.host)
self.assertLess(abs(deadline - request_event.call_details.deadline),
deadline_tolerance)
@@ -198,12 +198,12 @@ class InsecureServerInsecureClient(unittest.TestCase):
test_common.metadata_transmitted(server_initial_metadata,
client_result.initial_metadata))
elif client_result.type == _types.OpType.RECV_MESSAGE:
- self.assertEqual(response, client_result.message)
+ self.assertEqual(response.encode(), client_result.message)
elif client_result.type == _types.OpType.RECV_STATUS_ON_CLIENT:
self.assertTrue(
test_common.metadata_transmitted(server_trailing_metadata,
client_result.trailing_metadata))
- self.assertEqual(server_status_details, client_result.status.details)
+ self.assertEqual(server_status_details.encode(), client_result.status.details)
self.assertEqual(server_status_code, client_result.status.code)
self.assertEqual(set([
_types.OpType.SEND_INITIAL_METADATA,
@@ -220,7 +220,7 @@ class InsecureServerInsecureClient(unittest.TestCase):
self.assertNotIn(client_result.type, found_server_op_types)
found_server_op_types.add(server_result.type)
if server_result.type == _types.OpType.RECV_MESSAGE:
- self.assertEqual(request, server_result.message)
+ self.assertEqual(request.encode(), server_result.message)
elif server_result.type == _types.OpType.RECV_CLOSE_ON_SERVER:
self.assertFalse(server_result.cancelled)
self.assertEqual(set([
diff --git a/src/python/grpcio/tests/unit/_api_test.py b/src/python/grpcio/tests/unit/_api_test.py
new file mode 100644
index 0000000000..1501ec2a04
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_api_test.py
@@ -0,0 +1,104 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test of gRPC Python's application-layer API."""
+
+import unittest
+
+import six
+
+import grpc
+
+from tests.unit import _from_grpc_import_star
+
+
+class AllTest(unittest.TestCase):
+
+ def testAll(self):
+ expected_grpc_code_elements = (
+ 'FutureTimeoutError',
+ 'FutureCancelledError',
+ 'Future',
+ 'ChannelConnectivity',
+ 'StatusCode',
+ 'RpcError',
+ 'RpcContext',
+ 'Call',
+ 'ChannelCredentials',
+ 'CallCredentials',
+ 'AuthMetadataContext',
+ 'AuthMetadataPluginCallback',
+ 'AuthMetadataPlugin',
+ 'ServerCredentials',
+ 'UnaryUnaryMultiCallable',
+ 'UnaryStreamMultiCallable',
+ 'StreamUnaryMultiCallable',
+ 'StreamStreamMultiCallable',
+ 'Channel',
+ 'ServicerContext',
+ 'RpcMethodHandler',
+ 'HandlerCallDetails',
+ 'GenericRpcHandler',
+ 'Server',
+ 'unary_unary_rpc_method_handler',
+ 'unary_stream_rpc_method_handler',
+ 'stream_unary_rpc_method_handler',
+ 'stream_stream_rpc_method_handler',
+ 'method_handlers_generic_handler',
+ 'ssl_channel_credentials',
+ 'metadata_call_credentials',
+ 'access_token_call_credentials',
+ 'composite_call_credentials',
+ 'composite_channel_credentials',
+ 'ssl_server_credentials',
+ 'channel_ready_future',
+ 'insecure_channel',
+ 'secure_channel',
+ 'server',
+ )
+
+ six.assertCountEqual(
+ self, expected_grpc_code_elements,
+ _from_grpc_import_star.GRPC_ELEMENTS)
+
+
+class ChannelConnectivityTest(unittest.TestCase):
+
+ def testChannelConnectivity(self):
+ self.assertSequenceEqual(
+ (grpc.ChannelConnectivity.IDLE,
+ grpc.ChannelConnectivity.CONNECTING,
+ grpc.ChannelConnectivity.READY,
+ grpc.ChannelConnectivity.TRANSIENT_FAILURE,
+ grpc.ChannelConnectivity.SHUTDOWN,),
+ tuple(grpc.ChannelConnectivity))
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_auth_test.py b/src/python/grpcio/tests/unit/_auth_test.py
new file mode 100644
index 0000000000..c31f7b06f7
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_auth_test.py
@@ -0,0 +1,96 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests of standard AuthMetadataPlugins."""
+
+import collections
+import threading
+import unittest
+
+from grpc import _auth
+
+
+class MockGoogleCreds(object):
+
+ def get_access_token(self):
+ token = collections.namedtuple('MockAccessTokenInfo',
+ ('access_token', 'expires_in'))
+ token.access_token = 'token'
+ return token
+
+
+class MockExceptionGoogleCreds(object):
+
+ def get_access_token(self):
+ raise Exception()
+
+
+class GoogleCallCredentialsTest(unittest.TestCase):
+
+ def test_google_call_credentials_success(self):
+ callback_event = threading.Event()
+
+ def mock_callback(metadata, error):
+ self.assertEqual(metadata, (('authorization', 'Bearer token'),))
+ self.assertIsNone(error)
+ callback_event.set()
+
+ call_creds = _auth.GoogleCallCredentials(MockGoogleCreds())
+ call_creds(None, mock_callback)
+ self.assertTrue(callback_event.wait(1.0))
+
+ def test_google_call_credentials_error(self):
+ callback_event = threading.Event()
+
+ def mock_callback(metadata, error):
+ self.assertIsNotNone(error)
+ callback_event.set()
+
+ call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds())
+ call_creds(None, mock_callback)
+ self.assertTrue(callback_event.wait(1.0))
+
+
+class AccessTokenCallCredentialsTest(unittest.TestCase):
+
+ def test_google_call_credentials_success(self):
+ callback_event = threading.Event()
+
+ def mock_callback(metadata, error):
+ self.assertEqual(metadata, (('authorization', 'Bearer token'),))
+ self.assertIsNone(error)
+ callback_event.set()
+
+ call_creds = _auth.AccessTokenCallCredentials('token')
+ call_creds(None, mock_callback)
+ self.assertTrue(callback_event.wait(1.0))
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_channel_connectivity_test.py b/src/python/grpcio/tests/unit/_channel_connectivity_test.py
new file mode 100644
index 0000000000..a1575efada
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_channel_connectivity_test.py
@@ -0,0 +1,161 @@
+# Copyright 2015, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests of grpc._channel.Channel connectivity."""
+
+import threading
+import time
+import unittest
+from concurrent import futures
+
+import grpc
+from grpc import _channel
+from grpc import _server
+from tests.unit.framework.common import test_constants
+
+
+def _ready_in_connectivities(connectivities):
+ return grpc.ChannelConnectivity.READY in connectivities
+
+
+def _last_connectivity_is_not_ready(connectivities):
+ return connectivities[-1] is not grpc.ChannelConnectivity.READY
+
+
+class _Callback(object):
+
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._connectivities = []
+
+ def update(self, connectivity):
+ with self._condition:
+ self._connectivities.append(connectivity)
+ self._condition.notify()
+
+ def connectivities(self):
+ with self._condition:
+ return tuple(self._connectivities)
+
+ def block_until_connectivities_satisfy(self, predicate):
+ with self._condition:
+ while True:
+ connectivities = tuple(self._connectivities)
+ if predicate(connectivities):
+ return connectivities
+ else:
+ self._condition.wait()
+
+
+class ChannelConnectivityTest(unittest.TestCase):
+
+ def test_lonely_channel_connectivity(self):
+ callback = _Callback()
+
+ channel = _channel.Channel('localhost:12345', None, None)
+ channel.subscribe(callback.update, try_to_connect=False)
+ first_connectivities = callback.block_until_connectivities_satisfy(bool)
+ channel.subscribe(callback.update, try_to_connect=True)
+ second_connectivities = callback.block_until_connectivities_satisfy(
+ lambda connectivities: 2 <= len(connectivities))
+ # Wait for a connection that will never happen.
+ time.sleep(test_constants.SHORT_TIMEOUT)
+ third_connectivities = callback.connectivities()
+ channel.unsubscribe(callback.update)
+ fourth_connectivities = callback.connectivities()
+ channel.unsubscribe(callback.update)
+ fifth_connectivities = callback.connectivities()
+
+ self.assertSequenceEqual(
+ (grpc.ChannelConnectivity.IDLE,), first_connectivities)
+ self.assertNotIn(
+ grpc.ChannelConnectivity.READY, second_connectivities)
+ self.assertNotIn(
+ grpc.ChannelConnectivity.READY, third_connectivities)
+ self.assertNotIn(
+ grpc.ChannelConnectivity.READY, fourth_connectivities)
+ self.assertNotIn(
+ grpc.ChannelConnectivity.READY, fifth_connectivities)
+
+ def test_immediately_connectable_channel_connectivity(self):
+ server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0))
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ first_callback = _Callback()
+ second_callback = _Callback()
+
+ channel = _channel.Channel('localhost:{}'.format(port), None, None)
+ channel.subscribe(first_callback.update, try_to_connect=False)
+ first_connectivities = first_callback.block_until_connectivities_satisfy(
+ bool)
+ # Wait for a connection that will never happen because try_to_connect=True
+ # has not yet been passed.
+ time.sleep(test_constants.SHORT_TIMEOUT)
+ second_connectivities = first_callback.connectivities()
+ channel.subscribe(second_callback.update, try_to_connect=True)
+ third_connectivities = first_callback.block_until_connectivities_satisfy(
+ lambda connectivities: 2 <= len(connectivities))
+ fourth_connectivities = second_callback.block_until_connectivities_satisfy(
+ bool)
+ # Wait for a connection that will happen (or may already have happened).
+ first_callback.block_until_connectivities_satisfy(_ready_in_connectivities)
+ second_callback.block_until_connectivities_satisfy(_ready_in_connectivities)
+ del channel
+
+ self.assertSequenceEqual(
+ (grpc.ChannelConnectivity.IDLE,), first_connectivities)
+ self.assertSequenceEqual(
+ (grpc.ChannelConnectivity.IDLE,), second_connectivities)
+ self.assertNotIn(
+ grpc.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities)
+ self.assertNotIn(
+ grpc.ChannelConnectivity.FATAL_FAILURE, third_connectivities)
+ self.assertNotIn(
+ grpc.ChannelConnectivity.TRANSIENT_FAILURE,
+ fourth_connectivities)
+ self.assertNotIn(
+ grpc.ChannelConnectivity.FATAL_FAILURE, fourth_connectivities)
+
+ def test_reachable_then_unreachable_channel_connectivity(self):
+ server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0))
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ callback = _Callback()
+
+ channel = _channel.Channel('localhost:{}'.format(port), None, None)
+ channel.subscribe(callback.update, try_to_connect=True)
+ callback.block_until_connectivities_satisfy(_ready_in_connectivities)
+ # Now take down the server and confirm that channel readiness is repudiated.
+ server.stop(None)
+ callback.block_until_connectivities_satisfy(_last_connectivity_is_not_ready)
+ channel.unsubscribe(callback.update)
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_channel_ready_future_test.py b/src/python/grpcio/tests/unit/_channel_ready_future_test.py
new file mode 100644
index 0000000000..b84bc0197a
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_channel_ready_future_test.py
@@ -0,0 +1,103 @@
+# Copyright 2015, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests of grpc.channel_ready_future."""
+
+import threading
+import unittest
+from concurrent import futures
+
+import grpc
+from grpc import _channel
+from grpc import _server
+from tests.unit.framework.common import test_constants
+
+
+class _Callback(object):
+
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._value = None
+
+ def accept_value(self, value):
+ with self._condition:
+ self._value = value
+ self._condition.notify_all()
+
+ def block_until_called(self):
+ with self._condition:
+ while self._value is None:
+ self._condition.wait()
+ return self._value
+
+
+class ChannelReadyFutureTest(unittest.TestCase):
+
+ def test_lonely_channel_connectivity(self):
+ channel = grpc.insecure_channel('localhost:12345')
+ callback = _Callback()
+
+ ready_future = grpc.channel_ready_future(channel)
+ ready_future.add_done_callback(callback.accept_value)
+ with self.assertRaises(grpc.FutureTimeoutError):
+ ready_future.result(test_constants.SHORT_TIMEOUT)
+ self.assertFalse(ready_future.cancelled())
+ self.assertFalse(ready_future.done())
+ self.assertTrue(ready_future.running())
+ ready_future.cancel()
+ value_passed_to_callback = callback.block_until_called()
+ self.assertIs(ready_future, value_passed_to_callback)
+ self.assertTrue(ready_future.cancelled())
+ self.assertTrue(ready_future.done())
+ self.assertFalse(ready_future.running())
+
+ def test_immediately_connectable_channel_connectivity(self):
+ server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0))
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ callback = _Callback()
+
+ ready_future = grpc.channel_ready_future(channel)
+ ready_future.add_done_callback(callback.accept_value)
+ self.assertIsNone(ready_future.result(test_constants.SHORT_TIMEOUT))
+ value_passed_to_callback = callback.block_until_called()
+ self.assertIs(ready_future, value_passed_to_callback)
+ self.assertFalse(ready_future.cancelled())
+ self.assertTrue(ready_future.done())
+ self.assertFalse(ready_future.running())
+ # Cancellation after maturity has no effect.
+ ready_future.cancel()
+ self.assertFalse(ready_future.cancelled())
+ self.assertTrue(ready_future.done())
+ self.assertFalse(ready_future.running())
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py
new file mode 100644
index 0000000000..c1de779014
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_cython/_cancel_many_calls_test.py
@@ -0,0 +1,222 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test making many calls and immediately cancelling most of them."""
+
+import threading
+import unittest
+
+from grpc._cython import cygrpc
+from grpc.framework.foundation import logging_pool
+from tests.unit.framework.common import test_constants
+
+_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
+_EMPTY_FLAGS = 0
+_EMPTY_METADATA = cygrpc.Metadata(())
+
+_SERVER_SHUTDOWN_TAG = 'server_shutdown'
+_REQUEST_CALL_TAG = 'request_call'
+_RECEIVE_CLOSE_ON_SERVER_TAG = 'receive_close_on_server'
+_RECEIVE_MESSAGE_TAG = 'receive_message'
+_SERVER_COMPLETE_CALL_TAG = 'server_complete_call'
+
+_SUCCESS_CALL_FRACTION = 1.0 / 8.0
+
+
+class _State(object):
+
+ def __init__(self):
+ self.condition = threading.Condition()
+ self.handlers_released = False
+ self.parked_handlers = 0
+ self.handled_rpcs = 0
+
+
+def _is_cancellation_event(event):
+ return (
+ event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and
+ event.batch_operations[0].received_cancelled)
+
+
+class _Handler(object):
+
+ def __init__(self, state, completion_queue, rpc_event):
+ self._state = state
+ self._lock = threading.Lock()
+ self._completion_queue = completion_queue
+ self._call = rpc_event.operation_call
+
+ def __call__(self):
+ with self._state.condition:
+ self._state.parked_handlers += 1
+ if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY:
+ self._state.condition.notify_all()
+ while not self._state.handlers_released:
+ self._state.condition.wait()
+
+ with self._lock:
+ self._call.start_batch(
+ cygrpc.Operations(
+ (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
+ _RECEIVE_CLOSE_ON_SERVER_TAG)
+ self._call.start_batch(
+ cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
+ _RECEIVE_MESSAGE_TAG)
+ first_event = self._completion_queue.poll()
+ if _is_cancellation_event(first_event):
+ self._completion_queue.poll()
+ else:
+ with self._lock:
+ operations = (
+ cygrpc.operation_send_initial_metadata(
+ _EMPTY_METADATA, _EMPTY_FLAGS),
+ cygrpc.operation_send_message(b'\x79\x57', _EMPTY_FLAGS),
+ cygrpc.operation_send_status_from_server(
+ _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
+ _EMPTY_FLAGS),
+ )
+ self._call.start_batch(
+ cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
+ self._completion_queue.poll()
+ self._completion_queue.poll()
+
+
+def _serve(state, server, server_completion_queue, thread_pool):
+ for _ in range(test_constants.RPC_CONCURRENCY):
+ call_completion_queue = cygrpc.CompletionQueue()
+ server.request_call(
+ call_completion_queue, server_completion_queue, _REQUEST_CALL_TAG)
+ rpc_event = server_completion_queue.poll()
+ thread_pool.submit(_Handler(state, call_completion_queue, rpc_event))
+ with state.condition:
+ state.handled_rpcs += 1
+ if test_constants.RPC_CONCURRENCY <= state.handled_rpcs:
+ state.condition.notify_all()
+ server_completion_queue.poll()
+
+
+class _QueueDriver(object):
+
+ def __init__(self, condition, completion_queue, due):
+ self._condition = condition
+ self._completion_queue = completion_queue
+ self._due = due
+ self._events = []
+ self._returned = False
+
+ def start(self):
+ def in_thread():
+ while True:
+ event = self._completion_queue.poll()
+ with self._condition:
+ self._events.append(event)
+ self._due.remove(event.tag)
+ self._condition.notify_all()
+ if not self._due:
+ self._returned = True
+ return
+ thread = threading.Thread(target=in_thread)
+ thread.start()
+
+ def events(self, at_least):
+ with self._condition:
+ while len(self._events) < at_least:
+ self._condition.wait()
+ return tuple(self._events)
+
+
+class CancelManyCallsTest(unittest.TestCase):
+
+ def testCancelManyCalls(self):
+ server_thread_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+
+ server_completion_queue = cygrpc.CompletionQueue()
+ server = cygrpc.Server()
+ server.register_completion_queue(server_completion_queue)
+ port = server.add_http2_port('[::]:0')
+ server.start()
+ channel = cygrpc.Channel('localhost:{}'.format(port))
+
+ state = _State()
+
+ server_thread_args = (
+ state, server, server_completion_queue, server_thread_pool,)
+ server_thread = threading.Thread(target=_serve, args=server_thread_args)
+ server_thread.start()
+
+ client_condition = threading.Condition()
+ client_due = set()
+ client_completion_queue = cygrpc.CompletionQueue()
+ client_driver = _QueueDriver(
+ client_condition, client_completion_queue, client_due)
+ client_driver.start()
+
+ with client_condition:
+ client_calls = []
+ for index in range(test_constants.RPC_CONCURRENCY):
+ client_call = channel.create_call(
+ None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None,
+ _INFINITE_FUTURE)
+ operations = (
+ cygrpc.operation_send_initial_metadata(
+ _EMPTY_METADATA, _EMPTY_FLAGS),
+ cygrpc.operation_send_message(b'\x45\x56', _EMPTY_FLAGS),
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
+ )
+ tag = 'client_complete_call_{0:04d}_tag'.format(index)
+ client_call.start_batch(cygrpc.Operations(operations), tag)
+ client_due.add(tag)
+ client_calls.append(client_call)
+
+ with state.condition:
+ while True:
+ if state.parked_handlers < test_constants.THREAD_CONCURRENCY:
+ state.condition.wait()
+ elif state.handled_rpcs < test_constants.RPC_CONCURRENCY:
+ state.condition.wait()
+ else:
+ state.handlers_released = True
+ state.condition.notify_all()
+ break
+
+ client_driver.events(
+ test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
+ with client_condition:
+ for client_call in client_calls:
+ client_call.cancel()
+
+ with state.condition:
+ server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG)
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py
new file mode 100644
index 0000000000..6ae7a90fbe
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py
@@ -0,0 +1,251 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test a corner-case at the level of the Cython API."""
+
+import threading
+import unittest
+
+from grpc._cython import cygrpc
+
+_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
+_EMPTY_FLAGS = 0
+_EMPTY_METADATA = cygrpc.Metadata(())
+
+
+class _ServerDriver(object):
+
+ def __init__(self, completion_queue, shutdown_tag):
+ self._condition = threading.Condition()
+ self._completion_queue = completion_queue
+ self._shutdown_tag = shutdown_tag
+ self._events = []
+ self._saw_shutdown_tag = False
+
+ def start(self):
+ def in_thread():
+ while True:
+ event = self._completion_queue.poll()
+ with self._condition:
+ self._events.append(event)
+ self._condition.notify()
+ if event.tag is self._shutdown_tag:
+ self._saw_shutdown_tag = True
+ break
+ thread = threading.Thread(target=in_thread)
+ thread.start()
+
+ def done(self):
+ with self._condition:
+ return self._saw_shutdown_tag
+
+ def first_event(self):
+ with self._condition:
+ while not self._events:
+ self._condition.wait()
+ return self._events[0]
+
+ def events(self):
+ with self._condition:
+ while not self._saw_shutdown_tag:
+ self._condition.wait()
+ return tuple(self._events)
+
+
+class _QueueDriver(object):
+
+ def __init__(self, condition, completion_queue, due):
+ self._condition = condition
+ self._completion_queue = completion_queue
+ self._due = due
+ self._events = []
+ self._returned = False
+
+ def start(self):
+ def in_thread():
+ while True:
+ event = self._completion_queue.poll()
+ with self._condition:
+ self._events.append(event)
+ self._due.remove(event.tag)
+ self._condition.notify_all()
+ if not self._due:
+ self._returned = True
+ return
+ thread = threading.Thread(target=in_thread)
+ thread.start()
+
+ def done(self):
+ with self._condition:
+ return self._returned
+
+ def event_with_tag(self, tag):
+ with self._condition:
+ while True:
+ for event in self._events:
+ if event.tag is tag:
+ return event
+ self._condition.wait()
+
+ def events(self):
+ with self._condition:
+ while not self._returned:
+ self._condition.wait()
+ return tuple(self._events)
+
+
+class ReadSomeButNotAllResponsesTest(unittest.TestCase):
+
+ def testReadSomeButNotAllResponses(self):
+ server_completion_queue = cygrpc.CompletionQueue()
+ server = cygrpc.Server()
+ server.register_completion_queue(server_completion_queue)
+ port = server.add_http2_port('[::]:0')
+ server.start()
+ channel = cygrpc.Channel('localhost:{}'.format(port))
+
+ server_shutdown_tag = 'server_shutdown_tag'
+ server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag)
+ server_driver.start()
+
+ client_condition = threading.Condition()
+ client_due = set()
+ client_completion_queue = cygrpc.CompletionQueue()
+ client_driver = _QueueDriver(
+ client_condition, client_completion_queue, client_due)
+ client_driver.start()
+
+ server_call_condition = threading.Condition()
+ server_send_initial_metadata_tag = 'server_send_initial_metadata_tag'
+ server_send_first_message_tag = 'server_send_first_message_tag'
+ server_send_second_message_tag = 'server_send_second_message_tag'
+ server_complete_rpc_tag = 'server_complete_rpc_tag'
+ server_call_due = set((
+ server_send_initial_metadata_tag,
+ server_send_first_message_tag,
+ server_send_second_message_tag,
+ server_complete_rpc_tag,
+ ))
+ server_call_completion_queue = cygrpc.CompletionQueue()
+ server_call_driver = _QueueDriver(
+ server_call_condition, server_call_completion_queue, server_call_due)
+ server_call_driver.start()
+
+ server_rpc_tag = 'server_rpc_tag'
+ request_call_result = server.request_call(
+ server_call_completion_queue, server_completion_queue, server_rpc_tag)
+
+ client_call = channel.create_call(
+ None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None,
+ _INFINITE_FUTURE)
+ client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
+ client_complete_rpc_tag = 'client_complete_rpc_tag'
+ with client_condition:
+ client_receive_initial_metadata_start_batch_result = (
+ client_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+ ]), client_receive_initial_metadata_tag))
+ client_due.add(client_receive_initial_metadata_tag)
+ client_complete_rpc_start_batch_result = (
+ client_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_send_initial_metadata(
+ _EMPTY_METADATA, _EMPTY_FLAGS),
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
+ ]), client_complete_rpc_tag))
+ client_due.add(client_complete_rpc_tag)
+
+ server_rpc_event = server_driver.first_event()
+
+ with server_call_condition:
+ server_send_initial_metadata_start_batch_result = (
+ server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_send_initial_metadata(
+ _EMPTY_METADATA, _EMPTY_FLAGS),
+ ]), server_send_initial_metadata_tag))
+ server_send_first_message_start_batch_result = (
+ server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
+ ]), server_send_first_message_tag))
+ server_send_initial_metadata_event = server_call_driver.event_with_tag(
+ server_send_initial_metadata_tag)
+ server_send_first_message_event = server_call_driver.event_with_tag(
+ server_send_first_message_tag)
+ with server_call_condition:
+ server_send_second_message_start_batch_result = (
+ server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
+ ]), server_send_second_message_tag))
+ server_complete_rpc_start_batch_result = (
+ server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
+ cygrpc.operation_send_status_from_server(
+ cygrpc.Metadata(()), cygrpc.StatusCode.ok, b'test details',
+ _EMPTY_FLAGS),
+ ]), server_complete_rpc_tag))
+ server_send_second_message_event = server_call_driver.event_with_tag(
+ server_send_second_message_tag)
+ server_complete_rpc_event = server_call_driver.event_with_tag(
+ server_complete_rpc_tag)
+ server_call_driver.events()
+
+ with client_condition:
+ client_receive_first_message_tag = 'client_receive_first_message_tag'
+ client_receive_first_message_start_batch_result = (
+ client_call.start_batch(cygrpc.Operations([
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ ]), client_receive_first_message_tag))
+ client_due.add(client_receive_first_message_tag)
+ client_receive_first_message_event = client_driver.event_with_tag(
+ client_receive_first_message_tag)
+
+ client_call_cancel_result = client_call.cancel()
+ client_driver.events()
+
+ server.shutdown(server_completion_queue, server_shutdown_tag)
+ server.cancel_all_calls()
+ server_driver.events()
+
+ self.assertEqual(cygrpc.CallError.ok, request_call_result)
+ self.assertEqual(
+ cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result)
+ self.assertEqual(
+ cygrpc.CallError.ok, client_receive_initial_metadata_start_batch_result)
+ self.assertEqual(
+ cygrpc.CallError.ok, client_complete_rpc_start_batch_result)
+ self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result)
+ self.assertIs(server_rpc_tag, server_rpc_event.tag)
+ self.assertEqual(
+ cygrpc.CompletionType.operation_complete, server_rpc_event.type)
+ self.assertIsInstance(server_rpc_event.operation_call, cygrpc.Call)
+ self.assertEqual(0, len(server_rpc_event.batch_operations))
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
index 0a511101f0..a006a20ce3 100644
--- a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
+++ b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py
@@ -37,7 +37,7 @@ from tests.unit import test_common
from tests.unit import resources
-_SSL_HOST_OVERRIDE = 'foo.test.google.fr'
+_SSL_HOST_OVERRIDE = b'foo.test.google.fr'
_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
_EMPTY_FLAGS = 0
@@ -143,22 +143,60 @@ class TypeSmokeTest(unittest.TestCase):
del completion_queue
-class InsecureServerInsecureClient(unittest.TestCase):
+class ServerClientMixin(object):
- def setUp(self):
+ def setUpMixin(self, server_credentials, client_credentials, host_override):
self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server()
self.server.register_completion_queue(self.server_completion_queue)
- self.port = self.server.add_http2_port('[::]:0')
+ if server_credentials:
+ self.port = self.server.add_http2_port('[::]:0', server_credentials)
+ else:
+ self.port = self.server.add_http2_port('[::]:0')
self.server.start()
self.client_completion_queue = cygrpc.CompletionQueue()
- self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port))
-
- def tearDown(self):
+ if client_credentials:
+ client_channel_arguments = cygrpc.ChannelArgs([
+ cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
+ host_override)])
+ self.client_channel = cygrpc.Channel(
+ 'localhost:{}'.format(self.port), client_channel_arguments,
+ client_credentials)
+ else:
+ self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port))
+ if host_override:
+ self.host_argument = None # default host
+ self.expected_host = host_override
+ else:
+ # arbitrary host name necessitating no further identification
+ self.host_argument = b'hostess'
+ self.expected_host = self.host_argument
+
+ def tearDownMixin(self):
del self.server
del self.client_completion_queue
del self.server_completion_queue
+ def _perform_operations(self, operations, call, queue, deadline, description):
+ """Perform the list of operations with given call, queue, and deadline.
+
+ Invocation errors are reported with as an exception with `description` in
+ the message. Performs the operations asynchronously, returning a future.
+ """
+ def performer():
+ tag = object()
+ try:
+ call_result = call.start_batch(cygrpc.Operations(operations), tag)
+ self.assertEqual(cygrpc.CallError.ok, call_result)
+ event = queue.poll(deadline)
+ self.assertEqual(cygrpc.CompletionType.operation_complete, event.type)
+ self.assertTrue(event.success)
+ self.assertIs(tag, event.tag)
+ except Exception as error:
+ raise Exception("Error in '{}': {}".format(description, error.message))
+ return event
+ return test_utilities.SimpleFuture(performer)
+
def testEcho(self):
DEADLINE = time.time()+5
DEADLINE_TOLERANCE = 0.25
@@ -175,7 +213,6 @@ class InsecureServerInsecureClient(unittest.TestCase):
REQUEST = b'in death a member of project mayhem has a name'
RESPONSE = b'his name is robert paulson'
METHOD = b'twinkies'
- HOST = b'hostess'
cygrpc_deadline = cygrpc.Timespec(DEADLINE)
@@ -188,7 +225,8 @@ class InsecureServerInsecureClient(unittest.TestCase):
client_call_tag = object()
client_call = self.client_channel.create_call(
- None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline)
+ None, 0, self.client_completion_queue, METHOD, self.host_argument,
+ cygrpc_deadline)
client_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE),
@@ -216,7 +254,8 @@ class InsecureServerInsecureClient(unittest.TestCase):
test_common.metadata_transmitted(client_initial_metadata,
request_event.request_metadata))
self.assertEqual(METHOD, request_event.request_call_details.method)
- self.assertEqual(HOST, request_event.request_call_details.host)
+ self.assertEqual(self.expected_host,
+ request_event.request_call_details.host)
self.assertLess(
abs(DEADLINE - float(request_event.request_call_details.deadline)),
DEADLINE_TOLERANCE)
@@ -292,172 +331,101 @@ class InsecureServerInsecureClient(unittest.TestCase):
del client_call
del server_call
-
-class SecureServerSecureClient(unittest.TestCase):
-
- def setUp(self):
- server_credentials = cygrpc.server_credentials_ssl(
- None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
- resources.certificate_chain())], False)
- channel_credentials = cygrpc.channel_credentials_ssl(
- resources.test_root_certificates(), None)
- self.server_completion_queue = cygrpc.CompletionQueue()
- self.server = cygrpc.Server()
- self.server.register_completion_queue(self.server_completion_queue)
- self.port = self.server.add_http2_port('[::]:0', server_credentials)
- self.server.start()
- self.client_completion_queue = cygrpc.CompletionQueue()
- client_channel_arguments = cygrpc.ChannelArgs([
- cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
- _SSL_HOST_OVERRIDE)])
- self.client_channel = cygrpc.Channel(
- 'localhost:{}'.format(self.port), client_channel_arguments,
- channel_credentials)
-
- def tearDown(self):
- del self.server
- del self.client_completion_queue
- del self.server_completion_queue
-
- def testEcho(self):
+ def test6522(self):
DEADLINE = time.time()+5
DEADLINE_TOLERANCE = 0.25
- CLIENT_METADATA_ASCII_KEY = b'key'
- CLIENT_METADATA_ASCII_VALUE = b'val'
- CLIENT_METADATA_BIN_KEY = b'key-bin'
- CLIENT_METADATA_BIN_VALUE = b'\0'*1000
- SERVER_INITIAL_METADATA_KEY = b'init_me_me_me'
- SERVER_INITIAL_METADATA_VALUE = b'whodawha?'
- SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought'
- SERVER_TRAILING_METADATA_VALUE = b'zomg it is'
- SERVER_STATUS_CODE = cygrpc.StatusCode.ok
- SERVER_STATUS_DETAILS = b'our work is never over'
- REQUEST = b'in death a member of project mayhem has a name'
- RESPONSE = b'his name is robert paulson'
- METHOD = b'/twinkies'
- HOST = None # Default host
+ METHOD = b'twinkies'
cygrpc_deadline = cygrpc.Timespec(DEADLINE)
+ empty_metadata = cygrpc.Metadata([])
server_request_tag = object()
- request_call_result = self.server.request_call(
+ self.server.request_call(
self.server_completion_queue, self.server_completion_queue,
server_request_tag)
+ client_call = self.client_channel.create_call(
+ None, 0, self.client_completion_queue, METHOD, self.host_argument,
+ cygrpc_deadline)
- self.assertEqual(cygrpc.CallError.ok, request_call_result)
-
- plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '')
- call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
+ # Prologue
+ def perform_client_operations(operations, description):
+ return self._perform_operations(
+ operations, client_call,
+ self.client_completion_queue, cygrpc_deadline, description)
- client_call_tag = object()
- client_call = self.client_channel.create_call(
- None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline)
- client_call.set_credentials(call_credentials)
- client_initial_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
- CLIENT_METADATA_ASCII_VALUE),
- cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
- client_start_batch_result = client_call.start_batch(cygrpc.Operations([
- cygrpc.operation_send_initial_metadata(client_initial_metadata,
- _EMPTY_FLAGS),
- cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS),
- cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
- cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
- ]), client_call_tag)
- self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
- client_event_future = test_utilities.CompletionQueuePollFuture(
- self.client_completion_queue, cygrpc_deadline)
+ client_event_future = perform_client_operations([
+ cygrpc.operation_send_initial_metadata(empty_metadata,
+ _EMPTY_FLAGS),
+ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+ ], "Client prologue")
request_event = self.server_completion_queue.poll(cygrpc_deadline)
- self.assertEqual(cygrpc.CompletionType.operation_complete,
- request_event.type)
- self.assertIsInstance(request_event.operation_call, cygrpc.Call)
- self.assertIs(server_request_tag, request_event.tag)
- self.assertEqual(0, len(request_event.batch_operations))
- client_metadata_with_credentials = list(client_initial_metadata) + [
- (_CALL_CREDENTIALS_METADATA_KEY, _CALL_CREDENTIALS_METADATA_VALUE)]
- self.assertTrue(
- test_common.metadata_transmitted(client_metadata_with_credentials,
- request_event.request_metadata))
- self.assertEqual(METHOD, request_event.request_call_details.method)
- self.assertEqual(_SSL_HOST_OVERRIDE,
- request_event.request_call_details.host)
- self.assertLess(
- abs(DEADLINE - float(request_event.request_call_details.deadline)),
- DEADLINE_TOLERANCE)
-
- server_call_tag = object()
server_call = request_event.operation_call
- server_initial_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
- SERVER_INITIAL_METADATA_VALUE)])
- server_trailing_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
- SERVER_TRAILING_METADATA_VALUE)])
- server_start_batch_result = server_call.start_batch([
- cygrpc.operation_send_initial_metadata(server_initial_metadata,
- _EMPTY_FLAGS),
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- cygrpc.operation_send_message(RESPONSE, _EMPTY_FLAGS),
- cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
- cygrpc.operation_send_status_from_server(
- server_trailing_metadata, SERVER_STATUS_CODE,
- SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
- ], server_call_tag)
- self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
- client_event = client_event_future.result()
- server_event = self.server_completion_queue.poll(cygrpc_deadline)
+ def perform_server_operations(operations, description):
+ return self._perform_operations(
+ operations, server_call,
+ self.server_completion_queue, cygrpc_deadline, description)
- self.assertEqual(6, len(client_event.batch_operations))
- found_client_op_types = set()
- for client_result in client_event.batch_operations:
- # we expect each op type to be unique
- self.assertNotIn(client_result.type, found_client_op_types)
- found_client_op_types.add(client_result.type)
- if client_result.type == cygrpc.OperationType.receive_initial_metadata:
- self.assertTrue(
- test_common.metadata_transmitted(server_initial_metadata,
- client_result.received_metadata))
- elif client_result.type == cygrpc.OperationType.receive_message:
- self.assertEqual(RESPONSE, client_result.received_message.bytes())
- elif client_result.type == cygrpc.OperationType.receive_status_on_client:
- self.assertTrue(
- test_common.metadata_transmitted(server_trailing_metadata,
- client_result.received_metadata))
- self.assertEqual(SERVER_STATUS_DETAILS,
- client_result.received_status_details)
- self.assertEqual(SERVER_STATUS_CODE, client_result.received_status_code)
- self.assertEqual(set([
- cygrpc.OperationType.send_initial_metadata,
- cygrpc.OperationType.send_message,
- cygrpc.OperationType.send_close_from_client,
- cygrpc.OperationType.receive_initial_metadata,
- cygrpc.OperationType.receive_message,
- cygrpc.OperationType.receive_status_on_client
- ]), found_client_op_types)
+ server_event_future = perform_server_operations([
+ cygrpc.operation_send_initial_metadata(empty_metadata,
+ _EMPTY_FLAGS),
+ ], "Server prologue")
- self.assertEqual(5, len(server_event.batch_operations))
- found_server_op_types = set()
- for server_result in server_event.batch_operations:
- self.assertNotIn(client_result.type, found_server_op_types)
- found_server_op_types.add(server_result.type)
- if server_result.type == cygrpc.OperationType.receive_message:
- self.assertEqual(REQUEST, server_result.received_message.bytes())
- elif server_result.type == cygrpc.OperationType.receive_close_on_server:
- self.assertFalse(server_result.received_cancelled)
- self.assertEqual(set([
- cygrpc.OperationType.send_initial_metadata,
- cygrpc.OperationType.receive_message,
- cygrpc.OperationType.send_message,
- cygrpc.OperationType.receive_close_on_server,
- cygrpc.OperationType.send_status_from_server
- ]), found_server_op_types)
+ client_event_future.result() # force completion
+ server_event_future.result()
- del client_call
- del server_call
+ # Messaging
+ for _ in range(10):
+ client_event_future = perform_client_operations([
+ cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ ], "Client message")
+ server_event_future = perform_server_operations([
+ cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ ], "Server receive")
+
+ client_event_future.result() # force completion
+ server_event_future.result()
+
+ # Epilogue
+ client_event_future = perform_client_operations([
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
+ ], "Client epilogue")
+
+ server_event_future = perform_server_operations([
+ cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
+ cygrpc.operation_send_status_from_server(
+ empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
+ ], "Server epilogue")
+
+ client_event_future.result() # force completion
+ server_event_future.result()
+
+
+class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
+
+ def setUp(self):
+ self.setUpMixin(None, None, None)
+
+ def tearDown(self):
+ self.tearDownMixin()
+
+
+class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
+
+ def setUp(self):
+ server_credentials = cygrpc.server_credentials_ssl(
+ None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
+ resources.certificate_chain())], False)
+ client_credentials = cygrpc.channel_credentials_ssl(
+ resources.test_root_certificates(), None)
+ self.setUpMixin(server_credentials, client_credentials, _SSL_HOST_OVERRIDE)
+
+ def tearDown(self):
+ self.tearDownMixin()
if __name__ == '__main__':
diff --git a/src/python/grpcio/tests/unit/_cython/test_utilities.py b/src/python/grpcio/tests/unit/_cython/test_utilities.py
index 6a09739643..6280ce74c4 100644
--- a/src/python/grpcio/tests/unit/_cython/test_utilities.py
+++ b/src/python/grpcio/tests/unit/_cython/test_utilities.py
@@ -32,15 +32,35 @@ import threading
from grpc._cython import cygrpc
-class CompletionQueuePollFuture:
+class SimpleFuture(object):
+ """A simple future mechanism."""
- def __init__(self, completion_queue, deadline):
- def poller_function():
- self._event_result = completion_queue.poll(deadline)
- self._event_result = None
- self._thread = threading.Thread(target=poller_function)
+ def __init__(self, function, *args, **kwargs):
+ def wrapped_function():
+ try:
+ self._result = function(*args, **kwargs)
+ except Exception as error:
+ self._error = error
+ self._result = None
+ self._error = None
+ self._thread = threading.Thread(target=wrapped_function)
self._thread.start()
def result(self):
+ """The resulting value of this future.
+
+ Re-raises any exceptions.
+ """
self._thread.join()
- return self._event_result
+ if self._error:
+ # TODO(atash): re-raise exceptions in a way that preserves tracebacks
+ raise self._error
+ return self._result
+
+
+class CompletionQueuePollFuture(SimpleFuture):
+
+ def __init__(self, completion_queue, deadline):
+ super(CompletionQueuePollFuture, self).__init__(
+ lambda: completion_queue.poll(deadline))
+
diff --git a/src/python/grpcio/tests/unit/_empty_message_test.py b/src/python/grpcio/tests/unit/_empty_message_test.py
new file mode 100644
index 0000000000..f324f6216b
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_empty_message_test.py
@@ -0,0 +1,137 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import unittest
+
+import grpc
+from grpc.framework.foundation import logging_pool
+
+from tests.unit.framework.common import test_constants
+
+_REQUEST = b''
+_RESPONSE = b''
+
+_UNARY_UNARY = b'/test/UnaryUnary'
+_UNARY_STREAM = b'/test/UnaryStream'
+_STREAM_UNARY = b'/test/StreamUnary'
+_STREAM_STREAM = b'/test/StreamStream'
+
+
+def handle_unary_unary(request, servicer_context):
+ return _RESPONSE
+
+
+def handle_unary_stream(request, servicer_context):
+ for _ in range(test_constants.STREAM_LENGTH):
+ yield _RESPONSE
+
+
+def handle_stream_unary(request_iterator, servicer_context):
+ for request in request_iterator:
+ pass
+ return _RESPONSE
+
+
+def handle_stream_stream(request_iterator, servicer_context):
+ for request in request_iterator:
+ yield _RESPONSE
+
+
+class _MethodHandler(grpc.RpcMethodHandler):
+
+ def __init__(self, request_streaming, response_streaming):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ self.stream_stream = handle_stream_stream
+ elif self.request_streaming:
+ self.stream_unary = handle_stream_unary
+ elif self.response_streaming:
+ self.unary_stream = handle_unary_stream
+ else:
+ self.unary_unary = handle_unary_unary
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(False, False)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(False, True)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(True, False)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(True, True)
+ else:
+ return None
+
+
+class EmptyMessageTest(unittest.TestCase):
+
+ def setUp(self):
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server((_GenericHandler(),), self._server_pool)
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+
+ def tearDown(self):
+ self._server.stop(0)
+
+ def testUnaryUnary(self):
+ response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST)
+ self.assertEqual(_RESPONSE, response)
+
+ def testUnaryStream(self):
+ response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST)
+ self.assertSequenceEqual(
+ [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator))
+
+ def testStreamUnary(self):
+ response = self._channel.stream_unary(_STREAM_UNARY)(
+ [_REQUEST] * test_constants.STREAM_LENGTH)
+ self.assertEqual(_RESPONSE, response)
+
+ def testStreamStream(self):
+ response_iterator = self._channel.stream_stream(_STREAM_STREAM)(
+ [_REQUEST] * test_constants.STREAM_LENGTH)
+ self.assertSequenceEqual(
+ [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator))
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
+
diff --git a/src/python/grpcio/grpc/_adapter/_implementations.py b/src/python/grpcio/tests/unit/_from_grpc_import_star.py
index b85f228bf6..78d2fb7dc5 100644
--- a/src/python/grpcio/grpc/_adapter/_implementations.py
+++ b/src/python/grpcio/tests/unit/_from_grpc_import_star.py
@@ -1,4 +1,4 @@
-# Copyright 2015, Google Inc.
+# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
@@ -27,22 +27,12 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-import collections
+_BEFORE_IMPORT = tuple(globals())
-from grpc.beta import interfaces
+from grpc import *
-class AuthMetadataContext(collections.namedtuple(
- 'AuthMetadataContext', [
- 'service_url',
- 'method_name'
- ]), interfaces.GRPCAuthMetadataContext):
- pass
+_AFTER_IMPORT = tuple(globals())
-
-class AuthMetadataPluginCallback(interfaces.GRPCAuthMetadataContext):
-
- def __init__(self, callback):
- self._callback = callback
-
- def __call__(self, metadata, error):
- self._callback(metadata, error)
+GRPC_ELEMENTS = tuple(
+ element for element in _AFTER_IMPORT
+ if element not in _BEFORE_IMPORT and element != '_BEFORE_IMPORT')
diff --git a/src/python/grpcio/tests/unit/_links/_transmission_test.py b/src/python/grpcio/tests/unit/_links/_transmission_test.py
index 888684d197..1f6edd18ca 100644
--- a/src/python/grpcio/tests/unit/_links/_transmission_test.py
+++ b/src/python/grpcio/tests/unit/_links/_transmission_test.py
@@ -153,7 +153,7 @@ class RoundTripTest(unittest.TestCase):
invocation_mate.tickets()[-1].termination,
links.Ticket.Termination.COMPLETION)
self.assertIs(invocation_mate.tickets()[-1].code, test_code)
- self.assertEqual(invocation_mate.tickets()[-1].message, test_message)
+ self.assertEqual(invocation_mate.tickets()[-1].message, test_message.encode())
def _perform_scenario_test(self, scenario):
test_operation_id = object()
diff --git a/src/python/grpcio/tests/unit/_metadata_test.py b/src/python/grpcio/tests/unit/_metadata_test.py
new file mode 100644
index 0000000000..77b3901261
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_metadata_test.py
@@ -0,0 +1,216 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Tests server and client side metadata API."""
+
+import unittest
+import weakref
+
+import grpc
+from grpc import _grpcio_metadata
+from grpc.framework.foundation import logging_pool
+
+from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+
+_CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'),
+ ('grpc.secondary_user_agent', 'secondary-agent'))
+
+_REQUEST = b'\x00\x00\x00'
+_RESPONSE = b'\x00\x00\x00'
+
+_UNARY_UNARY = b'/test/UnaryUnary'
+_UNARY_STREAM = b'/test/UnaryStream'
+_STREAM_UNARY = b'/test/StreamUnary'
+_STREAM_STREAM = b'/test/StreamStream'
+
+_USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
+
+_CLIENT_METADATA = (
+ (b'client-md-key', b'client-md-key'),
+ (b'client-md-key-bin', b'\x00\x01')
+)
+
+_SERVER_INITIAL_METADATA = (
+ (b'server-initial-md-key', b'server-initial-md-value'),
+ (b'server-initial-md-key-bin', b'\x00\x02')
+)
+
+_SERVER_TRAILING_METADATA = (
+ (b'server-trailing-md-key', b'server-trailing-md-value'),
+ (b'server-trailing-md-key-bin', b'\x00\x03')
+)
+
+
+def user_agent(metadata):
+ for key, val in metadata:
+ if key == b'user-agent':
+ return val.decode('ascii')
+ raise KeyError('No user agent!')
+
+
+def validate_client_metadata(test, servicer_context):
+ test.assertTrue(test_common.metadata_transmitted(
+ _CLIENT_METADATA, servicer_context.invocation_metadata()))
+ test.assertTrue(user_agent(servicer_context.invocation_metadata())
+ .startswith('primary-agent ' + _USER_AGENT))
+ test.assertTrue(user_agent(servicer_context.invocation_metadata())
+ .endswith('secondary-agent'))
+
+
+def handle_unary_unary(test, request, servicer_context):
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ return _RESPONSE
+
+
+def handle_unary_stream(test, request, servicer_context):
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ for _ in range(test_constants.STREAM_LENGTH):
+ yield _RESPONSE
+
+
+def handle_stream_unary(test, request_iterator, servicer_context):
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ # TODO(issue:#6891) We should be able to remove this loop
+ for request in request_iterator:
+ pass
+ return _RESPONSE
+
+
+def handle_stream_stream(test, request_iterator, servicer_context):
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ # TODO(issue:#6891) We should be able to remove this loop,
+ # and replace with return; yield
+ for request in request_iterator:
+ yield _RESPONSE
+
+
+class _MethodHandler(grpc.RpcMethodHandler):
+
+ def __init__(self, test, request_streaming, response_streaming):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
+ elif self.request_streaming:
+ self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
+ elif self.response_streaming:
+ self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
+ else:
+ self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def __init__(self, test):
+ self._test = test
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(self._test, False, False)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(self._test, False, True)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(self._test, True, False)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(self._test, True, True)
+ else:
+ return None
+
+
+class MetadataTest(unittest.TestCase):
+
+ def setUp(self):
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server((_GenericHandler(weakref.proxy(self)),),
+ self._server_pool)
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+ self._channel = grpc.insecure_channel('localhost:%d' % port,
+ options=_CHANNEL_ARGS)
+
+ def tearDown(self):
+ self._server.stop(0)
+
+ def testUnaryUnary(self):
+ multi_callable = self._channel.unary_unary(_UNARY_UNARY)
+ unused_response, call = multi_callable(
+ _REQUEST, metadata=_CLIENT_METADATA, with_call=True)
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+
+ def testUnaryStream(self):
+ multi_callable = self._channel.unary_stream(_UNARY_STREAM)
+ call = multi_callable(_REQUEST, metadata=_CLIENT_METADATA)
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ for _ in call:
+ pass
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+
+ def testStreamUnary(self):
+ multi_callable = self._channel.stream_unary(_STREAM_UNARY)
+ unused_response, call = multi_callable(
+ [_REQUEST] * test_constants.STREAM_LENGTH,
+ metadata=_CLIENT_METADATA, with_call=True)
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+
+ def testStreamStream(self):
+ multi_callable = self._channel.stream_stream(_STREAM_STREAM)
+ call = multi_callable([_REQUEST] * test_constants.STREAM_LENGTH,
+ metadata=_CLIENT_METADATA)
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA, call.initial_metadata()))
+ for _ in call:
+ pass
+ self.assertTrue(test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_rpc_test.py b/src/python/grpcio/tests/unit/_rpc_test.py
new file mode 100644
index 0000000000..8407593c86
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_rpc_test.py
@@ -0,0 +1,768 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test of gRPC Python's application-layer API."""
+
+import itertools
+import threading
+import unittest
+from concurrent import futures
+
+import grpc
+from grpc.framework.foundation import logging_pool
+
+from tests.unit.framework.common import test_constants
+from tests.unit.framework.common import test_control
+
+_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2
+_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:]
+_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
+_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]
+
+_UNARY_UNARY = b'/test/UnaryUnary'
+_UNARY_STREAM = b'/test/UnaryStream'
+_STREAM_UNARY = b'/test/StreamUnary'
+_STREAM_STREAM = b'/test/StreamStream'
+
+
+class _Callback(object):
+
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._value = None
+ self._called = False
+
+ def __call__(self, value):
+ with self._condition:
+ self._value = value
+ self._called = True
+ self._condition.notify_all()
+
+ def value(self):
+ with self._condition:
+ while not self._called:
+ self._condition.wait()
+ return self._value
+
+
+class _Handler(object):
+
+ def __init__(self, control):
+ self._control = control
+
+ def handle_unary_unary(self, request, servicer_context):
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
+ return request
+
+ def handle_unary_stream(self, request, servicer_context):
+ for _ in range(test_constants.STREAM_LENGTH):
+ self._control.control()
+ yield request
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
+
+ def handle_stream_unary(self, request_iterator, servicer_context):
+ if servicer_context is not None:
+ servicer_context.invocation_metadata()
+ self._control.control()
+ response_elements = []
+ for request in request_iterator:
+ self._control.control()
+ response_elements.append(request)
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
+ return b''.join(response_elements)
+
+ def handle_stream_stream(self, request_iterator, servicer_context):
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),))
+ for request in request_iterator:
+ self._control.control()
+ yield request
+ self._control.control()
+
+
+class _MethodHandler(grpc.RpcMethodHandler):
+
+ def __init__(
+ self, request_streaming, response_streaming, request_deserializer,
+ response_serializer, unary_unary, unary_stream, stream_unary,
+ stream_stream):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = request_deserializer
+ self.response_serializer = response_serializer
+ self.unary_unary = unary_unary
+ self.unary_stream = unary_stream
+ self.stream_unary = stream_unary
+ self.stream_stream = stream_stream
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def __init__(self, handler):
+ self._handler = handler
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(
+ False, False, None, None, self._handler.handle_unary_unary, None,
+ None, None)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(
+ False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None,
+ self._handler.handle_unary_stream, None, None)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(
+ True, False, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, None,
+ self._handler.handle_stream_unary, None)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(
+ True, True, None, None, None, None, None,
+ self._handler.handle_stream_stream)
+ else:
+ return None
+
+
+def _unary_unary_multi_callable(channel):
+ return channel.unary_unary(_UNARY_UNARY)
+
+
+def _unary_stream_multi_callable(channel):
+ return channel.unary_stream(
+ _UNARY_STREAM,
+ request_serializer=_SERIALIZE_REQUEST,
+ response_deserializer=_DESERIALIZE_RESPONSE)
+
+
+def _stream_unary_multi_callable(channel):
+ return channel.stream_unary(
+ _STREAM_UNARY,
+ request_serializer=_SERIALIZE_REQUEST,
+ response_deserializer=_DESERIALIZE_RESPONSE)
+
+
+def _stream_stream_multi_callable(channel):
+ return channel.stream_stream(_STREAM_STREAM)
+
+
+class RPCTest(unittest.TestCase):
+
+ def setUp(self):
+ self._control = test_control.PauseFailControl()
+ self._handler = _Handler(self._control)
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+
+ self._server = grpc.server((), self._server_pool)
+ port = self._server.add_insecure_port(b'[::]:0')
+ self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
+ self._server.start()
+
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+
+ def testUnrecognizedMethod(self):
+ request = b'abc'
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._channel.unary_unary(b'NoSuchMethod')(request)
+
+ self.assertEqual(
+ grpc.StatusCode.UNIMPLEMENTED, exception_context.exception.code())
+
+ def testSuccessfulUnaryRequestBlockingUnaryResponse(self):
+ request = b'\x07\x08'
+ expected_response = self._handler.handle_unary_unary(request, None)
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ response = multi_callable(
+ request, metadata=(
+ (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponse'),))
+
+ self.assertEqual(expected_response, response)
+
+ def testSuccessfulUnaryRequestBlockingUnaryResponseWithCall(self):
+ request = b'\x07\x08'
+ expected_response = self._handler.handle_unary_unary(request, None)
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ response, call = multi_callable(
+ request, metadata=(
+ (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),),
+ with_call=True)
+
+ self.assertEqual(expected_response, response)
+ self.assertIs(grpc.StatusCode.OK, call.code())
+
+ def testSuccessfulUnaryRequestFutureUnaryResponse(self):
+ request = b'\x07\x08'
+ expected_response = self._handler.handle_unary_unary(request, None)
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ response_future = multi_callable.future(
+ request, metadata=(
+ (b'test', b'SuccessfulUnaryRequestFutureUnaryResponse'),))
+ response = response_future.result()
+
+ self.assertEqual(expected_response, response)
+
+ def testSuccessfulUnaryRequestStreamResponse(self):
+ request = b'\x37\x58'
+ expected_responses = tuple(self._handler.handle_unary_stream(request, None))
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ request,
+ metadata=((b'test', b'SuccessfulUnaryRequestStreamResponse'),))
+ responses = tuple(response_iterator)
+
+ self.assertSequenceEqual(expected_responses, responses)
+
+ def testSuccessfulStreamRequestBlockingUnaryResponse(self):
+ requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
+ expected_response = self._handler.handle_stream_unary(iter(requests), None)
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ response = multi_callable(
+ request_iterator,
+ metadata=((b'test', b'SuccessfulStreamRequestBlockingUnaryResponse'),))
+
+ self.assertEqual(expected_response, response)
+
+ def testSuccessfulStreamRequestBlockingUnaryResponseWithCall(self):
+ requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
+ expected_response = self._handler.handle_stream_unary(iter(requests), None)
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ response, call = multi_callable(
+ request_iterator,
+ metadata=(
+ (b'test', b'SuccessfulStreamRequestBlockingUnaryResponseWithCall'),
+ ), with_call=True)
+
+ self.assertEqual(expected_response, response)
+ self.assertIs(grpc.StatusCode.OK, call.code())
+
+ def testSuccessfulStreamRequestFutureUnaryResponse(self):
+ requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
+ expected_response = self._handler.handle_stream_unary(iter(requests), None)
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ response_future = multi_callable.future(
+ request_iterator,
+ metadata=(
+ (b'test', b'SuccessfulStreamRequestFutureUnaryResponse'),))
+ response = response_future.result()
+
+ self.assertEqual(expected_response, response)
+
+ def testSuccessfulStreamRequestStreamResponse(self):
+ requests = tuple(b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH))
+ expected_responses = tuple(
+ self._handler.handle_stream_stream(iter(requests), None))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ request_iterator,
+ metadata=((b'test', b'SuccessfulStreamRequestStreamResponse'),))
+ responses = tuple(response_iterator)
+
+ self.assertSequenceEqual(expected_responses, responses)
+
+ def testSequentialInvocations(self):
+ first_request = b'\x07\x08'
+ second_request = b'\x0809'
+ expected_first_response = self._handler.handle_unary_unary(
+ first_request, None)
+ expected_second_response = self._handler.handle_unary_unary(
+ second_request, None)
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ first_response = multi_callable(
+ first_request, metadata=((b'test', b'SequentialInvocations'),))
+ second_response = multi_callable(
+ second_request, metadata=((b'test', b'SequentialInvocations'),))
+
+ self.assertEqual(expected_first_response, first_response)
+ self.assertEqual(expected_second_response, second_response)
+
+ def testConcurrentBlockingInvocations(self):
+ pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
+ expected_response = self._handler.handle_stream_unary(iter(requests), None)
+ expected_responses = [expected_response] * test_constants.THREAD_CONCURRENCY
+ response_futures = [None] * test_constants.THREAD_CONCURRENCY
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ for index in range(test_constants.THREAD_CONCURRENCY):
+ request_iterator = iter(requests)
+ response_future = pool.submit(
+ multi_callable, request_iterator,
+ metadata=((b'test', b'ConcurrentBlockingInvocations'),))
+ response_futures[index] = response_future
+ responses = tuple(
+ response_future.result() for response_future in response_futures)
+
+ pool.shutdown(wait=True)
+ self.assertSequenceEqual(expected_responses, responses)
+
+ def testConcurrentFutureInvocations(self):
+ requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
+ expected_response = self._handler.handle_stream_unary(iter(requests), None)
+ expected_responses = [expected_response] * test_constants.THREAD_CONCURRENCY
+ response_futures = [None] * test_constants.THREAD_CONCURRENCY
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ for index in range(test_constants.THREAD_CONCURRENCY):
+ request_iterator = iter(requests)
+ response_future = multi_callable.future(
+ request_iterator,
+ metadata=((b'test', b'ConcurrentFutureInvocations'),))
+ response_futures[index] = response_future
+ responses = tuple(
+ response_future.result() for response_future in response_futures)
+
+ self.assertSequenceEqual(expected_responses, responses)
+
+ def testWaitingForSomeButNotAllConcurrentFutureInvocations(self):
+ pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ request = b'\x67\x68'
+ expected_response = self._handler.handle_unary_unary(request, None)
+ response_futures = [None] * test_constants.THREAD_CONCURRENCY
+ lock = threading.Lock()
+ test_is_running_cell = [True]
+ def wrap_future(future):
+ def wrap():
+ try:
+ return future.result()
+ except grpc.RpcError:
+ with lock:
+ if test_is_running_cell[0]:
+ raise
+ return None
+ return wrap
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ for index in range(test_constants.THREAD_CONCURRENCY):
+ inner_response_future = multi_callable.future(
+ request,
+ metadata=(
+ (b'test',
+ b'WaitingForSomeButNotAllConcurrentFutureInvocations'),))
+ outer_response_future = pool.submit(wrap_future(inner_response_future))
+ response_futures[index] = outer_response_future
+
+ some_completed_response_futures_iterator = itertools.islice(
+ futures.as_completed(response_futures),
+ test_constants.THREAD_CONCURRENCY // 2)
+ for response_future in some_completed_response_futures_iterator:
+ self.assertEqual(expected_response, response_future.result())
+ with lock:
+ test_is_running_cell[0] = False
+
+ def testConsumingOneStreamResponseUnaryRequest(self):
+ request = b'\x57\x38'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ request,
+ metadata=(
+ (b'test', b'ConsumingOneStreamResponseUnaryRequest'),))
+ next(response_iterator)
+
+ def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self):
+ request = b'\x57\x38'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ request,
+ metadata=(
+ (b'test', b'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),))
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ next(response_iterator)
+
+ def testConsumingSomeButNotAllStreamResponsesStreamRequest(self):
+ requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ request_iterator,
+ metadata=(
+ (b'test', b'ConsumingSomeButNotAllStreamResponsesStreamRequest'),))
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ next(response_iterator)
+
+ def testConsumingTooManyStreamResponsesStreamRequest(self):
+ requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ request_iterator,
+ metadata=(
+ (b'test', b'ConsumingTooManyStreamResponsesStreamRequest'),))
+ for _ in range(test_constants.STREAM_LENGTH):
+ next(response_iterator)
+ for _ in range(test_constants.STREAM_LENGTH):
+ with self.assertRaises(StopIteration):
+ next(response_iterator)
+
+ self.assertIsNotNone(response_iterator.initial_metadata())
+ self.assertIs(grpc.StatusCode.OK, response_iterator.code())
+ self.assertIsNotNone(response_iterator.details())
+ self.assertIsNotNone(response_iterator.trailing_metadata())
+
+ def testCancelledUnaryRequestUnaryResponse(self):
+ request = b'\x07\x17'
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ with self._control.pause():
+ response_future = multi_callable.future(
+ request,
+ metadata=((b'test', b'CancelledUnaryRequestUnaryResponse'),))
+ response_future.cancel()
+
+ self.assertTrue(response_future.cancelled())
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.result()
+ self.assertIs(grpc.StatusCode.CANCELLED, response_future.code())
+
+ def testCancelledUnaryRequestStreamResponse(self):
+ request = b'\x07\x19'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ with self._control.pause():
+ response_iterator = multi_callable(
+ request,
+ metadata=((b'test', b'CancelledUnaryRequestStreamResponse'),))
+ self._control.block_until_paused()
+ response_iterator.cancel()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(response_iterator)
+ self.assertIs(grpc.StatusCode.CANCELLED, exception_context.exception.code())
+ self.assertIsNotNone(response_iterator.initial_metadata())
+ self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
+ self.assertIsNotNone(response_iterator.details())
+ self.assertIsNotNone(response_iterator.trailing_metadata())
+
+ def testCancelledStreamRequestUnaryResponse(self):
+ requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self._control.pause():
+ response_future = multi_callable.future(
+ request_iterator,
+ metadata=((b'test', b'CancelledStreamRequestUnaryResponse'),))
+ self._control.block_until_paused()
+ response_future.cancel()
+
+ self.assertTrue(response_future.cancelled())
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.result()
+ self.assertIsNotNone(response_future.initial_metadata())
+ self.assertIs(grpc.StatusCode.CANCELLED, response_future.code())
+ self.assertIsNotNone(response_future.details())
+ self.assertIsNotNone(response_future.trailing_metadata())
+
+ def testCancelledStreamRequestStreamResponse(self):
+ requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ with self._control.pause():
+ response_iterator = multi_callable(
+ request_iterator,
+ metadata=((b'test', b'CancelledStreamRequestStreamResponse'),))
+ response_iterator.cancel()
+
+ with self.assertRaises(grpc.RpcError):
+ next(response_iterator)
+ self.assertIsNotNone(response_iterator.initial_metadata())
+ self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
+ self.assertIsNotNone(response_iterator.details())
+ self.assertIsNotNone(response_iterator.trailing_metadata())
+
+ def testExpiredUnaryRequestBlockingUnaryResponse(self):
+ request = b'\x07\x17'
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ with self._control.pause():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ multi_callable(
+ request, timeout=test_constants.SHORT_TIMEOUT,
+ metadata=((b'test', b'ExpiredUnaryRequestBlockingUnaryResponse'),),
+ with_call=True)
+
+ self.assertIsNotNone(exception_context.exception.initial_metadata())
+ self.assertIs(
+ grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
+ self.assertIsNotNone(exception_context.exception.details())
+ self.assertIsNotNone(exception_context.exception.trailing_metadata())
+
+ def testExpiredUnaryRequestFutureUnaryResponse(self):
+ request = b'\x07\x17'
+ callback = _Callback()
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ with self._control.pause():
+ response_future = multi_callable.future(
+ request, timeout=test_constants.SHORT_TIMEOUT,
+ metadata=((b'test', b'ExpiredUnaryRequestFutureUnaryResponse'),))
+ response_future.add_done_callback(callback)
+ value_passed_to_callback = callback.value()
+
+ self.assertIs(response_future, value_passed_to_callback)
+ self.assertIsNotNone(response_future.initial_metadata())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
+ self.assertIsNotNone(response_future.details())
+ self.assertIsNotNone(response_future.trailing_metadata())
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(
+ grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIs(
+ grpc.StatusCode.DEADLINE_EXCEEDED, response_future.exception().code())
+
+ def testExpiredUnaryRequestStreamResponse(self):
+ request = b'\x07\x19'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ with self._control.pause():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_iterator = multi_callable(
+ request, timeout=test_constants.SHORT_TIMEOUT,
+ metadata=((b'test', b'ExpiredUnaryRequestStreamResponse'),))
+ next(response_iterator)
+
+ self.assertIs(
+ grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code())
+
+ def testExpiredStreamRequestBlockingUnaryResponse(self):
+ requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self._control.pause():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ multi_callable(
+ request_iterator, timeout=test_constants.SHORT_TIMEOUT,
+ metadata=((b'test', b'ExpiredStreamRequestBlockingUnaryResponse'),))
+
+ self.assertIsNotNone(exception_context.exception.initial_metadata())
+ self.assertIs(
+ grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
+ self.assertIsNotNone(exception_context.exception.details())
+ self.assertIsNotNone(exception_context.exception.trailing_metadata())
+
+ def testExpiredStreamRequestFutureUnaryResponse(self):
+ requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+ callback = _Callback()
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self._control.pause():
+ response_future = multi_callable.future(
+ request_iterator, timeout=test_constants.SHORT_TIMEOUT,
+ metadata=((b'test', b'ExpiredStreamRequestFutureUnaryResponse'),))
+ response_future.add_done_callback(callback)
+ value_passed_to_callback = callback.value()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
+ self.assertIs(
+ grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIs(response_future, value_passed_to_callback)
+ self.assertIsNotNone(response_future.initial_metadata())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
+ self.assertIsNotNone(response_future.details())
+ self.assertIsNotNone(response_future.trailing_metadata())
+
+ def testExpiredStreamRequestStreamResponse(self):
+ requests = tuple(b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ with self._control.pause():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_iterator = multi_callable(
+ request_iterator, timeout=test_constants.SHORT_TIMEOUT,
+ metadata=((b'test', b'ExpiredStreamRequestStreamResponse'),))
+ next(response_iterator)
+
+ self.assertIs(
+ grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code())
+
+ def testFailedUnaryRequestBlockingUnaryResponse(self):
+ request = b'\x37\x17'
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ with self._control.fail():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ multi_callable(
+ request,
+ metadata=((b'test', b'FailedUnaryRequestBlockingUnaryResponse'),),
+ with_call=True)
+
+ self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
+
+ def testFailedUnaryRequestFutureUnaryResponse(self):
+ request = b'\x37\x17'
+ callback = _Callback()
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ with self._control.fail():
+ response_future = multi_callable.future(
+ request,
+ metadata=((b'test', b'FailedUnaryRequestFutureUnaryResponse'),))
+ response_future.add_done_callback(callback)
+ value_passed_to_callback = callback.value()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(
+ grpc.StatusCode.UNKNOWN, exception_context.exception.code())
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIs(grpc.StatusCode.UNKNOWN, response_future.exception().code())
+ self.assertIs(response_future, value_passed_to_callback)
+
+ def testFailedUnaryRequestStreamResponse(self):
+ request = b'\x37\x17'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ with self._control.fail():
+ response_iterator = multi_callable(
+ request,
+ metadata=((b'test', b'FailedUnaryRequestStreamResponse'),))
+ next(response_iterator)
+
+ self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
+
+ def testFailedStreamRequestBlockingUnaryResponse(self):
+ requests = tuple(b'\x47\x58' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self._control.fail():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ multi_callable(
+ request_iterator,
+ metadata=((b'test', b'FailedStreamRequestBlockingUnaryResponse'),))
+
+ self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
+
+ def testFailedStreamRequestFutureUnaryResponse(self):
+ requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+ callback = _Callback()
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self._control.fail():
+ response_future = multi_callable.future(
+ request_iterator,
+ metadata=((b'test', b'FailedStreamRequestFutureUnaryResponse'),))
+ response_future.add_done_callback(callback)
+ value_passed_to_callback = callback.value()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(grpc.StatusCode.UNKNOWN, response_future.code())
+ self.assertIs(
+ grpc.StatusCode.UNKNOWN, exception_context.exception.code())
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIs(response_future, value_passed_to_callback)
+
+ def testFailedStreamRequestStreamResponse(self):
+ requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ with self._control.fail():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_iterator = multi_callable(
+ request_iterator,
+ metadata=((b'test', b'FailedStreamRequestStreamResponse'),))
+ tuple(response_iterator)
+
+ self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
+ self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code())
+
+ def testIgnoredUnaryRequestFutureUnaryResponse(self):
+ request = b'\x37\x17'
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ multi_callable.future(
+ request,
+ metadata=((b'test', b'IgnoredUnaryRequestFutureUnaryResponse'),))
+
+ def testIgnoredUnaryRequestStreamResponse(self):
+ request = b'\x37\x17'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ multi_callable(
+ request,
+ metadata=((b'test', b'IgnoredUnaryRequestStreamResponse'),))
+
+ def testIgnoredStreamRequestFutureUnaryResponse(self):
+ requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ multi_callable.future(
+ request_iterator,
+ metadata=((b'test', b'IgnoredStreamRequestFutureUnaryResponse'),))
+
+ def testIgnoredStreamRequestStreamResponse(self):
+ requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ multi_callable(
+ request_iterator,
+ metadata=((b'test', b'IgnoredStreamRequestStreamResponse'),))
+
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/_thread_cleanup_test.py b/src/python/grpcio/tests/unit/_thread_cleanup_test.py
new file mode 100644
index 0000000000..3e4f317edc
--- /dev/null
+++ b/src/python/grpcio/tests/unit/_thread_cleanup_test.py
@@ -0,0 +1,117 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Tests for CleanupThread."""
+
+import threading
+import time
+import unittest
+
+from grpc import _common
+
+_SHORT_TIME = 0.5
+_LONG_TIME = 2.0
+_EPSILON = 0.1
+
+
+def cleanup(timeout):
+ if timeout is not None:
+ time.sleep(timeout)
+ else:
+ time.sleep(_LONG_TIME)
+
+
+def slow_cleanup(timeout):
+ # Don't respect timeout
+ time.sleep(_LONG_TIME)
+
+
+class CleanupThreadTest(unittest.TestCase):
+
+ def testTargetInvocation(self):
+ event = threading.Event()
+ def target(arg1, arg2, arg3=None):
+ self.assertEqual('arg1', arg1)
+ self.assertEqual('arg2', arg2)
+ self.assertEqual('arg3', arg3)
+ event.set()
+
+ cleanup_thread = _common.CleanupThread(behavior=lambda x: None,
+ target=target, name='test-name',
+ args=('arg1', 'arg2'), kwargs={'arg3': 'arg3'})
+ cleanup_thread.start()
+ cleanup_thread.join()
+ self.assertEqual(cleanup_thread.name, 'test-name')
+ self.assertTrue(event.is_set())
+
+ def testJoinNoTimeout(self):
+ cleanup_thread = _common.CleanupThread(behavior=cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join()
+ end_time = time.time()
+ self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON)
+
+ def testJoinTimeout(self):
+ cleanup_thread = _common.CleanupThread(behavior=cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(_SHORT_TIME)
+ end_time = time.time()
+ self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON)
+
+ def testJoinTimeoutSlowBehavior(self):
+ cleanup_thread = _common.CleanupThread(behavior=slow_cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(_SHORT_TIME)
+ end_time = time.time()
+ self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON)
+
+ def testJoinTimeoutSlowTarget(self):
+ event = threading.Event()
+ def target():
+ event.wait(_LONG_TIME)
+ cleanup_thread = _common.CleanupThread(behavior=cleanup, target=target)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(_SHORT_TIME)
+ end_time = time.time()
+ self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON)
+ event.set()
+
+ def testJoinZeroTimeout(self):
+ cleanup_thread = _common.CleanupThread(behavior=cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(0)
+ end_time = time.time()
+ self.assertAlmostEqual(0, end_time - start_time, delta=_EPSILON)
+
+if __name__ == '__main__':
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/beta/_beta_features_test.py b/src/python/grpcio/tests/unit/beta/_beta_features_test.py
index bb2893a21b..3a9701b8eb 100644
--- a/src/python/grpcio/tests/unit/beta/_beta_features_test.py
+++ b/src/python/grpcio/tests/unit/beta/_beta_features_test.py
@@ -42,8 +42,8 @@ from tests.unit.framework.common import test_constants
_SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
-_PER_RPC_CREDENTIALS_METADATA_KEY = 'my-call-credentials-metadata-key'
-_PER_RPC_CREDENTIALS_METADATA_VALUE = 'my-call-credentials-metadata-value'
+_PER_RPC_CREDENTIALS_METADATA_KEY = b'my-call-credentials-metadata-key'
+_PER_RPC_CREDENTIALS_METADATA_VALUE = b'my-call-credentials-metadata-value'
_GROUP = 'group'
_UNARY_UNARY = 'unary-unary'
diff --git a/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py b/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py
index 5dc8720639..488f7d7141 100644
--- a/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py
+++ b/src/python/grpcio/tests/unit/beta/_connectivity_channel_test.py
@@ -187,5 +187,15 @@ class ChannelConnectivityTest(unittest.TestCase):
server_completion_queue_thread.join()
+class ConnectivityStatesTest(unittest.TestCase):
+
+ def testBetaConnectivityStates(self):
+ self.assertIsNotNone(interfaces.ChannelConnectivity.IDLE)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.CONNECTING)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.READY)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.TRANSIENT_FAILURE)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.FATAL_FAILURE)
+
+
if __name__ == '__main__':
unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/beta/_implementations_test.py b/src/python/grpcio/tests/unit/beta/_implementations_test.py
index 26be670c45..127f93e9bb 100644
--- a/src/python/grpcio/tests/unit/beta/_implementations_test.py
+++ b/src/python/grpcio/tests/unit/beta/_implementations_test.py
@@ -29,8 +29,11 @@
"""Tests the implementations module of the gRPC Python Beta API."""
+import datetime
import unittest
+from oauth2client import client as oauth2client_client
+
from grpc.beta import implementations
from tests.unit import resources
@@ -49,5 +52,19 @@ class ChannelCredentialsTest(unittest.TestCase):
channel_credentials, implementations.ChannelCredentials)
+class CallCredentialsTest(unittest.TestCase):
+
+ def test_google_call_credentials(self):
+ creds = oauth2client_client.GoogleCredentials(
+ 'token', 'client_id', 'secret', 'refresh_token',
+ datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/',
+ 'user_agent')
+ call_creds = implementations.google_call_credentials(creds)
+ self.assertIsInstance(call_creds, implementations.CallCredentials)
+
+ def test_access_token_call_credentials(self):
+ call_creds = implementations.access_token_call_credentials('token')
+ self.assertIsInstance(call_creds, implementations.CallCredentials)
+
if __name__ == '__main__':
unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/beta/_not_found_test.py b/src/python/grpcio/tests/unit/beta/_not_found_test.py
index 44fcd1e13c..37b8c49120 100644
--- a/src/python/grpcio/tests/unit/beta/_not_found_test.py
+++ b/src/python/grpcio/tests/unit/beta/_not_found_test.py
@@ -61,7 +61,7 @@ class NotFoundTest(unittest.TestCase):
def test_future_stream_unary_not_found(self):
rpc_future = self._generic_stub.future_stream_unary(
- 'grupe', 'mevvod', b'def', test_constants.LONG_TIMEOUT)
+ 'grupe', 'mevvod', [b'def'], test_constants.LONG_TIMEOUT)
with self.assertRaises(face.LocalError) as exception_assertion_context:
rpc_future.result()
self.assertIs(
diff --git a/src/python/grpcio/tests/unit/beta/test_utilities.py b/src/python/grpcio/tests/unit/beta/test_utilities.py
index 0313e06a93..e374b203ce 100644
--- a/src/python/grpcio/tests/unit/beta/test_utilities.py
+++ b/src/python/grpcio/tests/unit/beta/test_utilities.py
@@ -29,7 +29,7 @@
"""Test-appropriate entry points into the gRPC Python Beta API."""
-from grpc._adapter import _intermediary_low
+import grpc
from grpc.beta import implementations
@@ -48,9 +48,8 @@ def not_really_secure_channel(
An implementations.Channel to the remote host through which RPCs may be
conducted.
"""
- hostport = '%s:%d' % (host, port)
- intermediary_low_channel = _intermediary_low.Channel(
- hostport, channel_credentials._low_credentials,
- server_host_override=server_host_override)
- return implementations.Channel(
- intermediary_low_channel._internal, intermediary_low_channel)
+ target = '%s:%d' % (host, port)
+ channel = grpc.secure_channel(
+ target, channel_credentials._credentials,
+ ((b'grpc.ssl_target_name_override', server_host_override,),))
+ return implementations.Channel(channel)
diff --git a/src/python/grpcio/tests/unit/framework/common/test_control.py b/src/python/grpcio/tests/unit/framework/common/test_control.py
index ca5ba3a854..088e2f8b88 100644
--- a/src/python/grpcio/tests/unit/framework/common/test_control.py
+++ b/src/python/grpcio/tests/unit/framework/common/test_control.py
@@ -60,10 +60,16 @@ class Control(six.with_metaclass(abc.ABCMeta)):
class PauseFailControl(Control):
- """A Control that can be used to pause or fail code under control."""
+ """A Control that can be used to pause or fail code under control.
+
+ This object is only safe for use from two threads: one of the system under
+ test calling control and the other from the test system calling pause,
+ block_until_paused, and fail.
+ """
def __init__(self):
self._condition = threading.Condition()
+ self._pause = False
self._paused = False
self._fail = False
@@ -72,19 +78,31 @@ class PauseFailControl(Control):
if self._fail:
raise Defect()
- while self._paused:
+ while self._pause:
+ self._paused = True
+ self._condition.notify_all()
self._condition.wait()
+ self._paused = False
@contextlib.contextmanager
def pause(self):
"""Pauses code under control while controlling code is in context."""
with self._condition:
- self._paused = True
+ self._pause = True
yield
with self._condition:
- self._paused = False
+ self._pause = False
self._condition.notify_all()
+ def block_until_paused(self):
+ """Blocks controlling code until code under control is paused.
+
+ May only be called within the context of a pause call.
+ """
+ with self._condition:
+ while not self._paused:
+ self._condition.wait()
+
@contextlib.contextmanager
def fail(self):
"""Fails code under control while controlling code is in context."""
diff --git a/src/python/grpcio/tests/unit/framework/interfaces/base/test_cases.py b/src/python/grpcio/tests/unit/framework/interfaces/base/test_cases.py
index 4f8e26c9a2..5d16bf98be 100644
--- a/src/python/grpcio/tests/unit/framework/interfaces/base/test_cases.py
+++ b/src/python/grpcio/tests/unit/framework/interfaces/base/test_cases.py
@@ -29,6 +29,8 @@
"""Tests of the base interface of RPC Framework."""
+from __future__ import division
+
import logging
import random
import threading
@@ -54,13 +56,13 @@ class _Serialization(test_interfaces.Serialization):
return request + request
def deserialize_request(self, serialized_request):
- return serialized_request[:len(serialized_request) / 2]
+ return serialized_request[:len(serialized_request) // 2]
def serialize_response(self, response):
return response * 3
def deserialize_response(self, serialized_response):
- return serialized_response[2 * len(serialized_response) / 3:]
+ return serialized_response[2 * len(serialized_response) // 3:]
def _advance(quadruples, operator, controller):
diff --git a/src/python/grpcio/tests/unit/test_common.py b/src/python/grpcio/tests/unit/test_common.py
index 7b4d20dccb..b779f65e7e 100644
--- a/src/python/grpcio/tests/unit/test_common.py
+++ b/src/python/grpcio/tests/unit/test_common.py
@@ -61,6 +61,10 @@ def metadata_transmitted(original_metadata, transmitted_metadata):
original = collections.defaultdict(list)
for key_value_pair in original_metadata:
key, value = tuple(key_value_pair)
+ if not isinstance(key, bytes):
+ key = key.encode()
+ if not isinstance(value, bytes):
+ value = value.encode()
original[key].append(value)
transmitted = collections.defaultdict(list)
for key_value_pair in transmitted_metadata: