aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/eager/backprop.py35
-rw-r--r--tensorflow/python/framework/test_util.py3
2 files changed, 30 insertions, 8 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index d76fecf126..d8e13d7231 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import functools
import operator
import threading
@@ -42,6 +43,26 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
+class _TensorCache(object):
+ """Simple cache which evicts items based on length in a FIFO manner."""
+
+ def __init__(self, max_items=256):
+ self._data = collections.OrderedDict()
+ self._max_items = max_items if max_items else 256
+
+ def put(self, key, value):
+ self._data[key] = value
+
+ if len(self._data) > self._max_items:
+ self._data.popitem(last=False)
+
+ def get(self, key):
+ return self._data.get(key, None)
+
+ def flush(self):
+ self._data = {}
+
+
_op_attr_type_cache = {}
@@ -734,8 +755,7 @@ def _num_elements(grad):
raise ValueError("`grad` not a Tensor or IndexedSlices.")
-_last_zero_shape_dtype = [None, None]
-_last_zero = [None]
+_zeros_cache = _TensorCache()
def _fast_fill(value, shape, dtype):
@@ -744,14 +764,17 @@ def _fast_fill(value, shape, dtype):
def _zeros(shape, dtype):
"""Wraps array_ops.zeros to cache last zero for a given shape and dtype."""
+ device = context.context().device_name
if dtype == dtypes.variant:
# TODO(apassos): need to save enough information about variant tensors to do
# a zeros
return None
- if [shape, dtype] != _last_zero_shape_dtype:
- _last_zero_shape_dtype[:] = [shape, dtype]
- _last_zero[0] = _fast_fill(0, shape, dtype)
- return _last_zero[0]
+ cache_key = shape, dtype, device
+ cached = _zeros_cache.get(cache_key)
+ if cached is None:
+ cached = _fast_fill(0, shape, dtype)
+ _zeros_cache.put(cache_key, cached)
+ return cached
def _ones(shape, dtype):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index bfdd98819e..c09e2d8084 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -463,8 +463,7 @@ def assert_no_new_tensors(f):
f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.
- backprop._last_zero = [None]
- backprop._shape_dtype = [None, None]
+ backprop._zeros_cache.flush()
context.get_default_context().scalar_cache().clear()
gc.collect()
tensors_after = [