aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/data_structures.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/data_structures.py')
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index f06cbbfa15..c29e5db075 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import copy
import six
@@ -251,6 +252,12 @@ class List(CheckpointableDataStructure, collections.Sequence):
self._storage[index] = self._track_value(
element, name=self._name_element(index))
+ def __copy__(self):
+ return type(self)(copy.copy(self._storage))
+
+ def __deepcopy__(self, memo):
+ return type(self)(copy.deepcopy(self._storage, memo))
+
def _make_storage(self, *args, **kwargs):
"""Determines the backing storage (overridden in subclasses)."""
return list(*args, **kwargs)
@@ -325,6 +332,20 @@ class _ListWrapper(List, collections.MutableSequence,
super(_ListWrapper, self).__init__(wrapped_list)
self._last_wrapped_list_snapshot = list(self._storage)
+ # pylint: disable=protected-access
+ def __copy__(self):
+ copied = super(_ListWrapper, self).__copy__()
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ return copied
+
+ def __deepcopy__(self, memo):
+ copied = super(_ListWrapper, self).__deepcopy__(memo)
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ return copied
+ # pylint: enable=protected-access
+
def _make_storage(self, wrapped_list):
"""Use the user's original list for storage."""
return wrapped_list
@@ -449,6 +470,12 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
value, name=self._name_element(key))
for key, value in self._storage.items()})
+ def __copy__(self):
+ return type(self)(copy.copy(self._storage))
+
+ def __deepcopy__(self, memo):
+ return type(self)(copy.deepcopy(self._storage, memo))
+
def _make_storage(self, *args, **kwargs):
return dict(*args, **kwargs)
@@ -525,6 +552,22 @@ class _DictWrapper(Mapping, collections.MutableMapping):
super(_DictWrapper, self).__init__(wrapped_dict)
self._update_snapshot()
+ # pylint: disable=protected-access
+ def __copy__(self):
+ copied = super(_DictWrapper, self).__copy__()
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ copied._non_string_key = self._non_string_key
+ return copied
+
+ def __deepcopy__(self, memo):
+ copied = super(_DictWrapper, self).__deepcopy__(memo)
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ copied._non_string_key = self._non_string_key
+ return copied
+ # pylint: enable=protected-access
+
def _make_storage(self, wrapped_dict):
"""Re-use the wrapped dict for storage (to force them to be in sync)."""
return wrapped_dict