diff options
author | Tong Shen <endlessroad@google.com> | 2018-09-28 15:14:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 15:18:31 -0700 |
commit | 2f559f2d5f75cf80183ae0d855110809404019f7 (patch) | |
tree | d8101da5ee6b7008fa48cee7e5000c3bb183e6ee /tensorflow/compiler/tf2xla | |
parent | dee0481c07ed952d01b12704c89e50869a383c68 (diff) |
Handle noinline gradient function in control flow functionalization.
PiperOrigin-RevId: 215003704
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow.cc | 84 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/tf2xla_util.cc | 30 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/tf2xla_util.h | 51 |
3 files changed, 108 insertions, 57 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 2d45507796..36c6f5d316 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -92,13 +92,51 @@ Status FunctionalizeControlFlowForFunction( }); const FunctionBody* body = flr->GetFunctionBody(handle); + // Call graph optimizer. The most important optimization we need is constant + // folding, which will replace ops like Shape/BroadcastGradientArgs with + // constant shape input. Without this optimization, those ops might become + // dynamic input for then/else body function and XLA will complain that input + // is not compile time constant. We enable function inlining as well, because + // otherwise we won't be able to infer shape for any node depending on + // function call nodes. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_opt_", func_name), + *body->graph, fld); + } + // Optimizer accepts std::unique_ptr<Graph>* as input and might change + // underlying pointer, thus we create a new Graph and copy from body->graph. + std::unique_ptr<Graph> optimized_graph(new Graph(fld)); + CopyGraph(*body->graph, optimized_graph.get()); + OptimizerOptions opts; + opts.set_opt_level(OptimizerOptions::L0); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); + GraphOptimizer optimizer(opts); + auto cf_consider_fn = [](const Node* n) { + // Skip SymbolicGradient op when doing constant folding. + // Enabling SymbolicGradient op in constant folding requires + // flr->device() to be non-null, and here we have not constructed + // proper Device object yet (it will be constructed in XlaCompiler). + return n->type_string() != FunctionLibraryDefinition::kGradientOp; + }; + optimizer.Optimize(flr, flr->env(), + /*device=*/nullptr, &optimized_graph, + /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr, + cf_consider_fn); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_opt_", func_name), + *optimized_graph, fld); + } + // If any node has associated functions, functionalize them first. // Gather nodes with associated functions first, because rewriting those nodes // might involve node deletion/addition. Avoid modifying nodes while iterating // it. std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>> nodes_to_associated_functions; - for (auto* n : body->graph->nodes()) { + for (auto* n : optimized_graph->nodes()) { auto associated_functions = GetAssociatedFunctions(*n, flr); if (!associated_functions.empty()) { nodes_to_associated_functions.push_back({n, associated_functions}); @@ -118,7 +156,14 @@ Status FunctionalizeControlFlowForFunction( // but still rewrite the node. new_name = iter->second; } else { - new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + if (associated_function.type() == + AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { + // For SymbolicGradient, `name` is always "SymbolicGradient", + // which is not very informative. Use node name instead. + new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")); + } else { + new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + } TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( name, new_name, associated_function.attrs(), fld, flr, canonicalized_name_to_new_name)); @@ -129,43 +174,10 @@ Status FunctionalizeControlFlowForFunction( // That's fine because in that case, associated_functions will only have // one member and the loop will only run once. TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - body->graph, n, fld, associated_function, new_name)); + optimized_graph.get(), n, fld, associated_function, new_name)); } } - // Call graph optimizer. The most important optimization we need is constant - // folding, which will replace ops like Shape/BroadcastGradientArgs with - // constant shape input. Without this optimization, those ops might become - // dynamic input for then/else body function and XLA will complain that input - // is not compile time constant. We enable function inlining as well, because - // otherwise we won't be able to infer shape for any node depending on - // function call nodes. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_opt_", func_name), - *body->graph, fld); - } - // Optimizer accepts std::unique_ptr<Graph>* as input and might change - // underlying pointer, thus we create a new Graph and copy from body->graph. - std::unique_ptr<Graph> optimized_graph(new Graph(fld)); - CopyGraph(*body->graph, optimized_graph.get()); - OptimizerOptions opts; - opts.set_opt_level(OptimizerOptions::L0); - opts.set_do_function_inlining(true); - opts.set_do_constant_folding(true); - GraphOptimizer optimizer(opts); - auto cf_consider_fn = [](const Node* n) { - // Skip SymbolicGradient op when doing constant folding. - // Enabling SymbolicGradient op in constant folding requires - // flr->device() to be non-null, and here we have not constructed - // proper Device object yet (it will be constructed in XlaCompiler). - return n->type_string() != FunctionLibraryDefinition::kGradientOp; - }; - optimizer.Optimize(flr, flr->env(), - /*device=*/nullptr, &optimized_graph, - /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr, - cf_consider_fn); - // Functionalize the function body. if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index d6f42bac86..01dd3ba10f 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -336,9 +336,9 @@ bool HasAssociatedFunction(const NodeDef& node_def, } if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { - // Skip gradient op. Gradient op has "f" attr, which is set to the function - // we are getting gradient for. That function is not associated with the op. - return false; + // Gradient op has "f" attr, which is set to the function we are getting + // gradient for. We need to functionalize the gradient function. + return true; } for (const auto& iter : node_def.attr()) { @@ -357,17 +357,18 @@ std::vector<AssociatedFunctionInfo> GetAssociatedFunctions( if (flr->GetFunctionLibraryDefinition()->Contains(op)) { // This is a function call node. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); - results.emplace_back(AssociatedFunctionInfo(op, attrs)); + results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs)); } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { - // Skip gradient op. Gradient op has "f" attr, which is set to the function - // we are getting gradient for. That function is not associated with the op. + // This is a SymbolicGradient op. + AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); + results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs)); } else { // Collect all function attrs for the node. for (auto& iter : node.attrs()) { if (iter.second.has_func()) { VLOG(2) << "Found function attr for node " << node.name() << ": " << iter.first << " = " << iter.second.func().name(); - results.emplace_back(AssociatedFunctionInfo( + results.emplace_back(AssociatedFunctionInfo::FunctionAttr( iter.second.func().name(), iter.second.func().attr(), iter.first)); } } @@ -410,6 +411,21 @@ Status RewriteAssociatedFunction( graph->RemoveNode(node); break; } + case AssociatedFunctionInfo::kSymbolicGradient: { + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr( + node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func)); + GradientDef gradient_def; + gradient_def.set_function_name(func.name()); + gradient_def.set_gradient_func(rewritten_function_name); + string original_grad_func = fld->FindGradient(func.name()); + if (original_grad_func.empty()) { + TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def)); + } else if (original_grad_func != rewritten_function_name) { + TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def)); + } + break; + } case AssociatedFunctionInfo::kFunctionAttr: { // Change function attr to rewritten functions. NameAttrList func; diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 6065d0bb9a..53eab8b63e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -65,21 +65,33 @@ uint32 GetXLARandomSeed(); class AssociatedFunctionInfo { public: enum AssociatedFunctionType { - kFunctionCallNode = 0, - kFunctionAttr = 1, + kFunctionAttr = 0, + kFunctionCallNode = 1, + kSymbolicGradient = 2, }; - // The node is a function call. - AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs) - : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {} - // The function is an attr of the node. - AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs, - const string& attr_name) - : type_(kFunctionAttr), - func_name_(func_name), - attrs_(attrs), - attr_name_(attr_name) {} + static AssociatedFunctionInfo FunctionAttr(const string& func_name, + const AttrValueMap& attrs, + const string& attr_name) { + return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name); + } + + // The node is a function call. + static AssociatedFunctionInfo FunctionCall(const string& func_name, + const AttrValueMap& attrs) { + // attr_name will not be used in this case. + return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs, + /*attr_name=*/""); + } + + // The node is a SymbolicGradient op. + static AssociatedFunctionInfo SymbolicGradient(const string& func_name, + const AttrValueMap& attrs) { + // attr_name will not be used in this case. + return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs, + /*attr_name=*/""); + } AssociatedFunctionType type() const { return type_; } @@ -90,6 +102,13 @@ class AssociatedFunctionInfo { const AttrValueMap& attrs() const { return attrs_; } private: + AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name, + const AttrValueMap& attrs, const string& attr_name) + : type_(type), + func_name_(func_name), + attrs_(attrs), + attr_name_(attr_name) {} + // Available for all instances. AssociatedFunctionType type_; string func_name_; @@ -105,14 +124,18 @@ bool HasAssociatedFunction(const NodeDef& node_def, // Gets functions associated with the node. Current cases: // 1. For function call node, its function name; -// 2. For nodes like XlaWhile/XlaIf, all their function attributes. +// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient", +// and returned attrs will be this node's attributes; +// 3. For nodes like XlaWhile/XlaIf, all their function attributes. std::vector<AssociatedFunctionInfo> GetAssociatedFunctions( const Node& node, FunctionLibraryRuntime* flr); // Changes associated functions for the node. Current cases: // 1. For function call node, creates a new node with the new function name and // remove the old node; -// 2. For nodes like XlaWhile/XlaIf, modify their function attributes. +// 2. For SymbolicGradient op, add or replace GradientDef in +// FunctionLibraryDefinition; +// 3. For nodes like XlaWhile/XlaIf, modify their function attributes. Status RewriteAssociatedFunction( Graph* graph, Node* node, FunctionLibraryDefinition* fld, const AssociatedFunctionInfo& associated_function, |