diff options
author | 2016-06-21 20:53:46 -0800 | |
---|---|---|
committer | 2016-06-21 22:02:58 -0700 | |
commit | e57750551e4aecea2ecc7f9339f00ca84997514c (patch) | |
tree | 7c431648703d0c8fefcf7b781e7fd7579d0192bf /tensorflow/core/lib/strings | |
parent | 75919978d545473846b11eb1db0081dc2870b974 (diff) |
Make floating point number parsing functions locale independent.
Fixes #2974
Change: 125530598
Diffstat (limited to 'tensorflow/core/lib/strings')
-rw-r--r-- | tensorflow/core/lib/strings/numbers.cc | 88 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/numbers_test.cc | 16 |
2 files changed, 99 insertions, 5 deletions
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc index ad25ffbaf2..568c829759 100644 --- a/tensorflow/core/lib/strings/numbers.cc +++ b/tensorflow/core/lib/strings/numbers.cc @@ -21,12 +21,89 @@ limitations under the License. #include <stdlib.h> #include <algorithm> #include <cmath> +#include <unordered_map> #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { + +namespace { + +template <typename T> +T locale_independent_strtonum(const char* str, const char** endptr) { + static const std::unordered_map<string, T> special_nums = { + {"inf", std::numeric_limits<T>::infinity()}, + {"+inf", std::numeric_limits<T>::infinity()}, + {"-inf", -std::numeric_limits<T>::infinity()}, + {"infinity", std::numeric_limits<T>::infinity()}, + {"+infinity", std::numeric_limits<T>::infinity()}, + {"-infinity", -std::numeric_limits<T>::infinity()}, + {"nan", std::numeric_limits<T>::quiet_NaN()}, + {"+nan", std::numeric_limits<T>::quiet_NaN()}, + {"-nan", -std::numeric_limits<T>::quiet_NaN()}, + }; + std::stringstream s(str); + + // Check if str is one of the special numbers. + string special_num_str; + s >> special_num_str; + + for (int i = 0; i < special_num_str.length(); ++i) { + special_num_str[i] = + std::tolower(special_num_str[i], std::locale::classic()); + } + + auto entry = special_nums.find(special_num_str); + if (entry != special_nums.end()) { + *endptr = str + (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str)) + : s.tellg()); + return entry->second; + } else { + // Perhaps it's a hex number + if (special_num_str.compare(0, 2, "0x") == 0 || + special_num_str.compare(0, 3, "-0x") == 0) { + return strtol(str, const_cast<char**>(endptr), 16); + } + } + // Reset the stream + s.str(str); + s.clear(); + // Use the "C" locale + s.imbue(std::locale::classic()); + + T result; + s >> result; + + // Set to result to what strto{f,d} functions would have returned. If the + // number was outside the range, the stringstream sets the fail flag, but + // returns the +/-max() value, whereas strto{f,d} functions return +/-INF. + bool real_fail = false; + if (s.fail()) { + real_fail = true; + if (result == std::numeric_limits<T>::max()) { + result = std::numeric_limits<T>::infinity(); + real_fail = false; + } else if (result == -std::numeric_limits<T>::max()) { + result = -std::numeric_limits<T>::infinity(); + real_fail = false; + } + } + + if (endptr) { + *endptr = + str + + (real_fail + ? static_cast<std::iostream::pos_type>(0) + : (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str)) + : s.tellg())); + } + return result; +} + +} // namespace + namespace strings { char* FastInt32ToBufferLeft(int32 i, char* buffer) { @@ -90,7 +167,8 @@ char* DoubleToBuffer(double value, char* buffer) { // larger than the precision we asked for. DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); - full_precision_needed = strtod(buffer, NULL) != value; + full_precision_needed = + locale_independent_strtonum<double>(buffer, NULL) != value; } if (full_precision_needed) { @@ -226,8 +304,8 @@ bool safe_strtou32(StringPiece str, uint32* value) { } bool safe_strtof(const char* str, float* value) { - char* endptr; - *value = strtof(str, &endptr); + const char* endptr; + *value = locale_independent_strtonum<float>(str, &endptr); while (isspace(*endptr)) ++endptr; // Ignore range errors from strtod/strtof. // The values it returns on underflow and @@ -237,8 +315,8 @@ bool safe_strtof(const char* str, float* value) { } bool safe_strtod(const char* str, double* value) { - char* endptr; - *value = strtod(str, &endptr); + const char* endptr; + *value = locale_independent_strtonum<double>(str, &endptr); while (isspace(*endptr)) ++endptr; // Ignore range errors from strtod/strtof. // The values it returns on underflow and diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc index 99c76f0121..66f820762f 100644 --- a/tensorflow/core/lib/strings/numbers_test.cc +++ b/tensorflow/core/lib/strings/numbers_test.cc @@ -233,8 +233,20 @@ TEST(safe_strtof, Float) { // Overflow to infinity, underflow to 0. EXPECT_TRUE(safe_strtof("1e39", &result)); EXPECT_EQ(std::numeric_limits<float>::infinity(), result); + + EXPECT_TRUE(safe_strtof("-1e39", &result)); + EXPECT_EQ(-std::numeric_limits<float>::infinity(), result); + EXPECT_TRUE(safe_strtof("1e-50", &result)); EXPECT_EQ(0, result); + + EXPECT_TRUE(safe_strtof("0xF", &result)); + EXPECT_EQ(0xF, result); + + EXPECT_TRUE(safe_strtof("-0x2A", &result)); + EXPECT_EQ(-42.0f, result); + + EXPECT_FALSE(safe_strtof("-infinity is awesome", &result)); } TEST(safe_strtod, Double) { @@ -247,6 +259,10 @@ TEST(safe_strtod, Double) { // Overflow to infinity, underflow to 0. EXPECT_TRUE(safe_strtod("1e310", &result)); EXPECT_EQ(std::numeric_limits<double>::infinity(), result); + + EXPECT_TRUE(safe_strtod("-1e310", &result)); + EXPECT_EQ(-std::numeric_limits<double>::infinity(), result); + EXPECT_TRUE(safe_strtod("1e-325", &result)); EXPECT_EQ(0, result); } |