diff options
-rw-r--r-- | tensorflow/python/eager/context.py | 14 | ||||
-rw-r--r-- | tensorflow/python/eager/core_test.py | 11 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tensor.cc | 26 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 13 |
4 files changed, 59 insertions, 5 deletions
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 1b8542ca3a..778ff85342 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -56,14 +56,18 @@ SYNC = 0 ASYNC = 1 -class _TensorCache(object): +class _EagerTensorCache(object): """Simple cache which evicts items based on length in a FIFO manner.""" - def __init__(self, max_items=256): + def __init__(self, max_items=256, max_tensor_size=10000): self._data = collections.OrderedDict() - self._max_items = max_items if max_items else 256 + self._max_items = max_items + self._max_tensor_size = max_tensor_size def put(self, key, value): + if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access + return + self._data[key] = value if len(self._data) > self._max_items: @@ -90,8 +94,8 @@ class _EagerContext(threading.local): self.recording_summaries = False self.summary_writer_resource = None self.scalar_cache = {} - self.ones_rank_cache = _TensorCache() - self.zeros_cache = _TensorCache() + self.ones_rank_cache = _EagerTensorCache() + self.zeros_cache = _EagerTensorCache() self.execution_mode = None diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index cbd6f4cb75..fb5442b646 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -689,5 +689,16 @@ class SendRecvTest(test_util.TensorFlowTestCase): 2.0) +class EagerTensorCacheTest(test_util.TensorFlowTestCase): + + def testCacheSkipsTensorsTooLarge(self): + cache = context._EagerTensorCache(max_items=100, max_tensor_size=3) + cache.put('1', array_ops.zeros((2, 2))) + self.assertEqual(cache.get('1'), None) + + cache.put('2', array_ops.zeros((2))) + self.assertNotEqual(cache.get('2'), None) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index c12bf89f8f..9e66ae19fa 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -474,6 +474,30 @@ static PyObject* EagerTensor_rank(EagerTensor* self) { #endif } +// Getter for `_num_elements`. +static PyObject* EagerTensor_num_elements(EagerTensor* self) { + auto handle = self->handle; + int n = TFE_TensorHandleNumDims(handle, self->status); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + return nullptr; + } + tensorflow::int64 value = 1; + if (PyErr_Occurred()) return nullptr; + for (int i = 0; i < n; ++i) { + int64_t dim = TFE_TensorHandleDim(handle, i, self->status); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions"); + return nullptr; + } + value *= dim; + } + return PyLong_FromLongLong(value); +} + static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) { Py_INCREF(self->handle_data); return self->handle_data; @@ -592,6 +616,8 @@ static PyMethodDef EagerTensor_methods[] = { {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")}, {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")}, + {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS, + PyDoc_STR("_num_elements")}, {nullptr, nullptr}, }; diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index c134536bd3..4cfd639bf9 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -802,6 +802,19 @@ class _EagerTensorBase(Tensor): """ raise NotImplementedError() + def _num_elements(self): + """Number of elements of this Tensor. + + Unlike regular Tensors, the number of elements is always known for + EagerTensors. + + This is more performant than tensor.shape.num_elements + + Returns: + Long - num elements in the tensor + """ + raise NotImplementedError() + def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name raise NotImplementedError() |