diff options
author | 2017-03-15 16:20:51 -0800 | |
---|---|---|
committer | 2017-03-15 17:44:33 -0700 | |
commit | 7b03171d6e50216fc7fdff9a2502a6af660291dd (patch) | |
tree | d7ff375e9d84c02ddc29de18f4c9affcb2c37966 | |
parent | 90d964f3382faf30d291aa3fbdb509844e1f042a (diff) |
Ensure that partial run doesn't block any threads on the worker compute_pool.
Change: 150265300
-rw-r--r-- | tensorflow/core/distributed_runtime/graph_mgr.cc | 90 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/graph_mgr.h | 4 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/worker.cc | 111 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/worker.h | 11 |
4 files changed, 172 insertions, 44 deletions
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 8afbd18a82..f07d7fcc14 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -326,6 +326,69 @@ Status GraphMgr::RecvOutputsFromRendezvous(Rendezvous* rendezvous, return Status::OK(); } +void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, + NamedTensors* out, + const StatusCallback& done) { + if (out->empty()) { + done(Status::OK()); + return; + } + // We compute the args before calling RecvAsync because we need to ensure that + // out isn't being iterated over after done is called, since done deletes out. + std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> args; + for (auto& p : *out) { + Rendezvous::ParsedKey parsed; + Status s = Rendezvous::ParseKey(p.first, &parsed); + if (!s.ok()) { + done(s); + return; + } + args.push_back(std::make_tuple(p.first, &p.second, parsed)); + } + + typedef struct { + mutex mu; + int done_counter; + Status shared_status = Status::OK(); + } CallState; + CallState* call_state = new CallState; + call_state->done_counter = out->size(); + for (auto& p : args) { + const string& key = std::get<0>(p); + Tensor* val = std::get<1>(p); + Rendezvous::ParsedKey parsed = std::get<2>(p); + rendezvous->RecvAsync( + parsed, Rendezvous::Args(), + [val, done, key, call_state](const Status& s, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& v, const bool is_dead) { + Status status = s; + if (status.ok()) { + *val = v; + if (is_dead) { + status = errors::InvalidArgument("The tensor returned for ", key, + " was not valid."); + } + } + call_state->mu.lock(); + if (status.ok()) { + call_state->shared_status = status; + } + call_state->done_counter--; + // If we are the last async call to return, call the done callback. + if (call_state->done_counter == 0) { + const Status& final_status = call_state->shared_status; + call_state->mu.unlock(); + done(final_status); + delete call_state; + return; + } + call_state->mu.unlock(); + }); + } +} + Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) { Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Status s = SendInputsToRendezvous(rendezvous, in); @@ -340,6 +403,16 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { return s; } +void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, + StatusCallback done) { + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); + RecvOutputsFromRendezvousAsync(rendezvous, out, + [done, rendezvous](const Status s) { + rendezvous->Unref(); + done(s); + }); +} + void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, const ExecutorOpts& opts, StepStatsCollector* collector, @@ -395,13 +468,14 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, worker_env_->device_mgr->ClearContainers({name}); }); // NOTE: Transfer one ref of rendezvous and item. - ExecutorBarrier* barrier = new ExecutorBarrier( - num_units, rendezvous, [this, item, collector, cost_graph, step_container, - done](const Status& s) { - BuildCostModel(item, collector, cost_graph); - done(s); - delete step_container; - }); + ExecutorBarrier* barrier = + new ExecutorBarrier(num_units, rendezvous, + [this, item, collector, cost_graph, step_container, + done](const Status& s) { + BuildCostModel(item, collector, cost_graph); + done(s); + delete step_container; + }); Executor::Args args; { mutex_lock l(mu_); @@ -416,7 +490,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, LogMemory::RecordStep(args.step_id, handle); } thread::ThreadPool* pool = worker_env_->compute_pool; - using namespace std::placeholders; + using std::placeholders::_1; // Line below is equivalent to this code, but does one less indirect call: // args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); }; args.runner = std::bind(&thread::ThreadPool::Schedule, pool, _1); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index e9b8d415ed..18013aa91e 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -81,6 +81,8 @@ class GraphMgr { Status SendInputs(const int64 step_id, const NamedTensors& in); Status RecvOutputs(const int64 step_id, NamedTensors* out); + void RecvOutputsAsync(const int64 step_id, NamedTensors* out, + StatusCallback done); // Deregisters a graph. Status Deregister(const string& handle); @@ -156,6 +158,8 @@ class GraphMgr { Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in); Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out); + void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, NamedTensors* out, + const StatusCallback& done); Status InitItem(const string& session, const GraphDef& gdef, const GraphOptions& graph_options, Item* item); diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 0d6ccceef0..654bd4d93f 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -57,7 +57,7 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, Worker::PartialRunState* Worker::FindPartialRun(const string& graph_handle, int step_id) { - std::pair<string, int> k(graph_handle, step_id); + const std::pair<string, int> k(graph_handle, step_id); Worker::PartialRunState* prun_state = nullptr; mutex_lock l(mu_); auto it = partial_runs_.find(k); @@ -70,17 +70,59 @@ Worker::PartialRunState* Worker::FindPartialRun(const string& graph_handle, void Worker::InsertPartialRunLocked(const string& graph_handle, int step_id, Worker::PartialRunState* partial_run_state) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - std::pair<string, int> k(graph_handle, step_id); + const std::pair<string, int> k(graph_handle, step_id); partial_runs_.emplace(std::make_pair( k, std::unique_ptr<Worker::PartialRunState>(partial_run_state))); } void Worker::RemovePartialRun(const string& graph_handle, int step_id) { - std::pair<string, int> k(graph_handle, step_id); + const std::pair<string, int> k(graph_handle, step_id); mutex_lock l(mu_); partial_runs_.erase(partial_runs_.find(k)); } +void Worker::MaybeCallFinalCallback(const string& graph_handle, int step_id, + const Status& executor_status) { + const std::pair<string, int> k(graph_handle, step_id); + StatusCallback done; + Status s; + { + mutex_lock l(mu_); + auto it = partial_runs_.find(k); + if (it != partial_runs_.end()) { + // If we found the partial_run, we call the final callback, if it + // exists. + std::swap(done, it->second->final_callback); + s = it->second->final_status; + it->second->executor_done = true; + } + } + if (done != nullptr) { + if (s.ok()) { + s = executor_status; + } + done(s); + } +} + +void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id, + StatusCallback done, const Status& s) { + const std::pair<string, int> k(graph_handle, step_id); + { + mutex_lock l(mu_); + auto it = partial_runs_.find(k); + if (!it->second->executor_done) { + // If we found the partial_run, we set the final callback to call only + // when the executor is completely done. + it->second->final_callback = std::move(done); + it->second->final_status = s; + return; + } + } + // Otherwise we call the callback immediately. + done(s); +} + void Worker::AbortStep(int64 step_id) { Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id); SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { @@ -205,7 +247,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts, GraphMgr::NamedTensors in; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; Status s = PrepareRunGraph(request, &in, out); - auto finish = [this, done, out](const Status& s) { + auto finish = [this, done, out, opts](const Status& s) { + opts->ClearCancelCallback(); delete out; done(s); }; @@ -249,14 +292,13 @@ void Worker::DoPartialRunGraph(CallOptions* opts, env_->graph_mgr->ExecuteAsync( graph_handle, step_id, request->exec_opts(), nullptr /* collector */, nullptr /* cost_graph */, cm, in, - [this, step_id, graph_handle, token, partial_run_state](Status s) { + [this, token, graph_handle, step_id, cm](Status s) { { mutex_lock l(mu_); cancellation_manager_->DeregisterCallback(token); } - partial_run_state->executor_done.Notify(); - // TODO(suharshs): Propagate the status once we keep state for - // each partial run call. + MaybeCallFinalCallback(graph_handle, step_id, s); + delete cm; }); } else { // Send the partial run's new inputs. @@ -267,33 +309,32 @@ void Worker::DoPartialRunGraph(CallOptions* opts, } } - // Receive the partial run's outputs. - s = env_->graph_mgr->RecvOutputs(step_id, out); - if (!s.ok()) { - finish(s); - return; - } - - // Construct and return the resp. - for (const auto& p : *out) { - const string& key = p.first; - const Tensor& val = p.second; - response->AddRecv(key, val); - } - - // If this is the last partial run request we must also wait for the entire - // graph execution to be completed. - if (request->is_last_partial_run()) { - partial_run_state->executor_done.WaitForNotification(); - RemovePartialRun(graph_handle, step_id); - // Before deleting the cancellation manager on the final call, ensure - // that we clear the RPC cancel callback, which has a reference to the - // cancellation manager. - opts->ClearCancelCallback(); - delete cm; - } - - finish(s); + env_->graph_mgr->RecvOutputsAsync( + step_id, out, + [this, out, request, response, graph_handle, step_id, + finish](Status s) { + if (s.ok()) { + // Construct and return the resp. + for (const auto& p : *out) { + const string& key = p.first; + const Tensor& val = p.second; + response->AddRecv(key, val); + } + } + if (request->is_last_partial_run()) { + SetOrCallFinalCallback( + graph_handle, step_id, + [this, graph_handle, step_id, finish](const Status& s) { + finish(s); + // We must wait to remove the partial_run_state until both the + // executor and the RecvAsync are complete. + RemovePartialRun(graph_handle, step_id); + }, + s); + } else { + finish(s); + } + }); } void Worker::CleanupGraphAsync(const CleanupGraphRequest* request, diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index b52a809a0e..6d1c8e3b00 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -92,7 +92,10 @@ class Worker : public WorkerInterface { struct PartialRunState { CancellationManager* cancellation_manager; - Notification executor_done; + + bool executor_done = false; + StatusCallback final_callback = nullptr; + Status final_status; explicit PartialRunState(CancellationManager* cm) : cancellation_manager(cm) {} @@ -115,6 +118,12 @@ class Worker : public WorkerInterface { void RemovePartialRun(const string& graph_handle, int step_id); + void MaybeCallFinalCallback(const string& graph_handle, int step_id, + const Status& executor_status); + + void SetOrCallFinalCallback(const string& graph_handle, int step_id, + StatusCallback done, const Status& s); + Status PrepareRunGraph(RunGraphRequestWrapper* req, GraphMgr::NamedTensors* in, GraphMgr::NamedTensors* out); |