aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 15:34:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 15:45:34 -0700
commit5dbb021354e0acda667d823e856ec8be88960b35 (patch)
tree026b325db796c46d228fbc599697a628d4374bec /tensorflow/c
parent2511230c0a9b8e2ec652d00dcedbd75d644e5400 (diff)
Added a C utility to create a ServerDef proto from text representation.
PiperOrigin-RevId: 214681193
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/BUILD1
-rw-r--r--tensorflow/c/c_api_experimental.cc15
-rw-r--r--tensorflow/c/c_api_experimental.h2
-rw-r--r--tensorflow/c/c_api_experimental_test.cc46
4 files changed, 64 insertions, 0 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 43c279bd80..17e2e292eb 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -246,6 +246,7 @@ tf_cc_test(
":c_api_experimental",
":c_test_util",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 3bcc62cf2d..f316e4ba67 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using tensorflow::FunctionDef;
using tensorflow::Node;
@@ -8508,6 +8509,20 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
VLOG(1) << "Enqueuing is done.";
}
+TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
+ tensorflow::ServerDef server_def;
+ if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
+ &server_def)) {
+ status->status = tensorflow::errors::Internal(
+ "Invalid text proto for ServerDef: ", text_proto);
+ return nullptr;
+ }
+ status->status = tensorflow::Status();
+ TF_Buffer* ret = TF_NewBuffer();
+ TF_CHECK_OK(MessageToBuffer(server_def, ret));
+ return ret;
+}
+
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
TF_Status* status) {
auto* opts = TFE_NewContextOptions();
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index a3ca847d96..950ad9aeed 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -131,6 +131,8 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
int tensor_id,
TF_Tensor* tensor,
TF_Status* status);
+// Create a serialized tensorflow.ServerDef proto.
+TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
// TODO: remove this API in favor of the next one.
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc
index 30fcfd401d..c6effd3969 100644
--- a/tensorflow/c/c_api_experimental_test.cc
+++ b/tensorflow/c/c_api_experimental_test.cc
@@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
namespace {
@@ -116,5 +118,49 @@ TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) {
TF_DeleteStatus(s);
}
+TEST(CAPI_EXPERIMENTAL, GetServerDefTest) {
+ const string expected_text_proto(R"(cluster {
+ job {
+ name: "worker"
+ tasks {
+ key: 0
+ value: "tpuserver:0"
+ }
+ tasks {
+ key: 1
+ value: "localhost:1"
+ }
+ }
+}
+job_name: "worker"
+task_index: 1
+protocol: "grpc"
+)");
+
+ TF_Status* status = TF_NewStatus();
+ TF_Buffer* result = TFE_GetServerDef(expected_text_proto.c_str(), status);
+ EXPECT_EQ(TF_GetCode(status), TF_OK);
+
+ ServerDef actual;
+ ASSERT_TRUE(actual.ParseFromArray(result->data, result->length));
+ string actual_text_proto;
+ tensorflow::protobuf::TextFormat::PrintToString(actual, &actual_text_proto);
+ EXPECT_EQ(expected_text_proto, actual_text_proto);
+
+ const string malformed_text_proto(R"(cluster {
+ job {
+ name: "worker")");
+ TF_Buffer* null_result =
+ TFE_GetServerDef(malformed_text_proto.c_str(), status);
+ EXPECT_NE(TF_GetCode(status), TF_OK);
+ EXPECT_TRUE(tensorflow::str_util::StrContains(
+ TF_Message(status), "Invalid text proto for ServerDef"));
+ EXPECT_EQ(null_result, nullptr);
+
+ // Cleanup
+ TF_DeleteBuffer(result);
+ TF_DeleteStatus(status);
+}
+
} // namespace
} // namespace tensorflow