diff options
author | 2017-01-14 07:35:00 -0800 | |
---|---|---|
committer | 2017-01-14 07:45:34 -0800 | |
commit | 8eb161e39185d364e04a6edf2beadc2bb5cb978c (patch) | |
tree | 1256e6088fdf998e76a430c791e022f52b7bd16c /tensorflow/core/common_runtime/direct_session.cc | |
parent | 7ab67d8b526d6f01c8c8347993d68f8f5074c184 (diff) |
Add support for passes that run post-partitioning to OptimizationRegistry.
To avoid another GraphDef -> Graph -> GraphDef conversion, change Device::MaybeRewriteGraph to take a Graph instead of a GraphDef.
Use std::unique_ptr<> in more places to avoid some awkward .release() magic.
Change: 144532446
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 53 |
1 files changed, 31 insertions, 22 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 7dc6db682e..85ce9d772a 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/simple_placer.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/function.h" @@ -940,7 +941,7 @@ Status DirectSession::GetOrCreateExecutors( GraphOptimizer optimizer(optimizer_opts); for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { const string& partition_name = iter->first; - Graph* partition_graph = iter->second.get(); + std::unique_ptr<Graph>& partition_graph = iter->second; const int graph_def_version = partition_graph->versions().producer(); Device* device; @@ -980,24 +981,23 @@ Status DirectSession::GetOrCreateExecutors( }; params.node_outputs_cb = node_outputs_callback_; - partition_graph = iter->second.release(); - optimizer.Optimize(lib, options_.env, device, &partition_graph); + optimizer.Optimize(lib, options_.env, device, &iter->second); // EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph if (run_state_args->debugger_state) { TF_RETURN_IF_ERROR(run_state_args->debugger_state->DecorateGraphForDebug( - partition_graph, params.device)); + partition_graph.get(), params.device)); } - iter->second.reset(partition_graph); TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()), - device->name(), partition_graph)); + device->name(), + partition_graph.get())); // NewLocalExecutor takes ownership of partition_graph. - item->graph = partition_graph; + item->graph = partition_graph.get(); item->executor = nullptr; Executor* executor; TF_RETURN_IF_ERROR( - NewLocalExecutor(params, iter->second.release(), &executor)); + NewLocalExecutor(params, partition_graph.release(), &executor)); item->executor.reset(executor); } @@ -1118,12 +1118,31 @@ Status DirectSession::CreateGraphs( } } + for (const auto& partition : partitions) { + std::unique_ptr<Graph> device_graph( + new Graph(client_graph->flib_def.get())); + GraphConstructorOptions device_opts; + // There are internal operations (e.g., send/recv) that we now allow. + device_opts.allow_internal_ops = true; + device_opts.expect_device_spec = true; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second, + device_graph.get())); + outputs->emplace(partition.first, std::move(device_graph)); + } + + GraphOptimizationPassOptions optimization_options; + optimization_options.session_options = &options_; + optimization_options.flib_def = client_graph->flib_def.get(); + optimization_options.partition_graphs = outputs; + TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); + Status s; - for (auto&& partition : partitions) { + for (auto& partition : *outputs) { const string& partition_name = partition.first; + std::unique_ptr<Graph>* graph = &partition.second; - GraphDef* graph_def = &partition.second; - VLOG(2) << "Created " << ProtoDebugString(*graph_def) << " for " + VLOG(2) << "Created " << DebugString(graph->get()) << " for " << partition_name; // Give the device an opportunity to rewrite its subgraph. @@ -1134,20 +1153,10 @@ Status DirectSession::CreateGraphs( // may be possible use cases where a device may want to modify // function definitions - in which case the library would need to be // replicated per device. - s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph_def); + s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph); if (!s.ok()) { break; } - std::unique_ptr<Graph> device_graph( - new Graph(client_graph->flib_def.get())); - GraphConstructorOptions device_opts; - // There are internal operations (e.g., send/recv) that we now - // allow. - device_opts.allow_internal_ops = true; - device_opts.expect_device_spec = true; - TF_RETURN_IF_ERROR( - ConvertGraphDefToGraph(device_opts, *graph_def, device_graph.get())); - outputs->emplace(partition_name, std::move(device_graph)); } *flib_def = std::move(client_graph->flib_def); return s; |