diff options
author | Tong Shen <endlessroad@google.com> | 2018-10-04 11:24:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 11:28:53 -0700 |
commit | c8d5054e8c12800f0c3db0e51f3d5902e04eaa37 (patch) | |
tree | e9e2ee47e91ecb831faa7666e2213967b650c921 /tensorflow/compiler | |
parent | 6850dafeeaaa48efa748134688844bd079ef3949 (diff) |
Roll forward change "Skip control flow functionalization if there is no Switch or Merge node.".
PiperOrigin-RevId: 215772272
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow.cc | 129 |
1 files changed, 90 insertions, 39 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 36c6f5d316..28e09d7b79 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -79,7 +79,10 @@ Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map<string, tensorflow::AttrValue>& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map<string, string>* canonicalized_name_to_new_name) { + std::map<string, absl::optional<string>>* canonicalized_name_to_new_name, + bool* modified) { + *modified = false; + // Convert the function to Graph. FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); @@ -92,6 +95,19 @@ Status FunctionalizeControlFlowForFunction( }); const FunctionBody* body = flr->GetFunctionBody(handle); + // Check if the graph has Switch or Merge node before optimizing the graph. + bool has_switch_or_merge = false; + for (Node* n : body->graph->nodes()) { + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } + } + // We cannot return here directly if the graph has no Switch/Merge. + // It might contain function call nodes, or If/While nodes with Switch/Merge + // in function body. We still need to rewrite those functions and modify + // corresponding nodes. + // Call graph optimizer. The most important optimization we need is constant // folding, which will replace ops like Shape/BroadcastGradientArgs with // constant shape input. Without this optimization, those ops might become @@ -129,6 +145,13 @@ Status FunctionalizeControlFlowForFunction( absl::StrCat("functionalize_control_flow_after_opt_", func_name), *optimized_graph, fld); } + // Some inlined functions might have Switch/Merge nodes. + for (Node* n : optimized_graph->nodes()) { + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } + } // If any node has associated functions, functionalize them first. // Gather nodes with associated functions first, because rewriting those nodes @@ -151,10 +174,15 @@ Status FunctionalizeControlFlowForFunction( Canonicalize(name, AttrSlice(&associated_function.attrs())); auto iter = canonicalized_name_to_new_name->find(canonicalized_name); string new_name; + bool function_modified; if (iter != canonicalized_name_to_new_name->end()) { - // If we already functionalized this function, skip functionalization - // but still rewrite the node. - new_name = iter->second; + // If we already processed this function, check if it was rewritten. If + // the function was rewritten, the entry will be non-empty. Otherwise + // the entry will be empty. + function_modified = iter->second.has_value(); + if (function_modified) { + new_name = iter->second.value(); + } } else { if (associated_function.type() == AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { @@ -166,42 +194,62 @@ Status FunctionalizeControlFlowForFunction( } TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( name, new_name, associated_function.attrs(), fld, flr, - canonicalized_name_to_new_name)); - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + canonicalized_name_to_new_name, &function_modified)); + if (function_modified) { + // If the function was rewritten, add an non-empty entry. So later we + // know we have processed this function, and it was rewritten into + // another function. + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + } else { + // If the function was not rewritten, add an empty entry. So later + // we know we have processed this function, and it does not need to be + // rewritten. + (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; + } + } + if (function_modified) { + *modified = true; + + // Notice that if "n" is a function call, RewriteAssociatedFunction() + // will delete it and create a new node instead, making "n" an invalid + // pointer. That's fine because in that case, associated_functions will + // only have one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + optimized_graph.get(), n, fld, associated_function, new_name)); } - // Notice that if "n" is a function call, RewriteAssociatedFunction() will - // delete it and create a new node instead, making "n" an invalid pointer. - // That's fine because in that case, associated_functions will only have - // one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - optimized_graph.get(), n, fld, associated_function, new_name)); } } - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *optimized_graph, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *optimized_graph, fld); + if (has_switch_or_merge) { + *modified = true; + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *optimized_graph, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), + *optimized_graph, fld); + } } - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, - &functionalized_fdef)); - // Add rewritten FunctionDef into library. - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; - TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + if (*modified) { + // Add rewritten FunctionDef into library. + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, + &functionalized_fdef)); + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } } return ret_status; @@ -227,7 +275,7 @@ Status FunctionalizeControlFlowPass::Run( {"TPUCompile", "function"}, {"XlaLaunch", "function"}, }; - std::map<string, string> canonicalized_name_to_new_name; + std::map<string, absl::optional<string>> canonicalized_name_to_new_name; for (Node* n : graph->nodes()) { auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); if (it == kNodeTypeToFunctionAttrMapping->end()) { @@ -242,12 +290,15 @@ Status FunctionalizeControlFlowPass::Run( << ". Corresponding function: " << func.name(); string new_func_name = options.flib_def->UniqueFunctionName( absl::StrCat(func.name(), "_f15n_")); + bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name)); - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); + &canonicalized_name_to_new_name, &modified)); + if (modified) { + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); + } } } |