aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-09-19 18:54:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-19 18:57:41 -0700
commit7ad8e25495a2793ea14189359af736d2c662a694 (patch)
treea45d248a4eaff33d65b48864f06b59acf884f905 /tensorflow/c
parented89a2b31f775db8ae6adf894fee27cc963ba030 (diff)
Add attribute setting and getting support to TF_Function
PiperOrigin-RevId: 169337159
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/c_api.h18
-rw-r--r--tensorflow/c/c_api_function.cc27
-rw-r--r--tensorflow/c/c_api_function_test.cc39
3 files changed, 84 insertions, 0 deletions
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index ccaaa30041..719374f2a4 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1136,6 +1136,24 @@ TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func,
TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef(
const TF_Buffer* func_def, TF_Status* status);
+// Sets function attribute named `attr_name` to value stored in `proto`.
+// If this attribute is already set to another value, it is overriden.
+// `proto` should point to a sequence of bytes of length `proto_len`
+// representing a binary serialization of an AttrValue protocol
+// buffer.
+TF_CAPI_EXPORT extern void TF_FunctionSetAttrValueProto(TF_Function* func,
+ const char* attr_name,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status);
+
+// Sets `output_attr_value` to the binary-serialized AttrValue proto
+// representation of the value of the `attr_name` attr of `func`.
+// If `attr_name` attribute is not present, status is set to an error.
+TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto(
+ TF_Function* func, const char* attr_name, TF_Buffer* output_attr_value,
+ TF_Status* status);
+
// Frees the memory used by the `func` struct.
// TF_DeleteFunction is a noop if `func` is null.
// Deleting a function does not remove it from any graphs it was copied to.
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 7848883e3e..92ee77935e 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -545,4 +545,31 @@ TF_Function* TF_FunctionImportFunctionDef(const TF_Buffer* func_def,
return func;
}
+void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
+ const void* proto, size_t proto_len,
+ TF_Status* status) {
+ tensorflow::AttrValue attr_value;
+ if (!attr_value.ParseFromArray(proto, proto_len)) {
+ status->status = InvalidArgument(
+ "Unparseable AttrValue proto passed to "
+ "TF_FunctionSetAttrValueProto");
+ return;
+ }
+ (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
+ status->status = tensorflow::Status::OK();
+}
+
+void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
+ TF_Buffer* output_attr_value,
+ TF_Status* status) {
+ const auto& it = func->fdef.attr().find(attr_name);
+ if (it == func->fdef.attr().end()) {
+ status->status =
+ InvalidArgument("Function '", func->fdef.signature().name(),
+ "' has no attr named '", attr_name, "'.");
+ return;
+ }
+ status->status = MessageToBuffer(it->second, output_attr_value);
+}
+
void TF_DeleteFunction(TF_Function* func) { delete func; }
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index 9b0279dc17..82d0dc531e 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -372,6 +372,13 @@ class CApiFunctionTest : public ::testing::Test {
TF_DeleteBuffer(buf);
}
+ void GetAttr(const char* attr_name, AttrValue* out_attr) {
+ TF_Buffer* attr_buf = TF_NewBuffer();
+ TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_);
+ ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length));
+ TF_DeleteBuffer(attr_buf);
+ }
+
const char* func_name_ = "MyFunc";
const char* func_node_name_ = "MyFunc_0";
TF_Status* s_;
@@ -1406,5 +1413,37 @@ TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
string(TF_Message(s_)));
}
+TEST_F(CApiFunctionTest, Attribute) {
+ DefineFunction(func_name_, &func_);
+
+ // Get non existent attribute
+ TF_Buffer* attr_buf = TF_NewBuffer();
+ TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."),
+ string(TF_Message(s_)));
+ TF_DeleteBuffer(attr_buf);
+
+ // Set attr
+ tensorflow::AttrValue attr;
+ attr.set_s("test_attr_value");
+ string bytes;
+ attr.SerializeToString(&bytes);
+ TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(),
+ bytes.size(), s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Get attr
+ AttrValue read_attr;
+ GetAttr("test_attr_name", &read_attr);
+ ASSERT_EQ(attr.DebugString(), read_attr.DebugString());
+
+ // Retrieve the same attr after save/restore
+ Reincarnate();
+ AttrValue read_attr2;
+ GetAttr("test_attr_name", &read_attr2);
+ ASSERT_EQ(attr.DebugString(), read_attr2.DebugString());
+}
+
} // namespace
} // namespace tensorflow