aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2017-03-15 16:20:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-15 17:44:33 -0700
commit7b03171d6e50216fc7fdff9a2502a6af660291dd (patch)
treed7ff375e9d84c02ddc29de18f4c9affcb2c37966
parent90d964f3382faf30d291aa3fbdb509844e1f042a (diff)
Ensure that partial run doesn't block any threads on the worker compute_pool.
Change: 150265300
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc90
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h4
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc111
-rw-r--r--tensorflow/core/distributed_runtime/worker.h11
4 files changed, 172 insertions, 44 deletions
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 8afbd18a82..f07d7fcc14 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -326,6 +326,69 @@ Status GraphMgr::RecvOutputsFromRendezvous(Rendezvous* rendezvous,
return Status::OK();
}
+void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
+ NamedTensors* out,
+ const StatusCallback& done) {
+ if (out->empty()) {
+ done(Status::OK());
+ return;
+ }
+ // We compute the args before calling RecvAsync because we need to ensure that
+ // out isn't being iterated over after done is called, since done deletes out.
+ std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> args;
+ for (auto& p : *out) {
+ Rendezvous::ParsedKey parsed;
+ Status s = Rendezvous::ParseKey(p.first, &parsed);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+ args.push_back(std::make_tuple(p.first, &p.second, parsed));
+ }
+
+ typedef struct {
+ mutex mu;
+ int done_counter;
+ Status shared_status = Status::OK();
+ } CallState;
+ CallState* call_state = new CallState;
+ call_state->done_counter = out->size();
+ for (auto& p : args) {
+ const string& key = std::get<0>(p);
+ Tensor* val = std::get<1>(p);
+ Rendezvous::ParsedKey parsed = std::get<2>(p);
+ rendezvous->RecvAsync(
+ parsed, Rendezvous::Args(),
+ [val, done, key, call_state](const Status& s,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& v, const bool is_dead) {
+ Status status = s;
+ if (status.ok()) {
+ *val = v;
+ if (is_dead) {
+ status = errors::InvalidArgument("The tensor returned for ", key,
+ " was not valid.");
+ }
+ }
+ call_state->mu.lock();
+ if (status.ok()) {
+ call_state->shared_status = status;
+ }
+ call_state->done_counter--;
+ // If we are the last async call to return, call the done callback.
+ if (call_state->done_counter == 0) {
+ const Status& final_status = call_state->shared_status;
+ call_state->mu.unlock();
+ done(final_status);
+ delete call_state;
+ return;
+ }
+ call_state->mu.unlock();
+ });
+ }
+}
+
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = SendInputsToRendezvous(rendezvous, in);
@@ -340,6 +403,16 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
return s;
}
+void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
+ StatusCallback done) {
+ Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
+ RecvOutputsFromRendezvousAsync(rendezvous, out,
+ [done, rendezvous](const Status s) {
+ rendezvous->Unref();
+ done(s);
+ });
+}
+
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts,
StepStatsCollector* collector,
@@ -395,13 +468,14 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
worker_env_->device_mgr->ClearContainers({name});
});
// NOTE: Transfer one ref of rendezvous and item.
- ExecutorBarrier* barrier = new ExecutorBarrier(
- num_units, rendezvous, [this, item, collector, cost_graph, step_container,
- done](const Status& s) {
- BuildCostModel(item, collector, cost_graph);
- done(s);
- delete step_container;
- });
+ ExecutorBarrier* barrier =
+ new ExecutorBarrier(num_units, rendezvous,
+ [this, item, collector, cost_graph, step_container,
+ done](const Status& s) {
+ BuildCostModel(item, collector, cost_graph);
+ done(s);
+ delete step_container;
+ });
Executor::Args args;
{
mutex_lock l(mu_);
@@ -416,7 +490,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
LogMemory::RecordStep(args.step_id, handle);
}
thread::ThreadPool* pool = worker_env_->compute_pool;
- using namespace std::placeholders;
+ using std::placeholders::_1;
// Line below is equivalent to this code, but does one less indirect call:
// args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
args.runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
index e9b8d415ed..18013aa91e 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.h
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -81,6 +81,8 @@ class GraphMgr {
Status SendInputs(const int64 step_id, const NamedTensors& in);
Status RecvOutputs(const int64 step_id, NamedTensors* out);
+ void RecvOutputsAsync(const int64 step_id, NamedTensors* out,
+ StatusCallback done);
// Deregisters a graph.
Status Deregister(const string& handle);
@@ -156,6 +158,8 @@ class GraphMgr {
Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out);
+ void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, NamedTensors* out,
+ const StatusCallback& done);
Status InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options, Item* item);
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 0d6ccceef0..654bd4d93f 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -57,7 +57,7 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
Worker::PartialRunState* Worker::FindPartialRun(const string& graph_handle,
int step_id) {
- std::pair<string, int> k(graph_handle, step_id);
+ const std::pair<string, int> k(graph_handle, step_id);
Worker::PartialRunState* prun_state = nullptr;
mutex_lock l(mu_);
auto it = partial_runs_.find(k);
@@ -70,17 +70,59 @@ Worker::PartialRunState* Worker::FindPartialRun(const string& graph_handle,
void Worker::InsertPartialRunLocked(const string& graph_handle, int step_id,
Worker::PartialRunState* partial_run_state)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- std::pair<string, int> k(graph_handle, step_id);
+ const std::pair<string, int> k(graph_handle, step_id);
partial_runs_.emplace(std::make_pair(
k, std::unique_ptr<Worker::PartialRunState>(partial_run_state)));
}
void Worker::RemovePartialRun(const string& graph_handle, int step_id) {
- std::pair<string, int> k(graph_handle, step_id);
+ const std::pair<string, int> k(graph_handle, step_id);
mutex_lock l(mu_);
partial_runs_.erase(partial_runs_.find(k));
}
+void Worker::MaybeCallFinalCallback(const string& graph_handle, int step_id,
+ const Status& executor_status) {
+ const std::pair<string, int> k(graph_handle, step_id);
+ StatusCallback done;
+ Status s;
+ {
+ mutex_lock l(mu_);
+ auto it = partial_runs_.find(k);
+ if (it != partial_runs_.end()) {
+ // If we found the partial_run, we call the final callback, if it
+ // exists.
+ std::swap(done, it->second->final_callback);
+ s = it->second->final_status;
+ it->second->executor_done = true;
+ }
+ }
+ if (done != nullptr) {
+ if (s.ok()) {
+ s = executor_status;
+ }
+ done(s);
+ }
+}
+
+void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id,
+ StatusCallback done, const Status& s) {
+ const std::pair<string, int> k(graph_handle, step_id);
+ {
+ mutex_lock l(mu_);
+ auto it = partial_runs_.find(k);
+ if (!it->second->executor_done) {
+ // If we found the partial_run, we set the final callback to call only
+ // when the executor is completely done.
+ it->second->final_callback = std::move(done);
+ it->second->final_status = s;
+ return;
+ }
+ }
+ // Otherwise we call the callback immediately.
+ done(s);
+}
+
void Worker::AbortStep(int64 step_id) {
Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
@@ -205,7 +247,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out);
- auto finish = [this, done, out](const Status& s) {
+ auto finish = [this, done, out, opts](const Status& s) {
+ opts->ClearCancelCallback();
delete out;
done(s);
};
@@ -249,14 +292,13 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
env_->graph_mgr->ExecuteAsync(
graph_handle, step_id, request->exec_opts(), nullptr /* collector */,
nullptr /* cost_graph */, cm, in,
- [this, step_id, graph_handle, token, partial_run_state](Status s) {
+ [this, token, graph_handle, step_id, cm](Status s) {
{
mutex_lock l(mu_);
cancellation_manager_->DeregisterCallback(token);
}
- partial_run_state->executor_done.Notify();
- // TODO(suharshs): Propagate the status once we keep state for
- // each partial run call.
+ MaybeCallFinalCallback(graph_handle, step_id, s);
+ delete cm;
});
} else {
// Send the partial run's new inputs.
@@ -267,33 +309,32 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
}
}
- // Receive the partial run's outputs.
- s = env_->graph_mgr->RecvOutputs(step_id, out);
- if (!s.ok()) {
- finish(s);
- return;
- }
-
- // Construct and return the resp.
- for (const auto& p : *out) {
- const string& key = p.first;
- const Tensor& val = p.second;
- response->AddRecv(key, val);
- }
-
- // If this is the last partial run request we must also wait for the entire
- // graph execution to be completed.
- if (request->is_last_partial_run()) {
- partial_run_state->executor_done.WaitForNotification();
- RemovePartialRun(graph_handle, step_id);
- // Before deleting the cancellation manager on the final call, ensure
- // that we clear the RPC cancel callback, which has a reference to the
- // cancellation manager.
- opts->ClearCancelCallback();
- delete cm;
- }
-
- finish(s);
+ env_->graph_mgr->RecvOutputsAsync(
+ step_id, out,
+ [this, out, request, response, graph_handle, step_id,
+ finish](Status s) {
+ if (s.ok()) {
+ // Construct and return the resp.
+ for (const auto& p : *out) {
+ const string& key = p.first;
+ const Tensor& val = p.second;
+ response->AddRecv(key, val);
+ }
+ }
+ if (request->is_last_partial_run()) {
+ SetOrCallFinalCallback(
+ graph_handle, step_id,
+ [this, graph_handle, step_id, finish](const Status& s) {
+ finish(s);
+ // We must wait to remove the partial_run_state until both the
+ // executor and the RecvAsync are complete.
+ RemovePartialRun(graph_handle, step_id);
+ },
+ s);
+ } else {
+ finish(s);
+ }
+ });
}
void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h
index b52a809a0e..6d1c8e3b00 100644
--- a/tensorflow/core/distributed_runtime/worker.h
+++ b/tensorflow/core/distributed_runtime/worker.h
@@ -92,7 +92,10 @@ class Worker : public WorkerInterface {
struct PartialRunState {
CancellationManager* cancellation_manager;
- Notification executor_done;
+
+ bool executor_done = false;
+ StatusCallback final_callback = nullptr;
+ Status final_status;
explicit PartialRunState(CancellationManager* cm)
: cancellation_manager(cm) {}
@@ -115,6 +118,12 @@ class Worker : public WorkerInterface {
void RemovePartialRun(const string& graph_handle, int step_id);
+ void MaybeCallFinalCallback(const string& graph_handle, int step_id,
+ const Status& executor_status);
+
+ void SetOrCallFinalCallback(const string& graph_handle, int step_id,
+ StatusCallback done, const Status& s);
+
Status PrepareRunGraph(RunGraphRequestWrapper* req,
GraphMgr::NamedTensors* in,
GraphMgr::NamedTensors* out);