diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 1aafa862cb..7160962b16 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -62,6 +62,13 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption { plugins) override {} }; +// static utility function +RendezvousMgrInterface* NewRpcRendezvousMgr( + const WorkerEnv* env, const string& worker_name, + WorkerCacheInterface* worker_cache) { + return new RpcRendezvousMgr(env, worker_name, worker_cache); +} + } // namespace GrpcServer::GrpcServer(const ServerDef& server_def, Env* env) @@ -93,7 +100,8 @@ GrpcServer::~GrpcServer() { // - worker_env_.compute_pool } -Status GrpcServer::Init() { +Status GrpcServer::Init(ServiceInitFunction service_func, + RendezvousMgrCreationFunction rendevous_mgr_func) { mutex_lock l(mu_); CHECK_EQ(state_, NEW); master_env_.env = env_; @@ -170,6 +178,10 @@ Status GrpcServer::Init() { worker_impl_ = NewGrpcWorker(&worker_env_); worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder).release(); + // extra service: + if (service_func != nullptr) { + service_func(&worker_env_, &builder); + } server_ = builder.BuildAndStart(); if (!server_) { @@ -182,7 +194,9 @@ Status GrpcServer::Init() { // Set up worker environment. std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr( - new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache)); + rendevous_mgr_func == nullptr ? + new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) : + rendevous_mgr_func(&worker_env_, name_prefix, worker_cache)); worker_env_.session_mgr = new SessionMgr( &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_), std::unique_ptr<WorkerCacheInterface>(worker_cache), @@ -211,6 +225,10 @@ Status GrpcServer::Init() { return Status::OK(); } +Status GrpcServer::Init() { + return Init(nullptr, nullptr); +} + Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, GrpcChannelSpec* channel_spec) { for (const auto& job : server_def.cluster().job()) { @@ -248,6 +266,7 @@ Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def, channel_spec, GetChannelCreationFunction(server_def))); const string host_port = channel_cache->TranslateTask(name_prefix); int requested_port; + if (!strings::safe_strto32(str_util::Split(host_port, ':')[1], &requested_port)) { return errors::Internal("Could not parse port for local server from \"", @@ -346,7 +365,8 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, std::unique_ptr<ServerInterface>* out_server) { std::unique_ptr<GrpcServer> ret( new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); - TF_RETURN_IF_ERROR(ret->Init()); + ServiceInitFunction service_func = nullptr; + TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr)); *out_server = std::move(ret); return Status::OK(); } |