diff options
author | Masood Malekghassemi <soltanmm@users.noreply.github.com> | 2015-11-24 16:39:21 -0800 |
---|---|---|
committer | Masood Malekghassemi <soltanmm@users.noreply.github.com> | 2015-12-07 17:10:55 -0800 |
commit | 0f1bf3238709044d7d2575244e10193221361a84 (patch) | |
tree | a297a43240b70c5c4659d16aa5b526b57cc53da3 /src/python/grpcio | |
parent | 25ef5c8fad43124959d5d8d7586d5bd61dbb1194 (diff) |
Add metadata auth plugin API support
Diffstat (limited to 'src/python/grpcio')
18 files changed, 622 insertions, 65 deletions
diff --git a/src/python/grpcio/grpc/_adapter/_implementations.py b/src/python/grpcio/grpc/_adapter/_implementations.py new file mode 100644 index 0000000000..b85f228bf6 --- /dev/null +++ b/src/python/grpcio/grpc/_adapter/_implementations.py @@ -0,0 +1,48 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import collections + +from grpc.beta import interfaces + +class AuthMetadataContext(collections.namedtuple( + 'AuthMetadataContext', [ + 'service_url', + 'method_name' + ]), interfaces.GRPCAuthMetadataContext): + pass + + +class AuthMetadataPluginCallback(interfaces.GRPCAuthMetadataContext): + + def __init__(self, callback): + self._callback = callback + + def __call__(self, metadata, error): + self._callback(metadata, error) diff --git a/src/python/grpcio/grpc/_adapter/_intermediary_low.py b/src/python/grpcio/grpc/_adapter/_intermediary_low.py index f87446eae1..9698ffeabf 100644 --- a/src/python/grpcio/grpc/_adapter/_intermediary_low.py +++ b/src/python/grpcio/grpc/_adapter/_intermediary_low.py @@ -173,20 +173,17 @@ class Call(object): return self._internal.peer() def set_credentials(self, creds): - return self._internal.set_credentials(creds._internal) + return self._internal.set_credentials(creds) class Channel(object): """Adapter from old _low.Channel interface to new _low.Channel.""" - def __init__(self, hostport, client_credentials, server_host_override=None): + def __init__(self, hostport, channel_credentials, server_host_override=None): args = [] if server_host_override: args.append((_types.GrpcChannelArgumentKeys.SSL_TARGET_NAME_OVERRIDE.value, server_host_override)) - creds = None - if client_credentials: - creds = client_credentials._internal - self._internal = _low.Channel(hostport, args, creds) + self._internal = _low.Channel(hostport, args, channel_credentials) class CompletionQueue(object): @@ -245,7 +242,7 @@ class Server(object): if server_credentials is None: return self._internal.add_http2_port(addr, None) else: - return self._internal.add_http2_port(addr, server_credentials._internal) + return self._internal.add_http2_port(addr, server_credentials) def start(self): return self._internal.start() @@ -259,17 +256,3 @@ class Server(object): def stop(self): return self._internal.shutdown(_TagAdapter(None, Event.Kind.STOP)) - -class ClientCredentials(object): - """Adapter from old _low.ClientCredentials interface to new _low.ChannelCredentials.""" - - def __init__(self, root_certificates, private_key, certificate_chain): - self._internal = _low.channel_credentials_ssl(root_certificates, private_key, certificate_chain) - - -class ServerCredentials(object): - """Adapter from old _low.ServerCredentials interface to new _low.ServerCredentials.""" - - def __init__(self, root_credentials, pair_sequence, force_client_auth): - self._internal = _low.server_credentials_ssl( - root_credentials, pair_sequence, force_client_auth) diff --git a/src/python/grpcio/grpc/_adapter/_low.py b/src/python/grpcio/grpc/_adapter/_low.py index 264c33b484..b13d8dd9dd 100644 --- a/src/python/grpcio/grpc/_adapter/_low.py +++ b/src/python/grpcio/grpc/_adapter/_low.py @@ -27,8 +27,11 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import threading + from grpc import _grpcio_metadata from grpc._cython import cygrpc +from grpc._adapter import _implementations from grpc._adapter import _types _USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__) @@ -37,6 +40,9 @@ ChannelCredentials = cygrpc.ChannelCredentials CallCredentials = cygrpc.CallCredentials ServerCredentials = cygrpc.ServerCredentials +channel_credentials_composite = cygrpc.channel_credentials_composite +call_credentials_composite = cygrpc.call_credentials_composite + def server_credentials_ssl(root_credentials, pair_sequence, force_client_auth): return cygrpc.server_credentials_ssl( root_credentials, @@ -51,6 +57,80 @@ def channel_credentials_ssl( return cygrpc.channel_credentials_ssl(root_certificates, pair) +class _WrappedCygrpcCallback(object): + + def __init__(self, cygrpc_callback): + self.is_called = False + self.error = None + self.is_called_lock = threading.Lock() + self.cygrpc_callback = cygrpc_callback + + def _invoke_failure(self, error): + # TODO(atash) translate different Exception superclasses into different + # status codes. + self.cygrpc_callback( + cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message) + + def _invoke_success(self, metadata): + try: + cygrpc_metadata = cygrpc.Metadata( + cygrpc.Metadatum(key, value) + for key, value in metadata) + except Exception as error: + self._invoke_failure(error) + return + self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '') + + def __call__(self, metadata, error): + with self.is_called_lock: + if self.is_called: + raise RuntimeError('callback should only ever be invoked once') + if self.error: + self._invoke_failure(self.error) + return + self.is_called = True + if error is None: + self._invoke_success(metadata) + else: + self._invoke_failure(error) + + def notify_failure(self, error): + with self.is_called_lock: + if not self.is_called: + self.error = error + + +class _WrappedPlugin(object): + + def __init__(self, plugin): + self.plugin = plugin + + def __call__(self, context, cygrpc_callback): + wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback) + wrapped_context = _implementations.AuthMetadataContext(context.service_url, + context.method_name) + try: + self.plugin( + wrapped_context, + _implementations.AuthMetadataPluginCallback(wrapped_cygrpc_callback)) + except Exception as error: + wrapped_cygrpc_callback.notify_failure(error) + raise + + +def call_credentials_metadata_plugin(plugin, name): + """ + Args: + plugin: A callable accepting a _types.AuthMetadataContext + object and a callback (itself accepting a list of metadata key/value + 2-tuples and a None-able exception value). The callback must be eventually + called, but need not be called in plugin's invocation. + plugin's invocation must be non-blocking. + """ + return cygrpc.call_credentials_metadata_plugin( + cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name)) + + class CompletionQueue(_types.CompletionQueue): def __init__(self): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd index 7a9fa7b76d..db9f8ddec9 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd @@ -27,7 +27,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +cimport cpython + from grpc._cython._cygrpc cimport grpc +from grpc._cython._cygrpc cimport records cdef class ChannelCredentials: @@ -49,3 +52,23 @@ cdef class ServerCredentials: cdef grpc.grpc_ssl_pem_key_cert_pair *c_ssl_pem_key_cert_pairs cdef size_t c_ssl_pem_key_cert_pairs_count cdef list references + + +cdef class CredentialsMetadataPlugin: + + cdef object plugin_callback + cdef str plugin_name + + cdef grpc.grpc_metadata_credentials_plugin make_c_plugin(self) + + +cdef class AuthMetadataContext: + + cdef grpc.grpc_auth_metadata_context context + + +cdef void plugin_get_metadata( + void *state, grpc.grpc_auth_metadata_context context, + grpc.grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil + +cdef void plugin_destroy_c_plugin_state(void *state) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx index e6a22e7625..a968894967 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx @@ -27,6 +27,8 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +cimport cpython + from grpc._cython._cygrpc cimport grpc from grpc._cython._cygrpc cimport records @@ -78,6 +80,66 @@ cdef class ServerCredentials: grpc.grpc_server_credentials_release(self.c_credentials) +cdef class CredentialsMetadataPlugin: + + def __cinit__(self, object plugin_callback, str name): + """ + Args: + plugin_callback (callable): Callback accepting a service URL (str/bytes) + and callback object (accepting a records.Metadata, + grpc.grpc_status_code, and a str/bytes error message). This argument + when called should be non-blocking and eventually call the callback + object with the appropriate status code/details and metadata (if + successful). + name (str): Plugin name. + """ + if not callable(plugin_callback): + raise ValueError('expected callable plugin_callback') + self.plugin_callback = plugin_callback + self.plugin_name = name + + @staticmethod + cdef grpc.grpc_metadata_credentials_plugin make_c_plugin(self): + cdef grpc.grpc_metadata_credentials_plugin result + result.get_metadata = plugin_get_metadata + result.destroy = plugin_destroy_c_plugin_state + result.state = <void *>self + result.type = self.plugin_name + cpython.Py_INCREF(self) + return result + + +cdef class AuthMetadataContext: + + def __cinit__(self): + self.context.service_url = NULL + self.context.method_name = NULL + + @property + def service_url(self): + return self.context.service_url + + @property + def method_name(self): + return self.context.method_name + + +cdef void plugin_get_metadata( + void *state, grpc.grpc_auth_metadata_context context, + grpc.grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil: + def python_callback( + records.Metadata metadata, grpc.grpc_status_code status, + const char *error_details): + cb(user_data, metadata.c_metadata_array.metadata, + metadata.c_metadata_array.count, status, error_details) + cdef CredentialsMetadataPlugin self = <CredentialsMetadataPlugin>state + cdef AuthMetadataContext cy_context = AuthMetadataContext() + cy_context.context = context + self.plugin_callback(cy_context, python_callback) + +cdef void plugin_destroy_c_plugin_state(void *state): + cpython.Py_DECREF(<CredentialsMetadataPlugin>state) + def channel_credentials_google_default(): cdef ChannelCredentials credentials = ChannelCredentials(); credentials.c_credentials = grpc.grpc_google_default_credentials_create() @@ -185,6 +247,15 @@ def call_credentials_google_iam(authorization_token, authority_selector): credentials.references.append(authority_selector) return credentials +def call_credentials_metadata_plugin(CredentialsMetadataPlugin plugin): + cdef CallCredentials credentials = CallCredentials() + credentials.c_credentials = ( + grpc.grpc_metadata_credentials_create_from_plugin(plugin.make_c_plugin(), + NULL)) + # TODO(atash): the following held reference is *probably* never necessary + credentials.references.append(plugin) + return credentials + def server_credentials_ssl(pem_root_certs, pem_key_cert_pairs, bint force_client_auth): cdef char *c_pem_root_certs = NULL diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxd b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxd index 054ac7796a..10c948cd0a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxd +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxd @@ -137,8 +137,6 @@ cdef extern from "grpc/grpc.h": const char *GRPC_ARG_MAX_CONCURRENT_STREAMS const char *GRPC_ARG_MAX_MESSAGE_LENGTH const char *GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER - const char *GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_DECODER - const char *GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_ENCODER const char *GRPC_ARG_DEFAULT_AUTHORITY const char *GRPC_ARG_PRIMARY_USER_AGENT_STRING const char *GRPC_ARG_SECONDARY_USER_AGENT_STRING @@ -396,3 +394,27 @@ cdef extern from "grpc/grpc_security.h": grpc_call_error grpc_call_set_credentials(grpc_call *call, grpc_call_credentials *creds) + + ctypedef struct grpc_auth_context: + # We don't care about the internals (and in fact don't know them) + pass + + ctypedef struct grpc_auth_metadata_context: + const char *service_url + const char *method_name + const grpc_auth_context *channel_auth_context + + ctypedef void (*grpc_credentials_plugin_metadata_cb)( + void *user_data, const grpc_metadata *creds_md, size_t num_creds_md, + grpc_status_code status, const char *error_details) + + ctypedef struct grpc_metadata_credentials_plugin: + void (*get_metadata)( + void *state, grpc_auth_metadata_context context, + grpc_credentials_plugin_metadata_cb cb, void *user_data) + void (*destroy)(void *state) + void *state + const char *type + + grpc_call_credentials *grpc_metadata_credentials_create_from_plugin( + grpc_metadata_credentials_plugin plugin, void *reserved) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx index be89db8846..79a7f8f563 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx @@ -45,8 +45,6 @@ class ChannelArgKey: max_concurrent_streams = grpc.GRPC_ARG_MAX_CONCURRENT_STREAMS max_message_length = grpc.GRPC_ARG_MAX_MESSAGE_LENGTH http2_initial_sequence_number = grpc.GRPC_ARG_HTTP2_INITIAL_SEQUENCE_NUMBER - http2_hpack_table_size_decoder = grpc.GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_DECODER - http2_hpack_table_size_encoder = grpc.GRPC_ARG_HTTP2_HPACK_TABLE_SIZE_ENCODER default_authority = grpc.GRPC_ARG_DEFAULT_AUTHORITY primary_user_agent_string = grpc.GRPC_ARG_PRIMARY_USER_AGENT_STRING secondary_user_agent_string = grpc.GRPC_ARG_SECONDARY_USER_AGENT_STRING diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx index 635bf1918a..16ec12dac0 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pyx +++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx @@ -76,6 +76,8 @@ Operations = records.Operations CallCredentials = credentials.CallCredentials ChannelCredentials = credentials.ChannelCredentials ServerCredentials = credentials.ServerCredentials +CredentialsMetadataPlugin = credentials.CredentialsMetadataPlugin +AuthMetadataContext = credentials.AuthMetadataContext channel_credentials_google_default = ( credentials.channel_credentials_google_default) @@ -91,6 +93,7 @@ call_credentials_jwt_access = ( call_credentials_refresh_token = ( credentials.call_credentials_google_refresh_token) call_credentials_google_iam = credentials.call_credentials_google_iam +call_credentials_metadata_plugin = credentials.call_credentials_metadata_plugin server_credentials_ssl = credentials.server_credentials_ssl CompletionQueue = completion_queue.CompletionQueue diff --git a/src/python/grpcio/grpc/_links/invocation.py b/src/python/grpcio/grpc/_links/invocation.py index 23ce12a787..5ca0a0ee60 100644 --- a/src/python/grpcio/grpc/_links/invocation.py +++ b/src/python/grpcio/grpc/_links/invocation.py @@ -262,7 +262,7 @@ class _Kernel(object): self._channel, self._completion_queue, '/%s/%s' % (group, method), self._host, time.time() + timeout) if options is not None and options.credentials is not None: - call.set_credentials(options.credentials._intermediary_low_credentials) + call.set_credentials(options.credentials._low_credentials) if transformed_initial_metadata is not None: for metadata_key, metadata_value in transformed_initial_metadata: call.add_metadata(metadata_key, metadata_value) diff --git a/src/python/grpcio/grpc/beta/_server.py b/src/python/grpcio/grpc/beta/_server.py index 4f454437c0..2b520cc7e5 100644 --- a/src/python/grpcio/grpc/beta/_server.py +++ b/src/python/grpcio/grpc/beta/_server.py @@ -170,7 +170,7 @@ class _Server(interfaces.Server): with self._lock: if self._end_link is None: return self._grpc_link.add_port( - address, server_credentials._intermediary_low_credentials) # pylint: disable=protected-access + address, server_credentials._low_credentials) # pylint: disable=protected-access else: raise ValueError('Can\'t add port to serving server!') diff --git a/src/python/grpcio/grpc/beta/implementations.py b/src/python/grpcio/grpc/beta/implementations.py index c9d64ad35a..a0ca330d2c 100644 --- a/src/python/grpcio/grpc/beta/implementations.py +++ b/src/python/grpcio/grpc/beta/implementations.py @@ -36,6 +36,7 @@ import threading # pylint: disable=unused-import # cardinality and face are referenced from specification in this module. from grpc._adapter import _intermediary_low +from grpc._adapter import _low from grpc._adapter import _types from grpc.beta import _connectivity_channel from grpc.beta import _server @@ -48,7 +49,7 @@ _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = ( 'Exception calling channel subscription callback!') -class ClientCredentials(object): +class ChannelCredentials(object): """A value encapsulating the data required to create a secure Channel. This class and its instances have no supported interface - it exists to define @@ -56,13 +57,12 @@ class ClientCredentials(object): functions. """ - def __init__(self, low_credentials, intermediary_low_credentials): + def __init__(self, low_credentials): self._low_credentials = low_credentials - self._intermediary_low_credentials = intermediary_low_credentials -def ssl_client_credentials(root_certificates, private_key, certificate_chain): - """Creates a ClientCredentials for use with an SSL-enabled Channel. +def ssl_channel_credentials(root_certificates, private_key, certificate_chain): + """Creates a ChannelCredentials for use with an SSL-enabled Channel. Args: root_certificates: The PEM-encoded root certificates or None to ask for @@ -73,12 +73,73 @@ def ssl_client_credentials(root_certificates, private_key, certificate_chain): certificate chain should be used. Returns: - A ClientCredentials for use with an SSL-enabled Channel. + A ChannelCredentials for use with an SSL-enabled Channel. """ - intermediary_low_credentials = _intermediary_low.ClientCredentials( - root_certificates, private_key, certificate_chain) - return ClientCredentials( - intermediary_low_credentials._internal, intermediary_low_credentials) # pylint: disable=protected-access + return ChannelCredentials(_low.channel_credentials_ssl( + root_certificates, private_key, certificate_chain)) + + +class CallCredentials(object): + """A value encapsulating data asserting an identity over an *established* + channel. May be composed with ChannelCredentials to always assert identity for + every call over that channel. + + This class and its instances have no supported interface - it exists to define + the type of its instances and its instances exist to be passed to other + functions. + """ + + def __init__(self, low_credentials): + self._low_credentials = low_credentials + + +def metadata_call_credentials(metadata_plugin, name=None): + """Construct CallCredentials from an interfaces.GRPCAuthMetadataPlugin. + + Args: + metadata_plugin: An interfaces.GRPCAuthMetadataPlugin to use in constructing + the CallCredentials object. + + Returns: + A CallCredentials object for use in a GRPCCallOptions object. + """ + if name is None: + name = metadata_plugin.__name__ + return CallCredentials( + _low.call_credentials_metadata_plugin(metadata_plugin, name)) + +def composite_call_credentials(call_credentials, additional_call_credentials): + """Compose two CallCredentials to make a new one. + + Args: + call_credentials: A CallCredentials object. + additional_call_credentials: Another CallCredentials object to compose on + top of call_credentials. + + Returns: + A CallCredentials object for use in a GRPCCallOptions object. + """ + return CallCredentials( + _low.call_credentials_composite( + call_credentials._low_credentials, + additional_call_credentials._low_credentials)) + +def composite_channel_credentials(channel_credentials, + additional_call_credentials): + """Compose ChannelCredentials on top of client credentials to make a new one. + + Args: + channel_credentials: A ChannelCredentials object. + additional_call_credentials: A CallCredentials object to compose on + top of channel_credentials. + + Returns: + A ChannelCredentials object for use in a GRPCCallOptions object. + """ + return ChannelCredentials( + _low.channel_credentials_composite( + channel_credentials._low_credentials, + additional_call_credentials._low_credentials)) class Channel(object): @@ -135,19 +196,19 @@ def insecure_channel(host, port): return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access -def secure_channel(host, port, client_credentials): +def secure_channel(host, port, channel_credentials): """Creates a secure Channel to a remote host. Args: host: The name of the remote host to which to connect. port: The port of the remote host to which to connect. - client_credentials: A ClientCredentials. + channel_credentials: A ChannelCredentials. Returns: A secure Channel to the remote host through which RPCs may be conducted. """ intermediary_low_channel = _intermediary_low.Channel( - '%s:%d' % (host, port), client_credentials._intermediary_low_credentials) + '%s:%d' % (host, port), channel_credentials._low_credentials) return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access @@ -251,9 +312,8 @@ class ServerCredentials(object): functions. """ - def __init__(self, low_credentials, intermediary_low_credentials): + def __init__(self, low_credentials): self._low_credentials = low_credentials - self._intermediary_low_credentials = intermediary_low_credentials def ssl_server_credentials( @@ -282,11 +342,9 @@ def ssl_server_credentials( raise ValueError( 'Illegal to require client auth without providing root certificates!') else: - intermediary_low_credentials = _intermediary_low.ServerCredentials( + return ServerCredentials(_low.server_credentials_ssl( root_certificates, private_key_certificate_chain_pairs, - require_client_auth) - return ServerCredentials( - intermediary_low_credentials._internal, intermediary_low_credentials) # pylint: disable=protected-access + require_client_auth)) class ServerOptions(object): diff --git a/src/python/grpcio/grpc/beta/interfaces.py b/src/python/grpcio/grpc/beta/interfaces.py index d4ca56500f..0663119163 100644 --- a/src/python/grpcio/grpc/beta/interfaces.py +++ b/src/python/grpcio/grpc/beta/interfaces.py @@ -100,14 +100,55 @@ def grpc_call_options(disable_compression=False, credentials=None): disable_compression: A boolean indicating whether or not compression should be disabled for the request object of the RPC. Only valid for request-unary RPCs. - credentials: Reserved for gRPC per-call credentials. The type for this does - not exist yet at the Python level. + credentials: A CallCredentials object to use for the invoked RPC. """ - if credentials is not None: - raise ValueError('`credentials` is a reserved argument') return GRPCCallOptions(disable_compression, None, credentials) +class GRPCAuthMetadataContext(object): + """Provides information to call credentials metadata plugins. + + Attributes: + service_url: A string URL of the service being called into. + method_name: A string of the fully qualified method name being called. + """ + __metaclass__ = abc.ABCMeta + + +class GRPCAuthMetadataPluginCallback(object): + """Callback object received by a metadata plugin.""" + __metaclass__ = abc.ABCMeta + + def __call__(self, metadata, error): + """Inform the gRPC runtime of the metadata to construct a CallCredentials. + + Args: + metadata: An iterable of 2-sequences (e.g. tuples) of metadata key/value + pairs. + error: An Exception to indicate error or None to indicate success. + """ + raise NotImplementedError() + + +class GRPCAuthMetadataPlugin(object): + """ + """ + __metaclass__ = abc.ABCMeta + + def __call__(self, context, callback): + """Invoke the plugin. + + Must not block. Need only be called by the gRPC runtime. + + Args: + context: A GRPCAuthMetadataContext providing information on what the + plugin is being used for. + callback: A GRPCAuthMetadataPluginCallback to be invoked either + synchronously or asynchronously. + """ + raise NotImplementedError() + + class GRPCServicerContext(object): """Exposes gRPC-specific options and behaviors to code servicing RPCs.""" __metaclass__ = abc.ABCMeta diff --git a/src/python/grpcio/tests/interop/_secure_interop_test.py b/src/python/grpcio/tests/interop/_secure_interop_test.py index a0fef1fc20..7e3061133f 100644 --- a/src/python/grpcio/tests/interop/_secure_interop_test.py +++ b/src/python/grpcio/tests/interop/_secure_interop_test.py @@ -55,7 +55,7 @@ class SecureInteropTest( self.server.start() self.stub = test_pb2.beta_create_TestService_stub( test_utilities.not_really_secure_channel( - '[::]', port, implementations.ssl_client_credentials( + '[::]', port, implementations.ssl_channel_credentials( resources.test_root_certificates(), None, None), _SERVER_HOST_OVERRIDE)) diff --git a/src/python/grpcio/tests/interop/client.py b/src/python/grpcio/tests/interop/client.py index 9449ff5429..5c00bce014 100644 --- a/src/python/grpcio/tests/interop/client.py +++ b/src/python/grpcio/tests/interop/client.py @@ -94,7 +94,7 @@ def _stub(args): channel = test_utilities.not_really_secure_channel( args.server_host, args.server_port, - implementations.ssl_client_credentials(root_certificates, None, None), + implementations.ssl_channel_credentials(root_certificates, None, None), args.server_host_override) stub = test_pb2.beta_create_TestService_stub( channel, metadata_transformer=metadata_transformer) diff --git a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py index 13ecae817c..876da88de9 100644 --- a/src/python/grpcio/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio/tests/unit/_cython/cygrpc_test.py @@ -28,11 +28,24 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import time +import threading import unittest from grpc._cython import cygrpc from tests.unit._cython import test_utilities from tests.unit import test_common +from tests.unit import resources + + +_SSL_HOST_OVERRIDE = 'foo.test.google.fr' +_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key' +_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value' + +def _metadata_plugin_callback(context, callback): + callback(cygrpc.Metadata( + [cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY, + _CALL_CREDENTIALS_METADATA_VALUE)]), + cygrpc.StatusCode.ok, '') class TypeSmokeTest(unittest.TestCase): @@ -89,6 +102,17 @@ class TypeSmokeTest(unittest.TestCase): channel = cygrpc.Channel('[::]:0', cygrpc.ChannelArgs([])) del channel + def testCredentialsMetadataPluginUpDown(self): + plugin = cygrpc.CredentialsMetadataPlugin( + lambda ignored_a, ignored_b: None, '') + del plugin + + def testCallCredentialsFromPluginUpDown(self): + plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '') + call_credentials = cygrpc.call_credentials_metadata_plugin(plugin) + del plugin + del call_credentials + def testServerStartNoExplicitShutdown(self): server = cygrpc.Server() completion_queue = cygrpc.CompletionQueue() @@ -260,5 +284,169 @@ class InsecureServerInsecureClient(unittest.TestCase): del server_call +class SecureServerSecureClient(unittest.TestCase): + + def setUp(self): + server_credentials = cygrpc.server_credentials_ssl( + None, [cygrpc.SslPemKeyCertPair(resources.private_key(), + resources.certificate_chain())], False) + channel_credentials = cygrpc.channel_credentials_ssl( + resources.test_root_certificates(), None) + self.server_completion_queue = cygrpc.CompletionQueue() + self.server = cygrpc.Server() + self.server.register_completion_queue(self.server_completion_queue) + self.port = self.server.add_http2_port('[::]:0', server_credentials) + self.server.start() + self.client_completion_queue = cygrpc.CompletionQueue() + client_channel_arguments = cygrpc.ChannelArgs([ + cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override, + _SSL_HOST_OVERRIDE)]) + self.client_channel = cygrpc.Channel( + 'localhost:{}'.format(self.port), client_channel_arguments, + channel_credentials) + + def tearDown(self): + del self.server + del self.client_completion_queue + del self.server_completion_queue + + def testEcho(self): + DEADLINE = time.time()+5 + DEADLINE_TOLERANCE = 0.25 + CLIENT_METADATA_ASCII_KEY = b'key' + CLIENT_METADATA_ASCII_VALUE = b'val' + CLIENT_METADATA_BIN_KEY = b'key-bin' + CLIENT_METADATA_BIN_VALUE = b'\0'*1000 + SERVER_INITIAL_METADATA_KEY = b'init_me_me_me' + SERVER_INITIAL_METADATA_VALUE = b'whodawha?' + SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought' + SERVER_TRAILING_METADATA_VALUE = b'zomg it is' + SERVER_STATUS_CODE = cygrpc.StatusCode.ok + SERVER_STATUS_DETAILS = b'our work is never over' + REQUEST = b'in death a member of project mayhem has a name' + RESPONSE = b'his name is robert paulson' + METHOD = b'/twinkies' + HOST = None # Default host + + cygrpc_deadline = cygrpc.Timespec(DEADLINE) + + server_request_tag = object() + request_call_result = self.server.request_call( + self.server_completion_queue, self.server_completion_queue, + server_request_tag) + + self.assertEqual(cygrpc.CallError.ok, request_call_result) + + plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, '') + call_credentials = cygrpc.call_credentials_metadata_plugin(plugin) + + client_call_tag = object() + client_call = self.client_channel.create_call( + None, 0, self.client_completion_queue, METHOD, HOST, cygrpc_deadline) + client_call.set_credentials(call_credentials) + client_initial_metadata = cygrpc.Metadata([ + cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY, + CLIENT_METADATA_ASCII_VALUE), + cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)]) + client_start_batch_result = client_call.start_batch(cygrpc.Operations([ + cygrpc.operation_send_initial_metadata(client_initial_metadata), + cygrpc.operation_send_message(REQUEST), + cygrpc.operation_send_close_from_client(), + cygrpc.operation_receive_initial_metadata(), + cygrpc.operation_receive_message(), + cygrpc.operation_receive_status_on_client() + ]), client_call_tag) + self.assertEqual(cygrpc.CallError.ok, client_start_batch_result) + client_event_future = test_utilities.CompletionQueuePollFuture( + self.client_completion_queue, cygrpc_deadline) + + request_event = self.server_completion_queue.poll(cygrpc_deadline) + self.assertEqual(cygrpc.CompletionType.operation_complete, + request_event.type) + self.assertIsInstance(request_event.operation_call, cygrpc.Call) + self.assertIs(server_request_tag, request_event.tag) + self.assertEqual(0, len(request_event.batch_operations)) + client_metadata_with_credentials = list(client_initial_metadata) + [ + (_CALL_CREDENTIALS_METADATA_KEY, _CALL_CREDENTIALS_METADATA_VALUE)] + self.assertTrue( + test_common.metadata_transmitted(client_metadata_with_credentials, + request_event.request_metadata)) + self.assertEqual(METHOD, request_event.request_call_details.method) + self.assertEqual(_SSL_HOST_OVERRIDE, + request_event.request_call_details.host) + self.assertLess( + abs(DEADLINE - float(request_event.request_call_details.deadline)), + DEADLINE_TOLERANCE) + + server_call_tag = object() + server_call = request_event.operation_call + server_initial_metadata = cygrpc.Metadata([ + cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY, + SERVER_INITIAL_METADATA_VALUE)]) + server_trailing_metadata = cygrpc.Metadata([ + cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY, + SERVER_TRAILING_METADATA_VALUE)]) + server_start_batch_result = server_call.start_batch([ + cygrpc.operation_send_initial_metadata(server_initial_metadata), + cygrpc.operation_receive_message(), + cygrpc.operation_send_message(RESPONSE), + cygrpc.operation_receive_close_on_server(), + cygrpc.operation_send_status_from_server( + server_trailing_metadata, SERVER_STATUS_CODE, SERVER_STATUS_DETAILS) + ], server_call_tag) + self.assertEqual(cygrpc.CallError.ok, server_start_batch_result) + + client_event = client_event_future.result() + server_event = self.server_completion_queue.poll(cygrpc_deadline) + + self.assertEqual(6, len(client_event.batch_operations)) + found_client_op_types = set() + for client_result in client_event.batch_operations: + # we expect each op type to be unique + self.assertNotIn(client_result.type, found_client_op_types) + found_client_op_types.add(client_result.type) + if client_result.type == cygrpc.OperationType.receive_initial_metadata: + self.assertTrue( + test_common.metadata_transmitted(server_initial_metadata, + client_result.received_metadata)) + elif client_result.type == cygrpc.OperationType.receive_message: + self.assertEqual(RESPONSE, client_result.received_message.bytes()) + elif client_result.type == cygrpc.OperationType.receive_status_on_client: + self.assertTrue( + test_common.metadata_transmitted(server_trailing_metadata, + client_result.received_metadata)) + self.assertEqual(SERVER_STATUS_DETAILS, + client_result.received_status_details) + self.assertEqual(SERVER_STATUS_CODE, client_result.received_status_code) + self.assertEqual(set([ + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.send_message, + cygrpc.OperationType.send_close_from_client, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_message, + cygrpc.OperationType.receive_status_on_client + ]), found_client_op_types) + + self.assertEqual(5, len(server_event.batch_operations)) + found_server_op_types = set() + for server_result in server_event.batch_operations: + self.assertNotIn(client_result.type, found_server_op_types) + found_server_op_types.add(server_result.type) + if server_result.type == cygrpc.OperationType.receive_message: + self.assertEqual(REQUEST, server_result.received_message.bytes()) + elif server_result.type == cygrpc.OperationType.receive_close_on_server: + self.assertFalse(server_result.received_cancelled) + self.assertEqual(set([ + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.receive_message, + cygrpc.OperationType.send_message, + cygrpc.OperationType.receive_close_on_server, + cygrpc.OperationType.send_status_from_server + ]), found_server_op_types) + + del client_call + del server_call + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/src/python/grpcio/tests/unit/beta/_beta_features_test.py b/src/python/grpcio/tests/unit/beta/_beta_features_test.py index 5a7492ee9e..ea44177b49 100644 --- a/src/python/grpcio/tests/unit/beta/_beta_features_test.py +++ b/src/python/grpcio/tests/unit/beta/_beta_features_test.py @@ -42,6 +42,9 @@ from tests.unit.framework.common import test_constants _SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_PER_RPC_CREDENTIALS_METADATA_KEY = 'my-call-credentials-metadata-key' +_PER_RPC_CREDENTIALS_METADATA_VALUE = 'my-call-credentials-metadata-value' + _GROUP = 'group' _UNARY_UNARY = 'unary-unary' _UNARY_STREAM = 'unary-stream' @@ -63,6 +66,7 @@ class _Servicer(object): with self._condition: self._request = request self._peer = context.protocol_context().peer() + self._invocation_metadata = context.invocation_metadata() context.protocol_context().disable_next_response_compression() self._serviced = True self._condition.notify_all() @@ -72,6 +76,7 @@ class _Servicer(object): with self._condition: self._request = request self._peer = context.protocol_context().peer() + self._invocation_metadata = context.invocation_metadata() context.protocol_context().disable_next_response_compression() self._serviced = True self._condition.notify_all() @@ -83,6 +88,7 @@ class _Servicer(object): self._request = request with self._condition: self._peer = context.protocol_context().peer() + self._invocation_metadata = context.invocation_metadata() context.protocol_context().disable_next_response_compression() self._serviced = True self._condition.notify_all() @@ -95,6 +101,7 @@ class _Servicer(object): context.protocol_context().disable_next_response_compression() yield _RESPONSE with self._condition: + self._invocation_metadata = context.invocation_metadata() self._serviced = True self._condition.notify_all() @@ -137,6 +144,11 @@ class _BlockingIterator(object): self._condition.notify_all() +def _metadata_plugin(context, callback): + callback([(_PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE)], None) + + class BetaFeaturesTest(unittest.TestCase): def setUp(self): @@ -167,10 +179,12 @@ class BetaFeaturesTest(unittest.TestCase): [(resources.private_key(), resources.certificate_chain(),),]) port = self._server.add_secure_port('[::]:0', server_credentials) self._server.start() - self._client_credentials = implementations.ssl_client_credentials( + self._channel_credentials = implementations.ssl_channel_credentials( resources.test_root_certificates(), None, None) + self._call_credentials = implementations.metadata_call_credentials( + _metadata_plugin) channel = test_utilities.not_really_secure_channel( - 'localhost', port, self._client_credentials, _SERVER_HOST_OVERRIDE) + 'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE) stub_options = implementations.stub_options( thread_pool_size=test_constants.POOL_SIZE) self._dynamic_stub = implementations.dynamic_stub( @@ -181,21 +195,36 @@ class BetaFeaturesTest(unittest.TestCase): self._server.stop(test_constants.SHORT_TIMEOUT).wait() def test_unary_unary(self): - call_options = interfaces.grpc_call_options(disable_compression=True) + call_options = interfaces.grpc_call_options( + disable_compression=True, credentials=self._call_credentials) response = getattr(self._dynamic_stub, _UNARY_UNARY)( _REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options) self.assertEqual(_RESPONSE, response) self.assertIsNotNone(self._servicer.peer()) + invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in + self._servicer._invocation_metadata] + self.assertIn( + (_PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE), + invocation_metadata) def test_unary_stream(self): - call_options = interfaces.grpc_call_options(disable_compression=True) + call_options = interfaces.grpc_call_options( + disable_compression=True, credentials=self._call_credentials) response_iterator = getattr(self._dynamic_stub, _UNARY_STREAM)( _REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options) self._servicer.block_until_serviced() self.assertIsNotNone(self._servicer.peer()) + invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in + self._servicer._invocation_metadata] + self.assertIn( + (_PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE), + invocation_metadata) def test_stream_unary(self): - call_options = interfaces.grpc_call_options() + call_options = interfaces.grpc_call_options( + credentials=self._call_credentials) request_iterator = _BlockingIterator(iter((_REQUEST,))) response_future = getattr(self._dynamic_stub, _STREAM_UNARY).future( request_iterator, test_constants.LONG_TIMEOUT, @@ -207,9 +236,16 @@ class BetaFeaturesTest(unittest.TestCase): self._servicer.block_until_serviced() self.assertIsNotNone(self._servicer.peer()) self.assertEqual(_RESPONSE, response_future.result()) + invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in + self._servicer._invocation_metadata] + self.assertIn( + (_PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE), + invocation_metadata) def test_stream_stream(self): - call_options = interfaces.grpc_call_options() + call_options = interfaces.grpc_call_options( + credentials=self._call_credentials) request_iterator = _BlockingIterator(iter((_REQUEST,))) response_iterator = getattr(self._dynamic_stub, _STREAM_STREAM)( request_iterator, test_constants.SHORT_TIMEOUT, @@ -222,6 +258,12 @@ class BetaFeaturesTest(unittest.TestCase): self._servicer.block_until_serviced() self.assertIsNotNone(self._servicer.peer()) self.assertEqual(_RESPONSE, response) + invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in + self._servicer._invocation_metadata] + self.assertIn( + (_PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE), + invocation_metadata) class ContextManagementAndLifecycleTest(unittest.TestCase): @@ -250,7 +292,7 @@ class ContextManagementAndLifecycleTest(unittest.TestCase): thread_pool_size=test_constants.POOL_SIZE) self._server_credentials = implementations.ssl_server_credentials( [(resources.private_key(), resources.certificate_chain(),),]) - self._client_credentials = implementations.ssl_client_credentials( + self._channel_credentials = implementations.ssl_channel_credentials( resources.test_root_certificates(), None, None) self._stub_options = implementations.stub_options( thread_pool_size=test_constants.POOL_SIZE) @@ -262,7 +304,7 @@ class ContextManagementAndLifecycleTest(unittest.TestCase): server.start() channel = test_utilities.not_really_secure_channel( - 'localhost', port, self._client_credentials, _SERVER_HOST_OVERRIDE) + 'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE) dynamic_stub = implementations.dynamic_stub( channel, _GROUP, self._cardinalities, options=self._stub_options) for _ in range(100): diff --git a/src/python/grpcio/tests/unit/beta/_face_interface_test.py b/src/python/grpcio/tests/unit/beta/_face_interface_test.py index 55c0d20060..1c21dfd03d 100644 --- a/src/python/grpcio/tests/unit/beta/_face_interface_test.py +++ b/src/python/grpcio/tests/unit/beta/_face_interface_test.py @@ -91,10 +91,10 @@ class _Implementation(test_interfaces.Implementation): [(resources.private_key(), resources.certificate_chain(),),]) port = server.add_secure_port('[::]:0', server_credentials) server.start() - client_credentials = implementations.ssl_client_credentials( + channel_credentials = implementations.ssl_channel_credentials( resources.test_root_certificates(), None, None) channel = test_utilities.not_really_secure_channel( - 'localhost', port, client_credentials, _SERVER_HOST_OVERRIDE) + 'localhost', port, channel_credentials, _SERVER_HOST_OVERRIDE) stub_options = implementations.stub_options( request_serializers=serialization_behaviors.request_serializers, response_deserializers=serialization_behaviors.response_deserializers, diff --git a/src/python/grpcio/tests/unit/beta/test_utilities.py b/src/python/grpcio/tests/unit/beta/test_utilities.py index 24a8600e12..0313e06a93 100644 --- a/src/python/grpcio/tests/unit/beta/test_utilities.py +++ b/src/python/grpcio/tests/unit/beta/test_utilities.py @@ -34,13 +34,13 @@ from grpc.beta import implementations def not_really_secure_channel( - host, port, client_credentials, server_host_override): + host, port, channel_credentials, server_host_override): """Creates an insecure Channel to a remote host. Args: host: The name of the remote host to which to connect. port: The port of the remote host to which to connect. - client_credentials: The implementations.ClientCredentials with which to + channel_credentials: The implementations.ChannelCredentials with which to connect. server_host_override: The target name used for SSL host name checking. @@ -50,7 +50,7 @@ def not_really_secure_channel( """ hostport = '%s:%d' % (host, port) intermediary_low_channel = _intermediary_low.Channel( - hostport, client_credentials._intermediary_low_credentials, + hostport, channel_credentials._low_credentials, server_host_override=server_host_override) return implementations.Channel( intermediary_low_channel._internal, intermediary_low_channel) |