diff options
author | 2017-09-19 18:54:03 -0700 | |
---|---|---|
committer | 2017-09-19 18:57:41 -0700 | |
commit | 7ad8e25495a2793ea14189359af736d2c662a694 (patch) | |
tree | a45d248a4eaff33d65b48864f06b59acf884f905 /tensorflow/c | |
parent | ed89a2b31f775db8ae6adf894fee27cc963ba030 (diff) |
Add attribute setting and getting support to TF_Function
PiperOrigin-RevId: 169337159
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/c_api.h | 18 | ||||
-rw-r--r-- | tensorflow/c/c_api_function.cc | 27 | ||||
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 39 |
3 files changed, 84 insertions, 0 deletions
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index ccaaa30041..719374f2a4 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1136,6 +1136,24 @@ TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func, TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef( const TF_Buffer* func_def, TF_Status* status); +// Sets function attribute named `attr_name` to value stored in `proto`. +// If this attribute is already set to another value, it is overriden. +// `proto` should point to a sequence of bytes of length `proto_len` +// representing a binary serialization of an AttrValue protocol +// buffer. +TF_CAPI_EXPORT extern void TF_FunctionSetAttrValueProto(TF_Function* func, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); + +// Sets `output_attr_value` to the binary-serialized AttrValue proto +// representation of the value of the `attr_name` attr of `func`. +// If `attr_name` attribute is not present, status is set to an error. +TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto( + TF_Function* func, const char* attr_name, TF_Buffer* output_attr_value, + TF_Status* status); + // Frees the memory used by the `func` struct. // TF_DeleteFunction is a noop if `func` is null. // Deleting a function does not remove it from any graphs it was copied to. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 7848883e3e..92ee77935e 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -545,4 +545,31 @@ TF_Function* TF_FunctionImportFunctionDef(const TF_Buffer* func_def, return func; } +void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::AttrValue attr_value; + if (!attr_value.ParseFromArray(proto, proto_len)) { + status->status = InvalidArgument( + "Unparseable AttrValue proto passed to " + "TF_FunctionSetAttrValueProto"); + return; + } + (*func->fdef.mutable_attr())[string(attr_name)] = attr_value; + status->status = tensorflow::Status::OK(); +} + +void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name, + TF_Buffer* output_attr_value, + TF_Status* status) { + const auto& it = func->fdef.attr().find(attr_name); + if (it == func->fdef.attr().end()) { + status->status = + InvalidArgument("Function '", func->fdef.signature().name(), + "' has no attr named '", attr_name, "'."); + return; + } + status->status = MessageToBuffer(it->second, output_attr_value); +} + void TF_DeleteFunction(TF_Function* func) { delete func; } diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 9b0279dc17..82d0dc531e 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -372,6 +372,13 @@ class CApiFunctionTest : public ::testing::Test { TF_DeleteBuffer(buf); } + void GetAttr(const char* attr_name, AttrValue* out_attr) { + TF_Buffer* attr_buf = TF_NewBuffer(); + TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_); + ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length)); + TF_DeleteBuffer(attr_buf); + } + const char* func_name_ = "MyFunc"; const char* func_node_name_ = "MyFunc_0"; TF_Status* s_; @@ -1406,5 +1413,37 @@ TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) { string(TF_Message(s_))); } +TEST_F(CApiFunctionTest, Attribute) { + DefineFunction(func_name_, &func_); + + // Get non existent attribute + TF_Buffer* attr_buf = TF_NewBuffer(); + TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."), + string(TF_Message(s_))); + TF_DeleteBuffer(attr_buf); + + // Set attr + tensorflow::AttrValue attr; + attr.set_s("test_attr_value"); + string bytes; + attr.SerializeToString(&bytes); + TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(), + bytes.size(), s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Get attr + AttrValue read_attr; + GetAttr("test_attr_name", &read_attr); + ASSERT_EQ(attr.DebugString(), read_attr.DebugString()); + + // Retrieve the same attr after save/restore + Reincarnate(); + AttrValue read_attr2; + GetAttr("test_attr_name", &read_attr2); + ASSERT_EQ(attr.DebugString(), read_attr2.DebugString()); +} + } // namespace } // namespace tensorflow |