diff options
Diffstat (limited to 'tensorflow/python/training/checkpointable/data_structures_test.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/data_structures_test.py | 203 |
1 files changed, 196 insertions, 7 deletions
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py index ce5852dd6e..472b7c32b4 100644 --- a/tensorflow/python/training/checkpointable/data_structures_test.py +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import os import numpy +import six from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -31,6 +32,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util class HasList(training.Model): @@ -71,11 +74,14 @@ class ListTests(test.TestCase): model = HasList() output = model(array_ops.ones([32, 2])) self.assertAllEqual([32, 12], output.shape) - self.assertEqual(2, len(model.layers)) - self.assertIs(model.layer_list, model.layers[0]) - self.assertEqual(10, len(model.layers[0].layers)) + self.assertEqual(11, len(model.layers)) + self.assertEqual(10, len(model.layer_list.layers)) + six.assertCountEqual( + self, + model.layers, + model.layer_list.layers + model.layers_with_updates) for index in range(10): - self.assertEqual(3 + index, model.layers[0].layers[index].units) + self.assertEqual(3 + index, model.layer_list.layers[index].units) self.assertEqual(2, len(model._checkpoint_dependencies)) self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref) self.assertIs(model.layers_with_updates, @@ -113,6 +119,21 @@ class ListTests(test.TestCase): model(model_input) self.assertEqual(2, len(model.losses)) + def testModelContainersCompareEqual(self): + class HasEqualContainers(training.Model): + + def __init__(self): + super(HasEqualContainers, self).__init__() + self.l1 = [] + self.l2 = [] + + model = HasEqualContainers() + first_layer = HasEqualContainers() + model.l1.append(first_layer) + second_layer = HasEqualContainers() + model.l2.append(second_layer) + self.assertEqual([first_layer, second_layer], model.layers) + def testNotCheckpointable(self): class NotCheckpointable(object): pass @@ -158,11 +179,62 @@ class ListTests(test.TestCase): self.assertEqual([v], l.trainable_weights) self.assertEqual([v2], l.non_trainable_weights) + def testListWrapperBasic(self): + # _ListWrapper, unlike List, compares like the built-in list type (since it + # is used to automatically replace lists). + a = tracking.Checkpointable() + b = tracking.Checkpointable() + self.assertEqual([a, a], + [a, a]) + self.assertEqual(data_structures._ListWrapper([a, a]), + data_structures._ListWrapper([a, a])) + self.assertEqual([a, a], + data_structures._ListWrapper([a, a])) + self.assertEqual(data_structures._ListWrapper([a, a]), + [a, a]) + self.assertNotEqual([a, a], + [b, a]) + self.assertNotEqual(data_structures._ListWrapper([a, a]), + data_structures._ListWrapper([b, a])) + self.assertNotEqual([a, a], + data_structures._ListWrapper([b, a])) + self.assertLess([a], [a, b]) + self.assertLess(data_structures._ListWrapper([a]), + data_structures._ListWrapper([a, b])) + self.assertLessEqual([a], [a, b]) + self.assertLessEqual(data_structures._ListWrapper([a]), + data_structures._ListWrapper([a, b])) + self.assertGreater([a, b], [a]) + self.assertGreater(data_structures._ListWrapper([a, b]), + data_structures._ListWrapper([a])) + self.assertGreaterEqual([a, b], [a]) + self.assertGreaterEqual(data_structures._ListWrapper([a, b]), + data_structures._ListWrapper([a])) + self.assertEqual([a], data_structures._ListWrapper([a])) + self.assertEqual([a], list(data_structures.List([a]))) + self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a]) + self.assertEqual([a, a], [a] + data_structures._ListWrapper([a])) + self.assertIsInstance(data_structures._ListWrapper([a]), list) + + def testWrapperChangesList(self): + l = [] + l_wrapper = data_structures._ListWrapper(l) + l_wrapper.append(1) + self.assertEqual([1], l) + + def testListChangesWrapper(self): + l = [] + l_wrapper = data_structures._ListWrapper(l) + l.append(1) + self.assertEqual([1], l_wrapper) + def testHashing(self): has_sequences = set([data_structures.List(), data_structures.List()]) self.assertEqual(2, len(has_sequences)) self.assertNotIn(data_structures.List(), has_sequences) + with self.assertRaises(TypeError): + has_sequences.add(data_structures._ListWrapper([])) class HasMapping(training.Model): @@ -195,9 +267,8 @@ class MappingTests(test.TestCase): model = HasMapping() output = model(array_ops.ones([32, 2])) self.assertAllEqual([32, 7], output.shape) - self.assertEqual(1, len(model.layers)) - self.assertIs(model.layer_dict, model.layers[0]) - self.assertEqual(3, len(model.layers[0].layers)) + self.assertEqual(5, len(model.layers)) + six.assertCountEqual(self, model.layers, model.layer_dict.layers) self.assertEqual(1, len(model._checkpoint_dependencies)) self.assertIs(model.layer_dict, model._checkpoint_dependencies[0].ref) self.evaluate([v.initializer for v in model.variables]) @@ -233,6 +304,124 @@ class MappingTests(test.TestCase): data_structures.Mapping()]) self.assertEqual(2, len(has_mappings)) self.assertNotIn(data_structures.Mapping(), has_mappings) + # In contrast to Mapping, dict wrappers are not hashable + a = tracking.Checkpointable() + a.d = {} + self.assertEqual({}, a.d) + self.assertFalse({} != a.d) # pylint: disable=g-explicit-bool-comparison + self.assertNotEqual({1: 2}, a.d) + with self.assertRaisesRegexp(TypeError, "unhashable"): + set([a.d]) + + def testDictWrapperBadKeys(self): + a = tracking.Checkpointable() + a.d = {} + a.d[1] = data_structures.List() + model = training.Model() + model.sub = a + save_path = os.path.join(self.get_temp_dir(), "ckpt") + with self.assertRaisesRegexp(ValueError, "non-string key"): + model.save_weights(save_path) + + def testDictWrapperNoDependency(self): + a = tracking.Checkpointable() + a.d = data_structures.NoDependency({}) + a.d[1] = [3] + self.assertEqual([a], util.list_objects(a)) + model = training.Model() + model.sub = a + save_path = os.path.join(self.get_temp_dir(), "ckpt") + model.save_weights(save_path) + model.load_weights(save_path) + + def testNonStringKeyNotCheckpointableValue(self): + a = tracking.Checkpointable() + a.d = {} + a.d["a"] = [3] + a.d[1] = data_structures.NoDependency([3]) + self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) + model = training.Model() + model.sub = a + save_path = os.path.join(self.get_temp_dir(), "ckpt") + model.save_weights(save_path) + model.load_weights(save_path) + + def testNonAppendNotCheckpointable(self): + # Non-append mutations (deleting or overwriting values) are OK when the + # values aren't tracked. + a = tracking.Checkpointable() + a.d = {} + a.d["a"] = [3] + a.d[1] = 3 + a.d[1] = 2 + self.assertEqual(2, a.d[1]) + del a.d[1] + a.d[2] = data_structures.NoDependency(tracking.Checkpointable()) + second = tracking.Checkpointable() + a.d[2] = data_structures.NoDependency(second) + self.assertIs(second, a.d[2]) + self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) + model = training.Model() + model.sub = a + save_path = os.path.join(self.get_temp_dir(), "ckpt") + model.save_weights(save_path) + model.load_weights(save_path) + + def testDelNoSave(self): + model = training.Model() + model.d = {} + model.d["a"] = [] + del model.d["a"] + save_path = os.path.join(self.get_temp_dir(), "ckpt") + with self.assertRaisesRegexp(ValueError, "overwritten or deleted"): + model.save_weights(save_path) + + def testPopNoSave(self): + model = training.Model() + model.d = {} + model.d["a"] = [] + model.d.pop("a") + save_path = os.path.join(self.get_temp_dir(), "ckpt") + with self.assertRaisesRegexp(ValueError, "overwritten or deleted"): + model.save_weights(save_path) + + def testExternalModificationNoSave(self): + model = training.Model() + external_reference = {} + model.d = external_reference + external_reference["a"] = [] + save_path = os.path.join(self.get_temp_dir(), "ckpt") + with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"): + model.save_weights(save_path) + + def testOverwriteNoSave(self): + model = training.Model() + model.d = {} + model.d["a"] = {} + model.d["a"] = {} + save_path = os.path.join(self.get_temp_dir(), "ckpt") + with self.assertRaisesRegexp(ValueError, "overwritten or deleted"): + model.save_weights(save_path) + + def testIter(self): + model = training.Model() + model.d = {1: 3} + model.d[1] = 3 + self.assertEqual([1], list(model.d)) + new_dict = {} + # This update() is super tricky. If the dict wrapper subclasses dict, + # CPython will access its storage directly instead of calling any + # methods/properties on the object. So the options are either not to + # subclass dict (in which case update will call normal iter methods, but the + # object won't pass isinstance checks) or to subclass dict and keep that + # storage updated (no shadowing all its methods like _ListWrapper). + new_dict.update(model.d) + self.assertEqual({1: 3}, new_dict) + + def testConstructableFromSequence(self): + result = data_structures._DictWrapper([(1, 2), (3, 4)]) + self.assertIsInstance(result, dict) + self.assertEqual({1: 2, 3: 4}, result) if __name__ == "__main__": test.main() |