diff options
author | Tong Shen <endlessroad@google.com> | 2018-09-26 18:03:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 18:13:16 -0700 |
commit | 0c573ae93b3013b91d8d2493a1daed56d11ccc98 (patch) | |
tree | 5289b1b2960319c2c9e63c1ee522914c9038789d /tensorflow/compiler | |
parent | 5d61748f4e9998c9d2017bd01864b8fcb6d2127a (diff) |
Skip SymbolicGradientOp when doing constant folding in control flow functionalization.
If we want to evaluate SymbolicGradient op in constant folding, we need to construct Device object and attach it to FunctionLibraryRuntime. In graph rewriting pass, we do not have Device object created yet; it will only be created in XlaCompiler.
PiperOrigin-RevId: 214702943
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow.cc | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 98b333a467..2d45507796 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -109,7 +109,8 @@ Status FunctionalizeControlFlowForFunction( auto associated_functions = iter.second; for (auto& associated_function : associated_functions) { string name = associated_function.func_name(); - string canonicalized_name = Canonicalize(name, AttrSlice(&attrs)); + string canonicalized_name = + Canonicalize(name, AttrSlice(&associated_function.attrs())); auto iter = canonicalized_name_to_new_name->find(canonicalized_name); string new_name; if (iter != canonicalized_name_to_new_name->end()) { @@ -119,7 +120,8 @@ Status FunctionalizeControlFlowForFunction( } else { new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - name, new_name, attrs, fld, flr, canonicalized_name_to_new_name)); + name, new_name, associated_function.attrs(), fld, flr, + canonicalized_name_to_new_name)); (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; } // Notice that if "n" is a function call, RewriteAssociatedFunction() will @@ -152,9 +154,17 @@ Status FunctionalizeControlFlowForFunction( opts.set_do_function_inlining(true); opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); + auto cf_consider_fn = [](const Node* n) { + // Skip SymbolicGradient op when doing constant folding. + // Enabling SymbolicGradient op in constant folding requires + // flr->device() to be non-null, and here we have not constructed + // proper Device object yet (it will be constructed in XlaCompiler). + return n->type_string() != FunctionLibraryDefinition::kGradientOp; + }; optimizer.Optimize(flr, flr->env(), /*device=*/nullptr, &optimized_graph, - /*shape_map=*/nullptr); + /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr, + cf_consider_fn); // Functionalize the function body. if (VLOG_IS_ON(4)) { |