aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/variable_scope_test.py
diff options
context:
space:
mode:
authorGravatar Ali Yahya <alive@google.com>2017-09-14 15:21:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-14 15:25:21 -0700
commitf3f4f3fbcddf4b7d65ffe85c4e38e7a07d5f74fb (patch)
tree5f9417708cc4976353d19f8424df6f966e6df874 /tensorflow/python/kernel_tests/variable_scope_test.py
parent40d4a31545b6cb61e52c9b07827c6ca92c09ce75 (diff)
When Eager Execution is enabled, TensorFlow now no longer relies on global collections to keep track of ResourceVariables. Instead, they are tracked by the user as normal Python objects. In a subsequent CL, we'll make the lifetime of a variable's underlying resource match the lifetime of the corresponding Python object. For this to happen, there must be no everlasting global Python references to said variables.
More specifically, this change forces the `collections` flag in ResourceVariable's constructor to be None when Eager is enabled. It also raises an error on calls to get_collection() for variable collections. PiperOrigin-RevId: 168754146
Diffstat (limited to 'tensorflow/python/kernel_tests/variable_scope_test.py')
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py148
1 files changed, 75 insertions, 73 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index cdac12f05a..27c3fe6375 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -737,11 +737,12 @@ class VariableScopeTest(test.TestCase):
# Since variable is local, it should be in the local variable collection
# but not the trainable collection.
- self.assertIn(local_var,
- ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
- self.assertIn(local_var, ops.get_collection("foo"))
- self.assertNotIn(local_var,
- ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+ if context.in_graph_mode():
+ self.assertIn(local_var,
+ ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
+ self.assertIn(local_var, ops.get_collection("foo"))
+ self.assertNotIn(local_var,
+ ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
# Check that local variable respects `reuse`.
with variable_scope.variable_scope(outer, "default", reuse=True):
@@ -765,93 +766,94 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(varname_type[0], ("x", dtypes.float32))
self.assertEqual(varname_type[1], ("y", dtypes.int64))
- @test_util.run_in_graph_and_eager_modes()
def testGetCollection(self):
- _ = variable_scope.get_variable("testGetCollection_a", [])
- _ = variable_scope.get_variable("testGetCollection_b", [], trainable=False)
- with variable_scope.variable_scope("testGetCollection_foo_") as scope1:
+ with self.test_session():
_ = variable_scope.get_variable("testGetCollection_a", [])
_ = variable_scope.get_variable("testGetCollection_b", [],
trainable=False)
+ with variable_scope.variable_scope("testGetCollection_foo_") as scope1:
+ _ = variable_scope.get_variable("testGetCollection_a", [])
+ _ = variable_scope.get_variable("testGetCollection_b", [],
+ trainable=False)
+ self.assertEqual([
+ v.name
+ for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ ], ["testGetCollection_foo_/testGetCollection_a:0"])
+ self.assertEqual([
+ v.name
+ for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ ], [
+ "testGetCollection_foo_/testGetCollection_a:0",
+ "testGetCollection_foo_/testGetCollection_b:0"
+ ])
+ with variable_scope.variable_scope("testGetCollection_foo") as scope2:
+ _ = variable_scope.get_variable("testGetCollection_a", [])
+ _ = variable_scope.get_variable("testGetCollection_b", [],
+ trainable=False)
+ self.assertEqual([
+ v.name
+ for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ ], ["testGetCollection_foo/testGetCollection_a:0"])
+ self.assertEqual([
+ v.name
+ for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ ], [
+ "testGetCollection_foo/testGetCollection_a:0",
+ "testGetCollection_foo/testGetCollection_b:0"
+ ])
+ scope = variable_scope.get_variable_scope()
self.assertEqual([
- v.name
- for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- ], ["testGetCollection_foo_/testGetCollection_a:0"])
- self.assertEqual([
- v.name
- for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
], [
+ "testGetCollection_a:0", "testGetCollection_b:0",
"testGetCollection_foo_/testGetCollection_a:0",
- "testGetCollection_foo_/testGetCollection_b:0"
+ "testGetCollection_foo_/testGetCollection_b:0",
+ "testGetCollection_foo/testGetCollection_a:0",
+ "testGetCollection_foo/testGetCollection_b:0"
])
- with variable_scope.variable_scope("testGetCollection_foo") as scope2:
- _ = variable_scope.get_variable("testGetCollection_a", [])
- _ = variable_scope.get_variable("testGetCollection_b", [],
- trainable=False)
self.assertEqual([
v.name
- for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- ], ["testGetCollection_foo/testGetCollection_a:0"])
- self.assertEqual([
- v.name
- for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
], [
- "testGetCollection_foo/testGetCollection_a:0",
- "testGetCollection_foo/testGetCollection_b:0"
+ "testGetCollection_a:0",
+ "testGetCollection_foo_/testGetCollection_a:0",
+ "testGetCollection_foo/testGetCollection_a:0"
])
- scope = variable_scope.get_variable_scope()
- self.assertEqual([
- v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- ], [
- "testGetCollection_a:0", "testGetCollection_b:0",
- "testGetCollection_foo_/testGetCollection_a:0",
- "testGetCollection_foo_/testGetCollection_b:0",
- "testGetCollection_foo/testGetCollection_a:0",
- "testGetCollection_foo/testGetCollection_b:0"
- ])
- self.assertEqual([
- v.name
- for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- ], [
- "testGetCollection_a:0",
- "testGetCollection_foo_/testGetCollection_a:0",
- "testGetCollection_foo/testGetCollection_a:0"
- ])
- @test_util.run_in_graph_and_eager_modes()
def testGetTrainableVariables(self):
- _ = variable_scope.get_variable("testGetTrainableVariables_a", [])
- with variable_scope.variable_scope(
- "testGetTrainableVariables_foo") as scope:
- _ = variable_scope.get_variable("testGetTrainableVariables_b", [])
- _ = variable_scope.get_variable("testGetTrainableVariables_c", [],
- trainable=False)
- self.assertEqual([v.name
- for v in scope.trainable_variables()],
- ["testGetTrainableVariables_foo/"
- "testGetTrainableVariables_b:0"])
+ with self.test_session():
+ _ = variable_scope.get_variable("testGetTrainableVariables_a", [])
+ with variable_scope.variable_scope(
+ "testGetTrainableVariables_foo") as scope:
+ _ = variable_scope.get_variable("testGetTrainableVariables_b", [])
+ _ = variable_scope.get_variable("testGetTrainableVariables_c", [],
+ trainable=False)
+ self.assertEqual([v.name
+ for v in scope.trainable_variables()],
+ ["testGetTrainableVariables_foo/"
+ "testGetTrainableVariables_b:0"])
- @test_util.run_in_graph_and_eager_modes()
def testGetGlobalVariables(self):
- _ = variable_scope.get_variable("testGetGlobalVariables_a", [])
- with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
- _ = variable_scope.get_variable("testGetGlobalVariables_b", [])
- self.assertEqual([v.name
- for v in scope.global_variables()],
- ["testGetGlobalVariables_foo/"
- "testGetGlobalVariables_b:0"])
+ with self.test_session():
+ _ = variable_scope.get_variable("testGetGlobalVariables_a", [])
+ with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
+ _ = variable_scope.get_variable("testGetGlobalVariables_b", [])
+ self.assertEqual([v.name
+ for v in scope.global_variables()],
+ ["testGetGlobalVariables_foo/"
+ "testGetGlobalVariables_b:0"])
- @test_util.run_in_graph_and_eager_modes()
def testGetLocalVariables(self):
- _ = variable_scope.get_variable(
- "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
- with variable_scope.variable_scope("foo") as scope:
- _ = variable_scope.get_variable(
- "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ with self.test_session():
_ = variable_scope.get_variable(
- "c", [])
- self.assertEqual([v.name
- for v in scope.local_variables()], ["foo/b:0"])
+ "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ with variable_scope.variable_scope("foo") as scope:
+ _ = variable_scope.get_variable(
+ "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ _ = variable_scope.get_variable(
+ "c", [])
+ self.assertEqual([v.name
+ for v in scope.local_variables()], ["foo/b:0"])
def testGetVariableWithRefDtype(self):
v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32)