aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-24 14:10:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-24 14:14:11 -0700
commit01365dbc2c257ff2ab409a2a5122a06739272737 (patch)
tree4a9f7f29a92d91ff4f230dfe984411efe46675a3
parentde1b4a8a75ae3a50f4fa7480efb1177d79abf553 (diff)
Allow lists to be passed to tf.group().
PiperOrigin-RevId: 173308794
-rw-r--r--tensorflow/python/ops/control_flow_ops.py7
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py14
2 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index dcdbeefb70..10d8e01304 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -2910,7 +2910,7 @@ def _GroupControlDeps(dev, deps, name=None):
def group(*inputs, **kwargs):
"""Create an op that groups multiple operations.
- When this op finishes, all ops in `input` have finished. This op has no
+ When this op finishes, all ops in `inputs` have finished. This op has no
output.
See also @{tf.tuple$tuple} and
@@ -2938,7 +2938,10 @@ def group(*inputs, **kwargs):
# Sorts *inputs according to their devices.
ops_on_device = {} # device -> operations specified on the device.
- for inp in inputs:
+ for inp in nest.flatten(inputs):
+ if not hasattr(inp, "device"):
+ raise TypeError("Expected tf.group() expected Tensor arguments not "
+ "'%s' with type '%s'" % (inp, type(inp)))
if not hasattr(inp, "device"):
if isinstance(inp, list):
raise TypeError("To call tf.group() with a list, use "
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 34c405f293..3e8f39dd24 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -115,11 +115,19 @@ class GroupTestCase(test_util.TensorFlowTestCase):
""", self._StripGraph(gd))
def testPassingList(self):
- with ops.Graph().as_default():
+ with ops.Graph().as_default() as g:
a = constant_op.constant(0, name="a")
b = constant_op.constant(0, name="b")
- with self.assertRaises(TypeError):
- control_flow_ops.group([a.op, b.op])
+ control_flow_ops.group([a.op, b.op], name="root")
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "a" op: "Const"}
+ node { name: "b" op: "Const"}
+ node { name: "root" op: "NoOp" input: "^a" input: "^b" }
+ """, self._StripGraph(gd))
+
+ def testPassingNonTensors(self):
+ with ops.Graph().as_default():
with self.assertRaises(TypeError):
control_flow_ops.group(1, 2)