aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xtensorflow/c/eager/c_api.cc11
-rwxr-xr-xtensorflow/c/eager/c_api.h2
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h1
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc41
-rw-r--r--tensorflow/python/eager/pywrap_tensor.h5
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc9
7 files changed, 61 insertions, 24 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 6f86ea80e5..0bf3d9542b 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -375,6 +375,17 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
return result;
}
+int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return -1;
+ }
+ tensorflow::int64 result;
+ status->status = h->handle->NumElements(&result);
+ return result;
+}
+
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index a87d73ec8e..6323f8a053 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -163,6 +163,8 @@ TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h,
TF_Status* status);
+TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h,
+ TF_Status* status);
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index b912f7d37b..d58724cbfa 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
Status TensorHandle::NumDims(int* num_dims) {
if (IsRemote()) {
TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
- CHECK(remote_shape_ != nullptr);
*num_dims = remote_shape_->dims();
} else {
TF_RETURN_IF_ERROR(WaitReady());
@@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) {
return Status::OK();
}
+Status TensorHandle::NumElements(int64* num_elements) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ *num_elements = remote_shape_->num_elements();
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(num_elements != nullptr);
+
+ *num_elements = tensor_.NumElements();
+ }
+
+ return Status::OK();
+}
+
Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
if (!IsRemote()) {
return errors::FailedPrecondition(
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 1bc9c6531a..e55f1a0338 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -113,6 +113,7 @@ class TensorHandle : public core::RefCounted {
Status NumDims(int* num_dims);
Status Dim(int dim_index, int64* dim);
+ Status NumElements(int64* num_elements);
// Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(int64* op_id, int32* output_num);
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index f34ce6af79..5f44bd4fec 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -516,25 +516,13 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
// Getter for `_num_elements`.
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
auto handle = self->handle;
- int n = TFE_TensorHandleNumDims(handle, self->status);
+ int n = TFE_TensorHandleNumElements(handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
return nullptr;
}
- tensorflow::int64 value = 1;
- if (PyErr_Occurred()) return nullptr;
- for (int i = 0; i < n; ++i) {
- int64_t dim = TFE_TensorHandleDim(handle, i, self->status);
- if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
- // Cleanup self->status before returning.
- TF_SetStatus(self->status, TF_OK, "");
- PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions");
- return nullptr;
- }
- value *= dim;
- }
- return PyLong_FromLongLong(value);
+ return PyLong_FromLongLong(n);
}
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
@@ -777,17 +765,34 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
return reinterpret_cast<PyObject*>(t);
}
-tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
- CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
return reinterpret_cast<const EagerTensor*>(tensor)->id;
}
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
- CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
reinterpret_cast<const EagerTensor*>(tensor)->handle));
}
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
+ const EagerTensor* as_c_eager_tensor =
+ reinterpret_cast<const EagerTensor*>(tensor);
+ tensorflow::int64 result = TFE_TensorHandleNumElements(
+ as_c_eager_tensor->handle, as_c_eager_tensor->status);
+
+ if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
+ PyExc_ValueError)) {
+ // Cleanup status before returning.
+ TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
+ return -1;
+ }
+
+ return result;
+}
+
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
if (!PyType_Check(base_class)) {
PyErr_SetString(
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
index bc042eb19e..4eaa1ba536 100644
--- a/tensorflow/python/eager/pywrap_tensor.h
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -21,8 +21,9 @@ limitations under the License.
#include "tensorflow/python/lib/core/numpy.h"
bool EagerTensor_CheckExact(const PyObject* o);
-tensorflow::int64 EagerTensor_id(const PyObject* tensor);
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor);
namespace tensorflow {
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 99b46159a9..a0f6be459e 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -860,7 +860,7 @@ static tensorflow::int64 MakeInt(PyObject* integer) {
static tensorflow::int64 FastTensorId(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
- return EagerTensor_id(tensor);
+ return PyEagerTensor_ID(tensor);
}
PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
if (id_field == nullptr) {
@@ -873,7 +873,7 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
- return EagerTensor_dtype(tensor);
+ return PyEagerTensor_Dtype(tensor);
}
PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
if (dtype_field == nullptr) {
@@ -1178,7 +1178,7 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
- tensorflow::int64 id = EagerTensor_id(tensor);
+ tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::TensorShape tensor_shape;
const tensorflow::Status status = t->handle->Shape(&tensor_shape);
@@ -1400,6 +1400,9 @@ class PyVSpace
}
tensorflow::int64 NumElements(PyObject* tensor) const final {
+ if (EagerTensor_CheckExact(tensor)) {
+ return PyEagerTensor_NumElements(tensor);
+ }
PyObject* arglist =
Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
PyObject* result = PyEval_CallObject(num_elements_, arglist);