From 5dbb021354e0acda667d823e856ec8be88960b35 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Sep 2018 15:34:43 -0700 Subject: Added a C utility to create a ServerDef proto from text representation. PiperOrigin-RevId: 214681193 --- tensorflow/c/BUILD | 1 + tensorflow/c/c_api_experimental.cc | 15 +++++++++++ tensorflow/c/c_api_experimental.h | 2 ++ tensorflow/c/c_api_experimental_test.cc | 46 +++++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+) (limited to 'tensorflow/c') diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 43c279bd80..17e2e292eb 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -246,6 +246,7 @@ tf_cc_test( ":c_api_experimental", ":c_test_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 3bcc62cf2d..f316e4ba67 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" using tensorflow::FunctionDef; using tensorflow::Node; @@ -8508,6 +8509,20 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, VLOG(1) << "Enqueuing is done."; } +TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) { + tensorflow::ServerDef server_def; + if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, + &server_def)) { + status->status = tensorflow::errors::Internal( + "Invalid text proto for ServerDef: ", text_proto); + return nullptr; + } + status->status = tensorflow::Status(); + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(server_def, ret)); + return ret; +} + TFE_Context* TFE_CreateContextFromSession(TF_Session* session, TF_Status* status) { auto* opts = TFE_NewContextOptions(); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index a3ca847d96..950ad9aeed 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -131,6 +131,8 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, TF_Tensor* tensor, TF_Status* status); +// Create a serialized tensorflow.ServerDef proto. +TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status); // TODO: remove this API in favor of the next one. TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 30fcfd401d..c6effd3969 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" namespace tensorflow { namespace { @@ -116,5 +118,49 @@ TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) { TF_DeleteStatus(s); } +TEST(CAPI_EXPERIMENTAL, GetServerDefTest) { + const string expected_text_proto(R"(cluster { + job { + name: "worker" + tasks { + key: 0 + value: "tpuserver:0" + } + tasks { + key: 1 + value: "localhost:1" + } + } +} +job_name: "worker" +task_index: 1 +protocol: "grpc" +)"); + + TF_Status* status = TF_NewStatus(); + TF_Buffer* result = TFE_GetServerDef(expected_text_proto.c_str(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK); + + ServerDef actual; + ASSERT_TRUE(actual.ParseFromArray(result->data, result->length)); + string actual_text_proto; + tensorflow::protobuf::TextFormat::PrintToString(actual, &actual_text_proto); + EXPECT_EQ(expected_text_proto, actual_text_proto); + + const string malformed_text_proto(R"(cluster { + job { + name: "worker")"); + TF_Buffer* null_result = + TFE_GetServerDef(malformed_text_proto.c_str(), status); + EXPECT_NE(TF_GetCode(status), TF_OK); + EXPECT_TRUE(tensorflow::str_util::StrContains( + TF_Message(status), "Invalid text proto for ServerDef")); + EXPECT_EQ(null_result, nullptr); + + // Cleanup + TF_DeleteBuffer(result); + TF_DeleteStatus(status); +} + } // namespace } // namespace tensorflow -- cgit v1.2.3