aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-18 13:36:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 13:39:30 -0700
commitab251a0ec66a3c8b88ca467e49bfc68d18a2a8e9 (patch)
treea7c5b15fb417b2f66d52a51a55de622dcc221c73 /tensorflow/core/common_runtime/function.cc
parent3d3196f34173e5c6e1f9297e2fcd4c316fe903fd (diff)
Enables `If` operator lowering in cond_v2 when XLA is disabled. Lowering allows cond_v2 to avoid some of the limitations of Functions, allowing users to specify devices & colocation inside of cond_v2 branches, and enabling non-strict evaluation & partial pruning of branches. This brings cond_v2 closer to feature parity with tf.cond.
However, we do not lower `If` in the XLA context because it is easier for XLA to apply its own optimizations when dealing with un-lowered `If` operators than with lowered switch/merge control flow. Also adds a toggleable flag in for InlineFunctionBody in function.cc that prevents the function caller device from overriding the devices of function body nodes. This is necessary for cond_v2 branches to support explicitly-specified devices. Adds several tests to make sure that: - lowering is usually enabled - lowering is disabled for XLA - node colocation inside of cond_v2 branches works - explicit device placement inside of cond_v2 branches works PiperOrigin-RevId: 201049850
Diffstat (limited to 'tensorflow/core/common_runtime/function.cc')
-rw-r--r--tensorflow/core/common_runtime/function.cc12
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 68d37ddbcd..1200dcc1fe 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1188,11 +1188,13 @@ static bool ValidateInlining(const Node* node, const FunctionBody* fbody) {
return true;
}
-// Given a "caller" in "graph", which is a function call of a function
+// Given a "caller" in graph "g", which is a function call of a function
// to "fbody". Replaces the "caller" with fbody->graph and connects
-// edges properly.
+// edges properly. "override_device" specifies whether inlining should replace
+// explicitly specified devices inside fbody with the callee's device.
void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
- Node* caller, const FunctionBody* fbody) {
+ Node* caller, const FunctionBody* fbody,
+ bool override_device) {
if (!ValidateInlining(caller, fbody)) {
LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. "
<< DebugString(fbody->graph);
@@ -1227,7 +1229,9 @@ void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
for (Node* n : fbody->graph->op_nodes()) {
NodeDef ndef = n->def();
ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
- ndef.set_device(caller->def().device());
+ if (override_device || ndef.device().empty()) {
+ ndef.set_device(caller->def().device());
+ }
Node* clone = g->AddNode(ndef, &s);
TF_CHECK_OK(s);
node_map[n->id()] = clone;