diff options
Diffstat (limited to 'tensorflow/contrib/control_flow/python/cond_v2.py')
-rw-r--r-- | tensorflow/contrib/control_flow/python/cond_v2.py | 31 |
1 files changed, 27 insertions, 4 deletions
diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py index 9ffad9caa9..90371cd8d7 100644 --- a/tensorflow/contrib/control_flow/python/cond_v2.py +++ b/tensorflow/contrib/control_flow/python/cond_v2.py @@ -44,11 +44,34 @@ from tensorflow.python.util import compat def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" + if not name: + name = "cond" + with ops.name_scope(name) as scope: - true_graph = function.func_graph_from_py_func(true_fn, [], [], - name="%s_true" % scope) - false_graph = function.func_graph_from_py_func(false_fn, [], [], - name="%s_false" % scope) + # Identify if there is a caller device, & get the innermost if possible. + device_stack = ops.get_default_graph()._device_function_stack + caller_device = device_stack[-1] if device_stack else None + + caller_colocation_stack = ops.get_default_graph()._colocation_stack + caller_container = ops.get_default_graph()._container + caller_collection_ref = ops.get_default_graph()._collections + + func_name_prefix = scope.replace("/", "_") + + true_graph = function.func_graph_from_py_func( + true_fn, [], [], + name="%strue" % func_name_prefix, + device=caller_device, + colocation_stack=caller_colocation_stack, + collections_ref=caller_collection_ref, + container=caller_container) + false_graph = function.func_graph_from_py_func( + false_fn, [], [], + name="%sfalse" % func_name_prefix, + device=caller_device, + colocation_stack=caller_colocation_stack, + collections_ref=caller_collection_ref, + container=caller_container) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that |