aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/control_flow/python/cond_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/control_flow/python/cond_v2.py')
-rw-r--r--tensorflow/contrib/control_flow/python/cond_v2.py31
1 files changed, 27 insertions, 4 deletions
diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py
index 9ffad9caa9..90371cd8d7 100644
--- a/tensorflow/contrib/control_flow/python/cond_v2.py
+++ b/tensorflow/contrib/control_flow/python/cond_v2.py
@@ -44,11 +44,34 @@ from tensorflow.python.util import compat
def cond_v2(pred, true_fn, false_fn, name="cond"):
"""Like tf.cond, except emits a single If op."""
+ if not name:
+ name = "cond"
+
with ops.name_scope(name) as scope:
- true_graph = function.func_graph_from_py_func(true_fn, [], [],
- name="%s_true" % scope)
- false_graph = function.func_graph_from_py_func(false_fn, [], [],
- name="%s_false" % scope)
+ # Identify if there is a caller device, & get the innermost if possible.
+ device_stack = ops.get_default_graph()._device_function_stack
+ caller_device = device_stack[-1] if device_stack else None
+
+ caller_colocation_stack = ops.get_default_graph()._colocation_stack
+ caller_container = ops.get_default_graph()._container
+ caller_collection_ref = ops.get_default_graph()._collections
+
+ func_name_prefix = scope.replace("/", "_")
+
+ true_graph = function.func_graph_from_py_func(
+ true_fn, [], [],
+ name="%strue" % func_name_prefix,
+ device=caller_device,
+ colocation_stack=caller_colocation_stack,
+ collections_ref=caller_collection_ref,
+ container=caller_container)
+ false_graph = function.func_graph_from_py_func(
+ false_fn, [], [],
+ name="%sfalse" % func_name_prefix,
+ device=caller_device,
+ colocation_stack=caller_colocation_stack,
+ collections_ref=caller_collection_ref,
+ container=caller_container)
_check_same_outputs(true_graph, false_graph)
# Add inputs to true_graph and false_graph to make them match. Note that