aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc')
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc14
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 3867dd1f4d..4883e503e6 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -105,7 +105,8 @@ GrpcServer::~GrpcServer() {
Status GrpcServer::Init(
ServiceInitFunction service_func,
- const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const WorkerCreationFunction& worker_func) {
mutex_lock l(mu_);
CHECK_EQ(state_, NEW);
master_env_.env = env_;
@@ -183,7 +184,8 @@ Status GrpcServer::Init(
master_impl_ = CreateMaster(&master_env_);
master_service_ = NewGrpcMasterService(
master_impl_.get(), config.operation_timeout_in_ms(), &builder);
- worker_impl_ = NewGrpcWorker(&worker_env_);
+ worker_impl_ =
+ worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
worker_service_ =
NewGrpcWorkerService(worker_impl_.get(), &builder).release();
// extra service:
@@ -239,7 +241,13 @@ Status GrpcServer::Init(
return Status::OK();
}
-Status GrpcServer::Init() { return Init(nullptr, nullptr); }
+Status GrpcServer::Init(
+ ServiceInitFunction service_func,
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
+ return Init(service_func, rendezvous_mgr_func, nullptr);
+}
+
+Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); }
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) {