aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-10-03 12:39:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 12:49:10 -0700
commit19833284cc8fa555115aacde350ad66652b250dc (patch)
treefdf9a943d73670c35c89187cadb32a16878aa78a /tensorflow/compiler
parent506ea0b8d3af1b54f42721584a414957e1525c8a (diff)
Automated rollback of commit 2af8fd975aaf5c70ebb396895fa15a8f034a8440
PiperOrigin-RevId: 215608349
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc129
1 files changed, 39 insertions, 90 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 28e09d7b79..36c6f5d316 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -79,10 +79,7 @@ Status FunctionalizeControlFlowForFunction(
const string& func_name, const string& new_func_name,
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
- std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
- bool* modified) {
- *modified = false;
-
+ std::map<string, string>* canonicalized_name_to_new_name) {
// Convert the function to Graph.
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
@@ -95,19 +92,6 @@ Status FunctionalizeControlFlowForFunction(
});
const FunctionBody* body = flr->GetFunctionBody(handle);
- // Check if the graph has Switch or Merge node before optimizing the graph.
- bool has_switch_or_merge = false;
- for (Node* n : body->graph->nodes()) {
- if (n->type_string() == "Switch" || n->type_string() == "Merge") {
- has_switch_or_merge = true;
- break;
- }
- }
- // We cannot return here directly if the graph has no Switch/Merge.
- // It might contain function call nodes, or If/While nodes with Switch/Merge
- // in function body. We still need to rewrite those functions and modify
- // corresponding nodes.
-
// Call graph optimizer. The most important optimization we need is constant
// folding, which will replace ops like Shape/BroadcastGradientArgs with
// constant shape input. Without this optimization, those ops might become
@@ -145,13 +129,6 @@ Status FunctionalizeControlFlowForFunction(
absl::StrCat("functionalize_control_flow_after_opt_", func_name),
*optimized_graph, fld);
}
- // Some inlined functions might have Switch/Merge nodes.
- for (Node* n : optimized_graph->nodes()) {
- if (n->type_string() == "Switch" || n->type_string() == "Merge") {
- has_switch_or_merge = true;
- break;
- }
- }
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
@@ -174,15 +151,10 @@ Status FunctionalizeControlFlowForFunction(
Canonicalize(name, AttrSlice(&associated_function.attrs()));
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
string new_name;
- bool function_modified;
if (iter != canonicalized_name_to_new_name->end()) {
- // If we already processed this function, check if it was rewritten. If
- // the function was rewritten, the entry will be non-empty. Otherwise
- // the entry will be empty.
- function_modified = iter->second.has_value();
- if (function_modified) {
- new_name = iter->second.value();
- }
+ // If we already functionalized this function, skip functionalization
+ // but still rewrite the node.
+ new_name = iter->second;
} else {
if (associated_function.type() ==
AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
@@ -194,62 +166,42 @@ Status FunctionalizeControlFlowForFunction(
}
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
name, new_name, associated_function.attrs(), fld, flr,
- canonicalized_name_to_new_name, &function_modified));
- if (function_modified) {
- // If the function was rewritten, add an non-empty entry. So later we
- // know we have processed this function, and it was rewritten into
- // another function.
- (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
- } else {
- // If the function was not rewritten, add an empty entry. So later
- // we know we have processed this function, and it does not need to be
- // rewritten.
- (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt;
- }
- }
- if (function_modified) {
- *modified = true;
-
- // Notice that if "n" is a function call, RewriteAssociatedFunction()
- // will delete it and create a new node instead, making "n" an invalid
- // pointer. That's fine because in that case, associated_functions will
- // only have one member and the loop will only run once.
- TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- optimized_graph.get(), n, fld, associated_function, new_name));
+ canonicalized_name_to_new_name));
+ (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
}
+ // Notice that if "n" is a function call, RewriteAssociatedFunction() will
+ // delete it and create a new node instead, making "n" an invalid pointer.
+ // That's fine because in that case, associated_functions will only have
+ // one member and the loop will only run once.
+ TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
+ optimized_graph.get(), n, fld, associated_function, new_name));
}
}
- if (has_switch_or_merge) {
- *modified = true;
-
- // Functionalize the function body.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
- *optimized_graph, fld);
- }
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
- *optimized_graph, fld);
- }
+ // Functionalize the function body.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
+ *optimized_graph, fld);
}
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
+ *optimized_graph, fld);
+ }
+ FunctionDef functionalized_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
+ &functionalized_fdef));
- if (*modified) {
- // Add rewritten FunctionDef into library.
- FunctionDef functionalized_fdef;
- TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
- &functionalized_fdef));
- if (func_name == new_func_name) {
- VLOG(2) << "Replacing function " << func_name;
- TF_RETURN_IF_ERROR(
- fld->ReplaceFunction(new_func_name, functionalized_fdef));
- } else {
- VLOG(2) << "Adding function " << new_func_name;
- TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
- }
+ // Add rewritten FunctionDef into library.
+ if (func_name == new_func_name) {
+ VLOG(2) << "Replacing function " << func_name;
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(new_func_name, functionalized_fdef));
+ } else {
+ VLOG(2) << "Adding function " << new_func_name;
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
}
return ret_status;
@@ -275,7 +227,7 @@ Status FunctionalizeControlFlowPass::Run(
{"TPUCompile", "function"},
{"XlaLaunch", "function"},
};
- std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
+ std::map<string, string> canonicalized_name_to_new_name;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
@@ -290,15 +242,12 @@ Status FunctionalizeControlFlowPass::Run(
<< ". Corresponding function: " << func.name();
string new_func_name = options.flib_def->UniqueFunctionName(
absl::StrCat(func.name(), "_f15n_"));
- bool modified;
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
func.name(), new_func_name, func.attr(), options.flib_def, flr,
- &canonicalized_name_to_new_name, &modified));
- if (modified) {
- n->ClearAttr(func_attr);
- func.set_name(new_func_name);
- n->AddAttr(func_attr, func);
- }
+ &canonicalized_name_to_new_name));
+ n->ClearAttr(func_attr);
+ func.set_name(new_func_name);
+ n->AddAttr(func_attr, func);
}
}