diff options
author | 2018-06-13 19:47:23 -0700 | |
---|---|---|
committer | 2018-06-13 19:50:14 -0700 | |
commit | 007fc38f806c3405031dfef8076ca014bf0bcf7c (patch) | |
tree | a564f1145c85fa690aca37660a2f0987e93b57f6 | |
parent | dac4634dc8ad35115aabbc3ee054e08fea62fa50 (diff) |
Makes cond_v2 pass in device, container, colocation stacks, and collections to the branches.
This brings cond_v2 functionality closer to tf.cond.
PiperOrigin-RevId: 200495346
-rw-r--r-- | tensorflow/contrib/control_flow/python/cond_v2.py | 23 | ||||
-rw-r--r-- | tensorflow/contrib/control_flow/python/cond_v2_test.py | 223 | ||||
-rw-r--r-- | tensorflow/python/framework/function.py | 54 |
3 files changed, 296 insertions, 4 deletions
diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py index b364e34511..90371cd8d7 100644 --- a/tensorflow/contrib/control_flow/python/cond_v2.py +++ b/tensorflow/contrib/control_flow/python/cond_v2.py @@ -48,13 +48,30 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): name = "cond" with ops.name_scope(name) as scope: + # Identify if there is a caller device, & get the innermost if possible. + device_stack = ops.get_default_graph()._device_function_stack + caller_device = device_stack[-1] if device_stack else None + + caller_colocation_stack = ops.get_default_graph()._colocation_stack + caller_container = ops.get_default_graph()._container + caller_collection_ref = ops.get_default_graph()._collections + func_name_prefix = scope.replace("/", "_") true_graph = function.func_graph_from_py_func( - true_fn, [], [], name="%strue" % func_name_prefix) + true_fn, [], [], + name="%strue" % func_name_prefix, + device=caller_device, + colocation_stack=caller_colocation_stack, + collections_ref=caller_collection_ref, + container=caller_container) false_graph = function.func_graph_from_py_func( - false_fn, [], [], name="%sfalse" % func_name_prefix) - + false_fn, [], [], + name="%sfalse" % func_name_prefix, + device=caller_device, + colocation_stack=caller_colocation_stack, + collections_ref=caller_collection_ref, + container=caller_container) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py index b7d4c16df4..94ed3e130b 100644 --- a/tensorflow/contrib/control_flow/python/cond_v2_test.py +++ b/tensorflow/contrib/control_flow/python/cond_v2_test.py @@ -25,10 +25,13 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import saver +from tensorflow.python.util import compat class NewCondTest(test.TestCase): @@ -198,5 +201,225 @@ class NewCondTest(test.TestCase): self.assertEqual(false_val, [0.0]) +class CondV2CollectionTest(test.TestCase): + + def testCollectionIntValueAccessInCond(self): + """Read values from graph collections inside of cond_v2.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = 2 + y = 5 + ops.add_to_collection("x", x) + ops.add_to_collection("y", y) + def fn(): + x_const = constant_op.constant(ops.get_collection("x")[0]) + y_const = constant_op.constant(ops.get_collection("y")[0]) + return math_ops.add(x_const, y_const) + + cnd = cond_v2.cond_v2(True, fn, fn) + self.assertEquals(cnd[0].eval(), 7) + + def testCollectionTensorValueAccessInCond(self): + """Read tensors from collections inside of cond_v2 & use them.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = constant_op.constant(2) + y = constant_op.constant(5) + ops.add_to_collection("x", x) + ops.add_to_collection("y", y) + + def fn(): + x_read = ops.get_collection("x")[0] + y_read = ops.get_collection("y")[0] + return math_ops.add(x_read, y_read) + + cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn) + self.assertEquals(cnd[0].eval(), 7) + + def testCollectionIntValueWriteInCond(self): + """Make sure Int writes to collections work inside of cond_v2.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = constant_op.constant(2) + y = constant_op.constant(5) + def true_fn(): + z = math_ops.add(x, y) + ops.add_to_collection("z", 7) + return math_ops.mul(x, z) + + def false_fn(): + z = math_ops.add(x, y) + return math_ops.mul(x, z) + + cnd = cond_v2.cond_v2( + True, true_fn, + false_fn) + self.assertEquals(cnd[0].eval(), 14) + + read_z_collection = ops.get_collection("z") + self.assertEquals(read_z_collection, [7]) + + +class CondV2ContainerTest(test.TestCase): + + def testContainer(self): + """Set containers outside & inside of cond_v2. + + Make sure the containers are set correctly for both variable creation + (tested by variables.Variable) and for stateful ops (tested by FIFOQueue) + """ + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + v0 = variables.Variable([0]) + q0 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + def container(node): + return node.op.get_attr("container") + + self.assertEqual(compat.as_bytes(""), container(v0)) + self.assertEqual(compat.as_bytes(""), container(q0.queue_ref)) + + def true_fn(): + # When this branch is created in cond below, + # the container should begin with 'l1' + v1 = variables.Variable([1]) + q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + with ops.container("l2t"): + v2 = variables.Variable([2]) + q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + v3 = variables.Variable([1]) + q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v1)) + self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) + self.assertEqual(compat.as_bytes("l2t"), container(v2)) + self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref)) + self.assertEqual(compat.as_bytes("l1"), container(v3)) + self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) + + return constant_op.constant(2.0) + + def false_fn(): + # When this branch is created in cond below, + # the container should begin with 'l1' + v1 = variables.Variable([1]) + q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + with ops.container("l2f"): + v2 = variables.Variable([2]) + q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + v3 = variables.Variable([1]) + q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v1)) + self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) + self.assertEqual(compat.as_bytes("l2f"), container(v2)) + self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref)) + self.assertEqual(compat.as_bytes("l1"), container(v3)) + self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) + + return constant_op.constant(6.0) + + with ops.container("l1"): + cnd_true = cond_v2.cond_v2(True, true_fn, false_fn) + self.assertEquals(cnd_true[0].eval(), 2) + + cnd_false = cond_v2.cond_v2(False, true_fn, false_fn) + self.assertEquals(cnd_false[0].eval(), 6) + + v4 = variables.Variable([3]) + q4 = data_flow_ops.FIFOQueue(1, dtypes.float32) + v5 = variables.Variable([4]) + q5 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v4)) + self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref)) + self.assertEqual(compat.as_bytes(""), container(v5)) + self.assertEqual(compat.as_bytes(""), container(q5.queue_ref)) + + +class CondV2ColocationGroupAndDeviceTest(test.TestCase): + + def testColocateWithBeforeCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + a = constant_op.constant([2.0], name="a") + b = constant_op.constant([2.0], name="b") + + def fn(): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + + def fn2(): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + with ops.colocate_with(b.op): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + def testColocateWithInAndOutOfCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + a = constant_op.constant([2.0], name="a") + b = constant_op.constant([2.0], name="b") + + def fn2(): + with ops.colocate_with(b.op): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + d = constant_op.constant([2.0], name="d") + self.assertEqual([b"loc:@a"], d.op.colocation_groups()) + + def testDeviceBeforeCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + def fn(): + c = constant_op.constant(3.0) + self.assertEqual("/device:CPU:0", c.op.device) + return c + + with ops.device("/device:CPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + + def fn2(): + c = constant_op.constant(3.0) + self.assertEqual("/device:GPU:0", c.op.device) + return c + + with ops.device("/device:GPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + def testDeviceInAndOutOfCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + def fn2(): + with ops.device("/device:GPU:0"): + c = constant_op.constant(3.0) + self.assertEqual("/device:GPU:0", c.op.device) + return c + + with ops.device("/device:CPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + d = constant_op.constant(4.0) + self.assertEqual("/device:CPU:0", d.op.device) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 82ecba310b..002a3d3be5 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import compat +from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -650,6 +651,41 @@ class _FuncGraph(ops.Graph): # TODO(skyewm): is this needed? self.extra_vars = [] + # pylint: disable=g-doc-return-or-yield + + @tf_contextlib.contextmanager + def container(self, container_name): + """Returns a context manager that specifies the resource container to use. + + Overridden from @{tf.Graph} to update both the init_scope container + and the present inner container. This is necessary to make sure setting + containers applies correctly both to created variables and to stateful + ops. + + Args: + container_name: container name string. + + Returns: + A context manager for defining resource containers for stateful ops, + yields the container name. + """ + original_container = self._container + # pylint: disable=protected-access + with ops.init_scope(): + original_init_container = ops.get_default_graph()._container + try: + self._container = container_name + with ops.init_scope(): + ops.get_default_graph()._container = container_name + yield self._container + finally: + self._container = original_container + with ops.init_scope(): + ops.get_default_graph()._container = original_init_container + # pylint: enable=protected-access + + # pylint: enable=g-doc-return-or-yield + def getvar( self, getter, @@ -773,7 +809,9 @@ class _FuncGraph(ops.Graph): def func_graph_from_py_func(func, arg_names, arg_types, name=None, - capture_by_value=False, device=None): + capture_by_value=False, device=None, + colocation_stack=None, container=None, + collections_ref=None): """Returns a _FuncGraph generated from `func`. Args: @@ -786,6 +824,10 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None, capture_by_value: boolean. If True, captured values will be copied into the function body. device: device name or function. + colocation_stack: A colocation stack (list) the _FuncGraph should use. + container: A container name the _FuncGraph should start with. + collections_ref: A reference to a collections dict the _FuncGraph should + use internally. Returns: A _FuncGraph. @@ -796,7 +838,17 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None, if not name: name = _get_func_name(func) func_graph = _FuncGraph(name, capture_by_value) + with func_graph.as_default(), ops.device(device): + # pylint: disable=protected-access + if collections_ref is not None: + func_graph._collections = collections_ref + if container is not None: + func_graph._container = container + if colocation_stack is not None: + func_graph._colocation_stack = colocation_stack + # pylint: enable=protected-access + # Create placeholders for the function arguments. for (argname, argtype) in zip(arg_names, arg_types): argholder = array_ops.placeholder(argtype, name=argname) |