aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 18:25:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 18:28:50 -0700
commit086183579a59e07fc9b1ebbfa6516258da0a215b (patch)
treec6b7987336cade01751b8d0bea4fabdba459af74 /tensorflow
parentd125fb8a39bb4fca1be5421130ed66d673ee590f (diff)
Create a GRPC service library to enable reuse in other parts of the code base.
PiperOrigin-RevId: 214074684
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/rpc/BUILD12
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service_main.cc11
2 files changed, 18 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD
index 97fcd37f6b..aa8da04489 100644
--- a/tensorflow/compiler/xla/rpc/BUILD
+++ b/tensorflow/compiler/xla/rpc/BUILD
@@ -34,8 +34,8 @@ cc_library(
],
)
-tf_cc_binary(
- name = "grpc_service_main_cpu",
+cc_library(
+ name = "grpc_service_main_library",
srcs = ["grpc_service_main.cc"],
deps = [
":grpc_service",
@@ -47,6 +47,14 @@ tf_cc_binary(
],
)
+tf_cc_binary(
+ name = "grpc_service_main_cpu",
+ deps = [
+ ":grpc_service_main_library",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ ],
+)
+
tf_cc_test(
name = "grpc_client_test",
srcs = ["grpc_client_test.cc"],
diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
index d6b5149a24..fb54d39a2a 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
@@ -29,8 +29,12 @@ namespace {
int RealMain(int argc, char** argv) {
int32 port = 1685;
+ bool any_address = false;
std::vector<tensorflow::Flag> flag_list = {
- tensorflow::Flag("port", &port, "port to listen on"),
+ tensorflow::Flag("port", &port, "The TCP port to listen on"),
+ tensorflow::Flag(
+ "any", &any_address,
+ "Whether to listen to any host address or simply localhost"),
};
string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
@@ -44,15 +48,16 @@ int RealMain(int argc, char** argv) {
xla::GRPCService::NewService().ConsumeValueOrDie();
::grpc::ServerBuilder builder;
- string server_address(absl::StrFormat("localhost:%d", port));
+ string server_address(
+ absl::StrFormat("%s:%d", any_address ? "[::]" : "localhost", port));
+ builder.SetMaxReceiveMessageSize(INT_MAX);
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
builder.RegisterService(service.get());
std::unique_ptr<::grpc::Server> server(builder.BuildAndStart());
LOG(INFO) << "Server listening on " << server_address;
server->Wait();
-
return 0;
}