diff options
author | 2017-05-19 01:41:18 -0700 | |
---|---|---|
committer | 2017-05-19 01:44:43 -0700 | |
commit | 92d13e7c8c88a092645d5b471ea9deffde147077 (patch) | |
tree | c6bddc5880e1102d7f9c9d6e9342b0f767036132 /tensorflow | |
parent | 21185383dfe0047522740045d063f126fb3aa74f (diff) |
Refactor partial run state handling into partial_run_mgr.
PiperOrigin-RevId: 156529141
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/distributed_runtime/BUILD | 23 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/partial_run_mgr.cc | 96 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/partial_run_mgr.h | 87 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/partial_run_mgr_test.cc | 151 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/worker.cc | 101 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/worker.h | 37 |
6 files changed, 366 insertions, 129 deletions
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index d2a828f39f..e8aabf72dc 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -45,6 +45,28 @@ package(default_visibility = [ ]) cc_library( + name = "partial_run_mgr", + srcs = ["partial_run_mgr.cc"], + hdrs = ["partial_run_mgr.h"], + deps = [ + ":worker_interface", + "//tensorflow/core:framework", + ], +) + +cc_test( + name = "partial_run_mgr_test", + size = "small", + srcs = ["partial_run_mgr_test.cc"], + deps = [ + ":partial_run_mgr", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( name = "message_wrappers", srcs = ["message_wrappers.cc"], hdrs = ["message_wrappers.h"], @@ -141,6 +163,7 @@ cc_library( ], deps = [ ":graph_mgr", + ":partial_run_mgr", ":rendezvous_mgr_interface", ":session_mgr", ":worker_interface", diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr.cc b/tensorflow/core/distributed_runtime/partial_run_mgr.cc new file mode 100644 index 0000000000..c0dbabf9a2 --- /dev/null +++ b/tensorflow/core/distributed_runtime/partial_run_mgr.cc @@ -0,0 +1,96 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/partial_run_mgr.h" + +namespace tensorflow { + +namespace { +// TODO(suharshs): Move this to a common location to allow other part of the +// repo to use it. +template <typename T, typename... Args> +std::unique_ptr<T> MakeUnique(Args&&... args) { + return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); +} +} // namespace + +bool PartialRunMgr::FindOrCreate(int step_id, + CancellationManager** cancellation_manager) { + mutex_lock l(mu_); + auto it = step_id_to_partial_run_.find(step_id); + if (it != step_id_to_partial_run_.end()) { + *cancellation_manager = it->second->cancellation_manager.get(); + return false; + } + + std::unique_ptr<PartialRunState> partial_run = MakeUnique<PartialRunState>(); + partial_run->cancellation_manager = MakeUnique<CancellationManager>(); + *cancellation_manager = partial_run->cancellation_manager.get(); + step_id_to_partial_run_[step_id] = std::move(partial_run); + return true; +} + +void PartialRunMgr::ExecutorDone(int step_id, const Status& executor_status) { + StatusCallback done; + Status callback_status; + { + mutex_lock l(mu_); + auto run_it = step_id_to_partial_run_.find(step_id); + if (run_it == step_id_to_partial_run_.end()) { + return; + } + // If we found the partial_run, we call the final callback, if it + // exists. + // It is guaranteed that run_it->second->final_callback is left empty + // after the std::move call. + done = std::move(run_it->second->final_callback); + if (!executor_status.ok()) { + run_it->second->final_status = executor_status; + } + callback_status = run_it->second->final_status; + run_it->second->executor_done = true; + } + if (done != nullptr) { + done(callback_status); + mutex_lock l(mu_); + step_id_to_partial_run_.erase(step_id); + } +} + +void PartialRunMgr::PartialRunDone(int step_id, StatusCallback done, + const Status& status) { + Status callback_status; + { + mutex_lock l(mu_); + auto run_it = step_id_to_partial_run_.find(step_id); + if (run_it == step_id_to_partial_run_.end()) { + return; + } + run_it->second->final_status.Update(status); + if (!run_it->second->executor_done) { + // If we found the partial_run, we set the final callback to call only + // when the executor is completely done. + run_it->second->final_callback = std::move(done); + return; + } + callback_status = run_it->second->final_status; + } + // Otherwise we call the callback immediately. + done(callback_status); + mutex_lock l(mu_); + step_id_to_partial_run_.erase(step_id); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr.h b/tensorflow/core/distributed_runtime/partial_run_mgr.h new file mode 100644 index 0000000000..af56e723a9 --- /dev/null +++ b/tensorflow/core/distributed_runtime/partial_run_mgr.h @@ -0,0 +1,87 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ + +#include <unordered_map> + +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// PartialRunMgr keeps track of pending partial run requests, and ensures that +// the partial run is only marked complete when the corresponding executor is +// run to completion. +// +// In tensorflow workers, the executor runs operations asynchronously until +// specified fetches (operations that return tensors) or targets (operations +// that don't return tensors) are reached. A PartialRun has two components: a +// setup which specifies all desired fetches and targets, and run calls that +// specify fetch values (from the setup calls) to retrieve. +// On the last partial run call, it is possible to satisfy the +// required fetches before the executor has completed running the graph to all +// the desired targets. +// PartialRunMgr is used to ensure that we don't complete and return the final +// partial run call to the user until both the partial run and executor have +// completed. +// +// PartialRunMgr is thread-safe. +class PartialRunMgr { + public: + // Find or create the CancellationManager associated with step_id. + // The PartialRunMgr owns the cancellation_manager. + // Returns true if a new CancellationManager was created + // (i.e this is a new partial run). + bool FindOrCreate(int step_id, CancellationManager** cancellation_manager); + + // Calls the final callback if the PartialRunRequest has already completed. + // Otherwise stores the executor_status to be propagated when the + // PartialRunRequest completes (PartialRunDone has been called). + void ExecutorDone(int step_id, const Status& executor_status); + + // Calls done if the executor has already completed (ExecutorDone has been + // called). Otherwise, stores the status and done callback, calling them when + // ExecutorDone is called. The callback will either be called by the calling + // thread of either PartialRunDone or ExecutorDone. + // If executor_status in ExecutorDone is not OK, it takes precedence over + // status and is passed to the done callback. + void PartialRunDone(int step_id, StatusCallback done, const Status& status); + + private: + // PartialRunState stores state associated with a pending partial run request. + // This is protected by the mutex in PartialRunMgr. + struct PartialRunState { + std::unique_ptr<CancellationManager> cancellation_manager; + + bool executor_done = false; + StatusCallback final_callback = nullptr; + Status final_status; + }; + + mutex mu_; + + std::unordered_map<int, std::unique_ptr<PartialRunState>> + step_id_to_partial_run_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc b/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc new file mode 100644 index 0000000000..5f7c0cb3ca --- /dev/null +++ b/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/partial_run_mgr.h" + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(PartialRunMgrFindOrCreate, Create) { + // Basic test of PartialRunMgr CancellationManager creation. + PartialRunMgr partial_run_mgr; + int step_id = 1; + CancellationManager* cancellation_manager; + partial_run_mgr.FindOrCreate(step_id, &cancellation_manager); + EXPECT_TRUE(cancellation_manager != nullptr); +} + +TEST(PartialRunMgrFindOrCreate, Find) { + // Basic test of PartialRunMgr CancellationManager find. + PartialRunMgr partial_run_mgr; + int step_id = 1; + CancellationManager* cancellation_manager; + partial_run_mgr.FindOrCreate(step_id, &cancellation_manager); + // Looking for the same step should return the same cancellation_manager. + CancellationManager* found_cancellation_manager; + partial_run_mgr.FindOrCreate(step_id, &found_cancellation_manager); + EXPECT_EQ(cancellation_manager, found_cancellation_manager); +} + +TEST(PartialRunMgrFindOrCreate, NewCreate) { + // Test that PartialRunMgr creates a new CancellationManager for new steps. + PartialRunMgr partial_run_mgr; + int step_id = 1; + CancellationManager* cancellation_manager; + partial_run_mgr.FindOrCreate(step_id, &cancellation_manager); + // FindOrCreate on a new step should return a new cancellation_manager. + int new_step_id = 2; + CancellationManager* new_cancellation_manager; + partial_run_mgr.FindOrCreate(new_step_id, &new_cancellation_manager); + EXPECT_NE(cancellation_manager, new_cancellation_manager); +} + +TEST(PartialRunMgr, PartialRunRemoved) { + // Test that PartialRunMgr ensures that the PartialRun is deleted after + // ExecutorDone and PartialRunDone are called. + PartialRunMgr partial_run_mgr; + int step_id = 1; + CancellationManager* cancellation_manager; + partial_run_mgr.FindOrCreate(step_id, &cancellation_manager); + + int called = 0; + partial_run_mgr.PartialRunDone( + step_id, [&called](Status status) { called++; }, Status::OK()); + partial_run_mgr.ExecutorDone(step_id, Status::OK()); + + // Calling ExecutorDone and PartialRunDone on the step_id should still only + // result in the callback being called once. + // This proves that the original PartialRun has been removed. + partial_run_mgr.PartialRunDone( + step_id, [&called](Status status) { called++; }, Status::OK()); + partial_run_mgr.ExecutorDone(step_id, Status::OK()); + EXPECT_EQ(1, called); +} + +struct StatusTestParam { + Status executor_status; + Status partial_run_status; + Status expected_status; +}; + +class StatusPropagationTest : public ::testing::TestWithParam<StatusTestParam> { + protected: + PartialRunMgr partial_run_mgr_; + + // State to help keep track of when the callback is called. + Notification invoked_; + Status status_; + + void set_status(const Status& status) { + status_ = status; + invoked_.Notify(); + } + + // Blocks until status is set. + Status status() { + invoked_.WaitForNotification(); + return status_; + } +}; + +TEST_P(StatusPropagationTest, ExecutorDoneFirst) { + // Tests error propagation when ExecutorDone is called first. + StatusTestParam param = GetParam(); + int step_id = 1; + + CancellationManager* cancellation_manager; + partial_run_mgr_.FindOrCreate(step_id, &cancellation_manager); + + partial_run_mgr_.ExecutorDone(step_id, param.executor_status); + partial_run_mgr_.PartialRunDone(step_id, + [this](Status status) { set_status(status); }, + param.partial_run_status); + + EXPECT_EQ(status(), param.expected_status); +} + +TEST_P(StatusPropagationTest, PartialRunDoneFirst) { + // Tests error propagation when PartialRunDone is called first. + StatusTestParam param = GetParam(); + int step_id = 1; + + CancellationManager* cancellation_manager; + partial_run_mgr_.FindOrCreate(step_id, &cancellation_manager); + + partial_run_mgr_.PartialRunDone(step_id, + [this](Status status) { set_status(status); }, + param.partial_run_status); + partial_run_mgr_.ExecutorDone(step_id, param.executor_status); + + EXPECT_EQ(status(), param.expected_status); +} + +// Instantiate tests for all error orderings, for both call orders of +// ExecutorDone and PartialRunDone. +Status ExecutorError() { return errors::Internal("executor error"); } +Status PartialRunError() { return errors::Internal("partial run error"); } +INSTANTIATE_TEST_CASE_P( + PartialRunMgr, StatusPropagationTest, + ::testing::Values( + StatusTestParam{Status::OK(), Status::OK(), Status::OK()}, + StatusTestParam{ExecutorError(), Status::OK(), ExecutorError()}, + StatusTestParam{Status::OK(), PartialRunError(), PartialRunError()}, + StatusTestParam{ExecutorError(), PartialRunError(), ExecutorError()})); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 07bb17981d..32ea0cfaa4 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -69,72 +69,6 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, done(s); } -Worker::PartialRunState* Worker::FindPartialRun(const string& graph_handle, - int step_id) { - const std::pair<string, int> k(graph_handle, step_id); - Worker::PartialRunState* prun_state = nullptr; - mutex_lock l(mu_); - auto it = partial_runs_.find(k); - if (it != partial_runs_.end()) { - prun_state = it->second.get(); - } - return prun_state; -} - -void Worker::InsertPartialRunLocked(const string& graph_handle, int step_id, - Worker::PartialRunState* partial_run_state) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - const std::pair<string, int> k(graph_handle, step_id); - partial_runs_.emplace(std::make_pair( - k, std::unique_ptr<Worker::PartialRunState>(partial_run_state))); -} - -void Worker::RemovePartialRun(const string& graph_handle, int step_id) { - const std::pair<string, int> k(graph_handle, step_id); - mutex_lock l(mu_); - partial_runs_.erase(partial_runs_.find(k)); -} - -void Worker::MaybeCallFinalCallback(const string& graph_handle, int step_id, - const Status& executor_status) { - const std::pair<string, int> k(graph_handle, step_id); - StatusCallback done; - Status s; - { - mutex_lock l(mu_); - auto it = partial_runs_.find(k); - if (it != partial_runs_.end()) { - // If we found the partial_run, we call the final callback, if it - // exists. - std::swap(done, it->second->final_callback); - s = it->second->final_status; - it->second->executor_done = true; - } - } - if (done != nullptr) { - s.Update(executor_status); - done(s); - } -} - -void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id, - StatusCallback done, const Status& s) { - const std::pair<string, int> k(graph_handle, step_id); - { - mutex_lock l(mu_); - auto it = partial_runs_.find(k); - if (!it->second->executor_done) { - // If we found the partial_run, we set the final callback to call only - // when the executor is completely done. - it->second->final_callback = std::move(done); - it->second->final_status = s; - return; - } - } - // Otherwise we call the callback immediately. - done(s); -} - void Worker::AbortStep(int64 step_id) { Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id); SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { @@ -275,18 +209,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts, return; } - PartialRunState* partial_run_state = FindPartialRun(graph_handle, step_id); - CancellationManager* cm = nullptr; - // If this is a new partial run call we need to create a new cancellation - // manager. - // Otherwise we use the cancellation manager stored in the found partial - // run state. - if (partial_run_state == nullptr) { - cm = new CancellationManager; - } else { - cm = partial_run_state->cancellation_manager; - } + bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm); // Before we start doing anything, we set the RPC cancellation. opts->SetCancelCallback([this, cm, step_id]() { @@ -296,13 +220,10 @@ void Worker::DoPartialRunGraph(CallOptions* opts, // If this is a new partial run request, the request will need to start the // executors. - if (partial_run_state == nullptr) { + if (is_new_partial_run) { CancellationToken token; { mutex_lock l(mu_); - // Insert the new partial run into the partial_runs_ map. - partial_run_state = new PartialRunState(cm); - InsertPartialRunLocked(graph_handle, step_id, partial_run_state); token = cancellation_manager_->get_cancellation_token(); cancellation_manager_->RegisterCallback(token, [cm]() { cm->StartCancel(); }); @@ -310,13 +231,12 @@ void Worker::DoPartialRunGraph(CallOptions* opts, session->graph_mgr->ExecuteAsync( graph_handle, step_id, session, request->exec_opts(), nullptr /* collector */, nullptr /* cost_graph */, cm, in, - [this, token, graph_handle, step_id, cm](Status s) { + [this, token, step_id, cm](Status s) { { mutex_lock l(mu_); cancellation_manager_->DeregisterCallback(token); } - MaybeCallFinalCallback(graph_handle, step_id, s); - delete cm; + partial_run_mgr_.ExecutorDone(step_id, s); }); } else { // Send the partial run's new inputs. @@ -328,8 +248,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts, } session->graph_mgr->RecvOutputsAsync( - step_id, out, - [this, out, request, response, graph_handle, step_id, finish](Status s) { + step_id, out, [this, out, request, response, step_id, finish](Status s) { if (s.ok()) { // Construct and return the resp. for (const auto& p : *out) { @@ -339,15 +258,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts, } } if (request->is_last_partial_run()) { - SetOrCallFinalCallback( - graph_handle, step_id, - [this, graph_handle, step_id, finish](const Status& s) { - finish(s); - // We must wait to remove the partial_run_state until both the - // executor and the RecvAsync are complete. - RemovePartialRun(graph_handle, step_id); - }, - s); + partial_run_mgr_.PartialRunDone(step_id, std::move(finish), s); } else { finish(s); } diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index 290fc6de95..07300338c3 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -19,6 +19,7 @@ limitations under the License. #include <unordered_map> #include "tensorflow/core/distributed_runtime/graph_mgr.h" +#include "tensorflow/core/distributed_runtime/partial_run_mgr.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" @@ -93,43 +94,11 @@ class Worker : public WorkerInterface { void AbortStep(int64); private: + PartialRunMgr partial_run_mgr_; + mutex mu_; CancellationManager* cancellation_manager_ GUARDED_BY(mu_); - struct PartialRunState { - CancellationManager* cancellation_manager; - - bool executor_done = false; - StatusCallback final_callback = nullptr; - Status final_status; - - explicit PartialRunState(CancellationManager* cm) - : cancellation_manager(cm) {} - }; - struct PairHash { - std::size_t operator()(std::pair<string, int> const& p) const { - return Hash64Combine(std::hash<string>()(p.first), - std::hash<int>()(p.second)); - } - }; - std::unordered_map<std::pair<string, int>, std::unique_ptr<PartialRunState>, - PairHash> - partial_runs_ GUARDED_BY(mu_); - - PartialRunState* FindPartialRun(const string& graph_handle, int step_id); - - void InsertPartialRunLocked(const string& graph_handle, int step_id, - PartialRunState* partial_run_state) - EXCLUSIVE_LOCKS_REQUIRED(mu_); - - void RemovePartialRun(const string& graph_handle, int step_id); - - void MaybeCallFinalCallback(const string& graph_handle, int step_id, - const Status& executor_status); - - void SetOrCallFinalCallback(const string& graph_handle, int step_id, - StatusCallback done, const Status& s); - Status PrepareRunGraph(RunGraphRequestWrapper* req, GraphMgr::NamedTensors* in, GraphMgr::NamedTensors* out); |