diff options
Diffstat (limited to 'src/python')
38 files changed, 405 insertions, 262 deletions
diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index b7ed0c8563..d1477fbeb5 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -1250,19 +1250,20 @@ class Server(six.with_metaclass(abc.ABCMeta)): """Stops this Server. This method immediately stop service of new RPCs in all cases. + If a grace period is specified, this method returns immediately and all RPCs active at the end of the grace period are aborted. - - If a grace period is not specified, then all existing RPCs are - teriminated immediately and the this method blocks until the last - RPC handler terminates. + If a grace period is not specified (by passing None for `grace`), + all existing RPCs are aborted immediately and this method + blocks until the last RPC handler terminates. This method is idempotent and may be called at any time. - Passing a smaller grace value in subsequent call will have - the effect of stopping the Server sooner. Passing a larger - grace value in subsequent call *will not* have the effect of - stopping the server later (i.e. the most restrictive grace - value is used). + Passing a smaller grace value in a subsequent call will have + the effect of stopping the Server sooner (passing None will + have the effect of stopping the server immediately). Passing + a larger grace value in a subsequent call *will not* have the + effect of stopping the server later (i.e. the most restrictive + grace value is used). Args: grace: A duration of time in seconds or None. @@ -1481,7 +1482,7 @@ def ssl_server_credentials(private_key_certificate_chain_pairs, A ServerCredentials for use with an SSL-enabled Server. Typically, this object is an argument to add_secure_port() method during server setup. """ - if len(private_key_certificate_chain_pairs) == 0: + if not private_key_certificate_chain_pairs: raise ValueError( 'At least one private key-certificate chain pair is required!') elif require_client_auth and root_certificates is None: @@ -1511,15 +1512,15 @@ def ssl_server_certificate_configuration(private_key_certificate_chain_pairs, A ServerCertificateConfiguration that can be returned in the certificate configuration fetching callback. """ - if len(private_key_certificate_chain_pairs) == 0: - raise ValueError( - 'At least one private key-certificate chain pair is required!') - else: + if private_key_certificate_chain_pairs: return ServerCertificateConfiguration( _cygrpc.server_certificate_config_ssl(root_certificates, [ _cygrpc.SslPemKeyCertPair(key, pem) for key, pem in private_key_certificate_chain_pairs ])) + else: + raise ValueError( + 'At least one private key-certificate chain pair is required!') def dynamic_ssl_server_credentials(initial_certificate_configuration, diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 2017d47130..e9246991df 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -24,6 +24,8 @@ from grpc import _grpcio_metadata from grpc._cython import cygrpc from grpc.framework.foundation import callable_util +_LOGGER = logging.getLogger(__name__) + _USER_AGENT = 'grpc-python/{}'.format(_grpcio_metadata.__version__) _EMPTY_FLAGS = 0 @@ -181,7 +183,7 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, except Exception: # pylint: disable=broad-except code = grpc.StatusCode.UNKNOWN details = 'Exception iterating requests!' - logging.exception(details) + _LOGGER.exception(details) call.cancel(_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details) _abort(state, code, details) @@ -190,7 +192,7 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, with state.condition: if state.code is None and not state.cancelled: if serialized_request is None: - code = grpc.StatusCode.INTERNAL # pylint: disable=redefined-variable-type + code = grpc.StatusCode.INTERNAL details = 'Exception serializing request!' call.cancel( _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], @@ -811,10 +813,7 @@ def _poll_connectivity(state, channel, initial_try_to_connect): _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[ connectivity]) if not state.delivering: - # NOTE(nathaniel): The field is only ever used as a - # sequence so it's fine that both lists and tuples are - # assigned to it. - callbacks = _deliveries(state) # pylint: disable=redefined-variable-type + callbacks = _deliveries(state) if callbacks: _spawn_delivery(state, callbacks) diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index 862987a0cd..8358cbec5b 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -20,6 +20,8 @@ import six import grpc from grpc._cython import cygrpc +_LOGGER = logging.getLogger(__name__) + CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = { cygrpc.ConnectivityState.idle: grpc.ChannelConnectivity.IDLE, @@ -73,7 +75,7 @@ def decode(b): try: return b.decode('utf8') except UnicodeDecodeError: - logging.exception('Invalid encoding on %s', b) + _LOGGER.exception('Invalid encoding on %s', b) return b.decode('latin1') @@ -84,7 +86,7 @@ def _transform(message, transformer, exception_message): try: return transformer(message) except Exception: # pylint: disable=broad-except - logging.exception(exception_message) + _LOGGER.exception(exception_message) return None diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi index 65de30884c..aecd3d7b11 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi @@ -15,10 +15,12 @@ cimport cpython +# TODO(https://github.com/grpc/grpc/issues/15662): Reform this. cdef void* _copy_pointer(void* pointer): return pointer +# TODO(https://github.com/grpc/grpc/issues/15662): Reform this. cdef void _destroy_pointer(void* pointer): pass diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi index eefc685c0b..f067d76fab 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi @@ -69,3 +69,6 @@ cdef class Channel: cdef grpc_arg_pointer_vtable _vtable cdef _ChannelState _state + + # TODO(https://github.com/grpc/grpc/issues/15662): Eliminate this. + cdef tuple _arguments diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index 72e74e84ae..8c37a3cf85 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -390,6 +390,7 @@ cdef class Channel: def __cinit__( self, bytes target, object arguments, ChannelCredentials channel_credentials): + arguments = () if arguments is None else tuple(arguments) grpc_init() self._state = _ChannelState() self._vtable.copy = &_copy_pointer @@ -410,6 +411,7 @@ cdef class Channel: grpc_completion_queue_create_for_next(NULL)) self._state.c_connectivity_completion_queue = ( grpc_completion_queue_create_for_next(NULL)) + self._arguments = arguments def target(self): cdef char *c_target diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi index 7e9ea33ca0..8d73215247 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi @@ -57,6 +57,11 @@ cdef class ChannelCredentials: cdef grpc_channel_credentials *c_credentials +cdef class SSLSessionCacheLRU: + + cdef grpc_ssl_session_cache *_cache + + cdef class SSLChannelCredentials(ChannelCredentials): cdef readonly object _pem_root_certificates diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi index 500086f6cb..f4ccfbc016 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi @@ -17,6 +17,21 @@ cimport cpython import grpc import threading +from libc.stdint cimport uintptr_t + + +def _spawn_callback_in_thread(cb_func, args): + threading.Thread(target=cb_func, args=args).start() + +async_callback_func = _spawn_callback_in_thread + +def set_async_callback_func(callback_func): + global async_callback_func + async_callback_func = callback_func + +def _spawn_callback_async(callback, args): + async_callback_func(callback, args) + cdef class CallCredentials: @@ -40,7 +55,7 @@ cdef int _get_metadata( else: cb(user_data, NULL, 0, status, error_details) args = context.service_url, context.method_name, callback, - threading.Thread(target=<object>state, args=args).start() + _spawn_callback_async(<object>state, args) return 0 # Asynchronous return @@ -96,6 +111,21 @@ cdef class ChannelCredentials: raise NotImplementedError() +cdef class SSLSessionCacheLRU: + + def __cinit__(self, capacity): + grpc_init() + self._cache = grpc_ssl_session_cache_create_lru(capacity) + + def __int__(self): + return <uintptr_t>self._cache + + def __dealloc__(self): + if self._cache != NULL: + grpc_ssl_session_cache_destroy(self._cache) + grpc_shutdown() + + cdef class SSLChannelCredentials(ChannelCredentials): def __cinit__(self, pem_root_certificates, private_key, certificate_chain): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index 2d6c900c54..cfefeaf938 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -131,6 +131,7 @@ cdef extern from "grpc/grpc.h": const char *GRPC_ARG_PRIMARY_USER_AGENT_STRING const char *GRPC_ARG_SECONDARY_USER_AGENT_STRING const char *GRPC_SSL_TARGET_NAME_OVERRIDE_ARG + const char *GRPC_SSL_SESSION_CACHE_ARG const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM const char *GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL const char *GRPC_COMPRESSION_CHANNEL_ENABLED_ALGORITHMS_BITSET @@ -452,8 +453,16 @@ cdef extern from "grpc/grpc_security.h": # We don't care about the internals (and in fact don't know them) pass + + ctypedef struct grpc_ssl_session_cache: + # We don't care about the internals (and in fact don't know them) + pass + ctypedef void (*grpc_ssl_roots_override_callback)(char **pem_root_certs) + grpc_ssl_session_cache *grpc_ssl_session_cache_create_lru(size_t capacity) + void grpc_ssl_session_cache_destroy(grpc_ssl_session_cache* cache) + void grpc_set_ssl_roots_override_callback( grpc_ssl_roots_override_callback cb) nogil diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_gevent.pyx b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_gevent.pyx index 31ef671aed..f9a1b2856d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_gevent.pyx +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_gevent.pyx @@ -418,6 +418,11 @@ def init_grpc_gevent(): g_event = gevent.event.Event() g_pool = gevent.pool.Group() + + def cb_func(cb, args): + _spawn_greenlet(cb, *args) + set_async_callback_func(cb_func) + gevent_resolver_vtable.resolve = socket_resolve gevent_resolver_vtable.resolve_async = socket_resolve_async diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi index 53e06a1596..00a1b23a67 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi @@ -14,6 +14,7 @@ import logging +_LOGGER = logging.getLogger(__name__) # This function will ascii encode unicode string inputs if neccesary. # In Python3, unicode strings are the default str type. @@ -49,5 +50,5 @@ cdef str _decode(bytes bytestring): try: return bytestring.decode('utf8') except UnicodeDecodeError: - logging.exception('Invalid encoding on %s', bytestring) + _LOGGER.exception('Invalid encoding on %s', bytestring) return bytestring.decode('latin1') diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi index ecd991685f..37b98ebbdb 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi @@ -51,6 +51,7 @@ class ChannelArgKey: default_authority = GRPC_ARG_DEFAULT_AUTHORITY primary_user_agent_string = GRPC_ARG_PRIMARY_USER_AGENT_STRING secondary_user_agent_string = GRPC_ARG_SECONDARY_USER_AGENT_STRING + ssl_session_cache = GRPC_SSL_SESSION_CACHE_ARG ssl_target_name_override = GRPC_SSL_TARGET_NAME_OVERRIDE_ARG diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi index 4588db30d3..52cfccb677 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi @@ -23,6 +23,7 @@ cdef class Server: cdef bint is_shutdown # notification of complete shutdown received # used at dealloc when user forgets to shutdown cdef CompletionQueue backup_shutdown_queue + # TODO(https://github.com/grpc/grpc/issues/15662): Elide this. cdef list references cdef list registered_completion_queues diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index 707ec742dd..da3dd21244 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -18,6 +18,8 @@ import logging import time import grpc +_LOGGER = logging.getLogger(__name__) + cdef grpc_ssl_certificate_config_reload_status _server_cert_config_fetcher_wrapper( void* user_data, grpc_ssl_server_certificate_config **config) with gil: # This is a credentials.ServerCertificateConfig @@ -34,13 +36,13 @@ cdef grpc_ssl_certificate_config_reload_status _server_cert_config_fetcher_wrapp try: cert_config_wrapper = user_cb() except Exception: - logging.exception('Error fetching certificate config') + _LOGGER.exception('Error fetching certificate config') return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_FAIL if cert_config_wrapper is None: return GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED elif not isinstance( cert_config_wrapper, grpc.ServerCertificateConfiguration): - logging.error( + _LOGGER.error( 'Error fetching certificate configuration: certificate ' 'configuration must be of type grpc.ServerCertificateConfiguration, ' 'not %s' % type(cert_config_wrapper).__name__) diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index f465e35a9c..6b7a912a94 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -100,6 +100,12 @@ class _LocalFailure(grpc.RpcError, grpc.Future, grpc.Call): def cancelled(self): return False + def is_active(self): + return False + + def time_remaining(self): + return None + def running(self): return False @@ -115,6 +121,9 @@ class _LocalFailure(grpc.RpcError, grpc.Future, grpc.Call): def traceback(self, ignored_timeout=None): return self._traceback + def add_callback(self, callback): + return False + def add_done_callback(self, fn): fn(self) @@ -288,11 +297,11 @@ class _Channel(grpc.Channel): self._channel = channel self._interceptor = interceptor - def subscribe(self, *args, **kwargs): - self._channel.subscribe(*args, **kwargs) + def subscribe(self, callback, try_to_connect=False): + self._channel.subscribe(callback, try_to_connect=try_to_connect) - def unsubscribe(self, *args, **kwargs): - self._channel.unsubscribe(*args, **kwargs) + def unsubscribe(self, callback): + self._channel.unsubscribe(callback) def unary_unary(self, method, diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py index 6785e5876a..916ee080b6 100644 --- a/src/python/grpcio/grpc/_plugin_wrapping.py +++ b/src/python/grpcio/grpc/_plugin_wrapping.py @@ -20,6 +20,8 @@ import grpc from grpc import _common from grpc._cython import cygrpc +_LOGGER = logging.getLogger(__name__) + class _AuthMetadataContext( collections.namedtuple('AuthMetadataContext', ( @@ -76,7 +78,7 @@ class _Plugin(object): _AuthMetadataPluginCallback( callback_state, callback)) except Exception as exception: # pylint: disable=broad-except - logging.exception( + _LOGGER.exception( 'AuthMetadataPluginCallback "%s" raised exception!', self._metadata_plugin) with callback_state.lock: diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index d849cadbee..2761022f21 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -27,6 +27,8 @@ from grpc import _interceptor from grpc._cython import cygrpc from grpc.framework.foundation import callable_util +_LOGGER = logging.getLogger(__name__) + _SHUTDOWN_TAG = 'shutdown' _REQUEST_CALL_TAG = 'request_call' @@ -279,7 +281,7 @@ class _Context(grpc.ServicerContext): def abort(self, code, details): # treat OK like other invalid arguments: fail the RPC if code == grpc.StatusCode.OK: - logging.error( + _LOGGER.error( 'abort() called with StatusCode.OK; returning UNKNOWN') code = grpc.StatusCode.UNKNOWN details = '' @@ -328,6 +330,8 @@ class _RequestIterator(object): self._state.request = None return request + raise AssertionError() # should never run + def _next(self): with self._state.condition: self._raise_or_start_receive_message() @@ -390,7 +394,7 @@ def _call_behavior(rpc_event, state, behavior, argument, request_deserializer): b'RPC Aborted') elif exception not in state.rpc_errors: details = 'Exception calling application: {}'.format(exception) - logging.exception(details) + _LOGGER.exception(details) _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, _common.encode(details)) return None, False @@ -408,7 +412,7 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator): b'RPC Aborted') elif exception not in state.rpc_errors: details = 'Exception iterating responses: {}'.format(exception) - logging.exception(details) + _LOGGER.exception(details) _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, _common.encode(details)) return None, False @@ -617,7 +621,7 @@ def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool, interceptor_pipeline) except Exception as exception: # pylint: disable=broad-except details = 'Exception servicing handler: {}'.format(exception) - logging.exception(details) + _LOGGER.exception(details) return _reject_rpc(rpc_event, cygrpc.StatusCode.unknown, b'Error in service handler!'), None if method_handler is None: diff --git a/src/python/grpcio/grpc/_utilities.py b/src/python/grpcio/grpc/_utilities.py index 25bd1ceae2..d90b34bcbd 100644 --- a/src/python/grpcio/grpc/_utilities.py +++ b/src/python/grpcio/grpc/_utilities.py @@ -116,6 +116,8 @@ class _ChannelReadyFuture(grpc.Future): callable_util.call_logging_exceptions( done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self) + return True + def cancelled(self): with self._condition: return self._cancelled diff --git a/src/python/grpcio/grpc/beta/_server_adaptations.py b/src/python/grpcio/grpc/beta/_server_adaptations.py index ccafec8951..80ac65b649 100644 --- a/src/python/grpcio/grpc/beta/_server_adaptations.py +++ b/src/python/grpcio/grpc/beta/_server_adaptations.py @@ -305,6 +305,7 @@ def _simple_method_handler(implementation, request_deserializer, response_serializer, None, None, None, _adapt_stream_stream_event( implementation.stream_stream_event)) + raise ValueError() def _flatten_method_pair_map(method_pair_map): diff --git a/src/python/grpcio/grpc/beta/utilities.py b/src/python/grpcio/grpc/beta/utilities.py index b5d8aac71a..fe3ce606c9 100644 --- a/src/python/grpcio/grpc/beta/utilities.py +++ b/src/python/grpcio/grpc/beta/utilities.py @@ -85,6 +85,8 @@ class _ChannelReadyFuture(future.Future): callable_util.call_logging_exceptions( done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self) + return True + def cancelled(self): with self._condition: return self._cancelled diff --git a/src/python/grpcio/grpc/experimental/session_cache.py b/src/python/grpcio/grpc/experimental/session_cache.py new file mode 100644 index 0000000000..5c55f7c327 --- /dev/null +++ b/src/python/grpcio/grpc/experimental/session_cache.py @@ -0,0 +1,45 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""gRPC's APIs for TLS Session Resumption support""" + +from grpc._cython import cygrpc as _cygrpc + + +def ssl_session_cache_lru(capacity): + """Creates an SSLSessionCache with LRU replacement policy + + Args: + capacity: Size of the cache + + Returns: + An SSLSessionCache with LRU replacement policy that can be passed as a value for + the grpc.ssl_session_cache option to a grpc.Channel. SSL session caches are used + to store session tickets, which clients can present to resume previous TLS sessions + with a server. + """ + return SSLSessionCache(_cygrpc.SSLSessionCacheLRU(capacity)) + + +class SSLSessionCache(object): + """An encapsulation of a session cache used for TLS session resumption. + + Instances of this class can be passed to a Channel as values for the + grpc.ssl_session_cache option + """ + + def __init__(self, cache): + self._cache = cache + + def __int__(self): + return int(self._cache) diff --git a/src/python/grpcio/grpc/framework/foundation/callable_util.py b/src/python/grpcio/grpc/framework/foundation/callable_util.py index b9b9c49f17..24daf3406f 100644 --- a/src/python/grpcio/grpc/framework/foundation/callable_util.py +++ b/src/python/grpcio/grpc/framework/foundation/callable_util.py @@ -21,6 +21,8 @@ import logging import six +_LOGGER = logging.getLogger(__name__) + class Outcome(six.with_metaclass(abc.ABCMeta)): """A sum type describing the outcome of some call. @@ -53,7 +55,7 @@ def _call_logging_exceptions(behavior, message, *args, **kwargs): return _EasyOutcome(Outcome.Kind.RETURNED, behavior(*args, **kwargs), None) except Exception as e: # pylint: disable=broad-except - logging.exception(message) + _LOGGER.exception(message) return _EasyOutcome(Outcome.Kind.RAISED, None, e) diff --git a/src/python/grpcio/grpc/framework/foundation/logging_pool.py b/src/python/grpcio/grpc/framework/foundation/logging_pool.py index f75df10042..216e3990db 100644 --- a/src/python/grpcio/grpc/framework/foundation/logging_pool.py +++ b/src/python/grpcio/grpc/framework/foundation/logging_pool.py @@ -17,6 +17,8 @@ import logging from concurrent import futures +_LOGGER = logging.getLogger(__name__) + def _wrap(behavior): """Wraps an arbitrary callable behavior in exception-logging.""" @@ -25,7 +27,7 @@ def _wrap(behavior): try: return behavior(*args, **kwargs) except Exception: - logging.exception( + _LOGGER.exception( 'Unexpected exception from %s executed in logging pool!', behavior) raise diff --git a/src/python/grpcio/grpc/framework/foundation/stream_util.py b/src/python/grpcio/grpc/framework/foundation/stream_util.py index 04631d9899..1faaf29bd7 100644 --- a/src/python/grpcio/grpc/framework/foundation/stream_util.py +++ b/src/python/grpcio/grpc/framework/foundation/stream_util.py @@ -19,6 +19,7 @@ import threading from grpc.framework.foundation import stream _NO_VALUE = object() +_LOGGER = logging.getLogger(__name__) class TransformingConsumer(stream.Consumer): @@ -46,10 +47,10 @@ class IterableConsumer(stream.Consumer): self._values = [] self._active = True - def consume(self, stock_reply): + def consume(self, value): with self._condition: if self._active: - self._values.append(stock_reply) + self._values.append(value) self._condition.notify() def terminate(self): @@ -57,10 +58,10 @@ class IterableConsumer(stream.Consumer): self._active = False self._condition.notify() - def consume_and_terminate(self, stock_reply): + def consume_and_terminate(self, value): with self._condition: if self._active: - self._values.append(stock_reply) + self._values.append(value) self._active = False self._condition.notify() @@ -103,7 +104,7 @@ class ThreadSwitchingConsumer(stream.Consumer): else: sink.consume(value) except Exception as e: # pylint:disable=broad-except - logging.exception(e) + _LOGGER.exception(e) with self._lock: if terminate: diff --git a/src/python/grpcio_testing/grpc_testing/__init__.py b/src/python/grpcio_testing/grpc_testing/__init__.py index e87d0ffc96..65fdd1b8ca 100644 --- a/src/python/grpcio_testing/grpc_testing/__init__.py +++ b/src/python/grpcio_testing/grpc_testing/__init__.py @@ -14,9 +14,9 @@ """Objects for use in testing gRPC Python-using application code.""" import abc +import six from google.protobuf import descriptor -import six import grpc diff --git a/src/python/grpcio_testing/grpc_testing/_channel/_invocation.py b/src/python/grpcio_testing/grpc_testing/_channel/_invocation.py index ebce652eeb..191b1c1726 100644 --- a/src/python/grpcio_testing/grpc_testing/_channel/_invocation.py +++ b/src/python/grpcio_testing/grpc_testing/_channel/_invocation.py @@ -18,6 +18,7 @@ import threading import grpc _NOT_YET_OBSERVED = object() +_LOGGER = logging.getLogger(__name__) def _cancel(handler): @@ -248,7 +249,7 @@ def consume_requests(request_iterator, handler): break except Exception: # pylint: disable=broad-except details = 'Exception iterating requests!' - logging.exception(details) + _LOGGER.exception(details) handler.cancel(grpc.StatusCode.UNKNOWN, details) consumption = threading.Thread(target=_consume) diff --git a/src/python/grpcio_testing/grpc_testing/_server/_handler.py b/src/python/grpcio_testing/grpc_testing/_server/_handler.py index d4f50f6863..0e3404b0d0 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_handler.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_handler.py @@ -105,10 +105,10 @@ class _Handler(Handler): self._expiration_future.cancel() self._condition.notify_all() - def add_termination_callback(self, termination_callback): + def add_termination_callback(self, callback): with self._condition: if self._code is None: - self._termination_callbacks.append(termination_callback) + self._termination_callbacks.append(callback) return True else: return False diff --git a/src/python/grpcio_testing/grpc_testing/_server/_rpc.py b/src/python/grpcio_testing/grpc_testing/_server/_rpc.py index 2060e8daff..b856da100f 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_rpc.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_rpc.py @@ -18,6 +18,8 @@ import threading import grpc from grpc_testing import _common +_LOGGER = logging.getLogger(__name__) + class Rpc(object): @@ -47,7 +49,7 @@ class Rpc(object): try: callback() except Exception: # pylint: disable=broad-except - logging.exception('Exception calling server-side callback!') + _LOGGER.exception('Exception calling server-side callback!') callback_calling_thread = threading.Thread(target=call_back) callback_calling_thread.start() @@ -86,7 +88,7 @@ class Rpc(object): def application_exception_abort(self, exception): with self._condition: if exception not in self._rpc_errors: - logging.exception('Exception calling application!') + _LOGGER.exception('Exception calling application!') self._abort( grpc.StatusCode.UNKNOWN, 'Exception calling application: {}'.format(exception)) diff --git a/src/python/grpcio_testing/grpc_testing/_time.py b/src/python/grpcio_testing/grpc_testing/_time.py index afbdad3524..75e6db3458 100644 --- a/src/python/grpcio_testing/grpc_testing/_time.py +++ b/src/python/grpcio_testing/grpc_testing/_time.py @@ -21,13 +21,15 @@ import time as _time import grpc import grpc_testing +_LOGGER = logging.getLogger(__name__) + def _call(behaviors): for behavior in behaviors: try: behavior() except Exception: # pylint: disable=broad-except - logging.exception('Exception calling behavior "%r"!', behavior) + _LOGGER.exception('Exception calling behavior "%r"!', behavior) def _call_in_thread(behaviors): diff --git a/src/python/grpcio_tests/tests/_loader.py b/src/python/grpcio_tests/tests/_loader.py index be0af64646..80c107aa8e 100644 --- a/src/python/grpcio_tests/tests/_loader.py +++ b/src/python/grpcio_tests/tests/_loader.py @@ -48,12 +48,13 @@ class Loader(object): # measure unnecessarily suffers) coverage_context = coverage.Coverage(data_suffix=True) coverage_context.start() - modules = [importlib.import_module(name) for name in names] - for module in modules: - self.visit_module(module) - for module in modules: + imported_modules = tuple( + importlib.import_module(name) for name in names) + for imported_module in imported_modules: + self.visit_module(imported_module) + for imported_module in imported_modules: try: - package_paths = module.__path__ + package_paths = imported_module.__path__ except AttributeError: continue self.walk_packages(package_paths) diff --git a/src/python/grpcio_tests/tests/_result.py b/src/python/grpcio_tests/tests/_result.py index b105f18e78..e5378b7ea3 100644 --- a/src/python/grpcio_tests/tests/_result.py +++ b/src/python/grpcio_tests/tests/_result.py @@ -144,10 +144,6 @@ class AugmentedResult(unittest.TestResult): super(AugmentedResult, self).startTestRun() self.cases = dict() - def stopTestRun(self): - """See unittest.TestResult.stopTestRun.""" - super(AugmentedResult, self).stopTestRun() - def startTest(self, test): """See unittest.TestResult.startTest.""" super(AugmentedResult, self).startTest(test) @@ -155,19 +151,19 @@ class AugmentedResult(unittest.TestResult): self.cases[case_id] = CaseResult( id=case_id, name=test.id(), kind=CaseResult.Kind.RUNNING) - def addError(self, test, error): + def addError(self, test, err): """See unittest.TestResult.addError.""" - super(AugmentedResult, self).addError(test, error) + super(AugmentedResult, self).addError(test, err) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.ERROR, traceback=error) + kind=CaseResult.Kind.ERROR, traceback=err) - def addFailure(self, test, error): + def addFailure(self, test, err): """See unittest.TestResult.addFailure.""" - super(AugmentedResult, self).addFailure(test, error) + super(AugmentedResult, self).addFailure(test, err) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.FAILURE, traceback=error) + kind=CaseResult.Kind.FAILURE, traceback=err) def addSuccess(self, test): """See unittest.TestResult.addSuccess.""" @@ -183,12 +179,12 @@ class AugmentedResult(unittest.TestResult): self.cases[case_id] = self.cases[case_id].updated( kind=CaseResult.Kind.SKIP, skip_reason=reason) - def addExpectedFailure(self, test, error): + def addExpectedFailure(self, test, err): """See unittest.TestResult.addExpectedFailure.""" - super(AugmentedResult, self).addExpectedFailure(test, error) + super(AugmentedResult, self).addExpectedFailure(test, err) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.EXPECTED_FAILURE, traceback=error) + kind=CaseResult.Kind.EXPECTED_FAILURE, traceback=err) def addUnexpectedSuccess(self, test): """See unittest.TestResult.addUnexpectedSuccess.""" @@ -249,13 +245,6 @@ class CoverageResult(AugmentedResult): self.coverage_context.save() self.coverage_context = None - def stopTestRun(self): - """See unittest.TestResult.stopTestRun.""" - super(CoverageResult, self).stopTestRun() - # TODO(atash): Dig deeper into why the following line fails to properly - # combine coverage data from the Cython plugin. - #coverage.Coverage().combine() - class _Colors(object): """Namespaced constants for terminal color magic numbers.""" @@ -295,16 +284,16 @@ class TerminalResult(CoverageResult): self.out.write(summary(self)) self.out.flush() - def addError(self, test, error): + def addError(self, test, err): """See unittest.TestResult.addError.""" - super(TerminalResult, self).addError(test, error) + super(TerminalResult, self).addError(test, err) self.out.write( _Colors.FAIL + 'ERROR {}\n'.format(test.id()) + _Colors.END) self.out.flush() - def addFailure(self, test, error): + def addFailure(self, test, err): """See unittest.TestResult.addFailure.""" - super(TerminalResult, self).addFailure(test, error) + super(TerminalResult, self).addFailure(test, err) self.out.write( _Colors.FAIL + 'FAILURE {}\n'.format(test.id()) + _Colors.END) self.out.flush() @@ -323,9 +312,9 @@ class TerminalResult(CoverageResult): _Colors.INFO + 'SKIP {}\n'.format(test.id()) + _Colors.END) self.out.flush() - def addExpectedFailure(self, test, error): + def addExpectedFailure(self, test, err): """See unittest.TestResult.addExpectedFailure.""" - super(TerminalResult, self).addExpectedFailure(test, error) + super(TerminalResult, self).addExpectedFailure(test, err) self.out.write( _Colors.INFO + 'FAILURE_OK {}\n'.format(test.id()) + _Colors.END) self.out.flush() diff --git a/src/python/grpcio_tests/tests/interop/methods.py b/src/python/grpcio_tests/tests/interop/methods.py index b728ffd704..cda15a68a3 100644 --- a/src/python/grpcio_tests/tests/interop/methods.py +++ b/src/python/grpcio_tests/tests/interop/methods.py @@ -144,8 +144,8 @@ def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope, def _empty_unary(stub): response = stub.EmptyCall(empty_pb2.Empty()) if not isinstance(response, empty_pb2.Empty): - raise TypeError('response is of type "%s", not empty_pb2.Empty!', - type(response)) + raise TypeError( + 'response is of type "%s", not empty_pb2.Empty!' % type(response)) def _large_unary(stub): diff --git a/src/python/grpcio_tests/tests/interop/server.py b/src/python/grpcio_tests/tests/interop/server.py index 0810de2394..fd28d498a1 100644 --- a/src/python/grpcio_tests/tests/interop/server.py +++ b/src/python/grpcio_tests/tests/interop/server.py @@ -26,6 +26,7 @@ from tests.interop import resources from tests.unit import test_common _ONE_DAY_IN_SECONDS = 60 * 60 * 24 +_LOGGER = logging.getLogger(__name__) def serve(): @@ -52,14 +53,14 @@ def serve(): server.add_insecure_port('[::]:{}'.format(args.port)) server.start() - logging.info('Server serving.') + _LOGGER.info('Server serving.') try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except BaseException as e: - logging.info('Caught exception "%s"; stopping server...', e) + _LOGGER.info('Caught exception "%s"; stopping server...', e) server.stop(None) - logging.info('Server stopped; exiting.') + _LOGGER.info('Server stopped; exiting.') if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index 0d94426413..65460a9540 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -53,6 +53,7 @@ "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestCertConfigReuse", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithClientAuth", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithoutClientAuth", + "unit._session_cache_test.SSLSessionCacheTest", "unit.beta._beta_features_test.BetaFeaturesTest", "unit.beta._beta_features_test.ContextManagementAndLifecycleTest", "unit.beta._connectivity_channel_test.ConnectivityStatesTest", diff --git a/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/src/python/grpcio_tests/tests/unit/_auth_context_test.py index 8c1a30e032..d174051070 100644 --- a/src/python/grpcio_tests/tests/unit/_auth_context_test.py +++ b/src/python/grpcio_tests/tests/unit/_auth_context_test.py @@ -18,6 +18,7 @@ import unittest import grpc from grpc import _channel +from grpc.experimental import session_cache import six from tests.unit import test_common @@ -140,6 +141,50 @@ class AuthContextTest(unittest.TestCase): self.assertSequenceEqual([b'*.test.google.com'], auth_ctx['x509_common_name']) + def _do_one_shot_client_rpc(self, channel_creds, channel_options, port, + expect_ssl_session_reused): + channel = grpc.secure_channel( + 'localhost:{}'.format(port), channel_creds, options=channel_options) + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + auth_data = pickle.loads(response) + self.assertEqual(expect_ssl_session_reused, + auth_data[_AUTH_CTX]['ssl_session_reused']) + channel.close() + + def testSessionResumption(self): + # Set up a secure server + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = test_common.test_server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + server.start() + + # Create a cache for TLS session tickets + cache = session_cache.ssl_session_cache_lru(1) + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel_options = _PROPERTY_OPTIONS + ( + ('grpc.ssl_session_cache', cache),) + + # Initial connection has no session to resume + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b'false']) + + # Subsequent connections resume sessions + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b'true']) + server.stop(None) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py b/src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py deleted file mode 100644 index 5fb4f3c3cf..0000000000 --- a/src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2015 gRPC authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py b/src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py deleted file mode 100644 index 2bf1e1cc0d..0000000000 --- a/src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2015 gRPC authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO(nathaniel): Remove this from source control after having made -# generation from the stock.proto source part of GRPC's build-and-test -# process. - -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: stock.proto - -import sys -_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1')) -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -from google.protobuf import descriptor_pb2 -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - -DESCRIPTOR = _descriptor.FileDescriptor( - name='stock.proto', - package='stock', - serialized_pb=_b( - '\n\x0bstock.proto\x12\x05stock\">\n\x0cStockRequest\x12\x0e\n\x06symbol\x18\x01 \x01(\t\x12\x1e\n\x13num_trades_to_watch\x18\x02 \x01(\x05:\x01\x30\"+\n\nStockReply\x12\r\n\x05price\x18\x01 \x01(\x02\x12\x0e\n\x06symbol\x18\x02 \x01(\t2\x96\x02\n\x05Stock\x12=\n\x11GetLastTradePrice\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00\x12I\n\x19GetLastTradePriceMultiple\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00(\x01\x30\x01\x12?\n\x11WatchFutureTrades\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00\x30\x01\x12\x42\n\x14GetHighestTradePrice\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00(\x01' - )) -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -_STOCKREQUEST = _descriptor.Descriptor( - name='StockRequest', - full_name='stock.StockRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='symbol', - full_name='stock.StockRequest.symbol', - index=0, - number=1, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='num_trades_to_watch', - full_name='stock.StockRequest.num_trades_to_watch', - index=1, - number=2, - type=5, - cpp_type=1, - label=1, - has_default_value=True, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - extension_ranges=[], - oneofs=[], - serialized_start=22, - serialized_end=84,) - -_STOCKREPLY = _descriptor.Descriptor( - name='StockReply', - full_name='stock.StockReply', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='price', - full_name='stock.StockReply.price', - index=0, - number=1, - type=2, - cpp_type=6, - label=1, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='symbol', - full_name='stock.StockReply.symbol', - index=1, - number=2, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - extension_ranges=[], - oneofs=[], - serialized_start=86, - serialized_end=129,) - -DESCRIPTOR.message_types_by_name['StockRequest'] = _STOCKREQUEST -DESCRIPTOR.message_types_by_name['StockReply'] = _STOCKREPLY - -StockRequest = _reflection.GeneratedProtocolMessageType( - 'StockRequest', - (_message.Message,), - dict( - DESCRIPTOR=_STOCKREQUEST, - __module__='stock_pb2' - # @@protoc_insertion_point(class_scope:stock.StockRequest) - )) -_sym_db.RegisterMessage(StockRequest) - -StockReply = _reflection.GeneratedProtocolMessageType( - 'StockReply', - (_message.Message,), - dict( - DESCRIPTOR=_STOCKREPLY, - __module__='stock_pb2' - # @@protoc_insertion_point(class_scope:stock.StockReply) - )) -_sym_db.RegisterMessage(StockReply) - -# @@protoc_insertion_point(module_scope) diff --git a/src/python/grpcio_tests/tests/unit/_session_cache_test.py b/src/python/grpcio_tests/tests/unit/_session_cache_test.py new file mode 100644 index 0000000000..b4e4670fa7 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_session_cache_test.py @@ -0,0 +1,145 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests experimental TLS Session Resumption API""" + +import pickle +import unittest + +import grpc +from grpc import _channel +from grpc.experimental import session_cache + +from tests.unit import test_common +from tests.unit import resources + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_UNARY_UNARY = '/test/UnaryUnary' + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_ID = 'id' +_ID_KEY = 'id_key' +_AUTH_CTX = 'auth_ctx' + +_PRIVATE_KEY = resources.private_key() +_CERTIFICATE_CHAIN = resources.certificate_chain() +_TEST_ROOT_CERTIFICATES = resources.test_root_certificates() +_SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) +_PROPERTY_OPTIONS = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, +),) + + +def handle_unary_unary(request, servicer_context): + return pickle.dumps({ + _ID: servicer_context.peer_identities(), + _ID_KEY: servicer_context.peer_identity_key(), + _AUTH_CTX: servicer_context.auth_context() + }) + + +def start_secure_server(): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = test_common.test_server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + server.start() + + return server, port + + +class SSLSessionCacheTest(unittest.TestCase): + + def _do_one_shot_client_rpc(self, channel_creds, channel_options, port, + expect_ssl_session_reused): + channel = grpc.secure_channel( + 'localhost:{}'.format(port), channel_creds, options=channel_options) + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + auth_data = pickle.loads(response) + self.assertEqual(expect_ssl_session_reused, + auth_data[_AUTH_CTX]['ssl_session_reused']) + channel.close() + + def testSSLSessionCacheLRU(self): + server_1, port_1 = start_secure_server() + + cache = session_cache.ssl_session_cache_lru(1) + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel_options = _PROPERTY_OPTIONS + ( + ('grpc.ssl_session_cache', cache),) + + # Initial connection has no session to resume + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'false']) + + # Connection to server_1 resumes from initial session + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'true']) + + # Connection to a different server with the same name overwrites the cache entry + server_2, port_2 = start_secure_server() + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_2, + expect_ssl_session_reused=[b'false']) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_2, + expect_ssl_session_reused=[b'true']) + server_2.stop(None) + + # Connection to server_1 now falls back to full TLS handshake + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'false']) + + # Re-creating server_1 causes old sessions to become invalid + server_1.stop(None) + server_1, port_1 = start_secure_server() + + # Old sessions should no longer be valid + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'false']) + + # Resumption should work for subsequent connections + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'true']) + server_1.stop(None) + + +if __name__ == '__main__': + unittest.main(verbosity=2) |