aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/base_test.py
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-07-25 09:56:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 10:00:27 -0700
commit21f139075de212ccaab69bb89bb96d8b98282523 (patch)
tree1678e8895fa776db543044925f077db71921da02 /tensorflow/python/training/checkpointable/base_test.py
parentfa69d6531cc6cfe865a3ffc63c58f3c2fe0ec4df (diff)
Fix dependency overwriting in _add_variable_with_custom_getter
Removes an exception which should have been removed in cl/203156155 (there is an equivalent exception slightly deeper which is more nuanced). The conditional as-is makes no sense. PiperOrigin-RevId: 206009242
Diffstat (limited to 'tensorflow/python/training/checkpointable/base_test.py')
-rw-r--r--tensorflow/python/training/checkpointable/base_test.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/base_test.py b/tensorflow/python/training/checkpointable/base_test.py
index 950e9c5b53..fd935ac559 100644
--- a/tensorflow/python/training/checkpointable/base_test.py
+++ b/tensorflow/python/training/checkpointable/base_test.py
@@ -16,8 +16,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import util
class InterfaceTests(test.TestCase):
@@ -37,5 +40,22 @@ class InterfaceTests(test.TestCase):
self.assertIs(duplicate_name_dep, current_dependency)
self.assertEqual("leaf", current_name)
+ def testAddVariableOverwrite(self):
+ root = base.CheckpointableBase()
+ a = root._add_variable_with_custom_getter(
+ name="v", shape=[], getter=variable_scope.get_variable)
+ self.assertEqual([root, a], util.list_objects(root))
+ with ops.Graph().as_default():
+ b = root._add_variable_with_custom_getter(
+ name="v", shape=[], overwrite=True,
+ getter=variable_scope.get_variable)
+ self.assertEqual([root, b], util.list_objects(root))
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError, "already declared as a dependency"):
+ root._add_variable_with_custom_getter(
+ name="v", shape=[], overwrite=False,
+ getter=variable_scope.get_variable)
+
if __name__ == "__main__":
test.main()