diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc | 32 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc | 47 |
2 files changed, 49 insertions, 30 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc index 3416f99919..1c7bb4375c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include <iostream> +#include <vector> #include "grpc++/grpc++.h" #include "grpc++/security/credentials.h" @@ -36,17 +37,10 @@ limitations under the License. namespace tensorflow { namespace { -Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) { +Status FillServerDef(const string& cluster_spec, const string& job_name, + int task_index, ServerDef* options) { options->set_protocol("grpc"); - string cluster_spec; - int task_index = 0; - const bool parse_result = ParseFlags( - &argc, argv, {Flag("cluster_spec", &cluster_spec), // - Flag("job_name", options->mutable_job_name()), // - Flag("task_id", &task_index)}); - if (!parse_result) { - return errors::InvalidArgument("Error parsing command-line flags"); - } + options->set_job_name(job_name); options->set_task_index(task_index); size_t my_num_tasks = 0; @@ -101,9 +95,25 @@ void Usage(char* const argv_0) { } int main(int argc, char* argv[]) { + tensorflow::string cluster_spec; + tensorflow::string job_name; + int task_index = 0; + std::vector<tensorflow::Flag> flag_list = { + tensorflow::Flag("cluster_spec", &cluster_spec, "cluster spec"), + tensorflow::Flag("job_name", &job_name, "job name"), + tensorflow::Flag("task_id", &task_index, "task id"), + }; + tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(argv[0], &argc, &argv); + if (!parse_result || argc != 1) { + std::cerr << usage << std::endl; + Usage(argv[0]); + return -1; + } tensorflow::ServerDef server_def; - tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &server_def); + tensorflow::Status s = tensorflow::FillServerDef(cluster_spec, job_name, + task_index, &server_def); if (!s.ok()) { std::cerr << "ERROR: " << s.error_message() << std::endl; Usage(argv[0]); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc index f31687068b..953cf933d5 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <vector> + #include "grpc++/grpc++.h" #include "grpc++/security/credentials.h" #include "grpc++/server_builder.h" @@ -32,25 +34,13 @@ limitations under the License. namespace tensorflow { namespace { -Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) { +Status FillServerDef(const string& job_spec, const string& job_name, + int num_cpus, int num_gpus, int task_index, + ServerDef* options) { options->set_protocol("grpc"); - string job_spec; - int num_cpus = 1; - int num_gpus = 0; - int task_index = 0; - const bool parse_result = - ParseFlags(&argc, argv, {Flag("tf_jobs", &job_spec), // - Flag("tf_job", options->mutable_job_name()), // - Flag("tf_task", &task_index), // - Flag("num_cpus", &num_cpus), // - Flag("num_gpus", &num_gpus)}); - + options->set_job_name(job_name); options->set_task_index(task_index); - if (!parse_result) { - return errors::InvalidArgument("Error parsing command-line flags"); - } - uint32 my_tasks_per_replica = 0; for (const string& job_str : str_util::Split(job_spec, ',')) { JobDef* job_def = options->mutable_cluster()->add_job(); @@ -85,13 +75,32 @@ Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) { } // namespace tensorflow int main(int argc, char* argv[]) { + tensorflow::string job_spec; + tensorflow::string job_name; + int num_cpus = 1; + int num_gpus = 0; + int task_index = 0; + std::vector<tensorflow::Flag> flag_list = { + tensorflow::Flag("tf_jobs", &job_spec, "job specification"), + tensorflow::Flag("tf_job", &job_name, "job name"), + tensorflow::Flag("tf_task", &task_index, "task index"), + tensorflow::Flag("num_cpus", &num_cpus, "number of CPUs"), + tensorflow::Flag("num_gpus", &num_gpus, "number of GPUs"), + }; + tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(argv[0], &argc, &argv); + if (!parse_result || argc != 1) { + LOG(ERROR) << usage; + return -1; + } tensorflow::ServerDef def; - tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &def); - + tensorflow::Status s = tensorflow::FillServerDef(job_spec, job_name, num_cpus, + num_gpus, task_index, &def); if (!s.ok()) { - LOG(ERROR) << "Could not parse flags: " << s.error_message(); + LOG(ERROR) << "Could not parse job spec: " << s.error_message() << "\n" + << usage; return -1; } |