diff options
-rwxr-xr-x | tensorflow/c/eager/c_api.cc | 11 | ||||
-rwxr-xr-x | tensorflow/c/eager/c_api.h | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/tensor_handle.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/tensor_handle.h | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tensor.cc | 41 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tensor.h | 5 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 9 |
7 files changed, 61 insertions, 24 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 6f86ea80e5..0bf3d9542b 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -375,6 +375,17 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { return result; } +int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } + tensorflow::int64 result; + status->status = h->handle->NumElements(&result); + return result; +} + int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index a87d73ec8e..6323f8a053 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -163,6 +163,8 @@ TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status); +TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, + TF_Status* status); // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index b912f7d37b..d58724cbfa 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) { Status TensorHandle::NumDims(int* num_dims) { if (IsRemote()) { TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); - CHECK(remote_shape_ != nullptr); *num_dims = remote_shape_->dims(); } else { TF_RETURN_IF_ERROR(WaitReady()); @@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) { return Status::OK(); } +Status TensorHandle::NumElements(int64* num_elements) { + if (IsRemote()) { + TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); + *num_elements = remote_shape_->num_elements(); + } else { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + DCHECK(num_elements != nullptr); + + *num_elements = tensor_.NumElements(); + } + + return Status::OK(); +} + Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) { if (!IsRemote()) { return errors::FailedPrecondition( diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 1bc9c6531a..e55f1a0338 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -113,6 +113,7 @@ class TensorHandle : public core::RefCounted { Status NumDims(int* num_dims); Status Dim(int dim_index, int64* dim); + Status NumElements(int64* num_elements); // Return the op_id and output num if the handle refers to a remote tensor. Status RemoteAddress(int64* op_id, int32* output_num); diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index f34ce6af79..5f44bd4fec 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -516,25 +516,13 @@ static PyObject* EagerTensor_rank(EagerTensor* self) { // Getter for `_num_elements`. static PyObject* EagerTensor_num_elements(EagerTensor* self) { auto handle = self->handle; - int n = TFE_TensorHandleNumDims(handle, self->status); + int n = TFE_TensorHandleNumElements(handle, self->status); if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { // Cleanup self->status before returning. TF_SetStatus(self->status, TF_OK, ""); return nullptr; } - tensorflow::int64 value = 1; - if (PyErr_Occurred()) return nullptr; - for (int i = 0; i < n; ++i) { - int64_t dim = TFE_TensorHandleDim(handle, i, self->status); - if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { - // Cleanup self->status before returning. - TF_SetStatus(self->status, TF_OK, ""); - PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions"); - return nullptr; - } - value *= dim; - } - return PyLong_FromLongLong(value); + return PyLong_FromLongLong(n); } static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) { @@ -777,17 +765,34 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { return reinterpret_cast<PyObject*>(t); } -tensorflow::int64 EagerTensor_id(const PyObject* tensor) { - CHECK(EagerTensor_CheckExact(tensor)); +tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) { + DCHECK(EagerTensor_CheckExact(tensor)); return reinterpret_cast<const EagerTensor*>(tensor)->id; } -tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) { - CHECK(EagerTensor_CheckExact(tensor)); +tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) { + DCHECK(EagerTensor_CheckExact(tensor)); return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType( reinterpret_cast<const EagerTensor*>(tensor)->handle)); } +tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) { + DCHECK(EagerTensor_CheckExact(tensor)); + const EagerTensor* as_c_eager_tensor = + reinterpret_cast<const EagerTensor*>(tensor); + tensorflow::int64 result = TFE_TensorHandleNumElements( + as_c_eager_tensor->handle, as_c_eager_tensor->status); + + if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status, + PyExc_ValueError)) { + // Cleanup status before returning. + TF_SetStatus(as_c_eager_tensor->status, TF_OK, ""); + return -1; + } + + return result; +} + PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { if (!PyType_Check(base_class)) { PyErr_SetString( diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h index bc042eb19e..4eaa1ba536 100644 --- a/tensorflow/python/eager/pywrap_tensor.h +++ b/tensorflow/python/eager/pywrap_tensor.h @@ -21,8 +21,9 @@ limitations under the License. #include "tensorflow/python/lib/core/numpy.h" bool EagerTensor_CheckExact(const PyObject* o); -tensorflow::int64 EagerTensor_id(const PyObject* tensor); -tensorflow::DataType EagerTensor_dtype(const PyObject* tensor); +tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor); +tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor); +tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor); namespace tensorflow { TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 99b46159a9..a0f6be459e 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -860,7 +860,7 @@ static tensorflow::int64 MakeInt(PyObject* integer) { static tensorflow::int64 FastTensorId(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { - return EagerTensor_id(tensor); + return PyEagerTensor_ID(tensor); } PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); if (id_field == nullptr) { @@ -873,7 +873,7 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) { static tensorflow::DataType FastTensorDtype(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { - return EagerTensor_dtype(tensor); + return PyEagerTensor_Dtype(tensor); } PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype"); if (dtype_field == nullptr) { @@ -1178,7 +1178,7 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); - tensorflow::int64 id = EagerTensor_id(tensor); + tensorflow::int64 id = PyEagerTensor_ID(tensor); tensorflow::TensorShape tensor_shape; const tensorflow::Status status = t->handle->Shape(&tensor_shape); @@ -1400,6 +1400,9 @@ class PyVSpace } tensorflow::int64 NumElements(PyObject* tensor) const final { + if (EagerTensor_CheckExact(tensor)) { + return PyEagerTensor_NumElements(tensor); + } PyObject* arglist = Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); PyObject* result = PyEval_CallObject(num_elements_, arglist); |