diff options
author | Tong Shen <endlessroad@google.com> | 2018-10-03 12:39:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 12:49:10 -0700 |
commit | 19833284cc8fa555115aacde350ad66652b250dc (patch) | |
tree | fdf9a943d73670c35c89187cadb32a16878aa78a /tensorflow/compiler/tf2xla | |
parent | 506ea0b8d3af1b54f42721584a414957e1525c8a (diff) |
Automated rollback of commit 2af8fd975aaf5c70ebb396895fa15a8f034a8440
PiperOrigin-RevId: 215608349
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow.cc | 129 |
1 files changed, 39 insertions, 90 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 28e09d7b79..36c6f5d316 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -79,10 +79,7 @@ Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map<string, tensorflow::AttrValue>& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map<string, absl::optional<string>>* canonicalized_name_to_new_name, - bool* modified) { - *modified = false; - + std::map<string, string>* canonicalized_name_to_new_name) { // Convert the function to Graph. FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); @@ -95,19 +92,6 @@ Status FunctionalizeControlFlowForFunction( }); const FunctionBody* body = flr->GetFunctionBody(handle); - // Check if the graph has Switch or Merge node before optimizing the graph. - bool has_switch_or_merge = false; - for (Node* n : body->graph->nodes()) { - if (n->type_string() == "Switch" || n->type_string() == "Merge") { - has_switch_or_merge = true; - break; - } - } - // We cannot return here directly if the graph has no Switch/Merge. - // It might contain function call nodes, or If/While nodes with Switch/Merge - // in function body. We still need to rewrite those functions and modify - // corresponding nodes. - // Call graph optimizer. The most important optimization we need is constant // folding, which will replace ops like Shape/BroadcastGradientArgs with // constant shape input. Without this optimization, those ops might become @@ -145,13 +129,6 @@ Status FunctionalizeControlFlowForFunction( absl::StrCat("functionalize_control_flow_after_opt_", func_name), *optimized_graph, fld); } - // Some inlined functions might have Switch/Merge nodes. - for (Node* n : optimized_graph->nodes()) { - if (n->type_string() == "Switch" || n->type_string() == "Merge") { - has_switch_or_merge = true; - break; - } - } // If any node has associated functions, functionalize them first. // Gather nodes with associated functions first, because rewriting those nodes @@ -174,15 +151,10 @@ Status FunctionalizeControlFlowForFunction( Canonicalize(name, AttrSlice(&associated_function.attrs())); auto iter = canonicalized_name_to_new_name->find(canonicalized_name); string new_name; - bool function_modified; if (iter != canonicalized_name_to_new_name->end()) { - // If we already processed this function, check if it was rewritten. If - // the function was rewritten, the entry will be non-empty. Otherwise - // the entry will be empty. - function_modified = iter->second.has_value(); - if (function_modified) { - new_name = iter->second.value(); - } + // If we already functionalized this function, skip functionalization + // but still rewrite the node. + new_name = iter->second; } else { if (associated_function.type() == AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { @@ -194,62 +166,42 @@ Status FunctionalizeControlFlowForFunction( } TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( name, new_name, associated_function.attrs(), fld, flr, - canonicalized_name_to_new_name, &function_modified)); - if (function_modified) { - // If the function was rewritten, add an non-empty entry. So later we - // know we have processed this function, and it was rewritten into - // another function. - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; - } else { - // If the function was not rewritten, add an empty entry. So later - // we know we have processed this function, and it does not need to be - // rewritten. - (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; - } - } - if (function_modified) { - *modified = true; - - // Notice that if "n" is a function call, RewriteAssociatedFunction() - // will delete it and create a new node instead, making "n" an invalid - // pointer. That's fine because in that case, associated_functions will - // only have one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - optimized_graph.get(), n, fld, associated_function, new_name)); + canonicalized_name_to_new_name)); + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; } + // Notice that if "n" is a function call, RewriteAssociatedFunction() will + // delete it and create a new node instead, making "n" an invalid pointer. + // That's fine because in that case, associated_functions will only have + // one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + optimized_graph.get(), n, fld, associated_function, new_name)); } } - if (has_switch_or_merge) { - *modified = true; - - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *optimized_graph, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *optimized_graph, fld); - } + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *optimized_graph, fld); } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), + *optimized_graph, fld); + } + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, + &functionalized_fdef)); - if (*modified) { - // Add rewritten FunctionDef into library. - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, - &functionalized_fdef)); - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; - TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); - } + // Add rewritten FunctionDef into library. + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); } return ret_status; @@ -275,7 +227,7 @@ Status FunctionalizeControlFlowPass::Run( {"TPUCompile", "function"}, {"XlaLaunch", "function"}, }; - std::map<string, absl::optional<string>> canonicalized_name_to_new_name; + std::map<string, string> canonicalized_name_to_new_name; for (Node* n : graph->nodes()) { auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); if (it == kNodeTypeToFunctionAttrMapping->end()) { @@ -290,15 +242,12 @@ Status FunctionalizeControlFlowPass::Run( << ". Corresponding function: " << func.name(); string new_func_name = options.flib_def->UniqueFunctionName( absl::StrCat(func.name(), "_f15n_")); - bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name, &modified)); - if (modified) { - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); - } + &canonicalized_name_to_new_name)); + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); } } |