aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_test.cc
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2017-12-19 17:28:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-20 10:57:07 -0800
commitd064a47543f51ff5a62927a76bb0fb0862d05558 (patch)
treede5429f36f01e084217c10370a0cff9c446b5e7a /tensorflow/c/c_api_test.cc
parent5d2e8e05c2ddca08e4fc7b17c88ac36a6036dd4b (diff)
Read ApiDef from TensorFlow Go API.
PiperOrigin-RevId: 179625412
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r--tensorflow/c/c_api_test.cc72
1 files changed, 72 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 4e89b4fc43..df697e16d3 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
+#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/node_def.pb_text.h"
@@ -2027,6 +2028,77 @@ TEST_F(CApiAttributesTest, Errors) {
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
}
+TEST(TestApiDef, TestCreateApiDef) {
+ TF_Status* status = TF_NewStatus();
+ TF_Library* lib =
+ TF_LoadLibrary("tensorflow/c/test_op.so", status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+
+ TF_Buffer op_list_buf = TF_GetOpList(lib);
+ status = TF_NewStatus();
+ auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+
+ string op_name = "TestCApi";
+ status = TF_NewStatus();
+ auto* api_def_buf =
+ TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+
+ tensorflow::ApiDef api_def;
+ EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
+ EXPECT_EQ(op_name, api_def.graph_op_name());
+ EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary());
+
+ TF_DeleteBuffer(api_def_buf);
+ TF_DeleteApiDefMap(api_def_map);
+ TF_DeleteLibraryHandle(lib);
+}
+
+TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
+ TF_Status* status = TF_NewStatus();
+ TF_Library* lib =
+ TF_LoadLibrary("tensorflow/c/test_op.so", status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+
+ TF_Buffer op_list_buf = TF_GetOpList(lib);
+ status = TF_NewStatus();
+ auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+
+ string api_def_overwrites = R"(op: <
+ graph_op_name: "TestCApi"
+ summary: "New summary"
+>
+)";
+ status = TF_NewStatus();
+ TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(),
+ api_def_overwrites.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+
+ string op_name = "TestCApi";
+ status = TF_NewStatus();
+ auto* api_def_buf =
+ TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+
+ tensorflow::ApiDef api_def;
+ EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
+ EXPECT_EQ(op_name, api_def.graph_op_name());
+ EXPECT_EQ("New summary", api_def.summary());
+
+ TF_DeleteBuffer(api_def_buf);
+ TF_DeleteApiDefMap(api_def_map);
+ TF_DeleteLibraryHandle(lib);
+}
+
#undef EXPECT_TF_META
} // namespace