aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2017-04-11 14:53:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-11 16:08:23 -0700
commit7c9d2a458ee6cb925a0b3d23793d0e356a6eac12 (patch)
treecd00f2a857be90824707333b666c0dc6621c5d9a /tensorflow/core/util
parentb6d47b5e56b19394c2fa55e55b36e2ef77fbc69e (diff)
Add AudioSpectrogram op to TensorFlow for audio feature generation
Change: 152872386
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/command_line_flags.cc28
-rw-r--r--tensorflow/core/util/command_line_flags.h4
-rw-r--r--tensorflow/core/util/command_line_flags_test.cc35
3 files changed, 60 insertions, 7 deletions
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index 03eb076f30..8373eb1f9e 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -93,6 +93,22 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
return false;
}
+bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
+ float* dst, bool* value_parsing_ok) {
+ *value_parsing_ok = true;
+ if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
+ char extra;
+ if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) {
+ LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
+ << ".";
+ *value_parsing_ok = false;
+ }
+ return true;
+ }
+
+ return false;
+}
+
} // namespace
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
@@ -116,6 +132,12 @@ Flag::Flag(const char* name, string* dst, const string& usage_text)
string_value_(dst),
usage_text_(usage_text) {}
+Flag::Flag(const char* name, float* dst, const string& usage_text)
+ : name_(name),
+ type_(TYPE_FLOAT),
+ float_value_(dst),
+ usage_text_(usage_text) {}
+
bool Flag::Parse(string arg, bool* value_parsing_ok) const {
bool result = false;
if (type_ == TYPE_INT) {
@@ -126,6 +148,8 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok);
} else if (type_ == TYPE_STRING) {
result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok);
+ } else if (type_ == TYPE_FLOAT) {
+ result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok);
}
return result;
}
@@ -195,6 +219,10 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
type_name = "string";
flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
flag.string_value_->c_str());
+ } else if (flag.type_ == Flag::TYPE_FLOAT) {
+ type_name = "float";
+ flag_string =
+ strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_);
}
strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
type_name, flag.usage_text_.c_str());
diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h
index 2c77d7874f..f349df16fd 100644
--- a/tensorflow/core/util/command_line_flags.h
+++ b/tensorflow/core/util/command_line_flags.h
@@ -65,6 +65,7 @@ class Flag {
Flag(const char* name, int64* dst1, const string& usage_text);
Flag(const char* name, bool* dst, const string& usage_text);
Flag(const char* name, string* dst, const string& usage_text);
+ Flag(const char* name, float* dst, const string& usage_text);
private:
friend class Flags;
@@ -72,11 +73,12 @@ class Flag {
bool Parse(string arg, bool* value_parsing_ok) const;
string name_;
- enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING } type_;
+ enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_;
int* int_value_;
int64* int64_value_;
bool* bool_value_;
string* string_value_;
+ float* float_value_;
string usage_text_;
};
diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc
index b002e35899..62025463af 100644
--- a/tensorflow/core/util/command_line_flags_test.cc
+++ b/tensorflow/core/util/command_line_flags_test.cc
@@ -32,29 +32,35 @@ std::vector<char *> CharPointerVectorFromStrings(
}
return result;
}
-}
+} // namespace
TEST(CommandLineFlagsTest, BasicUsage) {
int some_int = 10;
int64 some_int64 = 21474836470; // max int32 is 2147483647
bool some_switch = false;
string some_name = "something";
- int argc = 5;
- std::vector<string> argv_strings = {
- "program_name", "--some_int=20", "--some_int64=214748364700",
- "--some_switch", "--some_name=somethingelse"};
+ float some_float = -23.23f;
+ int argc = 6;
+ std::vector<string> argv_strings = {"program_name",
+ "--some_int=20",
+ "--some_int64=214748364700",
+ "--some_switch",
+ "--some_name=somethingelse",
+ "--some_float=42.0"};
std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
bool parsed_ok =
Flags::Parse(&argc, argv_array.data(),
{Flag("some_int", &some_int, "some int"),
Flag("some_int64", &some_int64, "some int64"),
Flag("some_switch", &some_switch, "some switch"),
- Flag("some_name", &some_name, "some name")});
+ Flag("some_name", &some_name, "some name"),
+ Flag("some_float", &some_float, "some float")});
EXPECT_EQ(true, parsed_ok);
EXPECT_EQ(20, some_int);
EXPECT_EQ(214748364700, some_int64);
EXPECT_EQ(true, some_switch);
EXPECT_EQ("somethingelse", some_name);
+ EXPECT_NEAR(42.0f, some_float, 1e-5f);
EXPECT_EQ(argc, 1);
}
@@ -85,6 +91,21 @@ TEST(CommandLineFlagsTest, BadBoolValue) {
EXPECT_EQ(argc, 1);
}
+TEST(CommandLineFlagsTest, BadFloatValue) {
+ float some_float = -23.23f;
+ int argc = 2;
+ std::vector<string> argv_strings = {"program_name",
+ "--some_float=notanumber"};
+ std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ Flags::Parse(&argc, argv_array.data(),
+ {Flag("some_float", &some_float, "some float")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_NEAR(-23.23f, some_float, 1e-5f);
+ EXPECT_EQ(argc, 1);
+}
+
// Return whether str==pat, but allowing any whitespace in pat
// to match zero or more whitespace characters in str.
static bool MatchWithAnyWhitespace(const string &str, const string &pat) {
@@ -111,6 +132,8 @@ TEST(CommandLineFlagsTest, UsageString) {
int64 some_int64 = 21474836470; // max int32 is 2147483647
bool some_switch = false;
string some_name = "something";
+ // Don't test float in this case, because precision is hard to predict and
+ // match against, and we don't want a flakey test.
const string tool_name = "some_tool_name";
string usage = Flags::Usage(tool_name + "<flags>",
{Flag("some_int", &some_int, "some int"),