diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/variable_scope_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/variable_scope_test.py | 148 |
1 files changed, 75 insertions, 73 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index cdac12f05a..27c3fe6375 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -737,11 +737,12 @@ class VariableScopeTest(test.TestCase): # Since variable is local, it should be in the local variable collection # but not the trainable collection. - self.assertIn(local_var, - ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) - self.assertIn(local_var, ops.get_collection("foo")) - self.assertNotIn(local_var, - ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + if context.in_graph_mode(): + self.assertIn(local_var, + ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) + self.assertIn(local_var, ops.get_collection("foo")) + self.assertNotIn(local_var, + ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) # Check that local variable respects `reuse`. with variable_scope.variable_scope(outer, "default", reuse=True): @@ -765,93 +766,94 @@ class VariableScopeTest(test.TestCase): self.assertEqual(varname_type[0], ("x", dtypes.float32)) self.assertEqual(varname_type[1], ("y", dtypes.int64)) - @test_util.run_in_graph_and_eager_modes() def testGetCollection(self): - _ = variable_scope.get_variable("testGetCollection_a", []) - _ = variable_scope.get_variable("testGetCollection_b", [], trainable=False) - with variable_scope.variable_scope("testGetCollection_foo_") as scope1: + with self.test_session(): _ = variable_scope.get_variable("testGetCollection_a", []) _ = variable_scope.get_variable("testGetCollection_b", [], trainable=False) + with variable_scope.variable_scope("testGetCollection_foo_") as scope1: + _ = variable_scope.get_variable("testGetCollection_a", []) + _ = variable_scope.get_variable("testGetCollection_b", [], + trainable=False) + self.assertEqual([ + v.name + for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + ], ["testGetCollection_foo_/testGetCollection_a:0"]) + self.assertEqual([ + v.name + for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + ], [ + "testGetCollection_foo_/testGetCollection_a:0", + "testGetCollection_foo_/testGetCollection_b:0" + ]) + with variable_scope.variable_scope("testGetCollection_foo") as scope2: + _ = variable_scope.get_variable("testGetCollection_a", []) + _ = variable_scope.get_variable("testGetCollection_b", [], + trainable=False) + self.assertEqual([ + v.name + for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + ], ["testGetCollection_foo/testGetCollection_a:0"]) + self.assertEqual([ + v.name + for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + ], [ + "testGetCollection_foo/testGetCollection_a:0", + "testGetCollection_foo/testGetCollection_b:0" + ]) + scope = variable_scope.get_variable_scope() self.assertEqual([ - v.name - for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - ], ["testGetCollection_foo_/testGetCollection_a:0"]) - self.assertEqual([ - v.name - for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) ], [ + "testGetCollection_a:0", "testGetCollection_b:0", "testGetCollection_foo_/testGetCollection_a:0", - "testGetCollection_foo_/testGetCollection_b:0" + "testGetCollection_foo_/testGetCollection_b:0", + "testGetCollection_foo/testGetCollection_a:0", + "testGetCollection_foo/testGetCollection_b:0" ]) - with variable_scope.variable_scope("testGetCollection_foo") as scope2: - _ = variable_scope.get_variable("testGetCollection_a", []) - _ = variable_scope.get_variable("testGetCollection_b", [], - trainable=False) self.assertEqual([ v.name - for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - ], ["testGetCollection_foo/testGetCollection_a:0"]) - self.assertEqual([ - v.name - for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) ], [ - "testGetCollection_foo/testGetCollection_a:0", - "testGetCollection_foo/testGetCollection_b:0" + "testGetCollection_a:0", + "testGetCollection_foo_/testGetCollection_a:0", + "testGetCollection_foo/testGetCollection_a:0" ]) - scope = variable_scope.get_variable_scope() - self.assertEqual([ - v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - ], [ - "testGetCollection_a:0", "testGetCollection_b:0", - "testGetCollection_foo_/testGetCollection_a:0", - "testGetCollection_foo_/testGetCollection_b:0", - "testGetCollection_foo/testGetCollection_a:0", - "testGetCollection_foo/testGetCollection_b:0" - ]) - self.assertEqual([ - v.name - for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - ], [ - "testGetCollection_a:0", - "testGetCollection_foo_/testGetCollection_a:0", - "testGetCollection_foo/testGetCollection_a:0" - ]) - @test_util.run_in_graph_and_eager_modes() def testGetTrainableVariables(self): - _ = variable_scope.get_variable("testGetTrainableVariables_a", []) - with variable_scope.variable_scope( - "testGetTrainableVariables_foo") as scope: - _ = variable_scope.get_variable("testGetTrainableVariables_b", []) - _ = variable_scope.get_variable("testGetTrainableVariables_c", [], - trainable=False) - self.assertEqual([v.name - for v in scope.trainable_variables()], - ["testGetTrainableVariables_foo/" - "testGetTrainableVariables_b:0"]) + with self.test_session(): + _ = variable_scope.get_variable("testGetTrainableVariables_a", []) + with variable_scope.variable_scope( + "testGetTrainableVariables_foo") as scope: + _ = variable_scope.get_variable("testGetTrainableVariables_b", []) + _ = variable_scope.get_variable("testGetTrainableVariables_c", [], + trainable=False) + self.assertEqual([v.name + for v in scope.trainable_variables()], + ["testGetTrainableVariables_foo/" + "testGetTrainableVariables_b:0"]) - @test_util.run_in_graph_and_eager_modes() def testGetGlobalVariables(self): - _ = variable_scope.get_variable("testGetGlobalVariables_a", []) - with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope: - _ = variable_scope.get_variable("testGetGlobalVariables_b", []) - self.assertEqual([v.name - for v in scope.global_variables()], - ["testGetGlobalVariables_foo/" - "testGetGlobalVariables_b:0"]) + with self.test_session(): + _ = variable_scope.get_variable("testGetGlobalVariables_a", []) + with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope: + _ = variable_scope.get_variable("testGetGlobalVariables_b", []) + self.assertEqual([v.name + for v in scope.global_variables()], + ["testGetGlobalVariables_foo/" + "testGetGlobalVariables_b:0"]) - @test_util.run_in_graph_and_eager_modes() def testGetLocalVariables(self): - _ = variable_scope.get_variable( - "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) - with variable_scope.variable_scope("foo") as scope: - _ = variable_scope.get_variable( - "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) + with self.test_session(): _ = variable_scope.get_variable( - "c", []) - self.assertEqual([v.name - for v in scope.local_variables()], ["foo/b:0"]) + "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) + with variable_scope.variable_scope("foo") as scope: + _ = variable_scope.get_variable( + "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) + _ = variable_scope.get_variable( + "c", []) + self.assertEqual([v.name + for v in scope.local_variables()], ["foo/b:0"]) def testGetVariableWithRefDtype(self): v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32) |