diff options
-rw-r--r-- | tensorflow/python/eager/backprop.py | 35 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util.py | 3 |
2 files changed, 30 insertions, 8 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index d76fecf126..d8e13d7231 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import functools import operator import threading @@ -42,6 +43,26 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect +class _TensorCache(object): + """Simple cache which evicts items based on length in a FIFO manner.""" + + def __init__(self, max_items=256): + self._data = collections.OrderedDict() + self._max_items = max_items if max_items else 256 + + def put(self, key, value): + self._data[key] = value + + if len(self._data) > self._max_items: + self._data.popitem(last=False) + + def get(self, key): + return self._data.get(key, None) + + def flush(self): + self._data = {} + + _op_attr_type_cache = {} @@ -734,8 +755,7 @@ def _num_elements(grad): raise ValueError("`grad` not a Tensor or IndexedSlices.") -_last_zero_shape_dtype = [None, None] -_last_zero = [None] +_zeros_cache = _TensorCache() def _fast_fill(value, shape, dtype): @@ -744,14 +764,17 @@ def _fast_fill(value, shape, dtype): def _zeros(shape, dtype): """Wraps array_ops.zeros to cache last zero for a given shape and dtype.""" + device = context.context().device_name if dtype == dtypes.variant: # TODO(apassos): need to save enough information about variant tensors to do # a zeros return None - if [shape, dtype] != _last_zero_shape_dtype: - _last_zero_shape_dtype[:] = [shape, dtype] - _last_zero[0] = _fast_fill(0, shape, dtype) - return _last_zero[0] + cache_key = shape, dtype, device + cached = _zeros_cache.get(cache_key) + if cached is None: + cached = _fast_fill(0, shape, dtype) + _zeros_cache.put(cache_key, cached) + return cached def _ones(shape, dtype): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index bfdd98819e..c09e2d8084 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -463,8 +463,7 @@ def assert_no_new_tensors(f): f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. - backprop._last_zero = [None] - backprop._shape_dtype = [None, None] + backprop._zeros_cache.flush() context.get_default_context().scalar_cache().clear() gc.collect() tensors_after = [ |