aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/training/checkpointable/base.py6
-rw-r--r--tensorflow/python/training/checkpointable/base_test.py20
2 files changed, 20 insertions, 6 deletions
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index ee35b01328..f0703c8af4 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -501,12 +501,6 @@ class CheckpointableBase(object):
ValueError: If the variable name is not unique.
"""
self._maybe_initialize_checkpointable()
- if overwrite and self._lookup_dependency(name) is not None:
- raise ValueError(
- ("A variable named '%s' already exists in this Checkpointable, but "
- "Checkpointable._add_variable called to create another with "
- "that name. Variable names must be unique within a Checkpointable "
- "object.") % (name,))
with ops.init_scope():
if context.executing_eagerly():
# If this is a variable with a single Tensor stored in the checkpoint,
diff --git a/tensorflow/python/training/checkpointable/base_test.py b/tensorflow/python/training/checkpointable/base_test.py
index 950e9c5b53..fd935ac559 100644
--- a/tensorflow/python/training/checkpointable/base_test.py
+++ b/tensorflow/python/training/checkpointable/base_test.py
@@ -16,8 +16,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import util
class InterfaceTests(test.TestCase):
@@ -37,5 +40,22 @@ class InterfaceTests(test.TestCase):
self.assertIs(duplicate_name_dep, current_dependency)
self.assertEqual("leaf", current_name)
+ def testAddVariableOverwrite(self):
+ root = base.CheckpointableBase()
+ a = root._add_variable_with_custom_getter(
+ name="v", shape=[], getter=variable_scope.get_variable)
+ self.assertEqual([root, a], util.list_objects(root))
+ with ops.Graph().as_default():
+ b = root._add_variable_with_custom_getter(
+ name="v", shape=[], overwrite=True,
+ getter=variable_scope.get_variable)
+ self.assertEqual([root, b], util.list_objects(root))
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError, "already declared as a dependency"):
+ root._add_variable_with_custom_getter(
+ name="v", shape=[], overwrite=False,
+ getter=variable_scope.get_variable)
+
if __name__ == "__main__":
test.main()