aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/ops_test.py')
-rw-r--r--tensorflow/python/framework/ops_test.py108
1 files changed, 81 insertions, 27 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 43044b1d39..5831ccd108 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_kernel_label_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -356,19 +357,19 @@ class NameTest(test_util.TensorFlowTestCase):
self.assertEqual("my_op", op2.name)
self.assertEqual("my_op:0", op2.outputs[0].name)
- def testname_scope(self):
+ def testNameScope(self):
g = ops.Graph()
with g.name_scope("foo") as foo:
- self.assertEqual(foo, "foo/")
+ self.assertEqual("foo/", foo)
with g.name_scope("foo2") as foo2:
- self.assertEqual(foo2, "foo/foo2/")
+ self.assertEqual("foo/foo2/", foo2)
with g.name_scope(None) as empty1:
- self.assertEqual(empty1, "")
+ self.assertEqual("", empty1)
with g.name_scope("foo3") as foo3:
- self.assertEqual(foo3, "foo3/")
+ self.assertEqual("foo3/", foo3)
with g.name_scope("") as empty2:
- self.assertEqual(empty2, "")
+ self.assertEqual("", empty2)
self.assertEqual("const",
g.create_op("const", [], [dtypes.float32]).name)
@@ -792,6 +793,80 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
self.assertEqual(b.op.control_inputs, [])
+class OpScopeTest(test_util.TensorFlowTestCase):
+
+ def testNoScopeName(self):
+ g0 = ops.Graph()
+ values = [
+ g0.create_op("a", [], [dtypes.float32]),
+ g0.create_op("b", [], [dtypes.float32])]
+ with self.assertRaises(ValueError):
+ with ops.op_scope(values, None):
+ pass
+ with self.assertRaises(ValueError):
+ with ops.op_scope(values, None, None):
+ pass
+
+ def testEmptyScopeName(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [dtypes.float32])
+ b = g0.create_op("b", [], [dtypes.float32])
+ with ops.op_scope([a, b], "") as scope:
+ self.assertEqual("", scope)
+ self.assertEqual(g0, ops.get_default_graph())
+ with ops.op_scope([a, b], "", "my_default_scope") as scope:
+ self.assertEqual("", scope)
+ self.assertEqual(g0, ops.get_default_graph())
+
+ def testDefaultScopeName(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [dtypes.float32])
+ b = g0.create_op("b", [], [dtypes.float32])
+ scope_name = "my_scope"
+ default_scope_name = "my_default_scope"
+ with ops.op_scope([a, b], scope_name, default_scope_name) as scope:
+ self.assertEqual("%s/" % scope_name, scope)
+ self.assertEqual(g0, ops.get_default_graph())
+ with ops.op_scope([a, b], None, default_scope_name) as scope:
+ self.assertEqual("%s/" % default_scope_name, scope)
+ self.assertEqual(g0, ops.get_default_graph())
+
+ def _testGraphElements(self, graph_elements):
+ scope_name = "my_scope"
+ with ops.op_scope(graph_elements, scope_name) as scope:
+ self.assertEqual("%s/" % scope_name, scope)
+ self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
+ g1 = ops.Graph()
+ c = g1.create_op("c", [], [dtypes.float32])
+ with self.assertRaises(ValueError):
+ with ops.op_scope(graph_elements + [c], scope_name):
+ pass
+
+ def testTensor(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [dtypes.float32])
+ b = g0.create_op("b", [], [dtypes.float32])
+ self._testGraphElements([a, b])
+
+ def testSparseTensor(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [dtypes.float32])
+ b = g0.create_op("b", [], [dtypes.float32])
+ sparse = ops.SparseTensor(
+ _apply_op(g0, "const", [], [dtypes.int64]),
+ _apply_op(g0, "const", [], [dtypes.float32]),
+ _apply_op(g0, "const", [], [dtypes.int64]))
+ self._testGraphElements([a, sparse, b])
+
+ def testVariable(self):
+ g0 = ops.Graph()
+ with g0.as_default():
+ variable = variables.Variable([1.0])
+ a = g0.create_op("a", [], [dtypes.float32])
+ b = g0.create_op("b", [], [dtypes.float32])
+ self._testGraphElements([a, variable, b])
+
+
class GraphTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -835,27 +910,6 @@ class GraphTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError):
g.as_graph_element(NonConvertibleObj())
- def testAssertSameGraph(self):
- g0 = ops.Graph()
- a = g0.create_op("a", [], [dtypes.float32])
- b = g0.create_op("b", [], [dtypes.float32])
- ops.assert_same_graph([a, b])
- ops.assert_same_graph([a, b], g0)
- g1 = ops.Graph()
- c = g1.create_op("c", [], [dtypes.float32])
- self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c])
- self.assertRaises(ValueError, ops.assert_same_graph, [c], g0)
- self.assertRaises(ValueError, ops.assert_same_graph, [a], g1)
-
- sparse = ops.SparseTensor(
- _apply_op(g0, "const", [], [dtypes.int64]),
- _apply_op(g0, "const", [], [dtypes.float32]),
- _apply_op(g0, "const", [], [dtypes.int64]))
- ops.assert_same_graph([sparse, a, b])
- ops.assert_same_graph([sparse, a, b], g0)
- self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c])
- self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c], g1)
-
ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)