aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/src/grpc/_adapter/rear.py
blob: 2b93aa633141f5230ef15f357ed6f664f28f07bb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
# Copyright 2015, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
#     * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#     * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
#     * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""The RPC-invocation-side bridge between RPC Framework and GRPC-on-the-wire."""

import enum
import logging
import threading
import time

from grpc._adapter import _common
from grpc._adapter import _low
from grpc.framework.base import interfaces as base_interfaces
from grpc.framework.base import null
from grpc.framework.foundation import activated
from grpc.framework.foundation import logging_pool

_THREAD_POOL_SIZE = 100

_INVOCATION_EVENT_KINDS = (
    _low.Event.Kind.METADATA_ACCEPTED,
    _low.Event.Kind.FINISH
)


@enum.unique
class _LowWrite(enum.Enum):
  """The possible categories of low-level write state."""

  OPEN = 'OPEN'
  ACTIVE = 'ACTIVE'
  CLOSED = 'CLOSED'


class _RPCState(object):
  """The full state of any tracked RPC.

  Attributes:
    call: The _low.Call object for the RPC.
    outstanding: The set of Event.Kind values describing expected future events
      for the RPC.
    active: A boolean indicating whether or not the RPC is active.
    common: An _common.RPCState describing additional state for the RPC.
  """

  def __init__(self, call, outstanding, active, common):
    self.call = call
    self.outstanding = outstanding
    self.active = active
    self.common = common


def _write(operation_id, call, outstanding, write_state, serialized_payload):
  if write_state.low is _LowWrite.OPEN:
    call.write(serialized_payload, operation_id)
    outstanding.add(_low.Event.Kind.WRITE_ACCEPTED)
    write_state.low = _LowWrite.ACTIVE
  elif write_state.low is _LowWrite.ACTIVE:
    write_state.pending.append(serialized_payload)
  else:
    raise ValueError('Write attempted after writes completed!')


