aboutsummaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/core/ext/transport/chttp2/transport/chttp2_transport.cc7
-rw-r--r--src/core/ext/transport/chttp2/transport/context_list.cc38
-rw-r--r--src/core/ext/transport/chttp2/transport/context_list.h35
-rw-r--r--src/cpp/common/channel_arguments.cc2
-rw-r--r--src/python/grpcio_health_checking/grpc_health/v1/health.py80
-rw-r--r--src/python/grpcio_tests/commands.py2
-rw-r--r--src/python/grpcio_tests/tests/health_check/BUILD.bazel1
-rw-r--r--src/python/grpcio_tests/tests/health_check/_health_servicer_test.py187
-rw-r--r--src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py8
-rw-r--r--src/python/grpcio_tests/tests/unit/_api_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_auth_context_test.py6
-rw-r--r--src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py6
-rw-r--r--src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py5
-rw-r--r--src/python/grpcio_tests/tests/unit/_compression_test.py5
-rw-r--r--src/python/grpcio_tests/tests/unit/_empty_message_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_interceptor_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py3
-rw-r--r--src/python/grpcio_tests/tests/unit/_invocation_defects_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py14
-rw-r--r--src/python/grpcio_tests/tests/unit/_metadata_flags_test.py29
-rw-r--r--src/python/grpcio_tests/tests/unit/_metadata_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_reconnect_test.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_rpc_test.py1
-rwxr-xr-xsrc/ruby/end2end/graceful_sig_handling_client.rb61
-rwxr-xr-xsrc/ruby/end2end/graceful_sig_handling_driver.rb83
-rwxr-xr-xsrc/ruby/end2end/graceful_sig_stop_client.rb78
-rwxr-xr-xsrc/ruby/end2end/graceful_sig_stop_driver.rb62
-rw-r--r--src/ruby/lib/grpc/generic/rpc_server.rb61
30 files changed, 689 insertions, 94 deletions
diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
index 9b6574b612..7f4627fa77 100644
--- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
+++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
@@ -170,7 +170,12 @@ grpc_chttp2_transport::~grpc_chttp2_transport() {
grpc_slice_buffer_destroy_internal(&outbuf);
grpc_chttp2_hpack_compressor_destroy(&hpack_compressor);
- grpc_core::ContextList::Execute(cl, nullptr, GRPC_ERROR_NONE);
+ grpc_error* error =
+ GRPC_ERROR_CREATE_FROM_STATIC_STRING("Transport destroyed");
+ // ContextList::Execute follows semantics of a callback function and does not
+ // take a ref on error
+ grpc_core::ContextList::Execute(cl, nullptr, error);
+ GRPC_ERROR_UNREF(error);
cl = nullptr;
grpc_slice_buffer_destroy_internal(&read_buffer);
diff --git a/src/core/ext/transport/chttp2/transport/context_list.cc b/src/core/ext/transport/chttp2/transport/context_list.cc
index f30d41c332..df09809067 100644
--- a/src/core/ext/transport/chttp2/transport/context_list.cc
+++ b/src/core/ext/transport/chttp2/transport/context_list.cc
@@ -21,31 +21,47 @@
#include "src/core/ext/transport/chttp2/transport/context_list.h"
namespace {
-void (*write_timestamps_callback_g)(void*, grpc_core::Timestamps*) = nullptr;
-}
+void (*write_timestamps_callback_g)(void*, grpc_core::Timestamps*,
+ grpc_error* error) = nullptr;
+void* (*get_copied_context_fn_g)(void*) = nullptr;
+} // namespace
namespace grpc_core {
+void ContextList::Append(ContextList** head, grpc_chttp2_stream* s) {
+ if (get_copied_context_fn_g == nullptr ||
+ write_timestamps_callback_g == nullptr) {
+ return;
+ }
+ /* Create a new element in the list and add it at the front */
+ ContextList* elem = grpc_core::New<ContextList>();
+ elem->trace_context_ = get_copied_context_fn_g(s->context);
+ elem->byte_offset_ = s->byte_counter;
+ elem->next_ = *head;
+ *head = elem;
+}
+
void ContextList::Execute(void* arg, grpc_core::Timestamps* ts,
grpc_error* error) {
ContextList* head = static_cast<ContextList*>(arg);
ContextList* to_be_freed;
while (head != nullptr) {
- if (error == GRPC_ERROR_NONE && ts != nullptr) {
- if (write_timestamps_callback_g) {
- ts->byte_offset = static_cast<uint32_t>(head->byte_offset_);
- write_timestamps_callback_g(head->s_->context, ts);
- }
+ if (write_timestamps_callback_g) {
+ ts->byte_offset = static_cast<uint32_t>(head->byte_offset_);
+ write_timestamps_callback_g(head->trace_context_, ts, error);
}
- GRPC_CHTTP2_STREAM_UNREF(static_cast<grpc_chttp2_stream*>(head->s_),
- "timestamp");
to_be_freed = head;
head = head->next_;
grpc_core::Delete(to_be_freed);
}
}
-void grpc_http2_set_write_timestamps_callback(
- void (*fn)(void*, grpc_core::Timestamps*)) {
+void grpc_http2_set_write_timestamps_callback(void (*fn)(void*,
+ grpc_core::Timestamps*,
+ grpc_error* error)) {
write_timestamps_callback_g = fn;
}
+
+void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*)) {
+ get_copied_context_fn_g = fn;
+}
} /* namespace grpc_core */
diff --git a/src/core/ext/transport/chttp2/transport/context_list.h b/src/core/ext/transport/chttp2/transport/context_list.h
index d870107749..5b9d2ab378 100644
--- a/src/core/ext/transport/chttp2/transport/context_list.h
+++ b/src/core/ext/transport/chttp2/transport/context_list.h
@@ -31,42 +31,23 @@ class ContextList {
public:
/* Creates a new element with \a context as the value and appends it to the
* list. */
- static void Append(ContextList** head, grpc_chttp2_stream* s) {
- /* Make sure context is not already present */
- GRPC_CHTTP2_STREAM_REF(s, "timestamp");
-
-#ifndef NDEBUG
- ContextList* ptr = *head;
- while (ptr != nullptr) {
- if (ptr->s_ == s) {
- GPR_ASSERT(
- false &&
- "Trying to append a stream that is already present in the list");
- }
- ptr = ptr->next_;
- }
-#endif
-
- /* Create a new element in the list and add it at the front */
- ContextList* elem = grpc_core::New<ContextList>();
- elem->s_ = s;
- elem->byte_offset_ = s->byte_counter;
- elem->next_ = *head;
- *head = elem;
- }
+ static void Append(ContextList** head, grpc_chttp2_stream* s);
/* Executes a function \a fn with each context in the list and \a ts. It also
- * frees up the entire list after this operation. */
+ * frees up the entire list after this operation. It is intended as a callback
+ * and hence does not take a ref on \a error */
static void Execute(void* arg, grpc_core::Timestamps* ts, grpc_error* error);
private:
- grpc_chttp2_stream* s_ = nullptr;
+ void* trace_context_ = nullptr;
ContextList* next_ = nullptr;
size_t byte_offset_ = 0;
};
-void grpc_http2_set_write_timestamps_callback(
- void (*fn)(void*, grpc_core::Timestamps*));
+void grpc_http2_set_write_timestamps_callback(void (*fn)(void*,
+ grpc_core::Timestamps*,
+ grpc_error* error));
+void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*));
} /* namespace grpc_core */
#endif /* GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_CONTEXT_LIST_H */
diff --git a/src/cpp/common/channel_arguments.cc b/src/cpp/common/channel_arguments.cc
index 50ee9d871f..214d72f853 100644
--- a/src/cpp/common/channel_arguments.cc
+++ b/src/cpp/common/channel_arguments.cc
@@ -106,7 +106,9 @@ void ChannelArguments::SetSocketMutator(grpc_socket_mutator* mutator) {
}
if (!replaced) {
+ strings_.push_back(grpc::string(mutator_arg.key));
args_.push_back(mutator_arg);
+ args_.back().key = const_cast<char*>(strings_.back().c_str());
}
}
diff --git a/src/python/grpcio_health_checking/grpc_health/v1/health.py b/src/python/grpcio_health_checking/grpc_health/v1/health.py
index 0583659428..0a5bbb5504 100644
--- a/src/python/grpcio_health_checking/grpc_health/v1/health.py
+++ b/src/python/grpcio_health_checking/grpc_health/v1/health.py
@@ -23,15 +23,61 @@ from grpc_health.v1 import health_pb2_grpc as _health_pb2_grpc
SERVICE_NAME = _health_pb2.DESCRIPTOR.services_by_name['Health'].full_name
+class _Watcher():
+
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._responses = list()
+ self._open = True
+
+ def __iter__(self):
+ return self
+
+ def _next(self):
+ with self._condition:
+ while not self._responses and self._open:
+ self._condition.wait()
+ if self._responses:
+ return self._responses.pop(0)
+ else:
+ raise StopIteration()
+
+ def next(self):
+ return self._next()
+
+ def __next__(self):
+ return self._next()
+
+ def add(self, response):
+ with self._condition:
+ self._responses.append(response)
+ self._condition.notify()
+
+ def close(self):
+ with self._condition:
+ self._open = False
+ self._condition.notify()
+
+
class HealthServicer(_health_pb2_grpc.HealthServicer):
"""Servicer handling RPCs for service statuses."""
def __init__(self):
- self._server_status_lock = threading.Lock()
+ self._lock = threading.RLock()
self._server_status = {}
+ self._watchers = {}
+
+ def _on_close_callback(self, watcher, service):
+
+ def callback():
+ with self._lock:
+ self._watchers[service].remove(watcher)
+ watcher.close()
+
+ return callback
def Check(self, request, context):
- with self._server_status_lock:
+ with self._lock:
status = self._server_status.get(request.service)
if status is None:
context.set_code(grpc.StatusCode.NOT_FOUND)
@@ -39,14 +85,30 @@ class HealthServicer(_health_pb2_grpc.HealthServicer):
else:
return _health_pb2.HealthCheckResponse(status=status)
+ def Watch(self, request, context):
+ service = request.service
+ with self._lock:
+ status = self._server_status.get(service)
+ if status is None:
+ status = _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN # pylint: disable=no-member
+ watcher = _Watcher()
+ watcher.add(_health_pb2.HealthCheckResponse(status=status))
+ if service not in self._watchers:
+ self._watchers[service] = set()
+ self._watchers[service].add(watcher)
+ context.add_callback(self._on_close_callback(watcher, service))
+ return watcher
+
def set(self, service, status):
"""Sets the status of a service.
- Args:
- service: string, the name of the service.
- NOTE, '' must be set.
- status: HealthCheckResponse.status enum value indicating
- the status of the service
- """
- with self._server_status_lock:
+ Args:
+ service: string, the name of the service. NOTE, '' must be set.
+ status: HealthCheckResponse.status enum value indicating the status of
+ the service
+ """
+ with self._lock:
self._server_status[service] = status
+ if service in self._watchers:
+ for watcher in self._watchers[service]:
+ watcher.add(_health_pb2.HealthCheckResponse(status=status))
diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py
index 496bcfbcbf..582ce898de 100644
--- a/src/python/grpcio_tests/commands.py
+++ b/src/python/grpcio_tests/commands.py
@@ -133,6 +133,7 @@ class TestGevent(setuptools.Command):
# This test will stuck while running higher version of gevent
'unit._auth_context_test.AuthContextTest.testSessionResumption',
# TODO(https://github.com/grpc/grpc/issues/15411) enable these tests
+ 'unit._metadata_flags_test',
'unit._exit_test.ExitTest.test_in_flight_unary_unary_call',
'unit._exit_test.ExitTest.test_in_flight_unary_stream_call',
'unit._exit_test.ExitTest.test_in_flight_stream_unary_call',
@@ -140,6 +141,7 @@ class TestGevent(setuptools.Command):
'unit._exit_test.ExitTest.test_in_flight_partial_unary_stream_call',
'unit._exit_test.ExitTest.test_in_flight_partial_stream_unary_call',
'unit._exit_test.ExitTest.test_in_flight_partial_stream_stream_call',
+ 'health_check._health_servicer_test.HealthServicerTest.test_cancelled_watch_removed_from_watch_list',
# TODO(https://github.com/grpc/grpc/issues/17330) enable these three tests
'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels',
'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels_and_sockets',
diff --git a/src/python/grpcio_tests/tests/health_check/BUILD.bazel b/src/python/grpcio_tests/tests/health_check/BUILD.bazel
index 19e1e1b2e1..77bc61aa30 100644
--- a/src/python/grpcio_tests/tests/health_check/BUILD.bazel
+++ b/src/python/grpcio_tests/tests/health_check/BUILD.bazel
@@ -9,6 +9,7 @@ py_test(
"//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_health_checking/grpc_health/v1:grpc_health",
"//src/python/grpcio_tests/tests/unit:test_common",
+ "//src/python/grpcio_tests/tests/unit/framework/common:common",
],
imports = ["../../",],
)
diff --git a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
index 350b5eebe5..35794987bc 100644
--- a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
+++ b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
@@ -13,6 +13,8 @@
# limitations under the License.
"""Tests of grpc_health.v1.health."""
+import threading
+import time
import unittest
import grpc
@@ -21,58 +23,199 @@ from grpc_health.v1 import health_pb2
from grpc_health.v1 import health_pb2_grpc
from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+
+from six.moves import queue
+
+_SERVING_SERVICE = 'grpc.test.TestServiceServing'
+_UNKNOWN_SERVICE = 'grpc.test.TestServiceUnknown'
+_NOT_SERVING_SERVICE = 'grpc.test.TestServiceNotServing'
+_WATCH_SERVICE = 'grpc.test.WatchService'
+
+
+def _consume_responses(response_iterator, response_queue):
+ for response in response_iterator:
+ response_queue.put(response)
class HealthServicerTest(unittest.TestCase):
def setUp(self):
- servicer = health.HealthServicer()
- servicer.set('', health_pb2.HealthCheckResponse.SERVING)
- servicer.set('grpc.test.TestServiceServing',
- health_pb2.HealthCheckResponse.SERVING)
- servicer.set('grpc.test.TestServiceUnknown',
- health_pb2.HealthCheckResponse.UNKNOWN)
- servicer.set('grpc.test.TestServiceNotServing',
- health_pb2.HealthCheckResponse.NOT_SERVING)
+ self._servicer = health.HealthServicer()
+ self._servicer.set('', health_pb2.HealthCheckResponse.SERVING)
+ self._servicer.set(_SERVING_SERVICE,
+ health_pb2.HealthCheckResponse.SERVING)
+ self._servicer.set(_UNKNOWN_SERVICE,
+ health_pb2.HealthCheckResponse.UNKNOWN)
+ self._servicer.set(_NOT_SERVING_SERVICE,
+ health_pb2.HealthCheckResponse.NOT_SERVING)
self._server = test_common.test_server()
port = self._server.add_insecure_port('[::]:0')
- health_pb2_grpc.add_HealthServicer_to_server(servicer, self._server)
+ health_pb2_grpc.add_HealthServicer_to_server(self._servicer,
+ self._server)
self._server.start()
- channel = grpc.insecure_channel('localhost:%d' % port)
- self._stub = health_pb2_grpc.HealthStub(channel)
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+ self._stub = health_pb2_grpc.HealthStub(self._channel)
- def test_empty_service(self):
+ def tearDown(self):
+ self._server.stop(None)
+ self._channel.close()
+
+ def test_check_empty_service(self):
request = health_pb2.HealthCheckRequest()
resp = self._stub.Check(request)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status)
- def test_serving_service(self):
- request = health_pb2.HealthCheckRequest(
- service='grpc.test.TestServiceServing')
+ def test_check_serving_service(self):
+ request = health_pb2.HealthCheckRequest(service=_SERVING_SERVICE)
resp = self._stub.Check(request)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status)
- def test_unknown_serivce(self):
- request = health_pb2.HealthCheckRequest(
- service='grpc.test.TestServiceUnknown')
+ def test_check_unknown_serivce(self):
+ request = health_pb2.HealthCheckRequest(service=_UNKNOWN_SERVICE)
resp = self._stub.Check(request)
self.assertEqual(health_pb2.HealthCheckResponse.UNKNOWN, resp.status)
- def test_not_serving_service(self):
- request = health_pb2.HealthCheckRequest(
- service='grpc.test.TestServiceNotServing')
+ def test_check_not_serving_service(self):
+ request = health_pb2.HealthCheckRequest(service=_NOT_SERVING_SERVICE)
resp = self._stub.Check(request)
self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING,
resp.status)
- def test_not_found_service(self):
+ def test_check_not_found_service(self):
request = health_pb2.HealthCheckRequest(service='not-found')
with self.assertRaises(grpc.RpcError) as context:
resp = self._stub.Check(request)
self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code())
+ def test_watch_empty_service(self):
+ request = health_pb2.HealthCheckRequest(service='')
+ response_queue = queue.Queue()
+ rendezvous = self._stub.Watch(request)
+ thread = threading.Thread(
+ target=_consume_responses, args=(rendezvous, response_queue))
+ thread.start()
+
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
+ response.status)
+
+ rendezvous.cancel()
+ thread.join()
+ self.assertTrue(response_queue.empty())
+
+ def test_watch_new_service(self):
+ request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
+ response_queue = queue.Queue()
+ rendezvous = self._stub.Watch(request)
+ thread = threading.Thread(
+ target=_consume_responses, args=(rendezvous, response_queue))
+ thread.start()
+
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
+ response.status)
+
+ self._servicer.set(_WATCH_SERVICE,
+ health_pb2.HealthCheckResponse.SERVING)
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
+ response.status)
+
+ self._servicer.set(_WATCH_SERVICE,
+ health_pb2.HealthCheckResponse.NOT_SERVING)
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING,
+ response.status)
+
+ rendezvous.cancel()
+ thread.join()
+ self.assertTrue(response_queue.empty())
+
+ def test_watch_service_isolation(self):
+ request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
+ response_queue = queue.Queue()
+ rendezvous = self._stub.Watch(request)
+ thread = threading.Thread(
+ target=_consume_responses, args=(rendezvous, response_queue))
+ thread.start()
+
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
+ response.status)
+
+ self._servicer.set('some-other-service',
+ health_pb2.HealthCheckResponse.SERVING)
+ with self.assertRaises(queue.Empty):
+ response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+
+ rendezvous.cancel()
+ thread.join()
+ self.assertTrue(response_queue.empty())
+
+ def test_two_watchers(self):
+ request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
+ response_queue1 = queue.Queue()
+ response_queue2 = queue.Queue()
+ rendezvous1 = self._stub.Watch(request)
+ rendezvous2 = self._stub.Watch(request)
+ thread1 = threading.Thread(
+ target=_consume_responses, args=(rendezvous1, response_queue1))
+ thread2 = threading.Thread(
+ target=_consume_responses, args=(rendezvous2, response_queue2))
+ thread1.start()
+ thread2.start()
+
+ response1 = response_queue1.get(timeout=test_constants.SHORT_TIMEOUT)
+ response2 = response_queue2.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
+ response1.status)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
+ response2.status)
+
+ self._servicer.set(_WATCH_SERVICE,
+ health_pb2.HealthCheckResponse.SERVING)
+ response1 = response_queue1.get(timeout=test_constants.SHORT_TIMEOUT)
+ response2 = response_queue2.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
+ response1.status)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
+ response2.status)
+
+ rendezvous1.cancel()
+ rendezvous2.cancel()
+ thread1.join()
+ thread2.join()
+ self.assertTrue(response_queue1.empty())
+ self.assertTrue(response_queue2.empty())
+
+ def test_cancelled_watch_removed_from_watch_list(self):
+ request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
+ response_queue = queue.Queue()
+ rendezvous = self._stub.Watch(request)
+ thread = threading.Thread(
+ target=_consume_responses, args=(rendezvous, response_queue))
+ thread.start()
+
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
+ response.status)
+
+ rendezvous.cancel()
+ self._servicer.set(_WATCH_SERVICE,
+ health_pb2.HealthCheckResponse.SERVING)
+ thread.join()
+
+ # Wait, if necessary, for serving thread to process client cancellation
+ timeout = time.time() + test_constants.SHORT_TIMEOUT
+ while time.time() < timeout and self._servicer._watchers[_WATCH_SERVICE]:
+ time.sleep(1)
+ self.assertFalse(self._servicer._watchers[_WATCH_SERVICE],
+ 'watch set should be empty')
+ self.assertTrue(response_queue.empty())
+
def test_health_service_name(self):
self.assertEqual(health.SERVICE_NAME, 'grpc.health.v1.Health')
diff --git a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
index bcd9e14a38..560f6d3ddb 100644
--- a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
+++ b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
@@ -56,8 +56,12 @@ class ReflectionServicerTest(unittest.TestCase):
port = self._server.add_insecure_port('[::]:0')
self._server.start()
- channel = grpc.insecure_channel('localhost:%d' % port)
- self._stub = reflection_pb2_grpc.ServerReflectionStub(channel)
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+ self._stub = reflection_pb2_grpc.ServerReflectionStub(self._channel)
+
+ def tearDown(self):
+ self._server.stop(None)
+ self._channel.close()
def testFileByName(self):
requests = (
diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py
index 427894bfe9..0dc6a8718c 100644
--- a/src/python/grpcio_tests/tests/unit/_api_test.py
+++ b/src/python/grpcio_tests/tests/unit/_api_test.py
@@ -101,6 +101,7 @@ class ChannelTest(unittest.TestCase):
def test_secure_channel(self):
channel_credentials = grpc.ssl_channel_credentials()
channel = grpc.secure_channel('google.com:443', channel_credentials)
+ channel.close()
if __name__ == '__main__':
diff --git a/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/src/python/grpcio_tests/tests/unit/_auth_context_test.py
index b1b5bbdcab..96c4e9ec76 100644
--- a/src/python/grpcio_tests/tests/unit/_auth_context_test.py
+++ b/src/python/grpcio_tests/tests/unit/_auth_context_test.py
@@ -71,8 +71,8 @@ class AuthContextTest(unittest.TestCase):
port = server.add_insecure_port('[::]:0')
server.start()
- channel = grpc.insecure_channel('localhost:%d' % port)
- response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
+ with grpc.insecure_channel('localhost:%d' % port) as channel:
+ response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
server.stop(None)
auth_data = pickle.loads(response)
@@ -98,6 +98,7 @@ class AuthContextTest(unittest.TestCase):
channel_creds,
options=_PROPERTY_OPTIONS)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
+ channel.close()
server.stop(None)
auth_data = pickle.loads(response)
@@ -132,6 +133,7 @@ class AuthContextTest(unittest.TestCase):
options=_PROPERTY_OPTIONS)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
+ channel.close()
server.stop(None)
auth_data = pickle.loads(response)
diff --git a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
index 727fb7d65f..565bd39b3a 100644
--- a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
+++ b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
@@ -75,6 +75,8 @@ class ChannelConnectivityTest(unittest.TestCase):
channel.unsubscribe(callback.update)
fifth_connectivities = callback.connectivities()
+ channel.close()
+
self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,),
first_connectivities)
self.assertNotIn(grpc.ChannelConnectivity.READY, second_connectivities)
@@ -108,7 +110,8 @@ class ChannelConnectivityTest(unittest.TestCase):
_ready_in_connectivities)
second_callback.block_until_connectivities_satisfy(
_ready_in_connectivities)
- del channel
+ channel.close()
+ server.stop(None)
self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,),
first_connectivities)
@@ -139,6 +142,7 @@ class ChannelConnectivityTest(unittest.TestCase):
callback.block_until_connectivities_satisfy(
_last_connectivity_is_not_ready)
channel.unsubscribe(callback.update)
+ channel.close()
self.assertFalse(thread_pool.was_used())
diff --git a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
index 345460ef40..46a4eb9bb6 100644
--- a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
+++ b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
@@ -60,6 +60,8 @@ class ChannelReadyFutureTest(unittest.TestCase):
self.assertTrue(ready_future.done())
self.assertFalse(ready_future.running())
+ channel.close()
+
def test_immediately_connectable_channel_connectivity(self):
thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
server = grpc.server(thread_pool, options=(('grpc.so_reuseport', 0),))
@@ -84,6 +86,9 @@ class ChannelReadyFutureTest(unittest.TestCase):
self.assertFalse(ready_future.running())
self.assertFalse(thread_pool.was_used())
+ channel.close()
+ server.stop(None)
+
if __name__ == '__main__':
logging.basicConfig()
diff --git a/src/python/grpcio_tests/tests/unit/_compression_test.py b/src/python/grpcio_tests/tests/unit/_compression_test.py
index 876d8e827e..87884a19dc 100644
--- a/src/python/grpcio_tests/tests/unit/_compression_test.py
+++ b/src/python/grpcio_tests/tests/unit/_compression_test.py
@@ -77,6 +77,9 @@ class CompressionTest(unittest.TestCase):
self._port = self._server.add_insecure_port('[::]:0')
self._server.start()
+ def tearDown(self):
+ self._server.stop(None)
+
def testUnary(self):
request = b'\x00' * 100
@@ -102,6 +105,7 @@ class CompressionTest(unittest.TestCase):
response = multi_callable(
request, metadata=[('grpc-internal-encoding-request', 'gzip')])
self.assertEqual(request, response)
+ compressed_channel.close()
def testStreaming(self):
request = b'\x00' * 100
@@ -115,6 +119,7 @@ class CompressionTest(unittest.TestCase):
call = multi_callable(iter([request] * test_constants.STREAM_LENGTH))
for response in call:
self.assertEqual(request, response)
+ compressed_channel.close()
if __name__ == '__main__':
diff --git a/src/python/grpcio_tests/tests/unit/_empty_message_test.py b/src/python/grpcio_tests/tests/unit/_empty_message_test.py
index 3e8393b53c..f27ea422d0 100644
--- a/src/python/grpcio_tests/tests/unit/_empty_message_test.py
+++ b/src/python/grpcio_tests/tests/unit/_empty_message_test.py
@@ -96,6 +96,7 @@ class EmptyMessageTest(unittest.TestCase):
def tearDown(self):
self._server.stop(0)
+ self._channel.close()
def testUnaryUnary(self):
response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST)
diff --git a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py
index 6c551df3ec..81de1dae1d 100644
--- a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py
+++ b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py
@@ -71,6 +71,7 @@ class ErrorMessageEncodingTest(unittest.TestCase):
def tearDown(self):
self._server.stop(0)
+ self._channel.close()
def testMessageEncoding(self):
for message in _UNICODE_ERROR_MESSAGES:
diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py
index 99db0ac58b..a647e5e720 100644
--- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py
+++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py
@@ -337,6 +337,7 @@ class InterceptorTest(unittest.TestCase):
def tearDown(self):
self._server.stop(None)
self._server_pool.shutdown(wait=True)
+ self._channel.close()
def testTripleRequestMessagesClientInterceptor(self):
diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
index 0ff49490d5..7ed7c83893 100644
--- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
+++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
@@ -62,6 +62,9 @@ class InvalidMetadataTest(unittest.TestCase):
self._stream_unary = _stream_unary_multi_callable(self._channel)
self._stream_stream = _stream_stream_multi_callable(self._channel)
+ def tearDown(self):
+ self._channel.close()
+
def testUnaryRequestBlockingUnaryResponse(self):
request = b'\x07\x08'
metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),)
diff --git a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
index 00949e2236..e89b521cc5 100644
--- a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
+++ b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
@@ -215,6 +215,7 @@ class InvocationDefectsTest(unittest.TestCase):
def tearDown(self):
self._server.stop(0)
+ self._channel.close()
def testIterableStreamRequestBlockingUnaryResponse(self):
requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)]
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
index 0dafab827a..a63664ac5d 100644
--- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
+++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
@@ -198,8 +198,8 @@ class MetadataCodeDetailsTest(unittest.TestCase):
port = self._server.add_insecure_port('[::]:0')
self._server.start()
- channel = grpc.insecure_channel('localhost:{}'.format(port))
- self._unary_unary = channel.unary_unary(
+ self._channel = grpc.insecure_channel('localhost:{}'.format(port))
+ self._unary_unary = self._channel.unary_unary(
'/'.join((
'',
_SERVICE,
@@ -208,17 +208,17 @@ class MetadataCodeDetailsTest(unittest.TestCase):
request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,
)
- self._unary_stream = channel.unary_stream('/'.join((
+ self._unary_stream = self._channel.unary_stream('/'.join((
'',
_SERVICE,
_UNARY_STREAM,
)),)
- self._stream_unary = channel.stream_unary('/'.join((
+ self._stream_unary = self._channel.stream_unary('/'.join((
'',
_SERVICE,
_STREAM_UNARY,
)),)
- self._stream_stream = channel.stream_stream(
+ self._stream_stream = self._channel.stream_stream(
'/'.join((
'',
_SERVICE,
@@ -228,6 +228,10 @@ class MetadataCodeDetailsTest(unittest.TestCase):
response_deserializer=_RESPONSE_DESERIALIZER,
)
+ def tearDown(self):
+ self._server.stop(None)
+ self._channel.close()
+
def testSuccessfulUnaryUnary(self):
self._servicer.set_details(_DETAILS)
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
index 2d352e99d4..7b32b5b5f3 100644
--- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
+++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
@@ -187,13 +187,14 @@ class MetadataFlagsTest(unittest.TestCase):
def test_call_wait_for_ready_default(self):
for perform_call in _ALL_CALL_CASES:
- self.check_connection_does_failfast(perform_call,
- create_dummy_channel())
+ with create_dummy_channel() as channel:
+ self.check_connection_does_failfast(perform_call, channel)
def test_call_wait_for_ready_disabled(self):
for perform_call in _ALL_CALL_CASES:
- self.check_connection_does_failfast(
- perform_call, create_dummy_channel(), wait_for_ready=False)
+ with create_dummy_channel() as channel:
+ self.check_connection_does_failfast(
+ perform_call, channel, wait_for_ready=False)
def test_call_wait_for_ready_enabled(self):
# To test the wait mechanism, Python thread is required to make
@@ -210,16 +211,16 @@ class MetadataFlagsTest(unittest.TestCase):
wg.done()
def test_call(perform_call):
- try:
- channel = grpc.insecure_channel(addr)
- channel.subscribe(wait_for_transient_failure)
- perform_call(channel, wait_for_ready=True)
- except BaseException as e: # pylint: disable=broad-except
- # If the call failed, the thread would be destroyed. The channel
- # object can be collected before calling the callback, which
- # will result in a deadlock.
- wg.done()
- unhandled_exceptions.put(e, True)
+ with grpc.insecure_channel(addr) as channel:
+ try:
+ channel.subscribe(wait_for_transient_failure)
+ perform_call(channel, wait_for_ready=True)
+ except BaseException as e: # pylint: disable=broad-except
+ # If the call failed, the thread would be destroyed. The
+ # channel object can be collected before calling the
+ # callback, which will result in a deadlock.
+ wg.done()
+ unhandled_exceptions.put(e, True)
test_threads = []
for perform_call in _ALL_CALL_CASES:
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py
index 777ab683e3..892df3df08 100644
--- a/src/python/grpcio_tests/tests/unit/_metadata_test.py
+++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py
@@ -186,6 +186,7 @@ class MetadataTest(unittest.TestCase):
def tearDown(self):
self._server.stop(0)
+ self._channel.close()
def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
diff --git a/src/python/grpcio_tests/tests/unit/_reconnect_test.py b/src/python/grpcio_tests/tests/unit/_reconnect_test.py
index f6d4fcbd0a..d4ea126e2b 100644
--- a/src/python/grpcio_tests/tests/unit/_reconnect_test.py
+++ b/src/python/grpcio_tests/tests/unit/_reconnect_test.py
@@ -98,6 +98,8 @@ class ReconnectTest(unittest.TestCase):
server.add_insecure_port('[::]:{}'.format(port))
server.start()
self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
+ server.stop(None)
+ channel.close()
if __name__ == '__main__':
diff --git a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
index 4fead8fcd5..517c2d2f97 100644
--- a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
+++ b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
@@ -148,6 +148,7 @@ class ResourceExhaustedTest(unittest.TestCase):
def tearDown(self):
self._server.stop(0)
+ self._channel.close()
def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
diff --git a/src/python/grpcio_tests/tests/unit/_rpc_test.py b/src/python/grpcio_tests/tests/unit/_rpc_test.py
index a768d6c7c1..a99121cee5 100644
--- a/src/python/grpcio_tests/tests/unit/_rpc_test.py
+++ b/src/python/grpcio_tests/tests/unit/_rpc_test.py
@@ -193,6 +193,7 @@ class RPCTest(unittest.TestCase):
def tearDown(self):
self._server.stop(None)
+ self._channel.close()
def testUnrecognizedMethod(self):
request = b'abc'
diff --git a/src/ruby/end2end/graceful_sig_handling_client.rb b/src/ruby/end2end/graceful_sig_handling_client.rb
new file mode 100755
index 0000000000..14a67a62cc
--- /dev/null
+++ b/src/ruby/end2end/graceful_sig_handling_client.rb
@@ -0,0 +1,61 @@
+#!/usr/bin/env ruby
+
+# Copyright 2015 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+require_relative './end2end_common'
+
+# Test client. Sends RPC's as normal but process also has signal handlers
+class SigHandlingClientController < ClientControl::ClientController::Service
+ def initialize(stub)
+ @stub = stub
+ end
+
+ def do_echo_rpc(req, _)
+ response = @stub.echo(Echo::EchoRequest.new(request: req.request))
+ fail 'bad response' unless response.response == req.request
+ ClientControl::Void.new
+ end
+end
+
+def main
+ client_control_port = ''
+ server_port = ''
+ OptionParser.new do |opts|
+ opts.on('--client_control_port=P', String) do |p|
+ client_control_port = p
+ end
+ opts.on('--server_port=P', String) do |p|
+ server_port = p
+ end
+ end.parse!
+
+ # Allow a few seconds to be safe.
+ srv = new_rpc_server_for_testing
+ srv.add_http2_port("0.0.0.0:#{client_control_port}",
+ :this_port_is_insecure)
+ stub = Echo::EchoServer::Stub.new("localhost:#{server_port}",
+ :this_channel_is_insecure)
+ control_service = SigHandlingClientController.new(stub)
+ srv.handle(control_service)
+ server_thread = Thread.new do
+ srv.run_till_terminated_or_interrupted(['int'])
+ end
+ srv.wait_till_running
+ # send a first RPC to notify the parent process that we've started
+ stub.echo(Echo::EchoRequest.new(request: 'client/child started'))
+ server_thread.join
+end
+
+main
diff --git a/src/ruby/end2end/graceful_sig_handling_driver.rb b/src/ruby/end2end/graceful_sig_handling_driver.rb
new file mode 100755
index 0000000000..e12ae28485
--- /dev/null
+++ b/src/ruby/end2end/graceful_sig_handling_driver.rb
@@ -0,0 +1,83 @@
+#!/usr/bin/env ruby
+
+# Copyright 2016 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# smoke test for a grpc-using app that receives and
+# handles process-ending signals
+
+require_relative './end2end_common'
+
+# A service that calls back it's received_rpc_callback
+# upon receiving an RPC. Used for synchronization/waiting
+# for child process to start.
+class ClientStartedService < Echo::EchoServer::Service
+ def initialize(received_rpc_callback)
+ @received_rpc_callback = received_rpc_callback
+ end
+
+ def echo(echo_req, _)
+ @received_rpc_callback.call unless @received_rpc_callback.nil?
+ @received_rpc_callback = nil
+ Echo::EchoReply.new(response: echo_req.request)
+ end
+end
+
+def main
+ STDERR.puts 'start server'
+ client_started = false
+ client_started_mu = Mutex.new
+ client_started_cv = ConditionVariable.new
+ received_rpc_callback = proc do
+ client_started_mu.synchronize do
+ client_started = true
+ client_started_cv.signal
+ end
+ end
+
+ client_started_service = ClientStartedService.new(received_rpc_callback)
+ server_runner = ServerRunner.new(client_started_service)
+ server_port = server_runner.run
+ STDERR.puts 'start client'
+ control_stub, client_pid = start_client('graceful_sig_handling_client.rb', server_port)
+
+ client_started_mu.synchronize do
+ client_started_cv.wait(client_started_mu) until client_started
+ end
+
+ control_stub.do_echo_rpc(
+ ClientControl::DoEchoRpcRequest.new(request: 'hello'))
+
+ STDERR.puts 'killing client'
+ Process.kill('SIGINT', client_pid)
+ Process.wait(client_pid)
+ client_exit_status = $CHILD_STATUS
+
+ if client_exit_status.exited?
+ if client_exit_status.exitstatus != 0
+ STDERR.puts 'Client did not close gracefully'
+ exit(1)
+ end
+ else
+ STDERR.puts 'Client did not close gracefully'
+ exit(1)
+ end
+
+ STDERR.puts 'Client ended gracefully'
+
+ # no need to call cleanup, client should already be dead
+ server_runner.stop
+end
+
+main
diff --git a/src/ruby/end2end/graceful_sig_stop_client.rb b/src/ruby/end2end/graceful_sig_stop_client.rb
new file mode 100755
index 0000000000..b672dc3f2a
--- /dev/null
+++ b/src/ruby/end2end/graceful_sig_stop_client.rb
@@ -0,0 +1,78 @@
+#!/usr/bin/env ruby
+
+# Copyright 2015 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+require_relative './end2end_common'
+
+# Test client. Sends RPC's as normal but process also has signal handlers
+class SigHandlingClientController < ClientControl::ClientController::Service
+ def initialize(srv, stub)
+ @srv = srv
+ @stub = stub
+ end
+
+ def do_echo_rpc(req, _)
+ response = @stub.echo(Echo::EchoRequest.new(request: req.request))
+ fail 'bad response' unless response.response == req.request
+ ClientControl::Void.new
+ end
+
+ def shutdown(_, _)
+ # Spawn a new thread because RpcServer#stop is
+ # synchronous and blocks until either this RPC has finished,
+ # or the server's "poll_period" seconds have passed.
+ @shutdown_thread = Thread.new do
+ @srv.stop
+ end
+ ClientControl::Void.new
+ end
+
+ def join_shutdown_thread
+ @shutdown_thread.join
+ end
+end
+
+def main
+ client_control_port = ''
+ server_port = ''
+ OptionParser.new do |opts|
+ opts.on('--client_control_port=P', String) do |p|
+ client_control_port = p
+ end
+ opts.on('--server_port=P', String) do |p|
+ server_port = p
+ end
+ end.parse!
+
+ # The "shutdown" RPC should end very quickly.
+ # Allow a few seconds to be safe.
+ srv = new_rpc_server_for_testing(poll_period: 3)
+ srv.add_http2_port("0.0.0.0:#{client_control_port}",
+ :this_port_is_insecure)
+ stub = Echo::EchoServer::Stub.new("localhost:#{server_port}",
+ :this_channel_is_insecure)
+ control_service = SigHandlingClientController.new(srv, stub)
+ srv.handle(control_service)
+ server_thread = Thread.new do
+ srv.run_till_terminated_or_interrupted(['int'])
+ end
+ srv.wait_till_running
+ # send a first RPC to notify the parent process that we've started
+ stub.echo(Echo::EchoRequest.new(request: 'client/child started'))
+ server_thread.join
+ control_service.join_shutdown_thread
+end
+
+main
diff --git a/src/ruby/end2end/graceful_sig_stop_driver.rb b/src/ruby/end2end/graceful_sig_stop_driver.rb
new file mode 100755
index 0000000000..7a132403eb
--- /dev/null
+++ b/src/ruby/end2end/graceful_sig_stop_driver.rb
@@ -0,0 +1,62 @@
+#!/usr/bin/env ruby
+
+# Copyright 2016 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# smoke test for a grpc-using app that receives and
+# handles process-ending signals
+
+require_relative './end2end_common'
+
+# A service that calls back it's received_rpc_callback
+# upon receiving an RPC. Used for synchronization/waiting
+# for child process to start.
+class ClientStartedService < Echo::EchoServer::Service
+ def initialize(received_rpc_callback)
+ @received_rpc_callback = received_rpc_callback
+ end
+
+ def echo(echo_req, _)
+ @received_rpc_callback.call unless @received_rpc_callback.nil?
+ @received_rpc_callback = nil
+ Echo::EchoReply.new(response: echo_req.request)
+ end
+end
+
+def main
+ STDERR.puts 'start server'
+ client_started = false
+ client_started_mu = Mutex.new
+ client_started_cv = ConditionVariable.new
+ received_rpc_callback = proc do
+ client_started_mu.synchronize do
+ client_started = true
+ client_started_cv.signal
+ end
+ end
+
+ client_started_service = ClientStartedService.new(received_rpc_callback)
+ server_runner = ServerRunner.new(client_started_service)
+ server_port = server_runner.run
+ STDERR.puts 'start client'
+ control_stub, client_pid = start_client('./graceful_sig_stop_client.rb', server_port)
+
+ client_started_mu.synchronize do
+ client_started_cv.wait(client_started_mu) until client_started
+ end
+
+ cleanup(control_stub, client_pid, server_runner)
+end
+
+main
diff --git a/src/ruby/lib/grpc/generic/rpc_server.rb b/src/ruby/lib/grpc/generic/rpc_server.rb
index 3b5a0ce27f..f0f73dc56e 100644
--- a/src/ruby/lib/grpc/generic/rpc_server.rb
+++ b/src/ruby/lib/grpc/generic/rpc_server.rb
@@ -240,6 +240,13 @@ module GRPC
# the call has no impact if the server is already stopped, otherwise
# server's current call loop is it's last.
def stop
+ # if called via run_till_terminated_or_interrupted,
+ # signal stop_server_thread and dont do anything
+ if @stop_server.nil? == false && @stop_server == false
+ @stop_server = true
+ @stop_server_cv.broadcast
+ return
+ end
@run_mutex.synchronize do
fail 'Cannot stop before starting' if @running_state == :not_started
return if @running_state != :running
@@ -354,6 +361,60 @@ module GRPC
alias_method :run_till_terminated, :run
+ # runs the server with signal handlers
+ # @param signals
+ # List of String, Integer or both representing signals that the user
+ # would like to send to the server for graceful shutdown
+ # @param wait_interval (optional)
+ # Integer seconds that user would like stop_server_thread to poll
+ # stop_server
+ def run_till_terminated_or_interrupted(signals, wait_interval = 60)
+ @stop_server = false
+ @stop_server_mu = Mutex.new
+ @stop_server_cv = ConditionVariable.new
+
+ @stop_server_thread = Thread.new do
+ loop do
+ break if @stop_server
+ @stop_server_mu.synchronize do
+ @stop_server_cv.wait(@stop_server_mu, wait_interval)
+ end
+ end
+
+ # stop is surrounded by mutex, should handle multiple calls to stop
+ # correctly
+ stop
+ end
+
+ valid_signals = Signal.list
+
+ # register signal handlers
+ signals.each do |sig|
+ # input validation
+ if sig.class == String
+ sig.upcase!
+ if sig.start_with?('SIG')
+ # cut out the SIG prefix to see if valid signal
+ sig = sig[3..-1]
+ end
+ end
+
+ # register signal traps for all valid signals
+ if valid_signals.value?(sig) || valid_signals.key?(sig)
+ Signal.trap(sig) do
+ @stop_server = true
+ @stop_server_cv.broadcast
+ end
+ else
+ fail "#{sig} not a valid signal"
+ end
+ end
+
+ run
+
+ @stop_server_thread.join
+ end
+
# Sends RESOURCE_EXHAUSTED if there are too many unprocessed jobs
def available?(an_rpc)
return an_rpc if @pool.ready_for_work?