diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-06-19 10:25:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-19 10:28:19 -0700 |
commit | c740b345e8c17cde0dd4691c7e240a065cb8c88c (patch) | |
tree | dd85bcff39031ec09de4507a335b541fb183adb4 /tensorflow/core/distributed_runtime | |
parent | ccaf2ca02739792a8a8e50a95246f2db1197aa97 (diff) |
Allow setting server def on the eager context, and add the eager service to the grpc_tensorflow_server.
PiperOrigin-RevId: 201198350
Diffstat (limited to 'tensorflow/core/distributed_runtime')
7 files changed, 50 insertions, 134 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 882271e3f5..7b19427e4b 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -284,7 +284,9 @@ cc_library( "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:session_mgr", + "//tensorflow/core/distributed_runtime:worker_cache_wrapper", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl", "@grpc", "@grpc//:grpc++", ], diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD index a5472159cc..8cec497361 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD @@ -42,26 +42,11 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:ptr_util", "//tensorflow/core/distributed_runtime/eager:eager_service_impl", + "//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", - "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "@grpc//:grpc++", ], ) - -cc_library( - name = "eager_grpc_server_lib", - hdrs = ["eager_grpc_server_lib.h"], - deps = [ - ":grpc_eager_service_impl", - "//tensorflow/core:core_cpu", - "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", - "//tensorflow/core/distributed_runtime:worker_cache_wrapper", - "//tensorflow/core/distributed_runtime/eager:eager_service_impl", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", - ], -) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h deleted file mode 100644 index 9b863ccee5..0000000000 --- a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_ -#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_ - -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" -#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" -#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" - -namespace tensorflow { -namespace eager { - -class EagerGrpcServer : public GrpcServer { - public: - static Status Create(const ServerDef& server_def, - std::unique_ptr<EagerGrpcServer>* server) { - std::unique_ptr<EagerGrpcServer> ret(new EagerGrpcServer(server_def)); - - TF_RETURN_IF_ERROR(ret->InitEager()); - - *server = std::move(ret); - - return Status::OK(); - } - - Status Start() override { - TF_RETURN_IF_ERROR(GrpcServer::Start()); - - eager_service_->Start(); - - return Status::OK(); - } - - Status Stop() override { - TF_RETURN_IF_ERROR(GrpcServer::Stop()); - - eager_service_->Stop(); - - return Status::OK(); - } - - using GrpcServer::channel_cache; - using GrpcServer::master_env; - using GrpcServer::worker_env; - - private: - EagerGrpcServer(const ServerDef& server_def) - : GrpcServer(server_def, Env::Default()), - worker_name_( - strings::StrCat("/job:", server_def.job_name(), - "/replica:0/task:", server_def.task_index())) {} - - Status InitEager() { - TF_RETURN_IF_ERROR(this->Init( - [this](const WorkerEnv* worker_env, - ::grpc::ServerBuilder* server_builder) { - this->eager_service_.reset( - new eager::GrpcEagerServiceImpl(worker_env, server_builder)); - }, - nullptr, nullptr)); - - worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr( - "", worker_name_, - std::unique_ptr<WorkerCacheInterface>( - new WorkerCacheWrapper(master_env()->worker_cache)), - worker_env()->device_mgr, {}); - - auto* r = worker_env()->rendezvous_mgr->Find(0); - return r->Initialize(worker_session_.get()); - } - - std::unique_ptr<GrpcEagerServiceImpl> eager_service_; - std::shared_ptr<WorkerSession> worker_session_; - const string worker_name_; -}; // namespace eager - -} // namespace eager -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc index b36c6dce86..52e06c263d 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc @@ -18,10 +18,8 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -36,7 +34,7 @@ GrpcEagerServiceImpl::GrpcEagerServiceImpl( cq_ = server_builder->AddCompletionQueue(); } -void GrpcEagerServiceImpl::DriveCQ() { +void GrpcEagerServiceImpl::HandleRPCsLoop() { #define ENQUEUE_REQUEST(method) \ do { \ Call<GrpcEagerServiceImpl, \ @@ -74,12 +72,7 @@ void GrpcEagerServiceImpl::DriveCQ() { } } -void GrpcEagerServiceImpl::Start() { - // TODO(nareshmodi) separate thread for driving CQ - request_handler_threadpool_->Schedule([this]() { DriveCQ(); }); -} - -void GrpcEagerServiceImpl::Stop() { +void GrpcEagerServiceImpl::Shutdown() { // This enqueues a special event (with a null tag) // that causes the completion queue to be shut down on the // polling thread. diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h index e94aedf535..9a94026342 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h @@ -20,16 +20,16 @@ limitations under the License. #include "grpcpp/completion_queue.h" #include "grpcpp/server_builder.h" #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" namespace tensorflow { namespace eager { // This class is a wrapper that handles communication for gRPC. -class GrpcEagerServiceImpl { +class GrpcEagerServiceImpl : public AsyncServiceInterface { public: template <class RequestMessage, class ResponseMessage> using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, @@ -39,8 +39,8 @@ class GrpcEagerServiceImpl { ::grpc::ServerBuilder* server_builder); virtual ~GrpcEagerServiceImpl() {} - void Start(); - void Stop(); + void HandleRPCsLoop() override; + void Shutdown() override; private: #define HANDLER(method) \ @@ -66,8 +66,6 @@ class GrpcEagerServiceImpl { EagerServiceImpl local_impl_; - void DriveCQ(); - std::unique_ptr<::grpc::Alarm> shutdown_alarm_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 43dbe20836..2dd3e8678b 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/master_session.h" #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" #include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -81,6 +83,7 @@ GrpcServer::~GrpcServer() { delete master_service_; delete worker_service_; + delete eager_service_; // TODO(mrry): Refactor the *Env classes so that it is less fiddly // to destroy them. @@ -192,6 +195,8 @@ Status GrpcServer::Init( worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_); worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder).release(); + eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder); + // extra service: if (service_func != nullptr) { service_func(&worker_env_, &builder); @@ -264,7 +269,15 @@ Status GrpcServer::Init( LocalMaster::Register(target(), master_impl_.get(), config.operation_timeout_in_ms()); - return Status::OK(); + // Generate a dummy worker session that is used to register the + // Rendezvous for eager (we use Step 0 for eager). + worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr( + "", name_prefix, + std::unique_ptr<WorkerCacheInterface>( + new WorkerCacheWrapper(master_env_.worker_cache)), + worker_env_.device_mgr, {}); + auto* r = worker_env()->rendezvous_mgr->Find(0); + return r->Initialize(worker_session_.get()); } Status GrpcServer::Init( @@ -357,6 +370,9 @@ Status GrpcServer::Start() { worker_thread_.reset( env_->StartThread(ThreadOptions(), "TF_worker_service", [this] { worker_service_->HandleRPCsLoop(); })); + eager_thread_.reset( + env_->StartThread(ThreadOptions(), "TF_eager_service", + [this] { eager_service_->HandleRPCsLoop(); })); state_ = STARTED; LOG(INFO) << "Started server with target: " << target(); return Status::OK(); @@ -399,6 +415,7 @@ Status GrpcServer::Join() { case STOPPED: master_thread_.reset(); worker_thread_.reset(); + eager_thread_.reset(); return Status::OK(); default: LOG(FATAL); @@ -435,6 +452,17 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, return Status::OK(); } +/* static */ +Status GrpcServer::Create(const ServerDef& server_def, Env* env, + std::unique_ptr<GrpcServer>* out_server) { + std::unique_ptr<GrpcServer> ret( + new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); + ServiceInitFunction service_func = nullptr; + TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr, nullptr)); + *out_server = std::move(ret); + return Status::OK(); +} + namespace { class GrpcServerFactory : public ServerFactory { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index ca9946cafc..c674da9490 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -63,6 +63,8 @@ class GrpcServer : public ServerInterface { public: static Status Create(const ServerDef& server_def, Env* env, std::unique_ptr<ServerInterface>* out_server); + static Status Create(const ServerDef& server_def, Env* env, + std::unique_ptr<GrpcServer>* out_server); // Destruction is only supported in the factory method. Clean // shutdown is not currently implemented for this server type. @@ -74,6 +76,11 @@ class GrpcServer : public ServerInterface { Status Join() override; const string target() const override; + WorkerEnv* worker_env() { return &worker_env_; } + MasterEnv* master_env() { return &master_env_; } + + std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; } + protected: Status Init(ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, @@ -112,11 +119,6 @@ class GrpcServer : public ServerInterface { // This method may only be called after `this->Init()` returns successfully. int bound_port() const { return bound_port_; } - WorkerEnv* worker_env() { return &worker_env_; } - MasterEnv* master_env() { return &master_env_; } - - std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; } - const ServerDef& server_def() const { return server_def_; } private: @@ -155,6 +157,11 @@ class GrpcServer : public ServerInterface { AsyncServiceInterface* worker_service_ = nullptr; std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_); + // TensorFlow Eager implementation, and RPC polling thread. + AsyncServiceInterface* eager_service_ = nullptr; + std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_); + std::shared_ptr<WorkerSession> worker_session_; + std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_); }; |