diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-25 11:56:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 12:07:36 -0700 |
commit | d5c5df164cedcd8ae43fff41256592818bc6c2de (patch) | |
tree | 0b07685ff599ad9bfe93a58386c67adb94e57c8e /tensorflow/core/kernels | |
parent | df930015230c1195065e2fd01c61f527b8662efb (diff) |
Add "encoding" attribute to string length op, which controls how "string length" is defined:
* BYTE: The number of bytes in each string. (Default)
* UTF8: The number of UTF-8 encoded Unicode code points in each string.
RELNOTES: Add option to calculate string length in Unicode characters
PiperOrigin-RevId: 214478470
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/BUILD | 10 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_length_op.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_util.cc | 63 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_util.h | 45 |
4 files changed, 138 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ab69925d04..1a3db2c7cd 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4434,8 +4434,16 @@ cc_library( ], ) +cc_library( + name = "string_util", + srcs = ["string_util.cc"], + hdrs = ["string_util.h"], + deps = ["//tensorflow/core:lib"], +) + STRING_DEPS = [ ":bounds_check", + ":string_util", "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -5166,6 +5174,7 @@ filegroup( "spacetobatch_functor.h", "spacetodepth_op.h", "spectrogram.h", + "string_util.h", "tensor_array.h", "tile_functor.h", "tile_ops_cpu_impl.h", @@ -5334,6 +5343,7 @@ filegroup( "spectrogram_op.cc", "stack_ops.cc", "string_join_op.cc", + "string_util.cc", "summary_op.cc", "tensor_array.cc", "tensor_array_ops.cc", diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc index a6829b29d9..435a7abdca 100644 --- a/tensorflow/core/kernels/string_length_op.cc +++ b/tensorflow/core/kernels/string_length_op.cc @@ -14,13 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/string_util.h" namespace tensorflow { namespace { class StringLengthOp : public OpKernel { public: - using OpKernel::OpKernel; + explicit StringLengthOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string unit; + OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit)); + OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_)); + } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); @@ -32,10 +37,22 @@ class StringLengthOp : public OpKernel { auto src = input.flat<string>(); auto dst = output->flat<int32>(); - for (int n = 0; n < src.size(); ++n) { - dst(n) = src(n).size(); + switch (unit_) { + case CharUnit::BYTE: + for (int n = 0; n < src.size(); ++n) { + dst(n) = src(n).size(); + } + break; + case CharUnit::UTF8_CHAR: + for (int n = 0; n < src.size(); ++n) { + dst(n) = UTF8StrLen(src(n)); + } + break; } } + + private: + CharUnit unit_ = CharUnit::BYTE; }; REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc new file mode 100644 index 0000000000..3a9803a052 --- /dev/null +++ b/tensorflow/core/kernels/string_util.cc @@ -0,0 +1,63 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/string_util.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace { +inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } +} // namespace + +namespace tensorflow { + +// Sets unit value based on str. +Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) { + if (str == "UTF8") { + *encoding = UnicodeEncoding::UTF8; + } else { + return errors::InvalidArgument(strings::StrCat( + "Invalid encoding \"", str, "\": Should be one of: BYTE")); + } + return Status::OK(); +} + +// Sets unit value based on str. +Status ParseCharUnit(const string& str, CharUnit* unit) { + if (str == "BYTE") { + *unit = CharUnit::BYTE; + } else if (str == "UTF8_CHAR") { + *unit = CharUnit::UTF8_CHAR; + } else { + return errors::InvalidArgument(strings::StrCat( + "Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR")); + } + return Status::OK(); +} + +// Return the number of Unicode characters in a UTF-8 string. +// Result may be incorrect if the input string is not valid UTF-8. +int32 UTF8StrLen(const string& string) { + const int32 byte_size = string.size(); + const char* const end = string.data() + byte_size; + const char* ptr = string.data(); + int32 skipped_count = 0; + while (ptr < end) { + skipped_count += IsTrailByte(*ptr++) ? 1 : 0; + } + const int32 result = byte_size - skipped_count; + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h new file mode 100644 index 0000000000..390cf57702 --- /dev/null +++ b/tensorflow/core/kernels/string_util.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Enumeration for unicode encodings. Used by ops such as +// tf.strings.unicode_encode and tf.strings.unicode_decode. +// TODO(edloper): Add support for: +// UTF16, UTF32, UTF16BE, UTF32BE, UTF16LE, UTF32LE +enum class UnicodeEncoding { UTF8 }; + +// Enumeration for character units. Used by string such as +// tf.strings.length and tf.substr. +// TODO(edloper): Add support for: UTF32_CHAR, etc. +enum class CharUnit { BYTE, UTF8_CHAR }; + +// Sets `encoding` based on `str`. +Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding); + +// Sets `unit` value based on `str`. +Status ParseCharUnit(const string& str, CharUnit* unit); + +// Returns the number of Unicode characters in a UTF-8 string. +// Result may be incorrect if the input string is not valid UTF-8. +int32 UTF8StrLen(const string& string); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ |