diff options
5 files changed, 122 insertions, 58 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 36c6f5d316..28e09d7b79 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -79,7 +79,10 @@ Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map<string, tensorflow::AttrValue>& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map<string, string>* canonicalized_name_to_new_name) { + std::map<string, absl::optional<string>>* canonicalized_name_to_new_name, + bool* modified) { + *modified = false; + // Convert the function to Graph. FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); @@ -92,6 +95,19 @@ Status FunctionalizeControlFlowForFunction( }); const FunctionBody* body = flr->GetFunctionBody(handle); + // Check if the graph has Switch or Merge node before optimizing the graph. + bool has_switch_or_merge = false; + for (Node* n : body->graph->nodes()) { + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } + } + // We cannot return here directly if the graph has no Switch/Merge. + // It might contain function call nodes, or If/While nodes with Switch/Merge + // in function body. We still need to rewrite those functions and modify + // corresponding nodes. + // Call graph optimizer. The most important optimization we need is constant // folding, which will replace ops like Shape/BroadcastGradientArgs with // constant shape input. Without this optimization, those ops might become @@ -129,6 +145,13 @@ Status FunctionalizeControlFlowForFunction( absl::StrCat("functionalize_control_flow_after_opt_", func_name), *optimized_graph, fld); } + // Some inlined functions might have Switch/Merge nodes. + for (Node* n : optimized_graph->nodes()) { + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } + } // If any node has associated functions, functionalize them first. // Gather nodes with associated functions first, because rewriting those nodes @@ -151,10 +174,15 @@ Status FunctionalizeControlFlowForFunction( Canonicalize(name, AttrSlice(&associated_function.attrs())); auto iter = canonicalized_name_to_new_name->find(canonicalized_name); string new_name; + bool function_modified; if (iter != canonicalized_name_to_new_name->end()) { - // If we already functionalized this function, skip functionalization - // but still rewrite the node. - new_name = iter->second; + // If we already processed this function, check if it was rewritten. If + // the function was rewritten, the entry will be non-empty. Otherwise + // the entry will be empty. + function_modified = iter->second.has_value(); + if (function_modified) { + new_name = iter->second.value(); + } } else { if (associated_function.type() == AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { @@ -166,42 +194,62 @@ Status FunctionalizeControlFlowForFunction( } TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( name, new_name, associated_function.attrs(), fld, flr, - canonicalized_name_to_new_name)); - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + canonicalized_name_to_new_name, &function_modified)); + if (function_modified) { + // If the function was rewritten, add an non-empty entry. So later we + // know we have processed this function, and it was rewritten into + // another function. + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + } else { + // If the function was not rewritten, add an empty entry. So later + // we know we have processed this function, and it does not need to be + // rewritten. + (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; + } + } + if (function_modified) { + *modified = true; + + // Notice that if "n" is a function call, RewriteAssociatedFunction() + // will delete it and create a new node instead, making "n" an invalid + // pointer. That's fine because in that case, associated_functions will + // only have one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + optimized_graph.get(), n, fld, associated_function, new_name)); } - // Notice that if "n" is a function call, RewriteAssociatedFunction() will - // delete it and create a new node instead, making "n" an invalid pointer. - // That's fine because in that case, associated_functions will only have - // one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - optimized_graph.get(), n, fld, associated_function, new_name)); } } - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *optimized_graph, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *optimized_graph, fld); + if (has_switch_or_merge) { + *modified = true; + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *optimized_graph, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), + *optimized_graph, fld); + } } - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, - &functionalized_fdef)); - // Add rewritten FunctionDef into library. - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; - TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + if (*modified) { + // Add rewritten FunctionDef into library. + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, + &functionalized_fdef)); + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } } return ret_status; @@ -227,7 +275,7 @@ Status FunctionalizeControlFlowPass::Run( {"TPUCompile", "function"}, {"XlaLaunch", "function"}, }; - std::map<string, string> canonicalized_name_to_new_name; + std::map<string, absl::optional<string>> canonicalized_name_to_new_name; for (Node* n : graph->nodes()) { auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); if (it == kNodeTypeToFunctionAttrMapping->end()) { @@ -242,12 +290,15 @@ Status FunctionalizeControlFlowPass::Run( << ". Corresponding function: " << func.name(); string new_func_name = options.flib_def->UniqueFunctionName( absl::StrCat(func.name(), "_f15n_")); + bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name)); - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); + &canonicalized_name_to_new_name, &modified)); + if (modified) { + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); + } } } diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 419867ff58..db137f1a19 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -466,7 +466,7 @@ Graph* GetConstantGraph( bool ReplaceTensorWithConstant( Graph* graph, Device* partition_device, NodeAndOutput tensor, const Tensor& constant, const gtl::FlatSet<Node*>& control_deps, - int64 max_constant_size_in_bytes, + int64 max_constant_size_in_bytes, bool disable_memory_output_type_check, const ConstantFoldNameGenerator& generate_new_name) { // Be conservative when replacing a tensor with a constant, when not // running on CPU. @@ -535,21 +535,23 @@ bool ReplaceTensorWithConstant( if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) { return false; } - if (partition_device && device_type != DEVICE_CPU) { - MemoryType original_output_memory_type; - if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second, - &original_output_memory_type) - .ok()) { - return false; - } - MemoryType const_output_memory_type; - if (!MemoryTypeForOutput(device_type, graph, constant_node, 0, - &const_output_memory_type) - .ok()) { - return false; - } - if (original_output_memory_type != const_output_memory_type) { - return false; + if (!disable_memory_output_type_check) { + if (partition_device && device_type != DEVICE_CPU) { + MemoryType original_output_memory_type; + if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second, + &original_output_memory_type) + .ok()) { + return false; + } + MemoryType const_output_memory_type; + if (!MemoryTypeForOutput(device_type, graph, constant_node, 0, + &const_output_memory_type) + .ok()) { + return false; + } + if (original_output_memory_type != const_output_memory_type) { + return false; + } } } for (auto edge : edges_to_remove) { @@ -658,7 +660,8 @@ Status ConstantFold(const ConstantFoldingOptions& opts, constant_control_deps[tensors_to_replace[c].first]; if (ReplaceTensorWithConstant( graph, partition_device, tensors_to_replace[c], outputs[c], - control_deps, opts.max_constant_size_in_bytes, generate_new_name)) { + control_deps, opts.max_constant_size_in_bytes, + opts.disable_memory_output_type_check, generate_new_name)) { ++num_nodes_replaced; } } diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h index a9a84f761b..4c71b7bd27 100644 --- a/tensorflow/core/common_runtime/constant_folding.h +++ b/tensorflow/core/common_runtime/constant_folding.h @@ -45,6 +45,10 @@ struct ConstantFoldingOptions { // optimization. int64 max_constant_size_in_bytes = 10 * 1024 * 1024; + // If disable_memory_output_type_check is true, we will disable output memory + // type check for constant node replacement. + bool disable_memory_output_type_check = false; + // A generator for the name suffix of constant folded nodes. A // default id generator that monotonically increases is used if nullptr is // passed. diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index 37a979a8f1..91194bc86f 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -39,7 +39,8 @@ void GraphOptimizer::Optimize( const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, const std::function<bool(const Node*)>& cse_consider_fn, - const std::function<bool(const Node*)>& cf_consider_fn) { + const std::function<bool(const Node*)>& cf_consider_fn, + bool cf_disable_memory_output_type_check) { Graph* g = graph->get(); DumpGraph("Initial", g); @@ -64,6 +65,8 @@ void GraphOptimizer::Optimize( ConstantFoldingOptions cf_opts; cf_opts.shape_map = shape_map; cf_opts.consider = cf_consider_fn; + cf_opts.disable_memory_output_type_check = + cf_disable_memory_output_type_check; if (opts_.max_folded_constant_in_bytes() > 0) { cf_opts.max_constant_size_in_bytes = opts_.max_folded_constant_in_bytes(); diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 789cc56942..8954e9612d 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -47,13 +47,16 @@ class GraphOptimizer { // returns true will be considered for CSE. // If cf_consider_fn is not null then only nodes for which cf_consider_fn // returns true will be considered for CF. + // If cf_disable_memory_output_type_check is true, CF will discard output + // memory type check for constant node replacement. void Optimize( FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, const std::function<bool(const Node*)>& cse_consider_fn = nullptr, - const std::function<bool(const Node*)>& cf_consider_fn = nullptr); + const std::function<bool(const Node*)>& cf_consider_fn = nullptr, + bool cf_disable_memory_output_type_check = false); const OptimizerOptions& options() { return opts_; } |