aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-13 19:47:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 19:50:14 -0700
commit007fc38f806c3405031dfef8076ca014bf0bcf7c (patch)
treea564f1145c85fa690aca37660a2f0987e93b57f6
parentdac4634dc8ad35115aabbc3ee054e08fea62fa50 (diff)
Makes cond_v2 pass in device, container, colocation stacks, and collections to the branches.
This brings cond_v2 functionality closer to tf.cond. PiperOrigin-RevId: 200495346
-rw-r--r--tensorflow/contrib/control_flow/python/cond_v2.py23
-rw-r--r--tensorflow/contrib/control_flow/python/cond_v2_test.py223
-rw-r--r--tensorflow/python/framework/function.py54
3 files changed, 296 insertions, 4 deletions
diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py
index b364e34511..90371cd8d7 100644
--- a/tensorflow/contrib/control_flow/python/cond_v2.py
+++ b/tensorflow/contrib/control_flow/python/cond_v2.py
@@ -48,13 +48,30 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
name = "cond"
with ops.name_scope(name) as scope:
+ # Identify if there is a caller device, & get the innermost if possible.
+ device_stack = ops.get_default_graph()._device_function_stack
+ caller_device = device_stack[-1] if device_stack else None
+
+ caller_colocation_stack = ops.get_default_graph()._colocation_stack
+ caller_container = ops.get_default_graph()._container
+ caller_collection_ref = ops.get_default_graph()._collections
+
func_name_prefix = scope.replace("/", "_")
true_graph = function.func_graph_from_py_func(
- true_fn, [], [], name="%strue" % func_name_prefix)
+ true_fn, [], [],
+ name="%strue" % func_name_prefix,
+ device=caller_device,
+ colocation_stack=caller_colocation_stack,
+ collections_ref=caller_collection_ref,
+ container=caller_container)
false_graph = function.func_graph_from_py_func(
- false_fn, [], [], name="%sfalse" % func_name_prefix)
-
+ false_fn, [], [],
+ name="%sfalse" % func_name_prefix,
+ device=caller_device,
+ colocation_stack=caller_colocation_stack,
+ collections_ref=caller_collection_ref,
+ container=caller_container)
_check_same_outputs(true_graph, false_graph)
# Add inputs to true_graph and false_graph to make them match. Note that
diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py
index b7d4c16df4..94ed3e130b 100644
--- a/tensorflow/contrib/control_flow/python/cond_v2_test.py
+++ b/tensorflow/contrib/control_flow/python/cond_v2_test.py
@@ -25,10 +25,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
+from tensorflow.python.util import compat
class NewCondTest(test.TestCase):
@@ -198,5 +201,225 @@ class NewCondTest(test.TestCase):
self.assertEqual(false_val, [0.0])
+class CondV2CollectionTest(test.TestCase):
+
+ def testCollectionIntValueAccessInCond(self):
+ """Read values from graph collections inside of cond_v2."""
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ x = 2
+ y = 5
+ ops.add_to_collection("x", x)
+ ops.add_to_collection("y", y)
+ def fn():
+ x_const = constant_op.constant(ops.get_collection("x")[0])
+ y_const = constant_op.constant(ops.get_collection("y")[0])
+ return math_ops.add(x_const, y_const)
+
+ cnd = cond_v2.cond_v2(True, fn, fn)
+ self.assertEquals(cnd[0].eval(), 7)
+
+ def testCollectionTensorValueAccessInCond(self):
+ """Read tensors from collections inside of cond_v2 & use them."""
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ x = constant_op.constant(2)
+ y = constant_op.constant(5)
+ ops.add_to_collection("x", x)
+ ops.add_to_collection("y", y)
+
+ def fn():
+ x_read = ops.get_collection("x")[0]
+ y_read = ops.get_collection("y")[0]
+ return math_ops.add(x_read, y_read)
+
+ cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)
+ self.assertEquals(cnd[0].eval(), 7)
+
+ def testCollectionIntValueWriteInCond(self):
+ """Make sure Int writes to collections work inside of cond_v2."""
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ x = constant_op.constant(2)
+ y = constant_op.constant(5)
+ def true_fn():
+ z = math_ops.add(x, y)
+ ops.add_to_collection("z", 7)
+ return math_ops.mul(x, z)
+
+ def false_fn():
+ z = math_ops.add(x, y)
+ return math_ops.mul(x, z)
+
+ cnd = cond_v2.cond_v2(
+ True, true_fn,
+ false_fn)
+ self.assertEquals(cnd[0].eval(), 14)
+
+ read_z_collection = ops.get_collection("z")
+ self.assertEquals(read_z_collection, [7])
+
+
+class CondV2ContainerTest(test.TestCase):
+
+ def testContainer(self):
+ """Set containers outside & inside of cond_v2.
+
+ Make sure the containers are set correctly for both variable creation
+ (tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
+ """
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+
+ v0 = variables.Variable([0])
+ q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ def container(node):
+ return node.op.get_attr("container")
+
+ self.assertEqual(compat.as_bytes(""), container(v0))
+ self.assertEqual(compat.as_bytes(""), container(q0.queue_ref))
+
+ def true_fn():
+ # When this branch is created in cond below,
+ # the container should begin with 'l1'
+ v1 = variables.Variable([1])
+ q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ with ops.container("l2t"):
+ v2 = variables.Variable([2])
+ q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ v3 = variables.Variable([1])
+ q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ self.assertEqual(compat.as_bytes("l1"), container(v1))
+ self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
+ self.assertEqual(compat.as_bytes("l2t"), container(v2))
+ self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref))
+ self.assertEqual(compat.as_bytes("l1"), container(v3))
+ self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))
+
+ return constant_op.constant(2.0)
+
+ def false_fn():
+ # When this branch is created in cond below,
+ # the container should begin with 'l1'
+ v1 = variables.Variable([1])
+ q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ with ops.container("l2f"):
+ v2 = variables.Variable([2])
+ q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ v3 = variables.Variable([1])
+ q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ self.assertEqual(compat.as_bytes("l1"), container(v1))
+ self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
+ self.assertEqual(compat.as_bytes("l2f"), container(v2))
+ self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref))
+ self.assertEqual(compat.as_bytes("l1"), container(v3))
+ self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))
+
+ return constant_op.constant(6.0)
+
+ with ops.container("l1"):
+ cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
+ self.assertEquals(cnd_true[0].eval(), 2)
+
+ cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
+ self.assertEquals(cnd_false[0].eval(), 6)
+
+ v4 = variables.Variable([3])
+ q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+ v5 = variables.Variable([4])
+ q5 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ self.assertEqual(compat.as_bytes("l1"), container(v4))
+ self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref))
+ self.assertEqual(compat.as_bytes(""), container(v5))
+ self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
+
+
+class CondV2ColocationGroupAndDeviceTest(test.TestCase):
+
+ def testColocateWithBeforeCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+
+ a = constant_op.constant([2.0], name="a")
+ b = constant_op.constant([2.0], name="b")
+
+ def fn():
+ c = constant_op.constant(3.0)
+ self.assertEqual([b"loc:@a"], c.op.colocation_groups())
+ return c
+
+ with ops.colocate_with(a.op):
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+
+ def fn2():
+ c = constant_op.constant(3.0)
+ self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())
+ return c
+
+ with ops.colocate_with(a.op):
+ with ops.colocate_with(b.op):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ def testColocateWithInAndOutOfCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+
+ a = constant_op.constant([2.0], name="a")
+ b = constant_op.constant([2.0], name="b")
+
+ def fn2():
+ with ops.colocate_with(b.op):
+ c = constant_op.constant(3.0)
+ self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())
+ return c
+
+ with ops.colocate_with(a.op):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ d = constant_op.constant([2.0], name="d")
+ self.assertEqual([b"loc:@a"], d.op.colocation_groups())
+
+ def testDeviceBeforeCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ def fn():
+ c = constant_op.constant(3.0)
+ self.assertEqual("/device:CPU:0", c.op.device)
+ return c
+
+ with ops.device("/device:CPU:0"):
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+
+ def fn2():
+ c = constant_op.constant(3.0)
+ self.assertEqual("/device:GPU:0", c.op.device)
+ return c
+
+ with ops.device("/device:GPU:0"):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ def testDeviceInAndOutOfCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ def fn2():
+ with ops.device("/device:GPU:0"):
+ c = constant_op.constant(3.0)
+ self.assertEqual("/device:GPU:0", c.op.device)
+ return c
+
+ with ops.device("/device:CPU:0"):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ d = constant_op.constant(4.0)
+ self.assertEqual("/device:CPU:0", d.op.device)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 82ecba310b..002a3d3be5 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import compat
+from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -650,6 +651,41 @@ class _FuncGraph(ops.Graph):
# TODO(skyewm): is this needed?
self.extra_vars = []
+ # pylint: disable=g-doc-return-or-yield
+
+ @tf_contextlib.contextmanager
+ def container(self, container_name):
+ """Returns a context manager that specifies the resource container to use.
+
+ Overridden from @{tf.Graph} to update both the init_scope container
+ and the present inner container. This is necessary to make sure setting
+ containers applies correctly both to created variables and to stateful
+ ops.
+
+ Args:
+ container_name: container name string.
+
+ Returns:
+ A context manager for defining resource containers for stateful ops,
+ yields the container name.
+ """
+ original_container = self._container
+ # pylint: disable=protected-access
+ with ops.init_scope():
+ original_init_container = ops.get_default_graph()._container
+ try:
+ self._container = container_name
+ with ops.init_scope():
+ ops.get_default_graph()._container = container_name
+ yield self._container
+ finally:
+ self._container = original_container
+ with ops.init_scope():
+ ops.get_default_graph()._container = original_init_container
+ # pylint: enable=protected-access
+
+ # pylint: enable=g-doc-return-or-yield
+
def getvar(
self,
getter,
@@ -773,7 +809,9 @@ class _FuncGraph(ops.Graph):
def func_graph_from_py_func(func, arg_names, arg_types, name=None,
- capture_by_value=False, device=None):
+ capture_by_value=False, device=None,
+ colocation_stack=None, container=None,
+ collections_ref=None):
"""Returns a _FuncGraph generated from `func`.
Args:
@@ -786,6 +824,10 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None,
capture_by_value: boolean. If True, captured values will be copied into the
function body.
device: device name or function.
+ colocation_stack: A colocation stack (list) the _FuncGraph should use.
+ container: A container name the _FuncGraph should start with.
+ collections_ref: A reference to a collections dict the _FuncGraph should
+ use internally.
Returns:
A _FuncGraph.
@@ -796,7 +838,17 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None,
if not name:
name = _get_func_name(func)
func_graph = _FuncGraph(name, capture_by_value)
+
with func_graph.as_default(), ops.device(device):
+ # pylint: disable=protected-access
+ if collections_ref is not None:
+ func_graph._collections = collections_ref
+ if container is not None:
+ func_graph._container = container
+ if colocation_stack is not None:
+ func_graph._colocation_stack = colocation_stack
+ # pylint: enable=protected-access
+
# Create placeholders for the function arguments.
for (argname, argtype) in zip(arg_names, arg_types):
argholder = array_ops.placeholder(argtype, name=argname)