diff options
author | Sree Kuchibhotla <sreek@google.com> | 2016-06-06 10:31:45 -0700 |
---|---|---|
committer | Sree Kuchibhotla <sreek@google.com> | 2016-06-06 10:31:45 -0700 |
commit | 1fd6235791370e511093941e23432b446328f3e1 (patch) | |
tree | 27cae5ec3d74d36f5b0c8e8298fc97849a347579 /src/python | |
parent | 9bc3d2d67f32f4dad8cd1319dd4f3fce48c1abee (diff) | |
parent | 37e26d922fc443b67df50abd8b11288eba2920cd (diff) |
Merge branch 'master' into epoll_changes
Diffstat (limited to 'src/python')
26 files changed, 4615 insertions, 51 deletions
diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index b844a14c48..bbf04ad03e 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -27,5 +27,1058 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""gRPC's Python API.""" + __import__('pkg_resources').declare_namespace(__name__) +import abc +import enum + +import six + +from grpc._cython import cygrpc as _cygrpc + + +############################## Future Interface ############################### + + +class FutureTimeoutError(Exception): + """Indicates that a method call on a Future timed out.""" + + +class FutureCancelledError(Exception): + """Indicates that the computation underlying a Future was cancelled.""" + + +class Future(six.with_metaclass(abc.ABCMeta)): + """A representation of a computation in another control flow. + + Computations represented by a Future may be yet to be begun, may be ongoing, + or may have already completed. + """ + + @abc.abstractmethod + def cancel(self): + """Attempts to cancel the computation. + + This method does not block. + + Returns: + True if the computation has not yet begun, will not be allowed to take + place, and determination of both was possible without blocking. False + under all other circumstances including but not limited to the + computation's already having begun, the computation's already having + finished, and the computation's having been scheduled for execution on a + remote system for which a determination of whether or not it commenced + before being cancelled cannot be made without blocking. + """ + raise NotImplementedError() + + @abc.abstractmethod + def cancelled(self): + """Describes whether the computation was cancelled. + + This method does not block. + + Returns: + True if the computation was cancelled any time before its result became + immediately available. False under all other circumstances including but + not limited to this object's cancel method not having been called and + the computation's result having become immediately available. + """ + raise NotImplementedError() + + @abc.abstractmethod + def running(self): + """Describes whether the computation is taking place. + + This method does not block. + + Returns: + True if the computation is scheduled to take place in the future or is + taking place now, or False if the computation took place in the past or + was cancelled. + """ + raise NotImplementedError() + + @abc.abstractmethod + def done(self): + """Describes whether the computation has taken place. + + This method does not block. + + Returns: + True if the computation is known to have either completed or have been + unscheduled or interrupted. False if the computation may possibly be + executing or scheduled to execute later. + """ + raise NotImplementedError() + + @abc.abstractmethod + def result(self, timeout=None): + """Accesses the outcome of the computation or raises its exception. + + This method may return immediately or may block. + + Args: + timeout: The length of time in seconds to wait for the computation to + finish or be cancelled, or None if this method should block until the + computation has finished or is cancelled no matter how long that takes. + + Returns: + The return value of the computation. + + Raises: + FutureTimeoutError: If a timeout value is passed and the computation does + not terminate within the allotted time. + FutureCancelledError: If the computation was cancelled. + Exception: If the computation raised an exception, this call will raise + the same exception. + """ + raise NotImplementedError() + + @abc.abstractmethod + def exception(self, timeout=None): + """Return the exception raised by the computation. + + This method may return immediately or may block. + + Args: + timeout: The length of time in seconds to wait for the computation to + terminate or be cancelled, or None if this method should block until + the computation is terminated or is cancelled no matter how long that + takes. + + Returns: + The exception raised by the computation, or None if the computation did + not raise an exception. + + Raises: + FutureTimeoutError: If a timeout value is passed and the computation does + not terminate within the allotted time. + FutureCancelledError: If the computation was cancelled. + """ + raise NotImplementedError() + + @abc.abstractmethod + def traceback(self, timeout=None): + """Access the traceback of the exception raised by the computation. + + This method may return immediately or may block. + + Args: + timeout: The length of time in seconds to wait for the computation to + terminate or be cancelled, or None if this method should block until + the computation is terminated or is cancelled no matter how long that + takes. + + Returns: + The traceback of the exception raised by the computation, or None if the + computation did not raise an exception. + + Raises: + FutureTimeoutError: If a timeout value is passed and the computation does + not terminate within the allotted time. + FutureCancelledError: If the computation was cancelled. + """ + raise NotImplementedError() + + @abc.abstractmethod + def add_done_callback(self, fn): + """Adds a function to be called at completion of the computation. + + The callback will be passed this Future object describing the outcome of + the computation. + + If the computation has already completed, the callback will be called + immediately. + + Args: + fn: A callable taking this Future object as its single parameter. + """ + raise NotImplementedError() + + +################################ gRPC Enums ################################## + + +@enum.unique +class ChannelConnectivity(enum.Enum): + """Mirrors grpc_connectivity_state in the gRPC Core. + + Attributes: + IDLE: The channel is idle. + CONNECTING: The channel is connecting. + READY: The channel is ready to conduct RPCs. + TRANSIENT_FAILURE: The channel has seen a failure from which it expects to + recover. + FATAL_FAILURE: The channel has seen a failure from which it cannot recover. + """ + IDLE = (_cygrpc.ConnectivityState.idle, 'idle') + CONNECTING = (_cygrpc.ConnectivityState.connecting, 'connecting') + READY = (_cygrpc.ConnectivityState.ready, 'ready') + TRANSIENT_FAILURE = ( + _cygrpc.ConnectivityState.transient_failure, 'transient failure') + FATAL_FAILURE = (_cygrpc.ConnectivityState.fatal_failure, 'fatal failure') + + +@enum.unique +class StatusCode(enum.Enum): + """Mirrors grpc_status_code in the gRPC Core.""" + OK = (_cygrpc.StatusCode.ok, 'ok') + CANCELLED = (_cygrpc.StatusCode.cancelled, 'cancelled') + UNKNOWN = (_cygrpc.StatusCode.unknown, 'unknown') + INVALID_ARGUMENT = ( + _cygrpc.StatusCode.invalid_argument, 'invalid argument') + DEADLINE_EXCEEDED = ( + _cygrpc.StatusCode.deadline_exceeded, 'deadline exceeded') + NOT_FOUND = (_cygrpc.StatusCode.not_found, 'not found') + ALREADY_EXISTS = (_cygrpc.StatusCode.already_exists, 'already exists') + PERMISSION_DENIED = ( + _cygrpc.StatusCode.permission_denied, 'permission denied') + RESOURCE_EXHAUSTED = ( + _cygrpc.StatusCode.resource_exhausted, 'resource exhausted') + FAILED_PRECONDITION = ( + _cygrpc.StatusCode.failed_precondition, 'failed precondition') + ABORTED = (_cygrpc.StatusCode.aborted, 'aborted') + OUT_OF_RANGE = (_cygrpc.StatusCode.out_of_range, 'out of range') + UNIMPLEMENTED = (_cygrpc.StatusCode.unimplemented, 'unimplemented') + INTERNAL = (_cygrpc.StatusCode.internal, 'internal') + UNAVAILABLE = (_cygrpc.StatusCode.unavailable, 'unavailable') + DATA_LOSS = (_cygrpc.StatusCode.data_loss, 'data loss') + UNAUTHENTICATED = (_cygrpc.StatusCode.unauthenticated, 'unauthenticated') + + +############################# gRPC Exceptions ################################ + + +class RpcError(Exception): + """Raised by the gRPC library to indicate non-OK-status RPC termination.""" + + +############################## Shared Context ################################ + + +class RpcContext(six.with_metaclass(abc.ABCMeta)): + """Provides RPC-related information and action.""" + + @abc.abstractmethod + def is_active(self): + """Describes whether the RPC is active or has terminated.""" + raise NotImplementedError() + + @abc.abstractmethod + def time_remaining(self): + """Describes the length of allowed time remaining for the RPC. + + Returns: + A nonnegative float indicating the length of allowed time in seconds + remaining for the RPC to complete before it is considered to have timed + out, or None if no deadline was specified for the RPC. + """ + raise NotImplementedError() + + @abc.abstractmethod + def cancel(self): + """Cancels the RPC. + + Idempotent and has no effect if the RPC has already terminated. + """ + raise NotImplementedError() + + @abc.abstractmethod + def add_callback(self, callback): + """Registers a callback to be called on RPC termination. + + Args: + callback: A no-parameter callable to be called on RPC termination. + + Returns: + True if the callback was added and will be called later; False if the + callback was not added and will not later be called (because the RPC + already terminated or some other reason). + """ + raise NotImplementedError() + + +######################### Invocation-Side Context ############################ + + +class Call(six.with_metaclass(abc.ABCMeta, RpcContext)): + """Invocation-side utility object for an RPC.""" + + @abc.abstractmethod + def initial_metadata(self): + """Accesses the initial metadata from the service-side of the RPC. + + This method blocks until the value is available. + + Returns: + The initial metadata as a sequence of pairs of bytes. + """ + raise NotImplementedError() + + @abc.abstractmethod + def trailing_metadata(self): + """Accesses the trailing metadata from the service-side of the RPC. + + This method blocks until the value is available. + + Returns: + The trailing metadata as a sequence of pairs of bytes. + """ + raise NotImplementedError() + + @abc.abstractmethod + def code(self): + """Accesses the status code emitted by the service-side of the RPC. + + This method blocks until the value is available. + + Returns: + The StatusCode value for the RPC. + """ + raise NotImplementedError() + + @abc.abstractmethod + def details(self): + """Accesses the details value emitted by the service-side of the RPC. + + This method blocks until the value is available. + + Returns: + The bytes of the details of the RPC. + """ + raise NotImplementedError() + + +############ Authentication & Authorization Interfaces & Classes ############# + + +class ChannelCredentials(object): + """A value encapsulating the data required to create a secure Channel. + + This class has no supported interface - it exists to define the type of its + instances and its instances exist to be passed to other functions. + """ + + def __init__(self, credentials): + self._credentials = credentials + + +class CallCredentials(object): + """A value encapsulating data asserting an identity over a channel. + + A CallCredentials may be composed with ChannelCredentials to always assert + identity for every call over that Channel. + + This class has no supported interface - it exists to define the type of its + instances and its instances exist to be passed to other functions. + """ + + def __init__(self, credentials): + self._credentials = credentials + + +class AuthMetadataContext(six.with_metaclass(abc.ABCMeta)): + """Provides information to call credentials metadata plugins. + + Attributes: + service_url: A string URL of the service being called into. + method_name: A string of the fully qualified method name being called. + """ + + +class AuthMetadataPluginCallback(six.with_metaclass(abc.ABCMeta)): + """Callback object received by a metadata plugin.""" + + def __call__(self, metadata, error): + """Inform the gRPC runtime of the metadata to construct a CallCredentials. + + Args: + metadata: An iterable of 2-sequences (e.g. tuples) of metadata key/value + pairs. + error: An Exception to indicate error or None to indicate success. + """ + raise NotImplementedError() + + +class AuthMetadataPlugin(six.with_metaclass(abc.ABCMeta)): + """A specification for custom authentication.""" + + def __call__(self, context, callback): + """Implements authentication by passing metadata to a callback. + + Implementations of this method must not block. + + Args: + context: An AuthMetadataContext providing information on the RPC that the + plugin is being called to authenticate. + callback: An AuthMetadataPluginCallback to be invoked either synchronously + or asynchronously. + """ + raise NotImplementedError() + + +class ServerCredentials(object): + """A value encapsulating the data required to open a secure port on a Server. + + This class has no supported interface - it exists to define the type of its + instances and its instances exist to be passed to other functions. + """ + + def __init__(self, credentials): + self._credentials = credentials + + +######################## Multi-Callable Interfaces ########################### + + +class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): + """Affords invoking a unary-unary RPC.""" + + @abc.abstractmethod + def __call__( + self, request, timeout=None, metadata=None, credentials=None, + with_call=False): + """Synchronously invokes the underlying RPC. + + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + metadata: An optional sequence of pairs of bytes to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. + with_call: Whether or not to include return a Call for the RPC in addition + to the response. + + Returns: + The response value for the RPC, and a Call for the RPC if with_call was + set to True at invocation. + + Raises: + RpcError: Indicating that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + raise NotImplementedError() + + @abc.abstractmethod + def future(self, request, timeout=None, metadata=None, credentials=None): + """Asynchronously invokes the underlying RPC. + + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + metadata: An optional sequence of pairs of bytes to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. + + Returns: + An object that is both a Call for the RPC and a Future. In the event of + RPC completion, the return Future's result value will be the response + message of the RPC. Should the event terminate with non-OK status, the + returned Future's exception value will be an RpcError. + """ + raise NotImplementedError() + + +class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): + """Affords invoking a unary-stream RPC.""" + + @abc.abstractmethod + def __call__(self, request, timeout=None, metadata=None, credentials=None): + """Invokes the underlying RPC. + + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + metadata: An optional sequence of pairs of bytes to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. + + Returns: + An object that is both a Call for the RPC and an iterator of response + values. Drawing response values from the returned iterator may raise + RpcError indicating termination of the RPC with non-OK status. + """ + raise NotImplementedError() + + +class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): + """Affords invoking a stream-unary RPC in any call style.""" + + @abc.abstractmethod + def __call__( + self, request_iterator, timeout=None, metadata=None, credentials=None, + with_call=False): + """Synchronously invokes the underlying RPC. + + Args: + request_iterator: An iterator that yields request values for the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + metadata: An optional sequence of pairs of bytes to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. + with_call: Whether or not to include return a Call for the RPC in addition + to the response. + + Returns: + The response value for the RPC, and a Call for the RPC if with_call was + set to True at invocation. + + Raises: + RpcError: Indicating that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + raise NotImplementedError() + + @abc.abstractmethod + def future( + self, request_iterator, timeout=None, metadata=None, credentials=None): + """Asynchronously invokes the underlying RPC. + + Args: + request_iterator: An iterator that yields request values for the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + metadata: An optional sequence of pairs of bytes to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. + + Returns: + An object that is both a Call for the RPC and a Future. In the event of + RPC completion, the return Future's result value will be the response + message of the RPC. Should the event terminate with non-OK status, the + returned Future's exception value will be an RpcError. + """ + raise NotImplementedError() + + +class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): + """Affords invoking a stream-stream RPC in any call style.""" + + @abc.abstractmethod + def __call__( + self, request_iterator, timeout=None, metadata=None, credentials=None): + """Invokes the underlying RPC. + + Args: + request_iterator: An iterator that yields request values for the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + metadata: An optional sequence of pairs of bytes to be transmitted to the + service-side of the RPC. + credentials: An optional CallCredentials for the RPC. + + Returns: + An object that is both a Call for the RPC and an iterator of response + values. Drawing response values from the returned iterator may raise + RpcError indicating termination of the RPC with non-OK status. + """ + raise NotImplementedError() + + +############################# Channel Interface ############################## + + +class Channel(six.with_metaclass(abc.ABCMeta)): + """Affords RPC invocation via generic methods.""" + + @abc.abstractmethod + def subscribe(self, callback, try_to_connect=False): + """Subscribes to this Channel's connectivity. + + Args: + callback: A callable to be invoked and passed a ChannelConnectivity value + describing this Channel's connectivity. The callable will be invoked + immediately upon subscription and again for every change to this + Channel's connectivity thereafter until it is unsubscribed or this + Channel object goes out of scope. + try_to_connect: A boolean indicating whether or not this Channel should + attempt to connect if it is not already connected and ready to conduct + RPCs. + """ + raise NotImplementedError() + + @abc.abstractmethod + def unsubscribe(self, callback): + """Unsubscribes a callback from this Channel's connectivity. + + Args: + callback: A callable previously registered with this Channel from having + been passed to its "subscribe" method. + """ + raise NotImplementedError() + + @abc.abstractmethod + def unary_unary( + self, method, request_serializer=None, response_deserializer=None): + """Creates a UnaryUnaryMultiCallable for a unary-unary method. + + Args: + method: The name of the RPC method. + + Returns: + A UnaryUnaryMultiCallable value for the named unary-unary method. + """ + raise NotImplementedError() + + @abc.abstractmethod + def unary_stream( + self, method, request_serializer=None, response_deserializer=None): + """Creates a UnaryStreamMultiCallable for a unary-stream method. + + Args: + method: The name of the RPC method. + + Returns: + A UnaryStreamMultiCallable value for the name unary-stream method. + """ + raise NotImplementedError() + + @abc.abstractmethod + def stream_unary( + self, method, request_serializer=None, response_deserializer=None): + """Creates a StreamUnaryMultiCallable for a stream-unary method. + + Args: + method: The name of the RPC method. + + Returns: + A StreamUnaryMultiCallable value for the named stream-unary method. + """ + raise NotImplementedError() + + @abc.abstractmethod + def stream_stream( + self, method, request_serializer=None, response_deserializer=None): + """Creates a StreamStreamMultiCallable for a stream-stream method. + + Args: + method: The name of the RPC method. + + Returns: + A StreamStreamMultiCallable value for the named stream-stream method. + """ + raise NotImplementedError() + + +########################## Service-Side Context ############################## + + +class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)): + """A context object passed to method implementations.""" + + @abc.abstractmethod + def invocation_metadata(self): + """Accesses the metadata from the invocation-side of the RPC. + + Returns: + The invocation metadata object as a sequence of pairs of bytes. + """ + raise NotImplementedError() + + @abc.abstractmethod + def peer(self): + """Identifies the peer that invoked the RPC being serviced. + + Returns: + A string identifying the peer that invoked the RPC being serviced. + """ + raise NotImplementedError() + + @abc.abstractmethod + def send_initial_metadata(self, initial_metadata): + """Sends the initial metadata value to the invocation-side of the RPC. + + This method need not be called by method implementations if they have no + service-side initial metadata to transmit. + + Args: + initial_metadata: The initial metadata of the RPC as a sequence of pairs + of bytes. + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_trailing_metadata(self, trailing_metadata): + """Accepts the trailing metadata value of the RPC. + + This method need not be called by method implementations if they have no + service-side trailing metadata to transmit. + + Args: + trailing_metadata: The trailing metadata of the RPC as a sequence of pairs + of bytes. + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_code(self, code): + """Accepts the status code of the RPC. + + This method need not be called by method implementations if they wish the + gRPC runtime to determine the status code of the RPC. + + Args: + code: The integer status code of the RPC to be transmitted to the + invocation side of the RPC. + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_details(self, details): + """Accepts the service-side details of the RPC. + + This method need not be called by method implementations if they have no + details to transmit. + + Args: + details: The details bytes of the RPC to be transmitted to + the invocation side of the RPC. + """ + raise NotImplementedError() + + +##################### Service-Side Handler Interfaces ######################## + + +class RpcMethodHandler(six.with_metaclass(abc.ABCMeta)): + """An implementation of a single RPC method. + + Attributes: + request_streaming: Whether the RPC supports exactly one request message or + any arbitrary number of request messages. + response_streaming: Whether the RPC supports exactly one response message or + any arbitrary number of response messages. + request_deserializer: A callable behavior that accepts a byte string and + returns an object suitable to be passed to this object's business logic, + or None to indicate that this object's business logic should be passed the + raw request bytes. + response_serializer: A callable behavior that accepts an object produced by + this object's business logic and returns a byte string, or None to + indicate that the byte strings produced by this object's business logic + should be transmitted on the wire as they are. + unary_unary: This object's application-specific business logic as a callable + value that takes a request value and a ServicerContext object and returns + a response value. Only non-None if both request_streaming and + response_streaming are False. + unary_stream: This object's application-specific business logic as a + callable value that takes a request value and a ServicerContext object and + returns an iterator of response values. Only non-None if request_streaming + is False and response_streaming is True. + stream_unary: This object's application-specific business logic as a + callable value that takes an iterator of request values and a + ServicerContext object and returns a response value. Only non-None if + request_streaming is True and response_streaming is False. + stream_stream: This object's application-specific business logic as a + callable value that takes an iterator of request values and a + ServicerContext object and returns an iterator of response values. Only + non-None if request_streaming and response_streaming are both True. + """ + + +class HandlerCallDetails(six.with_metaclass(abc.ABCMeta)): + """Describes an RPC that has just arrived for service. + Attributes: + method: The method name of the RPC. + invocation_metadata: The metadata from the invocation side of the RPC. + """ + + +class GenericRpcHandler(six.with_metaclass(abc.ABCMeta)): + """An implementation of arbitrarily many RPC methods.""" + + @abc.abstractmethod + def service(self, handler_call_details): + """Services an RPC (or not). + + Args: + handler_call_details: A HandlerCallDetails describing the RPC. + + Returns: + An RpcMethodHandler with which the RPC may be serviced, or None to + indicate that this object will not be servicing the RPC. + """ + raise NotImplementedError() + + +############################# Server Interface ############################### + + +class Server(six.with_metaclass(abc.ABCMeta)): + """Services RPCs.""" + + @abc.abstractmethod + def add_generic_rpc_handlers(self, generic_rpc_handlers): + """Registers GenericRpcHandlers with this Server. + + This method is only safe to call before the server is started. + + Args: + generic_rpc_handlers: An iterable of GenericRpcHandlers that will be used + to service RPCs after this Server is started. + """ + raise NotImplementedError() + + @abc.abstractmethod + def add_insecure_port(self, address): + """Reserves a port for insecure RPC service once this Server becomes active. + + This method may only be called before calling this Server's start method is + called. + + Args: + address: The address for which to open a port. + + Returns: + An integer port on which RPCs will be serviced after this link has been + started. This is typically the same number as the port number contained + in the passed address, but will likely be different if the port number + contained in the passed address was zero. + """ + raise NotImplementedError() + + @abc.abstractmethod + def add_secure_port(self, address, server_credentials): + """Reserves a port for secure RPC service after this Server becomes active. + + This method may only be called before calling this Server's start method is + called. + + Args: + address: The address for which to open a port. + server_credentials: A ServerCredentials. + + Returns: + An integer port on which RPCs will be serviced after this link has been + started. This is typically the same number as the port number contained + in the passed address, but will likely be different if the port number + contained in the passed address was zero. + """ + raise NotImplementedError() + + @abc.abstractmethod + def start(self): + """Starts this Server's service of RPCs. + + This method may only be called while the server is not serving RPCs (i.e. it + is not idempotent). + """ + raise NotImplementedError() + + @abc.abstractmethod + def stop(self, grace): + """Stops this Server's service of RPCs. + + All calls to this method immediately stop service of new RPCs. When existing + RPCs are aborted is controlled by the grace period parameter passed to this + method. + + This method may be called at any time and is idempotent. Passing a smaller + grace value than has been passed in a previous call will have the effect of + stopping the Server sooner. Passing a larger grace value than has been + passed in a previous call will not have the effect of stopping the server + later. + + Args: + grace: A duration of time in seconds to allow existing RPCs to complete + before being aborted by this Server's stopping. If None, this method + will block until the server is completely stopped. + + Returns: + A threading.Event that will be set when this Server has completely + stopped. The returned event may not be set until after the full grace + period (if some ongoing RPC continues for the full length of the period) + of it may be set much sooner (such as if this Server had no RPCs underway + at the time it was stopped or if all RPCs that it had underway completed + very early in the grace period). + """ + raise NotImplementedError() + + +################################# Functions ################################ + + +def ssl_channel_credentials( + root_certificates=None, private_key=None, certificate_chain=None): + """Creates a ChannelCredentials for use with an SSL-enabled Channel. + + Args: + root_certificates: The PEM-encoded root certificates or unset to ask for + them to be retrieved from a default location. + private_key: The PEM-encoded private key to use or unset if no private key + should be used. + certificate_chain: The PEM-encoded certificate chain to use or unset if no + certificate chain should be used. + + Returns: + A ChannelCredentials for use with an SSL-enabled Channel. + """ + if private_key is not None or certificate_chain is not None: + pair = _cygrpc.SslPemKeyCertPair(private_key, certificate_chain) + else: + pair = None + return ChannelCredentials( + _cygrpc.channel_credentials_ssl(root_certificates, pair)) + + +def metadata_call_credentials(metadata_plugin, name=None): + """Construct CallCredentials from an AuthMetadataPlugin. + + Args: + metadata_plugin: An AuthMetadataPlugin to use as the authentication behavior + in the created CallCredentials. + name: A name for the plugin. + + Returns: + A CallCredentials. + """ + from grpc import _plugin_wrapping + if name is None: + try: + effective_name = metadata_plugin.__name__ + except AttributeError: + effective_name = metadata_plugin.__class__.__name__ + else: + effective_name = name + return CallCredentials( + _plugin_wrapping.call_credentials_metadata_plugin( + metadata_plugin, effective_name)) + + +def composite_call_credentials(call_credentials, additional_call_credentials): + """Compose two CallCredentials to make a new one. + + Args: + call_credentials: A CallCredentials object. + additional_call_credentials: Another CallCredentials object to compose on + top of call_credentials. + + Returns: + A new CallCredentials composed of the two given CallCredentials. + """ + return CallCredentials( + _cygrpc.call_credentials_composite( + call_credentials._credentials, + additional_call_credentials._credentials)) + + +def composite_channel_credentials(channel_credentials, call_credentials): + """Compose a ChannelCredentials and a CallCredentials. + + Args: + channel_credentials: A ChannelCredentials. + call_credentials: A CallCredentials. + + Returns: + A ChannelCredentials composed of the given ChannelCredentials and + CallCredentials. + """ + return ChannelCredentials( + _cygrpc.channel_credentials_composite( + channel_credentials._credentials, call_credentials._credentials)) + + +def ssl_server_credentials( + private_key_certificate_chain_pairs, root_certificates=None, + require_client_auth=False): + """Creates a ServerCredentials for use with an SSL-enabled Server. + + Args: + private_key_certificate_chain_pairs: A nonempty sequence each element of + which is a pair the first element of which is a PEM-encoded private key + and the second element of which is the corresponding PEM-encoded + certificate chain. + root_certificates: PEM-encoded client root certificates to be used for + verifying authenticated clients. If omitted, require_client_auth must also + be omitted or be False. + require_client_auth: A boolean indicating whether or not to require clients + to be authenticated. May only be True if root_certificates is not None. + + Returns: + A ServerCredentials for use with an SSL-enabled Server. + """ + if len(private_key_certificate_chain_pairs) == 0: + raise ValueError( + 'At least one private key-certificate chain pair is required!') + elif require_client_auth and root_certificates is None: + raise ValueError( + 'Illegal to require client auth without providing root certificates!') + else: + return ServerCredentials( + _cygrpc.server_credentials_ssl( + root_certificates, + [_cygrpc.SslPemKeyCertPair(key, pem) + for key, pem in private_key_certificate_chain_pairs], + require_client_auth)) + + +def channel_ready_future(channel): + """Creates a Future tracking when a Channel is ready. + + Cancelling the returned Future does not tell the given Channel to abandon + attempts it may have been making to connect; cancelling merely deactivates the + returned Future's subscription to the given Channel's connectivity. + + Args: + channel: A Channel. + + Returns: + A Future that matures when the given Channel has connectivity + ChannelConnectivity.READY. + """ + from grpc import _utilities + return _utilities.channel_ready_future(channel) + + +def insecure_channel(target, options=None): + """Creates an insecure Channel to a server. + + Args: + target: The target to which to connect. + options: A sequence of string-value pairs according to which to configure + the created channel. + + Returns: + A Channel to the target through which RPCs may be conducted. + """ + from grpc import _channel + return _channel.Channel(target, None, options) + + +def secure_channel(target, credentials, options=None): + """Creates an insecure Channel to a server. + + Args: + target: The target to which to connect. + credentials: A ChannelCredentials instance. + options: A sequence of string-value pairs according to which to configure + the created channel. + + Returns: + A Channel to the target through which RPCs may be conducted. + """ + from grpc import _channel + return _channel.Channel(target, credentials, options) + + +def server(generic_rpc_handlers, thread_pool, options=None): + """Creates a Server with which RPCs can be serviced. + + The GenericRpcHandlers passed to this function needn't be the only + GenericRpcHandlers that will be used to serve RPCs; others may be added later + by calling add_generic_rpc_handlers any time before the returned server is + started. + + Args: + generic_rpc_handlers: Some number of GenericRpcHandlers that will be used + to service RPCs after the returned Server is started. + thread_pool: A futures.ThreadPoolExecutor to be used by the returned Server + to service RPCs. + + Returns: + A Server with which RPCs can be serviced. + """ + from grpc import _server + return _server.Server(generic_rpc_handlers, thread_pool) diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py new file mode 100644 index 0000000000..d9eb5a4b77 --- /dev/null +++ b/src/python/grpcio/grpc/_channel.py @@ -0,0 +1,852 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Invocation-side implementation of gRPC Python.""" + +import sys +import threading +import time + +import grpc +from grpc import _common +from grpc import _grpcio_metadata +from grpc.framework.foundation import callable_util +from grpc._cython import cygrpc + +_USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__) + +_EMPTY_FLAGS = 0 +_INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) +_EMPTY_METADATA = cygrpc.Metadata(()) + +_UNARY_UNARY_INITIAL_DUE = ( + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.send_message, + cygrpc.OperationType.send_close_from_client, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_message, + cygrpc.OperationType.receive_status_on_client, +) +_UNARY_STREAM_INITIAL_DUE = ( + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.send_message, + cygrpc.OperationType.send_close_from_client, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_status_on_client, +) +_STREAM_UNARY_INITIAL_DUE = ( + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_message, + cygrpc.OperationType.receive_status_on_client, +) +_STREAM_STREAM_INITIAL_DUE = ( + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_status_on_client, +) + +_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = ( + 'Exception calling channel subscription callback!') + + +def _deadline(timeout): + if timeout is None: + return None, _INFINITE_FUTURE + else: + deadline = time.time() + timeout + return deadline, cygrpc.Timespec(deadline) + + +def _unknown_code_details(unknown_cygrpc_code, details): + return b'Server sent unknown code {} and details "{}"'.format( + unknown_cygrpc_code, details) + + +def _wait_once_until(condition, until): + if until is None: + condition.wait() + else: + remaining = until - time.time() + if remaining < 0: + raise grpc.FutureTimeoutError() + else: + condition.wait(timeout=remaining) + + +class _RPCState(object): + + def __init__(self, due, initial_metadata, trailing_metadata, code, details): + self.condition = threading.Condition() + # The cygrpc.OperationType objects representing events due from the RPC's + # completion queue. + self.due = set(due) + self.initial_metadata = initial_metadata + self.response = None + self.trailing_metadata = trailing_metadata + self.code = code + self.details = details + # The semantics of grpc.Future.cancel and grpc.Future.cancelled are + # slightly wonky, so they have to be tracked separately from the rest of the + # result of the RPC. This field tracks whether cancellation was requested + # prior to termination of the RPC. + self.cancelled = False + self.callbacks = [] + + +def _abort(state, code, details): + if state.code is None: + state.code = code + state.details = details + if state.initial_metadata is None: + state.initial_metadata = _EMPTY_METADATA + state.trailing_metadata = _EMPTY_METADATA + + +def _handle_event(event, state, response_deserializer): + callbacks = [] + for batch_operation in event.batch_operations: + operation_type = batch_operation.type + state.due.remove(operation_type) + if operation_type is cygrpc.OperationType.receive_initial_metadata: + state.initial_metadata = batch_operation.received_metadata + elif operation_type is cygrpc.OperationType.receive_message: + serialized_response = batch_operation.received_message.bytes() + if serialized_response is not None: + response = _common.deserialize( + serialized_response, response_deserializer) + if response is None: + details = b'Exception deserializing response!' + _abort(state, grpc.StatusCode.INTERNAL, details) + else: + state.response = response + elif operation_type is cygrpc.OperationType.receive_status_on_client: + state.trailing_metadata = batch_operation.received_metadata + if state.code is None: + code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE.get( + batch_operation.received_status_code) + if code is None: + state.code = grpc.StatusCode.UNKNOWN + state.details = _unknown_code_details( + batch_operation.received_status_code, + batch_operation.received_status_details) + else: + state.code = code + state.details = batch_operation.received_status_details + callbacks.extend(state.callbacks) + state.callbacks = None + return callbacks + + +def _event_handler(state, call, response_deserializer): + def handle_event(event): + with state.condition: + callbacks = _handle_event(event, state, response_deserializer) + state.condition.notify_all() + done = not state.due + for callback in callbacks: + callback() + return call if done else None + return handle_event + + +def _consume_request_iterator( + request_iterator, state, call, request_serializer): + event_handler = _event_handler(state, call, None) + def consume_request_iterator(): + for request in request_iterator: + serialized_request = _common.serialize(request, request_serializer) + with state.condition: + if state.code is None and not state.cancelled: + if serialized_request is None: + call.cancel() + details = b'Exception serializing request!' + _abort(state, grpc.StatusCode.INTERNAL, details) + return + else: + operations = ( + cygrpc.operation_send_message( + serialized_request, _EMPTY_FLAGS), + ) + call.start_batch(cygrpc.Operations(operations), event_handler) + state.due.add(cygrpc.OperationType.send_message) + while True: + state.condition.wait() + if state.code is None: + if cygrpc.OperationType.send_message not in state.due: + break + else: + return + else: + return + with state.condition: + if state.code is None: + operations = ( + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + ) + call.start_batch(cygrpc.Operations(operations), event_handler) + state.due.add(cygrpc.OperationType.send_close_from_client) + thread = threading.Thread(target=consume_request_iterator) + thread.start() + + +class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): + + def __init__(self, state, call, response_deserializer, deadline): + super(_Rendezvous, self).__init__() + self._state = state + self._call = call + self._response_deserializer = response_deserializer + self._deadline = deadline + + def cancel(self): + with self._state.condition: + if self._state.code is None: + self._call.cancel() + self._state.cancelled = True + _abort(self._state, grpc.StatusCode.CANCELLED, b'Cancelled!') + self._state.condition.notify_all() + return False + + def cancelled(self): + with self._state.condition: + return self._state.cancelled + + def running(self): + with self._state.condition: + return self._state.code is None + + def done(self): + with self._state.condition: + return self._state.code is not None + + def result(self, timeout=None): + until = None if timeout is None else time.time() + timeout + with self._state.condition: + while True: + if self._state.code is None: + _wait_once_until(self._state.condition, until) + elif self._state.code is grpc.StatusCode.OK: + return self._state.response + elif self._state.cancelled: + raise grpc.FutureCancelledError() + else: + raise self + + def exception(self, timeout=None): + until = None if timeout is None else time.time() + timeout + with self._state.condition: + while True: + if self._state.code is None: + _wait_once_until(self._state.condition, until) + elif self._state.code is grpc.StatusCode.OK: + return None + elif self._state.cancelled: + raise grpc.FutureCancelledError() + else: + return self + + def traceback(self, timeout=None): + until = None if timeout is None else time.time() + timeout + with self._state.condition: + while True: + if self._state.code is None: + _wait_once_until(self._state.condition, until) + elif self._state.code is grpc.StatusCode.OK: + return None + elif self._state.cancelled: + raise grpc.FutureCancelledError() + else: + try: + raise self + except grpc.RpcError: + return sys.exc_info()[2] + + def add_done_callback(self, fn): + with self._state.condition: + if self._state.code is None: + self._state.callbacks.append(lambda: fn(self)) + return + + fn(self) + + def _next(self): + with self._state.condition: + if self._state.code is None: + event_handler = _event_handler( + self._state, self._call, self._response_deserializer) + self._call.start_batch( + cygrpc.Operations( + (cygrpc.operation_receive_message(_EMPTY_FLAGS),)), + event_handler) + self._state.due.add(cygrpc.OperationType.receive_message) + elif self._state.code is grpc.StatusCode.OK: + raise StopIteration() + else: + raise self + while True: + self._state.condition.wait() + if self._state.response is not None: + response = self._state.response + self._state.response = None + return response + elif cygrpc.OperationType.receive_message not in self._state.due: + if self._state.code is grpc.StatusCode.OK: + raise StopIteration() + elif self._state.code is not None: + raise self + + def __iter__(self): + return self + + def __next__(self): + return self._next() + + def next(self): + return self._next() + + def is_active(self): + with self._state.condition: + return self._state.code is None + + def time_remaining(self): + if self._deadline is None: + return None + else: + return max(self._deadline - time.time(), 0) + + def add_cancellation_callback(self, callback): + with self._state.condition: + if self._state.callbacks is None: + return False + else: + self._state.callbacks.append(lambda unused_future: callback()) + return True + + def initial_metadata(self): + with self._state.condition: + while self._state.initial_metadata is None: + self._state.condition.wait() + return self._state.initial_metadata + + def trailing_metadata(self): + with self._state.condition: + while self._state.trailing_metadata is None: + self._state.condition.wait() + return self._state.trailing_metadata + + def code(self): + with self._state.condition: + while self._state.code is None: + self._state.condition.wait() + return self._state.code + + def details(self): + with self._state.condition: + while self._state.details is None: + self._state.condition.wait() + return self._state.details + + def _repr(self): + with self._state.condition: + if self._state.code is None: + return '<_Rendezvous object of in-flight RPC>' + else: + return '<_Rendezvous of RPC that terminated with ({}, {})>'.format( + self._state.code, self._state.details) + + def __repr__(self): + return self._repr() + + def __str__(self): + return self._repr() + + def __del__(self): + with self._state.condition: + if self._state.code is None: + self._call.cancel() + self._state.cancelled = True + self._state.code = grpc.StatusCode.CANCELLED + self._state.condition.notify_all() + + +def _start_unary_request(request, timeout, request_serializer): + deadline, deadline_timespec = _deadline(timeout) + serialized_request = _common.serialize(request, request_serializer) + if serialized_request is None: + state = _RPCState( + (), _EMPTY_METADATA, _EMPTY_METADATA, grpc.StatusCode.INTERNAL, + b'Exception serializing request!') + rendezvous = _Rendezvous(state, None, None, deadline) + return deadline, deadline_timespec, None, rendezvous + else: + return deadline, deadline_timespec, serialized_request, None + + +def _end_unary_response_blocking(state, with_call, deadline): + if state.code is grpc.StatusCode.OK: + if with_call: + rendezvous = _Rendezvous(state, None, None, deadline) + return state.response, rendezvous + else: + return state.response + else: + raise _Rendezvous(state, None, None, deadline) + + +class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): + + def __init__( + self, channel, create_managed_call, method, request_serializer, + response_deserializer): + self._channel = channel + self._create_managed_call = create_managed_call + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + def _prepare(self, request, timeout, metadata): + deadline, deadline_timespec, serialized_request, rendezvous = ( + _start_unary_request(request, timeout, self._request_serializer)) + if serialized_request is None: + return None, None, None, None, rendezvous + else: + state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None) + operations = ( + cygrpc.operation_send_initial_metadata( + _common.metadata(metadata), _EMPTY_FLAGS), + cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS), + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), + cygrpc.operation_receive_message(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ) + return state, operations, deadline, deadline_timespec, None + + def __call__( + self, request, timeout=None, metadata=None, credentials=None, + with_call=False): + state, operations, deadline, deadline_timespec, rendezvous = self._prepare( + request, timeout, metadata) + if rendezvous: + raise rendezvous + else: + completion_queue = cygrpc.CompletionQueue() + call = self._channel.create_call( + None, 0, completion_queue, self._method, None, deadline_timespec) + if credentials is not None: + call.set_credentials(credentials._credentials) + call.start_batch(cygrpc.Operations(operations), None) + _handle_event(completion_queue.poll(), state, self._response_deserializer) + return _end_unary_response_blocking(state, with_call, deadline) + + def future(self, request, timeout=None, metadata=None, credentials=None): + state, operations, deadline, deadline_timespec, rendezvous = self._prepare( + request, timeout, metadata) + if rendezvous: + return rendezvous + else: + call = self._create_managed_call( + None, 0, self._method, None, deadline_timespec) + if credentials is not None: + call.set_credentials(credentials._credentials) + event_handler = _event_handler(state, call, self._response_deserializer) + with state.condition: + call.start_batch(cygrpc.Operations(operations), event_handler) + return _Rendezvous(state, call, self._response_deserializer, deadline) + + +class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): + + def __init__( + self, channel, create_managed_call, method, request_serializer, + response_deserializer): + self._channel = channel + self._create_managed_call = create_managed_call + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + def __call__(self, request, timeout=None, metadata=None, credentials=None): + deadline, deadline_timespec, serialized_request, rendezvous = ( + _start_unary_request(request, timeout, self._request_serializer)) + if serialized_request is None: + raise rendezvous + else: + state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) + call = self._create_managed_call( + None, 0, self._method, None, deadline_timespec) + if credentials is not None: + call.set_credentials(credentials._credentials) + event_handler = _event_handler(state, call, self._response_deserializer) + with state.condition: + call.start_batch( + cygrpc.Operations( + (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), + event_handler) + operations = ( + cygrpc.operation_send_initial_metadata( + _common.metadata(metadata), _EMPTY_FLAGS), + cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS), + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ) + call.start_batch(cygrpc.Operations(operations), event_handler) + return _Rendezvous(state, call, self._response_deserializer, deadline) + + +class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): + + def __init__( + self, channel, create_managed_call, method, request_serializer, + response_deserializer): + self._channel = channel + self._create_managed_call = create_managed_call + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + def __call__( + self, request_iterator, timeout=None, metadata=None, credentials=None, + with_call=False): + deadline, deadline_timespec = _deadline(timeout) + state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) + completion_queue = cygrpc.CompletionQueue() + call = self._channel.create_call( + None, 0, completion_queue, self._method, None, deadline_timespec) + if credentials is not None: + call.set_credentials(credentials._credentials) + with state.condition: + call.start_batch( + cygrpc.Operations( + (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), + None) + operations = ( + cygrpc.operation_send_initial_metadata( + _common.metadata(metadata), _EMPTY_FLAGS), + cygrpc.operation_receive_message(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ) + call.start_batch(cygrpc.Operations(operations), None) + _consume_request_iterator( + request_iterator, state, call, self._request_serializer) + while True: + event = completion_queue.poll() + with state.condition: + _handle_event(event, state, self._response_deserializer) + state.condition.notify_all() + if not state.due: + break + return _end_unary_response_blocking(state, with_call, deadline) + + def future( + self, request_iterator, timeout=None, metadata=None, credentials=None): + deadline, deadline_timespec = _deadline(timeout) + state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) + call = self._create_managed_call( + None, 0, self._method, None, deadline_timespec) + if credentials is not None: + call.set_credentials(credentials._credentials) + event_handler = _event_handler(state, call, self._response_deserializer) + with state.condition: + call.start_batch( + cygrpc.Operations( + (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), + event_handler) + operations = ( + cygrpc.operation_send_initial_metadata( + _common.metadata(metadata), _EMPTY_FLAGS), + cygrpc.operation_receive_message(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ) + call.start_batch(cygrpc.Operations(operations), event_handler) + _consume_request_iterator( + request_iterator, state, call, self._request_serializer) + return _Rendezvous(state, call, self._response_deserializer, deadline) + + +class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): + + def __init__( + self, channel, create_managed_call, method, request_serializer, + response_deserializer): + self._channel = channel + self._create_managed_call = create_managed_call + self._method = method + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + def __call__( + self, request_iterator, timeout=None, metadata=None, credentials=None): + deadline, deadline_timespec = _deadline(timeout) + state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None) + call = self._create_managed_call( + None, 0, self._method, None, deadline_timespec) + if credentials is not None: + call.set_credentials(credentials._credentials) + event_handler = _event_handler(state, call, self._response_deserializer) + with state.condition: + call.start_batch( + cygrpc.Operations( + (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), + event_handler) + operations = ( + cygrpc.operation_send_initial_metadata( + _common.metadata(metadata), _EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ) + call.start_batch(cygrpc.Operations(operations), event_handler) + _consume_request_iterator( + request_iterator, state, call, self._request_serializer) + return _Rendezvous(state, call, self._response_deserializer, deadline) + + +class _ChannelCallState(object): + + def __init__(self, channel): + self.lock = threading.Lock() + self.channel = channel + self.completion_queue = cygrpc.CompletionQueue() + self.managed_calls = None + + +def _call_spin(state): + while True: + event = state.completion_queue.poll() + completed_call = event.tag(event) + if completed_call is not None: + with state.lock: + state.managed_calls.remove(completed_call) + if not state.managed_calls: + state.managed_calls = None + return + + +def _create_channel_managed_call(state): + def create_channel_managed_call(parent, flags, method, host, deadline): + """Creates a managed cygrpc.Call. + + Callers of this function must conduct at least one operation on the returned + call. The tags associated with operations conducted on the returned call + must be no-argument callables that return None to indicate that this channel + should continue polling for events associated with the call and return the + call itself to indicate that no more events associated with the call will be + generated. + + Args: + parent: A cygrpc.Call to be used as the parent of the created call. + flags: An integer bitfield of call flags. + method: The RPC method. + host: A host string for the created call. + deadline: A cygrpc.Timespec to be the deadline of the created call. + + Returns: + A cygrpc.Call with which to conduct an RPC. + """ + with state.lock: + call = state.channel.create_call( + parent, flags, state.completion_queue, method, host, deadline) + if state.managed_calls is None: + state.managed_calls = set((call,)) + spin_thread = threading.Thread(target=_call_spin, args=(state,)) + spin_thread.start() + else: + state.managed_calls.add(call) + return call + return create_channel_managed_call + + +class _ChannelConnectivityState(object): + + def __init__(self, channel): + self.lock = threading.Lock() + self.channel = channel + self.polling = False + self.connectivity = None + self.try_to_connect = False + self.callbacks_and_connectivities = [] + self.delivering = False + + +def _deliveries(state): + callbacks_needing_update = [] + for callback_and_connectivity in state.callbacks_and_connectivities: + callback, callback_connectivity, = callback_and_connectivity + if callback_connectivity is not state.connectivity: + callbacks_needing_update.append(callback) + callback_and_connectivity[1] = state.connectivity + return callbacks_needing_update + + +def _deliver(state, initial_connectivity, initial_callbacks): + connectivity = initial_connectivity + callbacks = initial_callbacks + while True: + for callback in callbacks: + callable_util.call_logging_exceptions( + callback, _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE, + connectivity) + with state.lock: + callbacks = _deliveries(state) + if callbacks: + connectivity = state.connectivity + else: + state.delivering = False + return + + +def _spawn_delivery(state, callbacks): + delivering_thread = threading.Thread( + target=_deliver, args=(state, state.connectivity, callbacks,)) + delivering_thread.start() + state.delivering = True + + +# NOTE(https://github.com/grpc/grpc/issues/3064): We'd rather not poll. +def _poll_connectivity(state, channel, initial_try_to_connect): + try_to_connect = initial_try_to_connect + connectivity = channel.check_connectivity_state(try_to_connect) + with state.lock: + state.connectivity = ( + _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[ + connectivity]) + callbacks = tuple( + callback for callback, unused_but_known_to_be_none_connectivity + in state.callbacks_and_connectivities) + for callback_and_connectivity in state.callbacks_and_connectivities: + callback_and_connectivity[1] = state.connectivity + if callbacks: + _spawn_delivery(state, callbacks) + completion_queue = cygrpc.CompletionQueue() + while True: + channel.watch_connectivity_state( + connectivity, cygrpc.Timespec(time.time() + 0.2), + completion_queue, None) + event = completion_queue.poll() + with state.lock: + if not state.callbacks_and_connectivities and not state.try_to_connect: + state.polling = False + state.connectivity = None + break + try_to_connect = state.try_to_connect + state.try_to_connect = False + if event.success or try_to_connect: + connectivity = channel.check_connectivity_state(try_to_connect) + with state.lock: + state.connectivity = ( + _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[ + connectivity]) + if not state.delivering: + callbacks = _deliveries(state) + if callbacks: + _spawn_delivery(state, callbacks) + + +def _subscribe(state, callback, try_to_connect): + with state.lock: + if not state.callbacks_and_connectivities and not state.polling: + polling_thread = threading.Thread( + target=_poll_connectivity, + args=(state, state.channel, bool(try_to_connect))) + polling_thread.start() + state.polling = True + state.callbacks_and_connectivities.append([callback, None]) + elif not state.delivering and state.connectivity is not None: + _spawn_delivery(state, (callback,)) + state.try_to_connect |= bool(try_to_connect) + state.callbacks_and_connectivities.append( + [callback, state.connectivity]) + else: + state.try_to_connect |= bool(try_to_connect) + state.callbacks_and_connectivities.append([callback, None]) + + +def _unsubscribe(state, callback): + with state.lock: + for index, (subscribed_callback, unused_connectivity) in enumerate( + state.callbacks_and_connectivities): + if callback == subscribed_callback: + state.callbacks_and_connectivities.pop(index) + break + + +def _moot(state): + with state.lock: + del state.callbacks_and_connectivities[:] + + +def _options(options): + if options is None: + pairs = ((cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT),) + else: + pairs = list(options) + [ + (cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)] + return cygrpc.ChannelArgs( + cygrpc.ChannelArg(arg_name, arg_value) for arg_name, arg_value in pairs) + + +class Channel(grpc.Channel): + + def __init__(self, target, options, credentials): + self._channel = cygrpc.Channel(target, _options(options), credentials) + self._call_state = _ChannelCallState(self._channel) + self._connectivity_state = _ChannelConnectivityState(self._channel) + + def subscribe(self, callback, try_to_connect=None): + _subscribe(self._connectivity_state, callback, try_to_connect) + + def unsubscribe(self, callback): + _unsubscribe(self._connectivity_state, callback) + + def unary_unary( + self, method, request_serializer=None, response_deserializer=None): + return _UnaryUnaryMultiCallable( + self._channel, _create_channel_managed_call(self._call_state), method, + request_serializer, response_deserializer) + + def unary_stream( + self, method, request_serializer=None, response_deserializer=None): + return _UnaryStreamMultiCallable( + self._channel, _create_channel_managed_call(self._call_state), method, + request_serializer, response_deserializer) + + def stream_unary( + self, method, request_serializer=None, response_deserializer=None): + return _StreamUnaryMultiCallable( + self._channel, _create_channel_managed_call(self._call_state), method, + request_serializer, response_deserializer) + + def stream_stream( + self, method, request_serializer=None, response_deserializer=None): + return _StreamStreamMultiCallable( + self._channel, _create_channel_managed_call(self._call_state), method, + request_serializer, response_deserializer) + + def __del__(self): + _moot(self._connectivity_state) diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py new file mode 100644 index 0000000000..a3fb66cd07 --- /dev/null +++ b/src/python/grpcio/grpc/_common.py @@ -0,0 +1,99 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Shared implementation.""" + +import logging + +import six + +import grpc +from grpc._cython import cygrpc + +_EMPTY_METADATA = cygrpc.Metadata(()) + +CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = { + cygrpc.ConnectivityState.idle: grpc.ChannelConnectivity.IDLE, + cygrpc.ConnectivityState.connecting: grpc.ChannelConnectivity.CONNECTING, + cygrpc.ConnectivityState.ready: grpc.ChannelConnectivity.READY, + cygrpc.ConnectivityState.transient_failure: + grpc.ChannelConnectivity.TRANSIENT_FAILURE, + cygrpc.ConnectivityState.fatal_failure: + grpc.ChannelConnectivity.FATAL_FAILURE, +} + +CYGRPC_STATUS_CODE_TO_STATUS_CODE = { + cygrpc.StatusCode.ok: grpc.StatusCode.OK, + cygrpc.StatusCode.cancelled: grpc.StatusCode.CANCELLED, + cygrpc.StatusCode.unknown: grpc.StatusCode.UNKNOWN, + cygrpc.StatusCode.invalid_argument: grpc.StatusCode.INVALID_ARGUMENT, + cygrpc.StatusCode.deadline_exceeded: grpc.StatusCode.DEADLINE_EXCEEDED, + cygrpc.StatusCode.not_found: grpc.StatusCode.NOT_FOUND, + cygrpc.StatusCode.already_exists: grpc.StatusCode.ALREADY_EXISTS, + cygrpc.StatusCode.permission_denied: grpc.StatusCode.PERMISSION_DENIED, + cygrpc.StatusCode.unauthenticated: grpc.StatusCode.UNAUTHENTICATED, + cygrpc.StatusCode.resource_exhausted: grpc.StatusCode.RESOURCE_EXHAUSTED, + cygrpc.StatusCode.failed_precondition: grpc.StatusCode.FAILED_PRECONDITION, + cygrpc.StatusCode.aborted: grpc.StatusCode.ABORTED, + cygrpc.StatusCode.out_of_range: grpc.StatusCode.OUT_OF_RANGE, + cygrpc.StatusCode.unimplemented: grpc.StatusCode.UNIMPLEMENTED, + cygrpc.StatusCode.internal: grpc.StatusCode.INTERNAL, + cygrpc.StatusCode.unavailable: grpc.StatusCode.UNAVAILABLE, + cygrpc.StatusCode.data_loss: grpc.StatusCode.DATA_LOSS, +} +STATUS_CODE_TO_CYGRPC_STATUS_CODE = { + grpc_code: cygrpc_code + for cygrpc_code, grpc_code in six.iteritems( + CYGRPC_STATUS_CODE_TO_STATUS_CODE) +} + + +def metadata(application_metadata): + return _EMPTY_METADATA if application_metadata is None else cygrpc.Metadata( + cygrpc.Metadatum(key, value) for key, value in application_metadata) + + +def _transform(message, transformer, exception_message): + if transformer is None: + return message + else: + try: + return transformer(message) + except Exception: # pylint: disable=broad-except + logging.exception(exception_message) + return None + + +def serialize(message, serializer): + return _transform(message, serializer, 'Exception serializing message!') + + +def deserialize(serialized_message, deserializer): + return _transform(serialized_message, deserializer, + 'Exception deserializing message!') diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi index c793c8f5e5..19a59e08f3 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi @@ -68,4 +68,4 @@ cdef void plugin_get_metadata( void *state, grpc_auth_metadata_context context, grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil -cdef void plugin_destroy_c_plugin_state(void *state) +cdef void plugin_destroy_c_plugin_state(void *state) with gil diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi index 94d13b5999..1ba86457af 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi @@ -137,7 +137,7 @@ cdef void plugin_get_metadata( cy_context.context = context self.plugin_callback(cy_context, python_callback) -cdef void plugin_destroy_c_plugin_state(void *state): +cdef void plugin_destroy_c_plugin_state(void *state) with gil: cpython.Py_DECREF(<CredentialsMetadataPlugin>state) def channel_credentials_google_default(): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index d42c58050f..05b8886df7 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -208,7 +208,7 @@ cdef extern from "grpc/_cython/loader.h": GRPC_CHANNEL_CONNECTING GRPC_CHANNEL_READY GRPC_CHANNEL_TRANSIENT_FAILURE - GRPC_CHANNEL_FATAL_FAILURE + GRPC_CHANNEL_SHUTDOWN ctypedef struct grpc_metadata: const char *key diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi index c7539f0d49..e0219b0086 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi @@ -33,7 +33,7 @@ class ConnectivityState: connecting = GRPC_CHANNEL_CONNECTING ready = GRPC_CHANNEL_READY transient_failure = GRPC_CHANNEL_TRANSIENT_FAILURE - fatal_failure = GRPC_CHANNEL_FATAL_FAILURE + fatal_failure = GRPC_CHANNEL_SHUTDOWN class ChannelArgKey: @@ -274,6 +274,7 @@ cdef class ByteBuffer: data_slice_length = gpr_slice_length(data_slice) with gil: result += (<char *>data_slice_pointer)[:data_slice_length] + gpr_slice_unref(data_slice) with nogil: grpc_byte_buffer_reader_destroy(&reader) return bytes(result) diff --git a/src/python/grpcio/grpc/_cython/imports.generated.c b/src/python/grpcio/grpc/_cython/imports.generated.c index f71cf12844..5d09329680 100644 --- a/src/python/grpcio/grpc/_cython/imports.generated.c +++ b/src/python/grpcio/grpc/_cython/imports.generated.c @@ -35,7 +35,7 @@ #include "imports.generated.h" -#ifdef GPR_WIN32 +#ifdef GPR_WINDOWS census_initialize_type census_initialize_import; census_shutdown_type census_shutdown_import; @@ -581,4 +581,4 @@ void pygrpc_load_imports(HMODULE library) { } #endif /* __cpluslus */ -#endif /* !GPR_WIN32 */ +#endif /* !GPR_WINDOWS */ diff --git a/src/python/grpcio/grpc/_cython/imports.generated.h b/src/python/grpcio/grpc/_cython/imports.generated.h index a364075e9e..cfadad0527 100644 --- a/src/python/grpcio/grpc/_cython/imports.generated.h +++ b/src/python/grpcio/grpc/_cython/imports.generated.h @@ -36,7 +36,7 @@ #include <grpc/support/port_platform.h> -#ifdef GPR_WIN32 +#ifdef GPR_WINDOWS #include <windows.h> @@ -57,7 +57,7 @@ #include <grpc/support/cpu.h> #include <grpc/support/histogram.h> #include <grpc/support/host_port.h> -#include <grpc/support/log_win32.h> +#include <grpc/support/log_windows.h> #include <grpc/support/string_util.h> #include <grpc/support/subprocess.h> #include <grpc/support/thd.h> @@ -871,7 +871,7 @@ void pygrpc_load_imports(HMODULE library); } #endif /* __cpluslus */ -#else /* !GPR_WIN32 */ +#else /* !GPR_WINDOWS */ #include <grpc/byte_buffer.h> #include <grpc/byte_buffer_reader.h> @@ -883,6 +883,6 @@ void pygrpc_load_imports(HMODULE library); #include <grpc/support/time.h> #include <grpc/status.h> -#endif /* !GPR_WIN32 */ +#endif /* !GPR_WINDOWS */ #endif diff --git a/src/python/grpcio/grpc/_cython/loader.c b/src/python/grpcio/grpc/_cython/loader.c index 3b72806ea1..b909ad594e 100644 --- a/src/python/grpcio/grpc/_cython/loader.c +++ b/src/python/grpcio/grpc/_cython/loader.c @@ -37,7 +37,7 @@ extern "C" { #endif /* __cpluslus */ -#if GPR_WIN32 +#if GPR_WINDOWS int pygrpc_load_core(char *path) { HMODULE grpc_c; @@ -60,7 +60,7 @@ int pygrpc_load_core(char *path) { int pygrpc_load_core(char *path) { return 1; } -#endif /* !GPR_WIN32 */ +#endif /* !GPR_WINDOWS */ #ifdef __cplusplus } diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py new file mode 100644 index 0000000000..4e9cfe710c --- /dev/null +++ b/src/python/grpcio/grpc/_plugin_wrapping.py @@ -0,0 +1,123 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import collections +import threading + +import grpc +from grpc._cython import cygrpc + + +class AuthMetadataContext( + collections.namedtuple( + 'AuthMetadataContext', ('service_url', 'method_name',)), + grpc.AuthMetadataContext): + pass + + +class AuthMetadataPluginCallback(grpc.AuthMetadataContext): + + def __init__(self, callback): + self._callback = callback + + def __call__(self, metadata, error): + self._callback(metadata, error) + + +class _WrappedCygrpcCallback(object): + + def __init__(self, cygrpc_callback): + self.is_called = False + self.error = None + self.is_called_lock = threading.Lock() + self.cygrpc_callback = cygrpc_callback + + def _invoke_failure(self, error): + # TODO(atash) translate different Exception superclasses into different + # status codes. + self.cygrpc_callback( + cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message) + + def _invoke_success(self, metadata): + try: + cygrpc_metadata = cygrpc.Metadata( + cygrpc.Metadatum(key, value) + for key, value in metadata) + except Exception as error: + self._invoke_failure(error) + return + self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '') + + def __call__(self, metadata, error): + with self.is_called_lock: + if self.is_called: + raise RuntimeError('callback should only ever be invoked once') + if self.error: + self._invoke_failure(self.error) + return + self.is_called = True + if error is None: + self._invoke_success(metadata) + else: + self._invoke_failure(error) + + def notify_failure(self, error): + with self.is_called_lock: + if not self.is_called: + self.error = error + + +class _WrappedPlugin(object): + + def __init__(self, plugin): + self.plugin = plugin + + def __call__(self, context, cygrpc_callback): + wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback) + wrapped_context = AuthMetadataContext( + context.service_url, context.method_name) + try: + self.plugin( + wrapped_context, AuthMetadataPluginCallback(wrapped_cygrpc_callback)) + except Exception as error: + wrapped_cygrpc_callback.notify_failure(error) + raise + + +def call_credentials_metadata_plugin(plugin, name): + """ + Args: + plugin: A callable accepting a grpc.AuthMetadataContext + object and a callback (itself accepting a list of metadata key/value + 2-tuples and a None-able exception value). The callback must be eventually + called, but need not be called in plugin's invocation. + plugin's invocation must be non-blocking. + """ + return cygrpc.call_credentials_metadata_plugin( + cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name)) diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py new file mode 100644 index 0000000000..c65070f1b3 --- /dev/null +++ b/src/python/grpcio/grpc/_server.py @@ -0,0 +1,734 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Service-side implementation of gRPC Python.""" + +import collections +import enum +import logging +import threading +import time + +import grpc +from grpc import _common +from grpc._cython import cygrpc +from grpc.framework.foundation import callable_util + +_SHUTDOWN_TAG = 'shutdown' +_REQUEST_CALL_TAG = 'request_call' + +_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server' +_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata' +_RECEIVE_MESSAGE_TOKEN = 'receive_message' +_SEND_MESSAGE_TOKEN = 'send_message' +_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = ( + 'send_initial_metadata * send_message') +_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server' +_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = ( + 'send_initial_metadata * send_status_from_server') + +_OPEN = 'open' +_CLOSED = 'closed' +_CANCELLED = 'cancelled' + +_EMPTY_FLAGS = 0 +_EMPTY_METADATA = cygrpc.Metadata(()) + + +def _serialized_request(request_event): + return request_event.batch_operations[0].received_message.bytes() + + +def _code(state): + if state.code is None: + return cygrpc.StatusCode.ok + else: + code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(state.code) + return cygrpc.StatusCode.unknown if code is None else code + + +def _details(state): + return b'' if state.details is None else state.details + + +class _HandlerCallDetails( + collections.namedtuple( + '_HandlerCallDetails', ('method', 'invocation_metadata',)), + grpc.HandlerCallDetails): + pass + + +class _RPCState(object): + + def __init__(self): + self.condition = threading.Condition() + self.due = set() + self.request = None + self.client = _OPEN + self.initial_metadata_allowed = True + self.disable_next_compression = False + self.trailing_metadata = None + self.code = None + self.details = None + self.statused = False + self.rpc_errors = [] + self.callbacks = [] + + +def _raise_rpc_error(state): + rpc_error = grpc.RpcError() + state.rpc_errors.append(rpc_error) + raise rpc_error + + +def _possibly_finish_call(state, token): + state.due.remove(token) + if (state.client is _CANCELLED or state.statused) and not state.due: + callbacks = state.callbacks + state.callbacks = None + return state, callbacks + else: + return None, () + + +def _send_status_from_server(state, token): + def send_status_from_server(unused_send_status_from_server_event): + with state.condition: + return _possibly_finish_call(state, token) + return send_status_from_server + + +def _abort(state, call, code, details): + if state.client is not _CANCELLED: + if state.initial_metadata_allowed: + operations = ( + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.operation_send_status_from_server( + _common.metadata(state.trailing_metadata), code, details, + _EMPTY_FLAGS), + ) + token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN + else: + operations = ( + cygrpc.operation_send_status_from_server( + _common.metadata(state.trailing_metadata), code, details, + _EMPTY_FLAGS), + ) + token = _SEND_STATUS_FROM_SERVER_TOKEN + call.start_batch( + cygrpc.Operations(operations), + _send_status_from_server(state, token)) + state.statused = True + state.due.add(token) + + +def _receive_close_on_server(state): + def receive_close_on_server(receive_close_on_server_event): + with state.condition: + if receive_close_on_server_event.batch_operations[0].received_cancelled: + state.client = _CANCELLED + elif state.client is _OPEN: + state.client = _CLOSED + state.condition.notify_all() + return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN) + return receive_close_on_server + + +def _receive_message(state, call, request_deserializer): + def receive_message(receive_message_event): + serialized_request = _serialized_request(receive_message_event) + if serialized_request is None: + with state.condition: + if state.client is _OPEN: + state.client = _CLOSED + state.condition.notify_all() + return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) + else: + request = _common.deserialize(serialized_request, request_deserializer) + with state.condition: + if request is None: + _abort( + state, call, cygrpc.StatusCode.internal, + b'Exception deserializing request!') + else: + state.request = request + state.condition.notify_all() + return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) + return receive_message + + +def _send_initial_metadata(state): + def send_initial_metadata(unused_send_initial_metadata_event): + with state.condition: + return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN) + return send_initial_metadata + + +def _send_message(state, token): + def send_message(unused_send_message_event): + with state.condition: + state.condition.notify_all() + return _possibly_finish_call(state, token) + return send_message + + +class _Context(grpc.ServicerContext): + + def __init__(self, rpc_event, state, request_deserializer): + self._rpc_event = rpc_event + self._state = state + self._request_deserializer = request_deserializer + + def is_active(self): + with self._state.condition: + return self._state.client is not _CANCELLED and not self._state.statused + + def time_remaining(self): + return max(self._rpc_event.request_call_details.deadline - time.time(), 0) + + def cancel(self): + self._rpc_event.operation_call.cancel() + + def add_callback(self, callback): + with self._state.condition: + if self._state.callbacks is None: + return False + else: + self._state.callbacks.append(callback) + return True + + def disable_next_message_compression(self): + with self._state.condition: + self._state.disable_next_compression = True + + def invocation_metadata(self): + return self._rpc_event.request_metadata + + def peer(self): + return self._rpc_event.operation_call.peer() + + def send_initial_metadata(self, initial_metadata): + with self._state.condition: + if self._state.client is _CANCELLED: + _raise_rpc_error(self._state) + else: + if self._state.initial_metadata_allowed: + operation = cygrpc.operation_send_initial_metadata( + cygrpc.Metadata(initial_metadata), _EMPTY_FLAGS) + self._rpc_event.operation_call.start_batch( + cygrpc.Operations((operation,)), + _send_initial_metadata(self._state)) + self._state.initial_metadata_allowed = False + self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) + else: + raise ValueError('Initial metadata no longer allowed!') + + def set_trailing_metadata(self, trailing_metadata): + with self._state.condition: + self._state.trailing_metadata = trailing_metadata + + def set_code(self, code): + with self._state.condition: + self._state.code = code + + def set_details(self, details): + with self._state.condition: + self._state.details = details + + +class _RequestIterator(object): + + def __init__(self, state, call, request_deserializer): + self._state = state + self._call = call + self._request_deserializer = request_deserializer + + def _raise_or_start_receive_message(self): + if self._state.client is _CANCELLED: + _raise_rpc_error(self._state) + elif self._state.client is _CLOSED or self._state.statused: + raise StopIteration() + else: + self._call.start_batch( + cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)), + _receive_message(self._state, self._call, self._request_deserializer)) + self._state.due.add(_RECEIVE_MESSAGE_TOKEN) + + def _look_for_request(self): + if self._state.client is _CANCELLED: + _raise_rpc_error(self._state) + elif (self._state.request is None and + _RECEIVE_MESSAGE_TOKEN not in self._state.due): + raise StopIteration() + else: + request = self._state.request + self._state.request = None + return request + + def _next(self): + with self._state.condition: + self._raise_or_start_receive_message() + while True: + self._state.condition.wait() + request = self._look_for_request() + if request is not None: + return request + + def __iter__(self): + return self + + def __next__(self): + return self._next() + + def next(self): + return self._next() + + +def _unary_request(rpc_event, state, request_deserializer): + def unary_request(): + with state.condition: + if state.client is _CANCELLED or state.statused: + return None + else: + start_batch_result = rpc_event.operation_call.start_batch( + cygrpc.Operations( + (cygrpc.operation_receive_message(_EMPTY_FLAGS),)), + _receive_message( + state, rpc_event.operation_call, request_deserializer)) + state.due.add(_RECEIVE_MESSAGE_TOKEN) + while True: + state.condition.wait() + if state.request is None: + if state.client is _CLOSED: + details = b'"{}" requires exactly one request message.'.format( + rpc_event.request_call_details.method) + # TODO(5992#issuecomment-220761992): really, what status code? + _abort( + state, rpc_event.operation_call, + cygrpc.StatusCode.unavailable, details) + return None + elif state.client is _CANCELLED: + return None + else: + request = state.request + state.request = None + return request + return unary_request + + +def _call_behavior(rpc_event, state, behavior, argument, request_deserializer): + context = _Context(rpc_event, state, request_deserializer) + try: + return behavior(argument, context) + except Exception as e: # pylint: disable=broad-except + with state.condition: + if e not in state.rpc_errors: + details = b'Exception calling application: {}'.format(e) + logging.exception(details) + _abort( + state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details) + return None + + +def _take_response_from_response_iterator(rpc_event, state, response_iterator): + try: + return next(response_iterator), True + except StopIteration: + return None, True + except Exception as e: # pylint: disable=broad-except + with state.condition: + if e not in state.rpc_errors: + details = b'Exception iterating responses: {}'.format(e) + logging.exception(details) + _abort( + state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details) + return None, False + + +def _serialize_response(rpc_event, state, response, response_serializer): + serialized_response = _common.serialize(response, response_serializer) + if serialized_response is None: + with state.condition: + _abort( + state, rpc_event.operation_call, cygrpc.StatusCode.internal, + b'Failed to serialize response!') + return None + else: + return serialized_response + + +def _send_response(rpc_event, state, serialized_response): + with state.condition: + if state.client is _CANCELLED or state.statused: + return False + else: + if state.initial_metadata_allowed: + operations = ( + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS), + ) + state.initial_metadata_allowed = False + token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN + else: + operations = ( + cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS), + ) + token = _SEND_MESSAGE_TOKEN + rpc_event.operation_call.start_batch( + cygrpc.Operations(operations), _send_message(state, token)) + state.due.add(token) + while True: + state.condition.wait() + if token not in state.due: + return state.client is not _CANCELLED and not state.statused + + +def _status(rpc_event, state, serialized_response): + with state.condition: + if state.client is not _CANCELLED: + trailing_metadata = _common.metadata(state.trailing_metadata) + code = _code(state) + details = _details(state) + operations = [ + cygrpc.operation_send_status_from_server( + trailing_metadata, code, details, _EMPTY_FLAGS), + ] + if state.initial_metadata_allowed: + operations.append( + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS)) + if serialized_response is not None: + operations.append(cygrpc.operation_send_message( + serialized_response, _EMPTY_FLAGS)) + rpc_event.operation_call.start_batch( + cygrpc.Operations(operations), + _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN)) + state.statused = True + state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) + + +def _unary_response_in_pool( + rpc_event, state, behavior, argument_thunk, request_deserializer, + response_serializer): + argument = argument_thunk() + if argument is not None: + response = _call_behavior( + rpc_event, state, behavior, argument, request_deserializer) + if response is not None: + serialized_response = _serialize_response( + rpc_event, state, response, response_serializer) + if serialized_response is not None: + _status(rpc_event, state, serialized_response) + return + + +def _stream_response_in_pool( + rpc_event, state, behavior, argument_thunk, request_deserializer, + response_serializer): + argument = argument_thunk() + if argument is not None: + response_iterator = _call_behavior( + rpc_event, state, behavior, argument, request_deserializer) + if response_iterator is not None: + while True: + response, proceed = _take_response_from_response_iterator( + rpc_event, state, response_iterator) + if proceed: + if response is None: + _status(rpc_event, state, None) + break + else: + serialized_response = _serialize_response( + rpc_event, state, response, response_serializer) + if serialized_response is not None: + proceed = _send_response(rpc_event, state, serialized_response) + if not proceed: + break + else: + break + else: + break + + +def _handle_unary_unary(rpc_event, state, method_handler, thread_pool): + unary_request = _unary_request( + rpc_event, state, method_handler.request_deserializer) + thread_pool.submit( + _unary_response_in_pool, rpc_event, state, method_handler.unary_unary, + unary_request, method_handler.request_deserializer, + method_handler.response_serializer) + + +def _handle_unary_stream(rpc_event, state, method_handler, thread_pool): + unary_request = _unary_request( + rpc_event, state, method_handler.request_deserializer) + thread_pool.submit( + _stream_response_in_pool, rpc_event, state, method_handler.unary_stream, + unary_request, method_handler.request_deserializer, + method_handler.response_serializer) + + +def _handle_stream_unary(rpc_event, state, method_handler, thread_pool): + request_iterator = _RequestIterator( + state, rpc_event.operation_call, method_handler.request_deserializer) + thread_pool.submit( + _unary_response_in_pool, rpc_event, state, method_handler.stream_unary, + lambda: request_iterator, method_handler.request_deserializer, + method_handler.response_serializer) + + +def _handle_stream_stream(rpc_event, state, method_handler, thread_pool): + request_iterator = _RequestIterator( + state, rpc_event.operation_call, method_handler.request_deserializer) + thread_pool.submit( + _stream_response_in_pool, rpc_event, state, method_handler.stream_stream, + lambda: request_iterator, method_handler.request_deserializer, + method_handler.response_serializer) + + +def _find_method_handler(rpc_event, generic_handlers): + for generic_handler in generic_handlers: + method_handler = generic_handler.service( + _HandlerCallDetails( + rpc_event.request_call_details.method, rpc_event.request_metadata)) + if method_handler is not None: + return method_handler + else: + return None + + +def _handle_unrecognized_method(rpc_event): + operations = ( + cygrpc.operation_send_initial_metadata(_EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), + cygrpc.operation_send_status_from_server( + _EMPTY_METADATA, cygrpc.StatusCode.unimplemented, + b'Method not found!', _EMPTY_FLAGS), + ) + rpc_state = _RPCState() + rpc_event.operation_call.start_batch( + operations, lambda ignored_event: (rpc_state, (),)) + return rpc_state + + +def _handle_with_method_handler(rpc_event, method_handler, thread_pool): + state = _RPCState() + with state.condition: + rpc_event.operation_call.start_batch( + cygrpc.Operations( + (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)), + _receive_close_on_server(state)) + state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) + if method_handler.request_streaming: + if method_handler.response_streaming: + _handle_stream_stream(rpc_event, state, method_handler, thread_pool) + else: + _handle_stream_unary(rpc_event, state, method_handler, thread_pool) + else: + if method_handler.response_streaming: + _handle_unary_stream(rpc_event, state, method_handler, thread_pool) + else: + _handle_unary_unary(rpc_event, state, method_handler, thread_pool) + return state + + +def _handle_call(rpc_event, generic_handlers, thread_pool): + if rpc_event.request_call_details.method is not None: + method_handler = _find_method_handler(rpc_event, generic_handlers) + if method_handler is None: + return _handle_unrecognized_method(rpc_event) + else: + return _handle_with_method_handler(rpc_event, method_handler, thread_pool) + else: + return None + + +@enum.unique +class _ServerStage(enum.Enum): + STOPPED = 'stopped' + STARTED = 'started' + GRACE = 'grace' + + +class _ServerState(object): + + def __init__(self, completion_queue, server, generic_handlers, thread_pool): + self.lock = threading.Lock() + self.completion_queue = completion_queue + self.server = server + self.generic_handlers = list(generic_handlers) + self.thread_pool = thread_pool + self.stage = _ServerStage.STOPPED + self.shutdown_events = None + + # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields. + self.rpc_states = set() + self.due = set() + + +def _add_generic_handlers(state, generic_handlers): + with state.lock: + state.generic_handlers.extend(generic_handlers) + + +def _add_insecure_port(state, address): + with state.lock: + return state.server.add_http2_port(address) + + +def _add_secure_port(state, address, server_credentials): + with state.lock: + return state.server.add_http2_port(address, server_credentials._credentials) + + +def _request_call(state): + state.server.request_call( + state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG) + state.due.add(_REQUEST_CALL_TAG) + + +# TODO(https://github.com/grpc/grpc/issues/6597): delete this function. +def _stop_serving(state): + if not state.rpc_states and not state.due: + for shutdown_event in state.shutdown_events: + shutdown_event.set() + state.stage = _ServerStage.STOPPED + return True + else: + return False + + +def _serve(state): + while True: + event = state.completion_queue.poll() + if event.tag is _SHUTDOWN_TAG: + with state.lock: + state.due.remove(_SHUTDOWN_TAG) + if _stop_serving(state): + return + elif event.tag is _REQUEST_CALL_TAG: + with state.lock: + state.due.remove(_REQUEST_CALL_TAG) + rpc_state = _handle_call( + event, state.generic_handlers, state.thread_pool) + if rpc_state is not None: + state.rpc_states.add(rpc_state) + if state.stage is _ServerStage.STARTED: + _request_call(state) + elif _stop_serving(state): + return + else: + rpc_state, callbacks = event.tag(event) + for callback in callbacks: + callable_util.call_logging_exceptions( + callback, 'Exception calling callback!') + if rpc_state is not None: + with state.lock: + state.rpc_states.remove(rpc_state) + if _stop_serving(state): + return + + +def _start(state): + with state.lock: + if state.stage is not _ServerStage.STOPPED: + raise ValueError('Cannot start already-started server!') + state.server.start() + state.stage = _ServerStage.STARTED + _request_call(state) + thread = threading.Thread(target=_serve, args=(state,)) + thread.start() + + +def _stop(state, grace): + with state.lock: + if state.stage is _ServerStage.STOPPED: + shutdown_event = threading.Event() + shutdown_event.set() + return shutdown_event + else: + if state.stage is _ServerStage.STARTED: + state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) + state.stage = _ServerStage.GRACE + state.shutdown_events = [] + state.due.add(_SHUTDOWN_TAG) + shutdown_event = threading.Event() + state.shutdown_events.append(shutdown_event) + if grace is None: + state.server.cancel_all_calls() + # TODO(https://github.com/grpc/grpc/issues/6597): delete this loop. + for rpc_state in state.rpc_states: + with rpc_state.condition: + rpc_state.client = _CANCELLED + rpc_state.condition.notify_all() + else: + def cancel_all_calls_after_grace(): + shutdown_event.wait(timeout=grace) + with state.lock: + state.server.cancel_all_calls() + # TODO(https://github.com/grpc/grpc/issues/6597): delete this loop. + for rpc_state in state.rpc_states: + with rpc_state.condition: + rpc_state.client = _CANCELLED + rpc_state.condition.notify_all() + thread = threading.Thread(target=cancel_all_calls_after_grace) + thread.start() + return shutdown_event + shutdown_event.wait() + return shutdown_event + + +class Server(grpc.Server): + + def __init__(self, generic_handlers, thread_pool): + completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server() + server.register_completion_queue(completion_queue) + self._state = _ServerState( + completion_queue, server, generic_handlers, thread_pool) + + def add_generic_rpc_handlers(self, generic_rpc_handlers): + _add_generic_handlers(self._state, generic_rpc_handlers) + + def add_insecure_port(self, address): + return _add_insecure_port(self._state, address) + + def add_secure_port(self, address, server_credentials): + return _add_secure_port(self._state, address, server_credentials) + + def start(self): + _start(self._state) + + def stop(self, grace): + return _stop(self._state, grace) + + def __del__(self): + _stop(self._state, None) diff --git a/src/python/grpcio/grpc/_utilities.py b/src/python/grpcio/grpc/_utilities.py new file mode 100644 index 0000000000..a4ca9b7282 --- /dev/null +++ b/src/python/grpcio/grpc/_utilities.py @@ -0,0 +1,147 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Internal utilities for gRPC Python.""" + +import threading +import time + +import grpc +from grpc.framework.foundation import callable_util + +_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = ( + 'Exception calling connectivity future "done" callback!') + + +class _ChannelReadyFuture(grpc.Future): + + def __init__(self, channel): + self._condition = threading.Condition() + self._channel = channel + + self._matured = False + self._cancelled = False + self._done_callbacks = [] + + def _block(self, timeout): + until = None if timeout is None else time.time() + timeout + with self._condition: + while True: + if self._cancelled: + raise grpc.FutureCancelledError() + elif self._matured: + return + else: + if until is None: + self._condition.wait() + else: + remaining = until - time.time() + if remaining < 0: + raise grpc.FutureTimeoutError() + else: + self._condition.wait(timeout=remaining) + + def _update(self, connectivity): + with self._condition: + if (not self._cancelled and + connectivity is grpc.ChannelConnectivity.READY): + self._matured = True + self._channel.unsubscribe(self._update) + self._condition.notify_all() + done_callbacks = tuple(self._done_callbacks) + self._done_callbacks = None + else: + return + + for done_callback in done_callbacks: + callable_util.call_logging_exceptions( + done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self) + + def cancel(self): + with self._condition: + if not self._matured: + self._cancelled = True + self._channel.unsubscribe(self._update) + self._condition.notify_all() + done_callbacks = tuple(self._done_callbacks) + self._done_callbacks = None + else: + return False + + for done_callback in done_callbacks: + callable_util.call_logging_exceptions( + done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self) + + def cancelled(self): + with self._condition: + return self._cancelled + + def running(self): + with self._condition: + return not self._cancelled and not self._matured + + def done(self): + with self._condition: + return self._cancelled or self._matured + + def result(self, timeout=None): + self._block(timeout) + return None + + def exception(self, timeout=None): + self._block(timeout) + return None + + def traceback(self, timeout=None): + self._block(timeout) + return None + + def add_done_callback(self, fn): + with self._condition: + if not self._cancelled and not self._matured: + self._done_callbacks.append(fn) + return + + fn(self) + + def start(self): + with self._condition: + self._channel.subscribe(self._update, try_to_connect=True) + + def __del__(self): + with self._condition: + if not self._cancelled and not self._matured: + self._channel.unsubscribe(self._update) + + +def channel_ready_future(channel): + ready_future = _ChannelReadyFuture(channel) + ready_future.start() + return ready_future + diff --git a/src/python/grpcio/grpc/beta/_auth.py b/src/python/grpcio/grpc/beta/_auth.py new file mode 100644 index 0000000000..553d4b9991 --- /dev/null +++ b/src/python/grpcio/grpc/beta/_auth.py @@ -0,0 +1,73 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""GRPCAuthMetadataPlugins for standard authentication.""" + +from concurrent import futures + +from grpc.beta import interfaces + + +def _sign_request(callback, token, error): + metadata = (('authorization', 'Bearer {}'.format(token)),) + callback(metadata, error) + + +class GoogleCallCredentials(interfaces.GRPCAuthMetadataPlugin): + """Metadata wrapper for GoogleCredentials from the oauth2client library.""" + + def __init__(self, credentials): + self._credentials = credentials + self._pool = futures.ThreadPoolExecutor(max_workers=1) + + def __call__(self, context, callback): + # MetadataPlugins cannot block (see grpc.beta.interfaces.py) + future = self._pool.submit(self._credentials.get_access_token) + future.add_done_callback(lambda x: self._get_token_callback(callback, x)) + + def _get_token_callback(self, callback, future): + try: + access_token = future.result().access_token + except Exception as e: + _sign_request(callback, None, e) + else: + _sign_request(callback, access_token, None) + + def __del__(self): + self._pool.shutdown(wait=False) + + +class AccessTokenCallCredentials(interfaces.GRPCAuthMetadataPlugin): + """Metadata wrapper for raw access token credentials.""" + + def __init__(self, access_token): + self._access_token = access_token + + def __call__(self, context, callback): + _sign_request(callback, self._access_token, None) diff --git a/src/python/grpcio/grpc/beta/implementations.py b/src/python/grpcio/grpc/beta/implementations.py index 822f593323..d8c32dd2f5 100644 --- a/src/python/grpcio/grpc/beta/implementations.py +++ b/src/python/grpcio/grpc/beta/implementations.py @@ -38,6 +38,7 @@ import threading # pylint: disable=unused-import from grpc._adapter import _intermediary_low from grpc._adapter import _low from grpc._adapter import _types +from grpc.beta import _auth from grpc.beta import _connectivity_channel from grpc.beta import _server from grpc.beta import _stub @@ -105,10 +106,40 @@ def metadata_call_credentials(metadata_plugin, name=None): A CallCredentials object for use in a GRPCCallOptions object. """ if name is None: - name = metadata_plugin.__name__ + try: + name = metadata_plugin.__name__ + except AttributeError: + name = metadata_plugin.__class__.__name__ return CallCredentials( _low.call_credentials_metadata_plugin(metadata_plugin, name)) + +def google_call_credentials(credentials): + """Construct CallCredentials from GoogleCredentials. + + Args: + credentials: A GoogleCredentials object from the oauth2client library. + + Returns: + A CallCredentials object for use in a GRPCCallOptions object. + """ + return metadata_call_credentials(_auth.GoogleCallCredentials(credentials)) + + +def access_token_call_credentials(access_token): + """Construct CallCredentials from an access token. + + Args: + access_token: A string to place directly in the http request + authorization header, ie "Authorization: Bearer <access_token>". + + Returns: + A CallCredentials object for use in a GRPCCallOptions object. + """ + return metadata_call_credentials( + _auth.AccessTokenCallCredentials(access_token)) + + def composite_call_credentials(call_credentials, additional_call_credentials): """Compose two CallCredentials to make a new one. diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index 50a6f196d8..59525f6183 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -42,7 +42,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/support/cpu_windows.c', 'src/core/lib/support/env_linux.c', 'src/core/lib/support/env_posix.c', - 'src/core/lib/support/env_win32.c', + 'src/core/lib/support/env_windows.c', 'src/core/lib/support/histogram.c', 'src/core/lib/support/host_port.c', 'src/core/lib/support/load_file.c', @@ -50,31 +50,31 @@ CORE_SOURCE_FILES = [ 'src/core/lib/support/log_android.c', 'src/core/lib/support/log_linux.c', 'src/core/lib/support/log_posix.c', - 'src/core/lib/support/log_win32.c', + 'src/core/lib/support/log_windows.c', 'src/core/lib/support/murmur_hash.c', 'src/core/lib/support/slice.c', 'src/core/lib/support/slice_buffer.c', 'src/core/lib/support/stack_lockfree.c', 'src/core/lib/support/string.c', 'src/core/lib/support/string_posix.c', - 'src/core/lib/support/string_util_win32.c', - 'src/core/lib/support/string_win32.c', + 'src/core/lib/support/string_util_windows.c', + 'src/core/lib/support/string_windows.c', 'src/core/lib/support/subprocess_posix.c', 'src/core/lib/support/subprocess_windows.c', 'src/core/lib/support/sync.c', 'src/core/lib/support/sync_posix.c', - 'src/core/lib/support/sync_win32.c', + 'src/core/lib/support/sync_windows.c', 'src/core/lib/support/thd.c', 'src/core/lib/support/thd_posix.c', - 'src/core/lib/support/thd_win32.c', + 'src/core/lib/support/thd_windows.c', 'src/core/lib/support/time.c', 'src/core/lib/support/time_posix.c', 'src/core/lib/support/time_precise.c', - 'src/core/lib/support/time_win32.c', + 'src/core/lib/support/time_windows.c', 'src/core/lib/support/tls_pthread.c', 'src/core/lib/support/tmpfile_msys.c', 'src/core/lib/support/tmpfile_posix.c', - 'src/core/lib/support/tmpfile_win32.c', + 'src/core/lib/support/tmpfile_windows.c', 'src/core/lib/support/wrap_memcpy.c', 'src/core/lib/surface/init.c', 'src/core/lib/channel/channel_args.c', @@ -95,6 +95,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/iomgr/endpoint_pair_posix.c', 'src/core/lib/iomgr/endpoint_pair_windows.c', 'src/core/lib/iomgr/ev_epoll_linux.c', + 'src/core/lib/iomgr/ev_poll_and_epoll_posix.c', 'src/core/lib/iomgr/ev_poll_posix.c', 'src/core/lib/iomgr/ev_posix.c', 'src/core/lib/iomgr/exec_ctx.c', @@ -189,7 +190,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/security/credentials/credentials_metadata.c', 'src/core/lib/security/credentials/fake/fake_credentials.c', 'src/core/lib/security/credentials/google_default/credentials_posix.c', - 'src/core/lib/security/credentials/google_default/credentials_win32.c', + 'src/core/lib/security/credentials/google_default/credentials_windows.c', 'src/core/lib/security/credentials/google_default/google_default_credentials.c', 'src/core/lib/security/credentials/iam/iam_credentials.c', 'src/core/lib/security/credentials/jwt/json_token.c', @@ -243,7 +244,10 @@ CORE_SOURCE_FILES = [ 'src/core/ext/lb_policy/round_robin/round_robin.c', 'src/core/ext/resolver/dns/native/dns_resolver.c', 'src/core/ext/resolver/sockaddr/sockaddr_resolver.c', + 'src/core/ext/load_reporting/load_reporting.c', + 'src/core/ext/load_reporting/load_reporting_filter.c', 'src/core/ext/census/context.c', + 'src/core/ext/census/gen/census.pb.c', 'src/core/ext/census/grpc_context.c', 'src/core/ext/census/grpc_filter.c', 'src/core/ext/census/grpc_plugin.c', diff --git a/src/python/grpcio/tests/interop/client.py b/src/python/grpcio/tests/interop/client.py index db29eb4aa7..e3d5545a02 100644 --- a/src/python/grpcio/tests/interop/client.py +++ b/src/python/grpcio/tests/interop/client.py @@ -65,39 +65,34 @@ def _args(): help='email address of the default service account', type=str) return parser.parse_args() -def _oauth_access_token(args): - credentials = oauth2client_client.GoogleCredentials.get_application_default() - scoped_credentials = credentials.create_scoped([args.oauth_scope]) - return scoped_credentials.get_access_token().access_token def _stub(args): - if args.oauth_scope: - if args.test_case == 'oauth2_auth_token': - # TODO(jtattermusch): This testcase sets the auth metadata key-value - # manually, which also means that the user would need to do the same - # thing every time he/she would like to use and out of band oauth token. - # The transformer function that produces the metadata key-value from - # the access token should be provided by gRPC auth library. - access_token = _oauth_access_token(args) - metadata_transformer = lambda x: [ - ('authorization', 'Bearer %s' % access_token)] - else: - metadata_transformer = lambda x: [ - ('authorization', 'Bearer %s' % _oauth_access_token(args))] + if args.test_case == 'oauth2_auth_token': + creds = oauth2client_client.GoogleCredentials.get_application_default() + scoped_creds = creds.create_scoped([args.oauth_scope]) + access_token = scoped_creds.get_access_token().access_token + call_creds = implementations.access_token_call_credentials(access_token) + elif args.test_case == 'compute_engine_creds': + creds = oauth2client_client.GoogleCredentials.get_application_default() + scoped_creds = creds.create_scoped([args.oauth_scope]) + call_creds = implementations.google_call_credentials(scoped_creds) else: - metadata_transformer = lambda x: [] + call_creds = None if args.use_tls: if args.use_test_ca: root_certificates = resources.test_root_certificates() else: root_certificates = None # will load default roots. + channel_creds = implementations.ssl_channel_credentials(root_certificates) + if call_creds is not None: + channel_creds = implementations.composite_channel_credentials( + channel_creds, call_creds) + channel = test_utilities.not_really_secure_channel( - args.server_host, args.server_port, - implementations.ssl_channel_credentials(root_certificates), + args.server_host, args.server_port, channel_creds, args.server_host_override) - stub = test_pb2.beta_create_TestService_stub( - channel, metadata_transformer=metadata_transformer) + stub = test_pb2.beta_create_TestService_stub(channel) else: channel = implementations.insecure_channel( args.server_host, args.server_port) diff --git a/src/python/grpcio/tests/interop/methods.py b/src/python/grpcio/tests/interop/methods.py index 67862ed7d3..d5ef0c68bb 100644 --- a/src/python/grpcio/tests/interop/methods.py +++ b/src/python/grpcio/tests/interop/methods.py @@ -39,6 +39,8 @@ import time from oauth2client import client as oauth2client_client +from grpc.beta import implementations +from grpc.beta import interfaces from grpc.framework.common import cardinality from grpc.framework.interfaces.face import face @@ -88,13 +90,15 @@ class TestService(test_pb2.BetaTestServiceServicer): return self.FullDuplexCall(request_iterator, context) -def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope): +def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope, + protocol_options=None): with stub: request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=314159, payload=messages_pb2.Payload(body=b'\x00' * 271828), fill_username=fill_username, fill_oauth_scope=fill_oauth_scope) - response_future = stub.UnaryCall.future(request, _TIMEOUT) + response_future = stub.UnaryCall.future(request, _TIMEOUT, + protocol_options=protocol_options) response = response_future.result() if response.payload.type is not messages_pb2.COMPRESSABLE: raise ValueError( @@ -303,7 +307,24 @@ def _oauth2_auth_token(stub, args): if args.oauth_scope.find(response.oauth_scope) == -1: raise ValueError( 'expected to find oauth scope "%s" in received "%s"' % - (response.oauth_scope, args.oauth_scope)) + (response.oauth_scope, args.oauth_scope)) + + +def _per_rpc_creds(stub, args): + json_key_filename = os.environ[ + oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS] + wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] + credentials = oauth2client_client.GoogleCredentials.get_application_default() + scoped_credentials = credentials.create_scoped([args.oauth_scope]) + call_creds = implementations.google_call_credentials(scoped_credentials) + options = interfaces.grpc_call_options(disable_compression=False, + credentials=call_creds) + response = _large_unary_common_behavior(stub, True, False, + protocol_options=options) + if wanted_email != response.username: + raise ValueError( + 'expected username %s, got %s' % (wanted_email, response.username)) + @enum.unique class TestCase(enum.Enum): @@ -317,6 +338,7 @@ class TestCase(enum.Enum): EMPTY_STREAM = 'empty_stream' COMPUTE_ENGINE_CREDS = 'compute_engine_creds' OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' + PER_RPC_CREDS = 'per_rpc_creds' TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' def test_interoperability(self, stub, args): @@ -342,5 +364,7 @@ class TestCase(enum.Enum): _compute_engine_creds(stub, args) elif self is TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token(stub, args) + elif self is TestCase.PER_RPC_CREDS: + _per_rpc_creds(stub, args) else: raise NotImplementedError('Test case "%s" not implemented!' % self.name) diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json index 1beb619f87..8dc47bf69d 100644 --- a/src/python/grpcio/tests/tests.json +++ b/src/python/grpcio/tests/tests.json @@ -1,4 +1,6 @@ [ + "_auth_test.AccessTokenCallCredentialsTest", + "_auth_test.GoogleCallCredentialsTest", "_base_interface_test.AsyncEasyTest", "_base_interface_test.AsyncPeasyTest", "_base_interface_test.SyncEasyTest", @@ -6,6 +8,8 @@ "_beta_features_test.BetaFeaturesTest", "_beta_features_test.ContextManagementAndLifecycleTest", "_cancel_many_calls_test.CancelManyCallsTest", + "_channel_connectivity_test.ChannelConnectivityTest", + "_channel_ready_future_test.ChannelReadyFutureTest", "_channel_test.ChannelTest", "_connectivity_channel_test.ChannelConnectivityTest", "_core_over_links_base_interface_test.AsyncEasyTest", @@ -31,6 +35,7 @@ "_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest", "_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest", "_health_servicer_test.HealthServicerTest", + "_implementations_test.CallCredentialsTest", "_implementations_test.ChannelCredentialsTest", "_insecure_interop_test.InsecureInteropTest", "_intermediary_low_test.CancellationTest", @@ -43,6 +48,8 @@ "_low_test.HangingServerShutdown", "_low_test.InsecureServerInsecureClient", "_not_found_test.NotFoundTest", + "_read_some_but_not_all_responses_test.ReadSomeButNotAllResponsesTest", + "_rpc_test.RPCTest", "_sanity_test.Sanity", "_secure_interop_test.SecureInteropTest", "_transmission_test.RoundTripTest", diff --git a/src/python/grpcio/tests/unit/_channel_connectivity_test.py b/src/python/grpcio/tests/unit/_channel_connectivity_test.py new file mode 100644 index 0000000000..a1575efada --- /dev/null +++ b/src/python/grpcio/tests/unit/_channel_connectivity_test.py @@ -0,0 +1,161 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests of grpc._channel.Channel connectivity.""" + +import threading +import time +import unittest +from concurrent import futures + +import grpc +from grpc import _channel +from grpc import _server +from tests.unit.framework.common import test_constants + + +def _ready_in_connectivities(connectivities): + return grpc.ChannelConnectivity.READY in connectivities + + +def _last_connectivity_is_not_ready(connectivities): + return connectivities[-1] is not grpc.ChannelConnectivity.READY + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._connectivities = [] + + def update(self, connectivity): + with self._condition: + self._connectivities.append(connectivity) + self._condition.notify() + + def connectivities(self): + with self._condition: + return tuple(self._connectivities) + + def block_until_connectivities_satisfy(self, predicate): + with self._condition: + while True: + connectivities = tuple(self._connectivities) + if predicate(connectivities): + return connectivities + else: + self._condition.wait() + + +class ChannelConnectivityTest(unittest.TestCase): + + def test_lonely_channel_connectivity(self): + callback = _Callback() + + channel = _channel.Channel('localhost:12345', None, None) + channel.subscribe(callback.update, try_to_connect=False) + first_connectivities = callback.block_until_connectivities_satisfy(bool) + channel.subscribe(callback.update, try_to_connect=True) + second_connectivities = callback.block_until_connectivities_satisfy( + lambda connectivities: 2 <= len(connectivities)) + # Wait for a connection that will never happen. + time.sleep(test_constants.SHORT_TIMEOUT) + third_connectivities = callback.connectivities() + channel.unsubscribe(callback.update) + fourth_connectivities = callback.connectivities() + channel.unsubscribe(callback.update) + fifth_connectivities = callback.connectivities() + + self.assertSequenceEqual( + (grpc.ChannelConnectivity.IDLE,), first_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.READY, second_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.READY, third_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.READY, fourth_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.READY, fifth_connectivities) + + def test_immediately_connectable_channel_connectivity(self): + server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0)) + port = server.add_insecure_port('[::]:0') + server.start() + first_callback = _Callback() + second_callback = _Callback() + + channel = _channel.Channel('localhost:{}'.format(port), None, None) + channel.subscribe(first_callback.update, try_to_connect=False) + first_connectivities = first_callback.block_until_connectivities_satisfy( + bool) + # Wait for a connection that will never happen because try_to_connect=True + # has not yet been passed. + time.sleep(test_constants.SHORT_TIMEOUT) + second_connectivities = first_callback.connectivities() + channel.subscribe(second_callback.update, try_to_connect=True) + third_connectivities = first_callback.block_until_connectivities_satisfy( + lambda connectivities: 2 <= len(connectivities)) + fourth_connectivities = second_callback.block_until_connectivities_satisfy( + bool) + # Wait for a connection that will happen (or may already have happened). + first_callback.block_until_connectivities_satisfy(_ready_in_connectivities) + second_callback.block_until_connectivities_satisfy(_ready_in_connectivities) + del channel + + self.assertSequenceEqual( + (grpc.ChannelConnectivity.IDLE,), first_connectivities) + self.assertSequenceEqual( + (grpc.ChannelConnectivity.IDLE,), second_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.FATAL_FAILURE, third_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.TRANSIENT_FAILURE, + fourth_connectivities) + self.assertNotIn( + grpc.ChannelConnectivity.FATAL_FAILURE, fourth_connectivities) + + def test_reachable_then_unreachable_channel_connectivity(self): + server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0)) + port = server.add_insecure_port('[::]:0') + server.start() + callback = _Callback() + + channel = _channel.Channel('localhost:{}'.format(port), None, None) + channel.subscribe(callback.update, try_to_connect=True) + callback.block_until_connectivities_satisfy(_ready_in_connectivities) + # Now take down the server and confirm that channel readiness is repudiated. + server.stop(None) + callback.block_until_connectivities_satisfy(_last_connectivity_is_not_ready) + channel.unsubscribe(callback.update) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_channel_ready_future_test.py b/src/python/grpcio/tests/unit/_channel_ready_future_test.py new file mode 100644 index 0000000000..b84bc0197a --- /dev/null +++ b/src/python/grpcio/tests/unit/_channel_ready_future_test.py @@ -0,0 +1,103 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests of grpc.channel_ready_future.""" + +import threading +import unittest +from concurrent import futures + +import grpc +from grpc import _channel +from grpc import _server +from tests.unit.framework.common import test_constants + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + + def accept_value(self, value): + with self._condition: + self._value = value + self._condition.notify_all() + + def block_until_called(self): + with self._condition: + while self._value is None: + self._condition.wait() + return self._value + + +class ChannelReadyFutureTest(unittest.TestCase): + + def test_lonely_channel_connectivity(self): + channel = grpc.insecure_channel('localhost:12345') + callback = _Callback() + + ready_future = grpc.channel_ready_future(channel) + ready_future.add_done_callback(callback.accept_value) + with self.assertRaises(grpc.FutureTimeoutError): + ready_future.result(test_constants.SHORT_TIMEOUT) + self.assertFalse(ready_future.cancelled()) + self.assertFalse(ready_future.done()) + self.assertTrue(ready_future.running()) + ready_future.cancel() + value_passed_to_callback = callback.block_until_called() + self.assertIs(ready_future, value_passed_to_callback) + self.assertTrue(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + + def test_immediately_connectable_channel_connectivity(self): + server = _server.Server((), futures.ThreadPoolExecutor(max_workers=0)) + port = server.add_insecure_port('[::]:0') + server.start() + channel = grpc.insecure_channel('localhost:{}'.format(port)) + callback = _Callback() + + ready_future = grpc.channel_ready_future(channel) + ready_future.add_done_callback(callback.accept_value) + self.assertIsNone(ready_future.result(test_constants.SHORT_TIMEOUT)) + value_passed_to_callback = callback.block_until_called() + self.assertIs(ready_future, value_passed_to_callback) + self.assertFalse(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + # Cancellation after maturity has no effect. + ready_future.cancel() + self.assertFalse(ready_future.cancelled()) + self.assertTrue(ready_future.done()) + self.assertFalse(ready_future.running()) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py new file mode 100644 index 0000000000..6ae7a90fbe --- /dev/null +++ b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -0,0 +1,251 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test a corner-case at the level of the Cython API.""" + +import threading +import unittest + +from grpc._cython import cygrpc + +_INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) +_EMPTY_FLAGS = 0 +_EMPTY_METADATA = cygrpc.Metadata(()) + + +class _ServerDriver(object): + + def __init__(self, completion_queue, shutdown_tag): + self._condition = threading.Condition() + self._completion_queue = completion_queue + self._shutdown_tag = shutdown_tag + self._events = [] + self._saw_shutdown_tag = False + + def start(self): + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._condition.notify() + if event.tag is self._shutdown_tag: + self._saw_shutdown_tag = True + break + thread = threading.Thread(target=in_thread) + thread.start() + + def done(self): + with self._condition: + return self._saw_shutdown_tag + + def first_event(self): + with self._condition: + while not self._events: + self._condition.wait() + return self._events[0] + + def events(self): + with self._condition: + while not self._saw_shutdown_tag: + self._condition.wait() + return tuple(self._events) + + +class _QueueDriver(object): + + def __init__(self, condition, completion_queue, due): + self._condition = condition + self._completion_queue = completion_queue + self._due = due + self._events = [] + self._returned = False + + def start(self): + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._due.remove(event.tag) + self._condition.notify_all() + if not self._due: + self._returned = True + return + thread = threading.Thread(target=in_thread) + thread.start() + + def done(self): + with self._condition: + return self._returned + + def event_with_tag(self, tag): + with self._condition: + while True: + for event in self._events: + if event.tag is tag: + return event + self._condition.wait() + + def events(self): + with self._condition: + while not self._returned: + self._condition.wait() + return tuple(self._events) + + +class ReadSomeButNotAllResponsesTest(unittest.TestCase): + + def testReadSomeButNotAllResponses(self): + server_completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server() + server.register_completion_queue(server_completion_queue) + port = server.add_http2_port('[::]:0') + server.start() + channel = cygrpc.Channel('localhost:{}'.format(port)) + + server_shutdown_tag = 'server_shutdown_tag' + server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag) + server_driver.start() + + client_condition = threading.Condition() + client_due = set() + client_completion_queue = cygrpc.CompletionQueue() + client_driver = _QueueDriver( + client_condition, client_completion_queue, client_due) + client_driver.start() + + server_call_condition = threading.Condition() + server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' + server_send_first_message_tag = 'server_send_first_message_tag' + server_send_second_message_tag = 'server_send_second_message_tag' + server_complete_rpc_tag = 'server_complete_rpc_tag' + server_call_due = set(( + server_send_initial_metadata_tag, + server_send_first_message_tag, + server_send_second_message_tag, + server_complete_rpc_tag, + )) + server_call_completion_queue = cygrpc.CompletionQueue() + server_call_driver = _QueueDriver( + server_call_condition, server_call_completion_queue, server_call_due) + server_call_driver.start() + + server_rpc_tag = 'server_rpc_tag' + request_call_result = server.request_call( + server_call_completion_queue, server_completion_queue, server_rpc_tag) + + client_call = channel.create_call( + None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None, + _INFINITE_FUTURE) + client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' + client_complete_rpc_tag = 'client_complete_rpc_tag' + with client_condition: + client_receive_initial_metadata_start_batch_result = ( + client_call.start_batch(cygrpc.Operations([ + cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), + ]), client_receive_initial_metadata_tag)) + client_due.add(client_receive_initial_metadata_tag) + client_complete_rpc_start_batch_result = ( + client_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ]), client_complete_rpc_tag)) + client_due.add(client_complete_rpc_tag) + + server_rpc_event = server_driver.first_event() + + with server_call_condition: + server_send_initial_metadata_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS), + ]), server_send_initial_metadata_tag)) + server_send_first_message_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS), + ]), server_send_first_message_tag)) + server_send_initial_metadata_event = server_call_driver.event_with_tag( + server_send_initial_metadata_tag) + server_send_first_message_event = server_call_driver.event_with_tag( + server_send_first_message_tag) + with server_call_condition: + server_send_second_message_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS), + ]), server_send_second_message_tag)) + server_complete_rpc_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), + cygrpc.operation_send_status_from_server( + cygrpc.Metadata(()), cygrpc.StatusCode.ok, b'test details', + _EMPTY_FLAGS), + ]), server_complete_rpc_tag)) + server_send_second_message_event = server_call_driver.event_with_tag( + server_send_second_message_tag) + server_complete_rpc_event = server_call_driver.event_with_tag( + server_complete_rpc_tag) + server_call_driver.events() + + with client_condition: + client_receive_first_message_tag = 'client_receive_first_message_tag' + client_receive_first_message_start_batch_result = ( + client_call.start_batch(cygrpc.Operations([ + cygrpc.operation_receive_message(_EMPTY_FLAGS), + ]), client_receive_first_message_tag)) + client_due.add(client_receive_first_message_tag) + client_receive_first_message_event = client_driver.event_with_tag( + client_receive_first_message_tag) + + client_call_cancel_result = client_call.cancel() + client_driver.events() + + server.shutdown(server_completion_queue, server_shutdown_tag) + server.cancel_all_calls() + server_driver.events() + + self.assertEqual(cygrpc.CallError.ok, request_call_result) + self.assertEqual( + cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result) + self.assertEqual( + cygrpc.CallError.ok, client_receive_initial_metadata_start_batch_result) + self.assertEqual( + cygrpc.CallError.ok, client_complete_rpc_start_batch_result) + self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result) + self.assertIs(server_rpc_tag, server_rpc_event.tag) + self.assertEqual( + cygrpc.CompletionType.operation_complete, server_rpc_event.type) + self.assertIsInstance(server_rpc_event.operation_call, cygrpc.Call) + self.assertEqual(0, len(server_rpc_event.batch_operations)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_rpc_test.py b/src/python/grpcio/tests/unit/_rpc_test.py new file mode 100644 index 0000000000..1c7a14c5d0 --- /dev/null +++ b/src/python/grpcio/tests/unit/_rpc_test.py @@ -0,0 +1,775 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test of gRPC Python's application-layer API.""" + +import itertools +import threading +import unittest +from concurrent import futures + +import grpc +from grpc.framework.foundation import logging_pool + +from tests.unit.framework.common import test_constants +from tests.unit.framework.common import test_control + +_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) / 2:] +_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) / 3] + +_UNARY_UNARY = b'/test/UnaryUnary' +_UNARY_STREAM = b'/test/UnaryStream' +_STREAM_UNARY = b'/test/StreamUnary' +_STREAM_STREAM = b'/test/StreamStream' + + +class _Callback(object): + + def __init__(self): + self._condition = threading.Condition() + self._value = None + self._called = False + + def __call__(self, value): + with self._condition: + self._value = value + self._called = True + self._condition.notify_all() + + def value(self): + with self._condition: + while not self._called: + self._condition.wait() + return self._value + + +class _Handler(object): + + def __init__(self, control): + self._control = control + + def handle_unary_unary(self, request, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + return request + + def handle_unary_stream(self, request, servicer_context): + for _ in range(test_constants.STREAM_LENGTH): + self._control.control() + yield request + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + + def handle_stream_unary(self, request_iterator, servicer_context): + if servicer_context is not None: + servicer_context.invocation_metadata() + self._control.control() + response_elements = [] + for request in request_iterator: + self._control.control() + response_elements.append(request) + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + return b''.join(response_elements) + + def handle_stream_stream(self, request_iterator, servicer_context): + self._control.control() + if servicer_context is not None: + servicer_context.set_trailing_metadata(((b'testkey', b'testvalue',),)) + for request in request_iterator: + self._control.control() + yield request + self._control.control() + + +class _MethodHandler(grpc.RpcMethodHandler): + + def __init__( + self, request_streaming, response_streaming, request_deserializer, + response_serializer, unary_unary, unary_stream, stream_unary, + stream_stream): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = request_deserializer + self.response_serializer = response_serializer + self.unary_unary = unary_unary + self.unary_stream = unary_stream + self.stream_unary = stream_unary + self.stream_stream = stream_stream + + +class _GenericHandler(grpc.GenericRpcHandler): + + def __init__(self, handler): + self._handler = handler + + def service(self, handler_call_details): + if handler_call_details.method == _UNARY_UNARY: + return _MethodHandler( + False, False, None, None, self._handler.handle_unary_unary, None, + None, None) + elif handler_call_details.method == _UNARY_STREAM: + return _MethodHandler( + False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, + self._handler.handle_unary_stream, None, None) + elif handler_call_details.method == _STREAM_UNARY: + return _MethodHandler( + True, False, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, None, + self._handler.handle_stream_unary, None) + elif handler_call_details.method == _STREAM_STREAM: + return _MethodHandler( + True, True, None, None, None, None, None, + self._handler.handle_stream_stream) + else: + return None + + +def _unary_unary_multi_callable(channel): + return channel.unary_unary(_UNARY_UNARY) + + +def _unary_stream_multi_callable(channel): + return channel.unary_stream( + _UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_unary_multi_callable(channel): + return channel.stream_unary( + _STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE) + + +def _stream_stream_multi_callable(channel): + return channel.stream_stream(_STREAM_STREAM) + + +class RPCTest(unittest.TestCase): + + def setUp(self): + self._control = test_control.PauseFailControl() + self._handler = _Handler(self._control) + self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + + self._server = grpc.server((), self._server_pool) + port = self._server.add_insecure_port(b'[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) + self._server.start() + + self._channel = grpc.insecure_channel(b'localhost:%d' % port) + + # TODO(nathaniel): Why is this necessary, and only in some development + # environments? + def tearDown(self): + del self._channel + del self._server + del self._server_pool + + def testUnrecognizedMethod(self): + request = b'abc' + + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(b'NoSuchMethod')(request) + + self.assertEqual( + grpc.StatusCode.UNIMPLEMENTED, exception_context.exception.code()) + + def testSuccessfulUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x08' + expected_response = self._handler.handle_unary_unary(request, None) + + multi_callable = _unary_unary_multi_callable(self._channel) + response = multi_callable( + request, metadata=( + (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponse'),)) + + self.assertEqual(expected_response, response) + + def testSuccessfulUnaryRequestBlockingUnaryResponseWithCall(self): + request = b'\x07\x08' + expected_response = self._handler.handle_unary_unary(request, None) + + multi_callable = _unary_unary_multi_callable(self._channel) + response, call = multi_callable( + request, metadata=( + (b'test', b'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),), + with_call=True) + + self.assertEqual(expected_response, response) + self.assertIs(grpc.StatusCode.OK, call.code()) + + def testSuccessfulUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x08' + expected_response = self._handler.handle_unary_unary(request, None) + + multi_callable = _unary_unary_multi_callable(self._channel) + response_future = multi_callable.future( + request, metadata=( + (b'test', b'SuccessfulUnaryRequestFutureUnaryResponse'),)) + response = response_future.result() + + self.assertEqual(expected_response, response) + + def testSuccessfulUnaryRequestStreamResponse(self): + request = b'\x37\x58' + expected_responses = tuple(self._handler.handle_unary_stream(request, None)) + + multi_callable = _unary_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request, + metadata=((b'test', b'SuccessfulUnaryRequestStreamResponse'),)) + responses = tuple(response_iterator) + + self.assertSequenceEqual(expected_responses, responses) + + def testSuccessfulStreamRequestBlockingUnaryResponse(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary(iter(requests), None) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + response = multi_callable( + request_iterator, + metadata=((b'test', b'SuccessfulStreamRequestBlockingUnaryResponse'),)) + + self.assertEqual(expected_response, response) + + def testSuccessfulStreamRequestBlockingUnaryResponseWithCall(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary(iter(requests), None) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + response, call = multi_callable( + request_iterator, + metadata=( + (b'test', b'SuccessfulStreamRequestBlockingUnaryResponseWithCall'), + ), with_call=True) + + self.assertEqual(expected_response, response) + self.assertIs(grpc.StatusCode.OK, call.code()) + + def testSuccessfulStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary(iter(requests), None) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + response_future = multi_callable.future( + request_iterator, + metadata=( + (b'test', b'SuccessfulStreamRequestFutureUnaryResponse'),)) + response = response_future.result() + + self.assertEqual(expected_response, response) + + def testSuccessfulStreamRequestStreamResponse(self): + requests = tuple(b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)) + expected_responses = tuple( + self._handler.handle_stream_stream(iter(requests), None)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request_iterator, + metadata=((b'test', b'SuccessfulStreamRequestStreamResponse'),)) + responses = tuple(response_iterator) + + self.assertSequenceEqual(expected_responses, responses) + + def testSequentialInvocations(self): + first_request = b'\x07\x08' + second_request = b'\x0809' + expected_first_response = self._handler.handle_unary_unary( + first_request, None) + expected_second_response = self._handler.handle_unary_unary( + second_request, None) + + multi_callable = _unary_unary_multi_callable(self._channel) + first_response = multi_callable( + first_request, metadata=((b'test', b'SequentialInvocations'),)) + second_response = multi_callable( + second_request, metadata=((b'test', b'SequentialInvocations'),)) + + self.assertEqual(expected_first_response, first_response) + self.assertEqual(expected_second_response, second_response) + + def testConcurrentBlockingInvocations(self): + pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary(iter(requests), None) + expected_responses = [expected_response] * test_constants.THREAD_CONCURRENCY + response_futures = [None] * test_constants.THREAD_CONCURRENCY + + multi_callable = _stream_unary_multi_callable(self._channel) + for index in range(test_constants.THREAD_CONCURRENCY): + request_iterator = iter(requests) + response_future = pool.submit( + multi_callable, request_iterator, + metadata=((b'test', b'ConcurrentBlockingInvocations'),)) + response_futures[index] = response_future + responses = tuple( + response_future.result() for response_future in response_futures) + + pool.shutdown(wait=True) + self.assertSequenceEqual(expected_responses, responses) + + def testConcurrentFutureInvocations(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + expected_response = self._handler.handle_stream_unary(iter(requests), None) + expected_responses = [expected_response] * test_constants.THREAD_CONCURRENCY + response_futures = [None] * test_constants.THREAD_CONCURRENCY + + multi_callable = _stream_unary_multi_callable(self._channel) + for index in range(test_constants.THREAD_CONCURRENCY): + request_iterator = iter(requests) + response_future = multi_callable.future( + request_iterator, + metadata=((b'test', b'ConcurrentFutureInvocations'),)) + response_futures[index] = response_future + responses = tuple( + response_future.result() for response_future in response_futures) + + self.assertSequenceEqual(expected_responses, responses) + + def testWaitingForSomeButNotAllConcurrentFutureInvocations(self): + pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) + request = b'\x67\x68' + expected_response = self._handler.handle_unary_unary(request, None) + response_futures = [None] * test_constants.THREAD_CONCURRENCY + lock = threading.Lock() + test_is_running_cell = [True] + def wrap_future(future): + def wrap(): + try: + return future.result() + except grpc.RpcError: + with lock: + if test_is_running_cell[0]: + raise + return None + return wrap + + multi_callable = _unary_unary_multi_callable(self._channel) + for index in range(test_constants.THREAD_CONCURRENCY): + inner_response_future = multi_callable.future( + request, + metadata=( + (b'test', + b'WaitingForSomeButNotAllConcurrentFutureInvocations'),)) + outer_response_future = pool.submit(wrap_future(inner_response_future)) + response_futures[index] = outer_response_future + + some_completed_response_futures_iterator = itertools.islice( + futures.as_completed(response_futures), + test_constants.THREAD_CONCURRENCY // 2) + for response_future in some_completed_response_futures_iterator: + self.assertEqual(expected_response, response_future.result()) + with lock: + test_is_running_cell[0] = False + + def testConsumingOneStreamResponseUnaryRequest(self): + request = b'\x57\x38' + + multi_callable = _unary_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request, + metadata=( + (b'test', b'ConsumingOneStreamResponseUnaryRequest'),)) + next(response_iterator) + + def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self): + request = b'\x57\x38' + + multi_callable = _unary_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request, + metadata=( + (b'test', b'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) + for _ in range(test_constants.STREAM_LENGTH // 2): + next(response_iterator) + + def testConsumingSomeButNotAllStreamResponsesStreamRequest(self): + requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request_iterator, + metadata=( + (b'test', b'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) + for _ in range(test_constants.STREAM_LENGTH // 2): + next(response_iterator) + + def testConsumingTooManyStreamResponsesStreamRequest(self): + requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + response_iterator = multi_callable( + request_iterator, + metadata=( + (b'test', b'ConsumingTooManyStreamResponsesStreamRequest'),)) + for _ in range(test_constants.STREAM_LENGTH): + next(response_iterator) + for _ in range(test_constants.STREAM_LENGTH): + with self.assertRaises(StopIteration): + next(response_iterator) + + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.OK, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def testCancelledUnaryRequestUnaryResponse(self): + request = b'\x07\x17' + + multi_callable = _unary_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request, + metadata=((b'test', b'CancelledUnaryRequestUnaryResponse'),)) + response_future.cancel() + + self.assertTrue(response_future.cancelled()) + with self.assertRaises(grpc.FutureCancelledError): + response_future.result() + self.assertIs(grpc.StatusCode.CANCELLED, response_future.code()) + + def testCancelledUnaryRequestStreamResponse(self): + request = b'\x07\x19' + + multi_callable = _unary_stream_multi_callable(self._channel) + with self._control.pause(): + response_iterator = multi_callable( + request, + metadata=((b'test', b'CancelledUnaryRequestStreamResponse'),)) + self._control.block_until_paused() + response_iterator.cancel() + + with self.assertRaises(grpc.RpcError) as exception_context: + next(response_iterator) + self.assertIs(grpc.StatusCode.CANCELLED, exception_context.exception.code()) + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def testCancelledStreamRequestUnaryResponse(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request_iterator, + metadata=((b'test', b'CancelledStreamRequestUnaryResponse'),)) + self._control.block_until_paused() + response_future.cancel() + + self.assertTrue(response_future.cancelled()) + with self.assertRaises(grpc.FutureCancelledError): + response_future.result() + self.assertIsNotNone(response_future.initial_metadata()) + self.assertIs(grpc.StatusCode.CANCELLED, response_future.code()) + self.assertIsNotNone(response_future.details()) + self.assertIsNotNone(response_future.trailing_metadata()) + + def testCancelledStreamRequestStreamResponse(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + with self._control.pause(): + response_iterator = multi_callable( + request_iterator, + metadata=((b'test', b'CancelledStreamRequestStreamResponse'),)) + response_iterator.cancel() + + with self.assertRaises(grpc.RpcError): + next(response_iterator) + self.assertIsNotNone(response_iterator.initial_metadata()) + self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) + self.assertIsNotNone(response_iterator.details()) + self.assertIsNotNone(response_iterator.trailing_metadata()) + + def testExpiredUnaryRequestBlockingUnaryResponse(self): + request = b'\x07\x17' + + multi_callable = _unary_unary_multi_callable(self._channel) + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredUnaryRequestBlockingUnaryResponse'),), + with_call=True) + + self.assertIsNotNone(exception_context.exception.initial_metadata()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIsNotNone(exception_context.exception.details()) + self.assertIsNotNone(exception_context.exception.trailing_metadata()) + + def testExpiredUnaryRequestFutureUnaryResponse(self): + request = b'\x07\x17' + callback = _Callback() + + multi_callable = _unary_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredUnaryRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + self.assertIs(response_future, value_passed_to_callback) + self.assertIsNotNone(response_future.initial_metadata()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code()) + self.assertIsNotNone(response_future.details()) + self.assertIsNotNone(response_future.trailing_metadata()) + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, response_future.exception().code()) + + def testExpiredUnaryRequestStreamResponse(self): + request = b'\x07\x19' + + multi_callable = _unary_stream_multi_callable(self._channel) + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredUnaryRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code()) + + def testExpiredStreamRequestBlockingUnaryResponse(self): + requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request_iterator, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredStreamRequestBlockingUnaryResponse'),)) + + self.assertIsNotNone(exception_context.exception.initial_metadata()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIsNotNone(exception_context.exception.details()) + self.assertIsNotNone(exception_context.exception.trailing_metadata()) + + def testExpiredStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + callback = _Callback() + + multi_callable = _stream_unary_multi_callable(self._channel) + with self._control.pause(): + response_future = multi_callable.future( + request_iterator, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredStreamRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIs(response_future, value_passed_to_callback) + self.assertIsNotNone(response_future.initial_metadata()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code()) + self.assertIsNotNone(response_future.details()) + self.assertIsNotNone(response_future.trailing_metadata()) + + def testExpiredStreamRequestStreamResponse(self): + requests = tuple(b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + with self._control.pause(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request_iterator, timeout=test_constants.SHORT_TIMEOUT, + metadata=((b'test', b'ExpiredStreamRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code()) + self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code()) + + def testFailedUnaryRequestBlockingUnaryResponse(self): + request = b'\x37\x17' + + multi_callable = _unary_unary_multi_callable(self._channel) + with self._control.fail(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request, + metadata=((b'test', b'FailedUnaryRequestBlockingUnaryResponse'),), + with_call=True) + + self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + + def testFailedUnaryRequestFutureUnaryResponse(self): + request = b'\x37\x17' + callback = _Callback() + + multi_callable = _unary_unary_multi_callable(self._channel) + with self._control.fail(): + response_future = multi_callable.future( + request, + metadata=((b'test', b'FailedUnaryRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIs(grpc.StatusCode.UNKNOWN, response_future.exception().code()) + self.assertIs(response_future, value_passed_to_callback) + + def testFailedUnaryRequestStreamResponse(self): + request = b'\x37\x17' + + multi_callable = _unary_stream_multi_callable(self._channel) + with self.assertRaises(grpc.RpcError) as exception_context: + with self._control.fail(): + response_iterator = multi_callable( + request, + metadata=((b'test', b'FailedUnaryRequestStreamResponse'),)) + next(response_iterator) + + self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + + def testFailedStreamRequestBlockingUnaryResponse(self): + requests = tuple(b'\x47\x58' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + with self._control.fail(): + with self.assertRaises(grpc.RpcError) as exception_context: + multi_callable( + request_iterator, + metadata=((b'test', b'FailedStreamRequestBlockingUnaryResponse'),)) + + self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + + def testFailedStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + callback = _Callback() + + multi_callable = _stream_unary_multi_callable(self._channel) + with self._control.fail(): + response_future = multi_callable.future( + request_iterator, + metadata=((b'test', b'FailedStreamRequestFutureUnaryResponse'),)) + response_future.add_done_callback(callback) + value_passed_to_callback = callback.value() + + with self.assertRaises(grpc.RpcError) as exception_context: + response_future.result() + self.assertIs(grpc.StatusCode.UNKNOWN, response_future.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + self.assertIsInstance(response_future.exception(), grpc.RpcError) + self.assertIs(response_future, value_passed_to_callback) + + def testFailedStreamRequestStreamResponse(self): + requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + with self._control.fail(): + with self.assertRaises(grpc.RpcError) as exception_context: + response_iterator = multi_callable( + request_iterator, + metadata=((b'test', b'FailedStreamRequestStreamResponse'),)) + tuple(response_iterator) + + self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code()) + self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code()) + + def testIgnoredUnaryRequestFutureUnaryResponse(self): + request = b'\x37\x17' + + multi_callable = _unary_unary_multi_callable(self._channel) + multi_callable.future( + request, + metadata=((b'test', b'IgnoredUnaryRequestFutureUnaryResponse'),)) + + def testIgnoredUnaryRequestStreamResponse(self): + request = b'\x37\x17' + + multi_callable = _unary_stream_multi_callable(self._channel) + multi_callable( + request, + metadata=((b'test', b'IgnoredUnaryRequestStreamResponse'),)) + + def testIgnoredStreamRequestFutureUnaryResponse(self): + requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_unary_multi_callable(self._channel) + multi_callable.future( + request_iterator, + metadata=((b'test', b'IgnoredStreamRequestFutureUnaryResponse'),)) + + def testIgnoredStreamRequestStreamResponse(self): + requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + request_iterator = iter(requests) + + multi_callable = _stream_stream_multi_callable(self._channel) + multi_callable( + request_iterator, + metadata=((b'test', b'IgnoredStreamRequestStreamResponse'),)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/beta/_auth_test.py b/src/python/grpcio/tests/unit/beta/_auth_test.py new file mode 100644 index 0000000000..694928a91b --- /dev/null +++ b/src/python/grpcio/tests/unit/beta/_auth_test.py @@ -0,0 +1,96 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests of standard AuthMetadataPlugins.""" + +import collections +import threading +import unittest + +from grpc.beta import _auth + + +class MockGoogleCreds(object): + + def get_access_token(self): + token = collections.namedtuple('MockAccessTokenInfo', + ('access_token', 'expires_in')) + token.access_token = 'token' + return token + + +class MockExceptionGoogleCreds(object): + + def get_access_token(self): + raise Exception() + + +class GoogleCallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials_success(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertEqual(metadata, (('authorization', 'Bearer token'),)) + self.assertIsNone(error) + callback_event.set() + + call_creds = _auth.GoogleCallCredentials(MockGoogleCreds()) + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + def test_google_call_credentials_error(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertIsNotNone(error) + callback_event.set() + + call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds()) + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + +class AccessTokenCallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials_success(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertEqual(metadata, (('authorization', 'Bearer token'),)) + self.assertIsNone(error) + callback_event.set() + + call_creds = _auth.AccessTokenCallCredentials('token') + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/beta/_implementations_test.py b/src/python/grpcio/tests/unit/beta/_implementations_test.py index 26be670c45..127f93e9bb 100644 --- a/src/python/grpcio/tests/unit/beta/_implementations_test.py +++ b/src/python/grpcio/tests/unit/beta/_implementations_test.py @@ -29,8 +29,11 @@ """Tests the implementations module of the gRPC Python Beta API.""" +import datetime import unittest +from oauth2client import client as oauth2client_client + from grpc.beta import implementations from tests.unit import resources @@ -49,5 +52,19 @@ class ChannelCredentialsTest(unittest.TestCase): channel_credentials, implementations.ChannelCredentials) +class CallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials(self): + creds = oauth2client_client.GoogleCredentials( + 'token', 'client_id', 'secret', 'refresh_token', + datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/', + 'user_agent') + call_creds = implementations.google_call_credentials(creds) + self.assertIsInstance(call_creds, implementations.CallCredentials) + + def test_access_token_call_credentials(self): + call_creds = implementations.access_token_call_credentials('token') + self.assertIsInstance(call_creds, implementations.CallCredentials) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/framework/common/test_control.py b/src/python/grpcio/tests/unit/framework/common/test_control.py index ca5ba3a854..088e2f8b88 100644 --- a/src/python/grpcio/tests/unit/framework/common/test_control.py +++ b/src/python/grpcio/tests/unit/framework/common/test_control.py @@ -60,10 +60,16 @@ class Control(six.with_metaclass(abc.ABCMeta)): class PauseFailControl(Control): - """A Control that can be used to pause or fail code under control.""" + """A Control that can be used to pause or fail code under control. + + This object is only safe for use from two threads: one of the system under + test calling control and the other from the test system calling pause, + block_until_paused, and fail. + """ def __init__(self): self._condition = threading.Condition() + self._pause = False self._paused = False self._fail = False @@ -72,19 +78,31 @@ class PauseFailControl(Control): if self._fail: raise Defect() - while self._paused: + while self._pause: + self._paused = True + self._condition.notify_all() self._condition.wait() + self._paused = False @contextlib.contextmanager def pause(self): """Pauses code under control while controlling code is in context.""" with self._condition: - self._paused = True + self._pause = True yield with self._condition: - self._paused = False + self._pause = False self._condition.notify_all() + def block_until_paused(self): + """Blocks controlling code until code under control is paused. + + May only be called within the context of a pause call. + """ + with self._condition: + while not self._paused: + self._condition.wait() + @contextlib.contextmanager def fail(self): """Fails code under control while controlling code is in context.""" |