diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-09-05 22:34:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-05 22:43:35 -0700 |
commit | e23d522e943309cefae368a11c21ae37b6986165 (patch) | |
tree | 959f83e6128e87734b837082164fb88dc3803ca9 /tensorflow/c/eager | |
parent | 5393c8f0dc57857c93482bff67f1134aae9af594 (diff) |
Allow creating a py EagerTensor that shares the underlying TensorHandle.
This is so that gradients with respect to scalars pass (see the test added in
backprop_test.py).
A micro benchmark just calling constant_op.constant slows down a bit - this is
inevitable as we are creating a new python object.
After: walltime: ~2.1
Before: walltime: ~1.47
Linear regression benchmark is pretty much unchanged.
PiperOrigin-RevId: 211753801
Diffstat (limited to 'tensorflow/c/eager')
-rwxr-xr-x | tensorflow/c/eager/c_api.cc | 13 | ||||
-rwxr-xr-x | tensorflow/c/eager/c_api.h | 6 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 25 |
3 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 77e3878a94..349d9bcd7c 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -399,6 +399,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { : d->name().c_str(); } +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + + h->handle->Ref(); + + return new TFE_TensorHandle(h->handle); +} + TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index eec2750d6e..337447eec9 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -171,6 +171,12 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); +// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor +// with `h`. On success, `status` is set to OK. On failure, `status` reflects +// the error and a nullptr is returned. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status); + // This function will block till the operation that produces `h` has // completed. The memory returned might alias the internal memory used by // TensorFlow. Hence, callers should not mutate this memory (for example by diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 7126227cf5..55331022b9 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1528,4 +1528,29 @@ TEST(CAPI, StringAttributes) { TFE_DeleteContext(ctx); TF_DeleteStatus(status); } + +TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { + TFE_TensorHandle* h = TestMatrixTensorHandle(); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); + + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( + TF_NewStatus(), TF_DeleteStatus); + + TFE_TensorHandle* h_shares_tensor = + TFE_TensorHandleCopySharingTensor(h, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get()); + ASSERT_EQ(16, TF_TensorByteSize(t)); + float data[4] = {0}; + memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(1.0, data[0]); + EXPECT_EQ(2.0, data[1]); + EXPECT_EQ(3.0, data[2]); + EXPECT_EQ(4.0, data[3]); + TF_DeleteTensor(t); + + TFE_DeleteTensorHandle(h); + TFE_DeleteTensorHandle(h_shares_tensor); +} } // namespace |