diff options
Diffstat (limited to 'tensorflow/python/training/checkpointable/base_test.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/base_test.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/base_test.py b/tensorflow/python/training/checkpointable/base_test.py index 950e9c5b53..fd935ac559 100644 --- a/tensorflow/python/training/checkpointable/base_test.py +++ b/tensorflow/python/training/checkpointable/base_test.py @@ -16,8 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import util class InterfaceTests(test.TestCase): @@ -37,5 +40,22 @@ class InterfaceTests(test.TestCase): self.assertIs(duplicate_name_dep, current_dependency) self.assertEqual("leaf", current_name) + def testAddVariableOverwrite(self): + root = base.CheckpointableBase() + a = root._add_variable_with_custom_getter( + name="v", shape=[], getter=variable_scope.get_variable) + self.assertEqual([root, a], util.list_objects(root)) + with ops.Graph().as_default(): + b = root._add_variable_with_custom_getter( + name="v", shape=[], overwrite=True, + getter=variable_scope.get_variable) + self.assertEqual([root, b], util.list_objects(root)) + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, "already declared as a dependency"): + root._add_variable_with_custom_getter( + name="v", shape=[], overwrite=False, + getter=variable_scope.get_variable) + if __name__ == "__main__": test.main() |