aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-09-26 18:03:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 18:13:16 -0700
commit0c573ae93b3013b91d8d2493a1daed56d11ccc98 (patch)
tree5289b1b2960319c2c9e63c1ee522914c9038789d /tensorflow/compiler
parent5d61748f4e9998c9d2017bd01864b8fcb6d2127a (diff)
Skip SymbolicGradientOp when doing constant folding in control flow functionalization.
If we want to evaluate SymbolicGradient op in constant folding, we need to construct Device object and attach it to FunctionLibraryRuntime. In graph rewriting pass, we do not have Device object created yet; it will only be created in XlaCompiler. PiperOrigin-RevId: 214702943
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 98b333a467..2d45507796 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -109,7 +109,8 @@ Status FunctionalizeControlFlowForFunction(
auto associated_functions = iter.second;
for (auto& associated_function : associated_functions) {
string name = associated_function.func_name();
- string canonicalized_name = Canonicalize(name, AttrSlice(&attrs));
+ string canonicalized_name =
+ Canonicalize(name, AttrSlice(&associated_function.attrs()));
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
string new_name;
if (iter != canonicalized_name_to_new_name->end()) {
@@ -119,7 +120,8 @@ Status FunctionalizeControlFlowForFunction(
} else {
new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
- name, new_name, attrs, fld, flr, canonicalized_name_to_new_name));
+ name, new_name, associated_function.attrs(), fld, flr,
+ canonicalized_name_to_new_name));
(*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
}
// Notice that if "n" is a function call, RewriteAssociatedFunction() will
@@ -152,9 +154,17 @@ Status FunctionalizeControlFlowForFunction(
opts.set_do_function_inlining(true);
opts.set_do_constant_folding(true);
GraphOptimizer optimizer(opts);
+ auto cf_consider_fn = [](const Node* n) {
+ // Skip SymbolicGradient op when doing constant folding.
+ // Enabling SymbolicGradient op in constant folding requires
+ // flr->device() to be non-null, and here we have not constructed
+ // proper Device object yet (it will be constructed in XlaCompiler).
+ return n->type_string() != FunctionLibraryDefinition::kGradientOp;
+ };
optimizer.Optimize(flr, flr->env(),
/*device=*/nullptr, &optimized_graph,
- /*shape_map=*/nullptr);
+ /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
+ cf_consider_fn);
// Functionalize the function body.
if (VLOG_IS_ON(4)) {