aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio
diff options
context:
space:
mode:
authorGravatar Nathaniel Manista <nathaniel@google.com>2015-09-05 02:46:08 +0000
committerGravatar Nathaniel Manista <nathaniel@google.com>2015-09-05 02:46:43 +0000
commit41abb052b8ed180ca14b9ff1427a215b8e4dcd60 (patch)
treebbba3a18aacfec705f16f79144ac38a9ded65ca4 /src/python/grpcio
parent13db8e517dd390cb1ede0146f5b0182e5f1e4dd5 (diff)
gRPC protocol objects
Diffstat (limited to 'src/python/grpcio')
-rw-r--r--src/python/grpcio/grpc/_adapter/_intermediary_low.py10
-rw-r--r--src/python/grpcio/grpc/_adapter/fore.py4
-rw-r--r--src/python/grpcio/grpc/_adapter/rear.py6
-rw-r--r--src/python/grpcio/grpc/_links/invocation.py57
-rw-r--r--src/python/grpcio/grpc/_links/service.py39
-rw-r--r--src/python/grpcio/grpc/beta/interfaces.py58
6 files changed, 154 insertions, 20 deletions
diff --git a/src/python/grpcio/grpc/_adapter/_intermediary_low.py b/src/python/grpcio/grpc/_adapter/_intermediary_low.py
index 735ad205a4..e2feec6ffb 100644
--- a/src/python/grpcio/grpc/_adapter/_intermediary_low.py
+++ b/src/python/grpcio/grpc/_adapter/_intermediary_low.py
@@ -59,6 +59,7 @@ from grpc._adapter import _types
_IGNORE_ME_TAG = object()
Code = _types.StatusCode
+WriteFlags = _types.OpWriteFlags
class Status(collections.namedtuple('Status', ['code', 'details'])):
@@ -125,9 +126,9 @@ class Call(object):
], _TagAdapter(finish_tag, Event.Kind.FINISH))
return err0 if err0 != _types.CallError.OK else err1 if err1 != _types.CallError.OK else err2 if err2 != _types.CallError.OK else _types.CallError.OK
- def write(self, message, tag):
+ def write(self, message, tag, flags):
return self._internal.start_batch([
- _types.OpArgs.send_message(message, 0)
+ _types.OpArgs.send_message(message, flags)
], _TagAdapter(tag, Event.Kind.WRITE_ACCEPTED))
def complete(self, tag):
@@ -163,8 +164,11 @@ class Call(object):
def cancel(self):
return self._internal.cancel()
+ def peer(self):
+ return self._internal.peer()
+
def set_credentials(self, creds):
- return self._internal.set_credentials(creds)
+ return self._internal.set_credentials(creds._internal)
class Channel(object):
diff --git a/src/python/grpcio/grpc/_adapter/fore.py b/src/python/grpcio/grpc/_adapter/fore.py
index daa41e8bde..acdd69c420 100644
--- a/src/python/grpcio/grpc/_adapter/fore.py
+++ b/src/python/grpcio/grpc/_adapter/fore.py
@@ -56,7 +56,7 @@ class _LowWrite(enum.Enum):
def _write(call, rpc_state, payload):
serialized_payload = rpc_state.serializer(payload)
if rpc_state.write.low is _LowWrite.OPEN:
- call.write(serialized_payload, call)
+ call.write(serialized_payload, call, 0)
rpc_state.write.low = _LowWrite.ACTIVE
else:
rpc_state.write.pending.append(serialized_payload)
@@ -164,7 +164,7 @@ class ForeLink(base_interfaces.ForeLink, activated.Activated):
if rpc_state.write.pending:
serialized_payload = rpc_state.write.pending.pop(0)
- call.write(serialized_payload, call)
+ call.write(serialized_payload, call, 0)
elif rpc_state.write.high is _common.HighWrite.CLOSED:
_status(call, rpc_state)
else:
diff --git a/src/python/grpcio/grpc/_adapter/rear.py b/src/python/grpcio/grpc/_adapter/rear.py
index fd6f45f7a7..17fa47f746 100644
--- a/src/python/grpcio/grpc/_adapter/rear.py
+++ b/src/python/grpcio/grpc/_adapter/rear.py
@@ -78,7 +78,7 @@ class _RPCState(object):
def _write(operation_id, call, outstanding, write_state, serialized_payload):
if write_state.low is _LowWrite.OPEN:
- call.write(serialized_payload, operation_id)
+ call.write(serialized_payload, operation_id, 0)
outstanding.add(_low.Event.Kind.WRITE_ACCEPTED)
write_state.low = _LowWrite.ACTIVE
elif write_state.low is _LowWrite.ACTIVE:
@@ -144,7 +144,7 @@ class RearLink(base_interfaces.RearLink, activated.Activated):
if event.write_accepted:
if rpc_state.common.write.pending:
rpc_state.call.write(
- rpc_state.common.write.pending.pop(0), operation_id)
+ rpc_state.common.write.pending.pop(0), operation_id, 0)
rpc_state.outstanding.add(_low.Event.Kind.WRITE_ACCEPTED)
elif rpc_state.common.write.high is _common.HighWrite.CLOSED:
rpc_state.call.complete(operation_id)
@@ -263,7 +263,7 @@ class RearLink(base_interfaces.RearLink, activated.Activated):
low_state = _LowWrite.OPEN
else:
serialized_payload = request_serializer(payload)
- call.write(serialized_payload, operation_id)
+ call.write(serialized_payload, operation_id, 0)
outstanding.add(_low.Event.Kind.WRITE_ACCEPTED)
low_state = _LowWrite.ACTIVE
diff --git a/src/python/grpcio/grpc/_links/invocation.py b/src/python/grpcio/grpc/_links/invocation.py
index fecb550ae0..67ef86a176 100644
--- a/src/python/grpcio/grpc/_links/invocation.py
+++ b/src/python/grpcio/grpc/_links/invocation.py
@@ -37,6 +37,7 @@ import time
from grpc._adapter import _intermediary_low
from grpc._links import _constants
+from grpc.beta import interfaces as beta_interfaces
from grpc.framework.foundation import activated
from grpc.framework.foundation import logging_pool
from grpc.framework.foundation import relay
@@ -73,11 +74,28 @@ class _LowWrite(enum.Enum):
CLOSED = 'CLOSED'
+class _Context(beta_interfaces.GRPCInvocationContext):
+
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._disable_next_compression = False
+
+ def disable_next_request_compression(self):
+ with self._lock:
+ self._disable_next_compression = True
+
+ def next_compression_disabled(self):
+ with self._lock:
+ disabled = self._disable_next_compression
+ self._disable_next_compression = False
+ return disabled
+
+
class _RPCState(object):
def __init__(
self, call, request_serializer, response_deserializer, sequence_number,
- read, allowance, high_write, low_write, due):
+ read, allowance, high_write, low_write, due, context):
self.call = call
self.request_serializer = request_serializer
self.response_deserializer = response_deserializer
@@ -87,6 +105,7 @@ class _RPCState(object):
self.high_write = high_write
self.low_write = low_write
self.due = due
+ self.context = context
def _no_longer_due(kind, rpc_state, key, rpc_states):
@@ -209,7 +228,7 @@ class _Kernel(object):
def _invoke(
self, operation_id, group, method, initial_metadata, payload, termination,
- timeout, allowance):
+ timeout, allowance, options):
"""Invoke an RPC.
Args:
@@ -224,6 +243,7 @@ class _Kernel(object):
timeout: A duration of time in seconds to allow for the RPC.
allowance: The number of payloads (beyond the free first one) that the
local ticket exchange mate has granted permission to be read.
+ options: A beta_interfaces.GRPCCallOptions value or None.
"""
if termination is links.Ticket.Termination.COMPLETION:
high_write = _HighWrite.CLOSED
@@ -241,6 +261,8 @@ class _Kernel(object):
call = _intermediary_low.Call(
self._channel, self._completion_queue, '/%s/%s' % (group, method),
self._host, time.time() + timeout)
+ if options is not None and options.credentials is not None:
+ call.set_credentials(options.credentials._intermediary_low_credentials)
if transformed_initial_metadata is not None:
for metadata_key, metadata_value in transformed_initial_metadata:
call.add_metadata(metadata_key, metadata_value)
@@ -254,17 +276,33 @@ class _Kernel(object):
low_write = _LowWrite.OPEN
due = set((_METADATA, _FINISH,))
else:
- call.write(request_serializer(payload), operation_id)
+ if options is not None and options.disable_compression:
+ flags = _intermediary_low.WriteFlags.WRITE_NO_COMPRESS
+ else:
+ flags = 0
+ call.write(request_serializer(payload), operation_id, flags)
low_write = _LowWrite.ACTIVE
due = set((_WRITE, _METADATA, _FINISH,))
+ context = _Context()
self._rpc_states[operation_id] = _RPCState(
- call, request_serializer, response_deserializer, 0,
+ call, request_serializer, response_deserializer, 1,
_Read.AWAITING_METADATA, 1 if allowance is None else (1 + allowance),
- high_write, low_write, due)
+ high_write, low_write, due, context)
+ protocol = links.Protocol(links.Protocol.Kind.INVOCATION_CONTEXT, context)
+ ticket = links.Ticket(
+ operation_id, 0, None, None, None, None, None, None, None, None, None,
+ None, None, protocol)
+ self._relay.add_value(ticket)
def _advance(self, operation_id, rpc_state, payload, termination, allowance):
if payload is not None:
- rpc_state.call.write(rpc_state.request_serializer(payload), operation_id)
+ disable_compression = rpc_state.context.next_compression_disabled()
+ if disable_compression:
+ flags = _intermediary_low.WriteFlags.WRITE_NO_COMPRESS
+ else:
+ flags = 0
+ rpc_state.call.write(
+ rpc_state.request_serializer(payload), operation_id, flags)
rpc_state.low_write = _LowWrite.ACTIVE
rpc_state.due.add(_WRITE)
@@ -292,10 +330,15 @@ class _Kernel(object):
if self._completion_queue is None:
logging.error('Received invocation ticket %s after stop!', ticket)
else:
+ if (ticket.protocol is not None and
+ ticket.protocol.kind is links.Protocol.Kind.CALL_OPTION):
+ grpc_call_options = ticket.protocol.value
+ else:
+ grpc_call_options = None
self._invoke(
ticket.operation_id, ticket.group, ticket.method,
ticket.initial_metadata, ticket.payload, ticket.termination,
- ticket.timeout, ticket.allowance)
+ ticket.timeout, ticket.allowance, grpc_call_options)
else:
rpc_state = self._rpc_states.get(ticket.operation_id)
if rpc_state is not None:
diff --git a/src/python/grpcio/grpc/_links/service.py b/src/python/grpcio/grpc/_links/service.py
index 07772c7de3..f56df84007 100644
--- a/src/python/grpcio/grpc/_links/service.py
+++ b/src/python/grpcio/grpc/_links/service.py
@@ -37,6 +37,7 @@ import time
from grpc._adapter import _intermediary_low
from grpc._links import _constants
+from grpc.beta import interfaces as beta_interfaces
from grpc.framework.foundation import logging_pool
from grpc.framework.foundation import relay
from grpc.framework.interfaces.links import links
@@ -89,12 +90,34 @@ class _LowWrite(enum.Enum):
CLOSED = 'CLOSED'
+class _Context(beta_interfaces.GRPCServicerContext):
+
+ def __init__(self, call):
+ self._lock = threading.Lock()
+ self._call = call
+ self._disable_next_compression = False
+
+ def peer(self):
+ with self._lock:
+ return self._call.peer()
+
+ def disable_next_response_compression(self):
+ with self._lock:
+ self._disable_next_compression = True
+
+ def next_compression_disabled(self):
+ with self._lock:
+ disabled = self._disable_next_compression
+ self._disable_next_compression = False
+ return disabled
+
+
class _RPCState(object):
def __init__(
self, request_deserializer, response_serializer, sequence_number, read,
early_read, allowance, high_write, low_write, premetadataed,
- terminal_metadata, code, message, due):
+ terminal_metadata, code, message, due, context):
self.request_deserializer = request_deserializer
self.response_serializer = response_serializer
self.sequence_number = sequence_number
@@ -110,6 +133,7 @@ class _RPCState(object):
self.code = code
self.message = message
self.due = due
+ self.context = context
def _no_longer_due(kind, rpc_state, key, rpc_states):
@@ -163,12 +187,12 @@ class _Kernel(object):
(group, method), _IDENTITY)
call.read(call)
+ context = _Context(call)
self._rpc_states[call] = _RPCState(
request_deserializer, response_serializer, 1, _Read.READING, None, 1,
_HighWrite.OPEN, _LowWrite.OPEN, False, None, None, None,
- set((_READ, _FINISH,)))
- protocol = links.Protocol(
- links.Protocol.Kind.SERVICER_CONTEXT, 'TODO: Service Context Object!')
+ set((_READ, _FINISH,)), context)
+ protocol = links.Protocol(links.Protocol.Kind.SERVICER_CONTEXT, context)
ticket = links.Ticket(
call, 0, group, method, links.Ticket.Subscription.FULL,
service_acceptance.deadline - time.time(), None, event.metadata, None,
@@ -313,7 +337,12 @@ class _Kernel(object):
self._relay.add_value(early_read_ticket)
if ticket.payload is not None:
- call.write(rpc_state.response_serializer(ticket.payload), call)
+ disable_compression = rpc_state.context.next_compression_disabled()
+ if disable_compression:
+ flags = _intermediary_low.WriteFlags.WRITE_NO_COMPRESS
+ else:
+ flags = 0
+ call.write(rpc_state.response_serializer(ticket.payload), call, flags)
rpc_state.due.add(_WRITE)
rpc_state.low_write = _LowWrite.ACTIVE
diff --git a/src/python/grpcio/grpc/beta/interfaces.py b/src/python/grpcio/grpc/beta/interfaces.py
index 25e6a9c66b..79f2620dd4 100644
--- a/src/python/grpcio/grpc/beta/interfaces.py
+++ b/src/python/grpcio/grpc/beta/interfaces.py
@@ -29,6 +29,7 @@
"""Constants and interfaces of the Beta API of gRPC Python."""
+import abc
import enum
@@ -52,3 +53,60 @@ class StatusCode(enum.Enum):
UNAVAILABLE = 14
DATA_LOSS = 15
UNAUTHENTICATED = 16
+
+
+class GRPCCallOptions(object):
+ """A value encapsulating gRPC-specific options passed on RPC invocation.
+
+ This class and its instances have no supported interface - it exists to
+ define the type of its instances and its instances exist to be passed to
+ other functions.
+ """
+
+ def __init__(self, disable_compression, subcall_of, credentials):
+ self.disable_compression = disable_compression
+ self.subcall_of = subcall_of
+ self.credentials = credentials
+
+
+def grpc_call_options(disable_compression=False, credentials=None):
+ """Creates a GRPCCallOptions value to be passed at RPC invocation.
+
+ All parameters are optional and should always be passed by keyword.
+
+ Args:
+ disable_compression: A boolean indicating whether or not compression should
+ be disabled for the request object of the RPC. Only valid for
+ request-unary RPCs.
+ credentials: A ClientCredentials object to use for the invoked RPC.
+ """
+ return GRPCCallOptions(disable_compression, None, credentials)
+
+
+class GRPCServicerContext(object):
+ """Exposes gRPC-specific options and behaviors to code servicing RPCs."""
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def peer(self):
+ """Identifies the peer that invoked the RPC being serviced.
+
+ Returns:
+ A string identifying the peer that invoked the RPC being serviced.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def disable_next_response_compression(self):
+ """Disables compression of the next response passed by the application."""
+ raise NotImplementedError()
+
+
+class GRPCInvocationContext(object):
+ """Exposes gRPC-specific options and behaviors to code invoking RPCs."""
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def disable_next_request_compression(self):
+ """Disables compression of the next request passed by the application."""
+ raise NotImplementedError()