diff options
Diffstat (limited to 'tensorflow/python/ops/control_flow_ops_test.py')
-rw-r--r-- | tensorflow/python/ops/control_flow_ops_test.py | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py new file mode 100644 index 0000000000..34b1ab0a25 --- /dev/null +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -0,0 +1,88 @@ +"""Tests for control_flow_ops.py.""" +import tensorflow.python.platform + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import standard_ops as tf +from tensorflow.python.platform import googletest + + +class GroupTestCase(TensorFlowTestCase): + + def _StripNode(self, nd): + snode = graph_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) + if nd.device: + snode.device = nd.device + return snode + + def _StripGraph(self, gd): + """Copy gd keeping only, node.name, node.op, node.input, and node.device.""" + return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node]) + + def testGroup_NoDevices(self): + with ops.Graph().as_default() as g: + a = tf.constant(0, name="a") + b = tf.constant(0, name="b") + c = tf.constant(0, name="c") + tf.group(a.op, b.op, c.op, name="root") + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "a" op: "Const"} + node { name: "b" op: "Const"} + node { name: "c" op: "Const"} + node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" } + """, self._StripGraph(gd)) + + def testGroup_OneDevice(self): + with ops.Graph().as_default() as g: + with g.device("/task:0"): + a = tf.constant(0, name="a") + b = tf.constant(0, name="b") + tf.group(a.op, b.op, name="root") + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "a" op: "Const" device: "/task:0" } + node { name: "b" op: "Const" device: "/task:0" } + node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } + """, self._StripGraph(gd)) + + def testGroup_MultiDevice(self): + with ops.Graph().as_default() as g: + with g.device("/task:0"): + a = tf.constant(0, name="a") + b = tf.constant(0, name="b") + with g.device("/task:1"): + c = tf.constant(0, name="c") + d = tf.constant(0, name="d") + with g.device("/task:2"): + tf.group(a.op, b.op, c.op, d.op, name="root") + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "a" op: "Const" device: "/task:0"} + node { name: "b" op: "Const" device: "/task:0"} + node { name: "c" op: "Const" device: "/task:1"} + node { name: "d" op: "Const" device: "/task:1"} + node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" + device: "/task:0" } + node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" + device: "/task:1" } + node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" + device: "/task:2" } + """, self._StripGraph(gd)) + + +class ShapeTestCase(TensorFlowTestCase): + + def testShape(self): + with ops.Graph().as_default(): + tensor = tf.constant([1.0, 2.0]) + self.assertEquals([2], tensor.get_shape()) + self.assertEquals([2], + control_flow_ops.with_dependencies( + [tf.constant(1.0)], tensor).get_shape()) + + +if __name__ == "__main__": + googletest.main() |