diff options
-rw-r--r-- | include/grpc++/impl/codegen/completion_queue.h | 6 | ||||
-rw-r--r-- | include/grpc++/impl/codegen/method_handler_impl.h | 12 | ||||
-rw-r--r-- | include/grpc++/impl/codegen/server_context.h | 6 | ||||
-rw-r--r-- | include/grpc++/server.h | 18 | ||||
-rw-r--r-- | include/grpc++/server_builder.h | 19 | ||||
-rw-r--r-- | src/cpp/client/secure_credentials.cc | 14 | ||||
-rw-r--r-- | src/cpp/server/create_default_thread_pool.cc | 2 | ||||
-rw-r--r-- | src/cpp/server/dynamic_thread_pool.cc | 54 | ||||
-rw-r--r-- | src/cpp/server/dynamic_thread_pool.h | 20 | ||||
-rw-r--r-- | src/cpp/server/secure_server_credentials.cc | 7 | ||||
-rw-r--r-- | src/cpp/server/server_builder.cc | 7 | ||||
-rw-r--r-- | src/cpp/server/server_cc.cc | 46 | ||||
-rw-r--r-- | src/cpp/server/thread_pool_interface.h | 4 | ||||
-rw-r--r-- | src/cpp/thread_manager/thread_manager.cc | 54 | ||||
-rw-r--r-- | src/cpp/thread_manager/thread_manager.h | 28 | ||||
-rw-r--r-- | test/cpp/end2end/thread_stress_test.cc | 157 | ||||
-rw-r--r-- | test/cpp/thread_manager/BUILD | 31 | ||||
-rw-r--r-- | test/cpp/thread_manager/thread_manager_test.cc | 8 |
18 files changed, 361 insertions, 132 deletions
diff --git a/include/grpc++/impl/codegen/completion_queue.h b/include/grpc++/impl/codegen/completion_queue.h index b8a7862578..452eac6646 100644 --- a/include/grpc++/impl/codegen/completion_queue.h +++ b/include/grpc++/impl/codegen/completion_queue.h @@ -78,7 +78,8 @@ template <class ServiceType, class RequestType, class ResponseType> class ServerStreamingHandler; template <class ServiceType, class RequestType, class ResponseType> class BidiStreamingHandler; -class UnknownMethodHandler; +template <StatusCode code> +class ErrorMethodHandler; template <class Streamer, bool WriteNeeded> class TemplatedBidiStreamingHandler; template <class InputMessage, class OutputMessage> @@ -221,7 +222,8 @@ class CompletionQueue : private GrpcLibraryCodegen { friend class ::grpc::internal::ServerStreamingHandler; template <class Streamer, bool WriteNeeded> friend class ::grpc::internal::TemplatedBidiStreamingHandler; - friend class ::grpc::internal::UnknownMethodHandler; + template <StatusCode code> + friend class ::grpc::internal::ErrorMethodHandler; friend class ::grpc::Server; friend class ::grpc::ServerContext; friend class ::grpc::ServerInterface; diff --git a/include/grpc++/impl/codegen/method_handler_impl.h b/include/grpc++/impl/codegen/method_handler_impl.h index c0af4ca130..d98ab7938c 100644 --- a/include/grpc++/impl/codegen/method_handler_impl.h +++ b/include/grpc++/impl/codegen/method_handler_impl.h @@ -242,12 +242,14 @@ class SplitServerStreamingHandler ServerSplitStreamer<RequestType, ResponseType>, false>(func) {} }; -/// Handle unknown method by returning UNIMPLEMENTED error. -class UnknownMethodHandler : public MethodHandler { +/// General method handler class for errors that prevent real method use +/// e.g., handle unknown method by returning UNIMPLEMENTED error. +template <StatusCode code> +class ErrorMethodHandler : public MethodHandler { public: template <class T> static void FillOps(ServerContext* context, T* ops) { - Status status(StatusCode::UNIMPLEMENTED, ""); + Status status(code, ""); if (!context->sent_initial_metadata_) { ops->SendInitialMetadata(context->initial_metadata_, context->initial_metadata_flags()); @@ -267,6 +269,10 @@ class UnknownMethodHandler : public MethodHandler { } }; +typedef ErrorMethodHandler<StatusCode::UNIMPLEMENTED> UnknownMethodHandler; +typedef ErrorMethodHandler<StatusCode::RESOURCE_EXHAUSTED> + ResourceExhaustedHandler; + } // namespace internal } // namespace grpc diff --git a/include/grpc++/impl/codegen/server_context.h b/include/grpc++/impl/codegen/server_context.h index a2d6967bf8..9f20335a2a 100644 --- a/include/grpc++/impl/codegen/server_context.h +++ b/include/grpc++/impl/codegen/server_context.h @@ -63,7 +63,8 @@ template <class ServiceType, class RequestType, class ResponseType> class ServerStreamingHandler; template <class ServiceType, class RequestType, class ResponseType> class BidiStreamingHandler; -class UnknownMethodHandler; +template <StatusCode code> +class ErrorMethodHandler; template <class Streamer, bool WriteNeeded> class TemplatedBidiStreamingHandler; class Call; @@ -255,7 +256,8 @@ class ServerContext { friend class ::grpc::internal::ServerStreamingHandler; template <class Streamer, bool WriteNeeded> friend class ::grpc::internal::TemplatedBidiStreamingHandler; - friend class ::grpc::internal::UnknownMethodHandler; + template <StatusCode code> + friend class ::grpc::internal::ErrorMethodHandler; friend class ::grpc::ClientContext; /// Prevent copying. diff --git a/include/grpc++/server.h b/include/grpc++/server.h index 01c4a60d21..456603e4e7 100644 --- a/include/grpc++/server.h +++ b/include/grpc++/server.h @@ -35,6 +35,7 @@ #include <grpc++/support/config.h> #include <grpc++/support/status.h> #include <grpc/compression.h> +#include <grpc/support/thd.h> struct grpc_server; @@ -138,10 +139,17 @@ class Server final : public ServerInterface, private GrpcLibraryCodegen { /// /// \param sync_cq_timeout_msec The timeout to use when calling AsyncNext() on /// server completion queues passed via sync_server_cqs param. + /// + /// \param thread_creator The thread creation function for the sync + /// server. Typically gpr_thd_new Server(int max_message_size, ChannelArguments* args, std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>> sync_server_cqs, - int min_pollers, int max_pollers, int sync_cq_timeout_msec); + int min_pollers, int max_pollers, int sync_cq_timeout_msec, + std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*, + const gpr_thd_options*)> + thread_creator, + std::function<void(gpr_thd_id)> thread_joiner); /// Register a service. This call does not take ownership of the service. /// The service must exist for the lifetime of the Server instance. @@ -220,6 +228,14 @@ class Server final : public ServerInterface, private GrpcLibraryCodegen { std::unique_ptr<HealthCheckServiceInterface> health_check_service_; bool health_check_service_disabled_; + + std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*, + const gpr_thd_options*)> + thread_creator_; + std::function<void(gpr_thd_id)> thread_joiner_; + + // A special handler for resource exhausted in sync case + std::unique_ptr<internal::MethodHandler> resource_exhausted_handler_; }; } // namespace grpc diff --git a/include/grpc++/server_builder.h b/include/grpc++/server_builder.h index e2bae4b41f..25bbacbbc7 100644 --- a/include/grpc++/server_builder.h +++ b/include/grpc++/server_builder.h @@ -20,6 +20,7 @@ #define GRPCXX_SERVER_BUILDER_H #include <climits> +#include <functional> #include <map> #include <memory> #include <vector> @@ -30,6 +31,7 @@ #include <grpc++/support/config.h> #include <grpc/compression.h> #include <grpc/support/cpu.h> +#include <grpc/support/thd.h> #include <grpc/support/useful.h> #include <grpc/support/workaround_list.h> @@ -47,6 +49,7 @@ class Service; namespace testing { class ServerBuilderPluginTest; +class ServerBuilderThreadCreatorOverrideTest; } // namespace testing /// A builder class for the creation and startup of \a grpc::Server instances. @@ -213,6 +216,17 @@ class ServerBuilder { private: friend class ::grpc::testing::ServerBuilderPluginTest; + friend class ::grpc::testing::ServerBuilderThreadCreatorOverrideTest; + + ServerBuilder& SetThreadFunctions( + std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*, + const gpr_thd_options*)> + thread_creator, + std::function<void(gpr_thd_id)> thread_joiner) { + thread_creator_ = thread_creator; + thread_joiner_ = thread_joiner; + return *this; + } struct Port { grpc::string addr; @@ -272,6 +286,11 @@ class ServerBuilder { grpc_compression_algorithm algorithm; } maybe_default_compression_algorithm_; uint32_t enabled_compression_algorithms_bitset_; + + std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*, + const gpr_thd_options*)> + thread_creator_; + std::function<void(gpr_thd_id)> thread_joiner_; }; } // namespace grpc diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc index 4fb128d98b..94519d817b 100644 --- a/src/cpp/client/secure_credentials.cc +++ b/src/cpp/client/secure_credentials.cc @@ -189,10 +189,16 @@ int MetadataCredentialsPluginWrapper::GetMetadata( } if (w->plugin_->IsBlocking()) { // Asynchronous return. - w->thread_pool_->Add( - std::bind(&MetadataCredentialsPluginWrapper::InvokePlugin, w, context, - cb, user_data, nullptr, nullptr, nullptr, nullptr)); - return 0; + if (w->thread_pool_->Add(std::bind( + &MetadataCredentialsPluginWrapper::InvokePlugin, w, context, cb, + user_data, nullptr, nullptr, nullptr, nullptr))) { + return 0; + } else { + *num_creds_md = 0; + *status = GRPC_STATUS_RESOURCE_EXHAUSTED; + *error_details = nullptr; + return true; + } } else { // Synchronous return. w->InvokePlugin(context, cb, user_data, creds_md, num_creds_md, status, diff --git a/src/cpp/server/create_default_thread_pool.cc b/src/cpp/server/create_default_thread_pool.cc index 8ca3e32c2f..2d2abbe9d1 100644 --- a/src/cpp/server/create_default_thread_pool.cc +++ b/src/cpp/server/create_default_thread_pool.cc @@ -28,7 +28,7 @@ namespace { ThreadPoolInterface* CreateDefaultThreadPoolImpl() { int cores = gpr_cpu_num_cores(); if (!cores) cores = 4; - return new DynamicThreadPool(cores); + return new DynamicThreadPool(cores, gpr_thd_new, gpr_thd_join); } CreateThreadPoolFunc g_ctp_impl = CreateDefaultThreadPoolImpl; diff --git a/src/cpp/server/dynamic_thread_pool.cc b/src/cpp/server/dynamic_thread_pool.cc index 81c78fe739..d0e62313f6 100644 --- a/src/cpp/server/dynamic_thread_pool.cc +++ b/src/cpp/server/dynamic_thread_pool.cc @@ -19,19 +19,32 @@ #include "src/cpp/server/dynamic_thread_pool.h" #include <mutex> -#include <thread> #include <grpc/support/log.h> +#include <grpc/support/thd.h> namespace grpc { -DynamicThreadPool::DynamicThread::DynamicThread(DynamicThreadPool* pool) - : pool_(pool), - thd_(new std::thread(&DynamicThreadPool::DynamicThread::ThreadFunc, - this)) {} +DynamicThreadPool::DynamicThread::DynamicThread(DynamicThreadPool* pool, + bool* valid) + : pool_(pool) { + gpr_thd_options opt = gpr_thd_options_default(); + gpr_thd_options_set_joinable(&opt); + + std::lock_guard<std::mutex> l(dt_mu_); + valid_ = *valid = pool->thread_creator_( + &thd_, "dynamic thread", + [](void* th) { + reinterpret_cast<DynamicThreadPool::DynamicThread*>(th)->ThreadFunc(); + }, + this, &opt); +} + DynamicThreadPool::DynamicThread::~DynamicThread() { - thd_->join(); - thd_.reset(); + std::lock_guard<std::mutex> l(dt_mu_); + if (valid_) { + pool_->thread_joiner_(thd_); + } } void DynamicThreadPool::DynamicThread::ThreadFunc() { @@ -73,15 +86,26 @@ void DynamicThreadPool::ThreadFunc() { } } -DynamicThreadPool::DynamicThreadPool(int reserve_threads) +DynamicThreadPool::DynamicThreadPool( + int reserve_threads, + std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*, + const gpr_thd_options*)> + thread_creator, + std::function<void(gpr_thd_id)> thread_joiner) : shutdown_(false), reserve_threads_(reserve_threads), nthreads_(0), - threads_waiting_(0) { + threads_waiting_(0), + thread_creator_(thread_creator), + thread_joiner_(thread_joiner) { for (int i = 0; i < reserve_threads_; i++) { std::lock_guard<std::mutex> lock(mu_); nthreads_++; - new DynamicThread(this); + bool valid; + auto* th = new DynamicThread(this, &valid); + if (!valid) { + delete th; + } } } @@ -101,7 +125,7 @@ DynamicThreadPool::~DynamicThreadPool() { ReapThreads(&dead_threads_); } -void DynamicThreadPool::Add(const std::function<void()>& callback) { +bool DynamicThreadPool::Add(const std::function<void()>& callback) { std::lock_guard<std::mutex> lock(mu_); // Add works to the callbacks list callbacks_.push(callback); @@ -109,7 +133,12 @@ void DynamicThreadPool::Add(const std::function<void()>& callback) { if (threads_waiting_ == 0) { // Kick off a new thread nthreads_++; - new DynamicThread(this); + bool valid; + auto* th = new DynamicThread(this, &valid); + if (!valid) { + delete th; + return false; + } } else { cv_.notify_one(); } @@ -117,6 +146,7 @@ void DynamicThreadPool::Add(const std::function<void()>& callback) { if (!dead_threads_.empty()) { ReapThreads(&dead_threads_); } + return true; } } // namespace grpc diff --git a/src/cpp/server/dynamic_thread_pool.h b/src/cpp/server/dynamic_thread_pool.h index 9237c6e5ca..75d31cd908 100644 --- a/src/cpp/server/dynamic_thread_pool.h +++ b/src/cpp/server/dynamic_thread_pool.h @@ -24,9 +24,9 @@ #include <memory> #include <mutex> #include <queue> -#include <thread> #include <grpc++/support/config.h> +#include <grpc/support/thd.h> #include "src/cpp/server/thread_pool_interface.h" @@ -34,20 +34,26 @@ namespace grpc { class DynamicThreadPool final : public ThreadPoolInterface { public: - explicit DynamicThreadPool(int reserve_threads); + DynamicThreadPool(int reserve_threads, + std::function<int(gpr_thd_id*, const char*, void (*)(void*), + void*, const gpr_thd_options*)> + thread_creator, + std::function<void(gpr_thd_id)> thread_joiner); ~DynamicThreadPool(); - void Add(const std::function<void()>& callback) override; + bool Add(const std::function<void()>& callback) override; private: class DynamicThread { public: - DynamicThread(DynamicThreadPool* pool); + DynamicThread(DynamicThreadPool* pool, bool* valid); ~DynamicThread(); private: DynamicThreadPool* pool_; - std::unique_ptr<std::thread> thd_; + std::mutex dt_mu_; + gpr_thd_id thd_; + bool valid_; void ThreadFunc(); }; std::mutex mu_; @@ -59,6 +65,10 @@ class DynamicThreadPool final : public ThreadPoolInterface { int nthreads_; int threads_waiting_; std::list<DynamicThread*> dead_threads_; + std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*, + const gpr_thd_options*)> + thread_creator_; + std::function<void(gpr_thd_id)> thread_joiner_; void ThreadFunc(); static void ReapThreads(std::list<DynamicThread*>* tlist); diff --git a/src/cpp/server/secure_server_credentials.cc b/src/cpp/server/secure_server_credentials.cc index 0fbe4ccd18..fa08a6200f 100644 --- a/src/cpp/server/secure_server_credentials.cc +++ b/src/cpp/server/secure_server_credentials.cc @@ -43,9 +43,14 @@ void AuthMetadataProcessorAyncWrapper::Process( return; } if (w->processor_->IsBlocking()) { - w->thread_pool_->Add( + bool added = w->thread_pool_->Add( std::bind(&AuthMetadataProcessorAyncWrapper::InvokeProcessor, w, context, md, num_md, cb, user_data)); + if (!added) { + // no thread available, so fail with temporary resource unavailability + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAVAILABLE, nullptr); + return; + } } else { // invoke directly. w->InvokeProcessor(context, md, num_md, cb, user_data); diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc index 200e477822..d91ee7f4e3 100644 --- a/src/cpp/server/server_builder.cc +++ b/src/cpp/server/server_builder.cc @@ -23,6 +23,7 @@ #include <grpc++/server.h> #include <grpc/support/cpu.h> #include <grpc/support/log.h> +#include <grpc/support/thd.h> #include <grpc/support/useful.h> #include "src/cpp/server/thread_pool_interface.h" @@ -43,7 +44,9 @@ ServerBuilder::ServerBuilder() max_send_message_size_(-1), sync_server_settings_(SyncServerSettings()), resource_quota_(nullptr), - generic_service_(nullptr) { + generic_service_(nullptr), + thread_creator_(gpr_thd_new), + thread_joiner_(gpr_thd_join) { gpr_once_init(&once_init_plugin_list, do_plugin_list_init); for (auto it = g_plugin_factory_list->begin(); it != g_plugin_factory_list->end(); it++) { @@ -262,7 +265,7 @@ std::unique_ptr<Server> ServerBuilder::BuildAndStart() { std::unique_ptr<Server> server(new Server( max_receive_message_size_, &args, sync_server_cqs, sync_server_settings_.min_pollers, sync_server_settings_.max_pollers, - sync_server_settings_.cq_timeout_msec)); + sync_server_settings_.cq_timeout_msec, thread_creator_, thread_joiner_)); if (has_sync_methods) { // This is a Sync server diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 4f8f4e06fc..6ab76a287e 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -36,6 +36,7 @@ #include <grpc/grpc.h> #include <grpc/support/alloc.h> #include <grpc/support/log.h> +#include <grpc/support/thd.h> #include "src/core/ext/transport/inproc/inproc_transport.h" #include "src/core/lib/profiling/timers.h" @@ -196,7 +197,8 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { ctx_(mrd->deadline_, &mrd->request_metadata_), has_request_payload_(mrd->has_request_payload_), request_payload_(mrd->request_payload_), - method_(mrd->method_) { + method_(mrd->method_), + server_(server) { ctx_.set_call(mrd->call_); ctx_.cq_ = &cq_; GPR_ASSERT(mrd->in_flight_); @@ -210,10 +212,13 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { } } - void Run(std::shared_ptr<GlobalCallbacks> global_callbacks) { + void Run(std::shared_ptr<GlobalCallbacks> global_callbacks, + bool resources) { ctx_.BeginCompletionOp(&call_); global_callbacks->PreSynchronousRequest(&ctx_); - method_->handler()->RunHandler(internal::MethodHandler::HandlerParameter( + auto* handler = resources ? method_->handler() + : server_->resource_exhausted_handler_.get(); + handler->RunHandler(internal::MethodHandler::HandlerParameter( &call_, &ctx_, request_payload_)); global_callbacks->PostSynchronousRequest(&ctx_); request_payload_ = nullptr; @@ -235,6 +240,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { const bool has_request_payload_; grpc_byte_buffer* request_payload_; internal::RpcServiceMethod* const method_; + Server* server_; }; private: @@ -255,11 +261,15 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { // appropriate RPC handlers class Server::SyncRequestThreadManager : public ThreadManager { public: - SyncRequestThreadManager(Server* server, CompletionQueue* server_cq, - std::shared_ptr<GlobalCallbacks> global_callbacks, - int min_pollers, int max_pollers, - int cq_timeout_msec) - : ThreadManager(min_pollers, max_pollers), + SyncRequestThreadManager( + Server* server, CompletionQueue* server_cq, + std::shared_ptr<GlobalCallbacks> global_callbacks, int min_pollers, + int max_pollers, int cq_timeout_msec, + std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*, + const gpr_thd_options*)> + thread_creator, + std::function<void(gpr_thd_id)> thread_joiner) + : ThreadManager(min_pollers, max_pollers, thread_creator, thread_joiner), server_(server), server_cq_(server_cq), cq_timeout_msec_(cq_timeout_msec), @@ -285,7 +295,7 @@ class Server::SyncRequestThreadManager : public ThreadManager { GPR_UNREACHABLE_CODE(return TIMEOUT); } - void DoWork(void* tag, bool ok) override { + void DoWork(void* tag, bool ok, bool resources) override { SyncRequest* sync_req = static_cast<SyncRequest*>(tag); if (!sync_req) { @@ -305,7 +315,7 @@ class Server::SyncRequestThreadManager : public ThreadManager { } GPR_TIMER_SCOPE("cd.Run()", 0); - cd.Run(global_callbacks_); + cd.Run(global_callbacks_, resources); } // TODO (sreek) If ok is false here (which it isn't in case of // grpc_request_registered_call), we should still re-queue the request @@ -367,7 +377,11 @@ Server::Server( int max_receive_message_size, ChannelArguments* args, std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>> sync_server_cqs, - int min_pollers, int max_pollers, int sync_cq_timeout_msec) + int min_pollers, int max_pollers, int sync_cq_timeout_msec, + std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*, + const gpr_thd_options*)> + thread_creator, + std::function<void(gpr_thd_id)> thread_joiner) : max_receive_message_size_(max_receive_message_size), sync_server_cqs_(sync_server_cqs), started_(false), @@ -376,7 +390,9 @@ Server::Server( has_generic_service_(false), server_(nullptr), server_initializer_(new ServerInitializer(this)), - health_check_service_disabled_(false) { + health_check_service_disabled_(false), + thread_creator_(thread_creator), + thread_joiner_(thread_joiner) { g_gli_initializer.summon(); gpr_once_init(&g_once_init_callbacks, InitGlobalCallbacks); global_callbacks_ = g_callbacks; @@ -386,7 +402,7 @@ Server::Server( it++) { sync_req_mgrs_.emplace_back(new SyncRequestThreadManager( this, (*it).get(), global_callbacks_, min_pollers, max_pollers, - sync_cq_timeout_msec)); + sync_cq_timeout_msec, thread_creator_, thread_joiner_)); } grpc_channel_args channel_args; @@ -549,6 +565,10 @@ void Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) { } } + if (!sync_server_cqs_->empty()) { + resource_exhausted_handler_.reset(new internal::ResourceExhaustedHandler); + } + for (auto it = sync_req_mgrs_.begin(); it != sync_req_mgrs_.end(); it++) { (*it)->Start(); } diff --git a/src/cpp/server/thread_pool_interface.h b/src/cpp/server/thread_pool_interface.h index 028842a776..656e6673f1 100644 --- a/src/cpp/server/thread_pool_interface.h +++ b/src/cpp/server/thread_pool_interface.h @@ -29,7 +29,9 @@ class ThreadPoolInterface { virtual ~ThreadPoolInterface() {} // Schedule the given callback for execution. - virtual void Add(const std::function<void()>& callback) = 0; + // Return true on success, false on failure + virtual bool Add(const std::function<void()>& callback) + GRPC_MUST_USE_RESULT = 0; }; // Allows different codebases to use their own thread pool impls diff --git a/src/cpp/thread_manager/thread_manager.cc b/src/cpp/thread_manager/thread_manager.cc index 23264f1b5b..107c60f4eb 100644 --- a/src/cpp/thread_manager/thread_manager.cc +++ b/src/cpp/thread_manager/thread_manager.cc @@ -20,18 +20,26 @@ #include <climits> #include <mutex> -#include <thread> #include <grpc/support/log.h> +#include <grpc/support/thd.h> namespace grpc { -ThreadManager::WorkerThread::WorkerThread(ThreadManager* thd_mgr) +ThreadManager::WorkerThread::WorkerThread(ThreadManager* thd_mgr, bool* valid) : thd_mgr_(thd_mgr) { + gpr_thd_options opt = gpr_thd_options_default(); + gpr_thd_options_set_joinable(&opt); + // Make thread creation exclusive with respect to its join happening in // ~WorkerThread(). std::lock_guard<std::mutex> lock(wt_mu_); - thd_ = std::thread(&ThreadManager::WorkerThread::Run, this); + *valid = valid_ = thd_mgr->thread_creator_( + &thd_, "worker thread", + [](void* th) { + reinterpret_cast<ThreadManager::WorkerThread*>(th)->Run(); + }, + this, &opt); } void ThreadManager::WorkerThread::Run() { @@ -42,15 +50,24 @@ void ThreadManager::WorkerThread::Run() { ThreadManager::WorkerThread::~WorkerThread() { // Don't join until the thread is fully constructed. std::lock_guard<std::mutex> lock(wt_mu_); - thd_.join(); + if (valid_) { + thd_mgr_->thread_joiner_(thd_); + } } -ThreadManager::ThreadManager(int min_pollers, int max_pollers) +ThreadManager::ThreadManager( + int min_pollers, int max_pollers, + std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*, + const gpr_thd_options*)> + thread_creator, + std::function<void(gpr_thd_id)> thread_joiner) : shutdown_(false), num_pollers_(0), min_pollers_(min_pollers), max_pollers_(max_pollers == -1 ? INT_MAX : max_pollers), - num_threads_(0) {} + num_threads_(0), + thread_creator_(thread_creator), + thread_joiner_(thread_joiner) {} ThreadManager::~ThreadManager() { { @@ -111,7 +128,9 @@ void ThreadManager::Initialize() { for (int i = 0; i < min_pollers_; i++) { // Create a new thread (which ends up calling the MainWorkLoop() function - new WorkerThread(this); + bool valid; + new WorkerThread(this, &valid); + GPR_ASSERT(valid); // we need to have at least this minimum } } @@ -138,18 +157,27 @@ void ThreadManager::MainWorkLoop() { case WORK_FOUND: // If we got work and there are now insufficient pollers, start a new // one + bool resources; if (!shutdown_ && num_pollers_ < min_pollers_) { - num_pollers_++; - num_threads_++; + bool valid; // Drop lock before spawning thread to avoid contention lock.unlock(); - new WorkerThread(this); + auto* th = new WorkerThread(this, &valid); + lock.lock(); + if (valid) { + num_pollers_++; + num_threads_++; + } else { + delete th; + } + resources = (num_pollers_ > 0); } else { - // Drop lock for consistency with above branch - lock.unlock(); + resources = true; } + // Drop lock before any application work + lock.unlock(); // Lock is always released at this point - do the application work - DoWork(tag, ok); + DoWork(tag, ok, resources); // Take the lock again to check post conditions lock.lock(); // If we're shutdown, we should finish at this point. diff --git a/src/cpp/thread_manager/thread_manager.h b/src/cpp/thread_manager/thread_manager.h index a206e0bd8a..4fa8a6c563 100644 --- a/src/cpp/thread_manager/thread_manager.h +++ b/src/cpp/thread_manager/thread_manager.h @@ -23,15 +23,19 @@ #include <list> #include <memory> #include <mutex> -#include <thread> #include <grpc++/support/config.h> +#include <grpc/support/thd.h> namespace grpc { class ThreadManager { public: - explicit ThreadManager(int min_pollers, int max_pollers); + ThreadManager(int min_pollers, int max_pollers, + std::function<int(gpr_thd_id*, const char*, void (*)(void*), + void*, const gpr_thd_options*)> + thread_creator, + std::function<void(gpr_thd_id)> thread_joiner); virtual ~ThreadManager(); // Initializes and Starts the Rpc Manager threads @@ -50,6 +54,8 @@ class ThreadManager { // - ThreadManager does not interpret the values of 'tag' and 'ok' // - ThreadManager WILL call DoWork() and pass '*tag' and 'ok' as input to // DoWork() + // - ThreadManager will also pass DoWork a bool saying if there are actually + // resources to do the work // // If the return value is SHUTDOWN:, // - ThreadManager WILL NOT call DoWork() and terminates the thead @@ -69,7 +75,7 @@ class ThreadManager { // The implementation of DoWork() should also do any setup needed to ensure // that the next call to PollForWork() (not necessarily by the current thread) // actually finds some work - virtual void DoWork(void* tag, bool ok) = 0; + virtual void DoWork(void* tag, bool ok, bool resources) = 0; // Mark the ThreadManager as shutdown and begin draining the work. This is a // non-blocking call and the caller should call Wait(), a blocking call which @@ -84,15 +90,15 @@ class ThreadManager { virtual void Wait(); private: - // Helper wrapper class around std::thread. This takes a ThreadManager object - // and starts a new std::thread to calls the Run() function. + // Helper wrapper class around thread. This takes a ThreadManager object + // and starts a new thread to calls the Run() function. // // The Run() function calls ThreadManager::MainWorkLoop() function and once // that completes, it marks the WorkerThread completed by calling // ThreadManager::MarkAsCompleted() class WorkerThread { public: - WorkerThread(ThreadManager* thd_mgr); + WorkerThread(ThreadManager* thd_mgr, bool* valid); ~WorkerThread(); private: @@ -102,7 +108,8 @@ class ThreadManager { ThreadManager* const thd_mgr_; std::mutex wt_mu_; - std::thread thd_; + gpr_thd_id thd_; + bool valid_; }; // The main funtion in ThreadManager @@ -129,6 +136,13 @@ class ThreadManager { // currently polling i.e num_pollers_) int num_threads_; + // Functions for creating/joining threads. Normally, these should + // be gpr_thd_new/gpr_thd_join but they are overridable + std::function<int(gpr_thd_id*, const char*, void (*)(void*), void*, + const gpr_thd_options*)> + thread_creator_; + std::function<void(gpr_thd_id)> thread_joiner_; + std::mutex list_mu_; std::list<WorkerThread*> completed_threads_; }; diff --git a/test/cpp/end2end/thread_stress_test.cc b/test/cpp/end2end/thread_stress_test.cc index 90b2eddbbb..fd43c8f584 100644 --- a/test/cpp/end2end/thread_stress_test.cc +++ b/test/cpp/end2end/thread_stress_test.cc @@ -26,6 +26,7 @@ #include <grpc++/server_builder.h> #include <grpc++/server_context.h> #include <grpc/grpc.h> +#include <grpc/support/atm.h> #include <grpc/support/thd.h> #include <grpc/support/time.h> @@ -52,63 +53,13 @@ namespace testing { class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { public: - TestServiceImpl() : signal_client_(false) {} + TestServiceImpl() {} Status Echo(ServerContext* context, const EchoRequest* request, EchoResponse* response) override { response->set_message(request->message()); return Status::OK; } - - // Unimplemented is left unimplemented to test the returned error. - - Status RequestStream(ServerContext* context, - ServerReader<EchoRequest>* reader, - EchoResponse* response) override { - EchoRequest request; - response->set_message(""); - while (reader->Read(&request)) { - response->mutable_message()->append(request.message()); - } - return Status::OK; - } - - // Return 3 messages. - // TODO(yangg) make it generic by adding a parameter into EchoRequest - Status ResponseStream(ServerContext* context, const EchoRequest* request, - ServerWriter<EchoResponse>* writer) override { - EchoResponse response; - response.set_message(request->message() + "0"); - writer->Write(response); - response.set_message(request->message() + "1"); - writer->Write(response); - response.set_message(request->message() + "2"); - writer->Write(response); - - return Status::OK; - } - - Status BidiStream( - ServerContext* context, - ServerReaderWriter<EchoResponse, EchoRequest>* stream) override { - EchoRequest request; - EchoResponse response; - while (stream->Read(&request)) { - gpr_log(GPR_INFO, "recv msg %s", request.message().c_str()); - response.set_message(request.message()); - stream->Write(response); - } - return Status::OK; - } - - bool signal_client() { - std::unique_lock<std::mutex> lock(mu_); - return signal_client_; - } - - private: - bool signal_client_; - std::mutex mu_; }; template <class Service> @@ -119,10 +70,15 @@ class CommonStressTest { virtual void SetUp() = 0; virtual void TearDown() = 0; virtual void ResetStub() = 0; + virtual bool AllowExhaustion() = 0; grpc::testing::EchoTestService::Stub* GetStub() { return stub_.get(); } protected: std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; + // Some tests use a custom thread creator. This should be declared before the + // server so that it's destructor happens after the server + std::unique_ptr<ServerBuilderThreadCreatorOverrideTest> creator_; + std::unique_ptr<Server> server_; virtual void SetUpStart(ServerBuilder* builder, Service* service) = 0; @@ -147,6 +103,7 @@ class CommonStressTestInsecure : public CommonStressTest<Service> { CreateChannel(server_address_.str(), InsecureChannelCredentials()); this->stub_ = grpc::testing::EchoTestService::NewStub(channel); } + bool AllowExhaustion() override { return false; } protected: void SetUpStart(ServerBuilder* builder, Service* service) override { @@ -162,7 +119,7 @@ class CommonStressTestInsecure : public CommonStressTest<Service> { std::ostringstream server_address_; }; -template <class Service> +template <class Service, bool allow_resource_exhaustion> class CommonStressTestInproc : public CommonStressTest<Service> { public: void ResetStub() override { @@ -170,6 +127,7 @@ class CommonStressTestInproc : public CommonStressTest<Service> { std::shared_ptr<Channel> channel = this->server_->InProcessChannel(args); this->stub_ = grpc::testing::EchoTestService::NewStub(channel); } + bool AllowExhaustion() override { return allow_resource_exhaustion; } protected: void SetUpStart(ServerBuilder* builder, Service* service) override { @@ -194,6 +152,67 @@ class CommonStressTestSyncServer : public BaseClass { TestServiceImpl service_; }; +class ServerBuilderThreadCreatorOverrideTest { + public: + ServerBuilderThreadCreatorOverrideTest(ServerBuilder* builder, size_t limit) + : limit_(limit), threads_(0) { + builder->SetThreadFunctions( + [this](gpr_thd_id* id, const char* name, void (*f)(void*), void* arg, + const gpr_thd_options* options) -> int { + std::unique_lock<std::mutex> l(mu_); + if (threads_ < limit_) { + l.unlock(); + if (gpr_thd_new(id, name, f, arg, options) != 0) { + l.lock(); + threads_++; + return 1; + } + } + return 0; + }, + [this](gpr_thd_id id) { + gpr_thd_join(id); + std::unique_lock<std::mutex> l(mu_); + threads_--; + if (threads_ == 0) { + done_.notify_one(); + } + }); + } + ~ServerBuilderThreadCreatorOverrideTest() { + // Don't allow destruction until all threads are really done and uncounted + std::unique_lock<std::mutex> l(mu_); + done_.wait(l, [this] { return (threads_ == 0); }); + } + + private: + size_t limit_; + size_t threads_; + std::mutex mu_; + std::condition_variable done_; +}; + +template <class BaseClass> +class CommonStressTestSyncServerLowThreadCount : public BaseClass { + public: + void SetUp() override { + ServerBuilder builder; + this->SetUpStart(&builder, &service_); + builder.SetSyncServerOption(ServerBuilder::SyncServerOption::MIN_POLLERS, + 1); + this->creator_.reset( + new ServerBuilderThreadCreatorOverrideTest(&builder, 4)); + this->SetUpEnd(&builder); + } + void TearDown() override { + this->TearDownStart(); + this->TearDownEnd(); + } + + private: + TestServiceImpl service_; +}; + template <class BaseClass> class CommonStressTestAsyncServer : public BaseClass { public: @@ -294,7 +313,8 @@ class End2endTest : public ::testing::Test { Common common_; }; -static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs) { +static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs, + bool allow_exhaustion, gpr_atm* errors) { EchoRequest request; EchoResponse response; request.set_message("Hello"); @@ -302,33 +322,48 @@ static void SendRpc(grpc::testing::EchoTestService::Stub* stub, int num_rpcs) { for (int i = 0; i < num_rpcs; ++i) { ClientContext context; Status s = stub->Echo(&context, request, &response); - EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok() || (allow_exhaustion && + s.error_code() == StatusCode::RESOURCE_EXHAUSTED)); if (!s.ok()) { - gpr_log(GPR_ERROR, "RPC error: %d: %s", s.error_code(), - s.error_message().c_str()); + if (!(allow_exhaustion && + s.error_code() == StatusCode::RESOURCE_EXHAUSTED)) { + gpr_log(GPR_ERROR, "RPC error: %d: %s", s.error_code(), + s.error_message().c_str()); + } + gpr_atm_no_barrier_fetch_add(errors, static_cast<gpr_atm>(1)); + } else { + EXPECT_EQ(response.message(), request.message()); } - ASSERT_TRUE(s.ok()); } } typedef ::testing::Types< CommonStressTestSyncServer<CommonStressTestInsecure<TestServiceImpl>>, - CommonStressTestSyncServer<CommonStressTestInproc<TestServiceImpl>>, + CommonStressTestSyncServer<CommonStressTestInproc<TestServiceImpl, false>>, + CommonStressTestSyncServerLowThreadCount< + CommonStressTestInproc<TestServiceImpl, true>>, CommonStressTestAsyncServer< CommonStressTestInsecure<grpc::testing::EchoTestService::AsyncService>>, - CommonStressTestAsyncServer< - CommonStressTestInproc<grpc::testing::EchoTestService::AsyncService>>> + CommonStressTestAsyncServer<CommonStressTestInproc< + grpc::testing::EchoTestService::AsyncService, false>>> CommonTypes; TYPED_TEST_CASE(End2endTest, CommonTypes); TYPED_TEST(End2endTest, ThreadStress) { this->common_.ResetStub(); std::vector<std::thread> threads; + gpr_atm errors; + gpr_atm_rel_store(&errors, static_cast<gpr_atm>(0)); for (int i = 0; i < kNumThreads; ++i) { - threads.emplace_back(SendRpc, this->common_.GetStub(), kNumRpcs); + threads.emplace_back(SendRpc, this->common_.GetStub(), kNumRpcs, + this->common_.AllowExhaustion(), &errors); } for (int i = 0; i < kNumThreads; ++i) { threads[i].join(); } + uint64_t error_cnt = static_cast<uint64_t>(gpr_atm_no_barrier_load(&errors)); + if (error_cnt != 0) { + gpr_log(GPR_INFO, "RPC error count: %" PRIu64, error_cnt); + } } template <class Common> diff --git a/test/cpp/thread_manager/BUILD b/test/cpp/thread_manager/BUILD new file mode 100644 index 0000000000..1f0878770b --- /dev/null +++ b/test/cpp/thread_manager/BUILD @@ -0,0 +1,31 @@ +# Copyright 2017 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache v2 + +load("//bazel:grpc_build_system.bzl", "grpc_cc_library", "grpc_cc_test", "grpc_package") + +grpc_package(name = "test/cpp/thread_manager") + +grpc_cc_test( + name = "thread_manager_test", + srcs = ["thread_manager_test.cc"], + deps = [ + "//:gpr", + "//:grpc", + "//:grpc++", + "//test/cpp/util:test_config", + ], +) + diff --git a/test/cpp/thread_manager/thread_manager_test.cc b/test/cpp/thread_manager/thread_manager_test.cc index 8282d46694..d3d31f9dd9 100644 --- a/test/cpp/thread_manager/thread_manager_test.cc +++ b/test/cpp/thread_manager/thread_manager_test.cc @@ -20,10 +20,10 @@ #include <memory> #include <string> -#include <gflags/gflags.h> #include <grpc++/grpc++.h> #include <grpc/support/log.h> #include <grpc/support/port_platform.h> +#include <grpc/support/thd.h> #include "src/cpp/thread_manager/thread_manager.h" #include "test/cpp/util/test_config.h" @@ -32,13 +32,13 @@ namespace grpc { class ThreadManagerTest final : public grpc::ThreadManager { public: ThreadManagerTest() - : ThreadManager(kMinPollers, kMaxPollers), + : ThreadManager(kMinPollers, kMaxPollers, gpr_thd_new, gpr_thd_join), num_do_work_(0), num_poll_for_work_(0), num_work_found_(0) {} grpc::ThreadManager::WorkStatus PollForWork(void** tag, bool* ok) override; - void DoWork(void* tag, bool ok) override; + void DoWork(void* tag, bool ok, bool resources) override; void PerformTest(); private: @@ -89,7 +89,7 @@ grpc::ThreadManager::WorkStatus ThreadManagerTest::PollForWork(void** tag, } } -void ThreadManagerTest::DoWork(void* tag, bool ok) { +void ThreadManagerTest::DoWork(void* tag, bool ok, bool resources) { gpr_atm_no_barrier_fetch_add(&num_do_work_, 1); SleepForMs(kDoWorkDurationMsec); // Simulate doing work by sleeping } |