aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/ops_test.py')
-rw-r--r--tensorflow/python/framework/ops_test.py825
1 files changed, 825 insertions, 0 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
new file mode 100644
index 0000000000..a406c5e56e
--- /dev/null
+++ b/tensorflow/python/framework/ops_test.py
@@ -0,0 +1,825 @@
+"""Tests for tensorflow.python.framework.ops."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_kernel_label_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.platform import googletest
+
+
+class TensorTest(test_util.TensorFlowTestCase):
+
+ def testShape(self):
+ op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(),
+ [], [types.float32])
+ t = op.outputs[0]
+ self.assertEquals(tensor_shape.unknown_shape(), t.get_shape())
+ t.set_shape([1, 2, 3])
+ self.assertEquals([1, 2, 3], t.get_shape())
+
+
+class NodeDefConstructorTest(test_util.TensorFlowTestCase):
+
+ def testNoArgs(self):
+ nodedef = ops._NodeDef("noop", "bar")
+ self.assertProtoEquals("op: 'noop' name: 'bar'", nodedef)
+
+ def testArgs(self):
+ nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
+ self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
+ nodedef)
+ nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j"))
+ self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)
+
+
+# NOTE(mrry): Dummy shape registrations for ops used in the tests.
+ops.RegisterShape("a")(None)
+ops.RegisterShape("b")(None)
+ops.RegisterShape("c")(None)
+ops.RegisterShape("add")(None)
+ops.RegisterShape("an_op")(None)
+ops.RegisterShape("const")(None)
+ops.RegisterShape("copy")(None)
+ops.RegisterShape("foo")(None)
+ops.RegisterShape("identity")(None)
+ops.RegisterShape("mul")(None)
+ops.RegisterShape("nonrefop")(None)
+ops.RegisterShape("noop")(None)
+ops.RegisterShape("refop")(None)
+
+
+def _apply_op(g, *args, **kwargs):
+ op = g.create_op(*args, **kwargs)
+ if len(op.outputs) == 1:
+ return op.outputs[0]
+ else:
+ return op.outputs
+
+
+class OperationTest(test_util.TensorFlowTestCase):
+
+ def testNoInputs(self):
+ op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(),
+ [],
+ [types.float32, types.string])
+ self.assertEquals(2, len(op.values()))
+ self.assertEquals(0, len(op.inputs))
+ self.assertEquals("myop", op.name)
+
+ float_t, label_str_t = op.values()
+ self.assertEquals(types.float32, float_t.dtype)
+ self.assertEquals(op, float_t.op)
+ self.assertEquals(0, float_t._value_index)
+ self.assertEquals(0, len(float_t._consumers))
+ self.assertEquals("myop", float_t._as_node_def_input())
+
+ self.assertEquals(types.string, label_str_t.dtype)
+ self.assertEquals(op, label_str_t.op)
+ self.assertEquals(1, label_str_t._value_index)
+ self.assertEquals(0, len(label_str_t._consumers))
+ self.assertEquals("myop:1", label_str_t._as_node_def_input())
+
+ self.assertProtoEquals("op:'noop' name:'myop'", op.node_def)
+
+ def testNoOutputs(self):
+ g = ops.Graph()
+ op1 = ops.Operation(
+ ops._NodeDef("noop", "myop1"), g, [], [types.float32])
+ float_t, = op1.values()
+ op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [float_t], [])
+ self.assertEquals(0, len(op2.values()))
+ self.assertEquals(1, len(op2.inputs))
+ self.assertIs(float_t, op2.inputs[0])
+
+ self.assertEquals(1, len(float_t._consumers))
+ self.assertEquals(op2, float_t._consumers[0])
+
+ self.assertProtoEquals("op:'noop' name:'myop1'", op1.node_def)
+ self.assertProtoEquals("op:'reop' name:'myop2' input:'myop1'",
+ op2.node_def)
+
+ def testInputsAndOutputs(self):
+ g = ops.Graph()
+ op1 = ops.Operation(
+ ops._NodeDef("noop", "myop1"), g, [], [types.float32])
+ self.assertEquals(1, len(op1.values()))
+ float1_t, = op1.values()
+
+ op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g,
+ [], [types.float32, types.string])
+ self.assertEquals(2, len(op2.values()))
+ float2_t, label2_str_t = op2.values()
+
+ # Note that we consume label2_str_t twice here.
+ op3 = ops.Operation(ops._NodeDef("add", "myop3"), g,
+ [float1_t, label2_str_t, label2_str_t],
+ [types.float32, types.int32])
+ self.assertEquals(2, len(op3.values()))
+
+ self.assertEquals(1, len(float1_t._consumers))
+ self.assertEquals(op3, float1_t._consumers[0])
+
+ self.assertEquals(0, len(float2_t._consumers))
+
+ self.assertEquals(2, len(label2_str_t._consumers))
+ self.assertEquals(op3, label2_str_t._consumers[0])
+ self.assertEquals(op3, label2_str_t._consumers[1])
+
+ self.assertProtoEquals("""
+ op:'add' name:'myop3'
+ input:'myop1' input:'myop2:1' input:'myop2:1'
+ """, op3.node_def)
+
+ def testDeviceObject(self):
+ op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], [])
+ op._set_device("/job:goo/device:GPU:0")
+ self.assertProtoEquals(
+ "op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ",
+ op.node_def)
+ op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], [])
+ op._set_device(pydev.Device(job="muu", device_type="CPU", device_index=0))
+ self.assertProtoEquals(
+ "op:'noop' name:'op2' device:'/job:muu/device:CPU:0'",
+ op.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ op1 = ops.Operation(ops._NodeDef("noop", "op1"), g, [],
+ [types.float32_ref, types.float32])
+ self.assertProtoEquals("op:'noop' name:'op1'",
+ op1.node_def)
+ ref_t, nonref_t = op1.values()
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ op2 = ops.Operation(
+ ops._NodeDef("refop", "op2"), g, [ref_t, nonref_t], [],
+ input_types=[types.float32_ref, types.float32])
+ self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'",
+ op2.node_def)
+ op3 = ops.Operation(
+ ops._NodeDef("nonrefop", "op3"), g, [ref_t, nonref_t], [])
+ self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'",
+ op3.node_def)
+
+ def testInvalidNames(self):
+ g = ops.Graph()
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", ""), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "_invalid"), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "-invalid"), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "/invalid"), g)
+
+ def testShapeFunctionAbsence(self):
+ def _test():
+ pass
+ g = ops.Graph()
+ with self.assertRaises(RuntimeError):
+ g.create_op("shapeless_op", [], [types.float32])
+
+ def testNoShapeFunction(self):
+ g = ops.Graph()
+ op = ops.Operation(ops._NodeDef("op", "an_op"), g,
+ output_types = [types.float32])
+ self.assertEquals(tensor_shape.unknown_shape(),
+ _apply_op(g, "an_op", [], [types.float32]).get_shape())
+
+class CreateOpTest(test_util.TensorFlowTestCase):
+
+ def testNodeDefArgs(self):
+ g = ops.Graph()
+ op1 = g.create_op("const", [], [types.float32], None, name="myop1")
+ with g.device("/device:GPU"):
+ op2 = g.create_op("add",
+ [],
+ [types.float32, types.string], None,
+ name="myop2")
+ op3 = g.create_op(
+ "foo",
+ [op1.values()[0], op2.values()[1], op2.values()[0]],
+ [types.float32, types.int32], None,
+ name="myop3")
+ self.assertEquals(None, op1.device)
+ self.assertEquals("/device:GPU", op2.device)
+ self.assertEquals(None, op3.device)
+ self.assertProtoEquals("name:'myop1' op:'const'", op1.node_def)
+ self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'",
+ op2.node_def)
+ self.assertProtoEquals(
+ "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'",
+ op3.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ op1 = g.create_op("noop", [],
+ [types.float32_ref, types.float32], name="op1")
+ self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def)
+ ref_t, nonref_t = op1.values()
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ op2 = g.create_op("refop", [ref_t, nonref_t], [],
+ input_types=[types.float32_ref, types.float32],
+ name="op2")
+ self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'",
+ op2.node_def)
+ op3 = g.create_op("nonrefop", [ref_t, nonref_t], [], name="op3")
+ self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'",
+ op3.node_def)
+
+ def testFinalized(self):
+ g = ops.Graph()
+ g.finalize()
+ with self.assertRaises(RuntimeError):
+ g.create_op("const", [], [types.float32], None, name="myop1")
+
+
+class ApplyOpTest(test_util.TensorFlowTestCase):
+
+ def testNodeDefArgs(self):
+ g = ops.Graph()
+ t1 = _apply_op(g, "const", [], [types.float32], name="myop1")
+ with g.device("/device:GPU"):
+ t2 = _apply_op(g, "add",
+ [],
+ [types.float32, types.string],
+ name="myop2")
+ t3 = _apply_op(g, "foo", [t1, t2[1], t2[0]],
+ [types.float32, types.int32], name="myop3")
+ self.assertTrue(isinstance(t1, ops.Tensor))
+ self.assertTrue(isinstance(t2, list))
+ self.assertTrue(isinstance(t3, list))
+ self.assertTrue(isinstance(t3[0], ops.Tensor))
+ self.assertEquals("myop1", t1._as_node_def_input())
+ self.assertEquals("myop2", t2[0]._as_node_def_input())
+ self.assertEquals("myop2:1", t2[1]._as_node_def_input())
+ self.assertEquals("myop3", t3[0]._as_node_def_input())
+ # Validate that we got the right ops as well
+ self.assertProtoEquals("name:'myop1' op:'const'", t1.op.node_def)
+ self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'",
+ t2[0].op.node_def)
+ self.assertProtoEquals(
+ "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'",
+ t3[0].op.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ ref_t, nonref_t = _apply_op(
+ g, "noop", [], [types.float32_ref, types.float32], name="op1")
+ self.assertProtoEquals("op:'noop' name:'op1'", ref_t.op.node_def)
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ out_2 = _apply_op(g, "refop", [ref_t, nonref_t], [types.int32],
+ input_types=[types.float32_ref, types.float32],
+ name="op2")
+ self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'",
+ out_2.op.node_def)
+ out_3 = _apply_op(g, "nonrefop", [ref_t, nonref_t], [types.int32],
+ name="op3")
+ self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'",
+ out_3.op.node_def)
+
+
+class NameStackTest(test_util.TensorFlowTestCase):
+
+ def testBasics(self):
+ g = ops.Graph()
+ self.assertEquals("foo", g.unique_name("foo"))
+ self.assertEquals("foo_1", g.unique_name("foo"))
+ self.assertEquals("foo_2", g.unique_name("foo"))
+ self.assertEquals("foo_1_1", g.unique_name("foo_1"))
+ self.assertEquals("foo_1_2", g.unique_name("foo_1"))
+ self.assertEquals("foo_1_2_1", g.unique_name("foo_1_2"))
+ with g.name_scope("bar"):
+ self.assertEquals("bar/foo", g.unique_name("foo"))
+ self.assertEquals("bar/foo_1", g.unique_name("foo"))
+ with g.name_scope(None):
+ self.assertEquals("foo_3", g.unique_name("foo"))
+ with g.name_scope("baz"):
+ self.assertEquals("bar/baz/foo", g.unique_name("foo"))
+ self.assertEquals("bar/baz/foo_1", g.unique_name("foo"))
+ with g.name_scope("baz"):
+ self.assertEquals("bar/baz_1/foo", g.unique_name("foo"))
+ self.assertEquals("bar/baz_1/foo_1", g.unique_name("foo"))
+ with g.name_scope("quux"):
+ self.assertEquals("quux/foo", g.unique_name("foo"))
+ with g.name_scope("bar"):
+ with g.name_scope("baz"):
+ self.assertEquals("bar_1/baz/foo", g.unique_name("foo"))
+ self.assertEquals("foo_4", g.unique_name("foo"))
+ self.assertEquals("bar_2", g.unique_name("bar"))
+
+ def testOutOfOrderUniqueName(self):
+ g = ops.Graph()
+ self.assertEquals("foo_2", g.unique_name("foo_2"))
+ self.assertEquals("foo", g.unique_name("foo"))
+ self.assertEquals("foo_1", g.unique_name("foo"))
+ self.assertEquals("foo_3", g.unique_name("foo"))
+
+
+class NameTest(test_util.TensorFlowTestCase):
+
+ def testGenerateName(self):
+ g = ops.Graph()
+ op0 = g.create_op("const", [], [types.float32, types.float32])
+ self.assertEquals("const", op0.name)
+ self.assertEquals("const:0", op0.outputs[0].name)
+ self.assertEquals("const:1", op0.outputs[1].name)
+
+ op1 = g.create_op("const", [], [types.float32])
+ self.assertEquals("const_1", op1.name)
+ self.assertEquals("const_1:0", op1.outputs[0].name)
+
+ op2 = g.create_op("const", [], [types.float32], name="my_op")
+ self.assertEquals("my_op", op2.name)
+ self.assertEquals("my_op:0", op2.outputs[0].name)
+
+ def testname_scope(self):
+ g = ops.Graph()
+
+ with g.name_scope("foo") as foo:
+ self.assertEquals(foo, "foo/")
+ with g.name_scope("foo2") as foo2:
+ self.assertEquals(foo2, "foo/foo2/")
+ with g.name_scope(None) as empty1:
+ self.assertEquals(empty1, "")
+ with g.name_scope("foo3") as foo3:
+ self.assertEquals(foo3, "foo3/")
+ with g.name_scope("") as empty2:
+ self.assertEquals(empty2, "")
+
+ self.assertEquals("const",
+ g.create_op("const", [], [types.float32]).name)
+ with g.name_scope("bar") as scope:
+ self.assertEquals("bar/const",
+ g.create_op("const", [], [types.float32]).name)
+ self.assertEquals("bar/const_1",
+ g.create_op("const", [], [types.float32]).name)
+ # If you use the value from "with .. as", that values is used as-is.
+ self.assertEquals(
+ "bar",
+ g.create_op("const", [], [types.float32], name=scope).name)
+ with g.name_scope("baz") as scope:
+ with g.name_scope("quux"):
+ self.assertEquals("baz/quux/const",
+ g.create_op("const", [], [types.float32]).name)
+ # If you use the value from the enclosing "with .. as", nothing is pushed.
+ with g.name_scope(scope):
+ self.assertEquals("baz/const",
+ g.create_op("const", [], [types.float32]).name)
+ self.assertEquals("baz",
+ g.create_op("const", [], [types.float32],
+ name=scope).name)
+ self.assertEquals("trailing",
+ g.create_op("const", [], [types.float32],
+ name="trailing/").name)
+ with g.name_scope("bar"):
+ self.assertEquals("bar_1/const",
+ g.create_op("const", [], [types.float32]).name)
+ with g.name_scope("bar/"):
+ self.assertEquals("bar/const_2",
+ g.create_op("const", [], [types.float32]).name)
+
+
+class DeviceTest(test_util.TensorFlowTestCase):
+
+ def testNoDevice(self):
+ g = ops.Graph()
+ op = g.create_op("an_op", [], [types.float32])
+ self.assertEqual(None, op.device)
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op" }
+ """, gd)
+
+ def testDevicePartialString(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op" device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testDeviceFull(self):
+ g = ops.Graph()
+ with g.device(pydev.Device(job="worker", replica=2, task=0,
+ device_type="CPU",
+ device_index=3)):
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2/task:0/device:CPU:3" }
+ """, gd)
+
+ def testNesting(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device("/job:worker/replica:3/task:0"):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/replica:3/task:0" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testNestingString(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device("/job:worker/replica:3/task:0"):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/replica:3/task:0" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testNestingOverrideGpuCpu(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2/device:CPU:1"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device("/job:worker/replica:2/device:GPU:2"):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/replica:2/device:GPU:2" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ """, gd)
+
+ def testNestingWithMergeDeviceFunction(self):
+ g = ops.Graph()
+
+ with g.device(pydev.merge_device("/device:GPU:0")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device("/job:worker")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device("/device:CPU:0")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device("/job:ps")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device(None)):
+ g.create_op("an_op", [], [types.float32])
+
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/device:GPU:0" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/device:GPU:0" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/device:CPU:0" }
+ node { name: "an_op_3" op: "an_op"
+ device: "/job:ps/device:CPU:0" }
+ node { name: "an_op_4" op: "an_op"
+ device: "/job:ps/device:CPU:0" }
+ """, gd)
+
+ def testNoneClearsDefault(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2/device:CPU:1"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(None):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ node { name: "an_op_1" op: "an_op" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ """, gd)
+
+
+class ObjectWithName(object):
+
+ def __init__(self, name):
+ self._name = name
+
+ @property
+ def name(self):
+ return self._name
+
+
+class CollectionTest(test_util.TensorFlowTestCase):
+
+ def testadd_to_collection(self):
+ g = ops.Graph()
+ g.add_to_collection("key", 12)
+ g.add_to_collection("other", "foo")
+ g.add_to_collection("key", 34)
+
+ # Note that only blank1 is returned.
+ g.add_to_collection("blah", 27)
+ blank1 = ObjectWithName("prefix/foo")
+ g.add_to_collection("blah", blank1)
+ blank2 = ObjectWithName("junk/foo")
+ g.add_to_collection("blah", blank2)
+
+ self.assertEquals(["foo"], g.get_collection("other"))
+ self.assertEquals([12, 34], g.get_collection("key"))
+ self.assertEquals([], g.get_collection("nothing"))
+ self.assertEquals([27, blank1, blank2], g.get_collection("blah"))
+ self.assertEquals([blank1], g.get_collection("blah", "prefix"))
+
+ def testDefaulGraph(self):
+ with ops.Graph().as_default():
+ ops.add_to_collection("key", 90)
+ ops.add_to_collection("key", 100)
+ # Collections are ordered.
+ self.assertEquals([90, 100], ops.get_collection("key"))
+
+
+def an_op(g):
+ return _apply_op(g, "an_op", [], [types.float32])
+
+
+ops.NoGradient("an_op")
+
+
+def copy_op(x):
+ return _apply_op(x.graph, "copy", [x], [x.dtype])
+
+
+@ops.RegisterGradient("copy")
+def _CopyGrad(op, x_grad):
+ _ = op
+ return x_grad
+
+
+@ops.RegisterGradient("copy_override")
+def _CopyOverrideGrad(op, x_grad):
+ _ = op
+ return x_grad
+
+
+class RegistrationTest(test_util.TensorFlowTestCase):
+
+ def testRegisterGradients(self):
+ g = ops.Graph()
+ x = an_op(g)
+ y = copy_op(x)
+ fn = ops.get_gradient_function(y.op)
+ self.assertEquals(_CopyGrad, fn)
+
+ def testOverrideGradients(self):
+ g = ops.Graph()
+ x = an_op(g)
+ with g.gradient_override_map({"copy": "copy_override"}):
+ y = copy_op(x)
+ fn = ops.get_gradient_function(y.op)
+ self.assertEquals(_CopyOverrideGrad, fn)
+
+ def testNonExistentOverride(self):
+ g = ops.Graph()
+ x = an_op(g)
+ with g.gradient_override_map({"copy": "unknown_override"}):
+ y = copy_op(x)
+ with self.assertRaisesRegexp(LookupError, "unknown_override"):
+ fn = ops.get_gradient_function(y.op)
+
+
+class ComparisonTest(test_util.TensorFlowTestCase):
+
+ def testMembershipAllowed(self):
+ g = ops.Graph()
+ t1 = _apply_op(g, "const", [], [types.float32], name="myop1")
+ t2 = _apply_op(g, "const", [], [types.float32], name="myop2")
+ self.assertTrue(isinstance(t1, ops.Tensor))
+ self.assertTrue(isinstance(t2, ops.Tensor))
+ self.assertTrue(t1 in [t1])
+ self.assertTrue(t1 not in [t2])
+
+
+class ControlDependenciesTest(test_util.TensorFlowTestCase):
+
+ def testBasic(self):
+ g = ops.Graph()
+ a = _apply_op(g, "const", [], [types.float32])
+ b = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a]):
+ c = _apply_op(g, "const", [], [types.float32])
+ d = _apply_op(g, "identity", [b], [types.float32])
+ e = _apply_op(g, "identity", [c], [types.float32])
+
+ self.assertEqual(c.op.control_inputs, [a.op])
+ self.assertEqual(d.op.control_inputs, [a.op])
+ # e should be dominated by c.
+ self.assertEqual(e.op.control_inputs, [])
+
+ def testNested(self):
+ g = ops.Graph()
+ a_1 = _apply_op(g, "const", [], [types.float32])
+ a_2 = _apply_op(g, "const", [], [types.float32])
+ a_3 = _apply_op(g, "const", [], [types.float32])
+ a_4 = _apply_op(g, "const", [], [types.float32])
+
+ with g.control_dependencies([a_1, a_2, a_3, a_4]):
+ b_1 = _apply_op(g, "const", [], [types.float32])
+
+ with g.control_dependencies([a_1]):
+ with g.control_dependencies([a_2]):
+ with g.control_dependencies([a_3]):
+ with g.control_dependencies([a_4]):
+ b_2 = _apply_op(g, "const", [], [types.float32])
+
+ self.assertItemsEqual(
+ [a_1.op, a_2.op, a_3.op, a_4.op], b_1.op.control_inputs)
+ self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
+
+ def testComplex(self):
+ g = ops.Graph()
+
+ # Usage pattern:
+ # * Nodes a_i are constants defined at the outermost scope, and are used
+ # as control inputs for the ith nested scope.
+ # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
+ # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
+ # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
+ # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
+
+ a_1 = _apply_op(g, "const", [], [types.float32])
+ a_2 = _apply_op(g, "const", [], [types.float32])
+ a_3 = _apply_op(g, "const", [], [types.float32])
+ a_4 = _apply_op(g, "const", [], [types.float32])
+
+ with g.control_dependencies([a_1]):
+ b_1 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_1 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_1 = _apply_op(g, "mul", [b_1, c_1], [types.float32])
+ e_1 = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a_2]):
+ b_2 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_2 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_2 = _apply_op(g, "mul", [b_2, c_2], [types.float32])
+ e_2 = _apply_op(g, "mul", [e_1, e_1], [types.float32])
+ with g.control_dependencies([a_3]):
+ b_3 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_3 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_3 = _apply_op(g, "mul", [b_3, c_3], [types.float32])
+ e_3 = _apply_op(g, "mul", [e_2, e_2], [types.float32])
+ with g.control_dependencies([a_4]):
+ b_4 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_4 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_4 = _apply_op(g, "mul", [b_4, c_4], [types.float32])
+ e_4 = _apply_op(g, "mul", [e_3, e_3], [types.float32])
+
+ self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
+
+ self.assertItemsEqual([], c_1.op.control_inputs)
+ self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
+ self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
+ self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
+
+ self.assertItemsEqual([], d_1.op.control_inputs)
+ self.assertItemsEqual([], d_2.op.control_inputs)
+ self.assertItemsEqual([], d_3.op.control_inputs)
+ self.assertItemsEqual([], d_4.op.control_inputs)
+
+ self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
+ self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
+ self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
+ self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
+
+ def testRepeatedDependency(self):
+ g = ops.Graph()
+ a = g.create_op("foo", [], [types.float32, types.float32])
+ a_0, a_1 = a.outputs
+ with g.control_dependencies([a_0]):
+ b = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a_1]):
+ c = _apply_op(g, "const", [], [types.float32])
+
+ self.assertEqual(b.op.control_inputs, [a])
+ self.assertEqual(c.op.control_inputs, [a])
+
+ def testNoControlDependencyWithDataDependency(self):
+ g = ops.Graph()
+ a = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a]):
+ b = _apply_op(g, "identity", [a], [types.float32])
+
+ self.assertEqual(b.op.control_inputs, [])
+
+
+class GraphTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ ops.reset_default_graph()
+
+ def _AssertDefault(self, expected):
+ self.assertIs(expected, ops.get_default_graph())
+
+ def testGraphContextManager(self):
+ g0 = ops.Graph()
+ with g0.as_default() as g1:
+ self.assertIs(g0, g1)
+
+ def testDefaultGraph(self):
+ orig = ops.get_default_graph()
+ self._AssertDefault(orig)
+ g0 = ops.Graph()
+ self._AssertDefault(orig)
+ context_manager_0 = g0.as_default()
+ self._AssertDefault(orig)
+ with context_manager_0 as g0:
+ self._AssertDefault(g0)
+ with ops.Graph().as_default() as g1:
+ self._AssertDefault(g1)
+ self._AssertDefault(g0)
+ self._AssertDefault(orig)
+
+ def testAsGraphElementConversions(self):
+ class ConvertibleObj(object):
+
+ def _as_graph_element(self):
+ return "const:0"
+
+ class NonConvertibleObj(object):
+
+ pass
+
+ g = ops.Graph()
+ a = _apply_op(g, "const", [], [types.float32])
+ self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
+ with self.assertRaises(TypeError):
+ g.as_graph_element(NonConvertibleObj())
+
+ def testAssertSameGraph(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [types.float32])
+ b = g0.create_op("b", [], [types.float32])
+ ops.assert_same_graph([a, b])
+ ops.assert_same_graph([a, b], g0)
+ g1 = ops.Graph()
+ c = g1.create_op("c", [], [types.float32])
+ self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c])
+ self.assertRaises(ValueError, ops.assert_same_graph, [c], g0)
+ self.assertRaises(ValueError, ops.assert_same_graph, [a], g1)
+
+ sparse = ops.SparseTensor(
+ _apply_op(g0, "const", [], [types.int64]),
+ _apply_op(g0, "const", [], [types.float32]),
+ _apply_op(g0, "const", [], [types.int64]))
+ ops.assert_same_graph([sparse, a, b])
+ ops.assert_same_graph([sparse, a, b], g0)
+ self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c])
+ self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c], g1)
+
+ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
+
+
+class KernelLabelTest(test_util.TensorFlowTestCase):
+
+ def testNoLabel(self):
+ with self.test_session():
+ self.assertAllEqual("My label is: default",
+ test_kernel_label_op.kernel_label().eval())
+
+ def testLabelMap(self):
+ with self.test_session() as sess:
+ default_1 = test_kernel_label_op.kernel_label()
+ # pylint: disable=protected-access
+ with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
+ overload_1_1 = test_kernel_label_op.kernel_label()
+ with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
+ overload_2 = test_kernel_label_op.kernel_label()
+ with sess.graph._kernel_label_map({"KernelLabel": ""}):
+ default_2 = test_kernel_label_op.kernel_label()
+ overload_1_2 = test_kernel_label_op.kernel_label()
+ # pylint: enable=protected-access
+ default_3 = test_kernel_label_op.kernel_label()
+
+ self.assertAllEqual("My label is: default", default_1.eval())
+ self.assertAllEqual("My label is: default", default_2.eval())
+ self.assertAllEqual("My label is: default", default_3.eval())
+ self.assertAllEqual("My label is: overload_1", overload_1_1.eval())
+ self.assertAllEqual("My label is: overload_1", overload_1_2.eval())
+ self.assertAllEqual("My label is: overload_2", overload_2.eval())
+
+
+if __name__ == "__main__":
+ googletest.main()