diff options
-rw-r--r-- | tensorflow/compiler/xla/rpc/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/rpc/grpc_service_main.cc | 11 |
2 files changed, 18 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 97fcd37f6b..aa8da04489 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -34,8 +34,8 @@ cc_library( ], ) -tf_cc_binary( - name = "grpc_service_main_cpu", +cc_library( + name = "grpc_service_main_library", srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", @@ -47,6 +47,14 @@ tf_cc_binary( ], ) +tf_cc_binary( + name = "grpc_service_main_cpu", + deps = [ + ":grpc_service_main_library", + "//tensorflow/compiler/xla/service:cpu_plugin", + ], +) + tf_cc_test( name = "grpc_client_test", srcs = ["grpc_client_test.cc"], diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index d6b5149a24..fb54d39a2a 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -29,8 +29,12 @@ namespace { int RealMain(int argc, char** argv) { int32 port = 1685; + bool any_address = false; std::vector<tensorflow::Flag> flag_list = { - tensorflow::Flag("port", &port, "port to listen on"), + tensorflow::Flag("port", &port, "The TCP port to listen on"), + tensorflow::Flag( + "any", &any_address, + "Whether to listen to any host address or simply localhost"), }; string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); @@ -44,15 +48,16 @@ int RealMain(int argc, char** argv) { xla::GRPCService::NewService().ConsumeValueOrDie(); ::grpc::ServerBuilder builder; - string server_address(absl::StrFormat("localhost:%d", port)); + string server_address( + absl::StrFormat("%s:%d", any_address ? "[::]" : "localhost", port)); + builder.SetMaxReceiveMessageSize(INT_MAX); builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); builder.RegisterService(service.get()); std::unique_ptr<::grpc::Server> server(builder.BuildAndStart()); LOG(INFO) << "Server listening on " << server_address; server->Wait(); - return 0; } |