aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-09-25 15:30:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 15:34:13 -0700
commitf0f475690cb728d7988328d32a5955c55ab1fb22 (patch)
tree518ada04c54defdf07d22e0daeb557d62bfeb68d /tensorflow/compiler/tf2xla
parentad27440a79c30a53f9fd2a3171a2c2da6ff37820 (diff)
Optimize function before functionalization.
PiperOrigin-RevId: 214515610
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc43
1 files changed, 32 insertions, 11 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index f792c52032..98b333a467 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -31,11 +31,13 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -89,7 +91,6 @@ Status FunctionalizeControlFlowForFunction(
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
- const FunctionDef& fdef = body->fdef;
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
@@ -130,26 +131,46 @@ Status FunctionalizeControlFlowForFunction(
}
}
+ // Call graph optimizer. The most important optimization we need is constant
+ // folding, which will replace ops like Shape/BroadcastGradientArgs with
+ // constant shape input. Without this optimization, those ops might become
+ // dynamic input for then/else body function and XLA will complain that input
+ // is not compile time constant. We enable function inlining as well, because
+ // otherwise we won't be able to infer shape for any node depending on
+ // function call nodes.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_opt_", func_name),
+ *body->graph, fld);
+ }
+ // Optimizer accepts std::unique_ptr<Graph>* as input and might change
+ // underlying pointer, thus we create a new Graph and copy from body->graph.
+ std::unique_ptr<Graph> optimized_graph(new Graph(fld));
+ CopyGraph(*body->graph, optimized_graph.get());
+ OptimizerOptions opts;
+ opts.set_opt_level(OptimizerOptions::L0);
+ opts.set_do_function_inlining(true);
+ opts.set_do_constant_folding(true);
+ GraphOptimizer optimizer(opts);
+ optimizer.Optimize(flr, flr->env(),
+ /*device=*/nullptr, &optimized_graph,
+ /*shape_map=*/nullptr);
+
// Functionalize the function body.
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
- *body->graph, fld);
+ *optimized_graph, fld);
}
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
- *body->graph, fld);
+ *optimized_graph, fld);
}
FunctionDef functionalized_fdef;
- TF_RETURN_IF_ERROR(
- GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef));
-
- // Copy signature and ret from original FunctionDef.
- *functionalized_fdef.mutable_signature() = fdef.signature();
- *functionalized_fdef.mutable_ret() = fdef.ret();
- functionalized_fdef.mutable_signature()->set_name(new_func_name);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
+ &functionalized_fdef));
// Add rewritten FunctionDef into library.
if (func_name == new_func_name) {