From e61fb28461cc5f7b09fdd2504fb6ac6ec5bf2c3d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 17 Jul 2016 12:22:18 -0800 Subject: Automated rollback of change 127435209 Change: 127668670 --- .../distributed_runtime/rpc/rpc_rendezvous_mgr.cc | 133 +++++---------------- .../distributed_runtime/rpc/rpc_rendezvous_mgr.h | 6 +- .../rpc/rpc_rendezvous_mgr_test.cc | 24 ---- tensorflow/core/distributed_runtime/worker_cache.h | 14 +-- 4 files changed, 33 insertions(+), 144 deletions(-) diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 92ec9e77bc..96f7db2694 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -37,9 +37,8 @@ namespace { class RpcRemoteRendezvous : public BaseRemoteRendezvous { public: - RpcRemoteRendezvous(const WorkerEnv* env, WorkerCacheInterface* cache, - int64 step_id) - : BaseRemoteRendezvous(env, step_id, false), cache_(cache) {} + RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id) + : BaseRemoteRendezvous(env, step_id, false) {} protected: void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, @@ -49,7 +48,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { private: ~RpcRemoteRendezvous() override {} - WorkerCacheInterface* cache_; // Not owned. TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous); }; @@ -57,12 +55,13 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { class RpcRecvTensorCall : public BaseRecvTensorCall { public: RpcRecvTensorCall() - : wi_(nullptr), allocator_(nullptr), dst_device_(nullptr) {} + : wi_(nullptr), wc_(nullptr), allocator_(nullptr), dst_device_(nullptr) {} - void Init(WorkerInterface* wi, int64 step_id, StringPiece key, - Allocator* allocator, Device* dst_device, + void Init(WorkerCacheInterface* wc, WorkerInterface* wi, int64 step_id, + StringPiece key, Allocator* allocator, Device* dst_device, const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) { wi_ = wi; + wc_ = wc; allocator_ = allocator; dst_device_ = dst_device; recv_args_ = recv_args; @@ -74,6 +73,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { void Reset() { delete wi_; wi_ = nullptr; + wc_ = nullptr; allocator_ = nullptr; dst_device_ = nullptr; // We don't clear opts_ and assume that Init will set up the state for @@ -123,8 +123,6 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { const Rendezvous::DoneCallback& done() const { return done_; } private: - friend class RpcRemoteRendezvous; - // Start the main RecvTensor call, checking for an async abort. void StartRTCall(std::function recv_done) { wi_->RecvTensorAsync(&opts_, &req_, &resp_, @@ -139,9 +137,8 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { }); } - string src_worker_; - string src_rel_device_; - WorkerInterface* wi_; + WorkerInterface* wi_; // Owned. + WorkerCacheInterface* wc_; // Not owned. Allocator* allocator_; Device* dst_device_; CallOptions opts_; @@ -156,6 +153,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall); }; +namespace { class RpcRecvTensorFreeList { public: RpcRecvTensorFreeList() {} @@ -197,99 +195,32 @@ class RpcRecvTensorFreeList { }; static RpcRecvTensorFreeList call_freelist_; - -// A private cache that wraps env->worker_cache and allows reuse of -// WorkerInterface objects. -class WorkerFreeListCache : public WorkerCacheInterface { - public: - explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {} - - ~WorkerFreeListCache() { - for (auto p : workers_) { - delete p.second.worker; - } - } - - void ListWorkers(std::vector* workers) override { - wrapped_->ListWorkers(workers); - } - - WorkerInterface* CreateWorker(const string& target) override { - mutex_lock l(mu_); - auto p = workers_.find(target); - if (p != workers_.end()) { - return p->second.worker; - } - WorkerState state; - state.worker = wrapped_->CreateWorker(target); - if (state.worker != nullptr) { - workers_.insert(make_pair(target, state)); - } - return state.worker; - } - - void ReleaseWorker(const string& target, WorkerInterface* worker) override { - // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction. - } - - bool GetDeviceBusNonBlocking(const string& device, - BusAdjacency* ba) override { - return wrapped_->GetDeviceBusNonBlocking(device, ba); - } - - void GetDeviceBusAsync(const string& device, BusAdjacency* ba, - StatusCallback done) override { - wrapped_->GetDeviceBusAsync(device, ba, done); - } - - void SetLogging(bool active) override { wrapped_->SetLogging(active); } - - void ClearLogs() override { wrapped_->ClearLogs(); } - - bool RetrieveLogs(int64 step_id, StepStats* ss) override { - return wrapped_->RetrieveLogs(step_id, ss); - } - - private: - WorkerCacheInterface* wrapped_; - - // Information kept per created WorkerInterface. - struct WorkerState { - WorkerInterface* worker; - // TODO(jeff,sanjay): Add reference count if we support eviction. - }; - - // TODO(jeff,sanjay): Eviction when the map becomes too big. - mutex mu_; - std::unordered_map workers_ GUARDED_BY(mu_); -}; +} void RpcRemoteRendezvous::RecvFromRemoteAsync( const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { Status s; + // key.src_device identifies a remote device. + string src_worker; + string src_rel_device; + if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker, + &src_rel_device)) { + s = errors::Internal(parsed.src_device, + " is invalid remote source device."); + } // TODO(jeff): Consider checking for a valid worker_cache during the // constructor of RpcRemoteRendezvous, rather than here, to simplify // the twisty logic below. - if (env_->worker_cache == nullptr) { + WorkerCacheInterface* worker_cache = env_->worker_cache; + if (s.ok() && worker_cache == nullptr) { s = errors::Internal("No remote worker cache available."); - done(s, Args(), recv_args, Tensor{}, false); - return; } - - // Prepare a RecvTensor call that can handle being aborted. - RpcRecvTensorCall* call = call_freelist_.New(); - - // key.src_device identifies a remote device. - if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_, - &call->src_rel_device_)) { - s = errors::Internal(parsed.src_device, - " is invalid remote source device."); - } - WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_); + WorkerInterface* rwi = + (worker_cache ? worker_cache->CreateWorker(src_worker) : nullptr); if (s.ok() && rwi == nullptr) { - s = errors::Internal("No worker known as ", call->src_worker_); + s = errors::Internal("No worker known as ", src_worker); } Device* dst_device; @@ -297,14 +228,16 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); } if (!s.ok()) { - call_freelist_.Release(call); done(s, Args(), recv_args, Tensor{}, false); return; } Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs); - call->Init(rwi, step_id_, parsed.FullKey(), allocator, dst_device, recv_args, - std::move(done)); + // Prepare a RecvTensor call that can handle being aborted. + RpcRecvTensorCall* call = call_freelist_.New(); + + call->Init(worker_cache, rwi, step_id_, parsed.FullKey(), allocator, + dst_device, recv_args, std::move(done)); // Record "call" in active_ so that it can be aborted cleanly. RegisterCall(call); @@ -322,21 +255,15 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( call->tensor_proto(), call->recv_args().alloc_attrs, &val); } call->done()(s, Args(), call->recv_args(), val, call->is_dead()); - cache_->ReleaseWorker(call->src_worker_, call->wi_); - call->wi_ = nullptr; call_freelist_.Release(call); }); } } // namespace -RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env) - : BaseRendezvousMgr(env), - cache_(new WorkerFreeListCache(env->worker_cache)) {} - BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id, const WorkerEnv* worker_env) { - return new RpcRemoteRendezvous(worker_env, cache_.get(), step_id); + return new RpcRemoteRendezvous(worker_env, step_id); } } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h index 6a65d04ba4..7447c94c39 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/platform/macros.h" @@ -43,16 +42,13 @@ namespace tensorflow { // RendezvousMgr must have keys generated by Rendezvous::CreateKey. class RpcRendezvousMgr : public BaseRendezvousMgr { public: - explicit RpcRendezvousMgr(const WorkerEnv* env); + explicit RpcRendezvousMgr(const WorkerEnv* env) : BaseRendezvousMgr(env) {} protected: BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env) override; private: - // Private cache_ that allows us to reuse WorkerInterface objects. - std::unique_ptr cache_; - TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr); }; diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index dce49d33d7..7e18278f30 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -46,28 +46,10 @@ Rendezvous::ParsedKey MakeKey(const string& s) { return key; } -namespace { -// Fake cache implementation for WorkerEnv. -class DummyWorkerCache : public WorkerCacheInterface { - void ListWorkers(std::vector* workers) override {} - WorkerInterface* CreateWorker(const string& target) override { - return nullptr; - } - bool GetDeviceBusNonBlocking(const string& device, - BusAdjacency* ba) override { - return false; - } - void GetDeviceBusAsync(const string& device, BusAdjacency* ba, - StatusCallback done) override {} -}; -} // namespace - TEST(RpcRendezvousMgrTest, LocalSendRecv) { - DummyWorkerCache cache; WorkerEnv env; env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; - env.worker_cache = &cache; RpcRendezvousMgr rmgr(&env); const int64 step_id = 123; const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( @@ -89,11 +71,9 @@ TEST(RpcRendezvousMgrTest, LocalSendRecv) { } TEST(RpcRendezvousMgrTest, LocalAbort) { - DummyWorkerCache cache; WorkerEnv env; env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; - env.worker_cache = &cache; RpcRendezvousMgr rmgr(&env); const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( "/job:mnist/replica:1/task:2/cpu:0", 7890, @@ -127,11 +107,9 @@ TEST(RpcRendezvousMgrTest, LocalAbort) { } TEST(RpcRendezvousMgrTest, CleanupAll) { - DummyWorkerCache cache; WorkerEnv env; env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; - env.worker_cache = &cache; RpcRendezvousMgr rmgr(&env); const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( "/job:mnist/replica:1/task:2/cpu:0", 7890, @@ -162,11 +140,9 @@ class DummyDeviceContext : public DeviceContext { TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) { DummyDeviceContext* dc = new DummyDeviceContext(123); - DummyWorkerCache cache; WorkerEnv env; env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; - env.worker_cache = &cache; RpcRendezvousMgr rmgr(&env); const int64 step_id = 123; const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h index c46c056136..3efe14998f 100644 --- a/tensorflow/core/distributed_runtime/worker_cache.h +++ b/tensorflow/core/distributed_runtime/worker_cache.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" // for CallOptions #include "tensorflow/core/framework/device_attributes.pb.h" // for BusAdjacency #include "tensorflow/core/lib/core/status.h" @@ -28,6 +28,7 @@ typedef std::function StatusCallback; class ChannelCache; class StepStats; +class WorkerInterface; class WorkerCacheInterface { public: @@ -45,17 +46,6 @@ class WorkerCacheInterface { // ownership, not a cache lookup. virtual WorkerInterface* CreateWorker(const string& target) = 0; - // Release a worker previously returned by this->CreateWorker(target). - // - // TODO(jeff,sanjay): Consider moving target into WorkerInterface. - // TODO(jeff,sanjay): Consider disallowing direct deletion of WorkerInterface. - // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a - // per-rpc-subsystem WorkerInterface creator. - virtual void ReleaseWorker(const string& target, WorkerInterface* worker) { - // Subclasses may override to reuse worker objects. - delete worker; - } - // Set *ba with the BusAdjacency of the specified remote device // within its local environment. Returns true if the device bus // affinity was set, using only locally cached data. Returns false -- cgit v1.2.3