aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/direct_session.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-01-14 07:35:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-14 07:45:34 -0800
commit8eb161e39185d364e04a6edf2beadc2bb5cb978c (patch)
tree1256e6088fdf998e76a430c791e022f52b7bd16c /tensorflow/core/common_runtime/direct_session.cc
parent7ab67d8b526d6f01c8c8347993d68f8f5074c184 (diff)
Add support for passes that run post-partitioning to OptimizationRegistry.
To avoid another GraphDef -> Graph -> GraphDef conversion, change Device::MaybeRewriteGraph to take a Graph instead of a GraphDef. Use std::unique_ptr<> in more places to avoid some awkward .release() magic. Change: 144532446
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc53
1 files changed, 31 insertions, 22 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 7dc6db682e..85ce9d772a 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/simple_placer.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
@@ -940,7 +941,7 @@ Status DirectSession::GetOrCreateExecutors(
GraphOptimizer optimizer(optimizer_opts);
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
const string& partition_name = iter->first;
- Graph* partition_graph = iter->second.get();
+ std::unique_ptr<Graph>& partition_graph = iter->second;
const int graph_def_version = partition_graph->versions().producer();
Device* device;
@@ -980,24 +981,23 @@ Status DirectSession::GetOrCreateExecutors(
};
params.node_outputs_cb = node_outputs_callback_;
- partition_graph = iter->second.release();
- optimizer.Optimize(lib, options_.env, device, &partition_graph);
+ optimizer.Optimize(lib, options_.env, device, &iter->second);
// EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph
if (run_state_args->debugger_state) {
TF_RETURN_IF_ERROR(run_state_args->debugger_state->DecorateGraphForDebug(
- partition_graph, params.device));
+ partition_graph.get(), params.device));
}
- iter->second.reset(partition_graph);
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
- device->name(), partition_graph));
+ device->name(),
+ partition_graph.get()));
// NewLocalExecutor takes ownership of partition_graph.
- item->graph = partition_graph;
+ item->graph = partition_graph.get();
item->executor = nullptr;
Executor* executor;
TF_RETURN_IF_ERROR(
- NewLocalExecutor(params, iter->second.release(), &executor));
+ NewLocalExecutor(params, partition_graph.release(), &executor));
item->executor.reset(executor);
}
@@ -1118,12 +1118,31 @@ Status DirectSession::CreateGraphs(
}
}
+ for (const auto& partition : partitions) {
+ std::unique_ptr<Graph> device_graph(
+ new Graph(client_graph->flib_def.get()));
+ GraphConstructorOptions device_opts;
+ // There are internal operations (e.g., send/recv) that we now allow.
+ device_opts.allow_internal_ops = true;
+ device_opts.expect_device_spec = true;
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
+ device_graph.get()));
+ outputs->emplace(partition.first, std::move(device_graph));
+ }
+
+ GraphOptimizationPassOptions optimization_options;
+ optimization_options.session_options = &options_;
+ optimization_options.flib_def = client_graph->flib_def.get();
+ optimization_options.partition_graphs = outputs;
+ TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
+
Status s;
- for (auto&& partition : partitions) {
+ for (auto& partition : *outputs) {
const string& partition_name = partition.first;
+ std::unique_ptr<Graph>* graph = &partition.second;
- GraphDef* graph_def = &partition.second;
- VLOG(2) << "Created " << ProtoDebugString(*graph_def) << " for "
+ VLOG(2) << "Created " << DebugString(graph->get()) << " for "
<< partition_name;
// Give the device an opportunity to rewrite its subgraph.
@@ -1134,20 +1153,10 @@ Status DirectSession::CreateGraphs(
// may be possible use cases where a device may want to modify
// function definitions - in which case the library would need to be
// replicated per device.
- s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph_def);
+ s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph);
if (!s.ok()) {
break;
}
- std::unique_ptr<Graph> device_graph(
- new Graph(client_graph->flib_def.get()));
- GraphConstructorOptions device_opts;
- // There are internal operations (e.g., send/recv) that we now
- // allow.
- device_opts.allow_internal_ops = true;
- device_opts.expect_device_spec = true;
- TF_RETURN_IF_ERROR(
- ConvertGraphDefToGraph(device_opts, *graph_def, device_graph.get()));
- outputs->emplace(partition_name, std::move(device_graph));
}
*flib_def = std::move(client_graph->flib_def);
return s;