aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-19 10:25:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 10:28:19 -0700
commitc740b345e8c17cde0dd4691c7e240a065cb8c88c (patch)
treedd85bcff39031ec09de4507a335b541fb183adb4 /tensorflow/core/distributed_runtime
parentccaf2ca02739792a8a8e50a95246f2db1197aa97 (diff)
Allow setting server def on the eager context, and add the eager service to the grpc_tensorflow_server.
PiperOrigin-RevId: 201198350
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/BUILD17
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h97
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc11
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h10
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc30
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h17
7 files changed, 50 insertions, 134 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 882271e3f5..7b19427e4b 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -284,7 +284,9 @@ cc_library(
"//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:session_mgr",
+ "//tensorflow/core/distributed_runtime:worker_cache_wrapper",
"//tensorflow/core/distributed_runtime:worker_env",
+ "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl",
"@grpc",
"@grpc//:grpc++",
],
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD
index a5472159cc..8cec497361 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD
@@ -42,26 +42,11 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:ptr_util",
"//tensorflow/core/distributed_runtime/eager:eager_service_impl",
+ "//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
- "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
- "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"@grpc//:grpc++",
],
)
-
-cc_library(
- name = "eager_grpc_server_lib",
- hdrs = ["eager_grpc_server_lib.h"],
- deps = [
- ":grpc_eager_service_impl",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
- "//tensorflow/core/distributed_runtime:worker_cache_wrapper",
- "//tensorflow/core/distributed_runtime/eager:eager_service_impl",
- "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
- "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
- ],
-)
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
deleted file mode 100644
index 9b863ccee5..0000000000
--- a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
+++ /dev/null
@@ -1,97 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_
-#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_
-
-#include "tensorflow/core/common_runtime/device_factory.h"
-#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
-#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
-#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
-#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
-
-namespace tensorflow {
-namespace eager {
-
-class EagerGrpcServer : public GrpcServer {
- public:
- static Status Create(const ServerDef& server_def,
- std::unique_ptr<EagerGrpcServer>* server) {
- std::unique_ptr<EagerGrpcServer> ret(new EagerGrpcServer(server_def));
-
- TF_RETURN_IF_ERROR(ret->InitEager());
-
- *server = std::move(ret);
-
- return Status::OK();
- }
-
- Status Start() override {
- TF_RETURN_IF_ERROR(GrpcServer::Start());
-
- eager_service_->Start();
-
- return Status::OK();
- }
-
- Status Stop() override {
- TF_RETURN_IF_ERROR(GrpcServer::Stop());
-
- eager_service_->Stop();
-
- return Status::OK();
- }
-
- using GrpcServer::channel_cache;
- using GrpcServer::master_env;
- using GrpcServer::worker_env;
-
- private:
- EagerGrpcServer(const ServerDef& server_def)
- : GrpcServer(server_def, Env::Default()),
- worker_name_(
- strings::StrCat("/job:", server_def.job_name(),
- "/replica:0/task:", server_def.task_index())) {}
-
- Status InitEager() {
- TF_RETURN_IF_ERROR(this->Init(
- [this](const WorkerEnv* worker_env,
- ::grpc::ServerBuilder* server_builder) {
- this->eager_service_.reset(
- new eager::GrpcEagerServiceImpl(worker_env, server_builder));
- },
- nullptr, nullptr));
-
- worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr(
- "", worker_name_,
- std::unique_ptr<WorkerCacheInterface>(
- new WorkerCacheWrapper(master_env()->worker_cache)),
- worker_env()->device_mgr, {});
-
- auto* r = worker_env()->rendezvous_mgr->Find(0);
- return r->Initialize(worker_session_.get());
- }
-
- std::unique_ptr<GrpcEagerServiceImpl> eager_service_;
- std::shared_ptr<WorkerSession> worker_session_;
- const string worker_name_;
-}; // namespace eager
-
-} // namespace eager
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
index b36c6dce86..52e06c263d 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
@@ -18,10 +18,8 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -36,7 +34,7 @@ GrpcEagerServiceImpl::GrpcEagerServiceImpl(
cq_ = server_builder->AddCompletionQueue();
}
-void GrpcEagerServiceImpl::DriveCQ() {
+void GrpcEagerServiceImpl::HandleRPCsLoop() {
#define ENQUEUE_REQUEST(method) \
do { \
Call<GrpcEagerServiceImpl, \
@@ -74,12 +72,7 @@ void GrpcEagerServiceImpl::DriveCQ() {
}
}
-void GrpcEagerServiceImpl::Start() {
- // TODO(nareshmodi) separate thread for driving CQ
- request_handler_threadpool_->Schedule([this]() { DriveCQ(); });
-}
-
-void GrpcEagerServiceImpl::Stop() {
+void GrpcEagerServiceImpl::Shutdown() {
// This enqueues a special event (with a null tag)
// that causes the completion queue to be shut down on the
// polling thread.
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
index e94aedf535..9a94026342 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
@@ -20,16 +20,16 @@ limitations under the License.
#include "grpcpp/completion_queue.h"
#include "grpcpp/server_builder.h"
#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
+#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
namespace tensorflow {
namespace eager {
// This class is a wrapper that handles communication for gRPC.
-class GrpcEagerServiceImpl {
+class GrpcEagerServiceImpl : public AsyncServiceInterface {
public:
template <class RequestMessage, class ResponseMessage>
using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
@@ -39,8 +39,8 @@ class GrpcEagerServiceImpl {
::grpc::ServerBuilder* server_builder);
virtual ~GrpcEagerServiceImpl() {}
- void Start();
- void Stop();
+ void HandleRPCsLoop() override;
+ void Shutdown() override;
private:
#define HANDLER(method) \
@@ -66,8 +66,6 @@ class GrpcEagerServiceImpl {
EagerServiceImpl local_impl_;
- void DriveCQ();
-
std::unique_ptr<::grpc::Alarm> shutdown_alarm_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 43dbe20836..2dd3e8678b 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/master_session.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
@@ -42,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -81,6 +83,7 @@ GrpcServer::~GrpcServer() {
delete master_service_;
delete worker_service_;
+ delete eager_service_;
// TODO(mrry): Refactor the *Env classes so that it is less fiddly
// to destroy them.
@@ -192,6 +195,8 @@ Status GrpcServer::Init(
worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
worker_service_ =
NewGrpcWorkerService(worker_impl_.get(), &builder).release();
+ eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
+
// extra service:
if (service_func != nullptr) {
service_func(&worker_env_, &builder);
@@ -264,7 +269,15 @@ Status GrpcServer::Init(
LocalMaster::Register(target(), master_impl_.get(),
config.operation_timeout_in_ms());
- return Status::OK();
+ // Generate a dummy worker session that is used to register the
+ // Rendezvous for eager (we use Step 0 for eager).
+ worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr(
+ "", name_prefix,
+ std::unique_ptr<WorkerCacheInterface>(
+ new WorkerCacheWrapper(master_env_.worker_cache)),
+ worker_env_.device_mgr, {});
+ auto* r = worker_env()->rendezvous_mgr->Find(0);
+ return r->Initialize(worker_session_.get());
}
Status GrpcServer::Init(
@@ -357,6 +370,9 @@ Status GrpcServer::Start() {
worker_thread_.reset(
env_->StartThread(ThreadOptions(), "TF_worker_service",
[this] { worker_service_->HandleRPCsLoop(); }));
+ eager_thread_.reset(
+ env_->StartThread(ThreadOptions(), "TF_eager_service",
+ [this] { eager_service_->HandleRPCsLoop(); }));
state_ = STARTED;
LOG(INFO) << "Started server with target: " << target();
return Status::OK();
@@ -399,6 +415,7 @@ Status GrpcServer::Join() {
case STOPPED:
master_thread_.reset();
worker_thread_.reset();
+ eager_thread_.reset();
return Status::OK();
default:
LOG(FATAL);
@@ -435,6 +452,17 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
return Status::OK();
}
+/* static */
+Status GrpcServer::Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<GrpcServer>* out_server) {
+ std::unique_ptr<GrpcServer> ret(
+ new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
+ ServiceInitFunction service_func = nullptr;
+ TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr, nullptr));
+ *out_server = std::move(ret);
+ return Status::OK();
+}
+
namespace {
class GrpcServerFactory : public ServerFactory {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index ca9946cafc..c674da9490 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -63,6 +63,8 @@ class GrpcServer : public ServerInterface {
public:
static Status Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server);
+ static Status Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<GrpcServer>* out_server);
// Destruction is only supported in the factory method. Clean
// shutdown is not currently implemented for this server type.
@@ -74,6 +76,11 @@ class GrpcServer : public ServerInterface {
Status Join() override;
const string target() const override;
+ WorkerEnv* worker_env() { return &worker_env_; }
+ MasterEnv* master_env() { return &master_env_; }
+
+ std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
+
protected:
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
@@ -112,11 +119,6 @@ class GrpcServer : public ServerInterface {
// This method may only be called after `this->Init()` returns successfully.
int bound_port() const { return bound_port_; }
- WorkerEnv* worker_env() { return &worker_env_; }
- MasterEnv* master_env() { return &master_env_; }
-
- std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
-
const ServerDef& server_def() const { return server_def_; }
private:
@@ -155,6 +157,11 @@ class GrpcServer : public ServerInterface {
AsyncServiceInterface* worker_service_ = nullptr;
std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_);
+ // TensorFlow Eager implementation, and RPC polling thread.
+ AsyncServiceInterface* eager_service_ = nullptr;
+ std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_);
+ std::shared_ptr<WorkerSession> worker_session_;
+
std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_);
};