1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
"""Tests for control_flow_ops.py."""
import tensorflow.python.platform
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import ops
from tensorflow.python.framework.test_util import TensorFlowTestCase
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import standard_ops as tf
from tensorflow.python.platform import googletest
class GroupTestCase(TensorFlowTestCase):
def _StripNode(self, nd):
snode = graph_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
if nd.device:
snode.device = nd.device
return snode
def _StripGraph(self, gd):
"""Copy gd keeping only, node.name, node.op, node.input, and node.device."""
return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
def testGroup_NoDevices(self):
with ops.Graph().as_default() as g:
a = tf.constant(0, name="a")
b = tf.constant(0, name="b")
c = tf.constant(0, name="c")
tf.group(a.op, b.op, c.op, name="root")
gd = g.as_graph_def()
self.assertProtoEquals("""
node { name: "a" op: "Const"}
node { name: "b" op: "Const"}
node { name: "c" op: "Const"}
node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" }
""", self._StripGraph(gd))
def testGroup_OneDevice(self):
with ops.Graph().as_default() as g:
with g.device("/task:0"):
a = tf.constant(0, name="a")
b = tf.constant(0, name="b")
tf.group(a.op, b.op, name="root")
gd = g.as_graph_def()
self.assertProtoEquals("""
node { name: "a" op: "Const" device: "/task:0" }
node { name: "b" op: "Const" device: "/task:0" }
node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
""", self._StripGraph(gd))
def testGroup_MultiDevice(self):
with ops.Graph().as_default() as g:
with g.device("/task:0"):
a = tf.constant(0, name="a")
b = tf.constant(0, name="b")
with g.device("/task:1"):
c = tf.constant(0, name="c")
d = tf.constant(0, name="d")
with g.device("/task:2"):
tf.group(a.op, b.op, c.op, d.op, name="root")
gd = g.as_graph_def()
self.assertProtoEquals("""
node { name: "a" op: "Const" device: "/task:0"}
node { name: "b" op: "Const" device: "/task:0"}
node { name: "c" op: "Const" device: "/task:1"}
node { name: "d" op: "Const" device: "/task:1"}
node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b"
device: "/task:0" }
node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d"
device: "/task:1" }
node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1"
device: "/task:2" }
""", self._StripGraph(gd))
class ShapeTestCase(TensorFlowTestCase):
def testShape(self):
with ops.Graph().as_default():
tensor = tf.constant([1.0, 2.0])
self.assertEquals([2], tensor.get_shape())
self.assertEquals([2],
control_flow_ops.with_dependencies(
[tf.constant(1.0)], tensor).get_shape())
if __name__ == "__main__":
googletest.main()
|