# Copyright 2016 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Translates gRPC's client-side API into gRPC's client-side Beta API.""" import grpc from grpc import _common from grpc.beta import interfaces from grpc.framework.common import cardinality from grpc.framework.foundation import future from grpc.framework.interfaces.face import face # pylint: disable=too-many-arguments,too-many-locals,unused-argument _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = { grpc.StatusCode.CANCELLED: (face.Abortion.Kind.CANCELLED, face.CancellationError), grpc.StatusCode.UNKNOWN: (face.Abortion.Kind.REMOTE_FAILURE, face.RemoteError), grpc.StatusCode.DEADLINE_EXCEEDED: (face.Abortion.Kind.EXPIRED, face.ExpirationError), grpc.StatusCode.UNIMPLEMENTED: (face.Abortion.Kind.LOCAL_FAILURE, face.LocalError), } def _effective_metadata(metadata, metadata_transformer): non_none_metadata = () if metadata is None else metadata if metadata_transformer is None: return non_none_metadata else: return metadata_transformer(non_none_metadata) def _credentials(grpc_call_options): return None if grpc_call_options is None else grpc_call_options.credentials def _abortion(rpc_error_call): code = rpc_error_call.code() pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0] return face.Abortion(error_kind, rpc_error_call.initial_metadata(), rpc_error_call.trailing_metadata(), code, rpc_error_call.details()) def _abortion_error(rpc_error_call): code = rpc_error_call.code() pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) exception_class = face.AbortionError if pair is None else pair[1] return exception_class(rpc_error_call.initial_metadata(), rpc_error_call.trailing_metadata(), code, rpc_error_call.details()) class _InvocationProtocolContext(interfaces.GRPCInvocationContext): def disable_next_request_compression(self): pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement. class _Rendezvous(future.Future, face.Call): def __init__(self, response_future, response_iterator, call): self._future = response_future self._iterator = response_iterator self._call = call def cancel(self): return self._call.cancel() def cancelled(self): return self._future.cancelled() def running(self): return self._future.running() def done(self): return self._future.done() def result(self, timeout=None): try: return self._future.result(timeout=timeout) except grpc.RpcError as rpc_error_call: raise _abortion_error(rpc_error_call) except grpc.FutureTimeoutError: raise future.TimeoutError() except grpc.FutureCancelledError: raise future.CancelledError() def exception(self, timeout=None): try: rpc_error_call = self._future.exception(timeout=timeout) if rpc_error_call is None: return None else: return _abortion_error(rpc_error_call) except grpc.FutureTimeoutError: raise future.TimeoutError() except grpc.FutureCancelledError: raise future.CancelledError() def traceback(self, timeout=None): try: return self._future.traceback(timeout=timeout) except grpc.FutureTimeoutError: raise future.TimeoutError() except grpc.FutureCancelledError: raise future.CancelledError() def add_done_callback(self, fn): self._future.add_done_callback(lambda ignored_callback: fn(self)) def __iter__(self): return self def _next(self): try: return next(self._iterator) except grpc.RpcError as rpc_error_call: raise _abortion_error(rpc_error_call) def __next__(self): return self._next() def next(self): return self._next() def is_active(self): return self._call.is_active() def time_remaining(self): return self._call.time_remaining() def add_abortion_callback(self, abortion_callback): def done_callback(): if self.code() is not grpc.StatusCode.OK: abortion_callback(_abortion(self._call)) registered = self._call.add_callback(done_callback) return None if registered else done_callback() def protocol_context(self): return _InvocationProtocolContext() def initial_metadata(self): return self._call.initial_metadata() def terminal_metadata(self): return self._call.terminal_metadata() def code(self): return self._call.code() def details(self): return self._call.details() def _blocking_unary_unary(channel, group, method, timeout, with_call, protocol_options, metadata, metadata_transformer, request, request_serializer, response_deserializer): try: multi_callable = channel.unary_unary( _common.fully_qualified_method(group, method), request_serializer=request_serializer, response_deserializer=response_deserializer) effective_metadata = _effective_metadata(metadata, metadata_transformer) if with_call: response, call = multi_callable.with_call( request, timeout=timeout, metadata=effective_metadata, credentials=_credentials(protocol_options)) return response, _Rendezvous(None, None, call) else: return multi_callable( request, timeout=timeout, metadata=effective_metadata, credentials=_credentials(protocol_options)) except grpc.RpcError as rpc_error_call: raise _abortion_error(rpc_error_call) def _future_unary_unary(channel, group, method, timeout, protocol_options, metadata, metadata_transformer, request, request_serializer, response_deserializer): multi_callable = channel.unary_unary( _common.fully_qualified_method(group, method), request_serializer=request_serializer, response_deserializer=response_deserializer) effective_metadata = _effective_metadata(metadata, metadata_transformer) response_future = multi_callable.future( request, timeout=timeout, metadata=effective_metadata, credentials=_credentials(protocol_options)) return _Rendezvous(response_future, None, response_future) def _unary_stream(channel, group, method, timeout, protocol_options, metadata, metadata_transformer, request, request_serializer, response_deserializer): multi_callable = channel.unary_stream( _common.fully_qualified_method(group, method), request_serializer=request_serializer, response_deserializer=response_deserializer) effective_metadata = _effective_metadata(metadata, metadata_transformer) response_iterator = multi_callable( request, timeout=timeout, metadata=effective_metadata, credentials=_credentials(protocol_options)) return _Rendezvous(None, response_iterator, response_iterator) def _blocking_stream_unary(channel, group, method, timeout, with_call, protocol_options, metadata, metadata_transformer, request_iterator, request_serializer, response_deserializer): try: multi_callable = channel.stream_unary( _common.fully_qualified_method(group, method), request_serializer=request_serializer, response_deserializer=response_deserializer) effective_metadata = _effective_metadata(metadata, metadata_transformer) if with_call: response, call = multi_callable.with_call( request_iterator, timeout=timeout, metadata=effective_metadata, credentials=_credentials(protocol_options)) return response, _Rendezvous(None, None, call) else: return multi_callable( request_iterator, timeout=timeout, metadata=effective_metadata, credentials=_credentials(protocol_options)) except grpc.RpcError as rpc_error_call: raise _abortion_error(rpc_error_call) def _future_stream_unary(channel, group, method, timeout, protocol_options, metadata, metadata_transformer, request_iterator, request_serializer, response_deserializer): multi_callable = channel.stream_unary( _common.fully_qualified_method(group, method), request_serializer=request_serializer, response_deserializer=response_deserializer) effective_metadata = _effective_metadata(metadata, metadata_transformer) response_future = multi_callable.future( request_iterator, timeout=timeout, metadata=effective_metadata, credentials=_credentials(protocol_options)) return _Rendezvous(response_future, None, response_future) def _stream_stream(channel, group, method, timeout, protocol_options, metadata, metadata_transformer, request_iterator, request_serializer, response_deserializer): multi_callable = channel.stream_stream( _common.fully_qualified_method(group, method), request_serializer=request_serializer, response_deserializer=response_deserializer) effective_metadata = _effective_metadata(metadata, metadata_transformer) response_iterator = multi_callable( request_iterator, timeout=timeout, metadata=effective_metadata, credentials=_credentials(protocol_options)) return _Rendezvous(None, response_iterator, response_iterator) class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable): def __init__(self, channel, group, method, metadata_transformer, request_serializer, response_deserializer): self._channel = channel self._group = group self._method = method self._metadata_transformer = metadata_transformer self._request_serializer = request_serializer self._response_deserializer = response_deserializer def __call__(self, request, timeout, metadata=None, with_call=False, protocol_options=None): return _blocking_unary_unary( self._channel, self._group, self._method, timeout, with_call, protocol_options, metadata, self._metadata_transformer, request, self._request_serializer, self._response_deserializer) def future(self, request, timeout, metadata=None, protocol_options=None): return _future_unary_unary( self._channel, self._group, self._method, timeout, protocol_options, metadata, self._metadata_transformer, request, self._request_serializer, self._response_deserializer) def event(self, request, receiver, abortion_callback, timeout, metadata=None, protocol_options=None): raise NotImplementedError() class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable): def __init__(self, channel, group, method, metadata_transformer, request_serializer, response_deserializer): self._channel = channel self._group = group self._method = method self._metadata_transformer = metadata_transformer self._request_serializer = request_serializer self._response_deserializer = response_deserializer def __call__(self, request, timeout, metadata=None, protocol_options=None): return _unary_stream( self._channel, self._group, self._method, timeout, protocol_options, metadata, self._metadata_transformer, request, self._request_serializer, self._response_deserializer) def event(self, request, receiver, abortion_callback, timeout, metadata=None, protocol_options=None): raise NotImplementedError() class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable): def __init__(self, channel, group, method, metadata_transformer, request_serializer, response_deserializer): self._channel = channel self._group = group self._method = method self._metadata_transformer = metadata_transformer self._request_serializer = request_serializer self._response_deserializer = response_deserializer def __call__(self, request_iterator, timeout, metadata=None, with_call=False, protocol_options=None): return _blocking_stream_unary( self._channel, self._group, self._method, timeout, with_call, protocol_options, metadata, self._metadata_transformer, request_iterator, self._request_serializer, self._response_deserializer) def future(self, request_iterator, timeout, metadata=None, protocol_options=None): return _future_stream_unary( self._channel, self._group, self._method, timeout, protocol_options, metadata, self._metadata_transformer, request_iterator, self._request_serializer, self._response_deserializer) def event(self, receiver, abortion_callback, timeout, metadata=None, protocol_options=None): raise NotImplementedError() class _StreamStreamMultiCallable(face.StreamStreamMultiCallable): def __init__(self, channel, group, method, metadata_transformer, request_serializer, response_deserializer): self._channel = channel self._group = group self._method = method self._metadata_transformer = metadata_transformer self._request_serializer = request_serializer self._response_deserializer = response_deserializer def __call__(self, request_iterator, timeout, metadata=None, protocol_options=None): return _stream_stream( self._channel, self._group, self._method, timeout, protocol_options, metadata, self._metadata_transformer, request_iterator, self._request_serializer, self._response_deserializer) def event(self, receiver, abortion_callback, timeout, metadata=None, protocol_options=None): raise NotImplementedError() class _GenericStub(face.GenericStub): def __init__(self, channel, metadata_transformer, request_serializers, response_deserializers): self._channel = channel self._metadata_transformer = metadata_transformer self._request_serializers = request_serializers or {} self._response_deserializers = response_deserializers or {} def blocking_unary_unary(self, group, method, request, timeout, metadata=None, with_call=None, protocol_options=None): request_serializer = self._request_serializers.get((group, method,)) response_deserializer = self._response_deserializers.get((group, method,)) return _blocking_unary_unary(self._channel, group, method, timeout, with_call, protocol_options, metadata, self._metadata_transformer, request, request_serializer, response_deserializer) def future_unary_unary(self, group, method, request, timeout, metadata=None, protocol_options=None): request_serializer = self._request_serializers.get((group, method,)) response_deserializer = self._response_deserializers.get((group, method,)) return _future_unary_unary(self._channel, group, method, timeout, protocol_options, metadata, self._metadata_transformer, request, request_serializer, response_deserializer) def inline_unary_stream(self, group, method, request, timeout, metadata=None, protocol_options=None): request_serializer = self._request_serializers.get((group, method,)) response_deserializer = self._response_deserializers.get((group, method,)) return _unary_stream(self._channel, group, method, timeout, protocol_options, metadata, self._metadata_transformer, request, request_serializer, response_deserializer) def blocking_stream_unary(self, group, method, request_iterator, timeout, metadata=None, with_call=None, protocol_options=None): request_serializer = self._request_serializers.get((group, method,)) response_deserializer = self._response_deserializers.get((group, method,)) return _blocking_stream_unary( self._channel, group, method, timeout, with_call, protocol_options, metadata, self._metadata_transformer, request_iterator, request_serializer, response_deserializer) def future_stream_unary(self, group, method, request_iterator, timeout, metadata=None, protocol_options=None): request_serializer = self._request_serializers.get((group, method,)) response_deserializer = self._response_deserializers.get((group, method,)) return _future_stream_unary( self._channel, group, method, timeout, protocol_options, metadata, self._metadata_transformer, request_iterator, request_serializer, response_deserializer) def inline_stream_stream(self, group, method, request_iterator, timeout, metadata=None, protocol_options=None): request_serializer = self._request_serializers.get((group, method,)) response_deserializer = self._response_deserializers.get((group, method,)) return _stream_stream(self._channel, group, method, timeout, protocol_options, metadata, self._metadata_transformer, request_iterator, request_serializer, response_deserializer) def event_unary_unary(self, group, method, request, receiver, abortion_callback, timeout, metadata=None, protocol_options=None): raise NotImplementedError() def event_unary_stream(self, group, method, request, receiver, abortion_callback, timeout, metadata=None, protocol_options=None): raise NotImplementedError() def event_stream_unary(self, group, method, receiver, abortion_callback, timeout, metadata=None, protocol_options=None): raise NotImplementedError() def event_stream_stream(self, group, method, receiver, abortion_callback, timeout, metadata=None, protocol_options=None): raise NotImplementedError() def unary_unary(self, group, method): request_serializer = self._request_serializers.get((group, method,)) response_deserializer = self._response_deserializers.get((group, method,)) return _UnaryUnaryMultiCallable( self._channel, group, method, self._metadata_transformer, request_serializer, response_deserializer) def unary_stream(self, group, method): request_serializer = self._request_serializers.get((group, method,)) response_deserializer = self._response_deserializers.get((group, method,)) return _UnaryStreamMultiCallable( self._channel, group, method, self._metadata_transformer, request_serializer, response_deserializer) def stream_unary(self, group, method): request_serializer = self._request_serializers.get((group, method,)) response_deserializer = self._response_deserializers.get((group, method,)) return _StreamUnaryMultiCallable( self._channel, group, method, self._metadata_transformer, request_serializer, response_deserializer) def stream_stream(self, group, method): request_serializer = self._request_serializers.get((group, method,)) response_deserializer = self._response_deserializers.get((group, method,)) return _StreamStreamMultiCallable( self._channel, group, method, self._metadata_transformer, request_serializer, response_deserializer) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): return False class _DynamicStub(face.DynamicStub): def __init__(self, backing_generic_stub, group, cardinalities): self._generic_stub = backing_generic_stub self._group = group self._cardinalities = cardinalities def __getattr__(self, attr): method_cardinality = self._cardinalities.get(attr) if method_cardinality is cardinality.Cardinality.UNARY_UNARY: return self._generic_stub.unary_unary(self._group, attr) elif method_cardinality is cardinality.Cardinality.UNARY_STREAM: return self._generic_stub.unary_stream(self._group, attr) elif method_cardinality is cardinality.Cardinality.STREAM_UNARY: return self._generic_stub.stream_unary(self._group, attr) elif method_cardinality is cardinality.Cardinality.STREAM_STREAM: return self._generic_stub.stream_stream(self._group, attr) else: raise AttributeError('_DynamicStub object has no attribute "%s"!' % attr) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): return False def generic_stub(channel, host, metadata_transformer, request_serializers, response_deserializers): return _GenericStub(channel, metadata_transformer, request_serializers, response_deserializers) def dynamic_stub(channel, service, cardinalities, host, metadata_transformer, request_serializers, response_deserializers): return _DynamicStub( _GenericStub(channel, metadata_transformer, request_serializers, response_deserializers), service, cardinalities)