diff options
author | Alistair Veitch <aveitch@google.com> | 2016-06-07 09:04:44 -0700 |
---|---|---|
committer | Alistair Veitch <aveitch@google.com> | 2016-06-07 09:04:44 -0700 |
commit | ef1c860197742c87819070e6906b96d1e42451b3 (patch) | |
tree | f1636ea326a936262033189ad48f9026b56c3cda /src/python | |
parent | 0cf4defa42dfd0e1f09ea3e1fae7c111a673244f (diff) | |
parent | d941fd8786b7eb9c20a299b84d7a2a7c9228549d (diff) |
merge to master
Diffstat (limited to 'src/python')
24 files changed, 1863 insertions, 412 deletions
diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 536aaa81cf..5ba5a4e1fd 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -352,6 +352,85 @@ class Call(six.with_metaclass(abc.ABCMeta, RpcContext)): raise NotImplementedError() +############ Authentication & Authorization Interfaces & Classes ############# + + +class ChannelCredentials(object): + """A value encapsulating the data required to create a secure Channel. + + This class has no supported interface - it exists to define the type of its + instances and its instances exist to be passed to other functions. + """ + + def __init__(self, credentials): + self._credentials = credentials + + +class CallCredentials(object): + """A value encapsulating data asserting an identity over a channel. + + A CallCredentials may be composed with ChannelCredentials to always assert + identity for every call over that Channel. + + This class has no supported interface - it exists to define the type of its + instances and its instances exist to be passed to other functions. + """ + + def __init__(self, credentials): + self._credentials = credentials + + +class AuthMetadataContext(six.with_metaclass(abc.ABCMeta)): + """Provides information to call credentials metadata plugins. + + Attributes: + service_url: A string URL of the service being called into. + method_name: A string of the fully qualified method name being called. + """ + + +class AuthMetadataPluginCallback(six.with_metaclass(abc.ABCMeta)): + """Callback object received by a metadata plugin.""" + + def __call__(self, metadata, error): + """Inform the gRPC runtime of the metadata to construct a CallCredentials. + + Args: + metadata: An iterable of 2-sequences (e.g. tuples) of metadata key/value + pairs. + error: An Exception to indicate error or None to indicate success. + """ + raise NotImplementedError() + + +class AuthMetadataPlugin(six.with_metaclass(abc.ABCMeta)): + """A specification for custom authentication.""" + + def __call__(self, context, callback): + """Implements authentication by passing metadata to a callback. + + Implementations of this method must not block. + + Args: + context: An AuthMetadataContext providing information on the RPC that the + plugin is being called to authenticate. + callback: An AuthMetadataPluginCallback to be invoked either synchronously + or asynchronously. + """ + raise NotImplementedError() + + +class ServerCredentials(object): + """A value encapsulating the data required to open a secure port on a Server. + + This class has no supported interface - it exists to define the type of its + instances and its instances exist to be passed to other functions. + """ + + def __init__(self, credentials): + self._credentials = credentials + + ######################## Multi-Callable Interfaces ########################### @@ -359,7 +438,9 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): """Affords invoking a unary-unary RPC.""" @abc.abstractmethod - def __call__(self, request, timeout=None, metadata=None, with_call=False): + def __call__( + self, request, timeout=None, metadata=None, credentials=None, + with_call=False): """Synchronously invokes the underlying RPC. Args: @@ -367,6 +448,7 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout: An optional duration of time in seconds to allow for the RPC. metadata: An optional sequence of pairs of bytes to be transmitted to the service-side of the RPC. + credentials: An optional CallCredentials for the RPC. with_call: Whether or not to include return a Call for the RPC in addition to the response. @@ -382,7 +464,7 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): raise NotImplementedError() @abc.abstractmethod - def future(self, request, timeout=None, metadata=None): + def future(self, request, timeout=None, metadata=None, credentials=None): """Asynchronously invokes the underlying RPC. Args: @@ -390,6 +472,7 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout: An optional duration of time in seconds to allow for the RPC. metadata: An optional sequence of pairs of bytes to be transmitted to the service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Returns: An object that is both a Call for the RPC and a Future. In the event of @@ -404,7 +487,7 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): """Affords invoking a unary-stream RPC.""" @abc.abstractmethod - def __call__(self, request, timeout=None, metadata=None): + def __call__(self, request, timeout=None, metadata=None, credentials=None): """Invokes the underlying RPC. Args: @@ -412,6 +495,7 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout: An optional duration of time in seconds to allow for the RPC. metadata: An optional sequence of pairs of bytes to be transmitted to the service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Returns: An object that is both a Call for the RPC and an iterator of response @@ -426,7 +510,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): @abc.abstractmethod def __call__( - self, request_iterator, timeout=None, metadata=None, with_call=False): + self, request_iterator, timeout=None, metadata=None, credentials=None, + with_call=False): """Synchronously invokes the underlying RPC. Args: @@ -434,6 +519,7 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout: An optional duration of time in seconds to allow for the RPC. metadata: An optional sequence of pairs of bytes to be transmitted to the service-side of the RPC. + credentials: An optional CallCredentials for the RPC. with_call: Whether or not to include return a Call for the RPC in addition to the response. @@ -449,7 +535,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): raise NotImplementedError() @abc.abstractmethod - def future(self, request_iterator, timeout=None, metadata=None): + def future( + self, request_iterator, timeout=None, metadata=None, credentials=None): """Asynchronously invokes the underlying RPC. Args: @@ -457,6 +544,7 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout: An optional duration of time in seconds to allow for the RPC. metadata: An optional sequence of pairs of bytes to be transmitted to the service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Returns: An object that is both a Call for the RPC and a Future. In the event of @@ -471,7 +559,8 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): """Affords invoking a stream-stream RPC in any call style.""" @abc.abstractmethod - def __call__(self, request_iterator, timeout=None, metadata=None): + def __call__( + self, request_iterator, timeout=None, metadata=None, credentials=None): """Invokes the underlying RPC. Args: @@ -479,6 +568,7 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): timeout: An optional duration of time in seconds to allow for the RPC. metadata: An optional sequence of pairs of bytes to be transmitted to the service-side of the RPC. + credentials: An optional CallCredentials for the RPC. Returns: An object that is both a Call for the RPC and an iterator of response @@ -690,7 +780,6 @@ class RpcMethodHandler(six.with_metaclass(abc.ABCMeta)): class HandlerCallDetails(six.with_metaclass(abc.ABCMeta)): """Describes an RPC that has just arrived for service. - Attributes: method: The method name of the RPC. invocation_metadata: The metadata from the invocation side of the RPC. @@ -751,6 +840,25 @@ class Server(six.with_metaclass(abc.ABCMeta)): raise NotImplementedError() @abc.abstractmethod + def add_secure_port(self, address, server_credentials): + """Reserves a port for secure RPC service after this Server becomes active. + + This method may only be called before calling this Server's start method is + called. + + Args: + address: The address for which to open a port. + server_credentials: A ServerCredentials. + + Returns: + An integer port on which RPCs will be serviced after this link has been + started. This is typically the same number as the port number contained + in the passed address, but will likely be different if the port number + contained in the passed address was zero. + """ + raise NotImplementedError() + + @abc.abstractmethod def start(self): """Starts this Server's service of RPCs. @@ -792,6 +900,135 @@ class Server(six.with_metaclass(abc.ABCMeta)): ################################# Functions ################################ +def ssl_channel_credentials( + root_certificates=None, private_key=None, certificate_chain=None): + """Creates a ChannelCredentials for use with an SSL-enabled Channel. + + Args: + root_certificates: The PEM-encoded root certificates or unset to ask for + them to be retrieved from a default location. + private_key: The PEM-encoded private key to use or unset if no private key + should be used. + certificate_chain: The PEM-encoded certificate chain to use or unset if no + certificate chain should be used. + + Returns: + A ChannelCredentials for use with an SSL-enabled Channel. + """ + if private_key is not None or certificate_chain is not None: + pair = _cygrpc.SslPemKeyCertPair(private_key, certificate_chain) + else: + pair = None + return ChannelCredentials( + _cygrpc.channel_credentials_ssl(root_certificates, pair)) + + +def metadata_call_credentials(metadata_plugin, name=None): + """Construct CallCredentials from an AuthMetadataPlugin. + + Args: + metadata_plugin: An AuthMetadataPlugin to use as the authentication behavior + in the created CallCredentials. + name: A name for the plugin. + + Returns: + A CallCredentials. + """ + from grpc import _plugin_wrapping + if name is None: + try: + effective_name = metadata_plugin.__name__ + except AttributeError: + effective_name = metadata_plugin.__class__.__name__ + else: + effective_name = name + return CallCredentials( + _plugin_wrapping.call_credentials_metadata_plugin( + metadata_plugin, effective_name)) + + +def access_token_call_credentials(access_token): + """Construct CallCredentials from an access token. + + Args: + access_token: A string to place directly in the http request + authorization header, ie "Authorization: Bearer <access_token>". + + Returns: + A CallCredentials. + """ + from grpc import _auth + return metadata_call_credentials( + _auth.AccessTokenCallCredentials(access_token)) + + +def composite_call_credentials(call_credentials, additional_call_credentials): + """Compose two CallCredentials to make a new one. + + Args: + call_credentials: A CallCredentials object. + additional_call_credentials: Another CallCredentials object to compose on + top of call_credentials. + + Returns: + A new CallCredentials composed of the two given CallCredentials. + """ + return CallCredentials( + _cygrpc.call_credentials_composite( + call_credentials._credentials, + additional_call_credentials._credentials)) + + +def composite_channel_credentials(channel_credentials, call_credentials): + """Compose a ChannelCredentials and a CallCredentials. + + Args: + channel_credentials: A ChannelCredentials. + call_credentials: A CallCredentials. + + Returns: + A ChannelCredentials composed of the given ChannelCredentials and + CallCredentials. + """ + return ChannelCredentials( + _cygrpc.channel_credentials_composite( + channel_credentials._credentials, call_credentials._credentials)) + + +def ssl_server_credentials( + private_key_certificate_chain_pairs, root_certificates=None, + require_client_auth=False): + """Creates a ServerCredentials for use with an SSL-enabled Server. + + Args: + private_key_certificate_chain_pairs: A nonempty sequence each element of + which is a pair the first element of which is a PEM-encoded private key + and the second element of which is the corresponding PEM-encoded + certificate chain. + root_certificates: PEM-encoded client root certificates to be used for + verifying authenticated clients. If omitted, require_client_auth must also + be omitted or be False. + require_client_auth: A boolean indicating whether or not to require clients + to be authenticated. May only be True if root_certificates is not None. + + Returns: + A ServerCredentials for use with an SSL-enabled Server. + """ + if len(private_key_certificate_chain_pairs) == 0: + raise ValueError( + 'At least one private key-certificate chain pair is required!') + elif require_client_auth and root_certificates is None: + raise ValueError( + 'Illegal to require client auth without providing root certificates!') + else: + return ServerCredentials( + _cygrpc.server_credentials_ssl( + root_certificates, + [_cygrpc.SslPemKeyCertPair(key, pem) + for key, pem in private_key_certificate_chain_pairs], + require_client_auth)) + + def channel_ready_future(channel): """Creates a Future tracking when a Channel is ready. @@ -825,6 +1062,22 @@ def insecure_channel(target, options=None): return _channel.Channel(target, None, options) +def secure_channel(target, credentials, options=None): + """Creates an insecure Channel to a server. + + Args: + target: The target to which to connect. + credentials: A ChannelCredentials instance. + options: A sequence of string-value pairs according to which to configure + the created channel. + + Returns: + A Channel to the target through which RPCs may be conducted. + """ + from grpc import _channel + return _channel.Channel(target, credentials, options) + + def server(generic_rpc_handlers, thread_pool, options=None): """Creates a Server with which RPCs can be serviced. diff --git a/src/python/grpcio/grpc/_adapter/_low.py b/src/python/grpcio/grpc/_adapter/_low.py index 00788bd4cf..48410167a0 100644 --- a/src/python/grpcio/grpc/_adapter/_low.py +++ b/src/python/grpcio/grpc/_adapter/_low.py @@ -30,8 +30,8 @@ import threading from grpc import _grpcio_metadata +from grpc import _plugin_wrapping from grpc._cython import cygrpc -from grpc._adapter import _implementations from grpc._adapter import _types _USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__) @@ -57,78 +57,8 @@ def channel_credentials_ssl( return cygrpc.channel_credentials_ssl(root_certificates, pair) -class _WrappedCygrpcCallback(object): - - def __init__(self, cygrpc_callback): - self.is_called = False - self.error = None - self.is_called_lock = threading.Lock() - self.cygrpc_callback = cygrpc_callback - - def _invoke_failure(self, error): - # TODO(atash) translate different Exception superclasses into different - # status codes. - self.cygrpc_callback( - cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message) - - def _invoke_success(self, metadata): - try: - cygrpc_metadata = cygrpc.Metadata( - cygrpc.Metadatum(key, value) - for key, value in metadata) - except Exception as error: - self._invoke_failure(error) - return - self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '') - - def __call__(self, metadata, error): - with self.is_called_lock: - if self.is_called: - raise RuntimeError('callback should only ever be invoked once') - if self.error: - self._invoke_failure(self.error) - return - self.is_called = True - if error is None: - self._invoke_success(metadata) - else: - self._invoke_failure(error) - - def notify_failure(self, error): - with self.is_called_lock: - if not self.is_called: - self.error = error - - -class _WrappedPlugin(object): - - def __init__(self, plugin): - self.plugin = plugin - - def __call__(self, context, cygrpc_callback): - wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback) - wrapped_context = _implementations.AuthMetadataContext(context.service_url, - context.method_name) - try: - self.plugin( - wrapped_context, - _implementations.AuthMetadataPluginCallback(wrapped_cygrpc_callback)) - except Exception as error: - wrapped_cygrpc_callback.notify_failure(error) - raise - - -def call_credentials_metadata_plugin(plugin, name): - """ - Args: - plugin: A callable accepting a _types.AuthMetadataContext - object and a callback (itself accepting a list of metadata key/value - 2-tuples and a None-able exception value). The callback must be eventually - called, but need not be called in plugin's invocation. - plugin's invocation must be non-blocking. - """ - return cygrpc.call_credentials_metadata_plugin( - cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name)) +call_credentials_metadata_plugin = ( + _plugin_wrapping.call_credentials_metadata_plugin) class CompletionQueue(_types.CompletionQueue): diff --git a/src/python/grpcio/grpc/_adapter/_implementations.py b/src/python/grpcio/grpc/_auth.py index b85f228bf6..3ae00ca23a 100644 --- a/src/python/grpcio/grpc/_adapter/_implementations.py +++ b/src/python/grpcio/grpc/_auth.py @@ -1,4 +1,4 @@ -# Copyright 2015, Google Inc. +# Copyright 2016, Google Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -27,22 +27,47 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import collections +"""GRPCAuthMetadataPlugins for standard authentication.""" -from grpc.beta import interfaces +from concurrent import futures -class AuthMetadataContext(collections.namedtuple( - 'AuthMetadataContext', [ - 'service_url', - 'method_name' - ]), interfaces.GRPCAuthMetadataContext): - pass +import grpc -class AuthMetadataPluginCallback(interfaces.GRPCAuthMetadataContext): +def _sign_request(callback, token, error): + metadata = (('authorization', 'Bearer {}'.format(token)),) + callback(metadata, error) - def __init__(self, callback): - self._callback = callback - def __call__(self, metadata, error): - self._callback(metadata, error) +class GoogleCallCredentials(grpc.AuthMetadataPlugin): + """Metadata wrapper for GoogleCredentials from the oauth2client library.""" + + def __init__(self, credentials): + self._credentials = credentials + self._pool = futures.ThreadPoolExecutor(max_workers=1) + + def __call__(self, context, callback): + # MetadataPlugins cannot block (see grpc.beta.interfaces.py) + future = self._pool.submit(self._credentials.get_access_token) + future.add_done_callback(lambda x: self._get_token_callback(callback, x)) + + def _get_token_callback(self, callback, future): + try: + access_token = future.result().access_token + except Exception as e: + _sign_request(callback, None, e) + else: + _sign_request(callback, access_token, None) + + def __del__(self): + self._pool.shutdown(wait=False) + + +class AccessTokenCallCredentials(grpc.AuthMetadataPlugin): + """Metadata wrapper for raw access token credentials.""" + + def __init__(self, access_token): + self._access_token = access_token + + def __call__(self, context, callback): + _sign_request(callback, self._access_token, None) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi index c793c8f5e5..19a59e08f3 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi @@ -68,4 +68,4 @@ cdef void plugin_get_metadata( void *state, grpc_auth_metadata_context context, grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil -cdef void plugin_destroy_c_plugin_state(void *state) +cdef void plugin_destroy_c_plugin_state(void *state) with gil diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi index 94d13b5999..1ba86457af 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi @@ -137,7 +137,7 @@ cdef void plugin_get_metadata( cy_context.context = context self.plugin_callback(cy_context, python_callback) -cdef void plugin_destroy_c_plugin_state(void *state): +cdef void plugin_destroy_c_plugin_state(void *state) with gil: cpython.Py_DECREF(<CredentialsMetadataPlugin>state) def channel_credentials_google_default(): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index d42c58050f..05b8886df7 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -208,7 +208,7 @@ cdef extern from "grpc/_cython/loader.h": GRPC_CHANNEL_CONNECTING GRPC_CHANNEL_READY GRPC_CHANNEL_TRANSIENT_FAILURE - GRPC_CHANNEL_FATAL_FAILURE + GRPC_CHANNEL_SHUTDOWN ctypedef struct grpc_metadata: const char *key diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi index c7539f0d49..e0219b0086 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi @@ -33,7 +33,7 @@ class ConnectivityState: connecting = GRPC_CHANNEL_CONNECTING ready = GRPC_CHANNEL_READY transient_failure = GRPC_CHANNEL_TRANSIENT_FAILURE - fatal_failure = GRPC_CHANNEL_FATAL_FAILURE + fatal_failure = GRPC_CHANNEL_SHUTDOWN class ChannelArgKey: @@ -274,6 +274,7 @@ cdef class ByteBuffer: data_slice_length = gpr_slice_length(data_slice) with gil: result += (<char *>data_slice_pointer)[:data_slice_length] + gpr_slice_unref(data_slice) with nogil: grpc_byte_buffer_reader_destroy(&reader) return bytes(result) diff --git a/src/python/grpcio/grpc/_cython/imports.generated.c b/src/python/grpcio/grpc/_cython/imports.generated.c index 29ada16afd..d811ec68da 100644 --- a/src/python/grpcio/grpc/_cython/imports.generated.c +++ b/src/python/grpcio/grpc/_cython/imports.generated.c @@ -35,7 +35,7 @@ #include "imports.generated.h" -#ifdef GPR_WIN32 +#ifdef GPR_WINDOWS census_initialize_type census_initialize_import; census_shutdown_type census_shutdown_import; @@ -121,7 +121,6 @@ grpc_header_key_is_legal_type grpc_header_key_is_legal_import; grpc_header_nonbin_value_is_legal_type grpc_header_nonbin_value_is_legal_import; grpc_is_binary_header_type grpc_is_binary_header_import; grpc_call_error_to_string_type grpc_call_error_to_string_import; -grpc_cronet_secure_channel_create_type grpc_cronet_secure_channel_create_import; grpc_auth_property_iterator_next_type grpc_auth_property_iterator_next_import; grpc_auth_context_property_iterator_type grpc_auth_context_property_iterator_import; grpc_auth_context_peer_identity_type grpc_auth_context_peer_identity_import; @@ -388,7 +387,6 @@ void pygrpc_load_imports(HMODULE library) { grpc_header_nonbin_value_is_legal_import = (grpc_header_nonbin_value_is_legal_type) GetProcAddress(library, "grpc_header_nonbin_value_is_legal"); grpc_is_binary_header_import = (grpc_is_binary_header_type) GetProcAddress(library, "grpc_is_binary_header"); grpc_call_error_to_string_import = (grpc_call_error_to_string_type) GetProcAddress(library, "grpc_call_error_to_string"); - grpc_cronet_secure_channel_create_import = (grpc_cronet_secure_channel_create_type) GetProcAddress(library, "grpc_cronet_secure_channel_create"); grpc_auth_property_iterator_next_import = (grpc_auth_property_iterator_next_type) GetProcAddress(library, "grpc_auth_property_iterator_next"); grpc_auth_context_property_iterator_import = (grpc_auth_context_property_iterator_type) GetProcAddress(library, "grpc_auth_context_property_iterator"); grpc_auth_context_peer_identity_import = (grpc_auth_context_peer_identity_type) GetProcAddress(library, "grpc_auth_context_peer_identity"); @@ -571,4 +569,4 @@ void pygrpc_load_imports(HMODULE library) { } #endif /* __cpluslus */ -#endif /* !GPR_WIN32 */ +#endif /* !GPR_WINDOWS */ diff --git a/src/python/grpcio/grpc/_cython/imports.generated.h b/src/python/grpcio/grpc/_cython/imports.generated.h index f6dd79cadc..5e100796cd 100644 --- a/src/python/grpcio/grpc/_cython/imports.generated.h +++ b/src/python/grpcio/grpc/_cython/imports.generated.h @@ -36,14 +36,13 @@ #include <grpc/support/port_platform.h> -#ifdef GPR_WIN32 +#ifdef GPR_WINDOWS #include <windows.h> #include <grpc/census.h> #include <grpc/compression.h> #include <grpc/grpc.h> -#include <grpc/grpc_cronet.h> #include <grpc/grpc_security.h> #include <grpc/impl/codegen/alloc.h> #include <grpc/impl/codegen/byte_buffer.h> @@ -57,7 +56,7 @@ #include <grpc/support/cpu.h> #include <grpc/support/histogram.h> #include <grpc/support/host_port.h> -#include <grpc/support/log_win32.h> +#include <grpc/support/log_windows.h> #include <grpc/support/string_util.h> #include <grpc/support/subprocess.h> #include <grpc/support/thd.h> @@ -314,9 +313,6 @@ extern grpc_is_binary_header_type grpc_is_binary_header_import; typedef const char *(*grpc_call_error_to_string_type)(grpc_call_error error); extern grpc_call_error_to_string_type grpc_call_error_to_string_import; #define grpc_call_error_to_string grpc_call_error_to_string_import -typedef grpc_channel *(*grpc_cronet_secure_channel_create_type)(void *engine, const char *target, const grpc_channel_args *args, void *reserved); -extern grpc_cronet_secure_channel_create_type grpc_cronet_secure_channel_create_import; -#define grpc_cronet_secure_channel_create grpc_cronet_secure_channel_create_import typedef const grpc_auth_property *(*grpc_auth_property_iterator_next_type)(grpc_auth_property_iterator *it); extern grpc_auth_property_iterator_next_type grpc_auth_property_iterator_next_import; #define grpc_auth_property_iterator_next grpc_auth_property_iterator_next_import @@ -856,7 +852,7 @@ void pygrpc_load_imports(HMODULE library); } #endif /* __cpluslus */ -#else /* !GPR_WIN32 */ +#else /* !GPR_WINDOWS */ #include <grpc/byte_buffer.h> #include <grpc/byte_buffer_reader.h> @@ -868,6 +864,6 @@ void pygrpc_load_imports(HMODULE library); #include <grpc/support/time.h> #include <grpc/status.h> -#endif /* !GPR_WIN32 */ +#endif /* !GPR_WINDOWS */ #endif diff --git a/src/python/grpcio/grpc/_cython/loader.c b/src/python/grpcio/grpc/_cython/loader.c index 3b72806ea1..b909ad594e 100644 --- a/src/python/grpcio/grpc/_cython/loader.c +++ b/src/python/grpcio/grpc/_cython/loader.c @@ -37,7 +37,7 @@ extern "C" { #endif /* __cpluslus */ -#if GPR_WIN32 +#if GPR_WINDOWS int pygrpc_load_core(char *path) { HMODULE grpc_c; @@ -60,7 +60,7 @@ int pygrpc_load_core(char *path) { int pygrpc_load_core(char *path) { return 1; } -#endif /* !GPR_WIN32 */ +#endif /* !GPR_WINDOWS */ #ifdef __cplusplus } diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py new file mode 100644 index 0000000000..4e9cfe710c --- /dev/null +++ b/src/python/grpcio/grpc/_plugin_wrapping.py @@ -0,0 +1,123 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import collections +import threading + +import grpc +from grpc._cython import cygrpc + + +class AuthMetadataContext( + collections.namedtuple( + 'AuthMetadataContext', ('service_url', 'method_name',)), + grpc.AuthMetadataContext): + pass + + +class AuthMetadataPluginCallback(grpc.AuthMetadataContext): + + def __init__(self, callback): + self._callback = callback + + def __call__(self, metadata, error): + self._callback(metadata, error) + + +class _WrappedCygrpcCallback(object): + + def __init__(self, cygrpc_callback): + self.is_called = False + self.error = None + self.is_called_lock = threading.Lock() + self.cygrpc_callback = cygrpc_callback + + def _invoke_failure(self, error): + # TODO(atash) translate different Exception superclasses into different + # status codes. + self.cygrpc_callback( + cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message) + + def _invoke_success(self, metadata): + try: + cygrpc_metadata = cygrpc.Metadata( + cygrpc.Metadatum(key, value) + for key, value in metadata) + except Exception as error: + self._invoke_failure(error) + return + self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '') + + def __call__(self, metadata, error): + with self.is_called_lock: + if self.is_called: + raise RuntimeError('callback should only ever be invoked once') + if self.error: + self._invoke_failure(self.error) + return + self.is_called = True + if error is None: + self._invoke_success(metadata) + else: + self._invoke_failure(error) + + def notify_failure(self, error): + with self.is_called_lock: + if not self.is_called: + self.error = error + + +class _WrappedPlugin(object): + + def __init__(self, plugin): + self.plugin = plugin + + def __call__(self, context, cygrpc_callback): + wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback) + wrapped_context = AuthMetadataContext( + context.service_url, context.method_name) + try: + self.plugin( + wrapped_context, AuthMetadataPluginCallback(wrapped_cygrpc_callback)) + except Exception as error: + wrapped_cygrpc_callback.notify_failure(error) + raise + + +def call_credentials_metadata_plugin(plugin, name): + """ + Args: + plugin: A callable accepting a grpc.AuthMetadataContext + object and a callback (itself accepting a list of metadata key/value + 2-tuples and a None-able exception value). The callback must be eventually + called, but need not be called in plugin's invocation. + plugin's invocation must be non-blocking. + """ + return cygrpc.call_credentials_metadata_plugin( + cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name)) diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index c65070f1b3..aae9f48ae6 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -65,12 +65,23 @@ def _serialized_request(request_event): return request_event.batch_operations[0].received_message.bytes() -def _code(state): +def _application_code(code): + cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code) + return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code + + +def _completion_code(state): if state.code is None: return cygrpc.StatusCode.ok else: - code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(state.code) - return cygrpc.StatusCode.unknown if code is None else code + return _application_code(state.code) + + +def _abortion_code(state, code): + if state.code is None: + return code + else: + return _application_code(state.code) def _details(state): @@ -126,20 +137,22 @@ def _send_status_from_server(state, token): def _abort(state, call, code, details): if state.client is not _CANCELLED: + effective_code = _abortion_code(state, code) + effective_details = details if state.details is None else state.details if state.initial_metadata_allowed: operations = ( cygrpc.operation_send_initial_metadata( _EMPTY_METADATA, _EMPTY_FLAGS), cygrpc.operation_send_status_from_server( - _common.metadata(state.trailing_metadata), code, details, - _EMPTY_FLAGS), + _common.metadata(state.trailing_metadata), effective_code, + effective_details, _EMPTY_FLAGS), ) token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN else: operations = ( cygrpc.operation_send_status_from_server( - _common.metadata(state.trailing_metadata), code, details, - _EMPTY_FLAGS), + _common.metadata(state.trailing_metadata), effective_code, + effective_details, _EMPTY_FLAGS), ) token = _SEND_STATUS_FROM_SERVER_TOKEN call.start_batch( @@ -346,7 +359,7 @@ def _unary_request(rpc_event, state, request_deserializer): def _call_behavior(rpc_event, state, behavior, argument, request_deserializer): context = _Context(rpc_event, state, request_deserializer) try: - return behavior(argument, context) + return behavior(argument, context), True except Exception as e: # pylint: disable=broad-except with state.condition: if e not in state.rpc_errors: @@ -354,7 +367,7 @@ def _call_behavior(rpc_event, state, behavior, argument, request_deserializer): logging.exception(details) _abort( state, rpc_event.operation_call, cygrpc.StatusCode.unknown, details) - return None + return None, False def _take_response_from_response_iterator(rpc_event, state, response_iterator): @@ -415,7 +428,7 @@ def _status(rpc_event, state, serialized_response): with state.condition: if state.client is not _CANCELLED: trailing_metadata = _common.metadata(state.trailing_metadata) - code = _code(state) + code = _completion_code(state) details = _details(state) operations = [ cygrpc.operation_send_status_from_server( @@ -440,9 +453,9 @@ def _unary_response_in_pool( response_serializer): argument = argument_thunk() if argument is not None: - response = _call_behavior( + response, proceed = _call_behavior( rpc_event, state, behavior, argument, request_deserializer) - if response is not None: + if proceed: serialized_response = _serialize_response( rpc_event, state, response, response_serializer) if serialized_response is not None: @@ -455,9 +468,9 @@ def _stream_response_in_pool( response_serializer): argument = argument_thunk() if argument is not None: - response_iterator = _call_behavior( + response_iterator, proceed = _call_behavior( rpc_event, state, behavior, argument, request_deserializer) - if response_iterator is not None: + if proceed: while True: response, proceed = _take_response_from_response_iterator( rpc_event, state, response_iterator) diff --git a/src/python/grpcio/grpc/beta/_client_adaptations.py b/src/python/grpcio/grpc/beta/_client_adaptations.py new file mode 100644 index 0000000000..621fcf2174 --- /dev/null +++ b/src/python/grpcio/grpc/beta/_client_adaptations.py @@ -0,0 +1,566 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Translates gRPC's client-side API into gRPC's client-side Beta API.""" + +import grpc +from grpc._cython import cygrpc +from grpc.beta import interfaces +from grpc.framework.common import cardinality +from grpc.framework.foundation import future +from grpc.framework.interfaces.face import face + +_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = { + grpc.StatusCode.CANCELLED: ( + face.Abortion.Kind.CANCELLED, face.CancellationError), + grpc.StatusCode.UNKNOWN: ( + face.Abortion.Kind.REMOTE_FAILURE, face.RemoteError), + grpc.StatusCode.DEADLINE_EXCEEDED: ( + face.Abortion.Kind.EXPIRED, face.ExpirationError), + grpc.StatusCode.UNIMPLEMENTED: ( + face.Abortion.Kind.LOCAL_FAILURE, face.LocalError), +} + + +def _fully_qualified_method(group, method): + return b'/{}/{}'.format(group, method) + + +def _effective_metadata(metadata, metadata_transformer): + non_none_metadata = () if metadata is None else metadata + if metadata_transformer is None: + return non_none_metadata + else: + return metadata_transformer(non_none_metadata) + + +def _credentials(grpc_call_options): + return None if grpc_call_options is None else grpc_call_options.credentials + + +def _abortion(rpc_error_call): + code = rpc_error_call.code() + pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) + error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0] + return face.Abortion( + error_kind, rpc_error_call.initial_metadata(), + rpc_error_call.trailing_metadata(), code, rpc_error_code.details()) + + +def _abortion_error(rpc_error_call): + code = rpc_error_call.code() + pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) + exception_class = face.AbortionError if pair is None else pair[1] + return exception_class( + rpc_error_call.initial_metadata(), rpc_error_call.trailing_metadata(), + code, rpc_error_call.details()) + + +class _InvocationProtocolContext(interfaces.GRPCInvocationContext): + + def disable_next_request_compression(self): + pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement. + + +class _Rendezvous(future.Future, face.Call): + + def __init__(self, response_future, response_iterator, call): + self._future = response_future + self._iterator = response_iterator + self._call = call + + def cancel(self): + return self._call.cancel() + + def cancelled(self): + return self._future.cancelled() + + def running(self): + return self._future.running() + + def done(self): + return self._future.done() + + def result(self, timeout=None): + try: + return self._future.result(timeout=timeout) + except grpc.RpcError as rpc_error_call: + raise _abortion_error(rpc_error_call) + except grpc.FutureTimeoutError: + raise future.TimeoutError() + except grpc.FutureCancelledError: + raise future.CancelledError() + + def exception(self, timeout=None): + try: + rpc_error_call = self._future.exception(timeout=timeout) + return _abortion_error(rpc_error_call) + except grpc.FutureTimeoutError: + raise future.TimeoutError() + except grpc.FutureCancelledError: + raise future.CancelledError() + + def traceback(self, timeout=None): + try: + return self._future.traceback(timeout=timeout) + except grpc.FutureTimeoutError: + raise future.TimeoutError() + except grpc.FutureCancelledError: + raise future.CancelledError() + + def add_done_callback(self, fn): + self._future.add_done_callback(lambda ignored_callback: fn(self)) + + def __iter__(self): + return self + + def _next(self): + try: + return next(self._iterator) + except grpc.RpcError as rpc_error_call: + raise _abortion_error(rpc_error_call) + + def __next__(self): + return self._next() + + def next(self): + return self._next() + + def is_active(self): + return self._call.is_active() + + def time_remaining(self): + return self._call.time_remaining() + + def add_abortion_callback(self, abortion_callback): + registered = self._call.add_callback( + lambda: abortion_callback(_abortion(self._call))) + return None if registered else _abortion(self._call) + + def protocol_context(self): + return _InvocationProtocolContext() + + def initial_metadata(self): + return self._call.initial_metadata() + + def terminal_metadata(self): + return self._call.terminal_metadata() + + def code(self): + return self._call.code() + + def details(self): + return self._call.details() + + +def _blocking_unary_unary( + channel, group, method, timeout, with_call, protocol_options, metadata, + metadata_transformer, request, request_serializer, response_deserializer): + try: + multi_callable = channel.unary_unary( + _fully_qualified_method(group, method), + request_serializer=request_serializer, + response_deserializer=response_deserializer) + effective_metadata = _effective_metadata(metadata, metadata_transformer) + if with_call: + response, call = multi_callable( + request, timeout=timeout, metadata=effective_metadata, + credentials=_credentials(protocol_options), with_call=True) + return response, _Rendezvous(None, None, call) + else: + return multi_callable( + request, timeout=timeout, metadata=effective_metadata, + credentials=_credentials(protocol_options)) + except grpc.RpcError as rpc_error_call: + raise _abortion_error(rpc_error_call) + + +def _future_unary_unary( + channel, group, method, timeout, protocol_options, metadata, + metadata_transformer, request, request_serializer, response_deserializer): + multi_callable = channel.unary_unary( + _fully_qualified_method(group, method), + request_serializer=request_serializer, + response_deserializer=response_deserializer) + effective_metadata = _effective_metadata(metadata, metadata_transformer) + response_future = multi_callable.future( + request, timeout=timeout, metadata=effective_metadata, + credentials=_credentials(protocol_options)) + return _Rendezvous(response_future, None, response_future) + + +def _unary_stream( + channel, group, method, timeout, protocol_options, metadata, + metadata_transformer, request, request_serializer, response_deserializer): + multi_callable = channel.unary_stream( + _fully_qualified_method(group, method), + request_serializer=request_serializer, + response_deserializer=response_deserializer) + effective_metadata = _effective_metadata(metadata, metadata_transformer) + response_iterator = multi_callable( + request, timeout=timeout, metadata=effective_metadata, + credentials=_credentials(protocol_options)) + return _Rendezvous(None, response_iterator, response_iterator) + + +def _blocking_stream_unary( + channel, group, method, timeout, with_call, protocol_options, metadata, + metadata_transformer, request_iterator, request_serializer, + response_deserializer): + try: + multi_callable = channel.stream_unary( + _fully_qualified_method(group, method), + request_serializer=request_serializer, + response_deserializer=response_deserializer) + effective_metadata = _effective_metadata(metadata, metadata_transformer) + if with_call: + response, call = multi_callable( + request_iterator, timeout=timeout, metadata=effective_metadata, + credentials=_credentials(protocol_options), with_call=True) + return response, _Rendezvous(None, None, call) + else: + return multi_callable( + request_iterator, timeout=timeout, metadata=effective_metadata, + credentials=_credentials(protocol_options)) + except grpc.RpcError as rpc_error_call: + raise _abortion_error(rpc_error_call) + + +def _future_stream_unary( + channel, group, method, timeout, protocol_options, metadata, + metadata_transformer, request_iterator, request_serializer, + response_deserializer): + multi_callable = channel.stream_unary( + _fully_qualified_method(group, method), + request_serializer=request_serializer, + response_deserializer=response_deserializer) + effective_metadata = _effective_metadata(metadata, metadata_transformer) + response_future = multi_callable.future( + request_iterator, timeout=timeout, metadata=effective_metadata, + credentials=_credentials(protocol_options)) + return _Rendezvous(response_future, None, response_future) + + +def _stream_stream( + channel, group, method, timeout, protocol_options, metadata, + metadata_transformer, request_iterator, request_serializer, + response_deserializer): + multi_callable = channel.stream_stream( + _fully_qualified_method(group, method), + request_serializer=request_serializer, + response_deserializer=response_deserializer) + effective_metadata = _effective_metadata(metadata, metadata_transformer) + response_iterator = multi_callable( + request_iterator, timeout=timeout, metadata=effective_metadata, + credentials=_credentials(protocol_options)) + return _Rendezvous(None, response_iterator, response_iterator) + + +class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable): + + def __init__( + self, channel, group, method, metadata_transformer, request_serializer, + response_deserializer): + self._channel = channel + self._group = group + self._method = method + self._metadata_transformer = metadata_transformer + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + def __call__( + self, request, timeout, metadata=None, with_call=False, + protocol_options=None): + return _blocking_unary_unary( + self._channel, self._group, self._method, timeout, with_call, + protocol_options, metadata, self._metadata_transformer, request, + self._request_serializer, self._response_deserializer) + + def future(self, request, timeout, metadata=None, protocol_options=None): + return _future_unary_unary( + self._channel, self._group, self._method, timeout, protocol_options, + metadata, self._metadata_transformer, request, self._request_serializer, + self._response_deserializer) + + def event( + self, request, receiver, abortion_callback, timeout, + metadata=None, protocol_options=None): + raise NotImplementedError() + + +class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable): + + def __init__( + self, channel, group, method, metadata_transformer, request_serializer, + response_deserializer): + self._channel = channel + self._group = group + self._method = method + self._metadata_transformer = metadata_transformer + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + def __call__(self, request, timeout, metadata=None, protocol_options=None): + return _unary_stream( + self._channel, self._group, self._method, timeout, protocol_options, + metadata, self._metadata_transformer, request, self._request_serializer, + self._response_deserializer) + + def event( + self, request, receiver, abortion_callback, timeout, + metadata=None, protocol_options=None): + raise NotImplementedError() + + +class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable): + + def __init__( + self, channel, group, method, metadata_transformer, request_serializer, + response_deserializer): + self._channel = channel + self._group = group + self._method = method + self._metadata_transformer = metadata_transformer + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + def __call__( + self, request_iterator, timeout, metadata=None, with_call=False, + protocol_options=None): + return _blocking_stream_unary( + self._channel, self._group, self._method, timeout, with_call, + protocol_options, metadata, self._metadata_transformer, + request_iterator, self._request_serializer, self._response_deserializer) + + def future( + self, request_iterator, timeout, metadata=None, protocol_options=None): + return _future_stream_unary( + self._channel, self._group, self._method, timeout, protocol_options, + metadata, self._metadata_transformer, request_iterator, + self._request_serializer, self._response_deserializer) + + def event( + self, receiver, abortion_callback, timeout, metadata=None, + protocol_options=None): + raise NotImplementedError() + + +class _StreamStreamMultiCallable(face.StreamStreamMultiCallable): + + def __init__( + self, channel, group, method, metadata_transformer, request_serializer, + response_deserializer): + self._channel = channel + self._group = group + self._method = method + self._metadata_transformer = metadata_transformer + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer + + def __call__( + self, request_iterator, timeout, metadata=None, protocol_options=None): + return _stream_stream( + self._channel, self._group, self._method, timeout, protocol_options, + metadata, self._metadata_transformer, request_iterator, + self._request_serializer, self._response_deserializer) + + def event( + self, receiver, abortion_callback, timeout, metadata=None, + protocol_options=None): + raise NotImplementedError() + + +class _GenericStub(face.GenericStub): + + def __init__( + self, channel, metadata_transformer, request_serializers, + response_deserializers): + self._channel = channel + self._metadata_transformer = metadata_transformer + self._request_serializers = request_serializers or {} + self._response_deserializers = response_deserializers or {} + + def blocking_unary_unary( + self, group, method, request, timeout, metadata=None, + with_call=None, protocol_options=None): + request_serializer = self._request_serializers.get((group, method,)) + response_deserializer = self._response_deserializers.get((group, method,)) + return _blocking_unary_unary( + self._channel, group, method, timeout, with_call, protocol_options, + metadata, self._metadata_transformer, request, request_serializer, + response_deserializer) + + def future_unary_unary( + self, group, method, request, timeout, metadata=None, + protocol_options=None): + request_serializer = self._request_serializers.get((group, method,)) + response_deserializer = self._response_deserializers.get((group, method,)) + return _future_unary_unary( + self._channel, group, method, timeout, protocol_options, metadata, + self._metadata_transformer, request, request_serializer, + response_deserializer) + + def inline_unary_stream( + self, group, method, request, timeout, metadata=None, + protocol_options=None): + request_serializer = self._request_serializers.get((group, method,)) + response_deserializer = self._response_deserializers.get((group, method,)) + return _unary_stream( + self._channel, group, method, timeout, protocol_options, metadata, + self._metadata_transformer, request, request_serializer, + response_deserializer) + + def blocking_stream_unary( + self, group, method, request_iterator, timeout, metadata=None, + with_call=None, protocol_options=None): + request_serializer = self._request_serializers.get((group, method,)) + response_deserializer = self._response_deserializers.get((group, method,)) + return _blocking_stream_unary( + self._channel, group, method, timeout, with_call, protocol_options, + metadata, self._metadata_transformer, request_iterator, + request_serializer, response_deserializer) + + def future_stream_unary( + self, group, method, request_iterator, timeout, metadata=None, + protocol_options=None): + request_serializer = self._request_serializers.get((group, method,)) + response_deserializer = self._response_deserializers.get((group, method,)) + return _future_stream_unary( + self._channel, group, method, timeout, protocol_options, metadata, + self._metadata_transformer, request_iterator, request_serializer, + response_deserializer) + + def inline_stream_stream( + self, group, method, request_iterator, timeout, metadata=None, + protocol_options=None): + request_serializer = self._request_serializers.get((group, method,)) + response_deserializer = self._response_deserializers.get((group, method,)) + return _stream_stream( + self._channel, group, method, timeout, protocol_options, metadata, + self._metadata_transformer, request_iterator, request_serializer, + response_deserializer) + + def event_unary_unary( + self, group, method, request, receiver, abortion_callback, timeout, + metadata=None, protocol_options=None): + raise NotImplementedError() + + def event_unary_stream( + self, group, method, request, receiver, abortion_callback, timeout, + metadata=None, protocol_options=None): + raise NotImplementedError() + + def event_stream_unary( + self, group, method, receiver, abortion_callback, timeout, + metadata=None, protocol_options=None): + raise NotImplementedError() + + def event_stream_stream( + self, group, method, receiver, abortion_callback, timeout, + metadata=None, protocol_options=None): + raise NotImplementedError() + + def unary_unary(self, group, method): + request_serializer = self._request_serializers.get((group, method,)) + response_deserializer = self._response_deserializers.get((group, method,)) + return _UnaryUnaryMultiCallable( + self._channel, group, method, self._metadata_transformer, + request_serializer, response_deserializer) + + def unary_stream(self, group, method): + request_serializer = self._request_serializers.get((group, method,)) + response_deserializer = self._response_deserializers.get((group, method,)) + return _UnaryStreamMultiCallable( + self._channel, group, method, self._metadata_transformer, + request_serializer, response_deserializer) + + def stream_unary(self, group, method): + request_serializer = self._request_serializers.get((group, method,)) + response_deserializer = self._response_deserializers.get((group, method,)) + return _StreamUnaryMultiCallable( + self._channel, group, method, self._metadata_transformer, + request_serializer, response_deserializer) + + def stream_stream(self, group, method): + request_serializer = self._request_serializers.get((group, method,)) + response_deserializer = self._response_deserializers.get((group, method,)) + return _StreamStreamMultiCallable( + self._channel, group, method, self._metadata_transformer, + request_serializer, response_deserializer) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + +class _DynamicStub(face.DynamicStub): + + def __init__(self, generic_stub, group, cardinalities): + self._generic_stub = generic_stub + self._group = group + self._cardinalities = cardinalities + + def __getattr__(self, attr): + method_cardinality = self._cardinalities.get(attr) + if method_cardinality is cardinality.Cardinality.UNARY_UNARY: + return self._generic_stub.unary_unary(self._group, attr) + elif method_cardinality is cardinality.Cardinality.UNARY_STREAM: + return self._generic_stub.unary_stream(self._group, attr) + elif method_cardinality is cardinality.Cardinality.STREAM_UNARY: + return self._generic_stub.stream_unary(self._group, attr) + elif method_cardinality is cardinality.Cardinality.STREAM_STREAM: + return self._generic_stub.stream_stream(self._group, attr) + else: + raise AttributeError('_DynamicStub object has no attribute "%s"!' % attr) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + +def generic_stub( + channel, host, metadata_transformer, request_serializers, + response_deserializers): + return _GenericStub( + channel, metadata_transformer, request_serializers, + response_deserializers) + + +def dynamic_stub( + channel, service, cardinalities, host, metadata_transformer, + request_serializers, response_deserializers): + return _DynamicStub( + _GenericStub( + channel, metadata_transformer, request_serializers, + response_deserializers), + service, cardinalities) diff --git a/src/python/grpcio/grpc/beta/_server_adaptations.py b/src/python/grpcio/grpc/beta/_server_adaptations.py new file mode 100644 index 0000000000..52eadf2315 --- /dev/null +++ b/src/python/grpcio/grpc/beta/_server_adaptations.py @@ -0,0 +1,359 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Translates gRPC's server-side API into gRPC's server-side Beta API.""" + +import collections +import threading + +import grpc +from grpc.beta import interfaces +from grpc.framework.common import cardinality +from grpc.framework.common import style +from grpc.framework.foundation import abandonment +from grpc.framework.foundation import logging_pool +from grpc.framework.foundation import stream +from grpc.framework.interfaces.face import face + +_DEFAULT_POOL_SIZE = 8 + + +class _ServerProtocolContext(interfaces.GRPCServicerContext): + + def __init__(self, servicer_context): + self._servicer_context = servicer_context + + def peer(self): + return self._servicer_context.peer() + + def disable_next_response_compression(self): + pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement. + + +class _FaceServicerContext(face.ServicerContext): + + def __init__(self, servicer_context): + self._servicer_context = servicer_context + + def is_active(self): + return self._servicer_context.is_active() + + def time_remaining(self): + return self._servicer_context.time_remaining() + + def add_abortion_callback(self, abortion_callback): + raise NotImplementedError( + 'add_abortion_callback no longer supported server-side!') + + def cancel(self): + self._servicer_context.cancel() + + def protocol_context(self): + return _ServerProtocolContext(self._servicer_context) + + def invocation_metadata(self): + return self._servicer_context.invocation_metadata() + + def initial_metadata(self, initial_metadata): + self._servicer_context.send_initial_metadata(initial_metadata) + + def terminal_metadata(self, terminal_metadata): + self._servicer_context.set_terminal_metadata(terminal_metadata) + + def code(self, code): + self._servicer_context.set_code(code) + + def details(self, details): + self._servicer_context.set_details(details) + + +def _adapt_unary_request_inline(unary_request_inline): + def adaptation(request, servicer_context): + return unary_request_inline(request, _FaceServicerContext(servicer_context)) + return adaptation + + +def _adapt_stream_request_inline(stream_request_inline): + def adaptation(request_iterator, servicer_context): + return stream_request_inline( + request_iterator, _FaceServicerContext(servicer_context)) + return adaptation + + +class _Callback(stream.Consumer): + + def __init__(self): + self._condition = threading.Condition() + self._values = [] + self._terminated = False + self._cancelled = False + + def consume(self, value): + with self._condition: + self._values.append(value) + self._condition.notify_all() + + def terminate(self): + with self._condition: + self._terminated = True + self._condition.notify_all() + + def consume_and_terminate(self, value): + with self._condition: + self._values.append(value) + self._terminated = True + self._condition.notify_all() + + def cancel(self): + with self._condition: + self._cancelled = True + self._condition.notify_all() + + def draw_one_value(self): + with self._condition: + while True: + if self._cancelled: + raise abandonment.Abandoned() + elif self._values: + return self._values.pop(0) + elif self._terminated: + return None + else: + self._condition.wait() + + def draw_all_values(self): + with self._condition: + while True: + if self._cancelled: + raise abandonment.Abandoned() + elif self._terminated: + all_values = tuple(self._values) + self._values = None + return all_values + else: + self._condition.wait() + + +def _pipe_requests(request_iterator, request_consumer, servicer_context): + for request in request_iterator: + if not servicer_context.is_active(): + return + request_consumer.consume(request) + if not servicer_context.is_active(): + return + request_consumer.terminate() + + +def _adapt_unary_unary_event(unary_unary_event): + def adaptation(request, servicer_context): + callback = _Callback() + if not servicer_context.add_callback(callback.cancel): + raise abandonment.Abandoned() + unary_unary_event( + request, callback.consume_and_terminate, + _FaceServicerContext(servicer_context)) + return callback.draw_all_values()[0] + return adaptation + + +def _adapt_unary_stream_event(unary_stream_event): + def adaptation(request, servicer_context): + callback = _Callback() + if not servicer_context.add_callback(callback.cancel): + raise abandonment.Abandoned() + unary_stream_event( + request, callback, _FaceServicerContext(servicer_context)) + while True: + response = callback.draw_one_value() + if response is None: + return + else: + yield response + return adaptation + + +def _adapt_stream_unary_event(stream_unary_event): + def adaptation(request_iterator, servicer_context): + callback = _Callback() + if not servicer_context.add_callback(callback.cancel): + raise abandonment.Abandoned() + request_consumer = stream_unary_event( + callback.consume_and_terminate, _FaceServicerContext(servicer_context)) + request_pipe_thread = threading.Thread( + target=_pipe_requests, + args=(request_iterator, request_consumer, servicer_context,)) + request_pipe_thread.start() + return callback.draw_all_values()[0] + return adaptation + + +def _adapt_stream_stream_event(stream_stream_event): + def adaptation(request_iterator, servicer_context): + callback = _Callback() + if not servicer_context.add_callback(callback.cancel): + raise abandonment.Abandoned() + request_consumer = stream_stream_event( + callback, _FaceServicerContext(servicer_context)) + request_pipe_thread = threading.Thread( + target=_pipe_requests, + args=(request_iterator, request_consumer, servicer_context,)) + request_pipe_thread.start() + while True: + response = callback.draw_one_value() + if response is None: + return + else: + yield response + return adaptation + + +class _SimpleMethodHandler( + collections.namedtuple( + '_MethodHandler', + ('request_streaming', 'response_streaming', 'request_deserializer', + 'response_serializer', 'unary_unary', 'unary_stream', 'stream_unary', + 'stream_stream',)), + grpc.RpcMethodHandler): + pass + + +def _simple_method_handler( + implementation, request_deserializer, response_serializer): + if implementation.style is style.Service.INLINE: + if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY: + return _SimpleMethodHandler( + False, False, request_deserializer, response_serializer, + _adapt_unary_request_inline(implementation.unary_unary_inline), None, + None, None) + elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM: + return _SimpleMethodHandler( + False, True, request_deserializer, response_serializer, None, + _adapt_unary_request_inline(implementation.unary_stream_inline), None, + None) + elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY: + return _SimpleMethodHandler( + True, False, request_deserializer, response_serializer, None, None, + _adapt_stream_request_inline(implementation.stream_unary_inline), + None) + elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM: + return _SimpleMethodHandler( + True, True, request_deserializer, response_serializer, None, None, + None, + _adapt_stream_request_inline(implementation.stream_stream_inline)) + elif implementation.style is style.Service.EVENT: + if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY: + return _SimpleMethodHandler( + False, False, request_deserializer, response_serializer, + _adapt_unary_unary_event(implementation.unary_unary_event), None, + None, None) + elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM: + return _SimpleMethodHandler( + False, True, request_deserializer, response_serializer, None, + _adapt_unary_stream_event(implementation.unary_stream_event), None, + None) + elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY: + return _SimpleMethodHandler( + True, False, request_deserializer, response_serializer, None, None, + _adapt_stream_unary_event(implementation.stream_unary_event), None) + elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM: + return _SimpleMethodHandler( + True, True, request_deserializer, response_serializer, None, None, + None, _adapt_stream_stream_event(implementation.stream_stream_event)) + + +class _GenericRpcHandler(grpc.GenericRpcHandler): + + def __init__( + self, method_implementations, multi_method_implementation, + request_deserializers, response_serializers): + self._method_implementations = method_implementations + self._multi_method_implementation = multi_method_implementation + self._request_deserializers = request_deserializers or {} + self._response_serializers = response_serializers or {} + + def service(self, handler_call_details): + try: + group_name, method_name = handler_call_details.method.split(b'/')[1:3] + except ValueError: + return None + else: + method_implementation = self._method_implementations.get( + (group_name, method_name,)) + if method_implementation is not None: + return _simple_method_handler( + method_implementation, + self._request_deserializers.get((group_name, method_name,)), + self._response_serializers.get((group_name, method_name,))) + elif self._multi_method_implementation is None: + return None + else: + try: + return None #TODO(nathaniel): call the multimethod. + except face.NoSuchMethodError: + return None + + +class _Server(interfaces.Server): + + def __init__(self, server): + self._server = server + + def add_insecure_port(self, address): + return self._server.add_insecure_port(address) + + def add_secure_port(self, address, server_credentials): + return self._server.add_secure_port(address, server_credentials) + + def start(self): + self._server.start() + + def stop(self, grace): + return self._server.stop(grace) + + def __enter__(self): + self._server.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._server.stop(None) + return False + + +def server( + service_implementations, multi_method_implementation, request_deserializers, + response_serializers, thread_pool, thread_pool_size): + generic_rpc_handler = _GenericRpcHandler( + service_implementations, multi_method_implementation, + request_deserializers, response_serializers) + if thread_pool is None: + effective_thread_pool = logging_pool.pool( + _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size) + else: + effective_thread_pool = thread_pool + return _Server(grpc.server((generic_rpc_handler,), effective_thread_pool)) diff --git a/src/python/grpcio/grpc/beta/implementations.py b/src/python/grpcio/grpc/beta/implementations.py index 822f593323..4ae6e7d675 100644 --- a/src/python/grpcio/grpc/beta/implementations.py +++ b/src/python/grpcio/grpc/beta/implementations.py @@ -35,112 +35,36 @@ import enum import threading # pylint: disable=unused-import # cardinality and face are referenced from specification in this module. -from grpc._adapter import _intermediary_low -from grpc._adapter import _low +import grpc +from grpc import _auth from grpc._adapter import _types -from grpc.beta import _connectivity_channel -from grpc.beta import _server -from grpc.beta import _stub +from grpc.beta import _client_adaptations +from grpc.beta import _server_adaptations from grpc.beta import interfaces from grpc.framework.common import cardinality # pylint: disable=unused-import from grpc.framework.interfaces.face import face # pylint: disable=unused-import -_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = ( - 'Exception calling channel subscription callback!') +ChannelCredentials = grpc.ChannelCredentials +ssl_channel_credentials = grpc.ssl_channel_credentials +CallCredentials = grpc.CallCredentials +metadata_call_credentials = grpc.metadata_call_credentials -class ChannelCredentials(object): - """A value encapsulating the data required to create a secure Channel. - This class and its instances have no supported interface - it exists to define - the type of its instances and its instances exist to be passed to other - functions. - """ - - def __init__(self, low_credentials): - self._low_credentials = low_credentials - - -def ssl_channel_credentials(root_certificates=None, private_key=None, - certificate_chain=None): - """Creates a ChannelCredentials for use with an SSL-enabled Channel. - - Args: - root_certificates: The PEM-encoded root certificates or unset to ask for - them to be retrieved from a default location. - private_key: The PEM-encoded private key to use or unset if no private key - should be used. - certificate_chain: The PEM-encoded certificate chain to use or unset if no - certificate chain should be used. - - Returns: - A ChannelCredentials for use with an SSL-enabled Channel. - """ - return ChannelCredentials(_low.channel_credentials_ssl( - root_certificates, private_key, certificate_chain)) - - -class CallCredentials(object): - """A value encapsulating data asserting an identity over an *established* - channel. May be composed with ChannelCredentials to always assert identity for - every call over that channel. - - This class and its instances have no supported interface - it exists to define - the type of its instances and its instances exist to be passed to other - functions. - """ - - def __init__(self, low_credentials): - self._low_credentials = low_credentials - - -def metadata_call_credentials(metadata_plugin, name=None): - """Construct CallCredentials from an interfaces.GRPCAuthMetadataPlugin. - - Args: - metadata_plugin: An interfaces.GRPCAuthMetadataPlugin to use in constructing - the CallCredentials object. - - Returns: - A CallCredentials object for use in a GRPCCallOptions object. - """ - if name is None: - name = metadata_plugin.__name__ - return CallCredentials( - _low.call_credentials_metadata_plugin(metadata_plugin, name)) - -def composite_call_credentials(call_credentials, additional_call_credentials): - """Compose two CallCredentials to make a new one. +def google_call_credentials(credentials): + """Construct CallCredentials from GoogleCredentials. Args: - call_credentials: A CallCredentials object. - additional_call_credentials: Another CallCredentials object to compose on - top of call_credentials. + credentials: A GoogleCredentials object from the oauth2client library. Returns: A CallCredentials object for use in a GRPCCallOptions object. """ - return CallCredentials( - _low.call_credentials_composite( - call_credentials._low_credentials, - additional_call_credentials._low_credentials)) + return metadata_call_credentials(_auth.GoogleCallCredentials(credentials)) -def composite_channel_credentials(channel_credentials, - additional_call_credentials): - """Compose ChannelCredentials on top of client credentials to make a new one. - - Args: - channel_credentials: A ChannelCredentials object. - additional_call_credentials: A CallCredentials object to compose on - top of channel_credentials. - - Returns: - A ChannelCredentials object for use in a GRPCCallOptions object. - """ - return ChannelCredentials( - _low.channel_credentials_composite( - channel_credentials._low_credentials, - additional_call_credentials._low_credentials)) +access_token_call_credentials = grpc.access_token_call_credentials +composite_call_credentials = grpc.composite_call_credentials +composite_channel_credentials = grpc.composite_channel_credentials class Channel(object): @@ -151,11 +75,8 @@ class Channel(object): unsupported. """ - def __init__(self, low_channel, intermediary_low_channel): - self._low_channel = low_channel - self._intermediary_low_channel = intermediary_low_channel - self._connectivity_channel = _connectivity_channel.ConnectivityChannel( - low_channel) + def __init__(self, channel): + self._channel = channel def subscribe(self, callback, try_to_connect=None): """Subscribes to this Channel's connectivity. @@ -170,7 +91,7 @@ class Channel(object): attempt to connect if it is not already connected and ready to conduct RPCs. """ - self._connectivity_channel.subscribe(callback, try_to_connect) + self._channel.subscribe(callback, try_to_connect=try_to_connect) def unsubscribe(self, callback): """Unsubscribes a callback from this Channel's connectivity. @@ -179,7 +100,7 @@ class Channel(object): callback: A callable previously registered with this Channel from having been passed to its "subscribe" method. """ - self._connectivity_channel.unsubscribe(callback) + self._channel.unsubscribe(callback) def insecure_channel(host, port): @@ -193,9 +114,9 @@ def insecure_channel(host, port): Returns: A Channel to the remote host through which RPCs may be conducted. """ - intermediary_low_channel = _intermediary_low.Channel( - '%s:%d' % (host, port) if port else host, None) - return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access + channel = grpc.insecure_channel( + host if port is None else '%s:%d' % (host, port)) + return Channel(channel) def secure_channel(host, port, channel_credentials): @@ -210,10 +131,9 @@ def secure_channel(host, port, channel_credentials): Returns: A secure Channel to the remote host through which RPCs may be conducted. """ - intermediary_low_channel = _intermediary_low.Channel( - '%s:%d' % (host, port) if port else host, - channel_credentials._low_credentials) - return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access + channel = grpc.secure_channel( + host if port is None else '%s:%d' % (host, port), channel_credentials) + return Channel(channel) class StubOptions(object): @@ -277,12 +197,11 @@ def generic_stub(channel, options=None): A face.GenericStub on which RPCs can be made. """ effective_options = _EMPTY_STUB_OPTIONS if options is None else options - return _stub.generic_stub( - channel._intermediary_low_channel, effective_options.host, # pylint: disable=protected-access - effective_options.metadata_transformer, + return _client_adaptations.generic_stub( + channel._channel, # pylint: disable=protected-access + effective_options.host, effective_options.metadata_transformer, effective_options.request_serializers, - effective_options.response_deserializers, effective_options.thread_pool, - effective_options.thread_pool_size) + effective_options.response_deserializers) def dynamic_stub(channel, service, cardinalities, options=None): @@ -300,55 +219,16 @@ def dynamic_stub(channel, service, cardinalities, options=None): A face.DynamicStub with which RPCs can be invoked. """ effective_options = StubOptions() if options is None else options - return _stub.dynamic_stub( - channel._intermediary_low_channel, effective_options.host, service, # pylint: disable=protected-access - cardinalities, effective_options.metadata_transformer, + return _client_adaptations.dynamic_stub( + channel._channel, # pylint: disable=protected-access + service, cardinalities, effective_options.host, + effective_options.metadata_transformer, effective_options.request_serializers, - effective_options.response_deserializers, effective_options.thread_pool, - effective_options.thread_pool_size) - + effective_options.response_deserializers) -class ServerCredentials(object): - """A value encapsulating the data required to open a secure port on a Server. - This class and its instances have no supported interface - it exists to define - the type of its instances and its instances exist to be passed to other - functions. - """ - - def __init__(self, low_credentials): - self._low_credentials = low_credentials - - -def ssl_server_credentials( - private_key_certificate_chain_pairs, root_certificates=None, - require_client_auth=False): - """Creates a ServerCredentials for use with an SSL-enabled Server. - - Args: - private_key_certificate_chain_pairs: A nonempty sequence each element of - which is a pair the first element of which is a PEM-encoded private key - and the second element of which is the corresponding PEM-encoded - certificate chain. - root_certificates: PEM-encoded client root certificates to be used for - verifying authenticated clients. If omitted, require_client_auth must also - be omitted or be False. - require_client_auth: A boolean indicating whether or not to require clients - to be authenticated. May only be True if root_certificates is not None. - - Returns: - A ServerCredentials for use with an SSL-enabled Server. - """ - if len(private_key_certificate_chain_pairs) == 0: - raise ValueError( - 'At least one private key-certificate chain pairis required!') - elif require_client_auth and root_certificates is None: - raise ValueError( - 'Illegal to require client auth without providing root certificates!') - else: - return ServerCredentials(_low.server_credentials_ssl( - root_certificates, private_key_certificate_chain_pairs, - require_client_auth)) +ServerCredentials = grpc.ServerCredentials +ssl_server_credentials = grpc.ssl_server_credentials class ServerOptions(object): @@ -421,9 +301,8 @@ def server(service_implementations, options=None): An interfaces.Server with which RPCs can be serviced. """ effective_options = _EMPTY_SERVER_OPTIONS if options is None else options - return _server.server( + return _server_adaptations.server( service_implementations, effective_options.multi_method_implementation, effective_options.request_deserializers, effective_options.response_serializers, effective_options.thread_pool, - effective_options.thread_pool_size, effective_options.default_timeout, - effective_options.maximum_timeout) + effective_options.thread_pool_size) diff --git a/src/python/grpcio/grpc/beta/interfaces.py b/src/python/grpcio/grpc/beta/interfaces.py index 24de9ad1a8..4343b6c4b5 100644 --- a/src/python/grpcio/grpc/beta/interfaces.py +++ b/src/python/grpcio/grpc/beta/interfaces.py @@ -30,53 +30,13 @@ """Constants and interfaces of the Beta API of gRPC Python.""" import abc -import enum import six -from grpc._adapter import _types +import grpc - -@enum.unique -class ChannelConnectivity(enum.Enum): - """Mirrors grpc_connectivity_state in the gRPC Core. - - Attributes: - IDLE: The channel is idle. - CONNECTING: The channel is connecting. - READY: The channel is ready to conduct RPCs. - TRANSIENT_FAILURE: The channel has seen a failure from which it expects to - recover. - FATAL_FAILURE: The channel has seen a failure from which it cannot recover. - """ - IDLE = (_types.ConnectivityState.IDLE, 'idle',) - CONNECTING = (_types.ConnectivityState.CONNECTING, 'connecting',) - READY = (_types.ConnectivityState.READY, 'ready',) - TRANSIENT_FAILURE = ( - _types.ConnectivityState.TRANSIENT_FAILURE, 'transient failure',) - FATAL_FAILURE = (_types.ConnectivityState.FATAL_FAILURE, 'fatal failure',) - - -@enum.unique -class StatusCode(enum.Enum): - """Mirrors grpc_status_code in the C core.""" - OK = 0 - CANCELLED = 1 - UNKNOWN = 2 - INVALID_ARGUMENT = 3 - DEADLINE_EXCEEDED = 4 - NOT_FOUND = 5 - ALREADY_EXISTS = 6 - PERMISSION_DENIED = 7 - RESOURCE_EXHAUSTED = 8 - FAILED_PRECONDITION = 9 - ABORTED = 10 - OUT_OF_RANGE = 11 - UNIMPLEMENTED = 12 - INTERNAL = 13 - UNAVAILABLE = 14 - DATA_LOSS = 15 - UNAUTHENTICATED = 16 +ChannelConnectivity = grpc.ChannelConnectivity +StatusCode = grpc.StatusCode class GRPCCallOptions(object): @@ -106,46 +66,9 @@ def grpc_call_options(disable_compression=False, credentials=None): """ return GRPCCallOptions(disable_compression, None, credentials) - -class GRPCAuthMetadataContext(six.with_metaclass(abc.ABCMeta)): - """Provides information to call credentials metadata plugins. - - Attributes: - service_url: A string URL of the service being called into. - method_name: A string of the fully qualified method name being called. - """ - - -class GRPCAuthMetadataPluginCallback(six.with_metaclass(abc.ABCMeta)): - """Callback object received by a metadata plugin.""" - - def __call__(self, metadata, error): - """Inform the gRPC runtime of the metadata to construct a CallCredentials. - - Args: - metadata: An iterable of 2-sequences (e.g. tuples) of metadata key/value - pairs. - error: An Exception to indicate error or None to indicate success. - """ - raise NotImplementedError() - - -class GRPCAuthMetadataPlugin(six.with_metaclass(abc.ABCMeta)): - """ - """ - - def __call__(self, context, callback): - """Invoke the plugin. - - Must not block. Need only be called by the gRPC runtime. - - Args: - context: A GRPCAuthMetadataContext providing information on what the - plugin is being used for. - callback: A GRPCAuthMetadataPluginCallback to be invoked either - synchronously or asynchronously. - """ - raise NotImplementedError() +GRPCAuthMetadataContext = grpc.AuthMetadataContext +GRPCAuthMetadataPluginCallback = grpc.AuthMetadataPluginCallback +GRPCAuthMetadataPlugin = grpc.AuthMetadataPlugin class GRPCServicerContext(six.with_metaclass(abc.ABCMeta)): diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index c7d042d565..b668455b72 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -42,7 +42,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/support/cpu_windows.c', 'src/core/lib/support/env_linux.c', 'src/core/lib/support/env_posix.c', - 'src/core/lib/support/env_win32.c', + 'src/core/lib/support/env_windows.c', 'src/core/lib/support/histogram.c', 'src/core/lib/support/host_port.c', 'src/core/lib/support/load_file.c', @@ -50,31 +50,31 @@ CORE_SOURCE_FILES = [ 'src/core/lib/support/log_android.c', 'src/core/lib/support/log_linux.c', 'src/core/lib/support/log_posix.c', - 'src/core/lib/support/log_win32.c', + 'src/core/lib/support/log_windows.c', 'src/core/lib/support/murmur_hash.c', 'src/core/lib/support/slice.c', 'src/core/lib/support/slice_buffer.c', 'src/core/lib/support/stack_lockfree.c', 'src/core/lib/support/string.c', 'src/core/lib/support/string_posix.c', - 'src/core/lib/support/string_util_win32.c', - 'src/core/lib/support/string_win32.c', + 'src/core/lib/support/string_util_windows.c', + 'src/core/lib/support/string_windows.c', 'src/core/lib/support/subprocess_posix.c', 'src/core/lib/support/subprocess_windows.c', 'src/core/lib/support/sync.c', 'src/core/lib/support/sync_posix.c', - 'src/core/lib/support/sync_win32.c', + 'src/core/lib/support/sync_windows.c', 'src/core/lib/support/thd.c', 'src/core/lib/support/thd_posix.c', - 'src/core/lib/support/thd_win32.c', + 'src/core/lib/support/thd_windows.c', 'src/core/lib/support/time.c', 'src/core/lib/support/time_posix.c', 'src/core/lib/support/time_precise.c', - 'src/core/lib/support/time_win32.c', + 'src/core/lib/support/time_windows.c', 'src/core/lib/support/tls_pthread.c', 'src/core/lib/support/tmpfile_msys.c', 'src/core/lib/support/tmpfile_posix.c', - 'src/core/lib/support/tmpfile_win32.c', + 'src/core/lib/support/tmpfile_windows.c', 'src/core/lib/support/wrap_memcpy.c', 'src/core/lib/surface/init.c', 'src/core/lib/channel/channel_args.c', @@ -189,7 +189,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/security/credentials/credentials_metadata.c', 'src/core/lib/security/credentials/fake/fake_credentials.c', 'src/core/lib/security/credentials/google_default/credentials_posix.c', - 'src/core/lib/security/credentials/google_default/credentials_win32.c', + 'src/core/lib/security/credentials/google_default/credentials_windows.c', 'src/core/lib/security/credentials/google_default/google_default_credentials.c', 'src/core/lib/security/credentials/iam/iam_credentials.c', 'src/core/lib/security/credentials/jwt/json_token.c', @@ -231,9 +231,6 @@ CORE_SOURCE_FILES = [ 'src/core/ext/client_config/uri_parser.c', 'src/core/ext/transport/chttp2/server/insecure/server_chttp2.c', 'src/core/ext/transport/chttp2/client/insecure/channel_create.c', - 'src/core/ext/transport/cronet/client/secure/cronet_channel_create.c', - 'src/core/ext/transport/cronet/transport/cronet_api_dummy.c', - 'src/core/ext/transport/cronet/transport/cronet_transport.c', 'src/core/ext/lb_policy/grpclb/load_balancer_api.c', 'src/core/ext/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c', 'third_party/nanopb/pb_common.c', @@ -243,6 +240,8 @@ CORE_SOURCE_FILES = [ 'src/core/ext/lb_policy/round_robin/round_robin.c', 'src/core/ext/resolver/dns/native/dns_resolver.c', 'src/core/ext/resolver/sockaddr/sockaddr_resolver.c', + 'src/core/ext/load_reporting/load_reporting.c', + 'src/core/ext/load_reporting/load_reporting_filter.c', 'src/core/ext/census/base_resources.c', 'src/core/ext/census/context.c', 'src/core/ext/census/gen/census.pb.c', diff --git a/src/python/grpcio/tests/interop/client.py b/src/python/grpcio/tests/interop/client.py index db29eb4aa7..e3d5545a02 100644 --- a/src/python/grpcio/tests/interop/client.py +++ b/src/python/grpcio/tests/interop/client.py @@ -65,39 +65,34 @@ def _args(): help='email address of the default service account', type=str) return parser.parse_args() -def _oauth_access_token(args): - credentials = oauth2client_client.GoogleCredentials.get_application_default() - scoped_credentials = credentials.create_scoped([args.oauth_scope]) - return scoped_credentials.get_access_token().access_token def _stub(args): - if args.oauth_scope: - if args.test_case == 'oauth2_auth_token': - # TODO(jtattermusch): This testcase sets the auth metadata key-value - # manually, which also means that the user would need to do the same - # thing every time he/she would like to use and out of band oauth token. - # The transformer function that produces the metadata key-value from - # the access token should be provided by gRPC auth library. - access_token = _oauth_access_token(args) - metadata_transformer = lambda x: [ - ('authorization', 'Bearer %s' % access_token)] - else: - metadata_transformer = lambda x: [ - ('authorization', 'Bearer %s' % _oauth_access_token(args))] + if args.test_case == 'oauth2_auth_token': + creds = oauth2client_client.GoogleCredentials.get_application_default() + scoped_creds = creds.create_scoped([args.oauth_scope]) + access_token = scoped_creds.get_access_token().access_token + call_creds = implementations.access_token_call_credentials(access_token) + elif args.test_case == 'compute_engine_creds': + creds = oauth2client_client.GoogleCredentials.get_application_default() + scoped_creds = creds.create_scoped([args.oauth_scope]) + call_creds = implementations.google_call_credentials(scoped_creds) else: - metadata_transformer = lambda x: [] + call_creds = None if args.use_tls: if args.use_test_ca: root_certificates = resources.test_root_certificates() else: root_certificates = None # will load default roots. + channel_creds = implementations.ssl_channel_credentials(root_certificates) + if call_creds is not None: + channel_creds = implementations.composite_channel_credentials( + channel_creds, call_creds) + channel = test_utilities.not_really_secure_channel( - args.server_host, args.server_port, - implementations.ssl_channel_credentials(root_certificates), + args.server_host, args.server_port, channel_creds, args.server_host_override) - stub = test_pb2.beta_create_TestService_stub( - channel, metadata_transformer=metadata_transformer) + stub = test_pb2.beta_create_TestService_stub(channel) else: channel = implementations.insecure_channel( args.server_host, args.server_port) diff --git a/src/python/grpcio/tests/interop/methods.py b/src/python/grpcio/tests/interop/methods.py index 67862ed7d3..d5ef0c68bb 100644 --- a/src/python/grpcio/tests/interop/methods.py +++ b/src/python/grpcio/tests/interop/methods.py @@ -39,6 +39,8 @@ import time from oauth2client import client as oauth2client_client +from grpc.beta import implementations +from grpc.beta import interfaces from grpc.framework.common import cardinality from grpc.framework.interfaces.face import face @@ -88,13 +90,15 @@ class TestService(test_pb2.BetaTestServiceServicer): return self.FullDuplexCall(request_iterator, context) -def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope): +def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope, + protocol_options=None): with stub: request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=314159, payload=messages_pb2.Payload(body=b'\x00' * 271828), fill_username=fill_username, fill_oauth_scope=fill_oauth_scope) - response_future = stub.UnaryCall.future(request, _TIMEOUT) + response_future = stub.UnaryCall.future(request, _TIMEOUT, + protocol_options=protocol_options) response = response_future.result() if response.payload.type is not messages_pb2.COMPRESSABLE: raise ValueError( @@ -303,7 +307,24 @@ def _oauth2_auth_token(stub, args): if args.oauth_scope.find(response.oauth_scope) == -1: raise ValueError( 'expected to find oauth scope "%s" in received "%s"' % - (response.oauth_scope, args.oauth_scope)) + (response.oauth_scope, args.oauth_scope)) + + +def _per_rpc_creds(stub, args): + json_key_filename = os.environ[ + oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS] + wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] + credentials = oauth2client_client.GoogleCredentials.get_application_default() + scoped_credentials = credentials.create_scoped([args.oauth_scope]) + call_creds = implementations.google_call_credentials(scoped_credentials) + options = interfaces.grpc_call_options(disable_compression=False, + credentials=call_creds) + response = _large_unary_common_behavior(stub, True, False, + protocol_options=options) + if wanted_email != response.username: + raise ValueError( + 'expected username %s, got %s' % (wanted_email, response.username)) + @enum.unique class TestCase(enum.Enum): @@ -317,6 +338,7 @@ class TestCase(enum.Enum): EMPTY_STREAM = 'empty_stream' COMPUTE_ENGINE_CREDS = 'compute_engine_creds' OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' + PER_RPC_CREDS = 'per_rpc_creds' TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' def test_interoperability(self, stub, args): @@ -342,5 +364,7 @@ class TestCase(enum.Enum): _compute_engine_creds(stub, args) elif self is TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token(stub, args) + elif self is TestCase.PER_RPC_CREDS: + _per_rpc_creds(stub, args) else: raise NotImplementedError('Test case "%s" not implemented!' % self.name) diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json index fb357ea848..8dc47bf69d 100644 --- a/src/python/grpcio/tests/tests.json +++ b/src/python/grpcio/tests/tests.json @@ -1,4 +1,6 @@ [ + "_auth_test.AccessTokenCallCredentialsTest", + "_auth_test.GoogleCallCredentialsTest", "_base_interface_test.AsyncEasyTest", "_base_interface_test.AsyncPeasyTest", "_base_interface_test.SyncEasyTest", @@ -33,6 +35,7 @@ "_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest", "_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest", "_health_servicer_test.HealthServicerTest", + "_implementations_test.CallCredentialsTest", "_implementations_test.ChannelCredentialsTest", "_insecure_interop_test.InsecureInteropTest", "_intermediary_low_test.CancellationTest", @@ -45,6 +48,7 @@ "_low_test.HangingServerShutdown", "_low_test.InsecureServerInsecureClient", "_not_found_test.NotFoundTest", + "_read_some_but_not_all_responses_test.ReadSomeButNotAllResponsesTest", "_rpc_test.RPCTest", "_sanity_test.Sanity", "_secure_interop_test.SecureInteropTest", diff --git a/src/python/grpcio/tests/unit/_auth_test.py b/src/python/grpcio/tests/unit/_auth_test.py new file mode 100644 index 0000000000..c31f7b06f7 --- /dev/null +++ b/src/python/grpcio/tests/unit/_auth_test.py @@ -0,0 +1,96 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests of standard AuthMetadataPlugins.""" + +import collections +import threading +import unittest + +from grpc import _auth + + +class MockGoogleCreds(object): + + def get_access_token(self): + token = collections.namedtuple('MockAccessTokenInfo', + ('access_token', 'expires_in')) + token.access_token = 'token' + return token + + +class MockExceptionGoogleCreds(object): + + def get_access_token(self): + raise Exception() + + +class GoogleCallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials_success(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertEqual(metadata, (('authorization', 'Bearer token'),)) + self.assertIsNone(error) + callback_event.set() + + call_creds = _auth.GoogleCallCredentials(MockGoogleCreds()) + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + def test_google_call_credentials_error(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertIsNotNone(error) + callback_event.set() + + call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds()) + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + +class AccessTokenCallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials_success(self): + callback_event = threading.Event() + + def mock_callback(metadata, error): + self.assertEqual(metadata, (('authorization', 'Bearer token'),)) + self.assertIsNone(error) + callback_event.set() + + call_creds = _auth.AccessTokenCallCredentials('token') + call_creds(None, mock_callback) + self.assertTrue(callback_event.wait(1.0)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py new file mode 100644 index 0000000000..6ae7a90fbe --- /dev/null +++ b/src/python/grpcio/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -0,0 +1,251 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test a corner-case at the level of the Cython API.""" + +import threading +import unittest + +from grpc._cython import cygrpc + +_INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) +_EMPTY_FLAGS = 0 +_EMPTY_METADATA = cygrpc.Metadata(()) + + +class _ServerDriver(object): + + def __init__(self, completion_queue, shutdown_tag): + self._condition = threading.Condition() + self._completion_queue = completion_queue + self._shutdown_tag = shutdown_tag + self._events = [] + self._saw_shutdown_tag = False + + def start(self): + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._condition.notify() + if event.tag is self._shutdown_tag: + self._saw_shutdown_tag = True + break + thread = threading.Thread(target=in_thread) + thread.start() + + def done(self): + with self._condition: + return self._saw_shutdown_tag + + def first_event(self): + with self._condition: + while not self._events: + self._condition.wait() + return self._events[0] + + def events(self): + with self._condition: + while not self._saw_shutdown_tag: + self._condition.wait() + return tuple(self._events) + + +class _QueueDriver(object): + + def __init__(self, condition, completion_queue, due): + self._condition = condition + self._completion_queue = completion_queue + self._due = due + self._events = [] + self._returned = False + + def start(self): + def in_thread(): + while True: + event = self._completion_queue.poll() + with self._condition: + self._events.append(event) + self._due.remove(event.tag) + self._condition.notify_all() + if not self._due: + self._returned = True + return + thread = threading.Thread(target=in_thread) + thread.start() + + def done(self): + with self._condition: + return self._returned + + def event_with_tag(self, tag): + with self._condition: + while True: + for event in self._events: + if event.tag is tag: + return event + self._condition.wait() + + def events(self): + with self._condition: + while not self._returned: + self._condition.wait() + return tuple(self._events) + + +class ReadSomeButNotAllResponsesTest(unittest.TestCase): + + def testReadSomeButNotAllResponses(self): + server_completion_queue = cygrpc.CompletionQueue() + server = cygrpc.Server() + server.register_completion_queue(server_completion_queue) + port = server.add_http2_port('[::]:0') + server.start() + channel = cygrpc.Channel('localhost:{}'.format(port)) + + server_shutdown_tag = 'server_shutdown_tag' + server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag) + server_driver.start() + + client_condition = threading.Condition() + client_due = set() + client_completion_queue = cygrpc.CompletionQueue() + client_driver = _QueueDriver( + client_condition, client_completion_queue, client_due) + client_driver.start() + + server_call_condition = threading.Condition() + server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' + server_send_first_message_tag = 'server_send_first_message_tag' + server_send_second_message_tag = 'server_send_second_message_tag' + server_complete_rpc_tag = 'server_complete_rpc_tag' + server_call_due = set(( + server_send_initial_metadata_tag, + server_send_first_message_tag, + server_send_second_message_tag, + server_complete_rpc_tag, + )) + server_call_completion_queue = cygrpc.CompletionQueue() + server_call_driver = _QueueDriver( + server_call_condition, server_call_completion_queue, server_call_due) + server_call_driver.start() + + server_rpc_tag = 'server_rpc_tag' + request_call_result = server.request_call( + server_call_completion_queue, server_completion_queue, server_rpc_tag) + + client_call = channel.create_call( + None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None, + _INFINITE_FUTURE) + client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' + client_complete_rpc_tag = 'client_complete_rpc_tag' + with client_condition: + client_receive_initial_metadata_start_batch_result = ( + client_call.start_batch(cygrpc.Operations([ + cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), + ]), client_receive_initial_metadata_tag)) + client_due.add(client_receive_initial_metadata_tag) + client_complete_rpc_start_batch_result = ( + client_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS), + cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), + cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), + ]), client_complete_rpc_tag)) + client_due.add(client_complete_rpc_tag) + + server_rpc_event = server_driver.first_event() + + with server_call_condition: + server_send_initial_metadata_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_initial_metadata( + _EMPTY_METADATA, _EMPTY_FLAGS), + ]), server_send_initial_metadata_tag)) + server_send_first_message_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS), + ]), server_send_first_message_tag)) + server_send_initial_metadata_event = server_call_driver.event_with_tag( + server_send_initial_metadata_tag) + server_send_first_message_event = server_call_driver.event_with_tag( + server_send_first_message_tag) + with server_call_condition: + server_send_second_message_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS), + ]), server_send_second_message_tag)) + server_complete_rpc_start_batch_result = ( + server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), + cygrpc.operation_send_status_from_server( + cygrpc.Metadata(()), cygrpc.StatusCode.ok, b'test details', + _EMPTY_FLAGS), + ]), server_complete_rpc_tag)) + server_send_second_message_event = server_call_driver.event_with_tag( + server_send_second_message_tag) + server_complete_rpc_event = server_call_driver.event_with_tag( + server_complete_rpc_tag) + server_call_driver.events() + + with client_condition: + client_receive_first_message_tag = 'client_receive_first_message_tag' + client_receive_first_message_start_batch_result = ( + client_call.start_batch(cygrpc.Operations([ + cygrpc.operation_receive_message(_EMPTY_FLAGS), + ]), client_receive_first_message_tag)) + client_due.add(client_receive_first_message_tag) + client_receive_first_message_event = client_driver.event_with_tag( + client_receive_first_message_tag) + + client_call_cancel_result = client_call.cancel() + client_driver.events() + + server.shutdown(server_completion_queue, server_shutdown_tag) + server.cancel_all_calls() + server_driver.events() + + self.assertEqual(cygrpc.CallError.ok, request_call_result) + self.assertEqual( + cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result) + self.assertEqual( + cygrpc.CallError.ok, client_receive_initial_metadata_start_batch_result) + self.assertEqual( + cygrpc.CallError.ok, client_complete_rpc_start_batch_result) + self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result) + self.assertIs(server_rpc_tag, server_rpc_event.tag) + self.assertEqual( + cygrpc.CompletionType.operation_complete, server_rpc_event.type) + self.assertIsInstance(server_rpc_event.operation_call, cygrpc.Call) + self.assertEqual(0, len(server_rpc_event.batch_operations)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/beta/_implementations_test.py b/src/python/grpcio/tests/unit/beta/_implementations_test.py index 26be670c45..127f93e9bb 100644 --- a/src/python/grpcio/tests/unit/beta/_implementations_test.py +++ b/src/python/grpcio/tests/unit/beta/_implementations_test.py @@ -29,8 +29,11 @@ """Tests the implementations module of the gRPC Python Beta API.""" +import datetime import unittest +from oauth2client import client as oauth2client_client + from grpc.beta import implementations from tests.unit import resources @@ -49,5 +52,19 @@ class ChannelCredentialsTest(unittest.TestCase): channel_credentials, implementations.ChannelCredentials) +class CallCredentialsTest(unittest.TestCase): + + def test_google_call_credentials(self): + creds = oauth2client_client.GoogleCredentials( + 'token', 'client_id', 'secret', 'refresh_token', + datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/', + 'user_agent') + call_creds = implementations.google_call_credentials(creds) + self.assertIsInstance(call_creds, implementations.CallCredentials) + + def test_access_token_call_credentials(self): + call_creds = implementations.access_token_call_credentials('token') + self.assertIsInstance(call_creds, implementations.CallCredentials) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/beta/test_utilities.py b/src/python/grpcio/tests/unit/beta/test_utilities.py index 0313e06a93..66b5f72897 100644 --- a/src/python/grpcio/tests/unit/beta/test_utilities.py +++ b/src/python/grpcio/tests/unit/beta/test_utilities.py @@ -29,7 +29,7 @@ """Test-appropriate entry points into the gRPC Python Beta API.""" -from grpc._adapter import _intermediary_low +import grpc from grpc.beta import implementations @@ -48,9 +48,8 @@ def not_really_secure_channel( An implementations.Channel to the remote host through which RPCs may be conducted. """ - hostport = '%s:%d' % (host, port) - intermediary_low_channel = _intermediary_low.Channel( - hostport, channel_credentials._low_credentials, - server_host_override=server_host_override) - return implementations.Channel( - intermediary_low_channel._internal, intermediary_low_channel) + target = '%s:%d' % (host, port) + channel = grpc.secure_channel( + target, ((b'grpc.ssl_target_name_override', server_host_override,),), + channel_credentials._credentials) + return implementations.Channel(channel) |