aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib
diff options
context:
space:
mode:
authorGravatar akindyakov <akindyakov@gmail.com>2018-04-20 11:23:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-20 11:25:53 -0700
commit570d90b9c7e6a19bc2606fdaf7ad0f85b8590c0e (patch)
treeae3617cdb70686297ddf6bab05d99fd44bb64224 /tensorflow/core/lib
parent49f3469d9533cb12d06ed3907b4ced975e2fcea4 (diff)
Speed up safe_strtod and safe_strtof functions by using double-conversion library
Closes #12102. PiperOrigin-RevId: 193696537
Diffstat (limited to 'tensorflow/core/lib')
-rw-r--r--tensorflow/core/lib/strings/numbers.cc51
-rw-r--r--tensorflow/core/lib/strings/numbers.h2
-rw-r--r--tensorflow/core/lib/strings/numbers_test.cc87
-rw-r--r--tensorflow/core/lib/strings/str_util.cc8
-rw-r--r--tensorflow/core/lib/strings/str_util.h5
-rw-r--r--tensorflow/core/lib/strings/str_util_test.cc56
6 files changed, 143 insertions, 66 deletions
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
index c296daa95d..e4b909296e 100644
--- a/tensorflow/core/lib/strings/numbers.cc
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include <locale>
#include <unordered_map>
+#include "double-conversion/double-conversion.h"
+
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
@@ -110,6 +112,17 @@ T locale_independent_strtonum(const char* str, const char** endptr) {
return result;
}
+static inline const double_conversion::StringToDoubleConverter&
+StringToFloatConverter() {
+ static const double_conversion::StringToDoubleConverter converter(
+ double_conversion::StringToDoubleConverter::ALLOW_LEADING_SPACES |
+ double_conversion::StringToDoubleConverter::ALLOW_HEX |
+ double_conversion::StringToDoubleConverter::ALLOW_TRAILING_SPACES |
+ double_conversion::StringToDoubleConverter::ALLOW_CASE_INSENSIBILITY,
+ 0., 0., "inf", "nan");
+ return converter;
+}
+
} // namespace
namespace strings {
@@ -319,25 +332,31 @@ bool safe_strtou32(StringPiece str, uint32* value) {
}
bool safe_strtof(const char* str, float* value) {
- const char* endptr;
- *value = locale_independent_strtonum<float>(str, &endptr);
- while (isspace(*endptr)) ++endptr;
- // Ignore range errors from strtod/strtof.
- // The values it returns on underflow and
- // overflow are the right fallback in a
- // robust setting.
- return *str != '\0' && *endptr == '\0';
+ int processed_characters_count = -1;
+ auto len = str_util::Strnlen(str, kFastToBufferSize);
+
+ // If there is no zero-termination in str, fail.
+ if (len == kFastToBufferSize) return false;
+ // If string length exceeds int max, fail.
+ if (len > std::numeric_limits<int>::max()) return false;
+
+ *value = StringToFloatConverter().StringToFloat(str, static_cast<int>(len),
+ &processed_characters_count);
+ return processed_characters_count > 0;
}
bool safe_strtod(const char* str, double* value) {
- const char* endptr;
- *value = locale_independent_strtonum<double>(str, &endptr);
- while (isspace(*endptr)) ++endptr;
- // Ignore range errors from strtod/strtof.
- // The values it returns on underflow and
- // overflow are the right fallback in a
- // robust setting.
- return *str != '\0' && *endptr == '\0';
+ int processed_characters_count = -1;
+ auto len = str_util::Strnlen(str, kFastToBufferSize);
+
+ // If there is no zero-termination in str, fail.
+ if (len == kFastToBufferSize) return false;
+ // If string length exceeds int max, fail.
+ if (len > std::numeric_limits<int>::max()) return false;
+
+ *value = StringToFloatConverter().StringToDouble(str, static_cast<int>(len),
+ &processed_characters_count);
+ return processed_characters_count > 0;
}
size_t FloatToBuffer(float value, char* buffer) {
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
index 6b7703be37..e9add42849 100644
--- a/tensorflow/core/lib/strings/numbers.h
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -114,11 +114,13 @@ bool safe_strtou64(StringPiece str, uint64* value);
// Convert strings to floating point values.
// Leading and trailing spaces are allowed.
// Values may be rounded on over- and underflow.
+// Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`.
bool safe_strtof(const char* str, float* value);
// Convert strings to double precision floating point values.
// Leading and trailing spaces are allowed.
// Values may be rounded on over- and underflow.
+// Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`.
bool safe_strtod(const char* str, double* value);
inline bool ProtoParseNumeric(StringPiece s, int32* value) {
diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc
index e15161de66..0f22dac262 100644
--- a/tensorflow/core/lib/strings/numbers_test.cc
+++ b/tensorflow/core/lib/strings/numbers_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
+#include <cmath>
#include <string>
#include "tensorflow/core/platform/test.h"
@@ -277,7 +278,49 @@ TEST(safe_strtof, Float) {
EXPECT_TRUE(safe_strtof("-0x2A", &result));
EXPECT_EQ(-42.0f, result);
+ EXPECT_TRUE(safe_strtof(" -0x2", &result));
+ EXPECT_EQ(-2.0f, result);
+
+ EXPECT_TRUE(safe_strtof("8 \t", &result));
+ EXPECT_EQ(8.0f, result);
+
+ EXPECT_TRUE(safe_strtof("\t20.0\t ", &result));
+ EXPECT_EQ(20.0f, result);
+
EXPECT_FALSE(safe_strtof("-infinity is awesome", &result));
+
+ // Make sure we exit cleanly if the string is not terminated
+ char test_str[2 * kFastToBufferSize];
+ for (int i = 0; i < 2 * kFastToBufferSize; ++i) test_str[i] = 'a';
+ EXPECT_FALSE(safe_strtof(test_str, &result));
+
+ // Make sure we exit cleanly if the string is too long
+ test_str[kFastToBufferSize + 1] = '\0';
+ EXPECT_FALSE(safe_strtof(test_str, &result));
+
+ EXPECT_TRUE(safe_strtof("-inf", &result));
+ EXPECT_EQ(-std::numeric_limits<float>::infinity(), result);
+
+ EXPECT_TRUE(safe_strtof("+inf", &result));
+ EXPECT_EQ(std::numeric_limits<float>::infinity(), result);
+
+ EXPECT_TRUE(safe_strtof("InF", &result));
+ EXPECT_EQ(std::numeric_limits<float>::infinity(), result);
+
+ EXPECT_TRUE(safe_strtof("-INF", &result));
+ EXPECT_EQ(-std::numeric_limits<float>::infinity(), result);
+
+ EXPECT_TRUE(safe_strtof("nan", &result));
+ EXPECT_TRUE(std::isnan(result));
+
+ EXPECT_TRUE(safe_strtof("-nan", &result));
+ EXPECT_TRUE(std::isnan(result));
+
+ EXPECT_TRUE(safe_strtof("-NaN", &result));
+ EXPECT_TRUE(std::isnan(result));
+
+ EXPECT_TRUE(safe_strtof("+NAN", &result));
+ EXPECT_TRUE(std::isnan(result));
}
TEST(safe_strtod, Double) {
@@ -287,6 +330,15 @@ TEST(safe_strtod, Double) {
EXPECT_EQ(0.1234567890123, result);
EXPECT_FALSE(safe_strtod("0.1234567890123abc", &result));
+ // Make sure we exit cleanly if the string is not terminated
+ char test_str[2 * kFastToBufferSize];
+ for (int i = 0; i < 2 * kFastToBufferSize; ++i) test_str[i] = 'a';
+ EXPECT_FALSE(safe_strtod(test_str, &result));
+
+ // Make sure we exit cleanly if the string is too long
+ test_str[kFastToBufferSize + 1] = '\0';
+ EXPECT_FALSE(safe_strtod(test_str, &result));
+
// Overflow to infinity, underflow to 0.
EXPECT_TRUE(safe_strtod("1e310", &result));
EXPECT_EQ(std::numeric_limits<double>::infinity(), result);
@@ -296,6 +348,41 @@ TEST(safe_strtod, Double) {
EXPECT_TRUE(safe_strtod("1e-325", &result));
EXPECT_EQ(0, result);
+
+ EXPECT_TRUE(safe_strtod(" -0x1c", &result));
+ EXPECT_EQ(-28.0, result);
+
+ EXPECT_TRUE(safe_strtod("50 \t", &result));
+ EXPECT_EQ(50.0, result);
+
+ EXPECT_TRUE(safe_strtod("\t82.0\t ", &result));
+ EXPECT_EQ(82.0, result);
+
+ EXPECT_FALSE(safe_strtod("infinity", &result));
+
+ EXPECT_TRUE(safe_strtod("-inf", &result));
+ EXPECT_EQ(-std::numeric_limits<double>::infinity(), result);
+
+ EXPECT_TRUE(safe_strtod("+inf", &result));
+ EXPECT_EQ(std::numeric_limits<double>::infinity(), result);
+
+ EXPECT_TRUE(safe_strtod("InF", &result));
+ EXPECT_EQ(std::numeric_limits<double>::infinity(), result);
+
+ EXPECT_TRUE(safe_strtod("-INF", &result));
+ EXPECT_EQ(-std::numeric_limits<double>::infinity(), result);
+
+ EXPECT_TRUE(safe_strtod("nan", &result));
+ EXPECT_TRUE(std::isnan(result));
+
+ EXPECT_TRUE(safe_strtod("-nan", &result));
+ EXPECT_TRUE(std::isnan(result));
+
+ EXPECT_TRUE(safe_strtod("-NaN", &result));
+ EXPECT_TRUE(std::isnan(result));
+
+ EXPECT_TRUE(safe_strtod("+NAN", &result));
+ EXPECT_TRUE(std::isnan(result));
}
} // namespace strings
diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc
index 2c9e98357a..4598b8ccc7 100644
--- a/tensorflow/core/lib/strings/str_util.cc
+++ b/tensorflow/core/lib/strings/str_util.cc
@@ -454,6 +454,14 @@ bool SplitAndParseAsFloats(StringPiece text, char delim,
result);
}
+size_t Strnlen(const char* str, const size_t string_max_len) {
+ size_t len = 0;
+ while (len < string_max_len && str[len] != '\0') {
+ ++len;
+ }
+ return len;
+}
+
bool StrContains(StringPiece haystack, StringPiece needle) {
return std::search(haystack.begin(), haystack.end(), needle.begin(),
needle.end()) != haystack.end();
diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h
index 065871c1b4..e97d00b975 100644
--- a/tensorflow/core/lib/strings/str_util.h
+++ b/tensorflow/core/lib/strings/str_util.h
@@ -223,6 +223,11 @@ std::vector<string> Split(StringPiece text, char delims, Predicate p) {
return Split(text, StringPiece(&delims, 1), p);
}
+// Returns the length of the given null-terminated byte string 'str'.
+// Returns 'string_max_len' if the null character was not found in the first
+// 'string_max_len' bytes of 'str'.
+size_t Strnlen(const char* str, const size_t string_max_len);
+
} // namespace str_util
} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc
index 63643c3e8e..3bf3e99825 100644
--- a/tensorflow/core/lib/strings/str_util_test.cc
+++ b/tensorflow/core/lib/strings/str_util_test.cc
@@ -430,56 +430,12 @@ TEST(StringReplace, EmptyStringReplaceAll) {
EXPECT_EQ("", str_util::StringReplace("", "a", "X", /*replace_all=*/true));
}
-TEST(StartsWith, Basic) {
- const string s1(
- "123"
- "\0"
- "456",
- 7);
- const StringPiece a("foobar");
- const StringPiece b(s1);
- const StringPiece e;
- EXPECT_TRUE(str_util::StartsWith(a, a));
- EXPECT_TRUE(str_util::StartsWith(a, "foo"));
- EXPECT_TRUE(str_util::StartsWith(a, e));
- EXPECT_TRUE(str_util::StartsWith(b, s1));
- EXPECT_TRUE(str_util::StartsWith(b, b));
- EXPECT_TRUE(str_util::StartsWith(b, e));
- EXPECT_TRUE(str_util::StartsWith(e, ""));
- EXPECT_FALSE(str_util::StartsWith(a, b));
- EXPECT_FALSE(str_util::StartsWith(b, a));
- EXPECT_FALSE(str_util::StartsWith(e, a));
-}
-
-TEST(EndsWith, Basic) {
- const string s1(
- "123"
- "\0"
- "456",
- 7);
- const StringPiece a("foobar");
- const StringPiece b(s1);
- const StringPiece e;
- EXPECT_TRUE(str_util::EndsWith(a, a));
- EXPECT_TRUE(str_util::EndsWith(a, "bar"));
- EXPECT_TRUE(str_util::EndsWith(a, e));
- EXPECT_TRUE(str_util::EndsWith(b, s1));
- EXPECT_TRUE(str_util::EndsWith(b, b));
- EXPECT_TRUE(str_util::EndsWith(b, e));
- EXPECT_TRUE(str_util::EndsWith(e, ""));
- EXPECT_FALSE(str_util::EndsWith(a, b));
- EXPECT_FALSE(str_util::EndsWith(b, a));
- EXPECT_FALSE(str_util::EndsWith(e, a));
-}
-
-TEST(StrContains, Basic) {
- StringPiece a("abcdefg");
- StringPiece b("abcd");
- StringPiece c("efg");
- StringPiece d("gh");
- EXPECT_TRUE(str_util::StrContains(a, b));
- EXPECT_TRUE(str_util::StrContains(a, c));
- EXPECT_TRUE(!str_util::StrContains(a, d));
+TEST(Strnlen, Basic) {
+ EXPECT_EQ(0, str_util::Strnlen("ab", 0));
+ EXPECT_EQ(1, str_util::Strnlen("a", 1));
+ EXPECT_EQ(2, str_util::Strnlen("abcd", 2));
+ EXPECT_EQ(3, str_util::Strnlen("abc", 10));
+ EXPECT_EQ(4, str_util::Strnlen("a \t\n", 10));
}
} // namespace tensorflow