aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc')
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc26
1 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 1aafa862cb..7160962b16 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -62,6 +62,13 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
plugins) override {}
};
+// static utility function
+RendezvousMgrInterface* NewRpcRendezvousMgr(
+ const WorkerEnv* env, const string& worker_name,
+ WorkerCacheInterface* worker_cache) {
+ return new RpcRendezvousMgr(env, worker_name, worker_cache);
+}
+
} // namespace
GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
@@ -93,7 +100,8 @@ GrpcServer::~GrpcServer() {
// - worker_env_.compute_pool
}
-Status GrpcServer::Init() {
+Status GrpcServer::Init(ServiceInitFunction service_func,
+ RendezvousMgrCreationFunction rendevous_mgr_func) {
mutex_lock l(mu_);
CHECK_EQ(state_, NEW);
master_env_.env = env_;
@@ -170,6 +178,10 @@ Status GrpcServer::Init() {
worker_impl_ = NewGrpcWorker(&worker_env_);
worker_service_ =
NewGrpcWorkerService(worker_impl_.get(), &builder).release();
+ // extra service:
+ if (service_func != nullptr) {
+ service_func(&worker_env_, &builder);
+ }
server_ = builder.BuildAndStart();
if (!server_) {
@@ -182,7 +194,9 @@ Status GrpcServer::Init() {
// Set up worker environment.
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
- new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache));
+ rendevous_mgr_func == nullptr ?
+ new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
+ rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
std::unique_ptr<WorkerCacheInterface>(worker_cache),
@@ -211,6 +225,10 @@ Status GrpcServer::Init() {
return Status::OK();
}
+Status GrpcServer::Init() {
+ return Init(nullptr, nullptr);
+}
+
Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
GrpcChannelSpec* channel_spec) {
for (const auto& job : server_def.cluster().job()) {
@@ -248,6 +266,7 @@ Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def,
channel_spec, GetChannelCreationFunction(server_def)));
const string host_port = channel_cache->TranslateTask(name_prefix);
int requested_port;
+
if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
&requested_port)) {
return errors::Internal("Could not parse port for local server from \"",
@@ -346,7 +365,8 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server) {
std::unique_ptr<GrpcServer> ret(
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
- TF_RETURN_IF_ERROR(ret->Init());
+ ServiceInitFunction service_func = nullptr;
+ TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
*out_server = std::move(ret);
return Status::OK();
}