aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-17 12:22:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-17 13:33:33 -0700
commite61fb28461cc5f7b09fdd2504fb6ac6ec5bf2c3d (patch)
tree2460ef8d0d79ba0af4ea7e027e0b8fcd17cab432
parent462892fc48a9db857ee0e126a254441a9a8fb163 (diff)
Automated rollback of change 127435209
Change: 127668670
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc133
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc24
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache.h14
4 files changed, 33 insertions, 144 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
index 92ec9e77bc..96f7db2694 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
@@ -37,9 +37,8 @@ namespace {
class RpcRemoteRendezvous : public BaseRemoteRendezvous {
public:
- RpcRemoteRendezvous(const WorkerEnv* env, WorkerCacheInterface* cache,
- int64 step_id)
- : BaseRemoteRendezvous(env, step_id, false), cache_(cache) {}
+ RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
+ : BaseRemoteRendezvous(env, step_id, false) {}
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
@@ -49,7 +48,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
private:
~RpcRemoteRendezvous() override {}
- WorkerCacheInterface* cache_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
};
@@ -57,12 +55,13 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
class RpcRecvTensorCall : public BaseRecvTensorCall {
public:
RpcRecvTensorCall()
- : wi_(nullptr), allocator_(nullptr), dst_device_(nullptr) {}
+ : wi_(nullptr), wc_(nullptr), allocator_(nullptr), dst_device_(nullptr) {}
- void Init(WorkerInterface* wi, int64 step_id, StringPiece key,
- Allocator* allocator, Device* dst_device,
+ void Init(WorkerCacheInterface* wc, WorkerInterface* wi, int64 step_id,
+ StringPiece key, Allocator* allocator, Device* dst_device,
const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
wi_ = wi;
+ wc_ = wc;
allocator_ = allocator;
dst_device_ = dst_device;
recv_args_ = recv_args;
@@ -74,6 +73,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
void Reset() {
delete wi_;
wi_ = nullptr;
+ wc_ = nullptr;
allocator_ = nullptr;
dst_device_ = nullptr;
// We don't clear opts_ and assume that Init will set up the state for
@@ -123,8 +123,6 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
const Rendezvous::DoneCallback& done() const { return done_; }
private:
- friend class RpcRemoteRendezvous;
-
// Start the main RecvTensor call, checking for an async abort.
void StartRTCall(std::function<void()> recv_done) {
wi_->RecvTensorAsync(&opts_, &req_, &resp_,
@@ -139,9 +137,8 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
});
}
- string src_worker_;
- string src_rel_device_;
- WorkerInterface* wi_;
+ WorkerInterface* wi_; // Owned.
+ WorkerCacheInterface* wc_; // Not owned.
Allocator* allocator_;
Device* dst_device_;
CallOptions opts_;
@@ -156,6 +153,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
};
+namespace {
class RpcRecvTensorFreeList {
public:
RpcRecvTensorFreeList() {}
@@ -197,99 +195,32 @@ class RpcRecvTensorFreeList {
};
static RpcRecvTensorFreeList call_freelist_;
-
-// A private cache that wraps env->worker_cache and allows reuse of
-// WorkerInterface objects.
-class WorkerFreeListCache : public WorkerCacheInterface {
- public:
- explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {}
-
- ~WorkerFreeListCache() {
- for (auto p : workers_) {
- delete p.second.worker;
- }
- }
-
- void ListWorkers(std::vector<string>* workers) override {
- wrapped_->ListWorkers(workers);
- }
-
- WorkerInterface* CreateWorker(const string& target) override {
- mutex_lock l(mu_);
- auto p = workers_.find(target);
- if (p != workers_.end()) {
- return p->second.worker;
- }
- WorkerState state;
- state.worker = wrapped_->CreateWorker(target);
- if (state.worker != nullptr) {
- workers_.insert(make_pair(target, state));
- }
- return state.worker;
- }
-
- void ReleaseWorker(const string& target, WorkerInterface* worker) override {
- // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
- }
-
- bool GetDeviceBusNonBlocking(const string& device,
- BusAdjacency* ba) override {
- return wrapped_->GetDeviceBusNonBlocking(device, ba);
- }
-
- void GetDeviceBusAsync(const string& device, BusAdjacency* ba,
- StatusCallback done) override {
- wrapped_->GetDeviceBusAsync(device, ba, done);
- }
-
- void SetLogging(bool active) override { wrapped_->SetLogging(active); }
-
- void ClearLogs() override { wrapped_->ClearLogs(); }
-
- bool RetrieveLogs(int64 step_id, StepStats* ss) override {
- return wrapped_->RetrieveLogs(step_id, ss);
- }
-
- private:
- WorkerCacheInterface* wrapped_;
-
- // Information kept per created WorkerInterface.
- struct WorkerState {
- WorkerInterface* worker;
- // TODO(jeff,sanjay): Add reference count if we support eviction.
- };
-
- // TODO(jeff,sanjay): Eviction when the map becomes too big.
- mutex mu_;
- std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
-};
+}
void RpcRemoteRendezvous::RecvFromRemoteAsync(
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) {
Status s;
+ // key.src_device identifies a remote device.
+ string src_worker;
+ string src_rel_device;
+ if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker,
+ &src_rel_device)) {
+ s = errors::Internal(parsed.src_device,
+ " is invalid remote source device.");
+ }
// TODO(jeff): Consider checking for a valid worker_cache during the
// constructor of RpcRemoteRendezvous, rather than here, to simplify
// the twisty logic below.
- if (env_->worker_cache == nullptr) {
+ WorkerCacheInterface* worker_cache = env_->worker_cache;
+ if (s.ok() && worker_cache == nullptr) {
s = errors::Internal("No remote worker cache available.");
- done(s, Args(), recv_args, Tensor{}, false);
- return;
}
-
- // Prepare a RecvTensor call that can handle being aborted.
- RpcRecvTensorCall* call = call_freelist_.New();
-
- // key.src_device identifies a remote device.
- if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
- &call->src_rel_device_)) {
- s = errors::Internal(parsed.src_device,
- " is invalid remote source device.");
- }
- WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_);
+ WorkerInterface* rwi =
+ (worker_cache ? worker_cache->CreateWorker(src_worker) : nullptr);
if (s.ok() && rwi == nullptr) {
- s = errors::Internal("No worker known as ", call->src_worker_);
+ s = errors::Internal("No worker known as ", src_worker);
}
Device* dst_device;
@@ -297,14 +228,16 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
}
if (!s.ok()) {
- call_freelist_.Release(call);
done(s, Args(), recv_args, Tensor{}, false);
return;
}
Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs);
- call->Init(rwi, step_id_, parsed.FullKey(), allocator, dst_device, recv_args,
- std::move(done));
+ // Prepare a RecvTensor call that can handle being aborted.
+ RpcRecvTensorCall* call = call_freelist_.New();
+
+ call->Init(worker_cache, rwi, step_id_, parsed.FullKey(), allocator,
+ dst_device, recv_args, std::move(done));
// Record "call" in active_ so that it can be aborted cleanly.
RegisterCall(call);
@@ -322,21 +255,15 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
call->tensor_proto(), call->recv_args().alloc_attrs, &val);
}
call->done()(s, Args(), call->recv_args(), val, call->is_dead());
- cache_->ReleaseWorker(call->src_worker_, call->wi_);
- call->wi_ = nullptr;
call_freelist_.Release(call);
});
}
} // namespace
-RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
- : BaseRendezvousMgr(env),
- cache_(new WorkerFreeListCache(env->worker_cache)) {}
-
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
const WorkerEnv* worker_env) {
- return new RpcRemoteRendezvous(worker_env, cache_.get(), step_id);
+ return new RpcRemoteRendezvous(worker_env, step_id);
}
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
index 6a65d04ba4..7447c94c39 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
-#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/platform/macros.h"
@@ -43,16 +42,13 @@ namespace tensorflow {
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
class RpcRendezvousMgr : public BaseRendezvousMgr {
public:
- explicit RpcRendezvousMgr(const WorkerEnv* env);
+ explicit RpcRendezvousMgr(const WorkerEnv* env) : BaseRendezvousMgr(env) {}
protected:
BaseRemoteRendezvous* Create(int64 step_id,
const WorkerEnv* worker_env) override;
private:
- // Private cache_ that allows us to reuse WorkerInterface objects.
- std::unique_ptr<WorkerCacheInterface> cache_;
-
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
};
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
index dce49d33d7..7e18278f30 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
@@ -46,28 +46,10 @@ Rendezvous::ParsedKey MakeKey(const string& s) {
return key;
}
-namespace {
-// Fake cache implementation for WorkerEnv.
-class DummyWorkerCache : public WorkerCacheInterface {
- void ListWorkers(std::vector<string>* workers) override {}
- WorkerInterface* CreateWorker(const string& target) override {
- return nullptr;
- }
- bool GetDeviceBusNonBlocking(const string& device,
- BusAdjacency* ba) override {
- return false;
- }
- void GetDeviceBusAsync(const string& device, BusAdjacency* ba,
- StatusCallback done) override {}
-};
-} // namespace
-
TEST(RpcRendezvousMgrTest, LocalSendRecv) {
- DummyWorkerCache cache;
WorkerEnv env;
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
- env.worker_cache = &cache;
RpcRendezvousMgr rmgr(&env);
const int64 step_id = 123;
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
@@ -89,11 +71,9 @@ TEST(RpcRendezvousMgrTest, LocalSendRecv) {
}
TEST(RpcRendezvousMgrTest, LocalAbort) {
- DummyWorkerCache cache;
WorkerEnv env;
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
- env.worker_cache = &cache;
RpcRendezvousMgr rmgr(&env);
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
"/job:mnist/replica:1/task:2/cpu:0", 7890,
@@ -127,11 +107,9 @@ TEST(RpcRendezvousMgrTest, LocalAbort) {
}
TEST(RpcRendezvousMgrTest, CleanupAll) {
- DummyWorkerCache cache;
WorkerEnv env;
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
- env.worker_cache = &cache;
RpcRendezvousMgr rmgr(&env);
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
"/job:mnist/replica:1/task:2/cpu:0", 7890,
@@ -162,11 +140,9 @@ class DummyDeviceContext : public DeviceContext {
TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
DummyDeviceContext* dc = new DummyDeviceContext(123);
- DummyWorkerCache cache;
WorkerEnv env;
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
- env.worker_cache = &cache;
RpcRendezvousMgr rmgr(&env);
const int64 step_id = 123;
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h
index c46c056136..3efe14998f 100644
--- a/tensorflow/core/distributed_runtime/worker_cache.h
+++ b/tensorflow/core/distributed_runtime/worker_cache.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h" // for CallOptions
#include "tensorflow/core/framework/device_attributes.pb.h" // for BusAdjacency
#include "tensorflow/core/lib/core/status.h"
@@ -28,6 +28,7 @@ typedef std::function<void(const Status&)> StatusCallback;
class ChannelCache;
class StepStats;
+class WorkerInterface;
class WorkerCacheInterface {
public:
@@ -45,17 +46,6 @@ class WorkerCacheInterface {
// ownership, not a cache lookup.
virtual WorkerInterface* CreateWorker(const string& target) = 0;
- // Release a worker previously returned by this->CreateWorker(target).
- //
- // TODO(jeff,sanjay): Consider moving target into WorkerInterface.
- // TODO(jeff,sanjay): Consider disallowing direct deletion of WorkerInterface.
- // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a
- // per-rpc-subsystem WorkerInterface creator.
- virtual void ReleaseWorker(const string& target, WorkerInterface* worker) {
- // Subclasses may override to reuse worker objects.
- delete worker;
- }
-
// Set *ba with the BusAdjacency of the specified remote device
// within its local environment. Returns true if the device bus
// affinity was set, using only locally cached data. Returns false