diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc | 32 |
1 files changed, 21 insertions, 11 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc index 3416f99919..1c7bb4375c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include <iostream> +#include <vector> #include "grpc++/grpc++.h" #include "grpc++/security/credentials.h" @@ -36,17 +37,10 @@ limitations under the License. namespace tensorflow { namespace { -Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) { +Status FillServerDef(const string& cluster_spec, const string& job_name, + int task_index, ServerDef* options) { options->set_protocol("grpc"); - string cluster_spec; - int task_index = 0; - const bool parse_result = ParseFlags( - &argc, argv, {Flag("cluster_spec", &cluster_spec), // - Flag("job_name", options->mutable_job_name()), // - Flag("task_id", &task_index)}); - if (!parse_result) { - return errors::InvalidArgument("Error parsing command-line flags"); - } + options->set_job_name(job_name); options->set_task_index(task_index); size_t my_num_tasks = 0; @@ -101,9 +95,25 @@ void Usage(char* const argv_0) { } int main(int argc, char* argv[]) { + tensorflow::string cluster_spec; + tensorflow::string job_name; + int task_index = 0; + std::vector<tensorflow::Flag> flag_list = { + tensorflow::Flag("cluster_spec", &cluster_spec, "cluster spec"), + tensorflow::Flag("job_name", &job_name, "job name"), + tensorflow::Flag("task_id", &task_index, "task id"), + }; + tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(argv[0], &argc, &argv); + if (!parse_result || argc != 1) { + std::cerr << usage << std::endl; + Usage(argv[0]); + return -1; + } tensorflow::ServerDef server_def; - tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &server_def); + tensorflow::Status s = tensorflow::FillServerDef(cluster_spec, job_name, + task_index, &server_def); if (!s.ok()) { std::cerr << "ERROR: " << s.error_message() << std::endl; Usage(argv[0]); |