diff options
Diffstat (limited to 'tensorflow/python/framework/ops_test.py')
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 108 |
1 files changed, 81 insertions, 27 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 43044b1d39..5831ccd108 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_kernel_label_op from tensorflow.python.framework import test_util from tensorflow.python.ops import common_shapes +from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -356,19 +357,19 @@ class NameTest(test_util.TensorFlowTestCase): self.assertEqual("my_op", op2.name) self.assertEqual("my_op:0", op2.outputs[0].name) - def testname_scope(self): + def testNameScope(self): g = ops.Graph() with g.name_scope("foo") as foo: - self.assertEqual(foo, "foo/") + self.assertEqual("foo/", foo) with g.name_scope("foo2") as foo2: - self.assertEqual(foo2, "foo/foo2/") + self.assertEqual("foo/foo2/", foo2) with g.name_scope(None) as empty1: - self.assertEqual(empty1, "") + self.assertEqual("", empty1) with g.name_scope("foo3") as foo3: - self.assertEqual(foo3, "foo3/") + self.assertEqual("foo3/", foo3) with g.name_scope("") as empty2: - self.assertEqual(empty2, "") + self.assertEqual("", empty2) self.assertEqual("const", g.create_op("const", [], [dtypes.float32]).name) @@ -792,6 +793,80 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): self.assertEqual(b.op.control_inputs, []) +class OpScopeTest(test_util.TensorFlowTestCase): + + def testNoScopeName(self): + g0 = ops.Graph() + values = [ + g0.create_op("a", [], [dtypes.float32]), + g0.create_op("b", [], [dtypes.float32])] + with self.assertRaises(ValueError): + with ops.op_scope(values, None): + pass + with self.assertRaises(ValueError): + with ops.op_scope(values, None, None): + pass + + def testEmptyScopeName(self): + g0 = ops.Graph() + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) + with ops.op_scope([a, b], "") as scope: + self.assertEqual("", scope) + self.assertEqual(g0, ops.get_default_graph()) + with ops.op_scope([a, b], "", "my_default_scope") as scope: + self.assertEqual("", scope) + self.assertEqual(g0, ops.get_default_graph()) + + def testDefaultScopeName(self): + g0 = ops.Graph() + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) + scope_name = "my_scope" + default_scope_name = "my_default_scope" + with ops.op_scope([a, b], scope_name, default_scope_name) as scope: + self.assertEqual("%s/" % scope_name, scope) + self.assertEqual(g0, ops.get_default_graph()) + with ops.op_scope([a, b], None, default_scope_name) as scope: + self.assertEqual("%s/" % default_scope_name, scope) + self.assertEqual(g0, ops.get_default_graph()) + + def _testGraphElements(self, graph_elements): + scope_name = "my_scope" + with ops.op_scope(graph_elements, scope_name) as scope: + self.assertEqual("%s/" % scope_name, scope) + self.assertEqual(graph_elements[0].graph, ops.get_default_graph()) + g1 = ops.Graph() + c = g1.create_op("c", [], [dtypes.float32]) + with self.assertRaises(ValueError): + with ops.op_scope(graph_elements + [c], scope_name): + pass + + def testTensor(self): + g0 = ops.Graph() + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) + self._testGraphElements([a, b]) + + def testSparseTensor(self): + g0 = ops.Graph() + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) + sparse = ops.SparseTensor( + _apply_op(g0, "const", [], [dtypes.int64]), + _apply_op(g0, "const", [], [dtypes.float32]), + _apply_op(g0, "const", [], [dtypes.int64])) + self._testGraphElements([a, sparse, b]) + + def testVariable(self): + g0 = ops.Graph() + with g0.as_default(): + variable = variables.Variable([1.0]) + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) + self._testGraphElements([a, variable, b]) + + class GraphTest(test_util.TensorFlowTestCase): def setUp(self): @@ -835,27 +910,6 @@ class GraphTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError): g.as_graph_element(NonConvertibleObj()) - def testAssertSameGraph(self): - g0 = ops.Graph() - a = g0.create_op("a", [], [dtypes.float32]) - b = g0.create_op("b", [], [dtypes.float32]) - ops.assert_same_graph([a, b]) - ops.assert_same_graph([a, b], g0) - g1 = ops.Graph() - c = g1.create_op("c", [], [dtypes.float32]) - self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c]) - self.assertRaises(ValueError, ops.assert_same_graph, [c], g0) - self.assertRaises(ValueError, ops.assert_same_graph, [a], g1) - - sparse = ops.SparseTensor( - _apply_op(g0, "const", [], [dtypes.int64]), - _apply_op(g0, "const", [], [dtypes.float32]), - _apply_op(g0, "const", [], [dtypes.int64])) - ops.assert_same_graph([sparse, a, b]) - ops.assert_same_graph([sparse, a, b], g0) - self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c]) - self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c], g1) - ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) |