diff options
author | Dan Smilkov <dsmilkov@gmail.com> | 2016-02-16 11:30:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-02-16 11:42:48 -0800 |
commit | 0907b351103a2fed8c9b74a5ddaaf47870d9945e (patch) | |
tree | 6c52705016e5f1ecaffb6ffdd532e775659b9f10 /tensorflow/core/util | |
parent | f66d5034121940533f7f29c155505017da9cb7f4 (diff) |
Update versions of bower components to reflect those inside Google. This also fixes the problem where users are asked which version of polymer to install when they run `bower install`.
Change: 114774859
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/command_line_flags.cc | 136 | ||||
-rw-r--r-- | tensorflow/core/util/command_line_flags.h | 69 | ||||
-rw-r--r-- | tensorflow/core/util/command_line_flags_test.cc | 82 |
3 files changed, 287 insertions, 0 deletions
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc new file mode 100644 index 0000000000..d55fd568dd --- /dev/null +++ b/tensorflow/core/util/command_line_flags.cc @@ -0,0 +1,136 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace { + +bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + string* dst, bool* value_parsing_ok) { + *value_parsing_ok = true; + if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { + *dst = arg.ToString(); + return true; + } + + return false; +} + +bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + tensorflow::int32* dst, bool* value_parsing_ok) { + *value_parsing_ok = true; + if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { + char extra; + if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) { + LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag + << "."; + *value_parsing_ok = false; + } + return true; + } + + return false; +} + +bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + bool* dst, bool* value_parsing_ok) { + *value_parsing_ok = true; + if (arg.Consume("--") && arg.Consume(flag)) { + if (arg.empty()) { + *dst = true; + return true; + } + + if (arg == "=true") { + *dst = true; + return true; + } else if (arg == "=false") { + *dst = false; + return true; + } else { + LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag + << "."; + *value_parsing_ok = false; + return true; + } + } + + return false; +} + +} // namespace + +Flag::Flag(const char* name, tensorflow::int32* dst) + : name_(name), type_(TYPE_INT), int_value_(dst) {} + +Flag::Flag(const char* name, bool* dst) + : name_(name), type_(TYPE_BOOL), bool_value_(dst) {} + +Flag::Flag(const char* name, string* dst) + : name_(name), type_(TYPE_STRING), string_value_(dst) {} + +bool Flag::Parse(string arg, bool* value_parsing_ok) const { + bool result = false; + if (type_ == TYPE_INT) { + result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok); + } else if (type_ == TYPE_BOOL) { + result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok); + } else if (type_ == TYPE_STRING) { + result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok); + } + return result; +} + +bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list) { + bool result = true; + std::vector<char*> unknown_flags; + for (int i = 1; i < *argc; ++i) { + if (string(argv[i]) == "--") { + while (i < *argc) { + unknown_flags.push_back(argv[i]); + ++i; + } + break; + } + + bool was_found = false; + for (const Flag& flag : flag_list) { + bool value_parsing_ok; + was_found = flag.Parse(argv[i], &value_parsing_ok); + if (!value_parsing_ok) { + result = false; + } + if (was_found) { + break; + } + } + if (!was_found) { + unknown_flags.push_back(argv[i]); + } + } + // Passthrough any extra flags. + int dst = 1; // Skip argv[0] + for (char* f : unknown_flags) { + argv[dst++] = f; + } + argv[dst++] = nullptr; + *argc = unknown_flags.size() + 1; + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h new file mode 100644 index 0000000000..124756cdde --- /dev/null +++ b/tensorflow/core/util/command_line_flags.h @@ -0,0 +1,69 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H +#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H + +#include <vector> +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// N.B. This library is for INTERNAL use only. +// +// This is a simple command-line argument parsing module to help us handle +// parameters for C++ binaries. The recommended way of using it is with local +// variables and an initializer list of Flag objects, for example: +// +// int some_int = 10; +// bool some_switch = false; +// string some_name = "something"; +// bool parsed_values_ok = ParseFlags(&argc, argv, { +// Flag("some_int", &some_int), +// Flag("some_switch", &some_switch), +// Flag("some_name", &some_name)}); +// +// The argc and argv values are adjusted by the Parse function so all that +// remains is the program name (at argv[0]) and any unknown arguments fill the +// rest of the array. This means you can check for flags that weren't understood +// by seeing if argv is greater than 1. +// The result indicates if there were any errors parsing the values that were +// passed to the command-line switches. For example, --some_int=foo would return +// false because the argument is expected to be an integer. +// +// NOTE: Unlike gflags-style libraries, this library is intended to be +// used in the `main()` function of your binary. It does not handle +// flag definitions that are scattered around the source code. +class Flag { + public: + Flag(const char* name, int32* dst1); + Flag(const char* name, bool* dst); + Flag(const char* name, string* dst); + + bool Parse(string arg, bool* value_parsing_ok) const; + + private: + string name_; + enum { TYPE_INT, TYPE_BOOL, TYPE_STRING } type_; + int* int_value_; + bool* bool_value_; + string* string_value_; +}; + +bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc new file mode 100644 index 0000000000..00dde42a2f --- /dev/null +++ b/tensorflow/core/util/command_line_flags_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +// The returned array is only valid for the lifetime of the input vector. +// We're using const casting because we need to pass in an argv-style array of +// char* pointers for the API, even though we know they won't be altered. +std::vector<char*> CharPointerVectorFromStrings( + const std::vector<tensorflow::string>& strings) { + std::vector<char*> result; + for (const tensorflow::string& string : strings) { + result.push_back(const_cast<char*>(string.c_str())); + } + return result; +} +} + +TEST(CommandLineFlagsTest, BasicUsage) { + int some_int = 10; + bool some_switch = false; + tensorflow::string some_name = "something"; + int argc = 4; + std::vector<tensorflow::string> argv_strings = { + "program_name", "--some_int=20", "--some_switch", + "--some_name=somethingelse"}; + std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + ParseFlags(&argc, argv_array.data(), {Flag("some_int", &some_int), + Flag("some_switch", &some_switch), + Flag("some_name", &some_name)}); + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(20, some_int); + EXPECT_EQ(true, some_switch); + EXPECT_EQ("somethingelse", some_name); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadIntValue) { + int some_int = 10; + int argc = 2; + std::vector<tensorflow::string> argv_strings = {"program_name", + "--some_int=notanumber"}; + std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + ParseFlags(&argc, argv_array.data(), {Flag("some_int", &some_int)}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(10, some_int); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadBoolValue) { + bool some_switch = false; + int argc = 2; + std::vector<tensorflow::string> argv_strings = {"program_name", + "--some_switch=notabool"}; + std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + ParseFlags(&argc, argv_array.data(), {Flag("some_switch", &some_switch)}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(false, some_switch); + EXPECT_EQ(argc, 1); +} + +} // namespace tensorflow |