diff options
author | 2018-06-20 12:41:19 -0700 | |
---|---|---|
committer | 2018-06-20 12:41:19 -0700 | |
commit | 8f1f0a8e4eaa5ae7593dc596b9b69a6cd88fa16a (patch) | |
tree | 7a341bfcec45046067e1753b2fe11e3f543ccfbe /tensorflow/core/distributed_runtime | |
parent | 39ea5a7044a16b868e38717b358c46d6e3191373 (diff) | |
parent | 4fdb7cc4f92e76a168810e9b420bf1b90eb544e9 (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow/core/distributed_runtime')
7 files changed, 50 insertions, 133 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 66c4e5d7a9..4a10d99a60 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -286,7 +286,9 @@ cc_library( "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:session_mgr", + "//tensorflow/core/distributed_runtime:worker_cache_wrapper", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl", ], alwayslink = 1, ) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD index 6b44d8cecf..d09a85c6a5 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD @@ -43,25 +43,11 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:ptr_util", "//tensorflow/core/distributed_runtime/eager:eager_service_impl", + "//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", ], ) - -cc_library( - name = "eager_grpc_server_lib", - hdrs = ["eager_grpc_server_lib.h"], - deps = [ - ":grpc_eager_service_impl", - "//tensorflow/core:core_cpu", - "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", - "//tensorflow/core/distributed_runtime:worker_cache_wrapper", - "//tensorflow/core/distributed_runtime/eager:eager_service_impl", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", - ], -) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h deleted file mode 100644 index 9b863ccee5..0000000000 --- a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_ -#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_ - -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" -#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" -#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" - -namespace tensorflow { -namespace eager { - -class EagerGrpcServer : public GrpcServer { - public: - static Status Create(const ServerDef& server_def, - std::unique_ptr<EagerGrpcServer>* server) { - std::unique_ptr<EagerGrpcServer> ret(new EagerGrpcServer(server_def)); - - TF_RETURN_IF_ERROR(ret->InitEager()); - - *server = std::move(ret); - - return Status::OK(); - } - - Status Start() override { - TF_RETURN_IF_ERROR(GrpcServer::Start()); - - eager_service_->Start(); - - return Status::OK(); - } - - Status Stop() override { - TF_RETURN_IF_ERROR(GrpcServer::Stop()); - - eager_service_->Stop(); - - return Status::OK(); - } - - using GrpcServer::channel_cache; - using GrpcServer::master_env; - using GrpcServer::worker_env; - - private: - EagerGrpcServer(const ServerDef& server_def) - : GrpcServer(server_def, Env::Default()), - worker_name_( - strings::StrCat("/job:", server_def.job_name(), - "/replica:0/task:", server_def.task_index())) {} - - Status InitEager() { - TF_RETURN_IF_ERROR(this->Init( - [this](const WorkerEnv* worker_env, - ::grpc::ServerBuilder* server_builder) { - this->eager_service_.reset( - new eager::GrpcEagerServiceImpl(worker_env, server_builder)); - }, - nullptr, nullptr)); - - worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr( - "", worker_name_, - std::unique_ptr<WorkerCacheInterface>( - new WorkerCacheWrapper(master_env()->worker_cache)), - worker_env()->device_mgr, {}); - - auto* r = worker_env()->rendezvous_mgr->Find(0); - return r->Initialize(worker_session_.get()); - } - - std::unique_ptr<GrpcEagerServiceImpl> eager_service_; - std::shared_ptr<WorkerSession> worker_session_; - const string worker_name_; -}; // namespace eager - -} // namespace eager -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc index b36c6dce86..52e06c263d 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc @@ -18,10 +18,8 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -36,7 +34,7 @@ GrpcEagerServiceImpl::GrpcEagerServiceImpl( cq_ = server_builder->AddCompletionQueue(); } -void GrpcEagerServiceImpl::DriveCQ() { +void GrpcEagerServiceImpl::HandleRPCsLoop() { #define ENQUEUE_REQUEST(method) \ do { \ Call<GrpcEagerServiceImpl, \ @@ -74,12 +72,7 @@ void GrpcEagerServiceImpl::DriveCQ() { } } -void GrpcEagerServiceImpl::Start() { - // TODO(nareshmodi) separate thread for driving CQ - request_handler_threadpool_->Schedule([this]() { DriveCQ(); }); -} - -void GrpcEagerServiceImpl::Stop() { +void GrpcEagerServiceImpl::Shutdown() { // This enqueues a special event (with a null tag) // that causes the completion queue to be shut down on the // polling thread. diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h index e94aedf535..9a94026342 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h @@ -20,16 +20,16 @@ limitations under the License. #include "grpcpp/completion_queue.h" #include "grpcpp/server_builder.h" #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" namespace tensorflow { namespace eager { // This class is a wrapper that handles communication for gRPC. -class GrpcEagerServiceImpl { +class GrpcEagerServiceImpl : public AsyncServiceInterface { public: template <class RequestMessage, class ResponseMessage> using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, @@ -39,8 +39,8 @@ class GrpcEagerServiceImpl { ::grpc::ServerBuilder* server_builder); virtual ~GrpcEagerServiceImpl() {} - void Start(); - void Stop(); + void HandleRPCsLoop() override; + void Shutdown() override; private: #define HANDLER(method) \ @@ -66,8 +66,6 @@ class GrpcEagerServiceImpl { EagerServiceImpl local_impl_; - void DriveCQ(); - std::unique_ptr<::grpc::Alarm> shutdown_alarm_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index e7914740ae..aa334f9424 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/master_session.h" #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" #include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -81,6 +83,7 @@ GrpcServer::~GrpcServer() { delete master_service_; delete worker_service_; + delete eager_service_; // TODO(mrry): Refactor the *Env classes so that it is less fiddly // to destroy them. @@ -192,6 +195,8 @@ Status GrpcServer::Init( worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_); worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder).release(); + eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder); + // extra service: if (service_func != nullptr) { service_func(&worker_env_, &builder); @@ -264,7 +269,15 @@ Status GrpcServer::Init( LocalMaster::Register(target(), master_impl_.get(), config.operation_timeout_in_ms()); - return Status::OK(); + // Generate a dummy worker session that is used to register the + // Rendezvous for eager (we use Step 0 for eager). + worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr( + "", name_prefix, + std::unique_ptr<WorkerCacheInterface>( + new WorkerCacheWrapper(master_env_.worker_cache)), + worker_env_.device_mgr, {}); + auto* r = worker_env()->rendezvous_mgr->Find(0); + return r->Initialize(worker_session_.get()); } Status GrpcServer::Init( @@ -365,6 +378,9 @@ Status GrpcServer::Start() { worker_thread_.reset( env_->StartThread(ThreadOptions(), "TF_worker_service", [this] { worker_service_->HandleRPCsLoop(); })); + eager_thread_.reset( + env_->StartThread(ThreadOptions(), "TF_eager_service", + [this] { eager_service_->HandleRPCsLoop(); })); state_ = STARTED; LOG(INFO) << "Started server with target: " << target(); return Status::OK(); @@ -407,6 +423,7 @@ Status GrpcServer::Join() { case STOPPED: master_thread_.reset(); worker_thread_.reset(); + eager_thread_.reset(); return Status::OK(); default: LOG(FATAL); @@ -443,6 +460,17 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, return Status::OK(); } +/* static */ +Status GrpcServer::Create(const ServerDef& server_def, Env* env, + std::unique_ptr<GrpcServer>* out_server) { + std::unique_ptr<GrpcServer> ret( + new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); + ServiceInitFunction service_func = nullptr; + TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr, nullptr)); + *out_server = std::move(ret); + return Status::OK(); +} + namespace { class GrpcServerFactory : public ServerFactory { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 9e53330f85..115148b84e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -63,6 +63,8 @@ class GrpcServer : public ServerInterface { public: static Status Create(const ServerDef& server_def, Env* env, std::unique_ptr<ServerInterface>* out_server); + static Status Create(const ServerDef& server_def, Env* env, + std::unique_ptr<GrpcServer>* out_server); // Destruction is only supported in the factory method. Clean // shutdown is not currently implemented for this server type. @@ -74,6 +76,11 @@ class GrpcServer : public ServerInterface { Status Join() override; const string target() const override; + WorkerEnv* worker_env() { return &worker_env_; } + MasterEnv* master_env() { return &master_env_; } + + std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; } + protected: Status Init(ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, @@ -115,11 +122,6 @@ class GrpcServer : public ServerInterface { // This method may only be called after `this->Init()` returns successfully. int bound_port() const { return bound_port_; } - WorkerEnv* worker_env() { return &worker_env_; } - MasterEnv* master_env() { return &master_env_; } - - std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; } - const ServerDef& server_def() const { return server_def_; } private: @@ -158,6 +160,11 @@ class GrpcServer : public ServerInterface { AsyncServiceInterface* worker_service_ = nullptr; std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_); + // TensorFlow Eager implementation, and RPC polling thread. + AsyncServiceInterface* eager_service_ = nullptr; + std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_); + std::shared_ptr<WorkerSession> worker_session_; + std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_); }; |