diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-09-17 18:42:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 18:46:51 -0700 |
commit | 3b7ca4b86416f6b6153de90bc1df6e6e5b41934c (patch) | |
tree | e1eb70719e1b8b4edc572d02e18bcedbd8c00d9b /tensorflow/python/eager | |
parent | 71fab28dc4741dedf13fea732f6b134608719bc7 (diff) |
Num elements fastpath for eager tensors.
PiperOrigin-RevId: 213377426
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/pywrap_tensor.cc | 41 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tensor.h | 5 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 9 |
3 files changed, 32 insertions, 23 deletions
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index f34ce6af79..5f44bd4fec 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -516,25 +516,13 @@ static PyObject* EagerTensor_rank(EagerTensor* self) { // Getter for `_num_elements`. static PyObject* EagerTensor_num_elements(EagerTensor* self) { auto handle = self->handle; - int n = TFE_TensorHandleNumDims(handle, self->status); + int n = TFE_TensorHandleNumElements(handle, self->status); if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { // Cleanup self->status before returning. TF_SetStatus(self->status, TF_OK, ""); return nullptr; } - tensorflow::int64 value = 1; - if (PyErr_Occurred()) return nullptr; - for (int i = 0; i < n; ++i) { - int64_t dim = TFE_TensorHandleDim(handle, i, self->status); - if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { - // Cleanup self->status before returning. - TF_SetStatus(self->status, TF_OK, ""); - PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions"); - return nullptr; - } - value *= dim; - } - return PyLong_FromLongLong(value); + return PyLong_FromLongLong(n); } static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) { @@ -777,17 +765,34 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { return reinterpret_cast<PyObject*>(t); } -tensorflow::int64 EagerTensor_id(const PyObject* tensor) { - CHECK(EagerTensor_CheckExact(tensor)); +tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) { + DCHECK(EagerTensor_CheckExact(tensor)); return reinterpret_cast<const EagerTensor*>(tensor)->id; } -tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) { - CHECK(EagerTensor_CheckExact(tensor)); +tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) { + DCHECK(EagerTensor_CheckExact(tensor)); return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType( reinterpret_cast<const EagerTensor*>(tensor)->handle)); } +tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) { + DCHECK(EagerTensor_CheckExact(tensor)); + const EagerTensor* as_c_eager_tensor = + reinterpret_cast<const EagerTensor*>(tensor); + tensorflow::int64 result = TFE_TensorHandleNumElements( + as_c_eager_tensor->handle, as_c_eager_tensor->status); + + if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status, + PyExc_ValueError)) { + // Cleanup status before returning. + TF_SetStatus(as_c_eager_tensor->status, TF_OK, ""); + return -1; + } + + return result; +} + PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { if (!PyType_Check(base_class)) { PyErr_SetString( diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h index bc042eb19e..4eaa1ba536 100644 --- a/tensorflow/python/eager/pywrap_tensor.h +++ b/tensorflow/python/eager/pywrap_tensor.h @@ -21,8 +21,9 @@ limitations under the License. #include "tensorflow/python/lib/core/numpy.h" bool EagerTensor_CheckExact(const PyObject* o); -tensorflow::int64 EagerTensor_id(const PyObject* tensor); -tensorflow::DataType EagerTensor_dtype(const PyObject* tensor); +tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor); +tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor); +tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor); namespace tensorflow { TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 99b46159a9..a0f6be459e 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -860,7 +860,7 @@ static tensorflow::int64 MakeInt(PyObject* integer) { static tensorflow::int64 FastTensorId(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { - return EagerTensor_id(tensor); + return PyEagerTensor_ID(tensor); } PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); if (id_field == nullptr) { @@ -873,7 +873,7 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) { static tensorflow::DataType FastTensorDtype(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { - return EagerTensor_dtype(tensor); + return PyEagerTensor_Dtype(tensor); } PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype"); if (dtype_field == nullptr) { @@ -1178,7 +1178,7 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); - tensorflow::int64 id = EagerTensor_id(tensor); + tensorflow::int64 id = PyEagerTensor_ID(tensor); tensorflow::TensorShape tensor_shape; const tensorflow::Status status = t->handle->Shape(&tensor_shape); @@ -1400,6 +1400,9 @@ class PyVSpace } tensorflow::int64 NumElements(PyObject* tensor) const final { + if (EagerTensor_CheckExact(tensor)) { + return PyEagerTensor_NumElements(tensor); + } PyObject* arglist = Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); PyObject* result = PyEval_CallObject(num_elements_, arglist); |