aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/data_structures_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/data_structures_test.py')
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py119
1 files changed, 119 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index 7bee00a927..472b7c32b4 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
class HasList(training.Model):
@@ -303,6 +304,124 @@ class MappingTests(test.TestCase):
data_structures.Mapping()])
self.assertEqual(2, len(has_mappings))
self.assertNotIn(data_structures.Mapping(), has_mappings)
+ # In contrast to Mapping, dict wrappers are not hashable
+ a = tracking.Checkpointable()
+ a.d = {}
+ self.assertEqual({}, a.d)
+ self.assertFalse({} != a.d) # pylint: disable=g-explicit-bool-comparison
+ self.assertNotEqual({1: 2}, a.d)
+ with self.assertRaisesRegexp(TypeError, "unhashable"):
+ set([a.d])
+
+ def testDictWrapperBadKeys(self):
+ a = tracking.Checkpointable()
+ a.d = {}
+ a.d[1] = data_structures.List()
+ model = training.Model()
+ model.sub = a
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ with self.assertRaisesRegexp(ValueError, "non-string key"):
+ model.save_weights(save_path)
+
+ def testDictWrapperNoDependency(self):
+ a = tracking.Checkpointable()
+ a.d = data_structures.NoDependency({})
+ a.d[1] = [3]
+ self.assertEqual([a], util.list_objects(a))
+ model = training.Model()
+ model.sub = a
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ model.save_weights(save_path)
+ model.load_weights(save_path)
+
+ def testNonStringKeyNotCheckpointableValue(self):
+ a = tracking.Checkpointable()
+ a.d = {}
+ a.d["a"] = [3]
+ a.d[1] = data_structures.NoDependency([3])
+ self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
+ model = training.Model()
+ model.sub = a
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ model.save_weights(save_path)
+ model.load_weights(save_path)
+
+ def testNonAppendNotCheckpointable(self):
+ # Non-append mutations (deleting or overwriting values) are OK when the
+ # values aren't tracked.
+ a = tracking.Checkpointable()
+ a.d = {}
+ a.d["a"] = [3]
+ a.d[1] = 3
+ a.d[1] = 2
+ self.assertEqual(2, a.d[1])
+ del a.d[1]
+ a.d[2] = data_structures.NoDependency(tracking.Checkpointable())
+ second = tracking.Checkpointable()
+ a.d[2] = data_structures.NoDependency(second)
+ self.assertIs(second, a.d[2])
+ self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
+ model = training.Model()
+ model.sub = a
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ model.save_weights(save_path)
+ model.load_weights(save_path)
+
+ def testDelNoSave(self):
+ model = training.Model()
+ model.d = {}
+ model.d["a"] = []
+ del model.d["a"]
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
+ model.save_weights(save_path)
+
+ def testPopNoSave(self):
+ model = training.Model()
+ model.d = {}
+ model.d["a"] = []
+ model.d.pop("a")
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
+ model.save_weights(save_path)
+
+ def testExternalModificationNoSave(self):
+ model = training.Model()
+ external_reference = {}
+ model.d = external_reference
+ external_reference["a"] = []
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"):
+ model.save_weights(save_path)
+
+ def testOverwriteNoSave(self):
+ model = training.Model()
+ model.d = {}
+ model.d["a"] = {}
+ model.d["a"] = {}
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
+ model.save_weights(save_path)
+
+ def testIter(self):
+ model = training.Model()
+ model.d = {1: 3}
+ model.d[1] = 3
+ self.assertEqual([1], list(model.d))
+ new_dict = {}
+ # This update() is super tricky. If the dict wrapper subclasses dict,
+ # CPython will access its storage directly instead of calling any
+ # methods/properties on the object. So the options are either not to
+ # subclass dict (in which case update will call normal iter methods, but the
+ # object won't pass isinstance checks) or to subclass dict and keep that
+ # storage updated (no shadowing all its methods like _ListWrapper).
+ new_dict.update(model.d)
+ self.assertEqual({1: 3}, new_dict)
+
+ def testConstructableFromSequence(self):
+ result = data_structures._DictWrapper([(1, 2), (3, 4)])
+ self.assertIsInstance(result, dict)
+ self.assertEqual({1: 2, 3: 4}, result)
if __name__ == "__main__":
test.main()