diff options
-rw-r--r-- | tensorflow/core/common_runtime/gpu/gpu_device.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/framework/op_def_builder.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/decode_csv_op.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_to_number_op.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/lib/core/stringpiece.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/lib/core/stringpiece.h | 8 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/numbers.cc | 53 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/numbers.h | 4 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/numbers_test.cc | 14 |
9 files changed, 60 insertions, 48 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index c46c785744..53d1d7ffdd 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -667,9 +667,9 @@ struct CudaVersion { size_t dot_pos = version_name.find('.'); CHECK(dot_pos != string::npos); string major_str = version_name.substr(0, dot_pos); - CHECK(strings::safe_strto32(major_str.c_str(), &major_part)); + CHECK(strings::safe_strto32(major_str, &major_part)); string minor_str = version_name.substr(dot_pos + 1); - CHECK(strings::safe_strto32(minor_str.c_str(), &minor_part)); + CHECK(strings::safe_strto32(minor_str, &minor_part)); } CudaVersion() {} bool operator<(const CudaVersion& other) const { diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index ee80291759..5931b1ace1 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -92,7 +92,7 @@ bool ConsumeAttrNumber(StringPiece* sp, int64* out) { return false; } int64 value = 0; - if (!strings::safe_strto64(match.ToString().c_str(), &value)) { + if (!strings::safe_strto64(match, &value)) { return false; } *out = value; diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc index 18ee40e623..60f0474103 100644 --- a/tensorflow/core/kernels/decode_csv_op.cc +++ b/tensorflow/core/kernels/decode_csv_op.cc @@ -88,7 +88,7 @@ class DecodeCSVOp : public OpKernel { output[f]->flat<int32>()(i) = record_defaults[f].flat<int32>()(0); } else { int32 value; - OP_REQUIRES(ctx, strings::safe_strto32(fields[f].c_str(), &value), + OP_REQUIRES(ctx, strings::safe_strto32(fields[f], &value), errors::InvalidArgument("Field ", f, " in record ", i, " is not a valid int32: ", fields[f])); @@ -108,7 +108,7 @@ class DecodeCSVOp : public OpKernel { output[f]->flat<int64>()(i) = record_defaults[f].flat<int64>()(0); } else { int64 value; - OP_REQUIRES(ctx, strings::safe_strto64(fields[f].c_str(), &value), + OP_REQUIRES(ctx, strings::safe_strto64(fields[f], &value), errors::InvalidArgument("Field ", f, " in record ", i, " is not a valid int64: ", fields[f])); diff --git a/tensorflow/core/kernels/string_to_number_op.cc b/tensorflow/core/kernels/string_to_number_op.cc index 34abae788f..bc8a781cf7 100644 --- a/tensorflow/core/kernels/string_to_number_op.cc +++ b/tensorflow/core/kernels/string_to_number_op.cc @@ -49,25 +49,24 @@ class StringToNumberOp : public OpKernel { auto output_flat = output_tensor->flat<OutputType>(); for (int i = 0; i < input_flat.size(); ++i) { - const char* s = input_flat(i).data(); - Convert(s, &output_flat(i), context); + Convert(input_flat(i), &output_flat(i), context); } } private: - void Convert(const char* s, OutputType* output_data, + void Convert(const string& s, OutputType* output_data, OpKernelContext* context); }; template <> -void StringToNumberOp<float>::Convert(const char* s, float* output_data, +void StringToNumberOp<float>::Convert(const string& s, float* output_data, OpKernelContext* context) { - OP_REQUIRES(context, strings::safe_strtof(s, output_data), + OP_REQUIRES(context, strings::safe_strtof(s.c_str(), output_data), errors::InvalidArgument(kErrorMessage, s)); } template <> -void StringToNumberOp<int32>::Convert(const char* s, int32* output_data, +void StringToNumberOp<int32>::Convert(const string& s, int32* output_data, OpKernelContext* context) { OP_REQUIRES(context, strings::safe_strto32(s, output_data), errors::InvalidArgument(kErrorMessage, s)); diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc index 4a751881f7..17a143c550 100644 --- a/tensorflow/core/lib/core/stringpiece.cc +++ b/tensorflow/core/lib/core/stringpiece.cc @@ -54,14 +54,6 @@ size_t StringPiece::rfind(char c, size_t pos) const { return npos; } -bool StringPiece::Consume(StringPiece x) { - if (starts_with(x)) { - remove_prefix(x.size_); - return true; - } - return false; -} - StringPiece StringPiece::substr(size_t pos, size_t n) const { if (pos > size_) pos = size_; if (n > size_ - pos) n = size_ - pos; diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h index 7ac8df89a6..4740748f2d 100644 --- a/tensorflow/core/lib/core/stringpiece.h +++ b/tensorflow/core/lib/core/stringpiece.h @@ -104,7 +104,13 @@ class StringPiece { // Checks whether StringPiece starts with x and if so advances the beginning // of it to past the match. It's basically a shortcut for starts_with // followed by remove_prefix. - bool Consume(StringPiece x); + bool Consume(StringPiece x) { + if (starts_with(x)) { + remove_prefix(x.size_); + return true; + } + return false; + } StringPiece substr(size_t pos, size_t n = npos) const; diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc index 5e63eb9c01..9e3a6d5458 100644 --- a/tensorflow/core/lib/strings/numbers.cc +++ b/tensorflow/core/lib/strings/numbers.cc @@ -103,83 +103,84 @@ char* DoubleToBuffer(double value, char* buffer) { return buffer; } -bool safe_strto64(const char* str, int64* value) { - if (!str) return false; +namespace { +char SafeFirstChar(StringPiece str) { + if (str.empty()) return '\0'; + return str[0]; +} +} // namespace +bool safe_strto64(StringPiece str, int64* value) { // Skip leading space. - while (isspace(*str)) ++str; + while (isspace(SafeFirstChar(str))) str.remove_prefix(1); int64 vlimit = kint64max; int sign = 1; - if (*str == '-') { + if (str.Consume("-")) { sign = -1; - ++str; // Different limit for positive and negative integers. vlimit = kint64min; } - if (!isdigit(*str)) return false; + if (!isdigit(SafeFirstChar(str))) return false; int64 result = 0; if (sign == 1) { do { - int digit = *str - '0'; + int digit = SafeFirstChar(str) - '0'; if ((vlimit - digit) / 10 < result) { return false; } result = result * 10 + digit; - ++str; - } while (isdigit(*str)); + str.remove_prefix(1); + } while (isdigit(SafeFirstChar(str))); } else { do { - int digit = *str - '0'; + int digit = SafeFirstChar(str) - '0'; if ((vlimit + digit) / 10 > result) { return false; } result = result * 10 - digit; - ++str; - } while (isdigit(*str)); + str.remove_prefix(1); + } while (isdigit(SafeFirstChar(str))); } // Skip trailing space. - while (isspace(*str)) ++str; + while (isspace(SafeFirstChar(str))) str.remove_prefix(1); - if (*str) return false; + if (!str.empty()) return false; *value = result; return true; } -bool safe_strto32(const char* str, int32* value) { - if (!str) return false; - +bool safe_strto32(StringPiece str, int32* value) { // Skip leading space. - while (isspace(*str)) ++str; + while (isspace(SafeFirstChar(str))) str.remove_prefix(1); int64 vmax = kint32max; int sign = 1; - if (*str == '-') { + if (str.Consume("-")) { sign = -1; - ++str; // Different max for positive and negative integers. ++vmax; } - if (!isdigit(*str)) return false; + if (!isdigit(SafeFirstChar(str))) return false; int64 result = 0; do { - result = result * 10 + *str - '0'; + result = result * 10 + SafeFirstChar(str) - '0'; if (result > vmax) { return false; } - ++str; - } while (isdigit(*str)); + str.remove_prefix(1); + } while (isdigit(SafeFirstChar(str))); // Skip trailing space. - while (isspace(*str)) ++str; + while (isspace(SafeFirstChar(str))) str.remove_prefix(1); - if (*str) return false; + if (!str.empty()) return false; *value = result * sign; return true; diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h index 02903547a7..68ddc68ebb 100644 --- a/tensorflow/core/lib/strings/numbers.h +++ b/tensorflow/core/lib/strings/numbers.h @@ -95,12 +95,12 @@ bool HexStringToUint64(const StringPiece& s, uint64* v); // Convert strings to 32bit integer values. // Leading and trailing spaces are allowed. // Return false with overflow or invalid input. -bool safe_strto32(const char* str, int32* value); +bool safe_strto32(StringPiece str, int32* value); // Convert strings to 64bit integer values. // Leading and trailing spaces are allowed. // Return false with overflow or invalid input. -bool safe_strto64(const char* str, int64* value); +bool safe_strto64(StringPiece str, int64* value); // Convert strings to floating point values. // Leading and trailing spaces are allowed. diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc index 67fec856a1..88acd57852 100644 --- a/tensorflow/core/lib/strings/numbers_test.cc +++ b/tensorflow/core/lib/strings/numbers_test.cc @@ -110,6 +110,13 @@ TEST(safe_strto32, Int32s) { // Overflow EXPECT_EQ(false, safe_strto32("2147483648", &result)); EXPECT_EQ(false, safe_strto32("-2147483649", &result)); + + // Check that the StringPiece's length is respected. + EXPECT_EQ(true, safe_strto32(StringPiece("123", 1), &result)); + EXPECT_EQ(1, result); + EXPECT_EQ(true, safe_strto32(StringPiece(" -123", 4), &result)); + EXPECT_EQ(-12, result); + EXPECT_EQ(false, safe_strto32(StringPiece(nullptr, 0), &result)); } TEST(safe_strto64, Int64s) { @@ -139,6 +146,13 @@ TEST(safe_strto64, Int64s) { // Overflow EXPECT_EQ(false, safe_strto64("9223372036854775808", &result)); EXPECT_EQ(false, safe_strto64("-9223372036854775809", &result)); + + // Check that the StringPiece's length is respected. + EXPECT_EQ(true, safe_strto64(StringPiece("123", 1), &result)); + EXPECT_EQ(1, result); + EXPECT_EQ(true, safe_strto64(StringPiece(" -123", 4), &result)); + EXPECT_EQ(-12, result); + EXPECT_EQ(false, safe_strto64(StringPiece(nullptr, 0), &result)); } } // namespace strings |