aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-09-17 18:42:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 18:46:51 -0700
commit3b7ca4b86416f6b6153de90bc1df6e6e5b41934c (patch)
treee1eb70719e1b8b4edc572d02e18bcedbd8c00d9b /tensorflow/python/eager
parent71fab28dc4741dedf13fea732f6b134608719bc7 (diff)
Num elements fastpath for eager tensors.
PiperOrigin-RevId: 213377426
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc41
-rw-r--r--tensorflow/python/eager/pywrap_tensor.h5
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc9
3 files changed, 32 insertions, 23 deletions
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index f34ce6af79..5f44bd4fec 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -516,25 +516,13 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
// Getter for `_num_elements`.
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
auto handle = self->handle;
- int n = TFE_TensorHandleNumDims(handle, self->status);
+ int n = TFE_TensorHandleNumElements(handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
return nullptr;
}
- tensorflow::int64 value = 1;
- if (PyErr_Occurred()) return nullptr;
- for (int i = 0; i < n; ++i) {
- int64_t dim = TFE_TensorHandleDim(handle, i, self->status);
- if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
- // Cleanup self->status before returning.
- TF_SetStatus(self->status, TF_OK, "");
- PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions");
- return nullptr;
- }
- value *= dim;
- }
- return PyLong_FromLongLong(value);
+ return PyLong_FromLongLong(n);
}
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
@@ -777,17 +765,34 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
return reinterpret_cast<PyObject*>(t);
}
-tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
- CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
return reinterpret_cast<const EagerTensor*>(tensor)->id;
}
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
- CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
reinterpret_cast<const EagerTensor*>(tensor)->handle));
}
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
+ const EagerTensor* as_c_eager_tensor =
+ reinterpret_cast<const EagerTensor*>(tensor);
+ tensorflow::int64 result = TFE_TensorHandleNumElements(
+ as_c_eager_tensor->handle, as_c_eager_tensor->status);
+
+ if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
+ PyExc_ValueError)) {
+ // Cleanup status before returning.
+ TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
+ return -1;
+ }
+
+ return result;
+}
+
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
if (!PyType_Check(base_class)) {
PyErr_SetString(
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
index bc042eb19e..4eaa1ba536 100644
--- a/tensorflow/python/eager/pywrap_tensor.h
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -21,8 +21,9 @@ limitations under the License.
#include "tensorflow/python/lib/core/numpy.h"
bool EagerTensor_CheckExact(const PyObject* o);
-tensorflow::int64 EagerTensor_id(const PyObject* tensor);
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor);
namespace tensorflow {
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 99b46159a9..a0f6be459e 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -860,7 +860,7 @@ static tensorflow::int64 MakeInt(PyObject* integer) {
static tensorflow::int64 FastTensorId(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
- return EagerTensor_id(tensor);
+ return PyEagerTensor_ID(tensor);
}
PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
if (id_field == nullptr) {
@@ -873,7 +873,7 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
- return EagerTensor_dtype(tensor);
+ return PyEagerTensor_Dtype(tensor);
}
PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
if (dtype_field == nullptr) {
@@ -1178,7 +1178,7 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
- tensorflow::int64 id = EagerTensor_id(tensor);
+ tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::TensorShape tensor_shape;
const tensorflow::Status status = t->handle->Shape(&tensor_shape);
@@ -1400,6 +1400,9 @@ class PyVSpace
}
tensorflow::int64 NumElements(PyObject* tensor) const final {
+ if (EagerTensor_CheckExact(tensor)) {
+ return PyEagerTensor_NumElements(tensor);
+ }
PyObject* arglist =
Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
PyObject* result = PyEval_CallObject(num_elements_, arglist);