aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/eager/context.py14
-rw-r--r--tensorflow/python/eager/core_test.py11
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc26
-rw-r--r--tensorflow/python/framework/ops.py13
4 files changed, 59 insertions, 5 deletions
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 1b8542ca3a..778ff85342 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -56,14 +56,18 @@ SYNC = 0
ASYNC = 1
-class _TensorCache(object):
+class _EagerTensorCache(object):
"""Simple cache which evicts items based on length in a FIFO manner."""
- def __init__(self, max_items=256):
+ def __init__(self, max_items=256, max_tensor_size=10000):
self._data = collections.OrderedDict()
- self._max_items = max_items if max_items else 256
+ self._max_items = max_items
+ self._max_tensor_size = max_tensor_size
def put(self, key, value):
+ if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access
+ return
+
self._data[key] = value
if len(self._data) > self._max_items:
@@ -90,8 +94,8 @@ class _EagerContext(threading.local):
self.recording_summaries = False
self.summary_writer_resource = None
self.scalar_cache = {}
- self.ones_rank_cache = _TensorCache()
- self.zeros_cache = _TensorCache()
+ self.ones_rank_cache = _EagerTensorCache()
+ self.zeros_cache = _EagerTensorCache()
self.execution_mode = None
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index cbd6f4cb75..fb5442b646 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -689,5 +689,16 @@ class SendRecvTest(test_util.TensorFlowTestCase):
2.0)
+class EagerTensorCacheTest(test_util.TensorFlowTestCase):
+
+ def testCacheSkipsTensorsTooLarge(self):
+ cache = context._EagerTensorCache(max_items=100, max_tensor_size=3)
+ cache.put('1', array_ops.zeros((2, 2)))
+ self.assertEqual(cache.get('1'), None)
+
+ cache.put('2', array_ops.zeros((2)))
+ self.assertNotEqual(cache.get('2'), None)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index c12bf89f8f..9e66ae19fa 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -474,6 +474,30 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
#endif
}
+// Getter for `_num_elements`.
+static PyObject* EagerTensor_num_elements(EagerTensor* self) {
+ auto handle = self->handle;
+ int n = TFE_TensorHandleNumDims(handle, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ return nullptr;
+ }
+ tensorflow::int64 value = 1;
+ if (PyErr_Occurred()) return nullptr;
+ for (int i = 0; i < n; ++i) {
+ int64_t dim = TFE_TensorHandleDim(handle, i, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions");
+ return nullptr;
+ }
+ value *= dim;
+ }
+ return PyLong_FromLongLong(value);
+}
+
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
Py_INCREF(self->handle_data);
return self->handle_data;
@@ -592,6 +616,8 @@ static PyMethodDef EagerTensor_methods[] = {
{"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")},
{"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
+ {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
+ PyDoc_STR("_num_elements")},
{nullptr, nullptr},
};
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index c134536bd3..4cfd639bf9 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -802,6 +802,19 @@ class _EagerTensorBase(Tensor):
"""
raise NotImplementedError()
+ def _num_elements(self):
+ """Number of elements of this Tensor.
+
+ Unlike regular Tensors, the number of elements is always known for
+ EagerTensors.
+
+ This is more performant than tensor.shape.num_elements
+
+ Returns:
+ Long - num elements in the tensor
+ """
+ raise NotImplementedError()
+
def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name
raise NotImplementedError()