aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-10-05 12:17:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 12:22:40 -0700
commitd016650ca7636c96c6664bed2cf3a2fa8a3c674b (patch)
treef7ea332e174ea135ab07fa2536224cf6bc908b6e /tensorflow/compiler
parent0541a277d5c74cf8e99c9f5a7a015926d1a05214 (diff)
Revert constant folding to previous state.
PiperOrigin-RevId: 215946205
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc64
1 files changed, 10 insertions, 54 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 28e09d7b79..0362682bd6 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -94,8 +94,9 @@ Status FunctionalizeControlFlowForFunction(
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
+ Graph* g = body->graph;
- // Check if the graph has Switch or Merge node before optimizing the graph.
+ // Check if the graph has Switch or Merge node.
bool has_switch_or_merge = false;
for (Node* n : body->graph->nodes()) {
if (n->type_string() == "Switch" || n->type_string() == "Merge") {
@@ -108,58 +109,13 @@ Status FunctionalizeControlFlowForFunction(
// in function body. We still need to rewrite those functions and modify
// corresponding nodes.
- // Call graph optimizer. The most important optimization we need is constant
- // folding, which will replace ops like Shape/BroadcastGradientArgs with
- // constant shape input. Without this optimization, those ops might become
- // dynamic input for then/else body function and XLA will complain that input
- // is not compile time constant. We enable function inlining as well, because
- // otherwise we won't be able to infer shape for any node depending on
- // function call nodes.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_opt_", func_name),
- *body->graph, fld);
- }
- // Optimizer accepts std::unique_ptr<Graph>* as input and might change
- // underlying pointer, thus we create a new Graph and copy from body->graph.
- std::unique_ptr<Graph> optimized_graph(new Graph(fld));
- CopyGraph(*body->graph, optimized_graph.get());
- OptimizerOptions opts;
- opts.set_opt_level(OptimizerOptions::L0);
- opts.set_do_function_inlining(true);
- opts.set_do_constant_folding(true);
- GraphOptimizer optimizer(opts);
- auto cf_consider_fn = [](const Node* n) {
- // Skip SymbolicGradient op when doing constant folding.
- // Enabling SymbolicGradient op in constant folding requires
- // flr->device() to be non-null, and here we have not constructed
- // proper Device object yet (it will be constructed in XlaCompiler).
- return n->type_string() != FunctionLibraryDefinition::kGradientOp;
- };
- optimizer.Optimize(flr, flr->env(),
- /*device=*/nullptr, &optimized_graph,
- /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
- cf_consider_fn);
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_after_opt_", func_name),
- *optimized_graph, fld);
- }
- // Some inlined functions might have Switch/Merge nodes.
- for (Node* n : optimized_graph->nodes()) {
- if (n->type_string() == "Switch" || n->type_string() == "Merge") {
- has_switch_or_merge = true;
- break;
- }
- }
-
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
// might involve node deletion/addition. Avoid modifying nodes while iterating
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
- for (auto* n : optimized_graph->nodes()) {
+ for (auto* n : g->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, flr);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
@@ -215,7 +171,7 @@ Status FunctionalizeControlFlowForFunction(
// pointer. That's fine because in that case, associated_functions will
// only have one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- optimized_graph.get(), n, fld, associated_function, new_name));
+ g, n, fld, associated_function, new_name));
}
}
}
@@ -227,21 +183,21 @@ Status FunctionalizeControlFlowForFunction(
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
- *optimized_graph, fld);
+ *g, fld);
}
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
- *optimized_graph, fld);
+ absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
+ fld);
}
}
if (*modified) {
// Add rewritten FunctionDef into library.
FunctionDef functionalized_fdef;
- TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
- &functionalized_fdef));
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
if (func_name == new_func_name) {
VLOG(2) << "Replacing function " << func_name;
TF_RETURN_IF_ERROR(