diff options
author | Ken Payson <kpayson@google.com> | 2016-06-10 13:19:31 -0700 |
---|---|---|
committer | Ken Payson <kpayson@google.com> | 2016-06-13 15:38:53 -0700 |
commit | ed7428526cea1f7b54a126ef9cdb2c1456c47f2c (patch) | |
tree | 64d4239e374dbbc2da4be8153db816c29b5a4350 /src | |
parent | 68e5ecbee4882a74b16a6e22d935d68798016804 (diff) |
Added cleanup to the server thread's join method.
When the Python Interpreter exits, it first attempts to join any
outstanding threads. This is problematic if a server is created
as a top-level variable or referenced by a reference cycle, as join()
will hang. This adds cleanup behavior to the server thread's join().
Diffstat (limited to 'src')
-rw-r--r-- | src/python/grpcio/grpc/_common.py | 42 | ||||
-rw-r--r-- | src/python/grpcio/grpc/_server.py | 31 | ||||
-rw-r--r-- | src/python/grpcio/tests/tests.json | 1 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/_rpc_test.py | 7 | ||||
-rw-r--r-- | src/python/grpcio/tests/unit/_thread_cleanup_test.py | 117 |
5 files changed, 180 insertions, 18 deletions
diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index b8688a0149..1fd1704f18 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -30,6 +30,8 @@ """Shared implementation.""" import logging +import threading +import time import six @@ -110,3 +112,43 @@ def fully_qualified_method(group, method): group = _encode(group) method = _encode(method) return b'/' + group + b'/' + method + + +class CleanupThread(threading.Thread): + """A threading.Thread subclass supporting custom behavior on join(). + + On Python Interpreter exit, Python will attempt to join outstanding threads + prior to garbage collection. We may need to do additional cleanup, and + we accomplish this by overriding the join() method. + """ + + def __init__(self, behavior, group=None, target=None, name=None, + args=(), kwargs={}): + """Constructor. + + Args: + behavior (function): Function called on join() with a single + argument, timeout, indicating the maximum duration of + `behavior`, or None indicating `behavior` has no deadline. + `behavior` must be idempotent. + group (None): should be None. Reseved for future extensions + when ThreadGroup is implemented. + target (function): The function to invoke when this thread is + run. Defaults to None. + name (str): The name of this thread. Defaults to None. + args (tuple[object]): A tuple of arguments to pass to `target`. + kwargs (dict[str,object]): A dictionary of keyword arguments to + pass to `target`. + """ + super(CleanupThread, self).__init__(group=group, target=target, + name=name, args=args, kwargs=kwargs) + self._behavior = behavior + + def join(self, timeout=None): + start_time = time.time() + self._behavior(timeout) + end_time = time.time() + if timeout is not None: + timeout -= end_time - start_time + timeout = max(timeout, 0) + super(CleanupThread, self).join(timeout) diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index f4f6720497..2d9b96afbf 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -60,6 +60,8 @@ _CANCELLED = 'cancelled' _EMPTY_FLAGS = 0 _EMPTY_METADATA = cygrpc.Metadata(()) +_UNEXPECTED_EXIT_SERVER_GRACE = 1.0 + def _serialized_request(request_event): return request_event.batch_operations[0].received_message.bytes() @@ -670,17 +672,6 @@ def _serve(state): return -def _start(state): - with state.lock: - if state.stage is not _ServerStage.STOPPED: - raise ValueError('Cannot start already-started server!') - state.server.start() - state.stage = _ServerStage.STARTED - _request_call(state) - thread = threading.Thread(target=_serve, args=(state,)) - thread.start() - - def _stop(state, grace): with state.lock: if state.stage is _ServerStage.STOPPED: @@ -719,6 +710,24 @@ def _stop(state, grace): return shutdown_event +def _start(state): + with state.lock: + if state.stage is not _ServerStage.STOPPED: + raise ValueError('Cannot start already-started server!') + state.server.start() + state.stage = _ServerStage.STARTED + _request_call(state) + def cleanup_server(timeout): + if timeout is None: + _stop(state, _UNEXPECTED_EXIT_SERVER_GRACE).wait() + else: + _stop(state, timeout).wait() + + thread = _common.CleanupThread( + cleanup_server, target=_serve, args=(state,)) + thread.start() + + class Server(grpc.Server): def __init__(self, generic_handlers, thread_pool): diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json index 8dc47bf69d..fcf2001b80 100644 --- a/src/python/grpcio/tests/tests.json +++ b/src/python/grpcio/tests/tests.json @@ -52,6 +52,7 @@ "_rpc_test.RPCTest", "_sanity_test.Sanity", "_secure_interop_test.SecureInteropTest", + "_thread_cleanup_test.CleanupThreadTest", "_transmission_test.RoundTripTest", "_transmission_test.TransmissionTest", "_utilities_test.ChannelConnectivityTest", diff --git a/src/python/grpcio/tests/unit/_rpc_test.py b/src/python/grpcio/tests/unit/_rpc_test.py index b33bff490c..8773e96284 100644 --- a/src/python/grpcio/tests/unit/_rpc_test.py +++ b/src/python/grpcio/tests/unit/_rpc_test.py @@ -193,13 +193,6 @@ class RPCTest(unittest.TestCase): self._channel = grpc.insecure_channel('localhost:%d' % port) - # TODO(nathaniel): Why is this necessary, and only in some development - # environments? - def tearDown(self): - del self._channel - del self._server - del self._server_pool - def testUnrecognizedMethod(self): request = b'abc' diff --git a/src/python/grpcio/tests/unit/_thread_cleanup_test.py b/src/python/grpcio/tests/unit/_thread_cleanup_test.py new file mode 100644 index 0000000000..3e4f317edc --- /dev/null +++ b/src/python/grpcio/tests/unit/_thread_cleanup_test.py @@ -0,0 +1,117 @@ +# Copyright 2016, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Tests for CleanupThread.""" + +import threading +import time +import unittest + +from grpc import _common + +_SHORT_TIME = 0.5 +_LONG_TIME = 2.0 +_EPSILON = 0.1 + + +def cleanup(timeout): + if timeout is not None: + time.sleep(timeout) + else: + time.sleep(_LONG_TIME) + + +def slow_cleanup(timeout): + # Don't respect timeout + time.sleep(_LONG_TIME) + + +class CleanupThreadTest(unittest.TestCase): + + def testTargetInvocation(self): + event = threading.Event() + def target(arg1, arg2, arg3=None): + self.assertEqual('arg1', arg1) + self.assertEqual('arg2', arg2) + self.assertEqual('arg3', arg3) + event.set() + + cleanup_thread = _common.CleanupThread(behavior=lambda x: None, + target=target, name='test-name', + args=('arg1', 'arg2'), kwargs={'arg3': 'arg3'}) + cleanup_thread.start() + cleanup_thread.join() + self.assertEqual(cleanup_thread.name, 'test-name') + self.assertTrue(event.is_set()) + + def testJoinNoTimeout(self): + cleanup_thread = _common.CleanupThread(behavior=cleanup) + cleanup_thread.start() + start_time = time.time() + cleanup_thread.join() + end_time = time.time() + self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON) + + def testJoinTimeout(self): + cleanup_thread = _common.CleanupThread(behavior=cleanup) + cleanup_thread.start() + start_time = time.time() + cleanup_thread.join(_SHORT_TIME) + end_time = time.time() + self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON) + + def testJoinTimeoutSlowBehavior(self): + cleanup_thread = _common.CleanupThread(behavior=slow_cleanup) + cleanup_thread.start() + start_time = time.time() + cleanup_thread.join(_SHORT_TIME) + end_time = time.time() + self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON) + + def testJoinTimeoutSlowTarget(self): + event = threading.Event() + def target(): + event.wait(_LONG_TIME) + cleanup_thread = _common.CleanupThread(behavior=cleanup, target=target) + cleanup_thread.start() + start_time = time.time() + cleanup_thread.join(_SHORT_TIME) + end_time = time.time() + self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON) + event.set() + + def testJoinZeroTimeout(self): + cleanup_thread = _common.CleanupThread(behavior=cleanup) + cleanup_thread.start() + start_time = time.time() + cleanup_thread.join(0) + end_time = time.time() + self.assertAlmostEqual(0, end_time - start_time, delta=_EPSILON) + +if __name__ == '__main__': + unittest.main(verbosity=2) |