aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc32
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc47
-rw-r--r--tensorflow/core/util/command_line_flags.cc70
-rw-r--r--tensorflow/core/util/command_line_flags.h49
-rw-r--r--tensorflow/core/util/command_line_flags_test.cc94
5 files changed, 219 insertions, 73 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
index 3416f99919..1c7bb4375c 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <iostream>
+#include <vector>
#include "grpc++/grpc++.h"
#include "grpc++/security/credentials.h"
@@ -36,17 +37,10 @@ limitations under the License.
namespace tensorflow {
namespace {
-Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
+Status FillServerDef(const string& cluster_spec, const string& job_name,
+ int task_index, ServerDef* options) {
options->set_protocol("grpc");
- string cluster_spec;
- int task_index = 0;
- const bool parse_result = ParseFlags(
- &argc, argv, {Flag("cluster_spec", &cluster_spec), //
- Flag("job_name", options->mutable_job_name()), //
- Flag("task_id", &task_index)});
- if (!parse_result) {
- return errors::InvalidArgument("Error parsing command-line flags");
- }
+ options->set_job_name(job_name);
options->set_task_index(task_index);
size_t my_num_tasks = 0;
@@ -101,9 +95,25 @@ void Usage(char* const argv_0) {
}
int main(int argc, char* argv[]) {
+ tensorflow::string cluster_spec;
+ tensorflow::string job_name;
+ int task_index = 0;
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("cluster_spec", &cluster_spec, "cluster spec"),
+ tensorflow::Flag("job_name", &job_name, "job name"),
+ tensorflow::Flag("task_id", &task_index, "task id"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (!parse_result || argc != 1) {
+ std::cerr << usage << std::endl;
+ Usage(argv[0]);
+ return -1;
+ }
tensorflow::ServerDef server_def;
- tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &server_def);
+ tensorflow::Status s = tensorflow::FillServerDef(cluster_spec, job_name,
+ task_index, &server_def);
if (!s.ok()) {
std::cerr << "ERROR: " << s.error_message() << std::endl;
Usage(argv[0]);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
index f31687068b..953cf933d5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <vector>
+
#include "grpc++/grpc++.h"
#include "grpc++/security/credentials.h"
#include "grpc++/server_builder.h"
@@ -32,25 +34,13 @@ limitations under the License.
namespace tensorflow {
namespace {
-Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
+Status FillServerDef(const string& job_spec, const string& job_name,
+ int num_cpus, int num_gpus, int task_index,
+ ServerDef* options) {
options->set_protocol("grpc");
- string job_spec;
- int num_cpus = 1;
- int num_gpus = 0;
- int task_index = 0;
- const bool parse_result =
- ParseFlags(&argc, argv, {Flag("tf_jobs", &job_spec), //
- Flag("tf_job", options->mutable_job_name()), //
- Flag("tf_task", &task_index), //
- Flag("num_cpus", &num_cpus), //
- Flag("num_gpus", &num_gpus)});
-
+ options->set_job_name(job_name);
options->set_task_index(task_index);
- if (!parse_result) {
- return errors::InvalidArgument("Error parsing command-line flags");
- }
-
uint32 my_tasks_per_replica = 0;
for (const string& job_str : str_util::Split(job_spec, ',')) {
JobDef* job_def = options->mutable_cluster()->add_job();
@@ -85,13 +75,32 @@ Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
} // namespace tensorflow
int main(int argc, char* argv[]) {
+ tensorflow::string job_spec;
+ tensorflow::string job_name;
+ int num_cpus = 1;
+ int num_gpus = 0;
+ int task_index = 0;
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("tf_jobs", &job_spec, "job specification"),
+ tensorflow::Flag("tf_job", &job_name, "job name"),
+ tensorflow::Flag("tf_task", &task_index, "task index"),
+ tensorflow::Flag("num_cpus", &num_cpus, "number of CPUs"),
+ tensorflow::Flag("num_gpus", &num_gpus, "number of GPUs"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (!parse_result || argc != 1) {
+ LOG(ERROR) << usage;
+ return -1;
+ }
tensorflow::ServerDef def;
- tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &def);
-
+ tensorflow::Status s = tensorflow::FillServerDef(job_spec, job_name, num_cpus,
+ num_gpus, task_index, &def);
if (!s.ok()) {
- LOG(ERROR) << "Could not parse flags: " << s.error_message();
+ LOG(ERROR) << "Could not parse job spec: " << s.error_message() << "\n"
+ << usage;
return -1;
}
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index 2048126338..03eb076f30 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -13,9 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/util/command_line_flags.h"
+#include <string>
+#include <vector>
+
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace {
@@ -91,17 +95,26 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
} // namespace
-Flag::Flag(const char* name, tensorflow::int32* dst)
- : name_(name), type_(TYPE_INT), int_value_(dst) {}
+Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
+ : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {}
-Flag::Flag(const char* name, tensorflow::int64* dst)
- : name_(name), type_(TYPE_INT64), int64_value_(dst) {}
+Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_INT64),
+ int64_value_(dst),
+ usage_text_(usage_text) {}
-Flag::Flag(const char* name, bool* dst)
- : name_(name), type_(TYPE_BOOL), bool_value_(dst) {}
+Flag::Flag(const char* name, bool* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_BOOL),
+ bool_value_(dst),
+ usage_text_(usage_text) {}
-Flag::Flag(const char* name, string* dst)
- : name_(name), type_(TYPE_STRING), string_value_(dst) {}
+Flag::Flag(const char* name, string* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_STRING),
+ string_value_(dst),
+ usage_text_(usage_text) {}
bool Flag::Parse(string arg, bool* value_parsing_ok) const {
bool result = false;
@@ -117,7 +130,8 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
return result;
}
-bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list) {
+/*static*/ bool Flags::Parse(int* argc, char** argv,
+ const std::vector<Flag>& flag_list) {
bool result = true;
std::vector<char*> unknown_flags;
for (int i = 1; i < *argc; ++i) {
@@ -151,7 +165,41 @@ bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list) {
}
argv[dst++] = nullptr;
*argc = unknown_flags.size() + 1;
- return result;
+ return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
+}
+
+/*static*/ string Flags::Usage(const string& cmdline,
+ const std::vector<Flag>& flag_list) {
+ string usage_text;
+ if (!flag_list.empty()) {
+ strings::Appendf(&usage_text, "usage: %s\nFlags:\n", cmdline.c_str());
+ } else {
+ strings::Appendf(&usage_text, "usage: %s\n", cmdline.c_str());
+ }
+ for (const Flag& flag : flag_list) {
+ const char* type_name = "";
+ string flag_string;
+ if (flag.type_ == Flag::TYPE_INT) {
+ type_name = "int32";
+ flag_string =
+ strings::Printf("--%s=%d", flag.name_.c_str(), *flag.int_value_);
+ } else if (flag.type_ == Flag::TYPE_INT64) {
+ type_name = "int64";
+ flag_string = strings::Printf("--%s=%lld", flag.name_.c_str(),
+ static_cast<long long>(*flag.int64_value_));
+ } else if (flag.type_ == Flag::TYPE_BOOL) {
+ type_name = "bool";
+ flag_string = strings::Printf("--%s=%s", flag.name_.c_str(),
+ *flag.bool_value_ ? "true" : "false");
+ } else if (flag.type_ == Flag::TYPE_STRING) {
+ type_name = "string";
+ flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
+ flag.string_value_->c_str());
+ }
+ strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
+ type_name, flag.usage_text_.c_str());
+ }
+ return usage_text;
}
} // namespace tensorflow
diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h
index 9297fb066d..2c77d7874f 100644
--- a/tensorflow/core/util/command_line_flags.h
+++ b/tensorflow/core/util/command_line_flags.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
+#include <string>
#include <vector>
#include "tensorflow/core/platform/types.h"
@@ -30,10 +31,19 @@ namespace tensorflow {
// int some_int = 10;
// bool some_switch = false;
// string some_name = "something";
-// bool parsed_values_ok = ParseFlags(&argc, argv, {
-// Flag("some_int", &some_int),
-// Flag("some_switch", &some_switch),
-// Flag("some_name", &some_name)});
+// std::vector<tensorFlow::Flag> flag_list = {
+// Flag("some_int", &some_int, "an integer that affects X"),
+// Flag("some_switch", &some_switch, "a bool that affects Y"),
+// Flag("some_name", &some_name, "a string that affects Z")
+// };
+// // Get usage message before ParseFlags() to capture default values.
+// string usage = Flag::Usage(argv[0], flag_list);
+// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list);
+//
+// tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
+// if (argc != 1 || !parsed_values_ok) {
+// ...output usage and error message...
+// }
//
// The argc and argv values are adjusted by the Parse function so all that
// remains is the program name (at argv[0]) and any unknown arguments fill the
@@ -46,25 +56,44 @@ namespace tensorflow {
// NOTE: Unlike gflags-style libraries, this library is intended to be
// used in the `main()` function of your binary. It does not handle
// flag definitions that are scattered around the source code.
+
+// A description of a single command line flag, holding its name, type, usage
+// text, and a pointer to the corresponding variable.
class Flag {
public:
- Flag(const char* name, int32* dst1);
- Flag(const char* name, int64* dst1);
- Flag(const char* name, bool* dst);
- Flag(const char* name, string* dst);
+ Flag(const char* name, int32* dst1, const string& usage_text);
+ Flag(const char* name, int64* dst1, const string& usage_text);
+ Flag(const char* name, bool* dst, const string& usage_text);
+ Flag(const char* name, string* dst, const string& usage_text);
+
+ private:
+ friend class Flags;
bool Parse(string arg, bool* value_parsing_ok) const;
- private:
string name_;
enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING } type_;
int* int_value_;
int64* int64_value_;
bool* bool_value_;
string* string_value_;
+ string usage_text_;
};
-bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list);
+class Flags {
+ public:
+ // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag
+ // instances matching flags in flaglist[]. Update the variables associated
+ // with matching flags, and remove the matching arguments from (*argc, argv).
+ // Return true iff all recognized flag values were parsed correctly, and the
+ // first remaining argument is not "--help".
+ static bool Parse(int* argc, char** argv, const std::vector<Flag>& flag_list);
+
+ // Return a usage message with command line cmdline, and the
+ // usage_text strings in flag_list[].
+ static string Usage(const string& cmdline,
+ const std::vector<Flag>& flag_list);
+};
} // namespace tensorflow
diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc
index bc38fff8fd..b002e35899 100644
--- a/tensorflow/core/util/command_line_flags_test.cc
+++ b/tensorflow/core/util/command_line_flags_test.cc
@@ -13,19 +13,22 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/util/command_line_flags.h"
+#include <ctype.h>
+#include <vector>
+
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace {
// The returned array is only valid for the lifetime of the input vector.
// We're using const casting because we need to pass in an argv-style array of
// char* pointers for the API, even though we know they won't be altered.
-std::vector<char*> CharPointerVectorFromStrings(
- const std::vector<tensorflow::string>& strings) {
- std::vector<char*> result;
- for (const tensorflow::string& string : strings) {
- result.push_back(const_cast<char*>(string.c_str()));
+std::vector<char *> CharPointerVectorFromStrings(
+ const std::vector<string> &strings) {
+ std::vector<char *> result;
+ for (const string &string : strings) {
+ result.push_back(const_cast<char *>(string.c_str()));
}
return result;
}
@@ -35,16 +38,18 @@ TEST(CommandLineFlagsTest, BasicUsage) {
int some_int = 10;
int64 some_int64 = 21474836470; // max int32 is 2147483647
bool some_switch = false;
- tensorflow::string some_name = "something";
+ string some_name = "something";
int argc = 5;
- std::vector<tensorflow::string> argv_strings = {
+ std::vector<string> argv_strings = {
"program_name", "--some_int=20", "--some_int64=214748364700",
"--some_switch", "--some_name=somethingelse"};
- std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings);
- bool parsed_ok = ParseFlags(
- &argc, argv_array.data(),
- {Flag("some_int", &some_int), Flag("some_int64", &some_int64),
- Flag("some_switch", &some_switch), Flag("some_name", &some_name)});
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_int", &some_int, "some int"),
+ Flag("some_int64", &some_int64, "some int64"),
+ Flag("some_switch", &some_switch, "some switch"),
+ Flag("some_name", &some_name, "some name")});
EXPECT_EQ(true, parsed_ok);
EXPECT_EQ(20, some_int);
EXPECT_EQ(214748364700, some_int64);
@@ -56,11 +61,10 @@ TEST(CommandLineFlagsTest, BasicUsage) {
TEST(CommandLineFlagsTest, BadIntValue) {
int some_int = 10;
int argc = 2;
- std::vector<tensorflow::string> argv_strings = {"program_name",
- "--some_int=notanumber"};
- std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings);
- bool parsed_ok =
- ParseFlags(&argc, argv_array.data(), {Flag("some_int", &some_int)});
+ std::vector<string> argv_strings = {"program_name", "--some_int=notanumber"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok = Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_int", &some_int, "some int")});
EXPECT_EQ(false, parsed_ok);
EXPECT_EQ(10, some_int);
@@ -70,15 +74,61 @@ TEST(CommandLineFlagsTest, BadIntValue) {
TEST(CommandLineFlagsTest, BadBoolValue) {
bool some_switch = false;
int argc = 2;
- std::vector<tensorflow::string> argv_strings = {"program_name",
- "--some_switch=notabool"};
- std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings);
+ std::vector<string> argv_strings = {"program_name", "--some_switch=notabool"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
bool parsed_ok =
- ParseFlags(&argc, argv_array.data(), {Flag("some_switch", &some_switch)});
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_switch", &some_switch, "some switch")});
EXPECT_EQ(false, parsed_ok);
EXPECT_EQ(false, some_switch);
EXPECT_EQ(argc, 1);
}
+// Return whether str==pat, but allowing any whitespace in pat
+// to match zero or more whitespace characters in str.
+static bool MatchWithAnyWhitespace(const string &str, const string &pat) {
+ bool matching = true;
+ int pat_i = 0;
+ for (int str_i = 0; str_i != str.size() && matching; str_i++) {
+ if (isspace(str[str_i])) {
+ matching = (pat_i != pat.size() && isspace(pat[pat_i]));
+ } else {
+ while (pat_i != pat.size() && isspace(pat[pat_i])) {
+ pat_i++;
+ }
+ matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]);
+ }
+ }
+ while (pat_i != pat.size() && isspace(pat[pat_i])) {
+ pat_i++;
+ }
+ return (matching && pat_i == pat.size());
+}
+
+TEST(CommandLineFlagsTest, UsageString) {
+ int some_int = 10;
+ int64 some_int64 = 21474836470; // max int32 is 2147483647
+ bool some_switch = false;
+ string some_name = "something";
+ const string tool_name = "some_tool_name";
+ string usage = Flags::Usage(tool_name + "<flags>",
+ {Flag("some_int", &some_int, "some int"),
+ Flag("some_int64", &some_int64, "some int64"),
+ Flag("some_switch", &some_switch, "some switch"),
+ Flag("some_name", &some_name, "some name")});
+ // Match the usage message, being sloppy about whitespace.
+ const char *expected_usage =
+ " usage: some_tool_name <flags>\n"
+ "Flags:\n"
+ "--some_int=10 int32 some int\n"
+ "--some_int64=21474836470 int64 some int64\n"
+ "--some_switch=false bool some switch\n"
+ "--some_name=\"something\" string some name\n";
+ ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true);
+
+ // Again but with no flags.
+ usage = Flags::Usage(tool_name, {});
+ ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true);
+}
} // namespace tensorflow