aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-08-16 11:26:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 12:03:20 -0700
commit72c5efff887256ab50bc5f9e586cea17c71a50b2 (patch)
treef1c39345473b7642030fdc5f9405c266c42e03bc /tensorflow/c
parente77dbdb2050de8cd0504b8484904858dfcd64c75 (diff)
[C API/Eager]: Fix bug in TFE_OpSetAttrString.
TFE_OpSetAttrString was holding on to the 'value' pointer after it returned. This bug was introduced in commit 2b0805301e4531dd7c2ed677d932f6408675460e which caused TFE_OpSetAttrString to invoke AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value); instead of: AttrBuilder& AttrBuilder::Set(StringPiece attr_name, T&& value) (where the latter copies 'value' when T is a StringPiece or const char* and the former aliases the memory pointed to by StringPiece). In this process, I realized that AttrBuilder::Set(StringPiece attr_name, StringPiece&& value) was never being invoked (other than in this buggy situation), so I removed it altogether. Without the changes to attr_builder.{h,cc}, the newly added test fails - complaining that "NHWC" is not a valid value for the "padding" attribute. PiperOrigin-RevId: 209017110
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/eager/c_api_test.cc57
1 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 71d5f3613c..7126227cf5 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1471,4 +1471,61 @@ void BM_ReadVariable(int iters) {
}
BENCHMARK(BM_ReadVariable);
+TEST(CAPI, StringAttributes) {
+ // Test that TFE_OpSetAttrString doesn't hold on to the value after it
+ // returns.
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ std::vector<int64_t> dims(4, 1);
+ TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* tensor =
+ TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
+ float tensor_data[] = {1};
+ memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
+ TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, tensor_handle, status);
+ TF_DeleteTensor(tensor);
+ TFE_DeleteTensorHandle(tensor_handle);
+
+ std::vector<int64_t> values(4, 1);
+ TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size());
+ TFE_OpSetAttrIntList(op, "strides", values.data(), values.size());
+
+ const int BUFFER_SIZE = 10;
+ char buffer[BUFFER_SIZE];
+ std::strncpy(buffer, "VALID", BUFFER_SIZE);
+ TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer));
+ // Overwriting value in "buffer", should be fine since TFE_Op
+ // shouldn't be holding on to it.
+ std::strncpy(buffer, "NHWC", BUFFER_SIZE);
+ TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer));
+
+ TFE_OpSetAttrType(op, "T", TF_FLOAT);
+
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(op, &retvals[0], &num_retvals, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+
+ tensor = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ EXPECT_EQ(4, TF_TensorByteSize(tensor));
+ TF_DeleteTensor(tensor);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(op);
+
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
} // namespace