aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-11 18:29:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-11 20:46:08 -0800
commit3b55e1f4f4be8fd4a6a5084edf9daf01e0990c3c (patch)
tree16934f8a8322cd47b54bc43110337dde23c80811
parent90b1700d0df12cd03a4bfb75743bcf60a3c90255 (diff)
Change safe_strto32 and safe_strto64 to accept StringPiece. Updates callers to
pass the StringPiece values. Change: 117027762
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc4
-rw-r--r--tensorflow/core/framework/op_def_builder.cc2
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc4
-rw-r--r--tensorflow/core/kernels/string_to_number_op.cc11
-rw-r--r--tensorflow/core/lib/core/stringpiece.cc8
-rw-r--r--tensorflow/core/lib/core/stringpiece.h8
-rw-r--r--tensorflow/core/lib/strings/numbers.cc53
-rw-r--r--tensorflow/core/lib/strings/numbers.h4
-rw-r--r--tensorflow/core/lib/strings/numbers_test.cc14
9 files changed, 60 insertions, 48 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index c46c785744..53d1d7ffdd 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -667,9 +667,9 @@ struct CudaVersion {
size_t dot_pos = version_name.find('.');
CHECK(dot_pos != string::npos);
string major_str = version_name.substr(0, dot_pos);
- CHECK(strings::safe_strto32(major_str.c_str(), &major_part));
+ CHECK(strings::safe_strto32(major_str, &major_part));
string minor_str = version_name.substr(dot_pos + 1);
- CHECK(strings::safe_strto32(minor_str.c_str(), &minor_part));
+ CHECK(strings::safe_strto32(minor_str, &minor_part));
}
CudaVersion() {}
bool operator<(const CudaVersion& other) const {
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index ee80291759..5931b1ace1 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -92,7 +92,7 @@ bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
return false;
}
int64 value = 0;
- if (!strings::safe_strto64(match.ToString().c_str(), &value)) {
+ if (!strings::safe_strto64(match, &value)) {
return false;
}
*out = value;
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index 18ee40e623..60f0474103 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -88,7 +88,7 @@ class DecodeCSVOp : public OpKernel {
output[f]->flat<int32>()(i) = record_defaults[f].flat<int32>()(0);
} else {
int32 value;
- OP_REQUIRES(ctx, strings::safe_strto32(fields[f].c_str(), &value),
+ OP_REQUIRES(ctx, strings::safe_strto32(fields[f], &value),
errors::InvalidArgument("Field ", f, " in record ", i,
" is not a valid int32: ",
fields[f]));
@@ -108,7 +108,7 @@ class DecodeCSVOp : public OpKernel {
output[f]->flat<int64>()(i) = record_defaults[f].flat<int64>()(0);
} else {
int64 value;
- OP_REQUIRES(ctx, strings::safe_strto64(fields[f].c_str(), &value),
+ OP_REQUIRES(ctx, strings::safe_strto64(fields[f], &value),
errors::InvalidArgument("Field ", f, " in record ", i,
" is not a valid int64: ",
fields[f]));
diff --git a/tensorflow/core/kernels/string_to_number_op.cc b/tensorflow/core/kernels/string_to_number_op.cc
index 34abae788f..bc8a781cf7 100644
--- a/tensorflow/core/kernels/string_to_number_op.cc
+++ b/tensorflow/core/kernels/string_to_number_op.cc
@@ -49,25 +49,24 @@ class StringToNumberOp : public OpKernel {
auto output_flat = output_tensor->flat<OutputType>();
for (int i = 0; i < input_flat.size(); ++i) {
- const char* s = input_flat(i).data();
- Convert(s, &output_flat(i), context);
+ Convert(input_flat(i), &output_flat(i), context);
}
}
private:
- void Convert(const char* s, OutputType* output_data,
+ void Convert(const string& s, OutputType* output_data,
OpKernelContext* context);
};
template <>
-void StringToNumberOp<float>::Convert(const char* s, float* output_data,
+void StringToNumberOp<float>::Convert(const string& s, float* output_data,
OpKernelContext* context) {
- OP_REQUIRES(context, strings::safe_strtof(s, output_data),
+ OP_REQUIRES(context, strings::safe_strtof(s.c_str(), output_data),
errors::InvalidArgument(kErrorMessage, s));
}
template <>
-void StringToNumberOp<int32>::Convert(const char* s, int32* output_data,
+void StringToNumberOp<int32>::Convert(const string& s, int32* output_data,
OpKernelContext* context) {
OP_REQUIRES(context, strings::safe_strto32(s, output_data),
errors::InvalidArgument(kErrorMessage, s));
diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc
index 4a751881f7..17a143c550 100644
--- a/tensorflow/core/lib/core/stringpiece.cc
+++ b/tensorflow/core/lib/core/stringpiece.cc
@@ -54,14 +54,6 @@ size_t StringPiece::rfind(char c, size_t pos) const {
return npos;
}
-bool StringPiece::Consume(StringPiece x) {
- if (starts_with(x)) {
- remove_prefix(x.size_);
- return true;
- }
- return false;
-}
-
StringPiece StringPiece::substr(size_t pos, size_t n) const {
if (pos > size_) pos = size_;
if (n > size_ - pos) n = size_ - pos;
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index 7ac8df89a6..4740748f2d 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -104,7 +104,13 @@ class StringPiece {
// Checks whether StringPiece starts with x and if so advances the beginning
// of it to past the match. It's basically a shortcut for starts_with
// followed by remove_prefix.
- bool Consume(StringPiece x);
+ bool Consume(StringPiece x) {
+ if (starts_with(x)) {
+ remove_prefix(x.size_);
+ return true;
+ }
+ return false;
+ }
StringPiece substr(size_t pos, size_t n = npos) const;
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
index 5e63eb9c01..9e3a6d5458 100644
--- a/tensorflow/core/lib/strings/numbers.cc
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -103,83 +103,84 @@ char* DoubleToBuffer(double value, char* buffer) {
return buffer;
}
-bool safe_strto64(const char* str, int64* value) {
- if (!str) return false;
+namespace {
+char SafeFirstChar(StringPiece str) {
+ if (str.empty()) return '\0';
+ return str[0];
+}
+} // namespace
+bool safe_strto64(StringPiece str, int64* value) {
// Skip leading space.
- while (isspace(*str)) ++str;
+ while (isspace(SafeFirstChar(str))) str.remove_prefix(1);
int64 vlimit = kint64max;
int sign = 1;
- if (*str == '-') {
+ if (str.Consume("-")) {
sign = -1;
- ++str;
// Different limit for positive and negative integers.
vlimit = kint64min;
}
- if (!isdigit(*str)) return false;
+ if (!isdigit(SafeFirstChar(str))) return false;
int64 result = 0;
if (sign == 1) {
do {
- int digit = *str - '0';
+ int digit = SafeFirstChar(str) - '0';
if ((vlimit - digit) / 10 < result) {
return false;
}
result = result * 10 + digit;
- ++str;
- } while (isdigit(*str));
+ str.remove_prefix(1);
+ } while (isdigit(SafeFirstChar(str)));
} else {
do {
- int digit = *str - '0';
+ int digit = SafeFirstChar(str) - '0';
if ((vlimit + digit) / 10 > result) {
return false;
}
result = result * 10 - digit;
- ++str;
- } while (isdigit(*str));
+ str.remove_prefix(1);
+ } while (isdigit(SafeFirstChar(str)));
}
// Skip trailing space.
- while (isspace(*str)) ++str;
+ while (isspace(SafeFirstChar(str))) str.remove_prefix(1);
- if (*str) return false;
+ if (!str.empty()) return false;
*value = result;
return true;
}
-bool safe_strto32(const char* str, int32* value) {
- if (!str) return false;
-
+bool safe_strto32(StringPiece str, int32* value) {
// Skip leading space.
- while (isspace(*str)) ++str;
+ while (isspace(SafeFirstChar(str))) str.remove_prefix(1);
int64 vmax = kint32max;
int sign = 1;
- if (*str == '-') {
+ if (str.Consume("-")) {
sign = -1;
- ++str;
// Different max for positive and negative integers.
++vmax;
}
- if (!isdigit(*str)) return false;
+ if (!isdigit(SafeFirstChar(str))) return false;
int64 result = 0;
do {
- result = result * 10 + *str - '0';
+ result = result * 10 + SafeFirstChar(str) - '0';
if (result > vmax) {
return false;
}
- ++str;
- } while (isdigit(*str));
+ str.remove_prefix(1);
+ } while (isdigit(SafeFirstChar(str)));
// Skip trailing space.
- while (isspace(*str)) ++str;
+ while (isspace(SafeFirstChar(str))) str.remove_prefix(1);
- if (*str) return false;
+ if (!str.empty()) return false;
*value = result * sign;
return true;
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
index 02903547a7..68ddc68ebb 100644
--- a/tensorflow/core/lib/strings/numbers.h
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -95,12 +95,12 @@ bool HexStringToUint64(const StringPiece& s, uint64* v);
// Convert strings to 32bit integer values.
// Leading and trailing spaces are allowed.
// Return false with overflow or invalid input.
-bool safe_strto32(const char* str, int32* value);
+bool safe_strto32(StringPiece str, int32* value);
// Convert strings to 64bit integer values.
// Leading and trailing spaces are allowed.
// Return false with overflow or invalid input.
-bool safe_strto64(const char* str, int64* value);
+bool safe_strto64(StringPiece str, int64* value);
// Convert strings to floating point values.
// Leading and trailing spaces are allowed.
diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc
index 67fec856a1..88acd57852 100644
--- a/tensorflow/core/lib/strings/numbers_test.cc
+++ b/tensorflow/core/lib/strings/numbers_test.cc
@@ -110,6 +110,13 @@ TEST(safe_strto32, Int32s) {
// Overflow
EXPECT_EQ(false, safe_strto32("2147483648", &result));
EXPECT_EQ(false, safe_strto32("-2147483649", &result));
+
+ // Check that the StringPiece's length is respected.
+ EXPECT_EQ(true, safe_strto32(StringPiece("123", 1), &result));
+ EXPECT_EQ(1, result);
+ EXPECT_EQ(true, safe_strto32(StringPiece(" -123", 4), &result));
+ EXPECT_EQ(-12, result);
+ EXPECT_EQ(false, safe_strto32(StringPiece(nullptr, 0), &result));
}
TEST(safe_strto64, Int64s) {
@@ -139,6 +146,13 @@ TEST(safe_strto64, Int64s) {
// Overflow
EXPECT_EQ(false, safe_strto64("9223372036854775808", &result));
EXPECT_EQ(false, safe_strto64("-9223372036854775809", &result));
+
+ // Check that the StringPiece's length is respected.
+ EXPECT_EQ(true, safe_strto64(StringPiece("123", 1), &result));
+ EXPECT_EQ(1, result);
+ EXPECT_EQ(true, safe_strto64(StringPiece(" -123", 4), &result));
+ EXPECT_EQ(-12, result);
+ EXPECT_EQ(false, safe_strto64(StringPiece(nullptr, 0), &result));
}
} // namespace strings