diff options
author | 2017-05-19 01:41:18 -0700 | |
---|---|---|
committer | 2017-05-19 01:44:43 -0700 | |
commit | 92d13e7c8c88a092645d5b471ea9deffde147077 (patch) | |
tree | c6bddc5880e1102d7f9c9d6e9342b0f767036132 /tensorflow/core/distributed_runtime/partial_run_mgr.cc | |
parent | 21185383dfe0047522740045d063f126fb3aa74f (diff) |
Refactor partial run state handling into partial_run_mgr.
PiperOrigin-RevId: 156529141
Diffstat (limited to 'tensorflow/core/distributed_runtime/partial_run_mgr.cc')
-rw-r--r-- | tensorflow/core/distributed_runtime/partial_run_mgr.cc | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr.cc b/tensorflow/core/distributed_runtime/partial_run_mgr.cc new file mode 100644 index 0000000000..c0dbabf9a2 --- /dev/null +++ b/tensorflow/core/distributed_runtime/partial_run_mgr.cc @@ -0,0 +1,96 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/partial_run_mgr.h" + +namespace tensorflow { + +namespace { +// TODO(suharshs): Move this to a common location to allow other part of the +// repo to use it. +template <typename T, typename... Args> +std::unique_ptr<T> MakeUnique(Args&&... args) { + return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); +} +} // namespace + +bool PartialRunMgr::FindOrCreate(int step_id, + CancellationManager** cancellation_manager) { + mutex_lock l(mu_); + auto it = step_id_to_partial_run_.find(step_id); + if (it != step_id_to_partial_run_.end()) { + *cancellation_manager = it->second->cancellation_manager.get(); + return false; + } + + std::unique_ptr<PartialRunState> partial_run = MakeUnique<PartialRunState>(); + partial_run->cancellation_manager = MakeUnique<CancellationManager>(); + *cancellation_manager = partial_run->cancellation_manager.get(); + step_id_to_partial_run_[step_id] = std::move(partial_run); + return true; +} + +void PartialRunMgr::ExecutorDone(int step_id, const Status& executor_status) { + StatusCallback done; + Status callback_status; + { + mutex_lock l(mu_); + auto run_it = step_id_to_partial_run_.find(step_id); + if (run_it == step_id_to_partial_run_.end()) { + return; + } + // If we found the partial_run, we call the final callback, if it + // exists. + // It is guaranteed that run_it->second->final_callback is left empty + // after the std::move call. + done = std::move(run_it->second->final_callback); + if (!executor_status.ok()) { + run_it->second->final_status = executor_status; + } + callback_status = run_it->second->final_status; + run_it->second->executor_done = true; + } + if (done != nullptr) { + done(callback_status); + mutex_lock l(mu_); + step_id_to_partial_run_.erase(step_id); + } +} + +void PartialRunMgr::PartialRunDone(int step_id, StatusCallback done, + const Status& status) { + Status callback_status; + { + mutex_lock l(mu_); + auto run_it = step_id_to_partial_run_.find(step_id); + if (run_it == step_id_to_partial_run_.end()) { + return; + } + run_it->second->final_status.Update(status); + if (!run_it->second->executor_done) { + // If we found the partial_run, we set the final callback to call only + // when the executor is completely done. + run_it->second->final_callback = std::move(done); + return; + } + callback_status = run_it->second->final_status; + } + // Otherwise we call the callback immediately. + done(callback_status); + mutex_lock l(mu_); + step_id_to_partial_run_.erase(step_id); +} + +} // namespace tensorflow |