diff options
author | Allen Lavoie <allenl@google.com> | 2018-09-10 10:23:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-10 10:33:00 -0700 |
commit | 07c0f308ecce579ec69ad53541332ccf506ca280 (patch) | |
tree | 951b9367292f9db179e485588e821118abaa1517 /tensorflow/python/training | |
parent | 5f004516a3c104ed7632ff4a31b65c49f620d199 (diff) |
Make checkpointable list and dict wrappers copyable and deepcopyable
Also tests copying Checkpointable objects, which seems to just work.
PiperOrigin-RevId: 212289140
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/checkpointable/data_structures.py | 43 | ||||
-rw-r--r-- | tensorflow/python/training/checkpointable/data_structures_test.py | 99 |
2 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index f06cbbfa15..c29e5db075 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import collections +import copy import six @@ -251,6 +252,12 @@ class List(CheckpointableDataStructure, collections.Sequence): self._storage[index] = self._track_value( element, name=self._name_element(index)) + def __copy__(self): + return type(self)(copy.copy(self._storage)) + + def __deepcopy__(self, memo): + return type(self)(copy.deepcopy(self._storage, memo)) + def _make_storage(self, *args, **kwargs): """Determines the backing storage (overridden in subclasses).""" return list(*args, **kwargs) @@ -325,6 +332,20 @@ class _ListWrapper(List, collections.MutableSequence, super(_ListWrapper, self).__init__(wrapped_list) self._last_wrapped_list_snapshot = list(self._storage) + # pylint: disable=protected-access + def __copy__(self): + copied = super(_ListWrapper, self).__copy__() + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + return copied + + def __deepcopy__(self, memo): + copied = super(_ListWrapper, self).__deepcopy__(memo) + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + return copied + # pylint: enable=protected-access + def _make_storage(self, wrapped_list): """Use the user's original list for storage.""" return wrapped_list @@ -449,6 +470,12 @@ class Mapping(CheckpointableDataStructure, collections.Mapping): value, name=self._name_element(key)) for key, value in self._storage.items()}) + def __copy__(self): + return type(self)(copy.copy(self._storage)) + + def __deepcopy__(self, memo): + return type(self)(copy.deepcopy(self._storage, memo)) + def _make_storage(self, *args, **kwargs): return dict(*args, **kwargs) @@ -525,6 +552,22 @@ class _DictWrapper(Mapping, collections.MutableMapping): super(_DictWrapper, self).__init__(wrapped_dict) self._update_snapshot() + # pylint: disable=protected-access + def __copy__(self): + copied = super(_DictWrapper, self).__copy__() + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + copied._non_string_key = self._non_string_key + return copied + + def __deepcopy__(self, memo): + copied = super(_DictWrapper, self).__deepcopy__(memo) + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + copied._non_string_key = self._non_string_key + return copied + # pylint: enable=protected-access + def _make_storage(self, wrapped_dict): """Re-use the wrapped dict for storage (to force them to be in sync).""" return wrapped_dict diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py index 4638917b4c..5597c7c772 100644 --- a/tensorflow/python/training/checkpointable/data_structures_test.py +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import os import numpy @@ -424,6 +425,104 @@ class MappingTests(test.TestCase): new_dict.update(model.d) self.assertEqual({1: 3}, new_dict) + def testListShallowCopy(self): + root = tracking.Checkpointable() + orig_list = [[1.]] + root.a = orig_list + copied = copy.copy(root.a) + self.assertAllEqual([[1.]], copied) + self.assertIsNot(root.a, copied) + self.assertIs(root.a[0], copied[0]) + + # Dirtiness should be inherited + util.list_objects(root.a) + orig_list.append(1.) + with self.assertRaises(ValueError): + util.list_objects(root.a) + with self.assertRaises(ValueError): + util.list_objects(copy.copy(root.a)) + + def testListDeepCopy(self): + root = tracking.Checkpointable() + orig_list = [[1.]] + root.a = orig_list + copied = copy.deepcopy(root.a) + self.assertAllEqual([[1.]], copied) + self.assertIsNot(root.a, copied) + self.assertIsNot(root.a[0], copied[0]) + + # Dirtiness should be inherited + util.list_objects(root.a) + orig_list.append(1.) + with self.assertRaises(ValueError): + util.list_objects(root.a) + with self.assertRaises(ValueError): + util.list_objects(copy.deepcopy(root.a)) + + def testDictShallowCopy(self): + root = tracking.Checkpointable() + orig_dict = {"a": [1.]} + root.a = orig_dict + copied = copy.copy(root.a) + self.assertAllEqual([1.], copied["a"]) + self.assertIsNot(root.a, copied) + self.assertIs(root.a["a"], copied["a"]) + + # Dirtiness should be inherited + util.list_objects(root.a) + orig_dict["b"] = [] + with self.assertRaises(ValueError): + util.list_objects(root.a) + with self.assertRaises(ValueError): + util.list_objects(copy.copy(root.a)) + + def testDictDeepCopy(self): + root = tracking.Checkpointable() + orig_dict = {"a": [1.]} + root.a = orig_dict + copied = copy.deepcopy(root.a) + self.assertAllEqual([1.], copied["a"]) + self.assertIsNot(root.a, copied) + self.assertIsNot(root.a["a"], copied["a"]) + + # Dirtiness should be inherited + util.list_objects(root.a) + orig_dict["b"] = [] + with self.assertRaises(ValueError): + util.list_objects(root.a) + with self.assertRaises(ValueError): + util.list_objects(copy.deepcopy(root.a)) + + def testShallowCopyCheckpointable(self): + original = tracking.Checkpointable() + original_sub = tracking.Checkpointable() + original.a = [[1.]] + original.b = {"a": original_sub} + shallow_copied = copy.copy(original) + self.assertIs(original_sub, shallow_copied.b["a"]) + self.assertIsNot(original, shallow_copied) + self.assertEqual([[1.]], shallow_copied.a) + shallow_deps = util.list_objects(shallow_copied) + self.assertIn(shallow_copied.a, shallow_deps) + self.assertIn(shallow_copied.b, shallow_deps) + self.assertIn(shallow_copied.b["a"], shallow_deps) + + def testDeepCopyCheckpointable(self): + original = tracking.Checkpointable() + original_sub = tracking.Checkpointable() + original.a = [[1.]] + original.b = {"a": original_sub} + deep_copied = copy.deepcopy(original) + self.assertIsNot(original, deep_copied) + self.assertIsNot(original_sub, deep_copied.b["a"]) + self.assertEqual([[1.]], deep_copied.a) + self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable) + deps = util.list_objects(deep_copied) + self.assertIn(deep_copied.a, deps) + self.assertIn(deep_copied.b, deps) + self.assertIn(deep_copied.b["a"], deps) + self.assertNotIn(original_sub, deps) + def testConstructableFromSequence(self): result = data_structures._DictWrapper([(1, 2), (3, 4)]) self.assertIsInstance(result, dict) |