diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 3867dd1f4d..4883e503e6 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -105,7 +105,8 @@ GrpcServer::~GrpcServer() { Status GrpcServer::Init( ServiceInitFunction service_func, - const RendezvousMgrCreationFunction& rendezvous_mgr_func) { + const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const WorkerCreationFunction& worker_func) { mutex_lock l(mu_); CHECK_EQ(state_, NEW); master_env_.env = env_; @@ -183,7 +184,8 @@ Status GrpcServer::Init( master_impl_ = CreateMaster(&master_env_); master_service_ = NewGrpcMasterService( master_impl_.get(), config.operation_timeout_in_ms(), &builder); - worker_impl_ = NewGrpcWorker(&worker_env_); + worker_impl_ = + worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_); worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder).release(); // extra service: @@ -239,7 +241,13 @@ Status GrpcServer::Init( return Status::OK(); } -Status GrpcServer::Init() { return Init(nullptr, nullptr); } +Status GrpcServer::Init( + ServiceInitFunction service_func, + const RendezvousMgrCreationFunction& rendezvous_mgr_func) { + return Init(service_func, rendezvous_mgr_func, nullptr); +} + +Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); } Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec) { |