diff options
-rw-r--r-- | tensorflow/python/training/checkpointable/base.py | 6 | ||||
-rw-r--r-- | tensorflow/python/training/checkpointable/base_test.py | 20 |
2 files changed, 20 insertions, 6 deletions
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index ee35b01328..f0703c8af4 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -501,12 +501,6 @@ class CheckpointableBase(object): ValueError: If the variable name is not unique. """ self._maybe_initialize_checkpointable() - if overwrite and self._lookup_dependency(name) is not None: - raise ValueError( - ("A variable named '%s' already exists in this Checkpointable, but " - "Checkpointable._add_variable called to create another with " - "that name. Variable names must be unique within a Checkpointable " - "object.") % (name,)) with ops.init_scope(): if context.executing_eagerly(): # If this is a variable with a single Tensor stored in the checkpoint, diff --git a/tensorflow/python/training/checkpointable/base_test.py b/tensorflow/python/training/checkpointable/base_test.py index 950e9c5b53..fd935ac559 100644 --- a/tensorflow/python/training/checkpointable/base_test.py +++ b/tensorflow/python/training/checkpointable/base_test.py @@ -16,8 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import util class InterfaceTests(test.TestCase): @@ -37,5 +40,22 @@ class InterfaceTests(test.TestCase): self.assertIs(duplicate_name_dep, current_dependency) self.assertEqual("leaf", current_name) + def testAddVariableOverwrite(self): + root = base.CheckpointableBase() + a = root._add_variable_with_custom_getter( + name="v", shape=[], getter=variable_scope.get_variable) + self.assertEqual([root, a], util.list_objects(root)) + with ops.Graph().as_default(): + b = root._add_variable_with_custom_getter( + name="v", shape=[], overwrite=True, + getter=variable_scope.get_variable) + self.assertEqual([root, b], util.list_objects(root)) + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, "already declared as a dependency"): + root._add_variable_with_custom_getter( + name="v", shape=[], overwrite=False, + getter=variable_scope.get_variable) + if __name__ == "__main__": test.main() |