diff options
author | 2018-05-22 13:49:08 -0700 | |
---|---|---|
committer | 2018-05-22 13:51:22 -0700 | |
commit | 9d2c6ff2a542b9bd89b42e3b88e6299eae9bdcc4 (patch) | |
tree | 301c1026d384945565a96226d66180e1e950b3b3 /tensorflow/core/common_runtime/direct_session.cc | |
parent | 4d134bad0403ebb5722144d8f859a04a5f21efc2 (diff) |
Collective Ops Part 7
Complete just enough of the core implementation to run
multi-device collectives locally within a single process.
Interfaces are still private and not availble for general use.
PiperOrigin-RevId: 197617132
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 27 |
1 files changed, 26 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 0afbd02e86..07c1eafedc 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -19,15 +19,19 @@ limitations under the License. #include <string> #include <vector> +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/collective_param_resolver_local.h" #include "tensorflow/core/common_runtime/constant_folding.h" #include "tensorflow/core/common_runtime/debugger_state_interface.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_resolver_local.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb_text.h" @@ -443,6 +447,18 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, // Create a run state and start execution. RunState run_state(step_id, &devices_); run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); + // Set up for collectives if the RunOption declares a key. + if (run_options.experimental().collective_graph_key() > 0) { + if (!collective_executor_mgr_) { + DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get()); + collective_executor_mgr_.reset(new CollectiveExecutorMgr( + options_.config, device_mgr_.get(), drl, + new CollectiveParamResolverLocal(device_mgr_.get(), drl, + "/job:localhost/replica:0/task:0"))); + } + run_state.collective_executor.reset(new CollectiveExecutor::Handle( + collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/)); + } // Start parallel Executors. const size_t num_executors = executors_and_keys->items.size(); @@ -459,6 +475,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, args.step_id = step_id; args.call_frame = call_frame; args.rendezvous = run_state.rendez; + args.collective_executor = + (run_state.collective_executor ? run_state.collective_executor->get() + : nullptr); CancellationManager step_cancellation_manager; args.cancellation_manager = &step_cancellation_manager; args.session_state = &session_state_; @@ -768,6 +787,10 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names, args.rendezvous = run_state->rendez; args.cancellation_manager = cancellation_manager_; + // Note that Collectives are not supported in partial runs + // because RunOptions is not passed in so we can't know whether + // their use is intended. + args.collective_executor = nullptr; args.runner = [this, pool](Executor::Args::Closure c) { SchedClosure(pool, std::move(c)); }; @@ -1518,11 +1541,13 @@ DirectSession::RunState::RunState( const std::vector<string>& pending_input_names, const std::vector<string>& pending_output_names, int64 step_id, const std::vector<Device*>* devices) - : step_container(step_id, [devices](const string& name) { + : step_container(step_id, [devices, step_id](const string& name) { for (auto d : *devices) { if (!d->resource_manager()->Cleanup(name).ok()) { // Do nothing... } + ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr(); + if (sam) sam->Cleanup(step_id); } }) { // Initially all the feeds and fetches are pending. |