diff options
Diffstat (limited to 'tensorflow/python/training/checkpointable/data_structures.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/data_structures.py | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index f06cbbfa15..c29e5db075 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import collections +import copy import six @@ -251,6 +252,12 @@ class List(CheckpointableDataStructure, collections.Sequence): self._storage[index] = self._track_value( element, name=self._name_element(index)) + def __copy__(self): + return type(self)(copy.copy(self._storage)) + + def __deepcopy__(self, memo): + return type(self)(copy.deepcopy(self._storage, memo)) + def _make_storage(self, *args, **kwargs): """Determines the backing storage (overridden in subclasses).""" return list(*args, **kwargs) @@ -325,6 +332,20 @@ class _ListWrapper(List, collections.MutableSequence, super(_ListWrapper, self).__init__(wrapped_list) self._last_wrapped_list_snapshot = list(self._storage) + # pylint: disable=protected-access + def __copy__(self): + copied = super(_ListWrapper, self).__copy__() + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + return copied + + def __deepcopy__(self, memo): + copied = super(_ListWrapper, self).__deepcopy__(memo) + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + return copied + # pylint: enable=protected-access + def _make_storage(self, wrapped_list): """Use the user's original list for storage.""" return wrapped_list @@ -449,6 +470,12 @@ class Mapping(CheckpointableDataStructure, collections.Mapping): value, name=self._name_element(key)) for key, value in self._storage.items()}) + def __copy__(self): + return type(self)(copy.copy(self._storage)) + + def __deepcopy__(self, memo): + return type(self)(copy.deepcopy(self._storage, memo)) + def _make_storage(self, *args, **kwargs): return dict(*args, **kwargs) @@ -525,6 +552,22 @@ class _DictWrapper(Mapping, collections.MutableMapping): super(_DictWrapper, self).__init__(wrapped_dict) self._update_snapshot() + # pylint: disable=protected-access + def __copy__(self): + copied = super(_DictWrapper, self).__copy__() + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + copied._non_string_key = self._non_string_key + return copied + + def __deepcopy__(self, memo): + copied = super(_DictWrapper, self).__deepcopy__(memo) + copied._non_append_mutation = self._non_append_mutation + copied._external_modification = self._external_modification + copied._non_string_key = self._non_string_key + return copied + # pylint: enable=protected-access + def _make_storage(self, wrapped_dict): """Re-use the wrapped dict for storage (to force them to be in sync).""" return wrapped_dict |