aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-09-05 22:34:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 22:43:35 -0700
commite23d522e943309cefae368a11c21ae37b6986165 (patch)
tree959f83e6128e87734b837082164fb88dc3803ca9 /tensorflow/c/eager
parent5393c8f0dc57857c93482bff67f1134aae9af594 (diff)
Allow creating a py EagerTensor that shares the underlying TensorHandle.
This is so that gradients with respect to scalars pass (see the test added in backprop_test.py). A micro benchmark just calling constant_op.constant slows down a bit - this is inevitable as we are creating a new python object. After: walltime: ~2.1 Before: walltime: ~1.47 Linear regression benchmark is pretty much unchanged. PiperOrigin-RevId: 211753801
Diffstat (limited to 'tensorflow/c/eager')
-rwxr-xr-xtensorflow/c/eager/c_api.cc13
-rwxr-xr-xtensorflow/c/eager/c_api.h6
-rw-r--r--tensorflow/c/eager/c_api_test.cc25
3 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 77e3878a94..349d9bcd7c 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -399,6 +399,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
: d->name().c_str();
}
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
+ TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return nullptr;
+ }
+
+ h->handle->Ref();
+
+ return new TFE_TensorHandle(h->handle);
+}
+
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index eec2750d6e..337447eec9 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -171,6 +171,12 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
+// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
+// with `h`. On success, `status` is set to OK. On failure, `status` reflects
+// the error and a nullptr is returned.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
+ TFE_TensorHandle* h, TF_Status* status);
+
// This function will block till the operation that produces `h` has
// completed. The memory returned might alias the internal memory used by
// TensorFlow. Hence, callers should not mutate this memory (for example by
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 7126227cf5..55331022b9 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1528,4 +1528,29 @@ TEST(CAPI, StringAttributes) {
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
+
+TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
+ TFE_TensorHandle* h = TestMatrixTensorHandle();
+ EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
+
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+
+ TFE_TensorHandle* h_shares_tensor =
+ TFE_TensorHandleCopySharingTensor(h, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get());
+ ASSERT_EQ(16, TF_TensorByteSize(t));
+ float data[4] = {0};
+ memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
+ EXPECT_EQ(1.0, data[0]);
+ EXPECT_EQ(2.0, data[1]);
+ EXPECT_EQ(3.0, data[2]);
+ EXPECT_EQ(4.0, data[3]);
+ TF_DeleteTensor(t);
+
+ TFE_DeleteTensorHandle(h);
+ TFE_DeleteTensorHandle(h_shares_tensor);
+}
} // namespace