diff options
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_Substr.pbtxt | 10 | ||||
-rw-r--r-- | tensorflow/core/api_def/python_api/api_def_Substr.pbtxt | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_util.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_util.h | 44 | ||||
-rw-r--r-- | tensorflow/core/kernels/substr_op.cc | 162 | ||||
-rw-r--r-- | tensorflow/core/kernels/substr_op_test.cc | 100 | ||||
-rw-r--r-- | tensorflow/core/ops/string_ops.cc | 1 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/substr_op_test.py | 503 | ||||
-rw-r--r-- | tensorflow/python/ops/string_ops.py | 16 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 2 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt | 2 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v2/tensorflow.pbtxt | 2 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt | 2 |
14 files changed, 655 insertions, 208 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt index 5246090ab3..fe0fcc9508 100644 --- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt @@ -18,6 +18,16 @@ END Scalar defining the number of characters to include in each substring END } + attr { + name: "unit" + description: <<END +The unit that is used to create the substring. One of: `"BYTE"` (for +defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8 +encoded Unicode code points). The default is `"BYTE"`. Results are undefined if +`unit=UTF8_CHAR` and the `input` strings do not contain structurally valid +UTF-8. +END + } out_arg { name: "output" description: <<END diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt index 4778d7927c..4fb9ee56e9 100644 --- a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt @@ -1,10 +1,4 @@ op { graph_op_name: "Substr" - endpoint { - name: "strings.substr" - } - endpoint { - name: "substr" - deprecated: true - } + visibility: HIDDEN } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 9439ab332c..3a920f26f3 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4458,7 +4458,12 @@ cc_library( name = "string_util", srcs = ["string_util.cc"], hdrs = ["string_util.h"], - deps = ["//tensorflow/core:lib"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@icu//:common", + ], ) STRING_DEPS = [ diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc index 3a9803a052..92c73220d8 100644 --- a/tensorflow/core/kernels/string_util.cc +++ b/tensorflow/core/kernels/string_util.cc @@ -16,10 +16,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" -namespace { -inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } -} // namespace - namespace tensorflow { // Sets unit value based on str. diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h index 390cf57702..d40e93ea33 100644 --- a/tensorflow/core/kernels/string_util.h +++ b/tensorflow/core/kernels/string_util.h @@ -30,6 +30,9 @@ enum class UnicodeEncoding { UTF8 }; // TODO(edloper): Add support for: UTF32_CHAR, etc. enum class CharUnit { BYTE, UTF8_CHAR }; +// Whether or not the given byte is the trailing byte of a UTF-8/16/32 char. +inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } + // Sets `encoding` based on `str`. Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding); @@ -40,6 +43,47 @@ Status ParseCharUnit(const string& str, CharUnit* unit); // Result may be incorrect if the input string is not valid UTF-8. int32 UTF8StrLen(const string& string); +// Get the next UTF8 character position starting at the given position and +// skipping the given number of characters. Position is a byte offset, and +// should never be `null`. The function return true if successful. However, if +// the end of the string is reached before the requested characters, then the +// position will point to the end of string and this function will return false. +template <typename T> +bool ForwardNUTF8CharPositions(const StringPiece in, + const T num_utf8_chars_to_shift, T* pos) { + const size_t size = in.size(); + T utf8_chars_counted = 0; + while (utf8_chars_counted < num_utf8_chars_to_shift && *pos < size) { + // move forward one utf-8 character + do { + ++*pos; + } while (IsTrailByte(in[*pos]) && *pos < size); + ++utf8_chars_counted; + } + return utf8_chars_counted == num_utf8_chars_to_shift; +} + +// Get the previous UTF8 character position starting at the given position and +// skipping the given number of characters. Position is a byte offset with a +// positive value, relative to the beginning of the string, and should never be +// `null`. The function return true if successful. However, if the beginning of +// the string is reached before the requested character, then the position will +// point to the beginning of the string and this function will return false. +template <typename T> +bool BackNUTF8CharPositions(const StringPiece in, + const T num_utf8_chars_to_shift, T* pos) { + const size_t start = 0; + T utf8_chars_counted = 0; + while (utf8_chars_counted < num_utf8_chars_to_shift && (*pos > start)) { + // move back one utf-8 character + do { + --*pos; + } while (IsTrailByte(in[*pos]) && *pos > start); + ++utf8_chars_counted; + } + return utf8_chars_counted == num_utf8_chars_to_shift; +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 07f1d6e767..93c427039d 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/string_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" @@ -37,7 +38,11 @@ namespace tensorflow { template <typename T> class SubstrOp : public OpKernel { public: - using OpKernel::OpKernel; + explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string unit; + OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit)); + OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_)); + } void Compute(OpKernelContext* context) override { // Get inputs @@ -69,11 +74,23 @@ class SubstrOp : public OpKernel { tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { StringPiece in(input(i)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } } else { @@ -84,11 +101,23 @@ class SubstrOp : public OpKernel { StringPiece in(input(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } } @@ -151,12 +180,24 @@ class SubstrOp : public OpKernel { StringPiece in(input_bcast(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); - OP_REQUIRES( - context, - FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, + FastBoundsCheck(byte_pos, input_bcast(i).size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } break; @@ -205,12 +246,24 @@ class SubstrOp : public OpKernel { tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for ", - "string b'", in, "' at index (", i, - ", ", j, ")")); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index (", + i, ", ", j, ")")); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i, j).assign(sub_in.data(), sub_in.size()); } } @@ -227,12 +280,73 @@ class SubstrOp : public OpKernel { private: // This adjusts the requested position. Note it does not perform any bound // checks. - T AdjustedPosIndex(const T pos_requested, const StringPiece s) { + static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) { if (pos_requested < 0) { return s.size() + pos_requested; } return pos_requested; } + + // Return true if successful; otherwise, return false if the `pos` argument + // is out of range in the string. + static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos, + T* len) { + if (*pos >= 0) { + return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len); + } else { + return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len); + } + } + + static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos, + const T len, T* char_pos, + T* char_len) { + *char_pos = 0; + // Determine byte position of the substring start. + if (!ForwardNUTF8CharPositions(in, pos, char_pos)) { + return false; + } + // Determine position of the end of the substring. + // The length will be capped at the end of the string, and we ignore whether + // the string had enough characters to handle it or not. + *char_len = *char_pos; + ForwardNUTF8CharPositions(in, len, char_len); + // The length in bytes is the position end of the substring less the start. + *char_len = *char_len - *char_pos; + return true; + } + + // This function expects a negative position relative to the end of the + // string, but will update the character position to a positive number + // relative to the beginning of the string. + static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos, + const T len, T* char_pos, + T* char_len) { + // Initially treat the length as position of the end of the substring. + *char_len = in.size(); + // This is the number of character to skip from the end of the string to + // arrive at the position where the substring should end. + T utf8_chars_to_skip = -pos - len; + if (utf8_chars_to_skip < 0) { + utf8_chars_to_skip = 0; + } + // Find the byte position where the substring should end using the computed + // number of characters to skip. + if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) { + return false; + } + // Next, determine where the substring should begin. The number of chars to + // skip is the requested position minus the chars we've previously skipped. + *char_pos = *char_len; + if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) { + return false; + } + // The length in bytes is the position end of the substring less the start. + *char_len = *char_len - *char_pos; + return true; + } + + CharUnit unit_ = CharUnit::BYTE; }; #define REGISTER_SUBSTR(type) \ diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc index 2e07050260..ea6b1ed500 100644 --- a/tensorflow/core/kernels/substr_op_test.cc +++ b/tensorflow/core/kernels/substr_op_test.cc @@ -42,7 +42,7 @@ limitations under the License. namespace tensorflow { // Test data from the TensorFlow README.md. -const char* lines[] = { +const char* ascii_lines[] = { "**TensorFlow** is an open source software library for numerical " "computation using data flow graphs.", "The graph nodes represent mathematical operations, while the graph edges " @@ -64,17 +64,76 @@ const char* lines[] = { "backwards compatibility guarantee like C++, Go, Java, JavaScript and " "Swift."}; +const char* unicode_lines[] = { + "TensorFlow\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe4\xbd\xbf\xe7\x94\xa8\xe6" + "\x95\xb0\xe6\x8d\xae\xe6\xb5\x81\xe5\x9b\xbe\xe8\xbf\x9b\xe8\xa1\x8c\xe6" + "\x95\xb0\xe5\x80\xbc\xe8\xae\xa1\xe7\xae\x97\xe7\x9a\x84\xe5\xbc\x80\xe6" + "\xba\x90\xe8\xbd\xaf\xe4\xbb\xb6\xe5\xba\x93\xe3\x80\x82", + "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\x8a\x82\xe7\x82\xb9\xe8\xa1\xa8\xe7\xa4\xba" + "\xe6\x95\xb0\xe5\xad\xa6\xe8\xbf\x90\xe7\xae\x97\xef\xbc\x8c\xe8\x80\x8c" + "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\xbe\xb9\xe7\xbc\x98\xe8\xa1\xa8\xe7\xa4\xba" + "\xe5\x9c\xa8\xe5\xae\x83\xe4\xbb\xac\xe4\xb9\x8b\xe9\x97\xb4\xe6\xb5\x81" + "\xe5\x8a\xa8\xe7\x9a\x84\xe5\xa4\x9a\xe7\xbb\xb4\xe6\x95\xb0\xe6\x8d\xae" + "\xe9\x98\xb5\xe5\x88\x97\xef\xbc\x88\xe5\xbc\xa0\xe9\x87\x8f\xef\xbc\x89" + "\xe3\x80\x82", + "\xe8\xbf\x99\xe7\xa7\x8d\xe7\x81\xb5\xe6\xb4\xbb\xe7\x9a\x84\xe4\xbd\x93" + "\xe7\xb3\xbb\xe7\xbb\x93\xe6\x9e\x84\xe4\xbd\xbf\xe6\x82\xa8\xe5\x8f\xaf" + "\xe4\xbb\xa5\xe5\xb0\x86\xe8\xae\xa1\xe7\xae\x97\xe9\x83\xa8\xe7\xbd\xb2" + "\xe5\x88\xb0\xe6\xa1\x8c\xe9\x9d\xa2\xef\xbc\x8c\xe6\x9c\x8d\xe5\x8a\xa1" + "\xe5\x99\xa8\xe6\x88\x96\xe7\xa7\xbb\xe5\x8a\xa8\xe8\xae\xbe\xe5\xa4\x87" + "\xe4\xb8\xad\xe7\x9a\x84\xe4\xb8\x80\xe4\xb8\xaa\xe6\x88\x96\xe5\xa4\x9a" + "\xe4\xb8\xaa CPU\xe6\x88\x96GPU\xef\xbc\x8c\xe8\x80\x8c\xe6\x97\xa0\xe9" + "\x9c\x80\xe9\x87\x8d\xe5\x86\x99\xe4\xbb\xa3\xe7\xa0\x81\xe3\x80\x82", + "TensorFlow\xe8\xbf\x98\xe5\x8c\x85\xe6\x8b\xac[TensorBoard]\xef\xbc\x88" + "https://www.tensorflow.org/guide/summaries_and_tensorboard\xef\xbc\x89\xef" + "\xbc\x8c\xe8\xbf\x99\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe6\x95\xb0\xe6" + "\x8d\xae\xe5\x8f\xaf\xe8\xa7\x86\xe5\x8c\x96\xe5\xb7\xa5\xe5\x85\xb7\xe5" + "\x8c\x85\xe3\x80\x82", + "TensorFlow\xe6\x9c\x80\xe5\x88\x9d\xe6\x98\xaf\xe7\x94\xb1\xe7\xa0\x94\xe7" + "\xa9\xb6\xe4\xba\xba\xe5\x91\x98\xe5\x92\x8c\xe5\xb7\xa5\xe7\xa8\x8b\xe5" + "\xb8\x88\xe5\x9c\xa8Google\xe6\x9c\xba\xe5\x99\xa8\xe6\x99\xba\xe8\x83\xbd" + "\xe7\xa0\x94\xe7\xa9\xb6\xe7\xbb\x84\xe7\xbb\x87\xe7\x9a\x84Google Brain" + "\xe5\x9b\xa2\xe9\x98\x9f\xe5\xbc\x80\xe5\x8f\x91\xe7\x9a\x84\xef\xbc\x8c" + "\xe7\x9b\xae\xe7\x9a\x84\xe6\x98\xaf\xe8\xbf\x9b\xe8\xa1\x8c\xe6\x9c\xba" + "\xe5\x99\xa8\xe5\xad\xa6\xe4\xb9\xa0\xe5\x92\x8c\xe6\xb7\xb1\xe5\xba\xa6" + "\xe7\xa5\x9e\xe7\xbb\x8f\xe7\xbd\x91\xe7\xbb\x9c\xe7\xa0\x94\xe7\xa9\xb6" + "\xe3\x80\x82", + "\xe8\xaf\xa5\xe7\xb3\xbb\xe7\xbb\x9f\xe8\xb6\xb3\xe4\xbb\xa5\xe9\x80\x82" + "\xe7\x94\xa8\xe4\xba\x8e\xe5\x90\x84\xe7\xa7\x8d\xe5\x85\xb6\xe4\xbb\x96" + "\xe9\xa2\x86\xe5\x9f\x9f\xe4\xb9\x9f\xe6\x98\xaf\xe5\xa6\x82\xe6\xad\xa4" + "\xe3\x80\x82", + "TensorFlow\xe6\x8f\x90\xe4\xbe\x9b\xe7\xa8\xb3\xe5\xae\x9a\xe7\x9a\x84" + "Python API\xe5\x92\x8c C API\xef\xbc\x8c\xe4\xbb\xa5\xe5\x8f\x8a\xe6\xb2" + "\xa1\xe6\x9c\x89 API\xe5\x90\x91\xe5\x90\x8e\xe5\x85\xbc\xe5\xae\xb9\xe6" + "\x80\xa7\xe4\xbf\x9d\xe8\xaf\x81\xef\xbc\x8c\xe5\xa6\x82 C ++\xef\xbc\x8c" + "Go\xef\xbc\x8cJava\xef\xbc\x8cJavaScript\xe5\x92\x8cSwift\xe3\x80\x82", +}; + +const char* const kByteUnit = "BYTE"; +const char* const kUTF8Unit = "UTF8_CHAR"; + Tensor GetTestTensor(int batch) { - const int sz = TF_ARRAYSIZE(lines); + const int sz = TF_ARRAYSIZE(ascii_lines); + Tensor t(DT_STRING, {batch}); + auto s = t.flat<string>(); + for (int i = 0; i < batch; ++i) { + s(i) = ascii_lines[i % sz]; + } + return t; +} + +Tensor GetTestUTF8Tensor(int batch) { + const int sz = TF_ARRAYSIZE(unicode_lines); Tensor t(DT_STRING, {batch}); auto s = t.flat<string>(); for (int i = 0; i < batch; ++i) { - s(i) = lines[i % sz]; + s(i) = unicode_lines[i % sz]; } return t; } -Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) { +Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len, + const char* const unit) { Graph* g = new Graph(OpRegistry::Global()); Tensor position(DT_INT32, TensorShape({})); position.flat<int32>().setConstant(pos); @@ -85,21 +144,46 @@ Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) { .Input(test::graph::Constant(g, input)) .Input(test::graph::Constant(g, position)) .Input(test::graph::Constant(g, length)) + .Attr("unit", unit) .Finalize(g, nullptr /* node */)); return g; } -void BM_Substr(int iters, int batch_size) { +void BM_SubstrByte(int iters, int batch_size) { testing::StopTiming(); testing::ItemsProcessed(static_cast<int64>(iters)); testing::UseRealTime(); Tensor input = GetTestTensor(batch_size); - Graph* g = SetupSubstrGraph(input, 3, 30); + Graph* g = SetupSubstrGraph(input, 3, 30, kByteUnit); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +void BM_SubstrUTF8(int iters, int batch_size) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast<int64>(iters)); + testing::UseRealTime(); + Tensor input = GetTestUTF8Tensor(batch_size); + Graph* g = SetupSubstrGraph(input, 3, 30, kUTF8Unit); testing::StartTiming(); test::Benchmark("cpu", g).Run(iters); } -BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg( - 256); +BENCHMARK(BM_SubstrByte) + ->Arg(1) + ->Arg(8) + ->Arg(16) + ->Arg(32) + ->Arg(64) + ->Arg(128) + ->Arg(256); +BENCHMARK(BM_SubstrUTF8) + ->Arg(1) + ->Arg(8) + ->Arg(16) + ->Arg(32) + ->Arg(64) + ->Arg(128) + ->Arg(256); } // end namespace tensorflow diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index b4fbde54d9..94d71a4113 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -223,6 +223,7 @@ REGISTER_OP("Substr") .Input("len: T") .Output("output: string") .Attr("T: {int32, int64}") + .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") .SetShapeFn([](InferenceContext* c) { ShapeHandle pos_shape = c->input(1); ShapeHandle len_shape = c->input(2); diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py index cd3fe14883..37aa624b07 100644 --- a/tensorflow/python/kernel_tests/substr_op_test.py +++ b/tensorflow/python/kernel_tests/substr_op_test.py @@ -28,270 +28,448 @@ from tensorflow.python.platform import test class SubstrOpTest(test.TestCase, parameterized.TestCase): - def _testScalarString(self, dtype): - test_string = b"Hello" - position = np.array(1, dtype) + @parameterized.parameters( + (np.int32, 1, "BYTE"), + (np.int64, 1, "BYTE"), + (np.int32, -4, "BYTE"), + (np.int64, -4, "BYTE"), + (np.int32, 1, "UTF8_CHAR"), + (np.int64, 1, "UTF8_CHAR"), + (np.int32, -4, "UTF8_CHAR"), + (np.int64, -4, "UTF8_CHAR"), + ) + def testScalarString(self, dtype, pos, unit): + test_string = { + "BYTE": b"Hello", + "UTF8_CHAR": u"He\xc3\xc3\U0001f604".encode("utf-8"), + }[unit] + expected_value = { + "BYTE": b"ell", + "UTF8_CHAR": u"e\xc3\xc3".encode("utf-8"), + }[unit] + position = np.array(pos, dtype) length = np.array(3, dtype) - expected_value = b"ell" - - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - # Negative position. - test_string = b"Hello" - position = np.array(-4, dtype) + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testScalarString_EdgeCases(self, dtype, unit): + # Empty string + test_string = { + "BYTE": b"", + "UTF8_CHAR": u"".encode("utf-8"), + }[unit] + expected_value = b"" + position = np.array(0, dtype) length = np.array(3, dtype) - expected_value = b"ell" - - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - # Position is equal to the length of string. - test_string = b"" + # Full string + test_string = { + "BYTE": b"Hello", + "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), + }[unit] position = np.array(0, dtype) - length = np.array(2, dtype) - expected_value = b"" - - substr_op = string_ops.substr(test_string, position, length) + length = np.array(5, dtype) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() - self.assertAllEqual(substr, expected_value) - - # Negative position magnitude is equal to the length of string. - test_string = b"yo" - position = np.array(-2, dtype) - length = np.array(1, dtype) - expected_value = b"y" - - substr_op = string_ops.substr(test_string, position, length) + self.assertAllEqual(substr, test_string) + + # Full string (Negative) + test_string = { + "BYTE": b"Hello", + "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), + }[unit] + position = np.array(-5, dtype) + length = np.array(5, dtype) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() - self.assertAllEqual(substr, expected_value) - - def _testVectorStrings(self, dtype): - test_string = [b"Hello", b"World"] - position = np.array(1, dtype) - length = np.array(3, dtype) - expected_value = [b"ell", b"orl"] - - substr_op = string_ops.substr(test_string, position, length) + self.assertAllEqual(substr, test_string) + + # Length is larger in magnitude than a negative position + test_string = { + "BYTE": b"Hello", + "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), + }[unit] + expected_string = { + "BYTE": b"ello", + "UTF8_CHAR": u"\xc3ll\U0001f604".encode("utf-8"), + }[unit] + position = np.array(-4, dtype) + length = np.array(5, dtype) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() - self.assertAllEqual(substr, expected_value) - - # Negative position. - test_string = [b"Hello", b"World"] - position = np.array(-4, dtype) + self.assertAllEqual(substr, expected_string) + + @parameterized.parameters( + (np.int32, 1, "BYTE"), + (np.int64, 1, "BYTE"), + (np.int32, -4, "BYTE"), + (np.int64, -4, "BYTE"), + (np.int32, 1, "UTF8_CHAR"), + (np.int64, 1, "UTF8_CHAR"), + (np.int32, -4, "UTF8_CHAR"), + (np.int64, -4, "UTF8_CHAR"), + ) + def testVectorStrings(self, dtype, pos, unit): + test_string = { + "BYTE": [b"Hello", b"World"], + "UTF8_CHAR": [x.encode("utf-8") for x in [u"H\xc3llo", + u"W\U0001f604rld"]], + }[unit] + expected_value = { + "BYTE": [b"ell", b"orl"], + "UTF8_CHAR": [x.encode("utf-8") for x in [u"\xc3ll", u"\U0001f604rl"]], + }[unit] + position = np.array(pos, dtype) length = np.array(3, dtype) - expected_value = [b"ell", b"orl"] - - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - def _testMatrixStrings(self, dtype): - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"]] + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testMatrixStrings(self, dtype, unit): + test_string = { + "BYTE": [[b"ten", b"eleven", b"twelve"], + [b"thirteen", b"fourteen", b"fifteen"], + [b"sixteen", b"seventeen", b"eighteen"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", + u"\xc6\u053c\u025bv\u025bn", + u"tw\u0c1dlv\u025b"]], + [x.encode("utf-8") for x in [u"He\xc3\xc3o", + u"W\U0001f604rld", + u"d\xfcd\xea"]]], + }[unit] position = np.array(1, dtype) length = np.array(4, dtype) - expected_value = [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"], - [b"ixte", b"even", b"ight"]] - - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"], + [b"ixte", b"even", b"ight"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n", + u"\u053c\u025bv\u025b", + u"w\u0c1dlv"]], + [x.encode("utf-8") for x in [u"e\xc3\xc3o", + u"\U0001f604rld", + u"\xfcd\xea"]]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - # Negative position - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"]] - position = np.array(-2, dtype) + position = np.array(-3, dtype) length = np.array(2, dtype) - expected_value = [[b"en", b"en", b"ve"], [b"en", b"en", b"en"], - [b"en", b"en", b"en"]] - - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [[b"te", b"ve", b"lv"], [b"ee", b"ee", b"ee"], + [b"ee", b"ee", b"ee"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227", + u"v\u025b", u"lv"]], + [x.encode("utf-8") for x in [u"\xc3\xc3", u"rl", + u"\xfcd"]]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - def _testElementWisePosLen(self, dtype): - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"]] + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testElementWisePosLen(self, dtype, unit): + test_string = { + "BYTE": [[b"ten", b"eleven", b"twelve"], + [b"thirteen", b"fourteen", b"fifteen"], + [b"sixteen", b"seventeen", b"eighteen"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", + u"\xc6\u053c\u025bv\u025bn", + u"tw\u0c1dlv\u025b"]], + [x.encode("utf-8") for x in [u"He\xc3\xc3o", + u"W\U0001f604rld", + u"d\xfcd\xea"]], + [x.encode("utf-8") for x in [u"sixt\xea\xean", + u"se\U00010299enteen", + u"ei\U0001e920h\x86een"]]], + }[unit] position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype) length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype) - expected_value = [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"], - [b"xteen", b"vente", b"hteen"]] - - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"], + [b"xteen", b"vente", b"hteen"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n", + u"\u025bv", + u"lv\u025b"]], + [x.encode("utf-8") for x in [u"e\xc3\xc3o", + u"rld", + u"d\xfc"]], + [x.encode("utf-8") for x in [u"xt\xea\xean", + u"\U00010299ente", + u"h\x86een"]]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - def _testBroadcast(self, dtype): + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testBroadcast(self, dtype, unit): # Broadcast pos/len onto input string - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"], - [b"nineteen", b"twenty", b"twentyone"]] + test_string = { + "BYTE": [[b"ten", b"eleven", b"twelve"], + [b"thirteen", b"fourteen", b"fifteen"], + [b"sixteen", b"seventeen", b"eighteen"], + [b"nineteen", b"twenty", b"twentyone"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", + u"\xc6\u053c\u025bv\u025bn", + u"tw\u0c1dlv\u025b"]], + [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean", + u"f\U0001f604urt\xea\xean", + u"f\xcd\ua09ctee\ua0e4"]], + [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean", + u"se\U00010299enteen", + u"ei\U0001e920h\x86een"]], + [x.encode("utf-8") for x in [u"nineteen", + u"twenty", + u"twentyone"]]], + }[unit] position = np.array([1, -4, 3], dtype) length = np.array([1, 2, 3], dtype) - expected_value = [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"], - [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]] - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"], + [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227", + u"\u025bv", u"lv\u025b"]], + [x.encode("utf-8") for x in [u"h", u"t\xea", u"tee"]], + [x.encode("utf-8") for x in [u"\xcd", u"te", u"h\x86e"]], + [x.encode("utf-8") for x in [u"i", u"en", u"nty"]]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) # Broadcast input string onto pos/len - test_string = [b"thirteen", b"fourteen", b"fifteen"] + test_string = { + "BYTE": [b"thirteen", b"fourteen", b"fifteen"], + "UTF8_CHAR": [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean", + u"f\U0001f604urt\xea\xean", + u"f\xcd\ua09ctee\ua0e4"]], + }[unit] position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype) length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype) - expected_value = [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"], - [b"ee", b"ee", b"ft"]] - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"], + [b"ee", b"ee", b"ft"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"h\xcdr", u"\xean", u"t"]], + [x.encode("utf-8") for x in [u"\xea", u"ur", + u"\xcd\ua09ct"]], + [x.encode("utf-8") for x in [u"\xea\xea", u"\xea\xea", + u"\ua09ct"]]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) # Test 1D broadcast - test_string = b"thirteen" - position = np.array([1, -5, 7], dtype) + test_string = { + "BYTE": b"thirteen", + "UTF8_CHAR": u"th\xcdrt\xea\xean".encode("utf-8"), + }[unit] + position = np.array([1, -4, 7], dtype) length = np.array([3, 2, 1], dtype) - expected_value = [b"hir", b"rt", b"n"] - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [b"hir", b"te", b"n"], + "UTF8_CHAR": [x.encode("utf-8") for x in [u"h\xcdr", u"t\xea", u"n"]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - def _testBadBroadcast(self, dtype): + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testBadBroadcast(self, dtype, unit): test_string = [[b"ten", b"eleven", b"twelve"], [b"thirteen", b"fourteen", b"fifteen"], [b"sixteen", b"seventeen", b"eighteen"]] position = np.array([1, 2, -3, 4], dtype) length = np.array([1, 2, 3, 4], dtype) with self.assertRaises(ValueError): - substr_op = string_ops.substr(test_string, position, length) - - def _testOutOfRangeError(self, dtype): + string_ops.substr(test_string, position, length, unit=unit) + + @parameterized.parameters( + (np.int32, 6, "BYTE"), + (np.int64, 6, "BYTE"), + (np.int32, -6, "BYTE"), + (np.int64, -6, "BYTE"), + (np.int32, 6, "UTF8_CHAR"), + (np.int64, 6, "UTF8_CHAR"), + (np.int32, -6, "UTF8_CHAR"), + (np.int64, -6, "UTF8_CHAR"), + ) + def testOutOfRangeError_Scalar(self, dtype, pos, unit): # Scalar/Scalar - test_string = b"Hello" - position = np.array(7, dtype) - length = np.array(3, dtype) - substr_op = string_ops.substr(test_string, position, length) - with self.cached_session(): - with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - - # Scalar/Scalar (with negative) - test_string = b"Hello" - position = np.array(-7, dtype) + test_string = { + "BYTE": b"Hello", + "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), + }[unit] + position = np.array(pos, dtype) length = np.array(3, dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - + substr_op.eval() + + @parameterized.parameters( + (np.int32, 4, "BYTE"), + (np.int64, 4, "BYTE"), + (np.int32, -4, "BYTE"), + (np.int64, -4, "BYTE"), + (np.int32, 4, "UTF8_CHAR"), + (np.int64, 4, "UTF8_CHAR"), + (np.int32, -4, "UTF8_CHAR"), + (np.int64, -4, "UTF8_CHAR"), + ) + def testOutOfRangeError_VectorScalar(self, dtype, pos, unit): # Vector/Scalar - test_string = [b"good", b"good", b"bad", b"good"] - position = np.array(4, dtype) - length = np.array(1, dtype) - substr_op = string_ops.substr(test_string, position, length) - with self.cached_session(): - with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - - # Vector/Scalar (with negative) - test_string = [b"good", b"good", b"bad", b"good"] - position = np.array(-4, dtype) + test_string = { + "BYTE": [b"good", b"good", b"bad", b"good"], + "UTF8_CHAR": [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"b\xc3d", + u"g\xc3\xc3d"]], + }[unit] + position = np.array(pos, dtype) length = np.array(1, dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - + substr_op.eval() + + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testOutOfRangeError_MatrixMatrix(self, dtype, unit): # Matrix/Matrix - test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"], - [b"good", b"good", b"good"]] + test_string = { + "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"], + [b"good", b"good", b"good"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", + u"g\xc3\xc3d"]], + [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", + u"b\xc3d"]], + [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", + u"g\xc3\xc3d"]]], + }[unit] position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype) length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() + substr_op.eval() # Matrix/Matrix (with negative) - test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"], - [b"good", b"good", b"good"]] position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype) length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - + substr_op.eval() + + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testOutOfRangeError_Broadcast(self, dtype, unit): # Broadcast - test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]] + test_string = { + "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", + u"g\xc3\xc3d"]], + [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", + u"b\xc3d"]]], + }[unit] position = np.array([1, 2, 4], dtype) length = np.array([1, 2, 3], dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() + substr_op.eval() # Broadcast (with negative) - test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]] position = np.array([-1, -2, -4], dtype) length = np.array([1, 2, 3], dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - - def _testMismatchPosLenShapes(self, dtype): - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"]] + substr_op.eval() + + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testMismatchPosLenShapes(self, dtype, unit): + test_string = { + "BYTE": [[b"ten", b"eleven", b"twelve"], + [b"thirteen", b"fourteen", b"fifteen"], + [b"sixteen", b"seventeen", b"eighteen"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", + u"\xc6\u053c\u025bv\u025bn", + u"tw\u0c1dlv\u025b"]], + [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean", + u"f\U0001f604urt\xea\xean", + u"f\xcd\ua09ctee\ua0e4"]], + [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean", + u"se\U00010299enteen", + u"ei\U0001e920h\x86een"]]], + }[unit] position = np.array([[1, 2, 3]], dtype) length = np.array([2, 3, 4], dtype) # Should fail: position/length have different rank with self.assertRaises(ValueError): - substr_op = string_ops.substr(test_string, position, length) + string_ops.substr(test_string, position, length) position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype) length = np.array([[2, 3, 4]], dtype) # Should fail: position/length have different dimensionality with self.assertRaises(ValueError): - substr_op = string_ops.substr(test_string, position, length) - - # Negative position. - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"]] - position = np.array([[-1, -2, -3]], dtype) - length = np.array([1, 2, 3], dtype) - # Should fail: position/length have different rank - with self.assertRaises(ValueError): - substr_op = string_ops.substr(test_string, position, length) - - @parameterized.parameters(np.int32, np.int64) - def testAll(self, dtype): - self._testScalarString(dtype) - self._testVectorStrings(dtype) - self._testMatrixStrings(dtype) - self._testElementWisePosLen(dtype) - self._testBroadcast(dtype) - self._testBadBroadcast(dtype) - self._testOutOfRangeError(dtype) - self._testMismatchPosLenShapes(dtype) + string_ops.substr(test_string, position, length) def testWrongDtype(self): with self.cached_session(): @@ -300,6 +478,11 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase): with self.assertRaises(TypeError): string_ops.substr(b"test", 3, 1.0) + def testInvalidUnit(self): + with self.cached_session(): + with self.assertRaises(ValueError): + string_ops.substr(b"test", 3, 1, unit="UTF8") + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index 0812f901a2..f26388efea 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -347,6 +347,22 @@ def string_length(input, name=None, unit="BYTE"): string_length.__doc__ = gen_string_ops.string_length.__doc__ +@tf_export("substr") +@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.") +def substr_deprecated(input, pos, len, name=None, unit="BYTE"): + return substr(input, pos, len, name=name, unit=unit) + +substr_deprecated.__doc__ = gen_string_ops.substr.__doc__ + + +@tf_export("strings.substr") +def substr(input, pos, len, name=None, unit="BYTE"): + return gen_string_ops.substr(input, pos, len, unit=unit, name=name) + + +substr.__doc__ = gen_string_ops.substr.__doc__ + + ops.NotDifferentiable("RegexReplace") ops.NotDifferentiable("StringToHashBucket") ops.NotDifferentiable("StringToHashBucketFast") diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index c1cc7322f0..247dfcc1ca 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -2094,7 +2094,7 @@ tf_module { } member_method { name: "substr" - argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], " } member_method { name: "subtract" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt index ebdaf57231..5ba48e7f57 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt @@ -34,7 +34,7 @@ tf_module { } member_method { name: "substr" - argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], " } member_method { name: "to_hash_bucket" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 571abc3b19..978afcf985 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -1934,7 +1934,7 @@ tf_module { } member_method { name: "substr" - argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], " } member_method { name: "subtract" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt index ebdaf57231..5ba48e7f57 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt @@ -34,7 +34,7 @@ tf_module { } member_method { name: "substr" - argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], " } member_method { name: "to_hash_bucket" |