diff options
author | Pete Warden <petewarden@google.com> | 2017-04-11 14:53:41 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-11 16:08:23 -0700 |
commit | 7c9d2a458ee6cb925a0b3d23793d0e356a6eac12 (patch) | |
tree | cd00f2a857be90824707333b666c0dc6621c5d9a /tensorflow/core/util | |
parent | b6d47b5e56b19394c2fa55e55b36e2ef77fbc69e (diff) |
Add AudioSpectrogram op to TensorFlow for audio feature generation
Change: 152872386
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/command_line_flags.cc | 28 | ||||
-rw-r--r-- | tensorflow/core/util/command_line_flags.h | 4 | ||||
-rw-r--r-- | tensorflow/core/util/command_line_flags_test.cc | 35 |
3 files changed, 60 insertions, 7 deletions
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 03eb076f30..8373eb1f9e 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -93,6 +93,22 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, return false; } +bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + float* dst, bool* value_parsing_ok) { + *value_parsing_ok = true; + if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { + char extra; + if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) { + LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag + << "."; + *value_parsing_ok = false; + } + return true; + } + + return false; +} + } // namespace Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text) @@ -116,6 +132,12 @@ Flag::Flag(const char* name, string* dst, const string& usage_text) string_value_(dst), usage_text_(usage_text) {} +Flag::Flag(const char* name, float* dst, const string& usage_text) + : name_(name), + type_(TYPE_FLOAT), + float_value_(dst), + usage_text_(usage_text) {} + bool Flag::Parse(string arg, bool* value_parsing_ok) const { bool result = false; if (type_ == TYPE_INT) { @@ -126,6 +148,8 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const { result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok); } else if (type_ == TYPE_STRING) { result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok); + } else if (type_ == TYPE_FLOAT) { + result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok); } return result; } @@ -195,6 +219,10 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const { type_name = "string"; flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(), flag.string_value_->c_str()); + } else if (flag.type_ == Flag::TYPE_FLOAT) { + type_name = "float"; + flag_string = + strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_); } strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(), type_name, flag.usage_text_.c_str()); diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h index 2c77d7874f..f349df16fd 100644 --- a/tensorflow/core/util/command_line_flags.h +++ b/tensorflow/core/util/command_line_flags.h @@ -65,6 +65,7 @@ class Flag { Flag(const char* name, int64* dst1, const string& usage_text); Flag(const char* name, bool* dst, const string& usage_text); Flag(const char* name, string* dst, const string& usage_text); + Flag(const char* name, float* dst, const string& usage_text); private: friend class Flags; @@ -72,11 +73,12 @@ class Flag { bool Parse(string arg, bool* value_parsing_ok) const; string name_; - enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING } type_; + enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_; int* int_value_; int64* int64_value_; bool* bool_value_; string* string_value_; + float* float_value_; string usage_text_; }; diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc index b002e35899..62025463af 100644 --- a/tensorflow/core/util/command_line_flags_test.cc +++ b/tensorflow/core/util/command_line_flags_test.cc @@ -32,29 +32,35 @@ std::vector<char *> CharPointerVectorFromStrings( } return result; } -} +} // namespace TEST(CommandLineFlagsTest, BasicUsage) { int some_int = 10; int64 some_int64 = 21474836470; // max int32 is 2147483647 bool some_switch = false; string some_name = "something"; - int argc = 5; - std::vector<string> argv_strings = { - "program_name", "--some_int=20", "--some_int64=214748364700", - "--some_switch", "--some_name=somethingelse"}; + float some_float = -23.23f; + int argc = 6; + std::vector<string> argv_strings = {"program_name", + "--some_int=20", + "--some_int64=214748364700", + "--some_switch", + "--some_name=somethingelse", + "--some_float=42.0"}; std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); bool parsed_ok = Flags::Parse(&argc, argv_array.data(), {Flag("some_int", &some_int, "some int"), Flag("some_int64", &some_int64, "some int64"), Flag("some_switch", &some_switch, "some switch"), - Flag("some_name", &some_name, "some name")}); + Flag("some_name", &some_name, "some name"), + Flag("some_float", &some_float, "some float")}); EXPECT_EQ(true, parsed_ok); EXPECT_EQ(20, some_int); EXPECT_EQ(214748364700, some_int64); EXPECT_EQ(true, some_switch); EXPECT_EQ("somethingelse", some_name); + EXPECT_NEAR(42.0f, some_float, 1e-5f); EXPECT_EQ(argc, 1); } @@ -85,6 +91,21 @@ TEST(CommandLineFlagsTest, BadBoolValue) { EXPECT_EQ(argc, 1); } +TEST(CommandLineFlagsTest, BadFloatValue) { + float some_float = -23.23f; + int argc = 2; + std::vector<string> argv_strings = {"program_name", + "--some_float=notanumber"}; + std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_float", &some_float, "some float")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_NEAR(-23.23f, some_float, 1e-5f); + EXPECT_EQ(argc, 1); +} + // Return whether str==pat, but allowing any whitespace in pat // to match zero or more whitespace characters in str. static bool MatchWithAnyWhitespace(const string &str, const string &pat) { @@ -111,6 +132,8 @@ TEST(CommandLineFlagsTest, UsageString) { int64 some_int64 = 21474836470; // max int32 is 2147483647 bool some_switch = false; string some_name = "something"; + // Don't test float in this case, because precision is hard to predict and + // match against, and we don't want a flakey test. const string tool_name = "some_tool_name"; string usage = Flags::Usage(tool_name + "<flags>", {Flag("some_int", &some_int, "some int"), |