diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-18 13:36:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-18 13:39:30 -0700 |
commit | ab251a0ec66a3c8b88ca467e49bfc68d18a2a8e9 (patch) | |
tree | a7c5b15fb417b2f66d52a51a55de622dcc221c73 /tensorflow/core/common_runtime/function.cc | |
parent | 3d3196f34173e5c6e1f9297e2fcd4c316fe903fd (diff) |
Enables `If` operator lowering in cond_v2 when XLA is disabled. Lowering allows cond_v2 to avoid some of the limitations of Functions, allowing users to specify devices & colocation inside of cond_v2 branches, and enabling non-strict evaluation & partial pruning of branches. This brings cond_v2 closer to feature parity with tf.cond.
However, we do not lower `If` in the XLA context because it is easier for XLA to apply its own optimizations when dealing with un-lowered `If` operators than with lowered switch/merge control flow.
Also adds a toggleable flag in for InlineFunctionBody in function.cc that prevents the function caller device from overriding the devices of function body nodes. This is necessary for cond_v2 branches to support explicitly-specified devices.
Adds several tests to make sure that:
- lowering is usually enabled
- lowering is disabled for XLA
- node colocation inside of cond_v2 branches works
- explicit device placement inside of cond_v2 branches works
PiperOrigin-RevId: 201049850
Diffstat (limited to 'tensorflow/core/common_runtime/function.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 68d37ddbcd..1200dcc1fe 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -1188,11 +1188,13 @@ static bool ValidateInlining(const Node* node, const FunctionBody* fbody) { return true; } -// Given a "caller" in "graph", which is a function call of a function +// Given a "caller" in graph "g", which is a function call of a function // to "fbody". Replaces the "caller" with fbody->graph and connects -// edges properly. +// edges properly. "override_device" specifies whether inlining should replace +// explicitly specified devices inside fbody with the callee's device. void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, - Node* caller, const FunctionBody* fbody) { + Node* caller, const FunctionBody* fbody, + bool override_device) { if (!ValidateInlining(caller, fbody)) { LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. " << DebugString(fbody->graph); @@ -1227,7 +1229,9 @@ void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, for (Node* n : fbody->graph->op_nodes()) { NodeDef ndef = n->def(); ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name())); - ndef.set_device(caller->def().device()); + if (override_device || ndef.device().empty()) { + ndef.set_device(caller->def().device()); + } Node* clone = g->AddNode(ndef, &s); TF_CHECK_OK(s); node_map[n->id()] = clone; |