aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/partial_run_mgr.cc
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2017-05-19 01:41:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-19 01:44:43 -0700
commit92d13e7c8c88a092645d5b471ea9deffde147077 (patch)
treec6bddc5880e1102d7f9c9d6e9342b0f767036132 /tensorflow/core/distributed_runtime/partial_run_mgr.cc
parent21185383dfe0047522740045d063f126fb3aa74f (diff)
Refactor partial run state handling into partial_run_mgr.
PiperOrigin-RevId: 156529141
Diffstat (limited to 'tensorflow/core/distributed_runtime/partial_run_mgr.cc')
-rw-r--r--tensorflow/core/distributed_runtime/partial_run_mgr.cc96
1 files changed, 96 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr.cc b/tensorflow/core/distributed_runtime/partial_run_mgr.cc
new file mode 100644
index 0000000000..c0dbabf9a2
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/partial_run_mgr.cc
@@ -0,0 +1,96 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/partial_run_mgr.h"
+
+namespace tensorflow {
+
+namespace {
+// TODO(suharshs): Move this to a common location to allow other part of the
+// repo to use it.
+template <typename T, typename... Args>
+std::unique_ptr<T> MakeUnique(Args&&... args) {
+ return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
+}
+} // namespace
+
+bool PartialRunMgr::FindOrCreate(int step_id,
+ CancellationManager** cancellation_manager) {
+ mutex_lock l(mu_);
+ auto it = step_id_to_partial_run_.find(step_id);
+ if (it != step_id_to_partial_run_.end()) {
+ *cancellation_manager = it->second->cancellation_manager.get();
+ return false;
+ }
+
+ std::unique_ptr<PartialRunState> partial_run = MakeUnique<PartialRunState>();
+ partial_run->cancellation_manager = MakeUnique<CancellationManager>();
+ *cancellation_manager = partial_run->cancellation_manager.get();
+ step_id_to_partial_run_[step_id] = std::move(partial_run);
+ return true;
+}
+
+void PartialRunMgr::ExecutorDone(int step_id, const Status& executor_status) {
+ StatusCallback done;
+ Status callback_status;
+ {
+ mutex_lock l(mu_);
+ auto run_it = step_id_to_partial_run_.find(step_id);
+ if (run_it == step_id_to_partial_run_.end()) {
+ return;
+ }
+ // If we found the partial_run, we call the final callback, if it
+ // exists.
+ // It is guaranteed that run_it->second->final_callback is left empty
+ // after the std::move call.
+ done = std::move(run_it->second->final_callback);
+ if (!executor_status.ok()) {
+ run_it->second->final_status = executor_status;
+ }
+ callback_status = run_it->second->final_status;
+ run_it->second->executor_done = true;
+ }
+ if (done != nullptr) {
+ done(callback_status);
+ mutex_lock l(mu_);
+ step_id_to_partial_run_.erase(step_id);
+ }
+}
+
+void PartialRunMgr::PartialRunDone(int step_id, StatusCallback done,
+ const Status& status) {
+ Status callback_status;
+ {
+ mutex_lock l(mu_);
+ auto run_it = step_id_to_partial_run_.find(step_id);
+ if (run_it == step_id_to_partial_run_.end()) {
+ return;
+ }
+ run_it->second->final_status.Update(status);
+ if (!run_it->second->executor_done) {
+ // If we found the partial_run, we set the final callback to call only
+ // when the executor is completely done.
+ run_it->second->final_callback = std::move(done);
+ return;
+ }
+ callback_status = run_it->second->final_status;
+ }
+ // Otherwise we call the callback immediately.
+ done(callback_status);
+ mutex_lock l(mu_);
+ step_id_to_partial_run_.erase(step_id);
+}
+
+} // namespace tensorflow