diff options
Diffstat (limited to 'tensorflow/python/framework/ops_test.py')
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 825 |
1 files changed, 825 insertions, 0 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py new file mode 100644 index 0000000000..a406c5e56e --- /dev/null +++ b/tensorflow/python/framework/ops_test.py @@ -0,0 +1,825 @@ +"""Tests for tensorflow.python.framework.ops.""" +import tensorflow.python.platform + +from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_kernel_label_op +from tensorflow.python.framework import test_util +from tensorflow.python.framework import types +from tensorflow.python.ops import common_shapes +from tensorflow.python.platform import googletest + + +class TensorTest(test_util.TensorFlowTestCase): + + def testShape(self): + op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), + [], [types.float32]) + t = op.outputs[0] + self.assertEquals(tensor_shape.unknown_shape(), t.get_shape()) + t.set_shape([1, 2, 3]) + self.assertEquals([1, 2, 3], t.get_shape()) + + +class NodeDefConstructorTest(test_util.TensorFlowTestCase): + + def testNoArgs(self): + nodedef = ops._NodeDef("noop", "bar") + self.assertProtoEquals("op: 'noop' name: 'bar'", nodedef) + + def testArgs(self): + nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") + self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", + nodedef) + nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j")) + self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef) + + +# NOTE(mrry): Dummy shape registrations for ops used in the tests. +ops.RegisterShape("a")(None) +ops.RegisterShape("b")(None) +ops.RegisterShape("c")(None) +ops.RegisterShape("add")(None) +ops.RegisterShape("an_op")(None) +ops.RegisterShape("const")(None) +ops.RegisterShape("copy")(None) +ops.RegisterShape("foo")(None) +ops.RegisterShape("identity")(None) +ops.RegisterShape("mul")(None) +ops.RegisterShape("nonrefop")(None) +ops.RegisterShape("noop")(None) +ops.RegisterShape("refop")(None) + + +def _apply_op(g, *args, **kwargs): + op = g.create_op(*args, **kwargs) + if len(op.outputs) == 1: + return op.outputs[0] + else: + return op.outputs + + +class OperationTest(test_util.TensorFlowTestCase): + + def testNoInputs(self): + op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), + [], + [types.float32, types.string]) + self.assertEquals(2, len(op.values())) + self.assertEquals(0, len(op.inputs)) + self.assertEquals("myop", op.name) + + float_t, label_str_t = op.values() + self.assertEquals(types.float32, float_t.dtype) + self.assertEquals(op, float_t.op) + self.assertEquals(0, float_t._value_index) + self.assertEquals(0, len(float_t._consumers)) + self.assertEquals("myop", float_t._as_node_def_input()) + + self.assertEquals(types.string, label_str_t.dtype) + self.assertEquals(op, label_str_t.op) + self.assertEquals(1, label_str_t._value_index) + self.assertEquals(0, len(label_str_t._consumers)) + self.assertEquals("myop:1", label_str_t._as_node_def_input()) + + self.assertProtoEquals("op:'noop' name:'myop'", op.node_def) + + def testNoOutputs(self): + g = ops.Graph() + op1 = ops.Operation( + ops._NodeDef("noop", "myop1"), g, [], [types.float32]) + float_t, = op1.values() + op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [float_t], []) + self.assertEquals(0, len(op2.values())) + self.assertEquals(1, len(op2.inputs)) + self.assertIs(float_t, op2.inputs[0]) + + self.assertEquals(1, len(float_t._consumers)) + self.assertEquals(op2, float_t._consumers[0]) + + self.assertProtoEquals("op:'noop' name:'myop1'", op1.node_def) + self.assertProtoEquals("op:'reop' name:'myop2' input:'myop1'", + op2.node_def) + + def testInputsAndOutputs(self): + g = ops.Graph() + op1 = ops.Operation( + ops._NodeDef("noop", "myop1"), g, [], [types.float32]) + self.assertEquals(1, len(op1.values())) + float1_t, = op1.values() + + op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, + [], [types.float32, types.string]) + self.assertEquals(2, len(op2.values())) + float2_t, label2_str_t = op2.values() + + # Note that we consume label2_str_t twice here. + op3 = ops.Operation(ops._NodeDef("add", "myop3"), g, + [float1_t, label2_str_t, label2_str_t], + [types.float32, types.int32]) + self.assertEquals(2, len(op3.values())) + + self.assertEquals(1, len(float1_t._consumers)) + self.assertEquals(op3, float1_t._consumers[0]) + + self.assertEquals(0, len(float2_t._consumers)) + + self.assertEquals(2, len(label2_str_t._consumers)) + self.assertEquals(op3, label2_str_t._consumers[0]) + self.assertEquals(op3, label2_str_t._consumers[1]) + + self.assertProtoEquals(""" + op:'add' name:'myop3' + input:'myop1' input:'myop2:1' input:'myop2:1' + """, op3.node_def) + + def testDeviceObject(self): + op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], []) + op._set_device("/job:goo/device:GPU:0") + self.assertProtoEquals( + "op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ", + op.node_def) + op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], []) + op._set_device(pydev.Device(job="muu", device_type="CPU", device_index=0)) + self.assertProtoEquals( + "op:'noop' name:'op2' device:'/job:muu/device:CPU:0'", + op.node_def) + + def testReferenceInput(self): + g = ops.Graph() + op1 = ops.Operation(ops._NodeDef("noop", "op1"), g, [], + [types.float32_ref, types.float32]) + self.assertProtoEquals("op:'noop' name:'op1'", + op1.node_def) + ref_t, nonref_t = op1.values() + # NOTE(mrry): Must specify input_types to preserve ref-typed input. + op2 = ops.Operation( + ops._NodeDef("refop", "op2"), g, [ref_t, nonref_t], [], + input_types=[types.float32_ref, types.float32]) + self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", + op2.node_def) + op3 = ops.Operation( + ops._NodeDef("nonrefop", "op3"), g, [ref_t, nonref_t], []) + self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", + op3.node_def) + + def testInvalidNames(self): + g = ops.Graph() + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", ""), g) + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", "_invalid"), g) + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", "-invalid"), g) + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", "/invalid"), g) + + def testShapeFunctionAbsence(self): + def _test(): + pass + g = ops.Graph() + with self.assertRaises(RuntimeError): + g.create_op("shapeless_op", [], [types.float32]) + + def testNoShapeFunction(self): + g = ops.Graph() + op = ops.Operation(ops._NodeDef("op", "an_op"), g, + output_types = [types.float32]) + self.assertEquals(tensor_shape.unknown_shape(), + _apply_op(g, "an_op", [], [types.float32]).get_shape()) + +class CreateOpTest(test_util.TensorFlowTestCase): + + def testNodeDefArgs(self): + g = ops.Graph() + op1 = g.create_op("const", [], [types.float32], None, name="myop1") + with g.device("/device:GPU"): + op2 = g.create_op("add", + [], + [types.float32, types.string], None, + name="myop2") + op3 = g.create_op( + "foo", + [op1.values()[0], op2.values()[1], op2.values()[0]], + [types.float32, types.int32], None, + name="myop3") + self.assertEquals(None, op1.device) + self.assertEquals("/device:GPU", op2.device) + self.assertEquals(None, op3.device) + self.assertProtoEquals("name:'myop1' op:'const'", op1.node_def) + self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'", + op2.node_def) + self.assertProtoEquals( + "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'", + op3.node_def) + + def testReferenceInput(self): + g = ops.Graph() + op1 = g.create_op("noop", [], + [types.float32_ref, types.float32], name="op1") + self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def) + ref_t, nonref_t = op1.values() + # NOTE(mrry): Must specify input_types to preserve ref-typed input. + op2 = g.create_op("refop", [ref_t, nonref_t], [], + input_types=[types.float32_ref, types.float32], + name="op2") + self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", + op2.node_def) + op3 = g.create_op("nonrefop", [ref_t, nonref_t], [], name="op3") + self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", + op3.node_def) + + def testFinalized(self): + g = ops.Graph() + g.finalize() + with self.assertRaises(RuntimeError): + g.create_op("const", [], [types.float32], None, name="myop1") + + +class ApplyOpTest(test_util.TensorFlowTestCase): + + def testNodeDefArgs(self): + g = ops.Graph() + t1 = _apply_op(g, "const", [], [types.float32], name="myop1") + with g.device("/device:GPU"): + t2 = _apply_op(g, "add", + [], + [types.float32, types.string], + name="myop2") + t3 = _apply_op(g, "foo", [t1, t2[1], t2[0]], + [types.float32, types.int32], name="myop3") + self.assertTrue(isinstance(t1, ops.Tensor)) + self.assertTrue(isinstance(t2, list)) + self.assertTrue(isinstance(t3, list)) + self.assertTrue(isinstance(t3[0], ops.Tensor)) + self.assertEquals("myop1", t1._as_node_def_input()) + self.assertEquals("myop2", t2[0]._as_node_def_input()) + self.assertEquals("myop2:1", t2[1]._as_node_def_input()) + self.assertEquals("myop3", t3[0]._as_node_def_input()) + # Validate that we got the right ops as well + self.assertProtoEquals("name:'myop1' op:'const'", t1.op.node_def) + self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'", + t2[0].op.node_def) + self.assertProtoEquals( + "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'", + t3[0].op.node_def) + + def testReferenceInput(self): + g = ops.Graph() + ref_t, nonref_t = _apply_op( + g, "noop", [], [types.float32_ref, types.float32], name="op1") + self.assertProtoEquals("op:'noop' name:'op1'", ref_t.op.node_def) + # NOTE(mrry): Must specify input_types to preserve ref-typed input. + out_2 = _apply_op(g, "refop", [ref_t, nonref_t], [types.int32], + input_types=[types.float32_ref, types.float32], + name="op2") + self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", + out_2.op.node_def) + out_3 = _apply_op(g, "nonrefop", [ref_t, nonref_t], [types.int32], + name="op3") + self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", + out_3.op.node_def) + + +class NameStackTest(test_util.TensorFlowTestCase): + + def testBasics(self): + g = ops.Graph() + self.assertEquals("foo", g.unique_name("foo")) + self.assertEquals("foo_1", g.unique_name("foo")) + self.assertEquals("foo_2", g.unique_name("foo")) + self.assertEquals("foo_1_1", g.unique_name("foo_1")) + self.assertEquals("foo_1_2", g.unique_name("foo_1")) + self.assertEquals("foo_1_2_1", g.unique_name("foo_1_2")) + with g.name_scope("bar"): + self.assertEquals("bar/foo", g.unique_name("foo")) + self.assertEquals("bar/foo_1", g.unique_name("foo")) + with g.name_scope(None): + self.assertEquals("foo_3", g.unique_name("foo")) + with g.name_scope("baz"): + self.assertEquals("bar/baz/foo", g.unique_name("foo")) + self.assertEquals("bar/baz/foo_1", g.unique_name("foo")) + with g.name_scope("baz"): + self.assertEquals("bar/baz_1/foo", g.unique_name("foo")) + self.assertEquals("bar/baz_1/foo_1", g.unique_name("foo")) + with g.name_scope("quux"): + self.assertEquals("quux/foo", g.unique_name("foo")) + with g.name_scope("bar"): + with g.name_scope("baz"): + self.assertEquals("bar_1/baz/foo", g.unique_name("foo")) + self.assertEquals("foo_4", g.unique_name("foo")) + self.assertEquals("bar_2", g.unique_name("bar")) + + def testOutOfOrderUniqueName(self): + g = ops.Graph() + self.assertEquals("foo_2", g.unique_name("foo_2")) + self.assertEquals("foo", g.unique_name("foo")) + self.assertEquals("foo_1", g.unique_name("foo")) + self.assertEquals("foo_3", g.unique_name("foo")) + + +class NameTest(test_util.TensorFlowTestCase): + + def testGenerateName(self): + g = ops.Graph() + op0 = g.create_op("const", [], [types.float32, types.float32]) + self.assertEquals("const", op0.name) + self.assertEquals("const:0", op0.outputs[0].name) + self.assertEquals("const:1", op0.outputs[1].name) + + op1 = g.create_op("const", [], [types.float32]) + self.assertEquals("const_1", op1.name) + self.assertEquals("const_1:0", op1.outputs[0].name) + + op2 = g.create_op("const", [], [types.float32], name="my_op") + self.assertEquals("my_op", op2.name) + self.assertEquals("my_op:0", op2.outputs[0].name) + + def testname_scope(self): + g = ops.Graph() + + with g.name_scope("foo") as foo: + self.assertEquals(foo, "foo/") + with g.name_scope("foo2") as foo2: + self.assertEquals(foo2, "foo/foo2/") + with g.name_scope(None) as empty1: + self.assertEquals(empty1, "") + with g.name_scope("foo3") as foo3: + self.assertEquals(foo3, "foo3/") + with g.name_scope("") as empty2: + self.assertEquals(empty2, "") + + self.assertEquals("const", + g.create_op("const", [], [types.float32]).name) + with g.name_scope("bar") as scope: + self.assertEquals("bar/const", + g.create_op("const", [], [types.float32]).name) + self.assertEquals("bar/const_1", + g.create_op("const", [], [types.float32]).name) + # If you use the value from "with .. as", that values is used as-is. + self.assertEquals( + "bar", + g.create_op("const", [], [types.float32], name=scope).name) + with g.name_scope("baz") as scope: + with g.name_scope("quux"): + self.assertEquals("baz/quux/const", + g.create_op("const", [], [types.float32]).name) + # If you use the value from the enclosing "with .. as", nothing is pushed. + with g.name_scope(scope): + self.assertEquals("baz/const", + g.create_op("const", [], [types.float32]).name) + self.assertEquals("baz", + g.create_op("const", [], [types.float32], + name=scope).name) + self.assertEquals("trailing", + g.create_op("const", [], [types.float32], + name="trailing/").name) + with g.name_scope("bar"): + self.assertEquals("bar_1/const", + g.create_op("const", [], [types.float32]).name) + with g.name_scope("bar/"): + self.assertEquals("bar/const_2", + g.create_op("const", [], [types.float32]).name) + + +class DeviceTest(test_util.TensorFlowTestCase): + + def testNoDevice(self): + g = ops.Graph() + op = g.create_op("an_op", [], [types.float32]) + self.assertEqual(None, op.device) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" } + """, gd) + + def testDevicePartialString(self): + g = ops.Graph() + with g.device("/job:worker/replica:2"): + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" device: "/job:worker/replica:2" } + """, gd) + + def testDeviceFull(self): + g = ops.Graph() + with g.device(pydev.Device(job="worker", replica=2, task=0, + device_type="CPU", + device_index=3)): + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2/task:0/device:CPU:3" } + """, gd) + + def testNesting(self): + g = ops.Graph() + with g.device("/job:worker/replica:2"): + g.create_op("an_op", [], [types.float32]) + with g.device("/job:worker/replica:3/task:0"): + g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2" } + node { name: "an_op_1" op: "an_op" + device: "/job:worker/replica:3/task:0" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/replica:2" } + """, gd) + + def testNestingString(self): + g = ops.Graph() + with g.device("/job:worker/replica:2"): + g.create_op("an_op", [], [types.float32]) + with g.device("/job:worker/replica:3/task:0"): + g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2" } + node { name: "an_op_1" op: "an_op" + device: "/job:worker/replica:3/task:0" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/replica:2" } + """, gd) + + def testNestingOverrideGpuCpu(self): + g = ops.Graph() + with g.device("/job:worker/replica:2/device:CPU:1"): + g.create_op("an_op", [], [types.float32]) + with g.device("/job:worker/replica:2/device:GPU:2"): + g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2/device:CPU:1" } + node { name: "an_op_1" op: "an_op" + device: "/job:worker/replica:2/device:GPU:2" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/replica:2/device:CPU:1" } + """, gd) + + def testNestingWithMergeDeviceFunction(self): + g = ops.Graph() + + with g.device(pydev.merge_device("/device:GPU:0")): + g.create_op("an_op", [], [types.float32]) + with g.device(pydev.merge_device("/job:worker")): + g.create_op("an_op", [], [types.float32]) + with g.device(pydev.merge_device("/device:CPU:0")): + g.create_op("an_op", [], [types.float32]) + with g.device(pydev.merge_device("/job:ps")): + g.create_op("an_op", [], [types.float32]) + with g.device(pydev.merge_device(None)): + g.create_op("an_op", [], [types.float32]) + + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/device:GPU:0" } + node { name: "an_op_1" op: "an_op" + device: "/job:worker/device:GPU:0" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/device:CPU:0" } + node { name: "an_op_3" op: "an_op" + device: "/job:ps/device:CPU:0" } + node { name: "an_op_4" op: "an_op" + device: "/job:ps/device:CPU:0" } + """, gd) + + def testNoneClearsDefault(self): + g = ops.Graph() + with g.device("/job:worker/replica:2/device:CPU:1"): + g.create_op("an_op", [], [types.float32]) + with g.device(None): + g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2/device:CPU:1" } + node { name: "an_op_1" op: "an_op" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/replica:2/device:CPU:1" } + """, gd) + + +class ObjectWithName(object): + + def __init__(self, name): + self._name = name + + @property + def name(self): + return self._name + + +class CollectionTest(test_util.TensorFlowTestCase): + + def testadd_to_collection(self): + g = ops.Graph() + g.add_to_collection("key", 12) + g.add_to_collection("other", "foo") + g.add_to_collection("key", 34) + + # Note that only blank1 is returned. + g.add_to_collection("blah", 27) + blank1 = ObjectWithName("prefix/foo") + g.add_to_collection("blah", blank1) + blank2 = ObjectWithName("junk/foo") + g.add_to_collection("blah", blank2) + + self.assertEquals(["foo"], g.get_collection("other")) + self.assertEquals([12, 34], g.get_collection("key")) + self.assertEquals([], g.get_collection("nothing")) + self.assertEquals([27, blank1, blank2], g.get_collection("blah")) + self.assertEquals([blank1], g.get_collection("blah", "prefix")) + + def testDefaulGraph(self): + with ops.Graph().as_default(): + ops.add_to_collection("key", 90) + ops.add_to_collection("key", 100) + # Collections are ordered. + self.assertEquals([90, 100], ops.get_collection("key")) + + +def an_op(g): + return _apply_op(g, "an_op", [], [types.float32]) + + +ops.NoGradient("an_op") + + +def copy_op(x): + return _apply_op(x.graph, "copy", [x], [x.dtype]) + + +@ops.RegisterGradient("copy") +def _CopyGrad(op, x_grad): + _ = op + return x_grad + + +@ops.RegisterGradient("copy_override") +def _CopyOverrideGrad(op, x_grad): + _ = op + return x_grad + + +class RegistrationTest(test_util.TensorFlowTestCase): + + def testRegisterGradients(self): + g = ops.Graph() + x = an_op(g) + y = copy_op(x) + fn = ops.get_gradient_function(y.op) + self.assertEquals(_CopyGrad, fn) + + def testOverrideGradients(self): + g = ops.Graph() + x = an_op(g) + with g.gradient_override_map({"copy": "copy_override"}): + y = copy_op(x) + fn = ops.get_gradient_function(y.op) + self.assertEquals(_CopyOverrideGrad, fn) + + def testNonExistentOverride(self): + g = ops.Graph() + x = an_op(g) + with g.gradient_override_map({"copy": "unknown_override"}): + y = copy_op(x) + with self.assertRaisesRegexp(LookupError, "unknown_override"): + fn = ops.get_gradient_function(y.op) + + +class ComparisonTest(test_util.TensorFlowTestCase): + + def testMembershipAllowed(self): + g = ops.Graph() + t1 = _apply_op(g, "const", [], [types.float32], name="myop1") + t2 = _apply_op(g, "const", [], [types.float32], name="myop2") + self.assertTrue(isinstance(t1, ops.Tensor)) + self.assertTrue(isinstance(t2, ops.Tensor)) + self.assertTrue(t1 in [t1]) + self.assertTrue(t1 not in [t2]) + + +class ControlDependenciesTest(test_util.TensorFlowTestCase): + + def testBasic(self): + g = ops.Graph() + a = _apply_op(g, "const", [], [types.float32]) + b = _apply_op(g, "const", [], [types.float32]) + with g.control_dependencies([a]): + c = _apply_op(g, "const", [], [types.float32]) + d = _apply_op(g, "identity", [b], [types.float32]) + e = _apply_op(g, "identity", [c], [types.float32]) + + self.assertEqual(c.op.control_inputs, [a.op]) + self.assertEqual(d.op.control_inputs, [a.op]) + # e should be dominated by c. + self.assertEqual(e.op.control_inputs, []) + + def testNested(self): + g = ops.Graph() + a_1 = _apply_op(g, "const", [], [types.float32]) + a_2 = _apply_op(g, "const", [], [types.float32]) + a_3 = _apply_op(g, "const", [], [types.float32]) + a_4 = _apply_op(g, "const", [], [types.float32]) + + with g.control_dependencies([a_1, a_2, a_3, a_4]): + b_1 = _apply_op(g, "const", [], [types.float32]) + + with g.control_dependencies([a_1]): + with g.control_dependencies([a_2]): + with g.control_dependencies([a_3]): + with g.control_dependencies([a_4]): + b_2 = _apply_op(g, "const", [], [types.float32]) + + self.assertItemsEqual( + [a_1.op, a_2.op, a_3.op, a_4.op], b_1.op.control_inputs) + self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) + + def testComplex(self): + g = ops.Graph() + + # Usage pattern: + # * Nodes a_i are constants defined at the outermost scope, and are used + # as control inputs for the ith nested scope. + # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. + # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. + # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. + # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. + + a_1 = _apply_op(g, "const", [], [types.float32]) + a_2 = _apply_op(g, "const", [], [types.float32]) + a_3 = _apply_op(g, "const", [], [types.float32]) + a_4 = _apply_op(g, "const", [], [types.float32]) + + with g.control_dependencies([a_1]): + b_1 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) + c_1 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) + d_1 = _apply_op(g, "mul", [b_1, c_1], [types.float32]) + e_1 = _apply_op(g, "const", [], [types.float32]) + with g.control_dependencies([a_2]): + b_2 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) + c_2 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) + d_2 = _apply_op(g, "mul", [b_2, c_2], [types.float32]) + e_2 = _apply_op(g, "mul", [e_1, e_1], [types.float32]) + with g.control_dependencies([a_3]): + b_3 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) + c_3 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) + d_3 = _apply_op(g, "mul", [b_3, c_3], [types.float32]) + e_3 = _apply_op(g, "mul", [e_2, e_2], [types.float32]) + with g.control_dependencies([a_4]): + b_4 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) + c_4 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) + d_4 = _apply_op(g, "mul", [b_4, c_4], [types.float32]) + e_4 = _apply_op(g, "mul", [e_3, e_3], [types.float32]) + + self.assertItemsEqual([a_1.op], b_1.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) + + self.assertItemsEqual([], c_1.op.control_inputs) + self.assertItemsEqual([a_2.op], c_2.op.control_inputs) + self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) + self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) + + self.assertItemsEqual([], d_1.op.control_inputs) + self.assertItemsEqual([], d_2.op.control_inputs) + self.assertItemsEqual([], d_3.op.control_inputs) + self.assertItemsEqual([], d_4.op.control_inputs) + + self.assertItemsEqual([a_1.op], e_1.op.control_inputs) + self.assertItemsEqual([a_2.op], e_2.op.control_inputs) + self.assertItemsEqual([a_3.op], e_3.op.control_inputs) + self.assertItemsEqual([a_4.op], e_4.op.control_inputs) + + def testRepeatedDependency(self): + g = ops.Graph() + a = g.create_op("foo", [], [types.float32, types.float32]) + a_0, a_1 = a.outputs + with g.control_dependencies([a_0]): + b = _apply_op(g, "const", [], [types.float32]) + with g.control_dependencies([a_1]): + c = _apply_op(g, "const", [], [types.float32]) + + self.assertEqual(b.op.control_inputs, [a]) + self.assertEqual(c.op.control_inputs, [a]) + + def testNoControlDependencyWithDataDependency(self): + g = ops.Graph() + a = _apply_op(g, "const", [], [types.float32]) + with g.control_dependencies([a]): + b = _apply_op(g, "identity", [a], [types.float32]) + + self.assertEqual(b.op.control_inputs, []) + + +class GraphTest(test_util.TensorFlowTestCase): + + def setUp(self): + ops.reset_default_graph() + + def _AssertDefault(self, expected): + self.assertIs(expected, ops.get_default_graph()) + + def testGraphContextManager(self): + g0 = ops.Graph() + with g0.as_default() as g1: + self.assertIs(g0, g1) + + def testDefaultGraph(self): + orig = ops.get_default_graph() + self._AssertDefault(orig) + g0 = ops.Graph() + self._AssertDefault(orig) + context_manager_0 = g0.as_default() + self._AssertDefault(orig) + with context_manager_0 as g0: + self._AssertDefault(g0) + with ops.Graph().as_default() as g1: + self._AssertDefault(g1) + self._AssertDefault(g0) + self._AssertDefault(orig) + + def testAsGraphElementConversions(self): + class ConvertibleObj(object): + + def _as_graph_element(self): + return "const:0" + + class NonConvertibleObj(object): + + pass + + g = ops.Graph() + a = _apply_op(g, "const", [], [types.float32]) + self.assertEqual(a, g.as_graph_element(ConvertibleObj())) + with self.assertRaises(TypeError): + g.as_graph_element(NonConvertibleObj()) + + def testAssertSameGraph(self): + g0 = ops.Graph() + a = g0.create_op("a", [], [types.float32]) + b = g0.create_op("b", [], [types.float32]) + ops.assert_same_graph([a, b]) + ops.assert_same_graph([a, b], g0) + g1 = ops.Graph() + c = g1.create_op("c", [], [types.float32]) + self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c]) + self.assertRaises(ValueError, ops.assert_same_graph, [c], g0) + self.assertRaises(ValueError, ops.assert_same_graph, [a], g1) + + sparse = ops.SparseTensor( + _apply_op(g0, "const", [], [types.int64]), + _apply_op(g0, "const", [], [types.float32]), + _apply_op(g0, "const", [], [types.int64])) + ops.assert_same_graph([sparse, a, b]) + ops.assert_same_graph([sparse, a, b], g0) + self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c]) + self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c], g1) + +ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) + + +class KernelLabelTest(test_util.TensorFlowTestCase): + + def testNoLabel(self): + with self.test_session(): + self.assertAllEqual("My label is: default", + test_kernel_label_op.kernel_label().eval()) + + def testLabelMap(self): + with self.test_session() as sess: + default_1 = test_kernel_label_op.kernel_label() + # pylint: disable=protected-access + with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): + overload_1_1 = test_kernel_label_op.kernel_label() + with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}): + overload_2 = test_kernel_label_op.kernel_label() + with sess.graph._kernel_label_map({"KernelLabel": ""}): + default_2 = test_kernel_label_op.kernel_label() + overload_1_2 = test_kernel_label_op.kernel_label() + # pylint: enable=protected-access + default_3 = test_kernel_label_op.kernel_label() + + self.assertAllEqual("My label is: default", default_1.eval()) + self.assertAllEqual("My label is: default", default_2.eval()) + self.assertAllEqual("My label is: default", default_3.eval()) + self.assertAllEqual("My label is: overload_1", overload_1_1.eval()) + self.assertAllEqual("My label is: overload_1", overload_1_2.eval()) + self.assertAllEqual("My label is: overload_2", overload_2.eval()) + + +if __name__ == "__main__": + googletest.main() |