aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-09-28 15:14:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 15:18:31 -0700
commit2f559f2d5f75cf80183ae0d855110809404019f7 (patch)
treed8101da5ee6b7008fa48cee7e5000c3bb183e6ee /tensorflow/compiler/tf2xla
parentdee0481c07ed952d01b12704c89e50869a383c68 (diff)
Handle noinline gradient function in control flow functionalization.
PiperOrigin-RevId: 215003704
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc84
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc30
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h51
3 files changed, 108 insertions, 57 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 2d45507796..36c6f5d316 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -92,13 +92,51 @@ Status FunctionalizeControlFlowForFunction(
});
const FunctionBody* body = flr->GetFunctionBody(handle);
+ // Call graph optimizer. The most important optimization we need is constant
+ // folding, which will replace ops like Shape/BroadcastGradientArgs with
+ // constant shape input. Without this optimization, those ops might become
+ // dynamic input for then/else body function and XLA will complain that input
+ // is not compile time constant. We enable function inlining as well, because
+ // otherwise we won't be able to infer shape for any node depending on
+ // function call nodes.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_opt_", func_name),
+ *body->graph, fld);
+ }
+ // Optimizer accepts std::unique_ptr<Graph>* as input and might change
+ // underlying pointer, thus we create a new Graph and copy from body->graph.
+ std::unique_ptr<Graph> optimized_graph(new Graph(fld));
+ CopyGraph(*body->graph, optimized_graph.get());
+ OptimizerOptions opts;
+ opts.set_opt_level(OptimizerOptions::L0);
+ opts.set_do_function_inlining(true);
+ opts.set_do_constant_folding(true);
+ GraphOptimizer optimizer(opts);
+ auto cf_consider_fn = [](const Node* n) {
+ // Skip SymbolicGradient op when doing constant folding.
+ // Enabling SymbolicGradient op in constant folding requires
+ // flr->device() to be non-null, and here we have not constructed
+ // proper Device object yet (it will be constructed in XlaCompiler).
+ return n->type_string() != FunctionLibraryDefinition::kGradientOp;
+ };
+ optimizer.Optimize(flr, flr->env(),
+ /*device=*/nullptr, &optimized_graph,
+ /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
+ cf_consider_fn);
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_opt_", func_name),
+ *optimized_graph, fld);
+ }
+
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
// might involve node deletion/addition. Avoid modifying nodes while iterating
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
- for (auto* n : body->graph->nodes()) {
+ for (auto* n : optimized_graph->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, flr);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
@@ -118,7 +156,14 @@ Status FunctionalizeControlFlowForFunction(
// but still rewrite the node.
new_name = iter->second;
} else {
- new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ if (associated_function.type() ==
+ AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
+ // For SymbolicGradient, `name` is always "SymbolicGradient",
+ // which is not very informative. Use node name instead.
+ new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"));
+ } else {
+ new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ }
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
name, new_name, associated_function.attrs(), fld, flr,
canonicalized_name_to_new_name));
@@ -129,43 +174,10 @@ Status FunctionalizeControlFlowForFunction(
// That's fine because in that case, associated_functions will only have
// one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- body->graph, n, fld, associated_function, new_name));
+ optimized_graph.get(), n, fld, associated_function, new_name));
}
}
- // Call graph optimizer. The most important optimization we need is constant
- // folding, which will replace ops like Shape/BroadcastGradientArgs with
- // constant shape input. Without this optimization, those ops might become
- // dynamic input for then/else body function and XLA will complain that input
- // is not compile time constant. We enable function inlining as well, because
- // otherwise we won't be able to infer shape for any node depending on
- // function call nodes.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_opt_", func_name),
- *body->graph, fld);
- }
- // Optimizer accepts std::unique_ptr<Graph>* as input and might change
- // underlying pointer, thus we create a new Graph and copy from body->graph.
- std::unique_ptr<Graph> optimized_graph(new Graph(fld));
- CopyGraph(*body->graph, optimized_graph.get());
- OptimizerOptions opts;
- opts.set_opt_level(OptimizerOptions::L0);
- opts.set_do_function_inlining(true);
- opts.set_do_constant_folding(true);
- GraphOptimizer optimizer(opts);
- auto cf_consider_fn = [](const Node* n) {
- // Skip SymbolicGradient op when doing constant folding.
- // Enabling SymbolicGradient op in constant folding requires
- // flr->device() to be non-null, and here we have not constructed
- // proper Device object yet (it will be constructed in XlaCompiler).
- return n->type_string() != FunctionLibraryDefinition::kGradientOp;
- };
- optimizer.Optimize(flr, flr->env(),
- /*device=*/nullptr, &optimized_graph,
- /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
- cf_consider_fn);
-
// Functionalize the function body.
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index d6f42bac86..01dd3ba10f 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -336,9 +336,9 @@ bool HasAssociatedFunction(const NodeDef& node_def,
}
if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
- return false;
+ // Gradient op has "f" attr, which is set to the function we are getting
+ // gradient for. We need to functionalize the gradient function.
+ return true;
}
for (const auto& iter : node_def.attr()) {
@@ -357,17 +357,18 @@ std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
if (flr->GetFunctionLibraryDefinition()->Contains(op)) {
// This is a function call node.
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
- results.emplace_back(AssociatedFunctionInfo(op, attrs));
+ results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
} else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
+ // This is a SymbolicGradient op.
+ AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
+ results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
} else {
// Collect all function attrs for the node.
for (auto& iter : node.attrs()) {
if (iter.second.has_func()) {
VLOG(2) << "Found function attr for node " << node.name() << ": "
<< iter.first << " = " << iter.second.func().name();
- results.emplace_back(AssociatedFunctionInfo(
+ results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
iter.second.func().name(), iter.second.func().attr(), iter.first));
}
}
@@ -410,6 +411,21 @@ Status RewriteAssociatedFunction(
graph->RemoveNode(node);
break;
}
+ case AssociatedFunctionInfo::kSymbolicGradient: {
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
+ GradientDef gradient_def;
+ gradient_def.set_function_name(func.name());
+ gradient_def.set_gradient_func(rewritten_function_name);
+ string original_grad_func = fld->FindGradient(func.name());
+ if (original_grad_func.empty()) {
+ TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
+ } else if (original_grad_func != rewritten_function_name) {
+ TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
+ }
+ break;
+ }
case AssociatedFunctionInfo::kFunctionAttr: {
// Change function attr to rewritten functions.
NameAttrList func;
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 6065d0bb9a..53eab8b63e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -65,21 +65,33 @@ uint32 GetXLARandomSeed();
class AssociatedFunctionInfo {
public:
enum AssociatedFunctionType {
- kFunctionCallNode = 0,
- kFunctionAttr = 1,
+ kFunctionAttr = 0,
+ kFunctionCallNode = 1,
+ kSymbolicGradient = 2,
};
- // The node is a function call.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
- : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
-
// The function is an attr of the node.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
- const string& attr_name)
- : type_(kFunctionAttr),
- func_name_(func_name),
- attrs_(attrs),
- attr_name_(attr_name) {}
+ static AssociatedFunctionInfo FunctionAttr(const string& func_name,
+ const AttrValueMap& attrs,
+ const string& attr_name) {
+ return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name);
+ }
+
+ // The node is a function call.
+ static AssociatedFunctionInfo FunctionCall(const string& func_name,
+ const AttrValueMap& attrs) {
+ // attr_name will not be used in this case.
+ return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs,
+ /*attr_name=*/"");
+ }
+
+ // The node is a SymbolicGradient op.
+ static AssociatedFunctionInfo SymbolicGradient(const string& func_name,
+ const AttrValueMap& attrs) {
+ // attr_name will not be used in this case.
+ return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs,
+ /*attr_name=*/"");
+ }
AssociatedFunctionType type() const { return type_; }
@@ -90,6 +102,13 @@ class AssociatedFunctionInfo {
const AttrValueMap& attrs() const { return attrs_; }
private:
+ AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name,
+ const AttrValueMap& attrs, const string& attr_name)
+ : type_(type),
+ func_name_(func_name),
+ attrs_(attrs),
+ attr_name_(attr_name) {}
+
// Available for all instances.
AssociatedFunctionType type_;
string func_name_;
@@ -105,14 +124,18 @@ bool HasAssociatedFunction(const NodeDef& node_def,
// Gets functions associated with the node. Current cases:
// 1. For function call node, its function name;
-// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
+// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient",
+// and returned attrs will be this node's attributes;
+// 3. For nodes like XlaWhile/XlaIf, all their function attributes.
std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
const Node& node, FunctionLibraryRuntime* flr);
// Changes associated functions for the node. Current cases:
// 1. For function call node, creates a new node with the new function name and
// remove the old node;
-// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
+// 2. For SymbolicGradient op, add or replace GradientDef in
+// FunctionLibraryDefinition;
+// 3. For nodes like XlaWhile/XlaIf, modify their function attributes.
Status RewriteAssociatedFunction(
Graph* graph, Node* node, FunctionLibraryDefinition* fld,
const AssociatedFunctionInfo& associated_function,