aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_function_test.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-09-19 18:54:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-19 18:57:41 -0700
commit7ad8e25495a2793ea14189359af736d2c662a694 (patch)
treea45d248a4eaff33d65b48864f06b59acf884f905 /tensorflow/c/c_api_function_test.cc
parented89a2b31f775db8ae6adf894fee27cc963ba030 (diff)
Add attribute setting and getting support to TF_Function
PiperOrigin-RevId: 169337159
Diffstat (limited to 'tensorflow/c/c_api_function_test.cc')
-rw-r--r--tensorflow/c/c_api_function_test.cc39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index 9b0279dc17..82d0dc531e 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -372,6 +372,13 @@ class CApiFunctionTest : public ::testing::Test {
TF_DeleteBuffer(buf);
}
+ void GetAttr(const char* attr_name, AttrValue* out_attr) {
+ TF_Buffer* attr_buf = TF_NewBuffer();
+ TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_);
+ ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length));
+ TF_DeleteBuffer(attr_buf);
+ }
+
const char* func_name_ = "MyFunc";
const char* func_node_name_ = "MyFunc_0";
TF_Status* s_;
@@ -1406,5 +1413,37 @@ TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
string(TF_Message(s_)));
}
+TEST_F(CApiFunctionTest, Attribute) {
+ DefineFunction(func_name_, &func_);
+
+ // Get non existent attribute
+ TF_Buffer* attr_buf = TF_NewBuffer();
+ TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."),
+ string(TF_Message(s_)));
+ TF_DeleteBuffer(attr_buf);
+
+ // Set attr
+ tensorflow::AttrValue attr;
+ attr.set_s("test_attr_value");
+ string bytes;
+ attr.SerializeToString(&bytes);
+ TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(),
+ bytes.size(), s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Get attr
+ AttrValue read_attr;
+ GetAttr("test_attr_name", &read_attr);
+ ASSERT_EQ(attr.DebugString(), read_attr.DebugString());
+
+ // Retrieve the same attr after save/restore
+ Reincarnate();
+ AttrValue read_attr2;
+ GetAttr("test_attr_name", &read_attr2);
+ ASSERT_EQ(attr.DebugString(), read_attr2.DebugString());
+}
+
} // namespace
} // namespace tensorflow