aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc')
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc32
1 files changed, 21 insertions, 11 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
index 3416f99919..1c7bb4375c 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <iostream>
+#include <vector>
#include "grpc++/grpc++.h"
#include "grpc++/security/credentials.h"
@@ -36,17 +37,10 @@ limitations under the License.
namespace tensorflow {
namespace {
-Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
+Status FillServerDef(const string& cluster_spec, const string& job_name,
+ int task_index, ServerDef* options) {
options->set_protocol("grpc");
- string cluster_spec;
- int task_index = 0;
- const bool parse_result = ParseFlags(
- &argc, argv, {Flag("cluster_spec", &cluster_spec), //
- Flag("job_name", options->mutable_job_name()), //
- Flag("task_id", &task_index)});
- if (!parse_result) {
- return errors::InvalidArgument("Error parsing command-line flags");
- }
+ options->set_job_name(job_name);
options->set_task_index(task_index);
size_t my_num_tasks = 0;
@@ -101,9 +95,25 @@ void Usage(char* const argv_0) {
}
int main(int argc, char* argv[]) {
+ tensorflow::string cluster_spec;
+ tensorflow::string job_name;
+ int task_index = 0;
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("cluster_spec", &cluster_spec, "cluster spec"),
+ tensorflow::Flag("job_name", &job_name, "job name"),
+ tensorflow::Flag("task_id", &task_index, "task id"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (!parse_result || argc != 1) {
+ std::cerr << usage << std::endl;
+ Usage(argv[0]);
+ return -1;
+ }
tensorflow::ServerDef server_def;
- tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &server_def);
+ tensorflow::Status s = tensorflow::FillServerDef(cluster_spec, job_name,
+ task_index, &server_def);
if (!s.ok()) {
std::cerr << "ERROR: " << s.error_message() << std::endl;
Usage(argv[0]);