# Copyright 2015-2016, Google Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above # copyright notice, this list of conditions and the following disclaimer # in the documentation and/or other materials provided with the # distribution. # * Neither the name of Google Inc. nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import threading from grpc import _grpcio_metadata from grpc._cython import cygrpc from grpc._adapter import _implementations from grpc._adapter import _types _USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__) ChannelCredentials = cygrpc.ChannelCredentials CallCredentials = cygrpc.CallCredentials ServerCredentials = cygrpc.ServerCredentials channel_credentials_composite = cygrpc.channel_credentials_composite call_credentials_composite = cygrpc.call_credentials_composite def server_credentials_ssl(root_credentials, pair_sequence, force_client_auth): return cygrpc.server_credentials_ssl( root_credentials, [cygrpc.SslPemKeyCertPair(key, pem) for key, pem in pair_sequence], force_client_auth) def channel_credentials_ssl( root_certificates, private_key, certificate_chain): pair = None if private_key is not None or certificate_chain is not None: pair = cygrpc.SslPemKeyCertPair(private_key, certificate_chain) return cygrpc.channel_credentials_ssl(root_certificates, pair) class _WrappedCygrpcCallback(object): def __init__(self, cygrpc_callback): self.is_called = False self.error = None self.is_called_lock = threading.Lock() self.cygrpc_callback = cygrpc_callback def _invoke_failure(self, error): # TODO(atash) translate different Exception superclasses into different # status codes. self.cygrpc_callback( cygrpc.Metadata([]), cygrpc.StatusCode.internal, error.message) def _invoke_success(self, metadata): try: cygrpc_metadata = cygrpc.Metadata( cygrpc.Metadatum(key, value) for key, value in metadata) except Exception as error: self._invoke_failure(error) return self.cygrpc_callback(cygrpc_metadata, cygrpc.StatusCode.ok, '') def __call__(self, metadata, error): with self.is_called_lock: if self.is_called: raise RuntimeError('callback should only ever be invoked once') if self.error: self._invoke_failure(self.error) return self.is_called = True if error is None: self._invoke_success(metadata) else: self._invoke_failure(error) def notify_failure(self, error): with self.is_called_lock: if not self.is_called: self.error = error class _WrappedPlugin(object): def __init__(self, plugin): self.plugin = plugin def __call__(self, context, cygrpc_callback): wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback) wrapped_context = _implementations.AuthMetadataContext(context.service_url, context.method_name) try: self.plugin( wrapped_context, _implementations.AuthMetadataPluginCallback(wrapped_cygrpc_callback)) except Exception as error: wrapped_cygrpc_callback.notify_failure(error) raise def call_credentials_metadata_plugin(plugin, name): """ Args: plugin: A callable accepting a _types.AuthMetadataContext object and a callback (itself accepting a list of metadata key/value 2-tuples and a None-able exception value). The callback must be eventually called, but need not be called in plugin's invocation. plugin's invocation must be non-blocking. """ return cygrpc.call_credentials_metadata_plugin( cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), name)) class CompletionQueue(_types.CompletionQueue): def __init__(self): self.completion_queue = cygrpc.CompletionQueue() def next(self, deadline=float('+inf')): raw_event = self.completion_queue.poll(cygrpc.Timespec(deadline)) if raw_event.type == cygrpc.CompletionType.queue_timeout: return None event_type = raw_event.type event_tag = raw_event.tag event_call = Call(raw_event.operation_call) if raw_event.request_call_details: event_call_details = _types.CallDetails( raw_event.request_call_details.method, raw_event.request_call_details.host, float(raw_event.request_call_details.deadline)) else: event_call_details = None event_success = raw_event.success event_results = [] if raw_event.is_new_request: event_results.append(_types.OpResult( _types.OpType.RECV_INITIAL_METADATA, raw_event.request_metadata, None, None, None, None)) else: if raw_event.batch_operations: for operation in raw_event.batch_operations: result_type = operation.type result_initial_metadata = operation.received_metadata_or_none result_trailing_metadata = operation.received_metadata_or_none result_message = operation.received_message_or_none if result_message is not None: result_message = result_message.bytes() result_cancelled = operation.received_cancelled_or_none if operation.has_status: result_status = _types.Status( operation.received_status_code_or_none, operation.received_status_details_or_none) else: result_status = None event_results.append( _types.OpResult(result_type, result_initial_metadata, result_trailing_metadata, result_message, result_status, result_cancelled)) return _types.Event(event_type, event_tag, event_call, event_call_details, event_results, event_success) def shutdown(self): self.completion_queue.shutdown() class Call(_types.Call): def __init__(self, call): self.call = call def start_batch(self, ops, tag): translated_ops = [] for op in ops: if op.type == _types.OpType.SEND_INITIAL_METADATA: translated_op = cygrpc.operation_send_initial_metadata( cygrpc.Metadata( cygrpc.Metadatum(key, value) for key, value in op.initial_metadata)) elif op.type == _types.OpType.SEND_MESSAGE: translated_op = cygrpc.operation_send_message(op.message) elif op.type == _types.OpType.SEND_CLOSE_FROM_CLIENT: translated_op = cygrpc.operation_send_close_from_client() elif op.type == _types.OpType.SEND_STATUS_FROM_SERVER: translated_op = cygrpc.operation_send_status_from_server( cygrpc.Metadata( cygrpc.Metadatum(key, value) for key, value in op.trailing_metadata), op.status.code, op.status.details) elif op.type == _types.OpType.RECV_INITIAL_METADATA: translated_op = cygrpc.operation_receive_initial_metadata() elif op.type == _types.OpType.RECV_MESSAGE: translated_op = cygrpc.operation_receive_message() elif op.type == _types.OpType.RECV_STATUS_ON_CLIENT: translated_op = cygrpc.operation_receive_status_on_client() elif op.type == _types.OpType.RECV_CLOSE_ON_SERVER: translated_op = cygrpc.operation_receive_close_on_server() else: raise ValueError('unexpected operation type {}'.format(op.type)) translated_ops.append(translated_op) return self.call.start_batch(cygrpc.Operations(translated_ops), tag) def cancel(self, code=None, details=None): if code is None and details is None: return self.call.cancel() else: return self.call.cancel(code, details) def peer(self): return self.call.peer() def set_credentials(self, creds): return self.call.set_credentials(creds) class Channel(_types.Channel): def __init__(self, target, args, creds=None): args = list(args) + [ (cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)] args = cygrpc.ChannelArgs( cygrpc.ChannelArg(key, value) for key, value in args) if creds is None: self.channel = cygrpc.Channel(target, args) else: self.channel = cygrpc.Channel(target, args, creds) def create_call(self, completion_queue, method, host, deadline=None): internal_call = self.channel.create_call( None, 0, completion_queue.completion_queue, method, host, cygrpc.Timespec(deadline)) return Call(internal_call) def check_connectivity_state(self, try_to_connect): return self.channel.check_connectivity_state(try_to_connect) def watch_connectivity_state(self, last_observed_state, deadline, completion_queue, tag): self.channel.watch_connectivity_state( last_observed_state, cygrpc.Timespec(deadline), completion_queue.completion_queue, tag) def target(self): return self.channel.target() _NO_TAG = object() class Server(_types.Server): def __init__(self, completion_queue, args): args = cygrpc.ChannelArgs( cygrpc.ChannelArg(key, value) for key, value in args) self.server = cygrpc.Server(args) self.server.register_completion_queue(completion_queue.completion_queue) self.server_queue = completion_queue def add_http2_port(self, addr, creds=None): if creds is None: return self.server.add_http2_port(addr) else: return self.server.add_http2_port(addr, creds) def start(self): return self.server.start() def shutdown(self, tag=None): return self.server.shutdown(self.server_queue.completion_queue, tag) def request_call(self, completion_queue, tag): return self.server.request_call(completion_queue.completion_queue, self.server_queue.completion_queue, tag) def cancel_all_calls(self): return self.server.cancel_all_calls()