aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/data_structures_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/data_structures_test.py')
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py203
1 files changed, 196 insertions, 7 deletions
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index ce5852dd6e..472b7c32b4 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import os
import numpy
+import six
from tensorflow.python.eager import context
from tensorflow.python.eager import test
@@ -31,6 +32,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
class HasList(training.Model):
@@ -71,11 +74,14 @@ class ListTests(test.TestCase):
model = HasList()
output = model(array_ops.ones([32, 2]))
self.assertAllEqual([32, 12], output.shape)
- self.assertEqual(2, len(model.layers))
- self.assertIs(model.layer_list, model.layers[0])
- self.assertEqual(10, len(model.layers[0].layers))
+ self.assertEqual(11, len(model.layers))
+ self.assertEqual(10, len(model.layer_list.layers))
+ six.assertCountEqual(
+ self,
+ model.layers,
+ model.layer_list.layers + model.layers_with_updates)
for index in range(10):
- self.assertEqual(3 + index, model.layers[0].layers[index].units)
+ self.assertEqual(3 + index, model.layer_list.layers[index].units)
self.assertEqual(2, len(model._checkpoint_dependencies))
self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
self.assertIs(model.layers_with_updates,
@@ -113,6 +119,21 @@ class ListTests(test.TestCase):
model(model_input)
self.assertEqual(2, len(model.losses))
+ def testModelContainersCompareEqual(self):
+ class HasEqualContainers(training.Model):
+
+ def __init__(self):
+ super(HasEqualContainers, self).__init__()
+ self.l1 = []
+ self.l2 = []
+
+ model = HasEqualContainers()
+ first_layer = HasEqualContainers()
+ model.l1.append(first_layer)
+ second_layer = HasEqualContainers()
+ model.l2.append(second_layer)
+ self.assertEqual([first_layer, second_layer], model.layers)
+
def testNotCheckpointable(self):
class NotCheckpointable(object):
pass
@@ -158,11 +179,62 @@ class ListTests(test.TestCase):
self.assertEqual([v], l.trainable_weights)
self.assertEqual([v2], l.non_trainable_weights)
+ def testListWrapperBasic(self):
+ # _ListWrapper, unlike List, compares like the built-in list type (since it
+ # is used to automatically replace lists).
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ self.assertEqual([a, a],
+ [a, a])
+ self.assertEqual(data_structures._ListWrapper([a, a]),
+ data_structures._ListWrapper([a, a]))
+ self.assertEqual([a, a],
+ data_structures._ListWrapper([a, a]))
+ self.assertEqual(data_structures._ListWrapper([a, a]),
+ [a, a])
+ self.assertNotEqual([a, a],
+ [b, a])
+ self.assertNotEqual(data_structures._ListWrapper([a, a]),
+ data_structures._ListWrapper([b, a]))
+ self.assertNotEqual([a, a],
+ data_structures._ListWrapper([b, a]))
+ self.assertLess([a], [a, b])
+ self.assertLess(data_structures._ListWrapper([a]),
+ data_structures._ListWrapper([a, b]))
+ self.assertLessEqual([a], [a, b])
+ self.assertLessEqual(data_structures._ListWrapper([a]),
+ data_structures._ListWrapper([a, b]))
+ self.assertGreater([a, b], [a])
+ self.assertGreater(data_structures._ListWrapper([a, b]),
+ data_structures._ListWrapper([a]))
+ self.assertGreaterEqual([a, b], [a])
+ self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
+ data_structures._ListWrapper([a]))
+ self.assertEqual([a], data_structures._ListWrapper([a]))
+ self.assertEqual([a], list(data_structures.List([a])))
+ self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
+ self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
+ self.assertIsInstance(data_structures._ListWrapper([a]), list)
+
+ def testWrapperChangesList(self):
+ l = []
+ l_wrapper = data_structures._ListWrapper(l)
+ l_wrapper.append(1)
+ self.assertEqual([1], l)
+
+ def testListChangesWrapper(self):
+ l = []
+ l_wrapper = data_structures._ListWrapper(l)
+ l.append(1)
+ self.assertEqual([1], l_wrapper)
+
def testHashing(self):
has_sequences = set([data_structures.List(),
data_structures.List()])
self.assertEqual(2, len(has_sequences))
self.assertNotIn(data_structures.List(), has_sequences)
+ with self.assertRaises(TypeError):
+ has_sequences.add(data_structures._ListWrapper([]))
class HasMapping(training.Model):
@@ -195,9 +267,8 @@ class MappingTests(test.TestCase):
model = HasMapping()
output = model(array_ops.ones([32, 2]))
self.assertAllEqual([32, 7], output.shape)
- self.assertEqual(1, len(model.layers))
- self.assertIs(model.layer_dict, model.layers[0])
- self.assertEqual(3, len(model.layers[0].layers))
+ self.assertEqual(5, len(model.layers))
+ six.assertCountEqual(self, model.layers, model.layer_dict.layers)
self.assertEqual(1, len(model._checkpoint_dependencies))
self.assertIs(model.layer_dict, model._checkpoint_dependencies[0].ref)
self.evaluate([v.initializer for v in model.variables])
@@ -233,6 +304,124 @@ class MappingTests(test.TestCase):
data_structures.Mapping()])
self.assertEqual(2, len(has_mappings))
self.assertNotIn(data_structures.Mapping(), has_mappings)
+ # In contrast to Mapping, dict wrappers are not hashable
+ a = tracking.Checkpointable()
+ a.d = {}
+ self.assertEqual({}, a.d)
+ self.assertFalse({} != a.d) # pylint: disable=g-explicit-bool-comparison
+ self.assertNotEqual({1: 2}, a.d)
+ with self.assertRaisesRegexp(TypeError, "unhashable"):
+ set([a.d])
+
+ def testDictWrapperBadKeys(self):
+ a = tracking.Checkpointable()
+ a.d = {}
+ a.d[1] = data_structures.List()
+ model = training.Model()
+ model.sub = a
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ with self.assertRaisesRegexp(ValueError, "non-string key"):
+ model.save_weights(save_path)
+
+ def testDictWrapperNoDependency(self):
+ a = tracking.Checkpointable()
+ a.d = data_structures.NoDependency({})
+ a.d[1] = [3]
+ self.assertEqual([a], util.list_objects(a))
+ model = training.Model()
+ model.sub = a
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ model.save_weights(save_path)
+ model.load_weights(save_path)
+
+ def testNonStringKeyNotCheckpointableValue(self):
+ a = tracking.Checkpointable()
+ a.d = {}
+ a.d["a"] = [3]
+ a.d[1] = data_structures.NoDependency([3])
+ self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
+ model = training.Model()
+ model.sub = a
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ model.save_weights(save_path)
+ model.load_weights(save_path)
+
+ def testNonAppendNotCheckpointable(self):
+ # Non-append mutations (deleting or overwriting values) are OK when the
+ # values aren't tracked.
+ a = tracking.Checkpointable()
+ a.d = {}
+ a.d["a"] = [3]
+ a.d[1] = 3
+ a.d[1] = 2
+ self.assertEqual(2, a.d[1])
+ del a.d[1]
+ a.d[2] = data_structures.NoDependency(tracking.Checkpointable())
+ second = tracking.Checkpointable()
+ a.d[2] = data_structures.NoDependency(second)
+ self.assertIs(second, a.d[2])
+ self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
+ model = training.Model()
+ model.sub = a
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ model.save_weights(save_path)
+ model.load_weights(save_path)
+
+ def testDelNoSave(self):
+ model = training.Model()
+ model.d = {}
+ model.d["a"] = []
+ del model.d["a"]
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
+ model.save_weights(save_path)
+
+ def testPopNoSave(self):
+ model = training.Model()
+ model.d = {}
+ model.d["a"] = []
+ model.d.pop("a")
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
+ model.save_weights(save_path)
+
+ def testExternalModificationNoSave(self):
+ model = training.Model()
+ external_reference = {}
+ model.d = external_reference
+ external_reference["a"] = []
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"):
+ model.save_weights(save_path)
+
+ def testOverwriteNoSave(self):
+ model = training.Model()
+ model.d = {}
+ model.d["a"] = {}
+ model.d["a"] = {}
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
+ model.save_weights(save_path)
+
+ def testIter(self):
+ model = training.Model()
+ model.d = {1: 3}
+ model.d[1] = 3
+ self.assertEqual([1], list(model.d))
+ new_dict = {}
+ # This update() is super tricky. If the dict wrapper subclasses dict,
+ # CPython will access its storage directly instead of calling any
+ # methods/properties on the object. So the options are either not to
+ # subclass dict (in which case update will call normal iter methods, but the
+ # object won't pass isinstance checks) or to subclass dict and keep that
+ # storage updated (no shadowing all its methods like _ListWrapper).
+ new_dict.update(model.d)
+ self.assertEqual({1: 3}, new_dict)
+
+ def testConstructableFromSequence(self):
+ result = data_structures._DictWrapper([(1, 2), (3, 4)])
+ self.assertIsInstance(result, dict)
+ self.assertEqual({1: 2, 3: 4}, result)
if __name__ == "__main__":
test.main()