aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-10 10:23:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 10:33:00 -0700
commit07c0f308ecce579ec69ad53541332ccf506ca280 (patch)
tree951b9367292f9db179e485588e821118abaa1517 /tensorflow/python/training
parent5f004516a3c104ed7632ff4a31b65c49f620d199 (diff)
Make checkpointable list and dict wrappers copyable and deepcopyable
Also tests copying Checkpointable objects, which seems to just work. PiperOrigin-RevId: 212289140
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py43
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py99
2 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index f06cbbfa15..c29e5db075 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import copy
import six
@@ -251,6 +252,12 @@ class List(CheckpointableDataStructure, collections.Sequence):
self._storage[index] = self._track_value(
element, name=self._name_element(index))
+ def __copy__(self):
+ return type(self)(copy.copy(self._storage))
+
+ def __deepcopy__(self, memo):
+ return type(self)(copy.deepcopy(self._storage, memo))
+
def _make_storage(self, *args, **kwargs):
"""Determines the backing storage (overridden in subclasses)."""
return list(*args, **kwargs)
@@ -325,6 +332,20 @@ class _ListWrapper(List, collections.MutableSequence,
super(_ListWrapper, self).__init__(wrapped_list)
self._last_wrapped_list_snapshot = list(self._storage)
+ # pylint: disable=protected-access
+ def __copy__(self):
+ copied = super(_ListWrapper, self).__copy__()
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ return copied
+
+ def __deepcopy__(self, memo):
+ copied = super(_ListWrapper, self).__deepcopy__(memo)
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ return copied
+ # pylint: enable=protected-access
+
def _make_storage(self, wrapped_list):
"""Use the user's original list for storage."""
return wrapped_list
@@ -449,6 +470,12 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
value, name=self._name_element(key))
for key, value in self._storage.items()})
+ def __copy__(self):
+ return type(self)(copy.copy(self._storage))
+
+ def __deepcopy__(self, memo):
+ return type(self)(copy.deepcopy(self._storage, memo))
+
def _make_storage(self, *args, **kwargs):
return dict(*args, **kwargs)
@@ -525,6 +552,22 @@ class _DictWrapper(Mapping, collections.MutableMapping):
super(_DictWrapper, self).__init__(wrapped_dict)
self._update_snapshot()
+ # pylint: disable=protected-access
+ def __copy__(self):
+ copied = super(_DictWrapper, self).__copy__()
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ copied._non_string_key = self._non_string_key
+ return copied
+
+ def __deepcopy__(self, memo):
+ copied = super(_DictWrapper, self).__deepcopy__(memo)
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ copied._non_string_key = self._non_string_key
+ return copied
+ # pylint: enable=protected-access
+
def _make_storage(self, wrapped_dict):
"""Re-use the wrapped dict for storage (to force them to be in sync)."""
return wrapped_dict
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index 4638917b4c..5597c7c772 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
import os
import numpy
@@ -424,6 +425,104 @@ class MappingTests(test.TestCase):
new_dict.update(model.d)
self.assertEqual({1: 3}, new_dict)
+ def testListShallowCopy(self):
+ root = tracking.Checkpointable()
+ orig_list = [[1.]]
+ root.a = orig_list
+ copied = copy.copy(root.a)
+ self.assertAllEqual([[1.]], copied)
+ self.assertIsNot(root.a, copied)
+ self.assertIs(root.a[0], copied[0])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_list.append(1.)
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.copy(root.a))
+
+ def testListDeepCopy(self):
+ root = tracking.Checkpointable()
+ orig_list = [[1.]]
+ root.a = orig_list
+ copied = copy.deepcopy(root.a)
+ self.assertAllEqual([[1.]], copied)
+ self.assertIsNot(root.a, copied)
+ self.assertIsNot(root.a[0], copied[0])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_list.append(1.)
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.deepcopy(root.a))
+
+ def testDictShallowCopy(self):
+ root = tracking.Checkpointable()
+ orig_dict = {"a": [1.]}
+ root.a = orig_dict
+ copied = copy.copy(root.a)
+ self.assertAllEqual([1.], copied["a"])
+ self.assertIsNot(root.a, copied)
+ self.assertIs(root.a["a"], copied["a"])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_dict["b"] = []
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.copy(root.a))
+
+ def testDictDeepCopy(self):
+ root = tracking.Checkpointable()
+ orig_dict = {"a": [1.]}
+ root.a = orig_dict
+ copied = copy.deepcopy(root.a)
+ self.assertAllEqual([1.], copied["a"])
+ self.assertIsNot(root.a, copied)
+ self.assertIsNot(root.a["a"], copied["a"])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_dict["b"] = []
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.deepcopy(root.a))
+
+ def testShallowCopyCheckpointable(self):
+ original = tracking.Checkpointable()
+ original_sub = tracking.Checkpointable()
+ original.a = [[1.]]
+ original.b = {"a": original_sub}
+ shallow_copied = copy.copy(original)
+ self.assertIs(original_sub, shallow_copied.b["a"])
+ self.assertIsNot(original, shallow_copied)
+ self.assertEqual([[1.]], shallow_copied.a)
+ shallow_deps = util.list_objects(shallow_copied)
+ self.assertIn(shallow_copied.a, shallow_deps)
+ self.assertIn(shallow_copied.b, shallow_deps)
+ self.assertIn(shallow_copied.b["a"], shallow_deps)
+
+ def testDeepCopyCheckpointable(self):
+ original = tracking.Checkpointable()
+ original_sub = tracking.Checkpointable()
+ original.a = [[1.]]
+ original.b = {"a": original_sub}
+ deep_copied = copy.deepcopy(original)
+ self.assertIsNot(original, deep_copied)
+ self.assertIsNot(original_sub, deep_copied.b["a"])
+ self.assertEqual([[1.]], deep_copied.a)
+ self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable)
+ deps = util.list_objects(deep_copied)
+ self.assertIn(deep_copied.a, deps)
+ self.assertIn(deep_copied.b, deps)
+ self.assertIn(deep_copied.b["a"], deps)
+ self.assertNotIn(original_sub, deps)
+
def testConstructableFromSequence(self):
result = data_structures._DictWrapper([(1, 2), (3, 4)])
self.assertIsInstance(result, dict)