diff options
author | Igor Ganichev <iga@google.com> | 2017-09-19 18:54:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-19 18:57:41 -0700 |
commit | 7ad8e25495a2793ea14189359af736d2c662a694 (patch) | |
tree | a45d248a4eaff33d65b48864f06b59acf884f905 /tensorflow/c/c_api_function_test.cc | |
parent | ed89a2b31f775db8ae6adf894fee27cc963ba030 (diff) |
Add attribute setting and getting support to TF_Function
PiperOrigin-RevId: 169337159
Diffstat (limited to 'tensorflow/c/c_api_function_test.cc')
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 9b0279dc17..82d0dc531e 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -372,6 +372,13 @@ class CApiFunctionTest : public ::testing::Test { TF_DeleteBuffer(buf); } + void GetAttr(const char* attr_name, AttrValue* out_attr) { + TF_Buffer* attr_buf = TF_NewBuffer(); + TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_); + ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length)); + TF_DeleteBuffer(attr_buf); + } + const char* func_name_ = "MyFunc"; const char* func_node_name_ = "MyFunc_0"; TF_Status* s_; @@ -1406,5 +1413,37 @@ TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) { string(TF_Message(s_))); } +TEST_F(CApiFunctionTest, Attribute) { + DefineFunction(func_name_, &func_); + + // Get non existent attribute + TF_Buffer* attr_buf = TF_NewBuffer(); + TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."), + string(TF_Message(s_))); + TF_DeleteBuffer(attr_buf); + + // Set attr + tensorflow::AttrValue attr; + attr.set_s("test_attr_value"); + string bytes; + attr.SerializeToString(&bytes); + TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(), + bytes.size(), s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Get attr + AttrValue read_attr; + GetAttr("test_attr_name", &read_attr); + ASSERT_EQ(attr.DebugString(), read_attr.DebugString()); + + // Retrieve the same attr after save/restore + Reincarnate(); + AttrValue read_attr2; + GetAttr("test_attr_name", &read_attr2); + ASSERT_EQ(attr.DebugString(), read_attr2.DebugString()); +} + } // namespace } // namespace tensorflow |