diff options
author | Nathaniel Manista <nathaniel@google.com> | 2015-09-05 02:46:08 +0000 |
---|---|---|
committer | Nathaniel Manista <nathaniel@google.com> | 2015-09-05 02:46:43 +0000 |
commit | 41abb052b8ed180ca14b9ff1427a215b8e4dcd60 (patch) | |
tree | bbba3a18aacfec705f16f79144ac38a9ded65ca4 /src/python/grpcio | |
parent | 13db8e517dd390cb1ede0146f5b0182e5f1e4dd5 (diff) |
gRPC protocol objects
Diffstat (limited to 'src/python/grpcio')
-rw-r--r-- | src/python/grpcio/grpc/_adapter/_intermediary_low.py | 10 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_adapter/fore.py | 4 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_adapter/rear.py | 6 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_links/invocation.py | 57 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_links/service.py | 39 | ||||
-rw-r--r-- | src/python/grpcio/grpc/beta/interfaces.py | 58 |
6 files changed, 154 insertions, 20 deletions
diff --git a/src/python/grpcio/grpc/_adapter/_intermediary_low.py b/src/python/grpcio/grpc/_adapter/_intermediary_low.py index 735ad205a4..e2feec6ffb 100644 --- a/src/python/grpcio/grpc/_adapter/_intermediary_low.py +++ b/src/python/grpcio/grpc/_adapter/_intermediary_low.py @@ -59,6 +59,7 @@ from grpc._adapter import _types _IGNORE_ME_TAG = object() Code = _types.StatusCode +WriteFlags = _types.OpWriteFlags class Status(collections.namedtuple('Status', ['code', 'details'])): @@ -125,9 +126,9 @@ class Call(object): ], _TagAdapter(finish_tag, Event.Kind.FINISH)) return err0 if err0 != _types.CallError.OK else err1 if err1 != _types.CallError.OK else err2 if err2 != _types.CallError.OK else _types.CallError.OK - def write(self, message, tag): + def write(self, message, tag, flags): return self._internal.start_batch([ - _types.OpArgs.send_message(message, 0) + _types.OpArgs.send_message(message, flags) ], _TagAdapter(tag, Event.Kind.WRITE_ACCEPTED)) def complete(self, tag): @@ -163,8 +164,11 @@ class Call(object): def cancel(self): return self._internal.cancel() + def peer(self): + return self._internal.peer() + def set_credentials(self, creds): - return self._internal.set_credentials(creds) + return self._internal.set_credentials(creds._internal) class Channel(object): diff --git a/src/python/grpcio/grpc/_adapter/fore.py b/src/python/grpcio/grpc/_adapter/fore.py index daa41e8bde..acdd69c420 100644 --- a/src/python/grpcio/grpc/_adapter/fore.py +++ b/src/python/grpcio/grpc/_adapter/fore.py @@ -56,7 +56,7 @@ class _LowWrite(enum.Enum): def _write(call, rpc_state, payload): serialized_payload = rpc_state.serializer(payload) if rpc_state.write.low is _LowWrite.OPEN: - call.write(serialized_payload, call) + call.write(serialized_payload, call, 0) rpc_state.write.low = _LowWrite.ACTIVE else: rpc_state.write.pending.append(serialized_payload) @@ -164,7 +164,7 @@ class ForeLink(base_interfaces.ForeLink, activated.Activated): if rpc_state.write.pending: serialized_payload = rpc_state.write.pending.pop(0) - call.write(serialized_payload, call) + call.write(serialized_payload, call, 0) elif rpc_state.write.high is _common.HighWrite.CLOSED: _status(call, rpc_state) else: diff --git a/src/python/grpcio/grpc/_adapter/rear.py b/src/python/grpcio/grpc/_adapter/rear.py index fd6f45f7a7..17fa47f746 100644 --- a/src/python/grpcio/grpc/_adapter/rear.py +++ b/src/python/grpcio/grpc/_adapter/rear.py @@ -78,7 +78,7 @@ class _RPCState(object): def _write(operation_id, call, outstanding, write_state, serialized_payload): if write_state.low is _LowWrite.OPEN: - call.write(serialized_payload, operation_id) + call.write(serialized_payload, operation_id, 0) outstanding.add(_low.Event.Kind.WRITE_ACCEPTED) write_state.low = _LowWrite.ACTIVE elif write_state.low is _LowWrite.ACTIVE: @@ -144,7 +144,7 @@ class RearLink(base_interfaces.RearLink, activated.Activated): if event.write_accepted: if rpc_state.common.write.pending: rpc_state.call.write( - rpc_state.common.write.pending.pop(0), operation_id) + rpc_state.common.write.pending.pop(0), operation_id, 0) rpc_state.outstanding.add(_low.Event.Kind.WRITE_ACCEPTED) elif rpc_state.common.write.high is _common.HighWrite.CLOSED: rpc_state.call.complete(operation_id) @@ -263,7 +263,7 @@ class RearLink(base_interfaces.RearLink, activated.Activated): low_state = _LowWrite.OPEN else: serialized_payload = request_serializer(payload) - call.write(serialized_payload, operation_id) + call.write(serialized_payload, operation_id, 0) outstanding.add(_low.Event.Kind.WRITE_ACCEPTED) low_state = _LowWrite.ACTIVE diff --git a/src/python/grpcio/grpc/_links/invocation.py b/src/python/grpcio/grpc/_links/invocation.py index fecb550ae0..67ef86a176 100644 --- a/src/python/grpcio/grpc/_links/invocation.py +++ b/src/python/grpcio/grpc/_links/invocation.py @@ -37,6 +37,7 @@ import time from grpc._adapter import _intermediary_low from grpc._links import _constants +from grpc.beta import interfaces as beta_interfaces from grpc.framework.foundation import activated from grpc.framework.foundation import logging_pool from grpc.framework.foundation import relay @@ -73,11 +74,28 @@ class _LowWrite(enum.Enum): CLOSED = 'CLOSED' +class _Context(beta_interfaces.GRPCInvocationContext): + + def __init__(self): + self._lock = threading.Lock() + self._disable_next_compression = False + + def disable_next_request_compression(self): + with self._lock: + self._disable_next_compression = True + + def next_compression_disabled(self): + with self._lock: + disabled = self._disable_next_compression + self._disable_next_compression = False + return disabled + + class _RPCState(object): def __init__( self, call, request_serializer, response_deserializer, sequence_number, - read, allowance, high_write, low_write, due): + read, allowance, high_write, low_write, due, context): self.call = call self.request_serializer = request_serializer self.response_deserializer = response_deserializer @@ -87,6 +105,7 @@ class _RPCState(object): self.high_write = high_write self.low_write = low_write self.due = due + self.context = context def _no_longer_due(kind, rpc_state, key, rpc_states): @@ -209,7 +228,7 @@ class _Kernel(object): def _invoke( self, operation_id, group, method, initial_metadata, payload, termination, - timeout, allowance): + timeout, allowance, options): """Invoke an RPC. Args: @@ -224,6 +243,7 @@ class _Kernel(object): timeout: A duration of time in seconds to allow for the RPC. allowance: The number of payloads (beyond the free first one) that the local ticket exchange mate has granted permission to be read. + options: A beta_interfaces.GRPCCallOptions value or None. """ if termination is links.Ticket.Termination.COMPLETION: high_write = _HighWrite.CLOSED @@ -241,6 +261,8 @@ class _Kernel(object): call = _intermediary_low.Call( self._channel, self._completion_queue, '/%s/%s' % (group, method), self._host, time.time() + timeout) + if options is not None and options.credentials is not None: + call.set_credentials(options.credentials._intermediary_low_credentials) if transformed_initial_metadata is not None: for metadata_key, metadata_value in transformed_initial_metadata: call.add_metadata(metadata_key, metadata_value) @@ -254,17 +276,33 @@ class _Kernel(object): low_write = _LowWrite.OPEN due = set((_METADATA, _FINISH,)) else: - call.write(request_serializer(payload), operation_id) + if options is not None and options.disable_compression: + flags = _intermediary_low.WriteFlags.WRITE_NO_COMPRESS + else: + flags = 0 + call.write(request_serializer(payload), operation_id, flags) low_write = _LowWrite.ACTIVE due = set((_WRITE, _METADATA, _FINISH,)) + context = _Context() self._rpc_states[operation_id] = _RPCState( - call, request_serializer, response_deserializer, 0, + call, request_serializer, response_deserializer, 1, _Read.AWAITING_METADATA, 1 if allowance is None else (1 + allowance), - high_write, low_write, due) + high_write, low_write, due, context) + protocol = links.Protocol(links.Protocol.Kind.INVOCATION_CONTEXT, context) + ticket = links.Ticket( + operation_id, 0, None, None, None, None, None, None, None, None, None, + None, None, protocol) + self._relay.add_value(ticket) def _advance(self, operation_id, rpc_state, payload, termination, allowance): if payload is not None: - rpc_state.call.write(rpc_state.request_serializer(payload), operation_id) + disable_compression = rpc_state.context.next_compression_disabled() + if disable_compression: + flags = _intermediary_low.WriteFlags.WRITE_NO_COMPRESS + else: + flags = 0 + rpc_state.call.write( + rpc_state.request_serializer(payload), operation_id, flags) rpc_state.low_write = _LowWrite.ACTIVE rpc_state.due.add(_WRITE) @@ -292,10 +330,15 @@ class _Kernel(object): if self._completion_queue is None: logging.error('Received invocation ticket %s after stop!', ticket) else: + if (ticket.protocol is not None and + ticket.protocol.kind is links.Protocol.Kind.CALL_OPTION): + grpc_call_options = ticket.protocol.value + else: + grpc_call_options = None self._invoke( ticket.operation_id, ticket.group, ticket.method, ticket.initial_metadata, ticket.payload, ticket.termination, - ticket.timeout, ticket.allowance) + ticket.timeout, ticket.allowance, grpc_call_options) else: rpc_state = self._rpc_states.get(ticket.operation_id) if rpc_state is not None: diff --git a/src/python/grpcio/grpc/_links/service.py b/src/python/grpcio/grpc/_links/service.py index 07772c7de3..f56df84007 100644 --- a/src/python/grpcio/grpc/_links/service.py +++ b/src/python/grpcio/grpc/_links/service.py @@ -37,6 +37,7 @@ import time from grpc._adapter import _intermediary_low from grpc._links import _constants +from grpc.beta import interfaces as beta_interfaces from grpc.framework.foundation import logging_pool from grpc.framework.foundation import relay from grpc.framework.interfaces.links import links @@ -89,12 +90,34 @@ class _LowWrite(enum.Enum): CLOSED = 'CLOSED' +class _Context(beta_interfaces.GRPCServicerContext): + + def __init__(self, call): + self._lock = threading.Lock() + self._call = call + self._disable_next_compression = False + + def peer(self): + with self._lock: + return self._call.peer() + + def disable_next_response_compression(self): + with self._lock: + self._disable_next_compression = True + + def next_compression_disabled(self): + with self._lock: + disabled = self._disable_next_compression + self._disable_next_compression = False + return disabled + + class _RPCState(object): def __init__( self, request_deserializer, response_serializer, sequence_number, read, early_read, allowance, high_write, low_write, premetadataed, - terminal_metadata, code, message, due): + terminal_metadata, code, message, due, context): self.request_deserializer = request_deserializer self.response_serializer = response_serializer self.sequence_number = sequence_number @@ -110,6 +133,7 @@ class _RPCState(object): self.code = code self.message = message self.due = due + self.context = context def _no_longer_due(kind, rpc_state, key, rpc_states): @@ -163,12 +187,12 @@ class _Kernel(object): (group, method), _IDENTITY) call.read(call) + context = _Context(call) self._rpc_states[call] = _RPCState( request_deserializer, response_serializer, 1, _Read.READING, None, 1, _HighWrite.OPEN, _LowWrite.OPEN, False, None, None, None, - set((_READ, _FINISH,))) - protocol = links.Protocol( - links.Protocol.Kind.SERVICER_CONTEXT, 'TODO: Service Context Object!') + set((_READ, _FINISH,)), context) + protocol = links.Protocol(links.Protocol.Kind.SERVICER_CONTEXT, context) ticket = links.Ticket( call, 0, group, method, links.Ticket.Subscription.FULL, service_acceptance.deadline - time.time(), None, event.metadata, None, @@ -313,7 +337,12 @@ class _Kernel(object): self._relay.add_value(early_read_ticket) if ticket.payload is not None: - call.write(rpc_state.response_serializer(ticket.payload), call) + disable_compression = rpc_state.context.next_compression_disabled() + if disable_compression: + flags = _intermediary_low.WriteFlags.WRITE_NO_COMPRESS + else: + flags = 0 + call.write(rpc_state.response_serializer(ticket.payload), call, flags) rpc_state.due.add(_WRITE) rpc_state.low_write = _LowWrite.ACTIVE diff --git a/src/python/grpcio/grpc/beta/interfaces.py b/src/python/grpcio/grpc/beta/interfaces.py index 25e6a9c66b..79f2620dd4 100644 --- a/src/python/grpcio/grpc/beta/interfaces.py +++ b/src/python/grpcio/grpc/beta/interfaces.py @@ -29,6 +29,7 @@ """Constants and interfaces of the Beta API of gRPC Python.""" +import abc import enum @@ -52,3 +53,60 @@ class StatusCode(enum.Enum): UNAVAILABLE = 14 DATA_LOSS = 15 UNAUTHENTICATED = 16 + + +class GRPCCallOptions(object): + """A value encapsulating gRPC-specific options passed on RPC invocation. + + This class and its instances have no supported interface - it exists to + define the type of its instances and its instances exist to be passed to + other functions. + """ + + def __init__(self, disable_compression, subcall_of, credentials): + self.disable_compression = disable_compression + self.subcall_of = subcall_of + self.credentials = credentials + + +def grpc_call_options(disable_compression=False, credentials=None): + """Creates a GRPCCallOptions value to be passed at RPC invocation. + + All parameters are optional and should always be passed by keyword. + + Args: + disable_compression: A boolean indicating whether or not compression should + be disabled for the request object of the RPC. Only valid for + request-unary RPCs. + credentials: A ClientCredentials object to use for the invoked RPC. + """ + return GRPCCallOptions(disable_compression, None, credentials) + + +class GRPCServicerContext(object): + """Exposes gRPC-specific options and behaviors to code servicing RPCs.""" + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def peer(self): + """Identifies the peer that invoked the RPC being serviced. + + Returns: + A string identifying the peer that invoked the RPC being serviced. + """ + raise NotImplementedError() + + @abc.abstractmethod + def disable_next_response_compression(self): + """Disables compression of the next response passed by the application.""" + raise NotImplementedError() + + +class GRPCInvocationContext(object): + """Exposes gRPC-specific options and behaviors to code invoking RPCs.""" + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def disable_next_request_compression(self): + """Disables compression of the next request passed by the application.""" + raise NotImplementedError() |