diff options
author | 2018-04-20 11:23:15 -0700 | |
---|---|---|
committer | 2018-04-20 11:25:53 -0700 | |
commit | 570d90b9c7e6a19bc2606fdaf7ad0f85b8590c0e (patch) | |
tree | ae3617cdb70686297ddf6bab05d99fd44bb64224 /tensorflow/core/lib/strings | |
parent | 49f3469d9533cb12d06ed3907b4ced975e2fcea4 (diff) |
Speed up safe_strtod and safe_strtof functions by using double-conversion library
Closes #12102.
PiperOrigin-RevId: 193696537
Diffstat (limited to 'tensorflow/core/lib/strings')
-rw-r--r-- | tensorflow/core/lib/strings/numbers.cc | 51 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/numbers.h | 2 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/numbers_test.cc | 87 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/str_util.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/str_util.h | 5 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/str_util_test.cc | 56 |
6 files changed, 143 insertions, 66 deletions
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc index c296daa95d..e4b909296e 100644 --- a/tensorflow/core/lib/strings/numbers.cc +++ b/tensorflow/core/lib/strings/numbers.cc @@ -23,6 +23,8 @@ limitations under the License. #include <locale> #include <unordered_map> +#include "double-conversion/double-conversion.h" + #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -110,6 +112,17 @@ T locale_independent_strtonum(const char* str, const char** endptr) { return result; } +static inline const double_conversion::StringToDoubleConverter& +StringToFloatConverter() { + static const double_conversion::StringToDoubleConverter converter( + double_conversion::StringToDoubleConverter::ALLOW_LEADING_SPACES | + double_conversion::StringToDoubleConverter::ALLOW_HEX | + double_conversion::StringToDoubleConverter::ALLOW_TRAILING_SPACES | + double_conversion::StringToDoubleConverter::ALLOW_CASE_INSENSIBILITY, + 0., 0., "inf", "nan"); + return converter; +} + } // namespace namespace strings { @@ -319,25 +332,31 @@ bool safe_strtou32(StringPiece str, uint32* value) { } bool safe_strtof(const char* str, float* value) { - const char* endptr; - *value = locale_independent_strtonum<float>(str, &endptr); - while (isspace(*endptr)) ++endptr; - // Ignore range errors from strtod/strtof. - // The values it returns on underflow and - // overflow are the right fallback in a - // robust setting. - return *str != '\0' && *endptr == '\0'; + int processed_characters_count = -1; + auto len = str_util::Strnlen(str, kFastToBufferSize); + + // If there is no zero-termination in str, fail. + if (len == kFastToBufferSize) return false; + // If string length exceeds int max, fail. + if (len > std::numeric_limits<int>::max()) return false; + + *value = StringToFloatConverter().StringToFloat(str, static_cast<int>(len), + &processed_characters_count); + return processed_characters_count > 0; } bool safe_strtod(const char* str, double* value) { - const char* endptr; - *value = locale_independent_strtonum<double>(str, &endptr); - while (isspace(*endptr)) ++endptr; - // Ignore range errors from strtod/strtof. - // The values it returns on underflow and - // overflow are the right fallback in a - // robust setting. - return *str != '\0' && *endptr == '\0'; + int processed_characters_count = -1; + auto len = str_util::Strnlen(str, kFastToBufferSize); + + // If there is no zero-termination in str, fail. + if (len == kFastToBufferSize) return false; + // If string length exceeds int max, fail. + if (len > std::numeric_limits<int>::max()) return false; + + *value = StringToFloatConverter().StringToDouble(str, static_cast<int>(len), + &processed_characters_count); + return processed_characters_count > 0; } size_t FloatToBuffer(float value, char* buffer) { diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h index 6b7703be37..e9add42849 100644 --- a/tensorflow/core/lib/strings/numbers.h +++ b/tensorflow/core/lib/strings/numbers.h @@ -114,11 +114,13 @@ bool safe_strtou64(StringPiece str, uint64* value); // Convert strings to floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. +// Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. bool safe_strtof(const char* str, float* value); // Convert strings to double precision floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. +// Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. bool safe_strtod(const char* str, double* value); inline bool ProtoParseNumeric(StringPiece s, int32* value) { diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc index e15161de66..0f22dac262 100644 --- a/tensorflow/core/lib/strings/numbers_test.cc +++ b/tensorflow/core/lib/strings/numbers_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" +#include <cmath> #include <string> #include "tensorflow/core/platform/test.h" @@ -277,7 +278,49 @@ TEST(safe_strtof, Float) { EXPECT_TRUE(safe_strtof("-0x2A", &result)); EXPECT_EQ(-42.0f, result); + EXPECT_TRUE(safe_strtof(" -0x2", &result)); + EXPECT_EQ(-2.0f, result); + + EXPECT_TRUE(safe_strtof("8 \t", &result)); + EXPECT_EQ(8.0f, result); + + EXPECT_TRUE(safe_strtof("\t20.0\t ", &result)); + EXPECT_EQ(20.0f, result); + EXPECT_FALSE(safe_strtof("-infinity is awesome", &result)); + + // Make sure we exit cleanly if the string is not terminated + char test_str[2 * kFastToBufferSize]; + for (int i = 0; i < 2 * kFastToBufferSize; ++i) test_str[i] = 'a'; + EXPECT_FALSE(safe_strtof(test_str, &result)); + + // Make sure we exit cleanly if the string is too long + test_str[kFastToBufferSize + 1] = '\0'; + EXPECT_FALSE(safe_strtof(test_str, &result)); + + EXPECT_TRUE(safe_strtof("-inf", &result)); + EXPECT_EQ(-std::numeric_limits<float>::infinity(), result); + + EXPECT_TRUE(safe_strtof("+inf", &result)); + EXPECT_EQ(std::numeric_limits<float>::infinity(), result); + + EXPECT_TRUE(safe_strtof("InF", &result)); + EXPECT_EQ(std::numeric_limits<float>::infinity(), result); + + EXPECT_TRUE(safe_strtof("-INF", &result)); + EXPECT_EQ(-std::numeric_limits<float>::infinity(), result); + + EXPECT_TRUE(safe_strtof("nan", &result)); + EXPECT_TRUE(std::isnan(result)); + + EXPECT_TRUE(safe_strtof("-nan", &result)); + EXPECT_TRUE(std::isnan(result)); + + EXPECT_TRUE(safe_strtof("-NaN", &result)); + EXPECT_TRUE(std::isnan(result)); + + EXPECT_TRUE(safe_strtof("+NAN", &result)); + EXPECT_TRUE(std::isnan(result)); } TEST(safe_strtod, Double) { @@ -287,6 +330,15 @@ TEST(safe_strtod, Double) { EXPECT_EQ(0.1234567890123, result); EXPECT_FALSE(safe_strtod("0.1234567890123abc", &result)); + // Make sure we exit cleanly if the string is not terminated + char test_str[2 * kFastToBufferSize]; + for (int i = 0; i < 2 * kFastToBufferSize; ++i) test_str[i] = 'a'; + EXPECT_FALSE(safe_strtod(test_str, &result)); + + // Make sure we exit cleanly if the string is too long + test_str[kFastToBufferSize + 1] = '\0'; + EXPECT_FALSE(safe_strtod(test_str, &result)); + // Overflow to infinity, underflow to 0. EXPECT_TRUE(safe_strtod("1e310", &result)); EXPECT_EQ(std::numeric_limits<double>::infinity(), result); @@ -296,6 +348,41 @@ TEST(safe_strtod, Double) { EXPECT_TRUE(safe_strtod("1e-325", &result)); EXPECT_EQ(0, result); + + EXPECT_TRUE(safe_strtod(" -0x1c", &result)); + EXPECT_EQ(-28.0, result); + + EXPECT_TRUE(safe_strtod("50 \t", &result)); + EXPECT_EQ(50.0, result); + + EXPECT_TRUE(safe_strtod("\t82.0\t ", &result)); + EXPECT_EQ(82.0, result); + + EXPECT_FALSE(safe_strtod("infinity", &result)); + + EXPECT_TRUE(safe_strtod("-inf", &result)); + EXPECT_EQ(-std::numeric_limits<double>::infinity(), result); + + EXPECT_TRUE(safe_strtod("+inf", &result)); + EXPECT_EQ(std::numeric_limits<double>::infinity(), result); + + EXPECT_TRUE(safe_strtod("InF", &result)); + EXPECT_EQ(std::numeric_limits<double>::infinity(), result); + + EXPECT_TRUE(safe_strtod("-INF", &result)); + EXPECT_EQ(-std::numeric_limits<double>::infinity(), result); + + EXPECT_TRUE(safe_strtod("nan", &result)); + EXPECT_TRUE(std::isnan(result)); + + EXPECT_TRUE(safe_strtod("-nan", &result)); + EXPECT_TRUE(std::isnan(result)); + + EXPECT_TRUE(safe_strtod("-NaN", &result)); + EXPECT_TRUE(std::isnan(result)); + + EXPECT_TRUE(safe_strtod("+NAN", &result)); + EXPECT_TRUE(std::isnan(result)); } } // namespace strings diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc index 2c9e98357a..4598b8ccc7 100644 --- a/tensorflow/core/lib/strings/str_util.cc +++ b/tensorflow/core/lib/strings/str_util.cc @@ -454,6 +454,14 @@ bool SplitAndParseAsFloats(StringPiece text, char delim, result); } +size_t Strnlen(const char* str, const size_t string_max_len) { + size_t len = 0; + while (len < string_max_len && str[len] != '\0') { + ++len; + } + return len; +} + bool StrContains(StringPiece haystack, StringPiece needle) { return std::search(haystack.begin(), haystack.end(), needle.begin(), needle.end()) != haystack.end(); diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h index 065871c1b4..e97d00b975 100644 --- a/tensorflow/core/lib/strings/str_util.h +++ b/tensorflow/core/lib/strings/str_util.h @@ -223,6 +223,11 @@ std::vector<string> Split(StringPiece text, char delims, Predicate p) { return Split(text, StringPiece(&delims, 1), p); } +// Returns the length of the given null-terminated byte string 'str'. +// Returns 'string_max_len' if the null character was not found in the first +// 'string_max_len' bytes of 'str'. +size_t Strnlen(const char* str, const size_t string_max_len); + } // namespace str_util } // namespace tensorflow diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc index 63643c3e8e..3bf3e99825 100644 --- a/tensorflow/core/lib/strings/str_util_test.cc +++ b/tensorflow/core/lib/strings/str_util_test.cc @@ -430,56 +430,12 @@ TEST(StringReplace, EmptyStringReplaceAll) { EXPECT_EQ("", str_util::StringReplace("", "a", "X", /*replace_all=*/true)); } -TEST(StartsWith, Basic) { - const string s1( - "123" - "\0" - "456", - 7); - const StringPiece a("foobar"); - const StringPiece b(s1); - const StringPiece e; - EXPECT_TRUE(str_util::StartsWith(a, a)); - EXPECT_TRUE(str_util::StartsWith(a, "foo")); - EXPECT_TRUE(str_util::StartsWith(a, e)); - EXPECT_TRUE(str_util::StartsWith(b, s1)); - EXPECT_TRUE(str_util::StartsWith(b, b)); - EXPECT_TRUE(str_util::StartsWith(b, e)); - EXPECT_TRUE(str_util::StartsWith(e, "")); - EXPECT_FALSE(str_util::StartsWith(a, b)); - EXPECT_FALSE(str_util::StartsWith(b, a)); - EXPECT_FALSE(str_util::StartsWith(e, a)); -} - -TEST(EndsWith, Basic) { - const string s1( - "123" - "\0" - "456", - 7); - const StringPiece a("foobar"); - const StringPiece b(s1); - const StringPiece e; - EXPECT_TRUE(str_util::EndsWith(a, a)); - EXPECT_TRUE(str_util::EndsWith(a, "bar")); - EXPECT_TRUE(str_util::EndsWith(a, e)); - EXPECT_TRUE(str_util::EndsWith(b, s1)); - EXPECT_TRUE(str_util::EndsWith(b, b)); - EXPECT_TRUE(str_util::EndsWith(b, e)); - EXPECT_TRUE(str_util::EndsWith(e, "")); - EXPECT_FALSE(str_util::EndsWith(a, b)); - EXPECT_FALSE(str_util::EndsWith(b, a)); - EXPECT_FALSE(str_util::EndsWith(e, a)); -} - -TEST(StrContains, Basic) { - StringPiece a("abcdefg"); - StringPiece b("abcd"); - StringPiece c("efg"); - StringPiece d("gh"); - EXPECT_TRUE(str_util::StrContains(a, b)); - EXPECT_TRUE(str_util::StrContains(a, c)); - EXPECT_TRUE(!str_util::StrContains(a, d)); +TEST(Strnlen, Basic) { + EXPECT_EQ(0, str_util::Strnlen("ab", 0)); + EXPECT_EQ(1, str_util::Strnlen("a", 1)); + EXPECT_EQ(2, str_util::Strnlen("abcd", 2)); + EXPECT_EQ(3, str_util::Strnlen("abc", 10)); + EXPECT_EQ(4, str_util::Strnlen("a \t\n", 10)); } } // namespace tensorflow |