From d064a47543f51ff5a62927a76bb0fb0862d05558 Mon Sep 17 00:00:00 2001 From: Anna R Date: Tue, 19 Dec 2017 17:28:06 -0800 Subject: Read ApiDef from TensorFlow Go API. PiperOrigin-RevId: 179625412 --- tensorflow/c/c_api_test.cc | 72 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) (limited to 'tensorflow/c/c_api_test.cc') diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 4e89b4fc43..df697e16d3 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/node_def.pb_text.h" @@ -2027,6 +2028,77 @@ TEST_F(CApiAttributesTest, Errors) { EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); } +TEST(TestApiDef, TestCreateApiDef) { + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op.so", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + TF_Buffer op_list_buf = TF_GetOpList(lib); + status = TF_NewStatus(); + auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string op_name = "TestCApi"; + status = TF_NewStatus(); + auto* api_def_buf = + TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + tensorflow::ApiDef api_def; + EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length)); + EXPECT_EQ(op_name, api_def.graph_op_name()); + EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary()); + + TF_DeleteBuffer(api_def_buf); + TF_DeleteApiDefMap(api_def_map); + TF_DeleteLibraryHandle(lib); +} + +TEST(TestApiDef, TestCreateApiDefWithOverwrites) { + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op.so", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + TF_Buffer op_list_buf = TF_GetOpList(lib); + status = TF_NewStatus(); + auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string api_def_overwrites = R"(op: < + graph_op_name: "TestCApi" + summary: "New summary" +> +)"; + status = TF_NewStatus(); + TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(), + api_def_overwrites.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string op_name = "TestCApi"; + status = TF_NewStatus(); + auto* api_def_buf = + TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + tensorflow::ApiDef api_def; + EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length)); + EXPECT_EQ(op_name, api_def.graph_op_name()); + EXPECT_EQ("New summary", api_def.summary()); + + TF_DeleteBuffer(api_def_buf); + TF_DeleteApiDefMap(api_def_map); + TF_DeleteLibraryHandle(lib); +} + #undef EXPECT_TF_META } // namespace -- cgit v1.2.3