"""Tests for control_flow_ops.py.""" import tensorflow.python.platform from tensorflow.core.framework import graph_pb2 from tensorflow.python.framework import ops from tensorflow.python.framework.test_util import TensorFlowTestCase from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import standard_ops as tf from tensorflow.python.platform import googletest class GroupTestCase(TensorFlowTestCase): def _StripNode(self, nd): snode = graph_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) if nd.device: snode.device = nd.device return snode def _StripGraph(self, gd): """Copy gd keeping only, node.name, node.op, node.input, and node.device.""" return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node]) def testGroup_NoDevices(self): with ops.Graph().as_default() as g: a = tf.constant(0, name="a") b = tf.constant(0, name="b") c = tf.constant(0, name="c") tf.group(a.op, b.op, c.op, name="root") gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "a" op: "Const"} node { name: "b" op: "Const"} node { name: "c" op: "Const"} node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" } """, self._StripGraph(gd)) def testGroup_OneDevice(self): with ops.Graph().as_default() as g: with g.device("/task:0"): a = tf.constant(0, name="a") b = tf.constant(0, name="b") tf.group(a.op, b.op, name="root") gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "a" op: "Const" device: "/task:0" } node { name: "b" op: "Const" device: "/task:0" } node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } """, self._StripGraph(gd)) def testGroup_MultiDevice(self): with ops.Graph().as_default() as g: with g.device("/task:0"): a = tf.constant(0, name="a") b = tf.constant(0, name="b") with g.device("/task:1"): c = tf.constant(0, name="c") d = tf.constant(0, name="d") with g.device("/task:2"): tf.group(a.op, b.op, c.op, d.op, name="root") gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "a" op: "Const" device: "/task:0"} node { name: "b" op: "Const" device: "/task:0"} node { name: "c" op: "Const" device: "/task:1"} node { name: "d" op: "Const" device: "/task:1"} node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" device: "/task:1" } node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" device: "/task:2" } """, self._StripGraph(gd)) class ShapeTestCase(TensorFlowTestCase): def testShape(self): with ops.Graph().as_default(): tensor = tf.constant([1.0, 2.0]) self.assertEquals([2], tensor.get_shape()) self.assertEquals([2], control_flow_ops.with_dependencies( [tf.constant(1.0)], tensor).get_shape()) if __name__ == "__main__": googletest.main()