aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/strings
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@google.com>2016-06-21 20:53:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-21 22:02:58 -0700
commite57750551e4aecea2ecc7f9339f00ca84997514c (patch)
tree7c431648703d0c8fefcf7b781e7fd7579d0192bf /tensorflow/core/lib/strings
parent75919978d545473846b11eb1db0081dc2870b974 (diff)
Make floating point number parsing functions locale independent.
Fixes #2974 Change: 125530598
Diffstat (limited to 'tensorflow/core/lib/strings')
-rw-r--r--tensorflow/core/lib/strings/numbers.cc88
-rw-r--r--tensorflow/core/lib/strings/numbers_test.cc16
2 files changed, 99 insertions, 5 deletions
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
index ad25ffbaf2..568c829759 100644
--- a/tensorflow/core/lib/strings/numbers.cc
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -21,12 +21,89 @@ limitations under the License.
#include <stdlib.h>
#include <algorithm>
#include <cmath>
+#include <unordered_map>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+
+namespace {
+
+template <typename T>
+T locale_independent_strtonum(const char* str, const char** endptr) {
+ static const std::unordered_map<string, T> special_nums = {
+ {"inf", std::numeric_limits<T>::infinity()},
+ {"+inf", std::numeric_limits<T>::infinity()},
+ {"-inf", -std::numeric_limits<T>::infinity()},
+ {"infinity", std::numeric_limits<T>::infinity()},
+ {"+infinity", std::numeric_limits<T>::infinity()},
+ {"-infinity", -std::numeric_limits<T>::infinity()},
+ {"nan", std::numeric_limits<T>::quiet_NaN()},
+ {"+nan", std::numeric_limits<T>::quiet_NaN()},
+ {"-nan", -std::numeric_limits<T>::quiet_NaN()},
+ };
+ std::stringstream s(str);
+
+ // Check if str is one of the special numbers.
+ string special_num_str;
+ s >> special_num_str;
+
+ for (int i = 0; i < special_num_str.length(); ++i) {
+ special_num_str[i] =
+ std::tolower(special_num_str[i], std::locale::classic());
+ }
+
+ auto entry = special_nums.find(special_num_str);
+ if (entry != special_nums.end()) {
+ *endptr = str + (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str))
+ : s.tellg());
+ return entry->second;
+ } else {
+ // Perhaps it's a hex number
+ if (special_num_str.compare(0, 2, "0x") == 0 ||
+ special_num_str.compare(0, 3, "-0x") == 0) {
+ return strtol(str, const_cast<char**>(endptr), 16);
+ }
+ }
+ // Reset the stream
+ s.str(str);
+ s.clear();
+ // Use the "C" locale
+ s.imbue(std::locale::classic());
+
+ T result;
+ s >> result;
+
+ // Set to result to what strto{f,d} functions would have returned. If the
+ // number was outside the range, the stringstream sets the fail flag, but
+ // returns the +/-max() value, whereas strto{f,d} functions return +/-INF.
+ bool real_fail = false;
+ if (s.fail()) {
+ real_fail = true;
+ if (result == std::numeric_limits<T>::max()) {
+ result = std::numeric_limits<T>::infinity();
+ real_fail = false;
+ } else if (result == -std::numeric_limits<T>::max()) {
+ result = -std::numeric_limits<T>::infinity();
+ real_fail = false;
+ }
+ }
+
+ if (endptr) {
+ *endptr =
+ str +
+ (real_fail
+ ? static_cast<std::iostream::pos_type>(0)
+ : (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str))
+ : s.tellg()));
+ }
+ return result;
+}
+
+} // namespace
+
namespace strings {
char* FastInt32ToBufferLeft(int32 i, char* buffer) {
@@ -90,7 +167,8 @@ char* DoubleToBuffer(double value, char* buffer) {
// larger than the precision we asked for.
DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
- full_precision_needed = strtod(buffer, NULL) != value;
+ full_precision_needed =
+ locale_independent_strtonum<double>(buffer, NULL) != value;
}
if (full_precision_needed) {
@@ -226,8 +304,8 @@ bool safe_strtou32(StringPiece str, uint32* value) {
}
bool safe_strtof(const char* str, float* value) {
- char* endptr;
- *value = strtof(str, &endptr);
+ const char* endptr;
+ *value = locale_independent_strtonum<float>(str, &endptr);
while (isspace(*endptr)) ++endptr;
// Ignore range errors from strtod/strtof.
// The values it returns on underflow and
@@ -237,8 +315,8 @@ bool safe_strtof(const char* str, float* value) {
}
bool safe_strtod(const char* str, double* value) {
- char* endptr;
- *value = strtod(str, &endptr);
+ const char* endptr;
+ *value = locale_independent_strtonum<double>(str, &endptr);
while (isspace(*endptr)) ++endptr;
// Ignore range errors from strtod/strtof.
// The values it returns on underflow and
diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc
index 99c76f0121..66f820762f 100644
--- a/tensorflow/core/lib/strings/numbers_test.cc
+++ b/tensorflow/core/lib/strings/numbers_test.cc
@@ -233,8 +233,20 @@ TEST(safe_strtof, Float) {
// Overflow to infinity, underflow to 0.
EXPECT_TRUE(safe_strtof("1e39", &result));
EXPECT_EQ(std::numeric_limits<float>::infinity(), result);
+
+ EXPECT_TRUE(safe_strtof("-1e39", &result));
+ EXPECT_EQ(-std::numeric_limits<float>::infinity(), result);
+
EXPECT_TRUE(safe_strtof("1e-50", &result));
EXPECT_EQ(0, result);
+
+ EXPECT_TRUE(safe_strtof("0xF", &result));
+ EXPECT_EQ(0xF, result);
+
+ EXPECT_TRUE(safe_strtof("-0x2A", &result));
+ EXPECT_EQ(-42.0f, result);
+
+ EXPECT_FALSE(safe_strtof("-infinity is awesome", &result));
}
TEST(safe_strtod, Double) {
@@ -247,6 +259,10 @@ TEST(safe_strtod, Double) {
// Overflow to infinity, underflow to 0.
EXPECT_TRUE(safe_strtod("1e310", &result));
EXPECT_EQ(std::numeric_limits<double>::infinity(), result);
+
+ EXPECT_TRUE(safe_strtod("-1e310", &result));
+ EXPECT_EQ(-std::numeric_limits<double>::infinity(), result);
+
EXPECT_TRUE(safe_strtod("1e-325", &result));
EXPECT_EQ(0, result);
}