diff options
author | 2018-10-04 11:30:52 -0700 | |
---|---|---|
committer | 2018-10-04 11:34:46 -0700 | |
commit | 700c3325311e16be9bb4856cbf944d1871ff35c1 (patch) | |
tree | 9ae88328889950abaa951a628de7212caec8c026 /tensorflow/core | |
parent | c8d5054e8c12800f0c3db0e51f3d5902e04eaa37 (diff) |
Add "encoding" attribute to string substr op, which controls how each "character" is treated:
* BYTE: Position & length refer to bytes in the string. (Default)
* UTF8: The string is interpreted as UTF-8 encoded Unicode code points, and position & length are treated relative to them.
RELNOTES: Add option to get substring using Unicode characters
PiperOrigin-RevId: 215773373
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_Substr.pbtxt | 10 | ||||
-rw-r--r-- | tensorflow/core/api_def/python_api/api_def_Substr.pbtxt | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_util.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_util.h | 44 | ||||
-rw-r--r-- | tensorflow/core/kernels/substr_op.cc | 162 | ||||
-rw-r--r-- | tensorflow/core/kernels/substr_op_test.cc | 100 | ||||
-rw-r--r-- | tensorflow/core/ops/string_ops.cc | 1 |
8 files changed, 292 insertions, 44 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt index 5246090ab3..fe0fcc9508 100644 --- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt @@ -18,6 +18,16 @@ END Scalar defining the number of characters to include in each substring END } + attr { + name: "unit" + description: <<END +The unit that is used to create the substring. One of: `"BYTE"` (for +defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8 +encoded Unicode code points). The default is `"BYTE"`. Results are undefined if +`unit=UTF8_CHAR` and the `input` strings do not contain structurally valid +UTF-8. +END + } out_arg { name: "output" description: <<END diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt index 4778d7927c..4fb9ee56e9 100644 --- a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt @@ -1,10 +1,4 @@ op { graph_op_name: "Substr" - endpoint { - name: "strings.substr" - } - endpoint { - name: "substr" - deprecated: true - } + visibility: HIDDEN } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 9439ab332c..3a920f26f3 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4458,7 +4458,12 @@ cc_library( name = "string_util", srcs = ["string_util.cc"], hdrs = ["string_util.h"], - deps = ["//tensorflow/core:lib"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@icu//:common", + ], ) STRING_DEPS = [ diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc index 3a9803a052..92c73220d8 100644 --- a/tensorflow/core/kernels/string_util.cc +++ b/tensorflow/core/kernels/string_util.cc @@ -16,10 +16,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" -namespace { -inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } -} // namespace - namespace tensorflow { // Sets unit value based on str. diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h index 390cf57702..d40e93ea33 100644 --- a/tensorflow/core/kernels/string_util.h +++ b/tensorflow/core/kernels/string_util.h @@ -30,6 +30,9 @@ enum class UnicodeEncoding { UTF8 }; // TODO(edloper): Add support for: UTF32_CHAR, etc. enum class CharUnit { BYTE, UTF8_CHAR }; +// Whether or not the given byte is the trailing byte of a UTF-8/16/32 char. +inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } + // Sets `encoding` based on `str`. Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding); @@ -40,6 +43,47 @@ Status ParseCharUnit(const string& str, CharUnit* unit); // Result may be incorrect if the input string is not valid UTF-8. int32 UTF8StrLen(const string& string); +// Get the next UTF8 character position starting at the given position and +// skipping the given number of characters. Position is a byte offset, and +// should never be `null`. The function return true if successful. However, if +// the end of the string is reached before the requested characters, then the +// position will point to the end of string and this function will return false. +template <typename T> +bool ForwardNUTF8CharPositions(const StringPiece in, + const T num_utf8_chars_to_shift, T* pos) { + const size_t size = in.size(); + T utf8_chars_counted = 0; + while (utf8_chars_counted < num_utf8_chars_to_shift && *pos < size) { + // move forward one utf-8 character + do { + ++*pos; + } while (IsTrailByte(in[*pos]) && *pos < size); + ++utf8_chars_counted; + } + return utf8_chars_counted == num_utf8_chars_to_shift; +} + +// Get the previous UTF8 character position starting at the given position and +// skipping the given number of characters. Position is a byte offset with a +// positive value, relative to the beginning of the string, and should never be +// `null`. The function return true if successful. However, if the beginning of +// the string is reached before the requested character, then the position will +// point to the beginning of the string and this function will return false. +template <typename T> +bool BackNUTF8CharPositions(const StringPiece in, + const T num_utf8_chars_to_shift, T* pos) { + const size_t start = 0; + T utf8_chars_counted = 0; + while (utf8_chars_counted < num_utf8_chars_to_shift && (*pos > start)) { + // move back one utf-8 character + do { + --*pos; + } while (IsTrailByte(in[*pos]) && *pos > start); + ++utf8_chars_counted; + } + return utf8_chars_counted == num_utf8_chars_to_shift; +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 07f1d6e767..93c427039d 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/string_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" @@ -37,7 +38,11 @@ namespace tensorflow { template <typename T> class SubstrOp : public OpKernel { public: - using OpKernel::OpKernel; + explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string unit; + OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit)); + OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_)); + } void Compute(OpKernelContext* context) override { // Get inputs @@ -69,11 +74,23 @@ class SubstrOp : public OpKernel { tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { StringPiece in(input(i)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } } else { @@ -84,11 +101,23 @@ class SubstrOp : public OpKernel { StringPiece in(input(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } } @@ -151,12 +180,24 @@ class SubstrOp : public OpKernel { StringPiece in(input_bcast(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); - OP_REQUIRES( - context, - FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, + FastBoundsCheck(byte_pos, input_bcast(i).size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } break; @@ -205,12 +246,24 @@ class SubstrOp : public OpKernel { tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for ", - "string b'", in, "' at index (", i, - ", ", j, ")")); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index (", + i, ", ", j, ")")); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i, j).assign(sub_in.data(), sub_in.size()); } } @@ -227,12 +280,73 @@ class SubstrOp : public OpKernel { private: // This adjusts the requested position. Note it does not perform any bound // checks. - T AdjustedPosIndex(const T pos_requested, const StringPiece s) { + static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) { if (pos_requested < 0) { return s.size() + pos_requested; } return pos_requested; } + + // Return true if successful; otherwise, return false if the `pos` argument + // is out of range in the string. + static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos, + T* len) { + if (*pos >= 0) { + return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len); + } else { + return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len); + } + } + + static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos, + const T len, T* char_pos, + T* char_len) { + *char_pos = 0; + // Determine byte position of the substring start. + if (!ForwardNUTF8CharPositions(in, pos, char_pos)) { + return false; + } + // Determine position of the end of the substring. + // The length will be capped at the end of the string, and we ignore whether + // the string had enough characters to handle it or not. + *char_len = *char_pos; + ForwardNUTF8CharPositions(in, len, char_len); + // The length in bytes is the position end of the substring less the start. + *char_len = *char_len - *char_pos; + return true; + } + + // This function expects a negative position relative to the end of the + // string, but will update the character position to a positive number + // relative to the beginning of the string. + static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos, + const T len, T* char_pos, + T* char_len) { + // Initially treat the length as position of the end of the substring. + *char_len = in.size(); + // This is the number of character to skip from the end of the string to + // arrive at the position where the substring should end. + T utf8_chars_to_skip = -pos - len; + if (utf8_chars_to_skip < 0) { + utf8_chars_to_skip = 0; + } + // Find the byte position where the substring should end using the computed + // number of characters to skip. + if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) { + return false; + } + // Next, determine where the substring should begin. The number of chars to + // skip is the requested position minus the chars we've previously skipped. + *char_pos = *char_len; + if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) { + return false; + } + // The length in bytes is the position end of the substring less the start. + *char_len = *char_len - *char_pos; + return true; + } + + CharUnit unit_ = CharUnit::BYTE; }; #define REGISTER_SUBSTR(type) \ diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc index 2e07050260..ea6b1ed500 100644 --- a/tensorflow/core/kernels/substr_op_test.cc +++ b/tensorflow/core/kernels/substr_op_test.cc @@ -42,7 +42,7 @@ limitations under the License. namespace tensorflow { // Test data from the TensorFlow README.md. -const char* lines[] = { +const char* ascii_lines[] = { "**TensorFlow** is an open source software library for numerical " "computation using data flow graphs.", "The graph nodes represent mathematical operations, while the graph edges " @@ -64,17 +64,76 @@ const char* lines[] = { "backwards compatibility guarantee like C++, Go, Java, JavaScript and " "Swift."}; +const char* unicode_lines[] = { + "TensorFlow\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe4\xbd\xbf\xe7\x94\xa8\xe6" + "\x95\xb0\xe6\x8d\xae\xe6\xb5\x81\xe5\x9b\xbe\xe8\xbf\x9b\xe8\xa1\x8c\xe6" + "\x95\xb0\xe5\x80\xbc\xe8\xae\xa1\xe7\xae\x97\xe7\x9a\x84\xe5\xbc\x80\xe6" + "\xba\x90\xe8\xbd\xaf\xe4\xbb\xb6\xe5\xba\x93\xe3\x80\x82", + "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\x8a\x82\xe7\x82\xb9\xe8\xa1\xa8\xe7\xa4\xba" + "\xe6\x95\xb0\xe5\xad\xa6\xe8\xbf\x90\xe7\xae\x97\xef\xbc\x8c\xe8\x80\x8c" + "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\xbe\xb9\xe7\xbc\x98\xe8\xa1\xa8\xe7\xa4\xba" + "\xe5\x9c\xa8\xe5\xae\x83\xe4\xbb\xac\xe4\xb9\x8b\xe9\x97\xb4\xe6\xb5\x81" + "\xe5\x8a\xa8\xe7\x9a\x84\xe5\xa4\x9a\xe7\xbb\xb4\xe6\x95\xb0\xe6\x8d\xae" + "\xe9\x98\xb5\xe5\x88\x97\xef\xbc\x88\xe5\xbc\xa0\xe9\x87\x8f\xef\xbc\x89" + "\xe3\x80\x82", + "\xe8\xbf\x99\xe7\xa7\x8d\xe7\x81\xb5\xe6\xb4\xbb\xe7\x9a\x84\xe4\xbd\x93" + "\xe7\xb3\xbb\xe7\xbb\x93\xe6\x9e\x84\xe4\xbd\xbf\xe6\x82\xa8\xe5\x8f\xaf" + "\xe4\xbb\xa5\xe5\xb0\x86\xe8\xae\xa1\xe7\xae\x97\xe9\x83\xa8\xe7\xbd\xb2" + "\xe5\x88\xb0\xe6\xa1\x8c\xe9\x9d\xa2\xef\xbc\x8c\xe6\x9c\x8d\xe5\x8a\xa1" + "\xe5\x99\xa8\xe6\x88\x96\xe7\xa7\xbb\xe5\x8a\xa8\xe8\xae\xbe\xe5\xa4\x87" + "\xe4\xb8\xad\xe7\x9a\x84\xe4\xb8\x80\xe4\xb8\xaa\xe6\x88\x96\xe5\xa4\x9a" + "\xe4\xb8\xaa CPU\xe6\x88\x96GPU\xef\xbc\x8c\xe8\x80\x8c\xe6\x97\xa0\xe9" + "\x9c\x80\xe9\x87\x8d\xe5\x86\x99\xe4\xbb\xa3\xe7\xa0\x81\xe3\x80\x82", + "TensorFlow\xe8\xbf\x98\xe5\x8c\x85\xe6\x8b\xac[TensorBoard]\xef\xbc\x88" + "https://www.tensorflow.org/guide/summaries_and_tensorboard\xef\xbc\x89\xef" + "\xbc\x8c\xe8\xbf\x99\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe6\x95\xb0\xe6" + "\x8d\xae\xe5\x8f\xaf\xe8\xa7\x86\xe5\x8c\x96\xe5\xb7\xa5\xe5\x85\xb7\xe5" + "\x8c\x85\xe3\x80\x82", + "TensorFlow\xe6\x9c\x80\xe5\x88\x9d\xe6\x98\xaf\xe7\x94\xb1\xe7\xa0\x94\xe7" + "\xa9\xb6\xe4\xba\xba\xe5\x91\x98\xe5\x92\x8c\xe5\xb7\xa5\xe7\xa8\x8b\xe5" + "\xb8\x88\xe5\x9c\xa8Google\xe6\x9c\xba\xe5\x99\xa8\xe6\x99\xba\xe8\x83\xbd" + "\xe7\xa0\x94\xe7\xa9\xb6\xe7\xbb\x84\xe7\xbb\x87\xe7\x9a\x84Google Brain" + "\xe5\x9b\xa2\xe9\x98\x9f\xe5\xbc\x80\xe5\x8f\x91\xe7\x9a\x84\xef\xbc\x8c" + "\xe7\x9b\xae\xe7\x9a\x84\xe6\x98\xaf\xe8\xbf\x9b\xe8\xa1\x8c\xe6\x9c\xba" + "\xe5\x99\xa8\xe5\xad\xa6\xe4\xb9\xa0\xe5\x92\x8c\xe6\xb7\xb1\xe5\xba\xa6" + "\xe7\xa5\x9e\xe7\xbb\x8f\xe7\xbd\x91\xe7\xbb\x9c\xe7\xa0\x94\xe7\xa9\xb6" + "\xe3\x80\x82", + "\xe8\xaf\xa5\xe7\xb3\xbb\xe7\xbb\x9f\xe8\xb6\xb3\xe4\xbb\xa5\xe9\x80\x82" + "\xe7\x94\xa8\xe4\xba\x8e\xe5\x90\x84\xe7\xa7\x8d\xe5\x85\xb6\xe4\xbb\x96" + "\xe9\xa2\x86\xe5\x9f\x9f\xe4\xb9\x9f\xe6\x98\xaf\xe5\xa6\x82\xe6\xad\xa4" + "\xe3\x80\x82", + "TensorFlow\xe6\x8f\x90\xe4\xbe\x9b\xe7\xa8\xb3\xe5\xae\x9a\xe7\x9a\x84" + "Python API\xe5\x92\x8c C API\xef\xbc\x8c\xe4\xbb\xa5\xe5\x8f\x8a\xe6\xb2" + "\xa1\xe6\x9c\x89 API\xe5\x90\x91\xe5\x90\x8e\xe5\x85\xbc\xe5\xae\xb9\xe6" + "\x80\xa7\xe4\xbf\x9d\xe8\xaf\x81\xef\xbc\x8c\xe5\xa6\x82 C ++\xef\xbc\x8c" + "Go\xef\xbc\x8cJava\xef\xbc\x8cJavaScript\xe5\x92\x8cSwift\xe3\x80\x82", +}; + +const char* const kByteUnit = "BYTE"; +const char* const kUTF8Unit = "UTF8_CHAR"; + Tensor GetTestTensor(int batch) { - const int sz = TF_ARRAYSIZE(lines); + const int sz = TF_ARRAYSIZE(ascii_lines); + Tensor t(DT_STRING, {batch}); + auto s = t.flat<string>(); + for (int i = 0; i < batch; ++i) { + s(i) = ascii_lines[i % sz]; + } + return t; +} + +Tensor GetTestUTF8Tensor(int batch) { + const int sz = TF_ARRAYSIZE(unicode_lines); Tensor t(DT_STRING, {batch}); auto s = t.flat<string>(); for (int i = 0; i < batch; ++i) { - s(i) = lines[i % sz]; + s(i) = unicode_lines[i % sz]; } return t; } -Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) { +Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len, + const char* const unit) { Graph* g = new Graph(OpRegistry::Global()); Tensor position(DT_INT32, TensorShape({})); position.flat<int32>().setConstant(pos); @@ -85,21 +144,46 @@ Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) { .Input(test::graph::Constant(g, input)) .Input(test::graph::Constant(g, position)) .Input(test::graph::Constant(g, length)) + .Attr("unit", unit) .Finalize(g, nullptr /* node */)); return g; } -void BM_Substr(int iters, int batch_size) { +void BM_SubstrByte(int iters, int batch_size) { testing::StopTiming(); testing::ItemsProcessed(static_cast<int64>(iters)); testing::UseRealTime(); Tensor input = GetTestTensor(batch_size); - Graph* g = SetupSubstrGraph(input, 3, 30); + Graph* g = SetupSubstrGraph(input, 3, 30, kByteUnit); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +void BM_SubstrUTF8(int iters, int batch_size) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast<int64>(iters)); + testing::UseRealTime(); + Tensor input = GetTestUTF8Tensor(batch_size); + Graph* g = SetupSubstrGraph(input, 3, 30, kUTF8Unit); testing::StartTiming(); test::Benchmark("cpu", g).Run(iters); } -BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg( - 256); +BENCHMARK(BM_SubstrByte) + ->Arg(1) + ->Arg(8) + ->Arg(16) + ->Arg(32) + ->Arg(64) + ->Arg(128) + ->Arg(256); +BENCHMARK(BM_SubstrUTF8) + ->Arg(1) + ->Arg(8) + ->Arg(16) + ->Arg(32) + ->Arg(64) + ->Arg(128) + ->Arg(256); } // end namespace tensorflow diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index b4fbde54d9..94d71a4113 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -223,6 +223,7 @@ REGISTER_OP("Substr") .Input("len: T") .Output("output: string") .Attr("T: {int32, int64}") + .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") .SetShapeFn([](InferenceContext* c) { ShapeHandle pos_shape = c->input(1); ShapeHandle len_shape = c->input(2); |