aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/variable_scope_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/variable_scope_test.py')
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py148
1 files changed, 75 insertions, 73 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index cdac12f05a..27c3fe6375 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -737,11 +737,12 @@ class VariableScopeTest(test.TestCase):
# Since variable is local, it should be in the local variable collection
# but not the trainable collection.
- self.assertIn(local_var,
- ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
- self.assertIn(local_var, ops.get_collection("foo"))
- self.assertNotIn(local_var,
- ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+ if context.in_graph_mode():
+ self.assertIn(local_var,
+ ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
+ self.assertIn(local_var, ops.get_collection("foo"))
+ self.assertNotIn(local_var,
+ ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
# Check that local variable respects `reuse`.
with variable_scope.variable_scope(outer, "default", reuse=True):
@@ -765,93 +766,94 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(varname_type[0], ("x", dtypes.float32))
self.assertEqual(varname_type[1], ("y", dtypes.int64))
- @test_util.run_in_graph_and_eager_modes()
def testGetCollection(self):
- _ = variable_scope.get_variable("testGetCollection_a", [])
- _ = variable_scope.get_variable("testGetCollection_b", [], trainable=False)
- with variable_scope.variable_scope("testGetCollection_foo_") as scope1:
+ with self.test_session():
_ = variable_scope.get_variable("testGetCollection_a", [])
_ = variable_scope.get_variable("testGetCollection_b", [],
trainable=False)
+ with variable_scope.variable_scope("testGetCollection_foo_") as scope1:
+ _ = variable_scope.get_variable("testGetCollection_a", [])
+ _ = variable_scope.get_variable("testGetCollection_b", [],
+ trainable=False)
+ self.assertEqual([
+ v.name
+ for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ ], ["testGetCollection_foo_/testGetCollection_a:0"])
+ self.assertEqual([
+ v.name
+ for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ ], [
+ "testGetCollection_foo_/testGetCollection_a:0",
+ "testGetCollection_foo_/testGetCollection_b:0"
+ ])
+ with variable_scope.variable_scope("testGetCollection_foo") as scope2:
+ _ = variable_scope.get_variable("testGetCollection_a", [])
+ _ = variable_scope.get_variable("testGetCollection_b", [],
+ trainable=False)
+ self.assertEqual([
+ v.name
+ for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ ], ["testGetCollection_foo/testGetCollection_a:0"])
+ self.assertEqual([
+ v.name
+ for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ ], [
+ "testGetCollection_foo/testGetCollection_a:0",
+ "testGetCollection_foo/testGetCollection_b:0"
+ ])
+ scope = variable_scope.get_variable_scope()
self.assertEqual([
- v.name
- for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- ], ["testGetCollection_foo_/testGetCollection_a:0"])
- self.assertEqual([
- v.name
- for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
], [
+ "testGetCollection_a:0", "testGetCollection_b:0",
"testGetCollection_foo_/testGetCollection_a:0",
- "testGetCollection_foo_/testGetCollection_b:0"
+ "testGetCollection_foo_/testGetCollection_b:0",
+ "testGetCollection_foo/testGetCollection_a:0",
+ "testGetCollection_foo/testGetCollection_b:0"
])
- with variable_scope.variable_scope("testGetCollection_foo") as scope2:
- _ = variable_scope.get_variable("testGetCollection_a", [])
- _ = variable_scope.get_variable("testGetCollection_b", [],
- trainable=False)
self.assertEqual([
v.name
- for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- ], ["testGetCollection_foo/testGetCollection_a:0"])
- self.assertEqual([
- v.name
- for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
], [
- "testGetCollection_foo/testGetCollection_a:0",
- "testGetCollection_foo/testGetCollection_b:0"
+ "testGetCollection_a:0",
+ "testGetCollection_foo_/testGetCollection_a:0",
+ "testGetCollection_foo/testGetCollection_a:0"
])
- scope = variable_scope.get_variable_scope()
- self.assertEqual([
- v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- ], [
- "testGetCollection_a:0", "testGetCollection_b:0",
- "testGetCollection_foo_/testGetCollection_a:0",
- "testGetCollection_foo_/testGetCollection_b:0",
- "testGetCollection_foo/testGetCollection_a:0",
- "testGetCollection_foo/testGetCollection_b:0"
- ])
- self.assertEqual([
- v.name
- for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- ], [
- "testGetCollection_a:0",
- "testGetCollection_foo_/testGetCollection_a:0",
- "testGetCollection_foo/testGetCollection_a:0"
- ])
- @test_util.run_in_graph_and_eager_modes()
def testGetTrainableVariables(self):
- _ = variable_scope.get_variable("testGetTrainableVariables_a", [])
- with variable_scope.variable_scope(
- "testGetTrainableVariables_foo") as scope:
- _ = variable_scope.get_variable("testGetTrainableVariables_b", [])
- _ = variable_scope.get_variable("testGetTrainableVariables_c", [],
- trainable=False)
- self.assertEqual([v.name
- for v in scope.trainable_variables()],
- ["testGetTrainableVariables_foo/"
- "testGetTrainableVariables_b:0"])
+ with self.test_session():
+ _ = variable_scope.get_variable("testGetTrainableVariables_a", [])
+ with variable_scope.variable_scope(
+ "testGetTrainableVariables_foo") as scope:
+ _ = variable_scope.get_variable("testGetTrainableVariables_b", [])
+ _ = variable_scope.get_variable("testGetTrainableVariables_c", [],
+ trainable=False)
+ self.assertEqual([v.name
+ for v in scope.trainable_variables()],
+ ["testGetTrainableVariables_foo/"
+ "testGetTrainableVariables_b:0"])
- @test_util.run_in_graph_and_eager_modes()
def testGetGlobalVariables(self):
- _ = variable_scope.get_variable("testGetGlobalVariables_a", [])
- with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
- _ = variable_scope.get_variable("testGetGlobalVariables_b", [])
- self.assertEqual([v.name
- for v in scope.global_variables()],
- ["testGetGlobalVariables_foo/"
- "testGetGlobalVariables_b:0"])
+ with self.test_session():
+ _ = variable_scope.get_variable("testGetGlobalVariables_a", [])
+ with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
+ _ = variable_scope.get_variable("testGetGlobalVariables_b", [])
+ self.assertEqual([v.name
+ for v in scope.global_variables()],
+ ["testGetGlobalVariables_foo/"
+ "testGetGlobalVariables_b:0"])
- @test_util.run_in_graph_and_eager_modes()
def testGetLocalVariables(self):
- _ = variable_scope.get_variable(
- "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
- with variable_scope.variable_scope("foo") as scope:
- _ = variable_scope.get_variable(
- "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ with self.test_session():
_ = variable_scope.get_variable(
- "c", [])
- self.assertEqual([v.name
- for v in scope.local_variables()], ["foo/b:0"])
+ "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ with variable_scope.variable_scope("foo") as scope:
+ _ = variable_scope.get_variable(
+ "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ _ = variable_scope.get_variable(
+ "c", [])
+ self.assertEqual([v.name
+ for v in scope.local_variables()], ["foo/b:0"])
def testGetVariableWithRefDtype(self):
v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32)