From 01365dbc2c257ff2ab409a2a5122a06739272737 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Oct 2017 14:10:14 -0700 Subject: Allow lists to be passed to tf.group(). PiperOrigin-RevId: 173308794 --- tensorflow/python/ops/control_flow_ops.py | 7 +++++-- tensorflow/python/ops/control_flow_ops_test.py | 14 +++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index dcdbeefb70..10d8e01304 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -2910,7 +2910,7 @@ def _GroupControlDeps(dev, deps, name=None): def group(*inputs, **kwargs): """Create an op that groups multiple operations. - When this op finishes, all ops in `input` have finished. This op has no + When this op finishes, all ops in `inputs` have finished. This op has no output. See also @{tf.tuple$tuple} and @@ -2938,7 +2938,10 @@ def group(*inputs, **kwargs): # Sorts *inputs according to their devices. ops_on_device = {} # device -> operations specified on the device. - for inp in inputs: + for inp in nest.flatten(inputs): + if not hasattr(inp, "device"): + raise TypeError("Expected tf.group() expected Tensor arguments not " + "'%s' with type '%s'" % (inp, type(inp))) if not hasattr(inp, "device"): if isinstance(inp, list): raise TypeError("To call tf.group() with a list, use " diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 34c405f293..3e8f39dd24 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -115,11 +115,19 @@ class GroupTestCase(test_util.TensorFlowTestCase): """, self._StripGraph(gd)) def testPassingList(self): - with ops.Graph().as_default(): + with ops.Graph().as_default() as g: a = constant_op.constant(0, name="a") b = constant_op.constant(0, name="b") - with self.assertRaises(TypeError): - control_flow_ops.group([a.op, b.op]) + control_flow_ops.group([a.op, b.op], name="root") + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "a" op: "Const"} + node { name: "b" op: "Const"} + node { name: "root" op: "NoOp" input: "^a" input: "^b" } + """, self._StripGraph(gd)) + + def testPassingNonTensors(self): + with ops.Graph().as_default(): with self.assertRaises(TypeError): control_flow_ops.group(1, 2) -- cgit v1.2.3