diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-25 11:56:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 12:07:36 -0700 |
commit | d5c5df164cedcd8ae43fff41256592818bc6c2de (patch) | |
tree | 0b07685ff599ad9bfe93a58386c67adb94e57c8e /tensorflow | |
parent | df930015230c1195065e2fd01c61f527b8662efb (diff) |
Add "encoding" attribute to string length op, which controls how "string length" is defined:
* BYTE: The number of bytes in each string. (Default)
* UTF8: The number of UTF-8 encoded Unicode code points in each string.
RELNOTES: Add option to calculate string length in Unicode characters
PiperOrigin-RevId: 214478470
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/makefile/tf_op_files.txt | 1 | ||||
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt | 10 | ||||
-rw-r--r-- | tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 10 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_length_op.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_util.cc | 63 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_util.h | 45 | ||||
-rw-r--r-- | tensorflow/core/ops/string_ops.cc | 1 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/string_length_op_test.py | 27 | ||||
-rw-r--r-- | tensorflow/python/ops/string_ops.py | 13 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt | 2 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt | 2 |
12 files changed, 193 insertions, 8 deletions
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 08de54b8e1..f81a90809a 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -253,6 +253,7 @@ tensorflow/core/kernels/strided_slice_op_inst_5.cc tensorflow/core/kernels/strided_slice_op_inst_6.cc tensorflow/core/kernels/strided_slice_op_inst_7.cc tensorflow/core/kernels/string_join_op.cc +tensorflow/core/kernels/string_util.cc tensorflow/core/kernels/tensor_array.cc tensorflow/core/kernels/tensor_array_ops.cc tensorflow/core/kernels/tile_functor_cpu.cc diff --git a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt index cc21ddc815..7d2fbcd00b 100644 --- a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt @@ -1,5 +1,15 @@ op { graph_op_name: "StringLength" + attr { + name: "unit" + description: <<END +The unit that is counted to compute string length. One of: `"BYTE"` (for +the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8 +encoded Unicode code points in each string). Results are undefined +if `unit=UTF8_CHAR` and the `input` strings do not contain structurally +valid UTF-8. +END + } in_arg { name: "input" description: <<END diff --git a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt index 01c02e1f70..df012414e3 100644 --- a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt @@ -1,6 +1,4 @@ op { graph_op_name: "StringLength" - endpoint { - name: "strings.length" - } + visibility: HIDDEN } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ab69925d04..1a3db2c7cd 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4434,8 +4434,16 @@ cc_library( ], ) +cc_library( + name = "string_util", + srcs = ["string_util.cc"], + hdrs = ["string_util.h"], + deps = ["//tensorflow/core:lib"], +) + STRING_DEPS = [ ":bounds_check", + ":string_util", "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -5166,6 +5174,7 @@ filegroup( "spacetobatch_functor.h", "spacetodepth_op.h", "spectrogram.h", + "string_util.h", "tensor_array.h", "tile_functor.h", "tile_ops_cpu_impl.h", @@ -5334,6 +5343,7 @@ filegroup( "spectrogram_op.cc", "stack_ops.cc", "string_join_op.cc", + "string_util.cc", "summary_op.cc", "tensor_array.cc", "tensor_array_ops.cc", diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc index a6829b29d9..435a7abdca 100644 --- a/tensorflow/core/kernels/string_length_op.cc +++ b/tensorflow/core/kernels/string_length_op.cc @@ -14,13 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/string_util.h" namespace tensorflow { namespace { class StringLengthOp : public OpKernel { public: - using OpKernel::OpKernel; + explicit StringLengthOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string unit; + OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit)); + OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_)); + } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); @@ -32,10 +37,22 @@ class StringLengthOp : public OpKernel { auto src = input.flat<string>(); auto dst = output->flat<int32>(); - for (int n = 0; n < src.size(); ++n) { - dst(n) = src(n).size(); + switch (unit_) { + case CharUnit::BYTE: + for (int n = 0; n < src.size(); ++n) { + dst(n) = src(n).size(); + } + break; + case CharUnit::UTF8_CHAR: + for (int n = 0; n < src.size(); ++n) { + dst(n) = UTF8StrLen(src(n)); + } + break; } } + + private: + CharUnit unit_ = CharUnit::BYTE; }; REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc new file mode 100644 index 0000000000..3a9803a052 --- /dev/null +++ b/tensorflow/core/kernels/string_util.cc @@ -0,0 +1,63 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/string_util.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace { +inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } +} // namespace + +namespace tensorflow { + +// Sets unit value based on str. +Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) { + if (str == "UTF8") { + *encoding = UnicodeEncoding::UTF8; + } else { + return errors::InvalidArgument(strings::StrCat( + "Invalid encoding \"", str, "\": Should be one of: BYTE")); + } + return Status::OK(); +} + +// Sets unit value based on str. +Status ParseCharUnit(const string& str, CharUnit* unit) { + if (str == "BYTE") { + *unit = CharUnit::BYTE; + } else if (str == "UTF8_CHAR") { + *unit = CharUnit::UTF8_CHAR; + } else { + return errors::InvalidArgument(strings::StrCat( + "Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR")); + } + return Status::OK(); +} + +// Return the number of Unicode characters in a UTF-8 string. +// Result may be incorrect if the input string is not valid UTF-8. +int32 UTF8StrLen(const string& string) { + const int32 byte_size = string.size(); + const char* const end = string.data() + byte_size; + const char* ptr = string.data(); + int32 skipped_count = 0; + while (ptr < end) { + skipped_count += IsTrailByte(*ptr++) ? 1 : 0; + } + const int32 result = byte_size - skipped_count; + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h new file mode 100644 index 0000000000..390cf57702 --- /dev/null +++ b/tensorflow/core/kernels/string_util.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Enumeration for unicode encodings. Used by ops such as +// tf.strings.unicode_encode and tf.strings.unicode_decode. +// TODO(edloper): Add support for: +// UTF16, UTF32, UTF16BE, UTF32BE, UTF16LE, UTF32LE +enum class UnicodeEncoding { UTF8 }; + +// Enumeration for character units. Used by string such as +// tf.strings.length and tf.substr. +// TODO(edloper): Add support for: UTF32_CHAR, etc. +enum class CharUnit { BYTE, UTF8_CHAR }; + +// Sets `encoding` based on `str`. +Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding); + +// Sets `unit` value based on `str`. +Status ParseCharUnit(const string& str, CharUnit* unit); + +// Returns the number of Unicode characters in a UTF-8 string. +// Result may be incorrect if the input string is not valid UTF-8. +int32 UTF8StrLen(const string& string); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 99159839d0..da1d2a6432 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -203,6 +203,7 @@ REGISTER_OP("StringStrip") REGISTER_OP("StringLength") .Input("input: string") .Output("output: int32") + .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") .SetShapeFn(shape_inference::UnchangedShape); REGISTER_OP("EncodeBase64") diff --git a/tensorflow/python/kernel_tests/string_length_op_test.py b/tensorflow/python/kernel_tests/string_length_op_test.py index 9f013c2c7e..4afe3ad3f4 100644 --- a/tensorflow/python/kernel_tests/string_length_op_test.py +++ b/tensorflow/python/kernel_tests/string_length_op_test.py @@ -32,6 +32,33 @@ class StringLengthOpTest(test.TestCase): values = sess.run(lengths) self.assertAllEqual(values, [[[1, 2], [3, 4], [5, 6]]]) + def testUnit(self): + unicode_strings = [u"H\xc3llo", u"\U0001f604"] + utf8_strings = [s.encode("utf-8") for s in unicode_strings] + expected_utf8_byte_lengths = [6, 4] + expected_utf8_char_lengths = [5, 1] + + with self.test_session() as sess: + utf8_byte_lengths = string_ops.string_length(utf8_strings, unit="BYTE") + utf8_char_lengths = string_ops.string_length( + utf8_strings, unit="UTF8_CHAR") + self.assertAllEqual( + sess.run(utf8_byte_lengths), expected_utf8_byte_lengths) + self.assertAllEqual( + sess.run(utf8_char_lengths), expected_utf8_char_lengths) + with self.assertRaisesRegexp( + ValueError, "Attr 'unit' of 'StringLength' Op passed string 'XYZ' " + 'not in: "BYTE", "UTF8_CHAR"'): + string_ops.string_length(utf8_strings, unit="XYZ") + + def testLegacyPositionalName(self): + # Code that predates the 'unit' parameter may have used a positional + # argument for the 'name' parameter. Check that we don't break such code. + strings = [[["1", "12"], ["123", "1234"], ["12345", "123456"]]] + lengths = string_ops.string_length(strings, "some_name") + with self.test_session(): + self.assertAllEqual(lengths.eval(), [[[1, 2], [3, 4], [5, 6]]]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index 5d949467fd..046a48d192 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -36,10 +36,12 @@ from tensorflow.python.ops import math_ops # go/tf-wildcard-import # pylint: disable=wildcard-import +# pylint: disable=g-bad-import-order from tensorflow.python.ops.gen_string_ops import * from tensorflow.python.util import compat as util_compat from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export +# pylint: enable=g-bad-import-order # pylint: enable=wildcard-import @@ -328,6 +330,17 @@ def reduce_join(inputs, axis=None, reduce_join.__doc__ = deprecation.rewrite_argument_docstring( gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis") + +# This wrapper provides backwards compatibility for code that predates the +# unit argument and that passed 'name' as a positional argument. +@tf_export("strings.length") +def string_length(input, name=None, unit="BYTE"): + return gen_string_ops.string_length(input, unit=unit, name=name) + + +string_length.__doc__ = gen_string_ops.string_length.__doc__ + + ops.NotDifferentiable("RegexReplace") ops.NotDifferentiable("StringToHashBucket") ops.NotDifferentiable("StringToHashBucketFast") diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt index c81c156518..c52581dec1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt @@ -10,7 +10,7 @@ tf_module { } member_method { name: "length" - argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], " } member_method { name: "regex_full_match" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt index c81c156518..c52581dec1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt @@ -10,7 +10,7 @@ tf_module { } member_method { name: "length" - argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], " } member_method { name: "regex_full_match" |