aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 11:30:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 11:34:46 -0700
commit700c3325311e16be9bb4856cbf944d1871ff35c1 (patch)
tree9ae88328889950abaa951a628de7212caec8c026
parentc8d5054e8c12800f0c3db0e51f3d5902e04eaa37 (diff)
Add "encoding" attribute to string substr op, which controls how each "character" is treated:
* BYTE: Position & length refer to bytes in the string. (Default) * UTF8: The string is interpreted as UTF-8 encoded Unicode code points, and position & length are treated relative to them. RELNOTES: Add option to get substring using Unicode characters PiperOrigin-RevId: 215773373
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Substr.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Substr.pbtxt8
-rw-r--r--tensorflow/core/kernels/BUILD7
-rw-r--r--tensorflow/core/kernels/string_util.cc4
-rw-r--r--tensorflow/core/kernels/string_util.h44
-rw-r--r--tensorflow/core/kernels/substr_op.cc162
-rw-r--r--tensorflow/core/kernels/substr_op_test.cc100
-rw-r--r--tensorflow/core/ops/string_ops.cc1
-rw-r--r--tensorflow/python/kernel_tests/substr_op_test.py503
-rw-r--r--tensorflow/python/ops/string_ops.py16
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt2
14 files changed, 655 insertions, 208 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
index 5246090ab3..fe0fcc9508 100644
--- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
@@ -18,6 +18,16 @@ END
Scalar defining the number of characters to include in each substring
END
}
+ attr {
+ name: "unit"
+ description: <<END
+The unit that is used to create the substring. One of: `"BYTE"` (for
+defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8
+encoded Unicode code points). The default is `"BYTE"`. Results are undefined if
+`unit=UTF8_CHAR` and the `input` strings do not contain structurally valid
+UTF-8.
+END
+ }
out_arg {
name: "output"
description: <<END
diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
index 4778d7927c..4fb9ee56e9 100644
--- a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
@@ -1,10 +1,4 @@
op {
graph_op_name: "Substr"
- endpoint {
- name: "strings.substr"
- }
- endpoint {
- name: "substr"
- deprecated: true
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 9439ab332c..3a920f26f3 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4458,7 +4458,12 @@ cc_library(
name = "string_util",
srcs = ["string_util.cc"],
hdrs = ["string_util.h"],
- deps = ["//tensorflow/core:lib"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@icu//:common",
+ ],
)
STRING_DEPS = [
diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc
index 3a9803a052..92c73220d8 100644
--- a/tensorflow/core/kernels/string_util.cc
+++ b/tensorflow/core/kernels/string_util.cc
@@ -16,10 +16,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
-namespace {
-inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
-} // namespace
-
namespace tensorflow {
// Sets unit value based on str.
diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h
index 390cf57702..d40e93ea33 100644
--- a/tensorflow/core/kernels/string_util.h
+++ b/tensorflow/core/kernels/string_util.h
@@ -30,6 +30,9 @@ enum class UnicodeEncoding { UTF8 };
// TODO(edloper): Add support for: UTF32_CHAR, etc.
enum class CharUnit { BYTE, UTF8_CHAR };
+// Whether or not the given byte is the trailing byte of a UTF-8/16/32 char.
+inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
+
// Sets `encoding` based on `str`.
Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
@@ -40,6 +43,47 @@ Status ParseCharUnit(const string& str, CharUnit* unit);
// Result may be incorrect if the input string is not valid UTF-8.
int32 UTF8StrLen(const string& string);
+// Get the next UTF8 character position starting at the given position and
+// skipping the given number of characters. Position is a byte offset, and
+// should never be `null`. The function return true if successful. However, if
+// the end of the string is reached before the requested characters, then the
+// position will point to the end of string and this function will return false.
+template <typename T>
+bool ForwardNUTF8CharPositions(const StringPiece in,
+ const T num_utf8_chars_to_shift, T* pos) {
+ const size_t size = in.size();
+ T utf8_chars_counted = 0;
+ while (utf8_chars_counted < num_utf8_chars_to_shift && *pos < size) {
+ // move forward one utf-8 character
+ do {
+ ++*pos;
+ } while (IsTrailByte(in[*pos]) && *pos < size);
+ ++utf8_chars_counted;
+ }
+ return utf8_chars_counted == num_utf8_chars_to_shift;
+}
+
+// Get the previous UTF8 character position starting at the given position and
+// skipping the given number of characters. Position is a byte offset with a
+// positive value, relative to the beginning of the string, and should never be
+// `null`. The function return true if successful. However, if the beginning of
+// the string is reached before the requested character, then the position will
+// point to the beginning of the string and this function will return false.
+template <typename T>
+bool BackNUTF8CharPositions(const StringPiece in,
+ const T num_utf8_chars_to_shift, T* pos) {
+ const size_t start = 0;
+ T utf8_chars_counted = 0;
+ while (utf8_chars_counted < num_utf8_chars_to_shift && (*pos > start)) {
+ // move back one utf-8 character
+ do {
+ --*pos;
+ } while (IsTrailByte(in[*pos]) && *pos > start);
+ ++utf8_chars_counted;
+ }
+ return utf8_chars_counted == num_utf8_chars_to_shift;
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 07f1d6e767..93c427039d 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/string_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
@@ -37,7 +38,11 @@ namespace tensorflow {
template <typename T>
class SubstrOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
void Compute(OpKernelContext* context) override {
// Get inputs
@@ -69,11 +74,23 @@ class SubstrOp : public OpKernel {
tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
StringPiece in(input(i));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
} else {
@@ -84,11 +101,23 @@ class SubstrOp : public OpKernel {
StringPiece in(input(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
}
@@ -151,12 +180,24 @@ class SubstrOp : public OpKernel {
StringPiece in(input_bcast(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
- OP_REQUIRES(
- context,
- FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(byte_pos, input_bcast(i).size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
break;
@@ -205,12 +246,24 @@ class SubstrOp : public OpKernel {
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for ",
- "string b'", in, "' at index (", i,
- ", ", j, ")"));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index (",
+ i, ", ", j, ")"));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i, j).assign(sub_in.data(), sub_in.size());
}
}
@@ -227,12 +280,73 @@ class SubstrOp : public OpKernel {
private:
// This adjusts the requested position. Note it does not perform any bound
// checks.
- T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+ static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
if (pos_requested < 0) {
return s.size() + pos_requested;
}
return pos_requested;
}
+
+ // Return true if successful; otherwise, return false if the `pos` argument
+ // is out of range in the string.
+ static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos,
+ T* len) {
+ if (*pos >= 0) {
+ return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len);
+ } else {
+ return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len);
+ }
+ }
+
+ static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos,
+ const T len, T* char_pos,
+ T* char_len) {
+ *char_pos = 0;
+ // Determine byte position of the substring start.
+ if (!ForwardNUTF8CharPositions(in, pos, char_pos)) {
+ return false;
+ }
+ // Determine position of the end of the substring.
+ // The length will be capped at the end of the string, and we ignore whether
+ // the string had enough characters to handle it or not.
+ *char_len = *char_pos;
+ ForwardNUTF8CharPositions(in, len, char_len);
+ // The length in bytes is the position end of the substring less the start.
+ *char_len = *char_len - *char_pos;
+ return true;
+ }
+
+ // This function expects a negative position relative to the end of the
+ // string, but will update the character position to a positive number
+ // relative to the beginning of the string.
+ static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos,
+ const T len, T* char_pos,
+ T* char_len) {
+ // Initially treat the length as position of the end of the substring.
+ *char_len = in.size();
+ // This is the number of character to skip from the end of the string to
+ // arrive at the position where the substring should end.
+ T utf8_chars_to_skip = -pos - len;
+ if (utf8_chars_to_skip < 0) {
+ utf8_chars_to_skip = 0;
+ }
+ // Find the byte position where the substring should end using the computed
+ // number of characters to skip.
+ if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) {
+ return false;
+ }
+ // Next, determine where the substring should begin. The number of chars to
+ // skip is the requested position minus the chars we've previously skipped.
+ *char_pos = *char_len;
+ if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) {
+ return false;
+ }
+ // The length in bytes is the position end of the substring less the start.
+ *char_len = *char_len - *char_pos;
+ return true;
+ }
+
+ CharUnit unit_ = CharUnit::BYTE;
};
#define REGISTER_SUBSTR(type) \
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
index 2e07050260..ea6b1ed500 100644
--- a/tensorflow/core/kernels/substr_op_test.cc
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -42,7 +42,7 @@ limitations under the License.
namespace tensorflow {
// Test data from the TensorFlow README.md.
-const char* lines[] = {
+const char* ascii_lines[] = {
"**TensorFlow** is an open source software library for numerical "
"computation using data flow graphs.",
"The graph nodes represent mathematical operations, while the graph edges "
@@ -64,17 +64,76 @@ const char* lines[] = {
"backwards compatibility guarantee like C++, Go, Java, JavaScript and "
"Swift."};
+const char* unicode_lines[] = {
+ "TensorFlow\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe4\xbd\xbf\xe7\x94\xa8\xe6"
+ "\x95\xb0\xe6\x8d\xae\xe6\xb5\x81\xe5\x9b\xbe\xe8\xbf\x9b\xe8\xa1\x8c\xe6"
+ "\x95\xb0\xe5\x80\xbc\xe8\xae\xa1\xe7\xae\x97\xe7\x9a\x84\xe5\xbc\x80\xe6"
+ "\xba\x90\xe8\xbd\xaf\xe4\xbb\xb6\xe5\xba\x93\xe3\x80\x82",
+ "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\x8a\x82\xe7\x82\xb9\xe8\xa1\xa8\xe7\xa4\xba"
+ "\xe6\x95\xb0\xe5\xad\xa6\xe8\xbf\x90\xe7\xae\x97\xef\xbc\x8c\xe8\x80\x8c"
+ "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\xbe\xb9\xe7\xbc\x98\xe8\xa1\xa8\xe7\xa4\xba"
+ "\xe5\x9c\xa8\xe5\xae\x83\xe4\xbb\xac\xe4\xb9\x8b\xe9\x97\xb4\xe6\xb5\x81"
+ "\xe5\x8a\xa8\xe7\x9a\x84\xe5\xa4\x9a\xe7\xbb\xb4\xe6\x95\xb0\xe6\x8d\xae"
+ "\xe9\x98\xb5\xe5\x88\x97\xef\xbc\x88\xe5\xbc\xa0\xe9\x87\x8f\xef\xbc\x89"
+ "\xe3\x80\x82",
+ "\xe8\xbf\x99\xe7\xa7\x8d\xe7\x81\xb5\xe6\xb4\xbb\xe7\x9a\x84\xe4\xbd\x93"
+ "\xe7\xb3\xbb\xe7\xbb\x93\xe6\x9e\x84\xe4\xbd\xbf\xe6\x82\xa8\xe5\x8f\xaf"
+ "\xe4\xbb\xa5\xe5\xb0\x86\xe8\xae\xa1\xe7\xae\x97\xe9\x83\xa8\xe7\xbd\xb2"
+ "\xe5\x88\xb0\xe6\xa1\x8c\xe9\x9d\xa2\xef\xbc\x8c\xe6\x9c\x8d\xe5\x8a\xa1"
+ "\xe5\x99\xa8\xe6\x88\x96\xe7\xa7\xbb\xe5\x8a\xa8\xe8\xae\xbe\xe5\xa4\x87"
+ "\xe4\xb8\xad\xe7\x9a\x84\xe4\xb8\x80\xe4\xb8\xaa\xe6\x88\x96\xe5\xa4\x9a"
+ "\xe4\xb8\xaa CPU\xe6\x88\x96GPU\xef\xbc\x8c\xe8\x80\x8c\xe6\x97\xa0\xe9"
+ "\x9c\x80\xe9\x87\x8d\xe5\x86\x99\xe4\xbb\xa3\xe7\xa0\x81\xe3\x80\x82",
+ "TensorFlow\xe8\xbf\x98\xe5\x8c\x85\xe6\x8b\xac[TensorBoard]\xef\xbc\x88"
+ "https://www.tensorflow.org/guide/summaries_and_tensorboard\xef\xbc\x89\xef"
+ "\xbc\x8c\xe8\xbf\x99\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe6\x95\xb0\xe6"
+ "\x8d\xae\xe5\x8f\xaf\xe8\xa7\x86\xe5\x8c\x96\xe5\xb7\xa5\xe5\x85\xb7\xe5"
+ "\x8c\x85\xe3\x80\x82",
+ "TensorFlow\xe6\x9c\x80\xe5\x88\x9d\xe6\x98\xaf\xe7\x94\xb1\xe7\xa0\x94\xe7"
+ "\xa9\xb6\xe4\xba\xba\xe5\x91\x98\xe5\x92\x8c\xe5\xb7\xa5\xe7\xa8\x8b\xe5"
+ "\xb8\x88\xe5\x9c\xa8Google\xe6\x9c\xba\xe5\x99\xa8\xe6\x99\xba\xe8\x83\xbd"
+ "\xe7\xa0\x94\xe7\xa9\xb6\xe7\xbb\x84\xe7\xbb\x87\xe7\x9a\x84Google Brain"
+ "\xe5\x9b\xa2\xe9\x98\x9f\xe5\xbc\x80\xe5\x8f\x91\xe7\x9a\x84\xef\xbc\x8c"
+ "\xe7\x9b\xae\xe7\x9a\x84\xe6\x98\xaf\xe8\xbf\x9b\xe8\xa1\x8c\xe6\x9c\xba"
+ "\xe5\x99\xa8\xe5\xad\xa6\xe4\xb9\xa0\xe5\x92\x8c\xe6\xb7\xb1\xe5\xba\xa6"
+ "\xe7\xa5\x9e\xe7\xbb\x8f\xe7\xbd\x91\xe7\xbb\x9c\xe7\xa0\x94\xe7\xa9\xb6"
+ "\xe3\x80\x82",
+ "\xe8\xaf\xa5\xe7\xb3\xbb\xe7\xbb\x9f\xe8\xb6\xb3\xe4\xbb\xa5\xe9\x80\x82"
+ "\xe7\x94\xa8\xe4\xba\x8e\xe5\x90\x84\xe7\xa7\x8d\xe5\x85\xb6\xe4\xbb\x96"
+ "\xe9\xa2\x86\xe5\x9f\x9f\xe4\xb9\x9f\xe6\x98\xaf\xe5\xa6\x82\xe6\xad\xa4"
+ "\xe3\x80\x82",
+ "TensorFlow\xe6\x8f\x90\xe4\xbe\x9b\xe7\xa8\xb3\xe5\xae\x9a\xe7\x9a\x84"
+ "Python API\xe5\x92\x8c C API\xef\xbc\x8c\xe4\xbb\xa5\xe5\x8f\x8a\xe6\xb2"
+ "\xa1\xe6\x9c\x89 API\xe5\x90\x91\xe5\x90\x8e\xe5\x85\xbc\xe5\xae\xb9\xe6"
+ "\x80\xa7\xe4\xbf\x9d\xe8\xaf\x81\xef\xbc\x8c\xe5\xa6\x82 C ++\xef\xbc\x8c"
+ "Go\xef\xbc\x8cJava\xef\xbc\x8cJavaScript\xe5\x92\x8cSwift\xe3\x80\x82",
+};
+
+const char* const kByteUnit = "BYTE";
+const char* const kUTF8Unit = "UTF8_CHAR";
+
Tensor GetTestTensor(int batch) {
- const int sz = TF_ARRAYSIZE(lines);
+ const int sz = TF_ARRAYSIZE(ascii_lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = ascii_lines[i % sz];
+ }
+ return t;
+}
+
+Tensor GetTestUTF8Tensor(int batch) {
+ const int sz = TF_ARRAYSIZE(unicode_lines);
Tensor t(DT_STRING, {batch});
auto s = t.flat<string>();
for (int i = 0; i < batch; ++i) {
- s(i) = lines[i % sz];
+ s(i) = unicode_lines[i % sz];
}
return t;
}
-Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
+Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len,
+ const char* const unit) {
Graph* g = new Graph(OpRegistry::Global());
Tensor position(DT_INT32, TensorShape({}));
position.flat<int32>().setConstant(pos);
@@ -85,21 +144,46 @@ Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
.Input(test::graph::Constant(g, input))
.Input(test::graph::Constant(g, position))
.Input(test::graph::Constant(g, length))
+ .Attr("unit", unit)
.Finalize(g, nullptr /* node */));
return g;
}
-void BM_Substr(int iters, int batch_size) {
+void BM_SubstrByte(int iters, int batch_size) {
testing::StopTiming();
testing::ItemsProcessed(static_cast<int64>(iters));
testing::UseRealTime();
Tensor input = GetTestTensor(batch_size);
- Graph* g = SetupSubstrGraph(input, 3, 30);
+ Graph* g = SetupSubstrGraph(input, 3, 30, kByteUnit);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+void BM_SubstrUTF8(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestUTF8Tensor(batch_size);
+ Graph* g = SetupSubstrGraph(input, 3, 30, kUTF8Unit);
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
}
-BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg(
- 256);
+BENCHMARK(BM_SubstrByte)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+BENCHMARK(BM_SubstrUTF8)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index b4fbde54d9..94d71a4113 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -223,6 +223,7 @@ REGISTER_OP("Substr")
.Input("len: T")
.Output("output: string")
.Attr("T: {int32, int64}")
+ .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle pos_shape = c->input(1);
ShapeHandle len_shape = c->input(2);
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index cd3fe14883..37aa624b07 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -28,270 +28,448 @@ from tensorflow.python.platform import test
class SubstrOpTest(test.TestCase, parameterized.TestCase):
- def _testScalarString(self, dtype):
- test_string = b"Hello"
- position = np.array(1, dtype)
+ @parameterized.parameters(
+ (np.int32, 1, "BYTE"),
+ (np.int64, 1, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 1, "UTF8_CHAR"),
+ (np.int64, 1, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testScalarString(self, dtype, pos, unit):
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"He\xc3\xc3\U0001f604".encode("utf-8"),
+ }[unit]
+ expected_value = {
+ "BYTE": b"ell",
+ "UTF8_CHAR": u"e\xc3\xc3".encode("utf-8"),
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- expected_value = b"ell"
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Negative position.
- test_string = b"Hello"
- position = np.array(-4, dtype)
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testScalarString_EdgeCases(self, dtype, unit):
+ # Empty string
+ test_string = {
+ "BYTE": b"",
+ "UTF8_CHAR": u"".encode("utf-8"),
+ }[unit]
+ expected_value = b""
+ position = np.array(0, dtype)
length = np.array(3, dtype)
- expected_value = b"ell"
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Position is equal to the length of string.
- test_string = b""
+ # Full string
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
position = np.array(0, dtype)
- length = np.array(2, dtype)
- expected_value = b""
-
- substr_op = string_ops.substr(test_string, position, length)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- # Negative position magnitude is equal to the length of string.
- test_string = b"yo"
- position = np.array(-2, dtype)
- length = np.array(1, dtype)
- expected_value = b"y"
-
- substr_op = string_ops.substr(test_string, position, length)
+ self.assertAllEqual(substr, test_string)
+
+ # Full string (Negative)
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(-5, dtype)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- def _testVectorStrings(self, dtype):
- test_string = [b"Hello", b"World"]
- position = np.array(1, dtype)
- length = np.array(3, dtype)
- expected_value = [b"ell", b"orl"]
-
- substr_op = string_ops.substr(test_string, position, length)
+ self.assertAllEqual(substr, test_string)
+
+ # Length is larger in magnitude than a negative position
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ expected_string = {
+ "BYTE": b"ello",
+ "UTF8_CHAR": u"\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(-4, dtype)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- # Negative position.
- test_string = [b"Hello", b"World"]
- position = np.array(-4, dtype)
+ self.assertAllEqual(substr, expected_string)
+
+ @parameterized.parameters(
+ (np.int32, 1, "BYTE"),
+ (np.int64, 1, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 1, "UTF8_CHAR"),
+ (np.int64, 1, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testVectorStrings(self, dtype, pos, unit):
+ test_string = {
+ "BYTE": [b"Hello", b"World"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"H\xc3llo",
+ u"W\U0001f604rld"]],
+ }[unit]
+ expected_value = {
+ "BYTE": [b"ell", b"orl"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"\xc3ll", u"\U0001f604rl"]],
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- expected_value = [b"ell", b"orl"]
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testMatrixStrings(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testMatrixStrings(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"He\xc3\xc3o",
+ u"W\U0001f604rld",
+ u"d\xfcd\xea"]]],
+ }[unit]
position = np.array(1, dtype)
length = np.array(4, dtype)
- expected_value = [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
- [b"ixte", b"even", b"ight"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
+ [b"ixte", b"even", b"ight"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
+ u"\u053c\u025bv\u025b",
+ u"w\u0c1dlv"]],
+ [x.encode("utf-8") for x in [u"e\xc3\xc3o",
+ u"\U0001f604rld",
+ u"\xfcd\xea"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Negative position
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array(-2, dtype)
+ position = np.array(-3, dtype)
length = np.array(2, dtype)
- expected_value = [[b"en", b"en", b"ve"], [b"en", b"en", b"en"],
- [b"en", b"en", b"en"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"te", b"ve", b"lv"], [b"ee", b"ee", b"ee"],
+ [b"ee", b"ee", b"ee"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227",
+ u"v\u025b", u"lv"]],
+ [x.encode("utf-8") for x in [u"\xc3\xc3", u"rl",
+ u"\xfcd"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testElementWisePosLen(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testElementWisePosLen(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"He\xc3\xc3o",
+ u"W\U0001f604rld",
+ u"d\xfcd\xea"]],
+ [x.encode("utf-8") for x in [u"sixt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]]],
+ }[unit]
position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype)
length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype)
- expected_value = [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
- [b"xteen", b"vente", b"hteen"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
+ [b"xteen", b"vente", b"hteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
+ u"\u025bv",
+ u"lv\u025b"]],
+ [x.encode("utf-8") for x in [u"e\xc3\xc3o",
+ u"rld",
+ u"d\xfc"]],
+ [x.encode("utf-8") for x in [u"xt\xea\xean",
+ u"\U00010299ente",
+ u"h\x86een"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testBroadcast(self, dtype):
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testBroadcast(self, dtype, unit):
# Broadcast pos/len onto input string
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"],
- [b"nineteen", b"twenty", b"twentyone"]]
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"],
+ [b"nineteen", b"twenty", b"twentyone"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]],
+ [x.encode("utf-8") for x in [u"nineteen",
+ u"twenty",
+ u"twentyone"]]],
+ }[unit]
position = np.array([1, -4, 3], dtype)
length = np.array([1, 2, 3], dtype)
- expected_value = [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
- [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
+ [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227",
+ u"\u025bv", u"lv\u025b"]],
+ [x.encode("utf-8") for x in [u"h", u"t\xea", u"tee"]],
+ [x.encode("utf-8") for x in [u"\xcd", u"te", u"h\x86e"]],
+ [x.encode("utf-8") for x in [u"i", u"en", u"nty"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Broadcast input string onto pos/len
- test_string = [b"thirteen", b"fourteen", b"fifteen"]
+ test_string = {
+ "BYTE": [b"thirteen", b"fourteen", b"fifteen"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ }[unit]
position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- expected_value = [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
- [b"ee", b"ee", b"ft"]]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
+ [b"ee", b"ee", b"ft"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"h\xcdr", u"\xean", u"t"]],
+ [x.encode("utf-8") for x in [u"\xea", u"ur",
+ u"\xcd\ua09ct"]],
+ [x.encode("utf-8") for x in [u"\xea\xea", u"\xea\xea",
+ u"\ua09ct"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Test 1D broadcast
- test_string = b"thirteen"
- position = np.array([1, -5, 7], dtype)
+ test_string = {
+ "BYTE": b"thirteen",
+ "UTF8_CHAR": u"th\xcdrt\xea\xean".encode("utf-8"),
+ }[unit]
+ position = np.array([1, -4, 7], dtype)
length = np.array([3, 2, 1], dtype)
- expected_value = [b"hir", b"rt", b"n"]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [b"hir", b"te", b"n"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"h\xcdr", u"t\xea", u"n"]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testBadBroadcast(self, dtype):
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testBadBroadcast(self, dtype, unit):
test_string = [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]]
position = np.array([1, 2, -3, 4], dtype)
length = np.array([1, 2, 3, 4], dtype)
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- def _testOutOfRangeError(self, dtype):
+ string_ops.substr(test_string, position, length, unit=unit)
+
+ @parameterized.parameters(
+ (np.int32, 6, "BYTE"),
+ (np.int64, 6, "BYTE"),
+ (np.int32, -6, "BYTE"),
+ (np.int64, -6, "BYTE"),
+ (np.int32, 6, "UTF8_CHAR"),
+ (np.int64, 6, "UTF8_CHAR"),
+ (np.int32, -6, "UTF8_CHAR"),
+ (np.int64, -6, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_Scalar(self, dtype, pos, unit):
# Scalar/Scalar
- test_string = b"Hello"
- position = np.array(7, dtype)
- length = np.array(3, dtype)
- substr_op = string_ops.substr(test_string, position, length)
- with self.cached_session():
- with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- # Scalar/Scalar (with negative)
- test_string = b"Hello"
- position = np.array(-7, dtype)
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, 4, "BYTE"),
+ (np.int64, 4, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 4, "UTF8_CHAR"),
+ (np.int64, 4, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_VectorScalar(self, dtype, pos, unit):
# Vector/Scalar
- test_string = [b"good", b"good", b"bad", b"good"]
- position = np.array(4, dtype)
- length = np.array(1, dtype)
- substr_op = string_ops.substr(test_string, position, length)
- with self.cached_session():
- with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- # Vector/Scalar (with negative)
- test_string = [b"good", b"good", b"bad", b"good"]
- position = np.array(-4, dtype)
+ test_string = {
+ "BYTE": [b"good", b"good", b"bad", b"good"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"b\xc3d",
+ u"g\xc3\xc3d"]],
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(1, dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_MatrixMatrix(self, dtype, unit):
# Matrix/Matrix
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
- [b"good", b"good", b"good"]]
+ test_string = {
+ "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
+ [b"good", b"good", b"good"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"b\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]]],
+ }[unit]
position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
+ substr_op.eval()
# Matrix/Matrix (with negative)
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
- [b"good", b"good", b"good"]]
position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_Broadcast(self, dtype, unit):
# Broadcast
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
+ test_string = {
+ "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"b\xc3d"]]],
+ }[unit]
position = np.array([1, 2, 4], dtype)
length = np.array([1, 2, 3], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
+ substr_op.eval()
# Broadcast (with negative)
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
position = np.array([-1, -2, -4], dtype)
length = np.array([1, 2, 3], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- def _testMismatchPosLenShapes(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testMismatchPosLenShapes(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]]],
+ }[unit]
position = np.array([[1, 2, 3]], dtype)
length = np.array([2, 3, 4], dtype)
# Should fail: position/length have different rank
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
+ string_ops.substr(test_string, position, length)
position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
length = np.array([[2, 3, 4]], dtype)
# Should fail: position/length have different dimensionality
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- # Negative position.
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array([[-1, -2, -3]], dtype)
- length = np.array([1, 2, 3], dtype)
- # Should fail: position/length have different rank
- with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- @parameterized.parameters(np.int32, np.int64)
- def testAll(self, dtype):
- self._testScalarString(dtype)
- self._testVectorStrings(dtype)
- self._testMatrixStrings(dtype)
- self._testElementWisePosLen(dtype)
- self._testBroadcast(dtype)
- self._testBadBroadcast(dtype)
- self._testOutOfRangeError(dtype)
- self._testMismatchPosLenShapes(dtype)
+ string_ops.substr(test_string, position, length)
def testWrongDtype(self):
with self.cached_session():
@@ -300,6 +478,11 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(TypeError):
string_ops.substr(b"test", 3, 1.0)
+ def testInvalidUnit(self):
+ with self.cached_session():
+ with self.assertRaises(ValueError):
+ string_ops.substr(b"test", 3, 1, unit="UTF8")
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 0812f901a2..f26388efea 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -347,6 +347,22 @@ def string_length(input, name=None, unit="BYTE"):
string_length.__doc__ = gen_string_ops.string_length.__doc__
+@tf_export("substr")
+@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.")
+def substr_deprecated(input, pos, len, name=None, unit="BYTE"):
+ return substr(input, pos, len, name=name, unit=unit)
+
+substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
+
+
+@tf_export("strings.substr")
+def substr(input, pos, len, name=None, unit="BYTE"):
+ return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
+
+
+substr.__doc__ = gen_string_ops.substr.__doc__
+
+
ops.NotDifferentiable("RegexReplace")
ops.NotDifferentiable("StringToHashBucket")
ops.NotDifferentiable("StringToHashBucketFast")
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index c1cc7322f0..247dfcc1ca 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -2094,7 +2094,7 @@ tf_module {
}
member_method {
name: "substr"
- argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "subtract"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index ebdaf57231..5ba48e7f57 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -34,7 +34,7 @@ tf_module {
}
member_method {
name: "substr"
- argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "to_hash_bucket"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 571abc3b19..978afcf985 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -1934,7 +1934,7 @@ tf_module {
}
member_method {
name: "substr"
- argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "subtract"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index ebdaf57231..5ba48e7f57 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -34,7 +34,7 @@ tf_module {
}
member_method {
name: "substr"
- argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "to_hash_bucket"