aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/direct_session.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-22 13:49:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-22 13:51:22 -0700
commit9d2c6ff2a542b9bd89b42e3b88e6299eae9bdcc4 (patch)
tree301c1026d384945565a96226d66180e1e950b3b3 /tensorflow/core/common_runtime/direct_session.cc
parent4d134bad0403ebb5722144d8f859a04a5f21efc2 (diff)
Collective Ops Part 7
Complete just enough of the core implementation to run multi-device collectives locally within a single process. Interfaces are still private and not availble for general use. PiperOrigin-RevId: 197617132
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc27
1 files changed, 26 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0afbd02e86..07c1eafedc 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -19,15 +19,19 @@ limitations under the License.
#include <string>
#include <vector>
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb_text.h"
@@ -443,6 +447,18 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
// Create a run state and start execution.
RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
+ // Set up for collectives if the RunOption declares a key.
+ if (run_options.experimental().collective_graph_key() > 0) {
+ if (!collective_executor_mgr_) {
+ DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
+ collective_executor_mgr_.reset(new CollectiveExecutorMgr(
+ options_.config, device_mgr_.get(), drl,
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl,
+ "/job:localhost/replica:0/task:0")));
+ }
+ run_state.collective_executor.reset(new CollectiveExecutor::Handle(
+ collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
+ }
// Start parallel Executors.
const size_t num_executors = executors_and_keys->items.size();
@@ -459,6 +475,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
args.step_id = step_id;
args.call_frame = call_frame;
args.rendezvous = run_state.rendez;
+ args.collective_executor =
+ (run_state.collective_executor ? run_state.collective_executor->get()
+ : nullptr);
CancellationManager step_cancellation_manager;
args.cancellation_manager = &step_cancellation_manager;
args.session_state = &session_state_;
@@ -768,6 +787,10 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
args.rendezvous = run_state->rendez;
args.cancellation_manager = cancellation_manager_;
+ // Note that Collectives are not supported in partial runs
+ // because RunOptions is not passed in so we can't know whether
+ // their use is intended.
+ args.collective_executor = nullptr;
args.runner = [this, pool](Executor::Args::Closure c) {
SchedClosure(pool, std::move(c));
};
@@ -1518,11 +1541,13 @@ DirectSession::RunState::RunState(
const std::vector<string>& pending_input_names,
const std::vector<string>& pending_output_names, int64 step_id,
const std::vector<Device*>* devices)
- : step_container(step_id, [devices](const string& name) {
+ : step_container(step_id, [devices, step_id](const string& name) {
for (auto d : *devices) {
if (!d->resource_manager()->Cleanup(name).ok()) {
// Do nothing...
}
+ ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
+ if (sam) sam->Cleanup(step_id);
}
}) {
// Initially all the feeds and fetches are pending.