aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/base_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/base_test.py')
-rw-r--r--tensorflow/python/training/checkpointable/base_test.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/base_test.py b/tensorflow/python/training/checkpointable/base_test.py
index 950e9c5b53..fd935ac559 100644
--- a/tensorflow/python/training/checkpointable/base_test.py
+++ b/tensorflow/python/training/checkpointable/base_test.py
@@ -16,8 +16,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import util
class InterfaceTests(test.TestCase):
@@ -37,5 +40,22 @@ class InterfaceTests(test.TestCase):
self.assertIs(duplicate_name_dep, current_dependency)
self.assertEqual("leaf", current_name)
+ def testAddVariableOverwrite(self):
+ root = base.CheckpointableBase()
+ a = root._add_variable_with_custom_getter(
+ name="v", shape=[], getter=variable_scope.get_variable)
+ self.assertEqual([root, a], util.list_objects(root))
+ with ops.Graph().as_default():
+ b = root._add_variable_with_custom_getter(
+ name="v", shape=[], overwrite=True,
+ getter=variable_scope.get_variable)
+ self.assertEqual([root, b], util.list_objects(root))
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError, "already declared as a dependency"):
+ root._add_variable_with_custom_getter(
+ name="v", shape=[], overwrite=False,
+ getter=variable_scope.get_variable)
+
if __name__ == "__main__":
test.main()