aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Masood Malekghassemi <soltanmm@users.noreply.github.com>2015-08-31 12:26:35 -0400
committerGravatar Masood Malekghassemi <soltanmm@users.noreply.github.com>2015-08-31 12:26:35 -0400
commit5684561251f47bd2109052692e67b94dc21de8d1 (patch)
tree670d47b312a0319446ec6ac5019d32db1b203d79
parent8be7a0491c0e4938d04c677b4f5466aeb99c0cad (diff)
parent154e762ae8f2711cbb7097616859f5ff9c677ecf (diff)
Merge pull request #3147 from nathanielmanistaatgoogle/servicelink-shut-down
Fix gRPC links lifecycle tracking
-rw-r--r--src/python/grpcio/grpc/_links/invocation.py115
-rw-r--r--src/python/grpcio/grpc/_links/service.py80
2 files changed, 123 insertions, 72 deletions
diff --git a/src/python/grpcio/grpc/_links/invocation.py b/src/python/grpcio/grpc/_links/invocation.py
index ee3d72fdbc..729b987dd1 100644
--- a/src/python/grpcio/grpc/_links/invocation.py
+++ b/src/python/grpcio/grpc/_links/invocation.py
@@ -41,6 +41,13 @@ from grpc.framework.foundation import logging_pool
from grpc.framework.foundation import relay
from grpc.framework.interfaces.links import links
+_STOP = _intermediary_low.Event.Kind.STOP
+_WRITE = _intermediary_low.Event.Kind.WRITE_ACCEPTED
+_COMPLETE = _intermediary_low.Event.Kind.COMPLETE_ACCEPTED
+_READ = _intermediary_low.Event.Kind.READ_ACCEPTED
+_METADATA = _intermediary_low.Event.Kind.METADATA_ACCEPTED
+_FINISH = _intermediary_low.Event.Kind.FINISH
+
@enum.unique
class _Read(enum.Enum):
@@ -67,7 +74,7 @@ class _RPCState(object):
def __init__(
self, call, request_serializer, response_deserializer, sequence_number,
- read, allowance, high_write, low_write):
+ read, allowance, high_write, low_write, due):
self.call = call
self.request_serializer = request_serializer
self.response_deserializer = response_deserializer
@@ -76,6 +83,13 @@ class _RPCState(object):
self.allowance = allowance
self.high_write = high_write
self.low_write = low_write
+ self.due = due
+
+
+def _no_longer_due(kind, rpc_state, key, rpc_states):
+ rpc_state.due.remove(kind)
+ if not rpc_state.due:
+ del rpc_states[key]
class _Kernel(object):
@@ -91,12 +105,14 @@ class _Kernel(object):
self._relay = ticket_relay
self._completion_queue = None
- self._rpc_states = None
+ self._rpc_states = {}
self._pool = None
def _on_write_event(self, operation_id, unused_event, rpc_state):
if rpc_state.high_write is _HighWrite.CLOSED:
rpc_state.call.complete(operation_id)
+ rpc_state.due.add(_COMPLETE)
+ rpc_state.due.remove(_WRITE)
rpc_state.low_write = _LowWrite.CLOSED
else:
ticket = links.Ticket(
@@ -105,16 +121,19 @@ class _Kernel(object):
rpc_state.sequence_number += 1
self._relay.add_value(ticket)
rpc_state.low_write = _LowWrite.OPEN
+ _no_longer_due(_WRITE, rpc_state, operation_id, self._rpc_states)
def _on_read_event(self, operation_id, event, rpc_state):
- if event.bytes is None:
+ if event.bytes is None or _FINISH not in rpc_state.due:
rpc_state.read = _Read.CLOSED
+ _no_longer_due(_READ, rpc_state, operation_id, self._rpc_states)
else:
if 0 < rpc_state.allowance:
rpc_state.allowance -= 1
rpc_state.call.read(operation_id)
else:
rpc_state.read = _Read.AWAITING_ALLOWANCE
+ _no_longer_due(_READ, rpc_state, operation_id, self._rpc_states)
ticket = links.Ticket(
operation_id, rpc_state.sequence_number, None, None, None, None, None,
None, rpc_state.response_deserializer(event.bytes), None, None, None,
@@ -123,18 +142,23 @@ class _Kernel(object):
self._relay.add_value(ticket)
def _on_metadata_event(self, operation_id, event, rpc_state):
- rpc_state.allowance -= 1
- rpc_state.call.read(operation_id)
- rpc_state.read = _Read.READING
- ticket = links.Ticket(
- operation_id, rpc_state.sequence_number, None, None,
- links.Ticket.Subscription.FULL, None, None, event.metadata, None, None,
- None, None, None, None)
- rpc_state.sequence_number += 1
- self._relay.add_value(ticket)
+ if _FINISH in rpc_state.due:
+ rpc_state.allowance -= 1
+ rpc_state.call.read(operation_id)
+ rpc_state.read = _Read.READING
+ rpc_state.due.add(_READ)
+ rpc_state.due.remove(_METADATA)
+ ticket = links.Ticket(
+ operation_id, rpc_state.sequence_number, None, None,
+ links.Ticket.Subscription.FULL, None, None, event.metadata, None,
+ None, None, None, None, None)
+ rpc_state.sequence_number += 1
+ self._relay.add_value(ticket)
+ else:
+ _no_longer_due(_METADATA, rpc_state, operation_id, self._rpc_states)
def _on_finish_event(self, operation_id, event, rpc_state):
- self._rpc_states.pop(operation_id, None)
+ _no_longer_due(_FINISH, rpc_state, operation_id, self._rpc_states)
if event.status.code is _intermediary_low.Code.OK:
termination = links.Ticket.Termination.COMPLETION
elif event.status.code is _intermediary_low.Code.CANCELLED:
@@ -155,26 +179,26 @@ class _Kernel(object):
def _spin(self, completion_queue):
while True:
event = completion_queue.get(None)
- if event.kind is _intermediary_low.Event.Kind.STOP:
- return
- operation_id = event.tag
with self._lock:
- if self._completion_queue is None:
- continue
- rpc_state = self._rpc_states.get(operation_id)
- if rpc_state is not None:
- if event.kind is _intermediary_low.Event.Kind.WRITE_ACCEPTED:
- self._on_write_event(operation_id, event, rpc_state)
- elif event.kind is _intermediary_low.Event.Kind.METADATA_ACCEPTED:
- self._on_metadata_event(operation_id, event, rpc_state)
- elif event.kind is _intermediary_low.Event.Kind.READ_ACCEPTED:
- self._on_read_event(operation_id, event, rpc_state)
- elif event.kind is _intermediary_low.Event.Kind.FINISH:
- self._on_finish_event(operation_id, event, rpc_state)
- elif event.kind is _intermediary_low.Event.Kind.COMPLETE_ACCEPTED:
- pass
- else:
- logging.error('Illegal RPC event! %s', (event,))
+ rpc_state = self._rpc_states.get(event.tag, None)
+ if event.kind is _STOP:
+ pass
+ elif event.kind is _WRITE:
+ self._on_write_event(event.tag, event, rpc_state)
+ elif event.kind is _METADATA:
+ self._on_metadata_event(event.tag, event, rpc_state)
+ elif event.kind is _READ:
+ self._on_read_event(event.tag, event, rpc_state)
+ elif event.kind is _FINISH:
+ self._on_finish_event(event.tag, event, rpc_state)
+ elif event.kind is _COMPLETE:
+ _no_longer_due(_COMPLETE, rpc_state, event.tag, self._rpc_states)
+ else:
+ logging.error('Illegal RPC event! %s', (event,))
+
+ if self._completion_queue is None and not self._rpc_states:
+ completion_queue.stop()
+ return
def _invoke(
self, operation_id, group, method, initial_metadata, payload, termination,
@@ -221,26 +245,31 @@ class _Kernel(object):
if high_write is _HighWrite.CLOSED:
call.complete(operation_id)
low_write = _LowWrite.CLOSED
+ due = set((_METADATA, _COMPLETE, _FINISH,))
else:
low_write = _LowWrite.OPEN
+ due = set((_METADATA, _FINISH,))
else:
call.write(request_serializer(payload), operation_id)
low_write = _LowWrite.ACTIVE
+ due = set((_WRITE, _METADATA, _FINISH,))
self._rpc_states[operation_id] = _RPCState(
call, request_serializer, response_deserializer, 0,
_Read.AWAITING_METADATA, 1 if allowance is None else (1 + allowance),
- high_write, low_write)
+ high_write, low_write, due)
def _advance(self, operation_id, rpc_state, payload, termination, allowance):
if payload is not None:
rpc_state.call.write(rpc_state.request_serializer(payload), operation_id)
rpc_state.low_write = _LowWrite.ACTIVE
+ rpc_state.due.add(_WRITE)
if allowance is not None:
if rpc_state.read is _Read.AWAITING_ALLOWANCE:
rpc_state.allowance += allowance - 1
rpc_state.call.read(operation_id)
rpc_state.read = _Read.READING
+ rpc_state.due.add(_READ)
else:
rpc_state.allowance += allowance
@@ -248,19 +277,21 @@ class _Kernel(object):
rpc_state.high_write = _HighWrite.CLOSED
if rpc_state.low_write is _LowWrite.OPEN:
rpc_state.call.complete(operation_id)
+ rpc_state.due.add(_COMPLETE)
rpc_state.low_write = _LowWrite.CLOSED
elif termination is not None:
rpc_state.call.cancel()
def add_ticket(self, ticket):
with self._lock:
- if self._completion_queue is None:
- return
if ticket.sequence_number == 0:
- self._invoke(
- ticket.operation_id, ticket.group, ticket.method,
- ticket.initial_metadata, ticket.payload, ticket.termination,
- ticket.timeout, ticket.allowance)
+ if self._completion_queue is None:
+ logging.error('Received invocation ticket %s after stop!', ticket)
+ else:
+ self._invoke(
+ ticket.operation_id, ticket.group, ticket.method,
+ ticket.initial_metadata, ticket.payload, ticket.termination,
+ ticket.timeout, ticket.allowance)
else:
rpc_state = self._rpc_states.get(ticket.operation_id)
if rpc_state is not None:
@@ -276,7 +307,6 @@ class _Kernel(object):
"""
with self._lock:
self._completion_queue = _intermediary_low.CompletionQueue()
- self._rpc_states = {}
self._pool = logging_pool.pool(1)
self._pool.submit(self._spin, self._completion_queue)
@@ -288,11 +318,10 @@ class _Kernel(object):
has been called.
"""
with self._lock:
- self._completion_queue.stop()
+ if not self._rpc_states:
+ self._completion_queue.stop()
self._completion_queue = None
pool = self._pool
- self._pool = None
- self._rpc_states = None
pool.shutdown(wait=True)
diff --git a/src/python/grpcio/grpc/_links/service.py b/src/python/grpcio/grpc/_links/service.py
index c5ecc47cd9..bbfe9bcd55 100644
--- a/src/python/grpcio/grpc/_links/service.py
+++ b/src/python/grpcio/grpc/_links/service.py
@@ -53,6 +53,13 @@ _TERMINATION_KIND_TO_CODE = {
links.Ticket.Termination.REMOTE_FAILURE: _intermediary_low.Code.UNKNOWN,
}
+_STOP = _intermediary_low.Event.Kind.STOP
+_WRITE = _intermediary_low.Event.Kind.WRITE_ACCEPTED
+_COMPLETE = _intermediary_low.Event.Kind.COMPLETE_ACCEPTED
+_SERVICE = _intermediary_low.Event.Kind.SERVICE_ACCEPTED
+_READ = _intermediary_low.Event.Kind.READ_ACCEPTED
+_FINISH = _intermediary_low.Event.Kind.FINISH
+
@enum.unique
class _Read(enum.Enum):
@@ -84,7 +91,7 @@ class _RPCState(object):
def __init__(
self, request_deserializer, response_serializer, sequence_number, read,
early_read, allowance, high_write, low_write, premetadataed,
- terminal_metadata, code, message):
+ terminal_metadata, code, message, due):
self.request_deserializer = request_deserializer
self.response_serializer = response_serializer
self.sequence_number = sequence_number
@@ -99,6 +106,13 @@ class _RPCState(object):
self.terminal_metadata = terminal_metadata
self.code = code
self.message = message
+ self.due = due
+
+
+def _no_longer_due(kind, rpc_state, key, rpc_states):
+ rpc_state.due.remove(kind)
+ if not rpc_state.due:
+ del rpc_states[key]
def _metadatafy(call, metadata):
@@ -124,6 +138,7 @@ class _Kernel(object):
self._relay = ticket_relay
self._completion_queue = None
+ self._due = set()
self._server = None
self._rpc_states = {}
self._pool = None
@@ -149,7 +164,8 @@ class _Kernel(object):
call.read(call)
self._rpc_states[call] = _RPCState(
request_deserializer, response_serializer, 1, _Read.READING, None, 1,
- _HighWrite.OPEN, _LowWrite.OPEN, False, None, None, None)
+ _HighWrite.OPEN, _LowWrite.OPEN, False, None, None, None,
+ set((_READ, _FINISH,)))
ticket = links.Ticket(
call, 0, group, method, links.Ticket.Subscription.FULL,
service_acceptance.deadline - time.time(), None, event.metadata, None,
@@ -158,14 +174,13 @@ class _Kernel(object):
def _on_read_event(self, event):
call = event.tag
- rpc_state = self._rpc_states.get(call, None)
- if rpc_state is None:
- return
+ rpc_state = self._rpc_states[call]
if event.bytes is None:
rpc_state.read = _Read.CLOSED
payload = None
termination = links.Ticket.Termination.COMPLETION
+ _no_longer_due(_READ, rpc_state, call, self._rpc_states)
else:
if 0 < rpc_state.allowance:
payload = rpc_state.request_deserializer(event.bytes)
@@ -174,6 +189,7 @@ class _Kernel(object):
call.read(call)
else:
rpc_state.early_read = event.bytes
+ _no_longer_due(_READ, rpc_state, call, self._rpc_states)
return
# TODO(issue 2916): Instead of returning:
# rpc_state.read = _Read.AWAITING_ALLOWANCE
@@ -185,9 +201,7 @@ class _Kernel(object):
def _on_write_event(self, event):
call = event.tag
- rpc_state = self._rpc_states.get(call, None)
- if rpc_state is None:
- return
+ rpc_state = self._rpc_states[call]
if rpc_state.high_write is _HighWrite.CLOSED:
if rpc_state.terminal_metadata is not None:
@@ -197,6 +211,8 @@ class _Kernel(object):
rpc_state.message)
call.status(status, call)
rpc_state.low_write = _LowWrite.CLOSED
+ rpc_state.due.add(_COMPLETE)
+ rpc_state.due.remove(_WRITE)
else:
ticket = links.Ticket(
call, rpc_state.sequence_number, None, None, None, None, 1, None,
@@ -204,12 +220,12 @@ class _Kernel(object):
rpc_state.sequence_number += 1
self._relay.add_value(ticket)
rpc_state.low_write = _LowWrite.OPEN
+ _no_longer_due(_WRITE, rpc_state, call, self._rpc_states)
def _on_finish_event(self, event):
call = event.tag
- rpc_state = self._rpc_states.pop(call, None)
- if rpc_state is None:
- return
+ rpc_state = self._rpc_states[call]
+ _no_longer_due(_FINISH, rpc_state, call, self._rpc_states)
code = event.status.code
if code is _intermediary_low.Code.OK:
return
@@ -229,28 +245,33 @@ class _Kernel(object):
def _spin(self, completion_queue, server):
while True:
event = completion_queue.get(None)
- if event.kind is _intermediary_low.Event.Kind.STOP:
- return
with self._lock:
- if self._server is None:
- continue
- elif event.kind is _intermediary_low.Event.Kind.SERVICE_ACCEPTED:
- self._on_service_acceptance_event(event, server)
- elif event.kind is _intermediary_low.Event.Kind.READ_ACCEPTED:
+ if event.kind is _STOP:
+ self._due.remove(_STOP)
+ elif event.kind is _READ:
self._on_read_event(event)
- elif event.kind is _intermediary_low.Event.Kind.WRITE_ACCEPTED:
+ elif event.kind is _WRITE:
self._on_write_event(event)
- elif event.kind is _intermediary_low.Event.Kind.COMPLETE_ACCEPTED:
- pass
+ elif event.kind is _COMPLETE:
+ _no_longer_due(
+ _COMPLETE, self._rpc_states.get(event.tag), event.tag,
+ self._rpc_states)
elif event.kind is _intermediary_low.Event.Kind.FINISH:
self._on_finish_event(event)
+ elif event.kind is _SERVICE:
+ if self._server is None:
+ self._due.remove(_SERVICE)
+ else:
+ self._on_service_acceptance_event(event, server)
else:
logging.error('Illegal event! %s', (event,))
+ if not self._due and not self._rpc_states:
+ completion_queue.stop()
+ return
+
def add_ticket(self, ticket):
with self._lock:
- if self._server is None:
- return
call = ticket.operation_id
rpc_state = self._rpc_states.get(call)
if rpc_state is None:
@@ -278,6 +299,7 @@ class _Kernel(object):
rpc_state.early_read = None
if rpc_state.read is _Read.READING:
call.read(call)
+ rpc_state.due.add(_READ)
termination = None
else:
termination = links.Ticket.Termination.COMPLETION
@@ -289,6 +311,7 @@ class _Kernel(object):
if ticket.payload is not None:
call.write(rpc_state.response_serializer(ticket.payload), call)
+ rpc_state.due.add(_WRITE)
rpc_state.low_write = _LowWrite.ACTIVE
if ticket.terminal_metadata is not None:
@@ -307,6 +330,7 @@ class _Kernel(object):
links.Ticket.Termination.COMPLETION, rpc_state.code,
rpc_state.message)
call.status(status, call)
+ rpc_state.due.add(_COMPLETE)
rpc_state.low_write = _LowWrite.CLOSED
elif ticket.termination is not None:
if rpc_state.terminal_metadata is not None:
@@ -314,7 +338,7 @@ class _Kernel(object):
status = _status(
ticket.termination, rpc_state.code, rpc_state.message)
call.status(status, call)
- self._rpc_states.pop(call, None)
+ rpc_state.due.add(_COMPLETE)
def add_port(self, address, server_credentials):
with self._lock:
@@ -335,19 +359,17 @@ class _Kernel(object):
self._pool.submit(self._spin, self._completion_queue, self._server)
self._server.start()
self._server.service(None)
+ self._due.add(_SERVICE)
def begin_stop(self):
with self._lock:
self._server.stop()
+ self._due.add(_STOP)
self._server = None
def end_stop(self):
with self._lock:
- self._completion_queue.stop()
- self._completion_queue = None
pool = self._pool
- self._pool = None
- self._rpc_states = None
pool.shutdown(wait=True)
@@ -369,7 +391,7 @@ class ServiceLink(links.Link):
None for insecure service.
Returns:
- A integer port on which RPCs will be serviced after this link has been
+ An integer port on which RPCs will be serviced after this link has been
started. This is typically the same number as the port number contained
in the passed address, but will likely be different if the port number
contained in the passed address was zero.