diff options
author | Tong Shen <endlessroad@google.com> | 2018-09-25 15:30:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 15:34:13 -0700 |
commit | f0f475690cb728d7988328d32a5955c55ab1fb22 (patch) | |
tree | 518ada04c54defdf07d22e0daeb557d62bfeb68d /tensorflow/compiler/tf2xla | |
parent | ad27440a79c30a53f9fd2a3171a2c2da6ff37820 (diff) |
Optimize function before functionalization.
PiperOrigin-RevId: 214515610
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow.cc | 43 |
1 files changed, 32 insertions, 11 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index f792c52032..98b333a467 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -31,11 +31,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -89,7 +91,6 @@ Status FunctionalizeControlFlowForFunction( } }); const FunctionBody* body = flr->GetFunctionBody(handle); - const FunctionDef& fdef = body->fdef; // If any node has associated functions, functionalize them first. // Gather nodes with associated functions first, because rewriting those nodes @@ -130,26 +131,46 @@ Status FunctionalizeControlFlowForFunction( } } + // Call graph optimizer. The most important optimization we need is constant + // folding, which will replace ops like Shape/BroadcastGradientArgs with + // constant shape input. Without this optimization, those ops might become + // dynamic input for then/else body function and XLA will complain that input + // is not compile time constant. We enable function inlining as well, because + // otherwise we won't be able to infer shape for any node depending on + // function call nodes. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_opt_", func_name), + *body->graph, fld); + } + // Optimizer accepts std::unique_ptr<Graph>* as input and might change + // underlying pointer, thus we create a new Graph and copy from body->graph. + std::unique_ptr<Graph> optimized_graph(new Graph(fld)); + CopyGraph(*body->graph, optimized_graph.get()); + OptimizerOptions opts; + opts.set_opt_level(OptimizerOptions::L0); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); + GraphOptimizer optimizer(opts); + optimizer.Optimize(flr, flr->env(), + /*device=*/nullptr, &optimized_graph, + /*shape_map=*/nullptr); + // Functionalize the function body. if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *body->graph, fld); + *optimized_graph, fld); } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *body->graph, fld); + *optimized_graph, fld); } FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef)); - - // Copy signature and ret from original FunctionDef. - *functionalized_fdef.mutable_signature() = fdef.signature(); - *functionalized_fdef.mutable_ret() = fdef.ret(); - functionalized_fdef.mutable_signature()->set_name(new_func_name); + TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, + &functionalized_fdef)); // Add rewritten FunctionDef into library. if (func_name == new_func_name) { |