class RearLink(base_interfaces.RearLink, activated.Activated):
  """An invocation-side bridge between RPC Framework and the C-ish _low code."""

  def __init__(
      self, host, port, pool, request_serializers, response_deserializers,
      secure, root_certificates, private_key, certificate_chain,
      metadata_transformer=None, server_host_override=None):
    """Constructor.

    Args:
      host: The host to which to connect for RPC service.
      port: The port to which to connect for RPC service.
      pool: A thread pool.
      request_serializers: A dict from RPC method names to request object
        serializer behaviors.
      response_deserializers: A dict from RPC method names to response object
        deserializer behaviors.
      secure: A boolean indicating whether or not to use a secure connection.
      root_certificates: The PEM-encoded root certificates or None to ask for
        them to be retrieved from a default location.
      private_key: The PEM-encoded private key to use or None if no private
        key should be used.
      certificate_chain: The PEM-encoded certificate chain to use or None if
        no certificate chain should be used.
      metadata_transformer: A function that given a metadata object produces
        another metadata to be used in the underlying communication on the
        wire.
      server_host_override: (For testing only) the target name used for SSL
        host name checking.
    """
    self._condition = threading.Condition()
    self._host = host
    self._port = port
    self._pool = pool
    self._request_serializers = request_serializers
    self._response_deserializers = response_deserializers

    self._fore_link = null.NULL_FORE_LINK
    self._completion_queue = None
    self._channel = None
    self._rpc_states = {}
    self._spinning = False
    if secure:
      self._client_credentials = _low.ClientCredentials(
          root_certificates, private_key, certificate_chain)
    else:
      self._client_credentials = None
    self._root_certificates = root_certificates
    self._private_key = private_key
    self._certificate_chain = certificate_chain
    self._metadata_transformer = metadata_transformer
    self._server_host_override = server_host_override

  def _on_write_event(self, operation_id, event, rpc_state):
    if event.write_accepted:
      if rpc_state.common.write.pending:
        rpc_state.call.write(
            rpc_state.common.write.pending.pop(0), operation_id)
        rpc_state.outstanding.add(_low.Event.Kind.WRITE_ACCEPTED)
      elif rpc_state.common.write.high is _common.HighWrite.CLOSED:
        rpc_state.call.complete(operation_id)
        rpc_state.outstanding.add(_low.Event.Kind.COMPLETE_ACCEPTED)
        rpc_state.common.write.low = _LowWrite.CLOSED
      else:
        rpc_state.common.write.low = _LowWrite.OPEN
    else:
      logging.error('RPC write not accepted! Event: %s', (event,))
      rpc_state.active = False
      ticket = base_interfaces.BackToFrontTicket(
          operation_id, rpc_state.common.sequence_number,
          base_interfaces.BackToFrontTicket.Kind.TRANSMISSION_FAILURE, None)
      rpc_state.common.sequence_number += 1
      self._fore_link.accept_back_to_front_ticket(ticket)

  def _on_read_event(self, operation_id, event, rpc_state):
    if event.bytes is not None:
      rpc_state.call.read(operation_id)
      rpc_state.outstanding.add(_low.Event.Kind.READ_ACCEPTED)

      ticket = base_interfaces.BackToFrontTicket(
          operation_id, rpc_state.common.sequence_number,
          base_interfaces.BackToFrontTicket.Kind.CONTINUATION,
          rpc_state.common.deserializer(event.bytes))
      rpc_state.common.sequence_number += 1
      self._fore_link.accept_back_to_front_ticket(ticket)

  def _on_complete_event(self, operation_id, event, rpc_state):
    if not event.complete_accepted:
      logging.error('RPC complete not accepted! Event: %s', (event,))
      rpc_state.active = False
      ticket = base_interfaces.BackToFrontTicket(
          operation_id, rpc_state.common.sequence_number,
          base_interfaces.BackToFrontTicket.Kind.TRANSMISSION_FAILURE, None)
      rpc_state.common.sequence_number += 1
      self._fore_link.accept_back_to_front_ticket(ticket)

  # TODO(nathaniel): Metadata support.
  def _on_metadata_event(self, operation_id, event, rpc_state):  # pylint: disable=unused-argument
    rpc_state.call.read(operation_id)
    rpc_state.outstanding.add(_low.Event.Kind.READ_ACCEPTED)

  def _on_finish_event(self, operation_id, event, rpc_state):
    """Handle termination of an RPC."""
    # TODO(nathaniel): Cover all statuses.
    if event.status.code is _low.Code.OK:
      kind = base_interfaces.BackToFrontTicket.Kind.COMPLETION
    elif event.status.code is _low.Code.CANCELLED:
      kind = base_interfaces.BackToFrontTicket.Kind.CANCELLATION
    elif event.status.code is _low.Code.EXPIRED:
      kind = base_interfaces.BackToFrontTicket.Kind.EXPIRATION
    else:
      kind = base_interfaces.BackToFrontTicket.Kind.TRANSMISSION_FAILURE
    ticket = base_interfaces.BackToFrontTicket(
        operation_id, rpc_state.common.sequence_number, kind, None)
    rpc_state.common.sequence_number += 1
    self._fore_link.accept_back_to_front_ticket(ticket)

  def _spin(self, completion_queue):
    while True:
      event = completion_queue.get(None)
      operation_id = event.tag

      with self._condition:
        rpc_state = self._rpc_states[operation_id]
        rpc_state.outstanding.remove(event.kind)
        if rpc_state.active and self._completion_queue is not None:
          if event.kind is _low.Event.Kind.WRITE_ACCEPTED:
            self._on_write_event(operation_id, event, rpc_state)
          elif event.kind is _low.Event.Kind.METADATA_ACCEPTED:
            self._on_metadata_event(operation_id, event, rpc_state)
          elif event.kind is _low.Event.Kind.READ_ACCEPTED:
            self._on_read_event(operation_id, event, rpc_state)
          elif event.kind is _low.Event.Kind.COMPLETE_ACCEPTED:
            self._on_complete_event(operation_id, event, rpc_state)
          elif event.kind is _low.Event.Kind.FINISH:
            self._on_finish_event(operation_id, event, rpc_state)
          else:
            logging.error('Illegal RPC event! %s', (event,))

        if not rpc_state.outstanding:
          self._rpc_states.pop(operation_id)
        if not self._rpc_states:
          self._spinning = False
          self._condition.notify_all()
          return

  def _invoke(self, operation_id, name, high_state, payload, timeout):
    """Invoke an RPC.

    Args:
      operation_id: Any object to be used as an operation ID for the RPC.
      name: The RPC method name.
      high_state: A _common.HighWrite value representing the "high write state"
        of the RPC.
      payload: A payload object for the RPC or None if no payload was given at
        invocation-time.
      timeout: A duration of time in seconds to allow for the RPC.
    """
    request_serializer = self._request_serializers[name]
    call = _low.Call(self._channel, name, self._host, time.time() + timeout)
    if self._metadata_transformer is not None:
      metadata = self._metadata_transformer([])
      for metadata_key, metadata_value in metadata:
        call.add_metadata(metadata_key, metadata_value)
    call.invoke(self._completion_queue, operation_id, operation_id)
    outstanding = set(_INVOCATION_EVENT_KINDS)

    if payload is None:
      if high_state is _common.HighWrite.CLOSED:
        call.complete(operation_id)
        low_state = _LowWrite.CLOSED
        outstanding.add(_low.Event.Kind.COMPLETE_ACCEPTED)
      else:
        low_state = _LowWrite.OPEN
    else:
      serialized_payload = request_serializer(payload)
      call.write(serialized_payload, operation_id)
      outstanding.add(_low.Event.Kind.WRITE_ACCEPTED)
      low_state = _LowWrite.ACTIVE

    write_state = _common.WriteState(low_state, high_state, [])
    common_state = _common.CommonRPCState(
        write_state, 0, self._response_deserializers[name], request_serializer)
    self._rpc_states[operation_id] = _RPCState(
        call, outstanding, True, common_state)

    if not self._spinning:
      self._pool.submit(self._spin, self._completion_queue)
      self._spinning = True

  def _commence(self, operation_id, name, payload, timeout):
    self._invoke(operation_id, name, _common.HighWrite.OPEN, payload, timeout)

  def _continue(self, operation_id, payload):
    rpc_state = self._rpc_states.get(operation_id, None)
    if rpc_state is None or not rpc_state.active:
      return

    _write(
        operation_id, rpc_state.call, rpc_state.outstanding,
        rpc_state.common.write, rpc_state.common.serializer(payload))

  def _complete(self, operation_id, payload):
    """Close writes associated with an ongoing RPC.

    Args:
      operation_id: Any object being use as an operation ID for the RPC.
      payload: A payload object for the RPC (and thus the last payload object
        for the RPC) or None if no payload was given along with the instruction
        to indicate the end of writes for the RPC.
    """
    rpc_state = self._rpc_states.get(operation_id, None)
    if rpc_state is None or not rpc_state.active:
      return

    write_state = rpc_state.common.write
    if payload is None:
      if write_state.low is _LowWrite.OPEN:
        rpc_state.call.complete(operation_id)
        rpc_state.outstanding.add(_low.Event.Kind.COMPLETE_ACCEPTED)
        write_state.low = _LowWrite.CLOSED
    else:
      _write(
          operation_id, rpc_state.call, rpc_state.outstanding, write_state,
          rpc_state.common.serializer(payload))
    write_state.high = _common.HighWrite.CLOSED

  def _entire(self, operation_id, name, payload, timeout):
    self._invoke(operation_id, name, _common.HighWrite.CLOSED, payload, timeout)

  def _cancel(self, operation_id):
    rpc_state = self._rpc_states.get(operation_id, None)
    if rpc_state is not None and rpc_state.active:
      rpc_state.call.cancel()
      rpc_state.active = False

  def join_fore_link(self, fore_link):
    """See base_interfaces.RearLink.join_fore_link for specification."""
    with self._condition:
      self._fore_link = null.NULL_FORE_LINK if fore_link is None else fore_link

  def _start(self):
    """Starts this RearLink.

    This method must be called before attempting to exchange tickets with this
    object.
    """
    with self._condition:
      self._completion_queue = _low.CompletionQueue()
      self._channel = _low.Channel(
          '%s:%d' % (self._host, self._port), self._client_credentials,
          server_host_override=self._server_host_override)
    return self

  def _stop(self):
    """Stops this RearLink.

    This method must be called for proper termination of this object, and no
    attempts to exchange tickets with this object may be made after this method
    has been called.
    """
    with self._condition:
      self._completion_queue.stop()
      self._completion_queue = None

      while self._spinning:
        self._condition.wait()

  def __enter__(self):
    """See activated.Activated.__enter__ for specification."""
    return self._start()

  def __exit__(self, exc_type, exc_val, exc_tb):
    """See activated.Activated.__exit__ for specification."""
    self._stop()
    return False

  def start(self):
    """See activated.Activated.start for specification."""
    return self._start()

  def stop(self):
    """See activated.Activated.stop for specification."""
    self._stop()

  def accept_front_to_back_ticket(self, ticket):
    """See base_interfaces.RearLink.accept_front_to_back_ticket for spec."""
    with self._condition:
      if self._completion_queue is None:
        return

      if ticket.kind is base_interfaces.FrontToBackTicket.Kind.COMMENCEMENT:
        self._commence(
            ticket.operation_id, ticket.name, ticket.payload, ticket.timeout)
      elif ticket.kind is base_interfaces.FrontToBackTicket.Kind.CONTINUATION:
        self._continue(ticket.operation_id, ticket.payload)
      elif ticket.kind is base_interfaces.FrontToBackTicket.Kind.COMPLETION:
        self._complete(ticket.operation_id, ticket.payload)
      elif ticket.kind is base_interfaces.FrontToBackTicket.Kind.ENTIRE:
        self._entire(
            ticket.operation_id, ticket.name, ticket.payload, ticket.timeout)
      elif ticket.kind is base_interfaces.FrontToBackTicket.Kind.CANCELLATION:
        self._cancel(ticket.operation_id)
      else:
        # NOTE(nathaniel): All other categories are treated as cancellation.
        self._cancel(ticket.operation_id)