diff options
author | 2018-07-25 09:56:29 -0700 | |
---|---|---|
committer | 2018-07-25 10:00:27 -0700 | |
commit | 21f139075de212ccaab69bb89bb96d8b98282523 (patch) | |
tree | 1678e8895fa776db543044925f077db71921da02 /tensorflow/python/training/checkpointable/base_test.py | |
parent | fa69d6531cc6cfe865a3ffc63c58f3c2fe0ec4df (diff) |
Fix dependency overwriting in _add_variable_with_custom_getter
Removes an exception which should have been removed in cl/203156155 (there is an equivalent exception slightly deeper which is more nuanced). The conditional as-is makes no sense.
PiperOrigin-RevId: 206009242
Diffstat (limited to 'tensorflow/python/training/checkpointable/base_test.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/base_test.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/base_test.py b/tensorflow/python/training/checkpointable/base_test.py index 950e9c5b53..fd935ac559 100644 --- a/tensorflow/python/training/checkpointable/base_test.py +++ b/tensorflow/python/training/checkpointable/base_test.py @@ -16,8 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import util class InterfaceTests(test.TestCase): @@ -37,5 +40,22 @@ class InterfaceTests(test.TestCase): self.assertIs(duplicate_name_dep, current_dependency) self.assertEqual("leaf", current_name) + def testAddVariableOverwrite(self): + root = base.CheckpointableBase() + a = root._add_variable_with_custom_getter( + name="v", shape=[], getter=variable_scope.get_variable) + self.assertEqual([root, a], util.list_objects(root)) + with ops.Graph().as_default(): + b = root._add_variable_with_custom_getter( + name="v", shape=[], overwrite=True, + getter=variable_scope.get_variable) + self.assertEqual([root, b], util.list_objects(root)) + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, "already declared as a dependency"): + root._add_variable_with_custom_getter( + name="v", shape=[], overwrite=False, + getter=variable_scope.get_variable) + if __name__ == "__main__": test.main() |