From e23d522e943309cefae368a11c21ae37b6986165 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Wed, 5 Sep 2018 22:34:52 -0700 Subject: Allow creating a py EagerTensor that shares the underlying TensorHandle. This is so that gradients with respect to scalars pass (see the test added in backprop_test.py). A micro benchmark just calling constant_op.constant slows down a bit - this is inevitable as we are creating a new python object. After: walltime: ~2.1 Before: walltime: ~1.47 Linear regression benchmark is pretty much unchanged. PiperOrigin-RevId: 211753801 --- tensorflow/c/eager/c_api.cc | 13 +++++++++++++ tensorflow/c/eager/c_api.h | 6 ++++++ tensorflow/c/eager/c_api_test.cc | 25 +++++++++++++++++++++++++ 3 files changed, 44 insertions(+) (limited to 'tensorflow/c/eager') diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 77e3878a94..349d9bcd7c 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -399,6 +399,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { : d->name().c_str(); } +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + + h->handle->Ref(); + + return new TFE_TensorHandle(h->handle); +} + TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index eec2750d6e..337447eec9 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -171,6 +171,12 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); +// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor +// with `h`. On success, `status` is set to OK. On failure, `status` reflects +// the error and a nullptr is returned. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status); + // This function will block till the operation that produces `h` has // completed. The memory returned might alias the internal memory used by // TensorFlow. Hence, callers should not mutate this memory (for example by diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 7126227cf5..55331022b9 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1528,4 +1528,29 @@ TEST(CAPI, StringAttributes) { TFE_DeleteContext(ctx); TF_DeleteStatus(status); } + +TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { + TFE_TensorHandle* h = TestMatrixTensorHandle(); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); + + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + TFE_TensorHandle* h_shares_tensor = + TFE_TensorHandleCopySharingTensor(h, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get()); + ASSERT_EQ(16, TF_TensorByteSize(t)); + float data[4] = {0}; + memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(1.0, data[0]); + EXPECT_EQ(2.0, data[1]); + EXPECT_EQ(3.0, data[2]); + EXPECT_EQ(4.0, data[3]); + TF_DeleteTensor(t); + + TFE_DeleteTensorHandle(h); + TFE_DeleteTensorHandle(h_shares_tensor); +} } // namespace -- cgit v1.2.3