aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--include/grpc++/impl/codegen/completion_queue.h6
-rw-r--r--include/grpc++/impl/codegen/method_handler_impl.h12
-rw-r--r--include/grpc++/impl/codegen/server_context.h6
-rw-r--r--include/grpc++/server.h18
-rw-r--r--include/grpc++/server_builder.h19
-rw-r--r--src/cpp/client/secure_credentials.cc14
-rw-r--r--src/cpp/server/create_default_thread_pool.cc2
-rw-r--r--src/cpp/server/dynamic_thread_pool.cc54
-rw-r--r--src/cpp/server/dynamic_thread_pool.h20
-rw-r--r--src/cpp/server/secure_server_credentials.cc7
-rw-r--r--src/cpp/server/server_builder.cc7
-rw-r--r--src/cpp/server/server_cc.cc46
-rw-r--r--src/cpp/server/thread_pool_interface.h4
-rw-r--r--src/cpp/thread_manager/thread_manager.cc54
-rw-r--r--src/cpp/thread_manager/thread_manager.h28
-rw-r--r--test/cpp/end2end/thread_stress_test.cc157
-rw-r--r--test/cpp/thread_manager/BUILD31
-rw-r--r--test/cpp/thread_manager/thread_manager_test.cc8
18 files changed, 361 insertions, 132 deletions
diff --git a/include/grpc++/impl/codegen/completion_queue.h b/include/grpc++/impl/codegen/completion_queue.h
index b8a7862578..452eac6646 100644
--- a/include/grpc++/impl/codegen/completion_queue.h
+++ b/include/grpc++/impl/codegen/completion_queue.h
@@ -78,7 +78,8 @@ template <class ServiceType, class RequestType, class ResponseType>
class ServerStreamingHandler;
template <class ServiceType, class RequestType, class ResponseType>
class BidiStreamingHandler;
-class UnknownMethodHandler;
+template <StatusCode code>
+class ErrorMethodHandler;
template <class Streamer, bool WriteNeeded>
class TemplatedBidiStreamingHandler;
template <class InputMessage, class OutputMessage>
@@ -221,7 +222,8 @@ class CompletionQueue : private GrpcLibraryCodegen {
friend class ::grpc::internal::ServerStreamingHandler;
template <class Streamer, bool WriteNeeded>
friend class ::grpc::internal::TemplatedBidiStreamingHandler;
- friend class ::grpc::internal::UnknownMethodHandler;
+ template <StatusCode code>
+ friend class ::grpc::internal::ErrorMethodHandler;
friend class ::grpc::Server;
friend class ::grpc::ServerContext;
friend class ::grpc::ServerInterface;
diff --git a/include/grpc++/impl/codegen/method_handler_impl.h b/include/grpc++/impl/codegen/method_handler_impl.h
index c0af4ca130..d98ab7938c 100644
--- a/include/grpc++/impl/codegen/method_handler_impl.h
+++ b/include/grpc++/impl/codegen/method_handler_impl.h
@@ -242,12 +242,14 @@ class SplitServerStreamingHandler
ServerSplitStreamer<RequestType, ResponseType>, false>(func) {}
};
-/// Handle unknown method by returning UNIMPLEMENTED error.
-class UnknownMethodHandler : public MethodHandler {
+/// General method handler class for errors that prevent real method use
+/// e.g., handle unknown method by returning UNIMPLEMENTED error.
+template <StatusCode code>
+class ErrorMethodHandler : public MethodHandler {
public:
template <class T>
static void FillOps(ServerContext* context, T* ops) {
- Status status(StatusCode::UNIMPLEMENTED, "");
+ Status status(code, "");
if (!context->sent_initial_metadata_) {
ops->SendInitialMetadata(context->initial_metadata_,
context->initial_metadata_flags());
@@ -267,6 +269,10 @@ class UnknownMethodHandler : public MethodHandler {
}
};
+typedef ErrorMethodHandler<StatusCode::UNIMPLEMENTED> UnknownMethodHandler;
+typedef ErrorMethodHandler<StatusCode::RESOURCE_EXHAUSTED>
+ ResourceExhaustedHandler;
+
} // namespace internal
} // namespace grpc
diff --git a/include/grpc++/impl/codegen/server_context.h b/include/grpc++/impl/codegen/server_context.h
index a2d6967bf8..9f20335a2a 100644
--- a/include/grpc++/impl/codegen/server_context.h
+++ b/include/grpc++/impl/codegen/server_context.h
@@ -63,7 +63,8 @@ template <class ServiceType, class RequestType, class ResponseType>
class ServerStreamingHandler;
template <class ServiceType, class RequestType, class ResponseType>
class BidiStreamingHandler;
-class UnknownMethodHandler;
+template <StatusCode code>
+class ErrorMethodHandler;
template <class Streamer, bool WriteNeeded>
class TemplatedBidiStreamingHandler;
class Call;
@@ -255,7 +256,8 @@ class ServerContext {
friend class ::grpc::internal::ServerStreamingHandler;
template <class Streamer, bool WriteNeeded>
friend class ::grpc::internal::TemplatedBidiStreamingHandler;
- friend class ::grpc::internal::UnknownMethodHandler;
+ template <StatusCode code>
+ friend class ::grpc::internal::ErrorMethodHandler;
friend class ::grpc::ClientContext;
/// Prevent copying.
diff --git a/include/grpc++/server.h b/include/grpc++/server.h
index 01c4a60d21..456603e4e7 100644
--- a/include/grpc++/server.h
+++ b/include/grpc++/server.h
@@ -35,6 +35,7 @@
#include <grpc++/support/config.h>
#include <grpc++/support/status.h>
#include <grpc/compression.h>
+#include <grpc/support/thd.h>
struct grpc_server;
@@ -138,10 +139,17 @@ class Server final : public ServerInterface, private GrpcLibraryCodegen {
///
/// \param sync_cq_timeout_msec The timeout to use when calling AsyncNext() on
/// server completion queues passed via sync_server_cqs param.
+ ///
+ /// \param thread_creator The thread creation function for the sync
+ /// server. Typically gpr_thd_new
Server(int max_message_size, ChannelArguments* args,
std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>>
sync_server_cqs,
- int min_pollers, int max_pollers, int sync_cq_timeout_msec);
+ int min_pollers, int max_pollers, int sync_cq_timeout_msec,
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*,
+ const gpr_thd_options*)>
+ thread_creator,
+ std::function<void(gpr_thd_id)> thread_joiner);
/// Register a service. This call does not take ownership of the service.
/// The service must exist for the lifetime of the Server instance.
@@ -220,6 +228,14 @@ class Server final : public ServerInterface, private GrpcLibraryCodegen {
std::unique_ptr<HealthCheckServiceInterface> health_check_service_;
bool health_check_service_disabled_;
+
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*,
+ const gpr_thd_options*)>
+ thread_creator_;
+ std::function<void(gpr_thd_id)> thread_joiner_;
+
+ // A special handler for resource exhausted in sync case
+ std::unique_ptr<internal::MethodHandler> resource_exhausted_handler_;
};
} // namespace grpc
diff --git a/include/grpc++/server_builder.h b/include/grpc++/server_builder.h
index e2bae4b41f..25bbacbbc7 100644
--- a/include/grpc++/server_builder.h
+++ b/include/grpc++/server_builder.h
@@ -20,6 +20,7 @@
#define GRPCXX_SERVER_BUILDER_H
#include <climits>
+#include <functional>
#include <map>
#include <memory>
#include <vector>
@@ -30,6 +31,7 @@
#include <grpc++/support/config.h>
#include <grpc/compression.h>
#include <grpc/support/cpu.h>
+#include <grpc/support/thd.h>
#include <grpc/support/useful.h>
#include <grpc/support/workaround_list.h>
@@ -47,6 +49,7 @@ class Service;
namespace testing {
class ServerBuilderPluginTest;
+class ServerBuilderThreadCreatorOverrideTest;
} // namespace testing
/// A builder class for the creation and startup of \a grpc::Server instances.
@@ -213,6 +216,17 @@ class ServerBuilder {
private:
friend class ::grpc::testing::ServerBuilderPluginTest;
+ friend class ::grpc::testing::ServerBuilderThreadCreatorOverrideTest;
+
+ ServerBuilder& SetThreadFunctions(
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*,
+ const gpr_thd_options*)>
+ thread_creator,
+ std::function<void(gpr_thd_id)> thread_joiner) {
+ thread_creator_ = thread_creator;
+ thread_joiner_ = thread_joiner;
+ return *this;
+ }
struct Port {
grpc::string addr;
@@ -272,6 +286,11 @@ class ServerBuilder {
grpc_compression_algorithm algorithm;
} maybe_default_compression_algorithm_;
uint32_t enabled_compression_algorithms_bitset_;
+
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*,
+ const gpr_thd_options*)>
+ thread_creator_;
+ std::function<void(gpr_thd_id)> thread_joiner_;
};
} // namespace grpc
diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc
index 4fb128d98b..94519d817b 100644
--- a/src/cpp/client/secure_credentials.cc
+++ b/src/cpp/client/secure_credentials.cc
@@ -189,10 +189,16 @@ int MetadataCredentialsPluginWrapper::GetMetadata(
}
if (w->plugin_->IsBlocking()) {
// Asynchronous return.
- w->thread_pool_->Add(
- std::bind(&MetadataCredentialsPluginWrapper::InvokePlugin, w, context,
- cb, user_data, nullptr, nullptr, nullptr, nullptr));
- return 0;
+ if (w->thread_pool_->Add(std::bind(
+ &MetadataCredentialsPluginWrapper::InvokePlugin, w, context, cb,
+ user_data, nullptr, nullptr, nullptr, nullptr))) {
+ return 0;
+ } else {
+ *num_creds_md = 0;
+ *status = GRPC_STATUS_RESOURCE_EXHAUSTED;
+ *error_details = nullptr;
+ return true;
+ }
} else {
// Synchronous return.
w->InvokePlugin(context, cb, user_data, creds_md, num_creds_md, status,
diff --git a/src/cpp/server/create_default_thread_pool.cc b/src/cpp/server/create_default_thread_pool.cc
index 8ca3e32c2f..2d2abbe9d1 100644
--- a/src/cpp/server/create_default_thread_pool.cc
+++ b/src/cpp/server/create_default_thread_pool.cc
@@ -28,7 +28,7 @@ namespace {
ThreadPoolInterface* CreateDefaultThreadPoolImpl() {
int cores = gpr_cpu_num_cores();
if (!cores) cores = 4;
- return new DynamicThreadPool(cores);
+ return new DynamicThreadPool(cores, gpr_thd_new, gpr_thd_join);
}
CreateThreadPoolFunc g_ctp_impl = CreateDefaultThreadPoolImpl;
diff --git a/src/cpp/server/dynamic_thread_pool.cc b/src/cpp/server/dynamic_thread_pool.cc
index 81c78fe739..d0e62313f6 100644
--- a/src/cpp/server/dynamic_thread_pool.cc
+++ b/src/cpp/server/dynamic_thread_pool.cc
@@ -19,19 +19,32 @@
#include "src/cpp/server/dynamic_thread_pool.h"
#include <mutex>
-#include <thread>
#include <grpc/support/log.h>
+#include <grpc/support/thd.h>
namespace grpc {
-DynamicThreadPool::DynamicThread::DynamicThread(DynamicThreadPool* pool)
- : pool_(pool),
- thd_(new std::thread(&DynamicThreadPool::DynamicThread::ThreadFunc,
- this)) {}
+DynamicThreadPool::DynamicThread::DynamicThread(DynamicThreadPool* pool,
+ bool* valid)
+ : pool_(pool) {
+ gpr_thd_options opt = gpr_thd_options_default();
+ gpr_thd_options_set_joinable(&opt);
+
+ std::lock_guard<std::mutex> l(dt_mu_);
+ valid_ = *valid = pool->thread_creator_(
+ &thd_, "dynamic thread",
+ [](void* th) {
+ reinterpret_cast<DynamicThreadPool::DynamicThread*>(th)->ThreadFunc();
+ },
+ this, &opt);
+}
+
DynamicThreadPool::DynamicThread::~DynamicThread() {
- thd_->join();
- thd_.reset();
+ std::lock_guard<std::mutex> l(dt_mu_);
+ if (valid_) {
+ pool_->thread_joiner_(thd_);
+ }
}
void DynamicThreadPool::DynamicThread::ThreadFunc() {
@@ -73,15 +86,26 @@ void DynamicThreadPool::ThreadFunc() {
}
}
-DynamicThreadPool::DynamicThreadPool(int reserve_threads)
+DynamicThreadPool::DynamicThreadPool(
+ int reserve_threads,
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*,
+ const gpr_thd_options*)>
+ thread_creator,
+ std::function<void(gpr_thd_id)> thread_joiner)
: shutdown_(false),
reserve_threads_(reserve_threads),
nthreads_(0),
- threads_waiting_(0) {
+ threads_waiting_(0),
+ thread_creator_(thread_creator),
+ thread_joiner_(thread_joiner) {
for (int i = 0; i < reserve_threads_; i++) {
std::lock_guard<std::mutex> lock(mu_);
nthreads_++;
- new DynamicThread(this);
+ bool valid;
+ auto* th = new DynamicThread(this, &valid);
+ if (!valid) {
+ delete th;
+ }
}
}
@@ -101,7 +125,7 @@ DynamicThreadPool::~DynamicThreadPool() {
ReapThreads(&dead_threads_);
}
-void DynamicThreadPool::Add(const std::function<void()>& callback) {
+bool DynamicThreadPool::Add(const std::function<void()>& callback) {
std::lock_guard<std::mutex> lock(mu_);
// Add works to the callbacks list
callbacks_.push(callback);
@@ -109,7 +133,12 @@ void DynamicThreadPool::Add(const std::function<void()>& callback) {
if (threads_waiting_ == 0) {
// Kick off a new thread
nthreads_++;
- new DynamicThread(this);
+ bool valid;
+ auto* th = new DynamicThread(this, &valid);
+ if (!valid) {
+ delete th;
+ return false;
+ }
} else {
cv_.notify_one();
}
@@ -117,6 +146,7 @@ void DynamicThreadPool::Add(const std::function<void()>& callback) {
if (!dead_threads_.empty()) {
ReapThreads(&dead_threads_);
}
+ return true;
}
} // namespace grpc
diff --git a/src/cpp/server/dynamic_thread_pool.h b/src/cpp/server/dynamic_thread_pool.h
index 9237c6e5ca..75d31cd908 100644
--- a/src/cpp/server/dynamic_thread_pool.h
+++ b/src/cpp/server/dynamic_thread_pool.h
@@ -24,9 +24,9 @@
#include <memory>
#include <mutex>
#include <queue>
-#include <thread>
#include <grpc++/support/config.h>
+#include <grpc/support/thd.h>
#include "src/cpp/server/thread_pool_interface.h"
@@ -34,20 +34,26 @@ namespace grpc {
class DynamicThreadPool final : public ThreadPoolInterface {
public:
- explicit DynamicThreadPool(int reserve_threads);
+ DynamicThreadPool(int reserve_threads,
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*),
+ void*, const gpr_thd_options*)>
+ thread_creator,
+ std::function<void(gpr_thd_id)> thread_joiner);
~DynamicThreadPool();
- void Add(const std::function<void()>& callback) override;
+ bool Add(const std::function<void()>& callback) override;
private:
class DynamicThread {
public:
- DynamicThread(DynamicThreadPool* pool);
+ DynamicThread(DynamicThreadPool* pool, bool* valid);
~DynamicThread();
private:
DynamicThreadPool* pool_;
- std::unique_ptr<std::thread> thd_;
+ std::mutex dt_mu_;
+ gpr_thd_id thd_;
+ bool valid_;
void ThreadFunc();
};
std::mutex mu_;
@@ -59,6 +65,10 @@ class DynamicThreadPool final : public ThreadPoolInterface {
int nthreads_;
int threads_waiting_;
std::list<DynamicThread*> dead_threads_;
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*,
+ const gpr_thd_options*)>
+ thread_creator_;
+ std::function<void(gpr_thd_id)> thread_joiner_;
void ThreadFunc();
static void ReapThreads(std::list<DynamicThread*>* tlist);
diff --git a/src/cpp/server/secure_server_credentials.cc b/src/cpp/server/secure_server_credentials.cc
index 0fbe4ccd18..fa08a6200f 100644
--- a/src/cpp/server/secure_server_credentials.cc
+++ b/src/cpp/server/secure_server_credentials.cc
@@ -43,9 +43,14 @@ void AuthMetadataProcessorAyncWrapper::Process(
return;
}
if (w->processor_->IsBlocking()) {
- w->thread_pool_->Add(
+ bool added = w->thread_pool_->Add(
std::bind(&AuthMetadataProcessorAyncWrapper::InvokeProcessor, w,
context, md, num_md, cb, user_data));
+ if (!added) {
+ // no thread available, so fail with temporary resource unavailability
+ cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAVAILABLE, nullptr);
+ return;
+ }
} else {
// invoke directly.
w->InvokeProcessor(context, md, num_md, cb, user_data);
diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc
index 200e477822..d91ee7f4e3 100644
--- a/src/cpp/server/server_builder.cc
+++ b/src/cpp/server/server_builder.cc
@@ -23,6 +23,7 @@
#include <grpc++/server.h>
#include <grpc/support/cpu.h>
#include <grpc/support/log.h>
+#include <grpc/support/thd.h>
#include <grpc/support/useful.h>
#include "src/cpp/server/thread_pool_interface.h"
@@ -43,7 +44,9 @@ ServerBuilder::ServerBuilder()
max_send_message_size_(-1),
sync_server_settings_(SyncServerSettings()),
resource_quota_(nullptr),
- generic_service_(nullptr) {
+ generic_service_(nullptr),
+ thread_creator_(gpr_thd_new),
+ thread_joiner_(gpr_thd_join) {
gpr_once_init(&once_init_plugin_list, do_plugin_list_init);
for (auto it = g_plugin_factory_list->begin();
it != g_plugin_factory_list->end(); it++) {
@@ -262,7 +265,7 @@ std::unique_ptr<Server> ServerBuilder::BuildAndStart() {
std::unique_ptr<Server> server(new Server(
max_receive_message_size_, &args, sync_server_cqs,
sync_server_settings_.min_pollers, sync_server_settings_.max_pollers,
- sync_server_settings_.cq_timeout_msec));
+ sync_server_settings_.cq_timeout_msec, thread_creator_, thread_joiner_));
if (has_sync_methods) {
// This is a Sync server
diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc
index 4f8f4e06fc..6ab76a287e 100644
--- a/src/cpp/server/server_cc.cc
+++ b/src/cpp/server/server_cc.cc
@@ -36,6 +36,7 @@
#include <grpc/grpc.h>
#include <grpc/support/alloc.h>
#include <grpc/support/log.h>
+#include <grpc/support/thd.h>
#include "src/core/ext/transport/inproc/inproc_transport.h"
#include "src/core/lib/profiling/timers.h"
@@ -196,7 +197,8 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
ctx_(mrd->deadline_, &mrd->request_metadata_),
has_request_payload_(mrd->has_request_payload_),
request_payload_(mrd->request_payload_),
- method_(mrd->method_) {
+ method_(mrd->method_),
+ server_(server) {
ctx_.set_call(mrd->call_);
ctx_.cq_ = &cq_;
GPR_ASSERT(mrd->in_flight_);
@@ -210,10 +212,13 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
}
}
- void Run(std::shared_ptr<GlobalCallbacks> global_callbacks) {
+ void Run(std::shared_ptr<GlobalCallbacks> global_callbacks,
+ bool resources) {
ctx_.BeginCompletionOp(&call_);
global_callbacks->PreSynchronousRequest(&ctx_);
- method_->handler()->RunHandler(internal::MethodHandler::HandlerParameter(
+ auto* handler = resources ? method_->handler()
+ : server_->resource_exhausted_handler_.get();
+ handler->RunHandler(internal::MethodHandler::HandlerParameter(
&call_, &ctx_, request_payload_));
global_callbacks->PostSynchronousRequest(&ctx_);
request_payload_ = nullptr;
@@ -235,6 +240,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
const bool has_request_payload_;
grpc_byte_buffer* request_payload_;
internal::RpcServiceMethod* const method_;
+ Server* server_;
};
private:
@@ -255,11 +261,15 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
// appropriate RPC handlers
class Server::SyncRequestThreadManager : public ThreadManager {
public:
- SyncRequestThreadManager(Server* server, CompletionQueue* server_cq,
- std::shared_ptr<GlobalCallbacks> global_callbacks,
- int min_pollers, int max_pollers,
- int cq_timeout_msec)
- : ThreadManager(min_pollers, max_pollers),
+ SyncRequestThreadManager(
+ Server* server, CompletionQueue* server_cq,
+ std::shared_ptr<GlobalCallbacks> global_callbacks, int min_pollers,
+ int max_pollers, int cq_timeout_msec,
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*,
+ const gpr_thd_options*)>
+ thread_creator,
+ std::function<void(gpr_thd_id)> thread_joiner)
+ : ThreadManager(min_pollers, max_pollers, thread_creator, thread_joiner),
server_(server),
server_cq_(server_cq),
cq_timeout_msec_(cq_timeout_msec),
@@ -285,7 +295,7 @@ class Server::SyncRequestThreadManager : public ThreadManager {
GPR_UNREACHABLE_CODE(return TIMEOUT);
}
- void DoWork(void* tag, bool ok) override {
+ void DoWork(void* tag, bool ok, bool resources) override {
SyncRequest* sync_req = static_cast<SyncRequest*>(tag);
if (!sync_req) {
@@ -305,7 +315,7 @@ class Server::SyncRequestThreadManager : public ThreadManager {
}
GPR_TIMER_SCOPE("cd.Run()", 0);
- cd.Run(global_callbacks_);
+ cd.Run(global_callbacks_, resources);
}
// TODO (sreek) If ok is false here (which it isn't in case of
// grpc_request_registered_call), we should still re-queue the request
@@ -367,7 +377,11 @@ Server::Server(
int max_receive_message_size, ChannelArguments* args,
std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>>
sync_server_cqs,
- int min_pollers, int max_pollers, int sync_cq_timeout_msec)
+ int min_pollers, int max_pollers, int sync_cq_timeout_msec,
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*,
+ const gpr_thd_options*)>
+ thread_creator,
+ std::function<void(gpr_thd_id)> thread_joiner)
: max_receive_message_size_(max_receive_message_size),
sync_server_cqs_(sync_server_cqs),
started_(false),
@@ -376,7 +390,9 @@ Server::Server(
has_generic_service_(false),
server_(nullptr),
server_initializer_(new ServerInitializer(this)),
- health_check_service_disabled_(false) {
+ health_check_service_disabled_(false),
+ thread_creator_(thread_creator),
+ thread_joiner_(thread_joiner) {
g_gli_initializer.summon();
gpr_once_init(&g_once_init_callbacks, InitGlobalCallbacks);
global_callbacks_ = g_callbacks;
@@ -386,7 +402,7 @@ Server::Server(
it++) {
sync_req_mgrs_.emplace_back(new SyncRequestThreadManager(
this, (*it).get(), global_callbacks_, min_pollers, max_pollers,
- sync_cq_timeout_msec));
+ sync_cq_timeout_msec, thread_creator_, thread_joiner_));
}
grpc_channel_args channel_args;
@@ -549,6 +565,10 @@ void Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) {
}
}
+ if (!sync_server_cqs_->empty()) {
+ resource_exhausted_handler_.reset(new internal::ResourceExhaustedHandler);
+ }
+
for (auto it = sync_req_mgrs_.begin(); it != sync_req_mgrs_.end(); it++) {
(*it)->Start();
}
diff --git a/src/cpp/server/thread_pool_interface.h b/src/cpp/server/thread_pool_interface.h
index 028842a776..656e6673f1 100644
--- a/src/cpp/server/thread_pool_interface.h
+++ b/src/cpp/server/thread_pool_interface.h
@@ -29,7 +29,9 @@ class ThreadPoolInterface {
virtual ~ThreadPoolInterface() {}
// Schedule the given callback for execution.
- virtual void Add(const std::function<void()>& callback) = 0;
+ // Return true on success, false on failure
+ virtual bool Add(const std::function<void()>& callback)
+ GRPC_MUST_USE_RESULT = 0;
};
// Allows different codebases to use their own thread pool impls
diff --git a/src/cpp/thread_manager/thread_manager.cc b/src/cpp/thread_manager/thread_manager.cc
index 23264f1b5b..107c60f4eb 100644
--- a/src/cpp/thread_manager/thread_manager.cc
+++ b/src/cpp/thread_manager/thread_manager.cc
@@ -20,18 +20,26 @@
#include <climits>
#include <mutex>
-#include <thread>
#include <grpc/support/log.h>
+#include <grpc/support/thd.h>
namespace grpc {
-ThreadManager::WorkerThread::WorkerThread(ThreadManager* thd_mgr)
+ThreadManager::WorkerThread::WorkerThread(ThreadManager* thd_mgr, bool* valid)
: thd_mgr_(thd_mgr) {
+ gpr_thd_options opt = gpr_thd_options_default();
+ gpr_thd_options_set_joinable(&opt);
+
// Make thread creation exclusive with respect to its join happening in
// ~WorkerThread().
std::lock_guard<std::mutex> lock(wt_mu_);
- thd_ = std::thread(&ThreadManager::WorkerThread::Run, this);
+ *valid = valid_ = thd_mgr->thread_creator_(
+ &thd_, "worker thread",
+ [](void* th) {
+ reinterpret_cast<ThreadManager::WorkerThread*>(th)->Run();
+ },
+ this, &opt);
}
void ThreadManager::WorkerThread::Run() {
@@ -42,15 +50,24 @@ void ThreadManager::WorkerThread::Run() {
ThreadManager::WorkerThread::~WorkerThread() {
// Don't join until the thread is fully constructed.
std::lock_guard<std::mutex> lock(wt_mu_);
- thd_.join();
+ if (valid_) {
+ thd_mgr_->thread_joiner_(thd_);
+ }
}
-ThreadManager::ThreadManager(int min_pollers, int max_pollers)
+ThreadManager::ThreadManager(
+ int min_pollers, int max_pollers,
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*,
+ const gpr_thd_options*)>
+ thread_creator,
+ std::function<void(gpr_thd_id)> thread_joiner)
: shutdown_(false),
num_pollers_(0),
min_pollers_(min_pollers),
max_pollers_(max_pollers == -1 ? INT_MAX : max_pollers),
- num_threads_(0) {}
+ num_threads_(0),
+ thread_creator_(thread_creator),
+ thread_joiner_(thread_joiner) {}
ThreadManager::~ThreadManager() {
{
@@ -111,7 +128,9 @@ void ThreadManager::Initialize() {
for (int i = 0; i < min_pollers_; i++) {
// Create a new thread (which ends up calling the MainWorkLoop() function
- new WorkerThread(this);
+ bool valid;
+ new WorkerThread(this, &valid);
+ GPR_ASSERT(valid); // we need to have at least this minimum
}
}
@@ -138,18 +157,27 @@ void ThreadManager::MainWorkLoop() {
case WORK_FOUND:
// If we got work and there are now insufficient pollers, start a new
// one
+ bool resources;
if (!shutdown_ && num_pollers_ < min_pollers_) {
- num_pollers_++;
- num_threads_++;
+ bool valid;
// Drop lock before spawning thread to avoid contention
lock.unlock();
- new WorkerThread(this);
+ auto* th = new WorkerThread(this, &valid);
+ lock.lock();
+ if (valid) {
+ num_pollers_++;
+ num_threads_++;
+ } else {
+ delete th;
+ }
+ resources = (num_pollers_ > 0);
} else {
- // Drop lock for consistency with above branch
- lock.unlock();
+ resources = true;
}
+ // Drop lock before any application work
+ lock.unlock();
// Lock is always released at this point - do the application work
- DoWork(tag, ok);
+ DoWork(tag, ok, resources);
// Take the lock again to check post conditions
lock.lock();
// If we're shutdown, we should finish at this point.
diff --git a/src/cpp/thread_manager/thread_manager.h b/src/cpp/thread_manager/thread_manager.h
index a206e0bd8a..4fa8a6c563 100644
--- a/src/cpp/thread_manager/thread_manager.h
+++ b/src/cpp/thread_manager/thread_manager.h
@@ -23,15 +23,19 @@
#include <list>
#include <memory>
#include <mutex>
-#include <thread>
#include <grpc++/support/config.h>
+#include <grpc/support/thd.h>
namespace grpc {
class ThreadManager {
public:
- explicit ThreadManager(int min_pollers, int max_pollers);
+ ThreadManager(int min_pollers, int max_pollers,
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*),
+ void*, const gpr_thd_options*)>
+ thread_creator,
+ std::function<void(gpr_thd_id)> thread_joiner);
virtual ~ThreadManager();
// Initializes and Starts the Rpc Manager threads
@@ -50,6 +54,8 @@ class ThreadManager {
// - ThreadManager does not interpret the values of 'tag' and 'ok'
// - ThreadManager WILL call DoWork() and pass '*tag' and 'ok' as input to
// DoWork()
+ // - ThreadManager will also pass DoWork a bool saying if there are actually
+ // resources to do the work
//
// If the return value is SHUTDOWN:,
// - ThreadManager WILL NOT call DoWork() and terminates the thead
@@ -69,7 +75,7 @@ class ThreadManager {
// The implementation of DoWork() should also do any setup needed to ensure
// that the next call to PollForWork() (not necessarily by the current thread)
// actually finds some work
- virtual void DoWork(void* tag, bool ok) = 0;
+ virtual void DoWork(void* tag, bool ok, bool resources) = 0;
// Mark the ThreadManager as shutdown and begin draining the work. This is a
// non-blocking call and the caller should call Wait(), a blocking call which
@@ -84,15 +90,15 @@ class ThreadManager {
virtual void Wait();
private:
- // Helper wrapper class around std::thread. This takes a ThreadManager object
- // and starts a new std::thread to calls the Run() function.
+ // Helper wrapper class around thread. This takes a ThreadManager object
+ // and starts a new thread to calls the Run() function.
//
// The Run() function calls ThreadManager::MainWorkLoop() function and once
// that completes, it marks the WorkerThread completed by calling
// ThreadManager::MarkAsCompleted()
class WorkerThread {
public:
- WorkerThread(ThreadManager* thd_mgr);
+ WorkerThread(ThreadManager* thd_mgr, bool* valid);
~WorkerThread();
private:
@@ -102,7 +108,8 @@ class ThreadManager {
ThreadManager* const thd_mgr_;
std::mutex wt_mu_;
- std::thread thd_;
+ gpr_thd_id thd_;
+ bool valid_;
};
// The main funtion in ThreadManager
@@ -129,6 +136,13 @@ class ThreadManager {
// currently polling i.e num_pollers_)
int num_threads_;
+ // Functions for creating/joining threads. Normally, these should
+ // be gpr_thd_new/gpr_thd_join but they are overridable
+ std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*,
+ const gpr_thd_options*)>
+ thread_creator_;
+ std::function<void(gpr_thd_id)> thread_joiner_;
+
std::mutex list_mu_;
std::list<WorkerThread*> completed_threads_;
};
diff --git a/test/cpp/end2end/thread_stress_test.cc b/test/cpp/end2end/thread_stress_test.cc
index 90b2eddbbb..fd43c8f584 100644
--- a/test/cpp/end2end/thread_stress_test.cc
+++ b/test/cpp/end2end/thread_stress_test.cc
@@ -26,6 +26,7 @@
#include <grpc++/server_builder.h>
#include <grpc++/server_context.h>
#include <grpc/grpc.h>
+#include <grpc/support/atm.h>
#include <grpc/support/thd.h>
#include <grpc/support/time.h>
@@ -52,63 +53,13 @@ namespace testing {
class TestServiceImpl : public ::grpc::testing::EchoTestService::Service {
public:
- TestServiceImpl() : signal_client_(false) {}
+ TestServiceImpl() {}
Status Echo(ServerContext* context, const EchoRequest* request,
EchoResponse* response) override {
response->set_message(request->message());
return Status::OK;
}
-
- // Unimplemented is left unimplemented to test the returned error.
-
- Status RequestStream(ServerContext* context,
- ServerReader<EchoRequest>* reader,
- EchoResponse* response) override {
- EchoRequest request;
- response->set_message("");
- while (reader->Read(&request)) {
- response->mutable_message()->append(request.message());
- }
- return Status::OK;
- }
-
- // Return 3 messages.
- // TODO(yangg) make it generic by adding a parameter into EchoRequest
- Status ResponseStream(ServerContext* context, const EchoRequest* request,
- ServerWriter<EchoResponse>* writer) override {
- EchoResponse response;
- response.set_message(request->message() + "0");
- writer->Write(response);
- response.set_message(request->message() + "1");
- writer->Write(response);
- response.set_message(request->message() + "2");
- writer->Write(response);
-
- return Status::OK;
- }
-
- Status BidiStream(
- ServerContext* context,
- ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
- EchoRequest request;
- EchoResponse response;
- while (stream->Read(&request)) {
- gpr_log(GPR_INFO, "recv msg %s", request.message().c_str());
- response.set_message(request.message());
- stream->Write(response);
- }
- return Status::OK;
- }
-
- bool signal_client() {
- std::unique_lock<std::mutex> lock(mu_);
- return signal_client_;
- }
-
- private:
- bool signal_client_;
- std::mutex mu_;
};
template <class Service>
@@ -119,10 +70,15 @@ class CommonStressTest {
virtual void SetUp() = 0;
virtual void TearDown() = 0;
virtual void ResetStub() = 0;
+ virtual bool AllowExhaustion() = 0;
grpc::testing::EchoTestService::Stub* GetStub() { return stub_.get(); }
protected:
std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
+ // Some tests use a custom thread creator. This should be declared before the
+ // server so that it's destructor happens after the server
+ std::unique_ptr<ServerBuilderThreadCreatorOverrideTest> creator_;
+
std::unique_ptr<Server> server_;
virtual void SetUpStart(ServerBuilder* builder, Service* service) = 0;
@@ -147,6 +103,7 @@ class CommonStressTestInsecure : public CommonStressTest<Service> {
CreateChannel(server_address_.str(), InsecureChannelCredentials());
this->stub_ = grpc::testing::EchoTestService::NewStub(channel);
}
+ bool AllowExhaustion() override { return false; }
protected:
void SetUpStart(ServerBuilder* builder, Service* service) override {
@@ -162,7 +119,7 @@ class CommonStressTestInsecure : public CommonStressTest<Service> {
std::ostringstream server_address_;
};
-template <class Service>
+template <class Service, bool allow_resource_exhaustion>
class CommonStressTestInproc : public CommonStressTest<Service> {
public:
void ResetStub() override {
@@ -170,6 +127,7 @@ class CommonStressTestInproc : public CommonStressTest<Service> {
std::shared_ptr<Channel> channel = this->server_->InProcessChannel(args);
this->stub_ = grpc::testing::EchoTestService::NewStub(channel);
}
+ bool AllowExhaustion() override { return allow_resource_exhaustion; }
protected:
void SetUpStart(ServerBuilder* builder, Service* service) override {
@@ -194,6 +152,67 @@ class CommonStressTestSyncServer : public BaseClass {
TestServiceImpl service_;
};
+class ServerBuilderThreadCreatorOverrideTest {
+ public:
+ ServerBuilderThreadCreatorOverrideTest(ServerBuilder* builder, size_t limit)
+ : limit_(limit), threads_(0) {
+ builder->SetThreadFunctions(
+ [this](gpr_thd_id* id, const char* name, void (*f)(void*), void* arg,
+ const gpr_thd_options* options) -> int {
+ std::unique_lock<std::mutex> l(mu_);
+ if (threads_ < limit_) {
+ l.unlock();
+ if (gpr_thd_new(id, name, f, arg, options) != 0) {
+ l.lock();
+ threads_++;
+ return 1;
+ }
+ }
+ return 0;
+ },
+ [this](gpr_thd_id id) {
+ gpr_thd_join(id);
+ std::unique_lock<std::mutex> l(mu_);
+ threads_--;
+ if (threads_ == 0) {
+ done_.notify_one();
+ }
+ });
+ }
+ ~ServerBuilderThreadCreatorOverrideTest() {
+ // Don't allow destruction until all threads are really done and uncounted
+ std::unique_lock<std::mutex> l(mu_);
+ done_.wait(l, [this] { return (threads_ == 0); });
+ }
+
+ private:
+ size_t limit_;
+ size_t threads_;
+ std::mutex mu_;
+ std::condition_variable done_;
+};
+
+template <class BaseClass>
+class CommonStressTestSyncServerLowThreadCount : public BaseClass {
+ public:
+ void SetUp() override {
+ ServerBuilder builder;
+ this->SetUpStart(&builder, &service_);
+ builder.SetSyncServerOption(ServerBuilder::SyncServerOption::MIN_POLLERS,
+ 1);
+ this->creator_.reset(
+ new ServerBuilderThreadCreatorOverrideTest(&builder, 4));
+ this->SetUpEnd(&builder);
+ }
+ void TearDown() override {
+ this->TearDownStart();
+ this->TearDownEnd();
+ }
+
+ private:
+ TestServiceImpl service_;
+};
+
template <class BaseClass>
class CommonStressTestAsyncServer : public BaseClass {
public:
@@ -294,7 +313,8 @@ class End2endTest : public ::testing::Test {
Common common_;
};
-static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs) {
+static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs,
+ bool allow_exhaustion, gpr_atm* errors) {
EchoRequest request;
EchoResponse response;
request.set_message("Hello");
@@ -302,33 +322,48 @@ static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs) {
for (int i = 0; i < num_rpcs; ++i) {
ClientContext context;
Status s = stub->Echo(&context, request, &response);
- EXPECT_EQ(response.message(), request.message());
+ EXPECT_TRUE(s.ok() || (allow_exhaustion &&
+ s.error_code() == StatusCode::RESOURCE_EXHAUSTED));
if (!s.ok()) {
- gpr_log(GPR_ERROR, "RPC error: %d: %s", s.error_code(),
- s.error_message().c_str());
+ if (!(allow_exhaustion &&
+ s.error_code() == StatusCode::RESOURCE_EXHAUSTED)) {
+ gpr_log(GPR_ERROR, "RPC error: %d: %s", s.error_code(),
+ s.error_message().c_str());
+ }
+ gpr_atm_no_barrier_fetch_add(errors, static_cast<gpr_atm>(1));
+ } else {
+ EXPECT_EQ(response.message(), request.message());
}
- ASSERT_TRUE(s.ok());
}
}
typedef ::testing::Types<
CommonStressTestSyncServer<CommonStressTestInsecure<TestServiceImpl>>,
- CommonStressTestSyncServer<CommonStressTestInproc<TestServiceImpl>>,
+ CommonStressTestSyncServer<CommonStressTestInproc<TestServiceImpl, false>>,
+ CommonStressTestSyncServerLowThreadCount<
+ CommonStressTestInproc<TestServiceImpl, true>>,
CommonStressTestAsyncServer<
CommonStressTestInsecure<grpc::testing::EchoTestService::AsyncService>>,
- CommonStressTestAsyncServer<
- CommonStressTestInproc<grpc::testing::EchoTestService::AsyncService>>>
+ CommonStressTestAsyncServer<CommonStressTestInproc<
+ grpc::testing::EchoTestService::AsyncService, false>>>
CommonTypes;
TYPED_TEST_CASE(End2endTest, CommonTypes);
TYPED_TEST(End2endTest, ThreadStress) {
this->common_.ResetStub();
std::vector<std::thread> threads;
+ gpr_atm errors;
+ gpr_atm_rel_store(&errors, static_cast<gpr_atm>(0));
for (int i = 0; i < kNumThreads; ++i) {
- threads.emplace_back(SendRpc, this->common_.GetStub(), kNumRpcs);
+ threads.emplace_back(SendRpc, this->common_.GetStub(), kNumRpcs,
+ this->common_.AllowExhaustion(), &errors);
}
for (int i = 0; i < kNumThreads; ++i) {
threads[i].join();
}
+ uint64_t error_cnt = static_cast<uint64_t>(gpr_atm_no_barrier_load(&errors));
+ if (error_cnt != 0) {
+ gpr_log(GPR_INFO, "RPC error count: %" PRIu64, error_cnt);
+ }
}
template <class Common>
diff --git a/test/cpp/thread_manager/BUILD b/test/cpp/thread_manager/BUILD
new file mode 100644
index 0000000000..1f0878770b
--- /dev/null
+++ b/test/cpp/thread_manager/BUILD
@@ -0,0 +1,31 @@
+# Copyright 2017 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+licenses(["notice"]) # Apache v2
+
+load("//bazel:grpc_build_system.bzl", "grpc_cc_library", "grpc_cc_test", "grpc_package")
+
+grpc_package(name = "test/cpp/thread_manager")
+
+grpc_cc_test(
+ name = "thread_manager_test",
+ srcs = ["thread_manager_test.cc"],
+ deps = [
+ "//:gpr",
+ "//:grpc",
+ "//:grpc++",
+ "//test/cpp/util:test_config",
+ ],
+)
+
diff --git a/test/cpp/thread_manager/thread_manager_test.cc b/test/cpp/thread_manager/thread_manager_test.cc
index 8282d46694..d3d31f9dd9 100644
--- a/test/cpp/thread_manager/thread_manager_test.cc
+++ b/test/cpp/thread_manager/thread_manager_test.cc
@@ -20,10 +20,10 @@
#include <memory>
#include <string>
-#include <gflags/gflags.h>
#include <grpc++/grpc++.h>
#include <grpc/support/log.h>
#include <grpc/support/port_platform.h>
+#include <grpc/support/thd.h>
#include "src/cpp/thread_manager/thread_manager.h"
#include "test/cpp/util/test_config.h"
@@ -32,13 +32,13 @@ namespace grpc {
class ThreadManagerTest final : public grpc::ThreadManager {
public:
ThreadManagerTest()
- : ThreadManager(kMinPollers, kMaxPollers),
+ : ThreadManager(kMinPollers, kMaxPollers, gpr_thd_new, gpr_thd_join),
num_do_work_(0),
num_poll_for_work_(0),
num_work_found_(0) {}
grpc::ThreadManager::WorkStatus PollForWork(void** tag, bool* ok) override;
- void DoWork(void* tag, bool ok) override;
+ void DoWork(void* tag, bool ok, bool resources) override;
void PerformTest();
private:
@@ -89,7 +89,7 @@ grpc::ThreadManager::WorkStatus ThreadManagerTest::PollForWork(void** tag,
}
}
-void ThreadManagerTest::DoWork(void* tag, bool ok) {
+void ThreadManagerTest::DoWork(void* tag, bool ok, bool resources) {
gpr_atm_no_barrier_fetch_add(&num_do_work_, 1);
SleepForMs(kDoWorkDurationMsec); // Simulate doing work by sleeping
}