aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/c_api_test.cc57
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.cc8
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.h3
3 files changed, 57 insertions, 11 deletions
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 71d5f3613c..7126227cf5 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1471,4 +1471,61 @@ void BM_ReadVariable(int iters) {
}
BENCHMARK(BM_ReadVariable);
+TEST(CAPI, StringAttributes) {
+ // Test that TFE_OpSetAttrString doesn't hold on to the value after it
+ // returns.
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ std::vector<int64_t> dims(4, 1);
+ TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* tensor =
+ TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
+ float tensor_data[] = {1};
+ memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
+ TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, tensor_handle, status);
+ TF_DeleteTensor(tensor);
+ TFE_DeleteTensorHandle(tensor_handle);
+
+ std::vector<int64_t> values(4, 1);
+ TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size());
+ TFE_OpSetAttrIntList(op, "strides", values.data(), values.size());
+
+ const int BUFFER_SIZE = 10;
+ char buffer[BUFFER_SIZE];
+ std::strncpy(buffer, "VALID", BUFFER_SIZE);
+ TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer));
+ // Overwriting value in "buffer", should be fine since TFE_Op
+ // shouldn't be holding on to it.
+ std::strncpy(buffer, "NHWC", BUFFER_SIZE);
+ TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer));
+
+ TFE_OpSetAttrType(op, "T", TF_FLOAT);
+
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(op, &retvals[0], &num_retvals, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+
+ tensor = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ EXPECT_EQ(4, TF_TensorByteSize(tensor));
+ TF_DeleteTensor(tensor);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(op);
+
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
} // namespace
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc
index 92307d78f2..cf1cd4134e 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.cc
+++ b/tensorflow/core/common_runtime/eager/attr_builder.cc
@@ -103,7 +103,6 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
return *this; \
}
-DEFINE_SET_ATTR(StringPiece, string_attrs_);
DEFINE_SET_ATTR(float, float_attrs_);
DEFINE_SET_ATTR(int, int_attrs_);
DEFINE_SET_ATTR(bool, bool_attrs_);
@@ -119,9 +118,6 @@ AttrBuilder& AttrBuilder::NumInputs(int n) {
void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
bool include_those_in_node_def) const {
- for (const auto& p : string_attrs_) {
- SetInAttrValueMap(m, p.first, p.second);
- }
for (const auto& p : int_attrs_) {
SetInAttrValueMap(m, p.first, p.second);
}
@@ -211,10 +207,6 @@ tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
// not been called.
if (node_def_finalized_) return f;
}
- for (const auto& p : string_attrs_) {
- CombineUnordered(
- CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
- }
for (const auto& p : int_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
&f);
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h
index 929b1b8296..fc50bed3c0 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.h
+++ b/tensorflow/core/common_runtime/eager/attr_builder.h
@@ -131,7 +131,6 @@ class AttrBuilder {
}
}
- AttrVec<StringPiece> string_attrs_;
AttrVec<int> int_attrs_;
AttrVec<float> float_attrs_;
AttrVec<bool> bool_attrs_;
@@ -143,8 +142,6 @@ class AttrBuilder {
}; // namespace tensorflow
template <>
-AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
-template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);