diff options
Diffstat (limited to 'tensorflow/core/kernels/string_split_op.cc')
-rw-r--r-- | tensorflow/core/kernels/string_split_op.cc | 130 |
1 files changed, 130 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc index 4c2b312c34..26ab72f12e 100644 --- a/tensorflow/core/kernels/string_split_op.cc +++ b/tensorflow/core/kernels/string_split_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { @@ -43,6 +44,63 @@ std::vector<string> Split(const string& str, const string& delimiter, return char_vector; } +std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) { + // This SplitV2 method matches the behavior of python's str.split: + // If sep is given, consecutive delimiters are not grouped together + // and are deemed to delimit empty strings (for example, '1,,2'.split(',') + // returns ['1', '', '2']). The sep argument may consist of multiple + // characters (for example, '1<>2<>3'.split('<>') returns ['1', '2', '3']). + // Splitting an empty string with a specified separator returns ['']. + // + // If sep is not specified or is None, a different splitting algorithm is + // applied: runs of consecutive whitespace are regarded as a single + // separator, and the result will contain no empty strings at the start or + // end if the string has leading or trailing whitespace. Consequently, + // splitting an empty string or a string consisting of just whitespace + // with a None separator returns []. + + std::vector<string> result; + + StringPiece text(str); + if (maxsplit == 0) { + result.emplace_back(std::string(text)); + return result; + } + + if (sep.empty()) { + StringPiece token; + // Remove leading whitespaces. + str_util::RemoveLeadingWhitespace(&text); + int split = 0; + while (str_util::ConsumeNonWhitespace(&text, &token)) { + result.emplace_back(std::string(token)); + str_util::RemoveLeadingWhitespace(&text); + ++split; + if (maxsplit > 0 && split == maxsplit) { + result.emplace_back(std::string(text)); + return result; + } + } + return result; + } + auto p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); + int split = 0; + while (p != text.end()) { + StringPiece token = text.substr(0, p - text.begin()); + result.emplace_back(std::string(token)); + text.remove_prefix(token.size()); + text.remove_prefix(sep.size()); + ++split; + if (maxsplit > 0 && split == maxsplit) { + result.emplace_back(std::string(text)); + return result; + } + p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); + } + result.emplace_back(std::string(text)); + return result; +} + } // namespace class StringSplitOp : public OpKernel { @@ -122,6 +180,78 @@ class StringSplitOp : public OpKernel { bool skip_empty_; }; +class StringSplitV2Op : public OpKernel { + public: + explicit StringSplitV2Op(OpKernelConstruction* context) + : OpKernel(context), maxsplit_(-1) { + OP_REQUIRES_OK(context, context->GetAttr("maxsplit", &maxsplit_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()), + errors::InvalidArgument("input must be a vector, got shape: ", + input_tensor->shape().DebugString())); + + const auto input_vec = input_tensor->vec<string>(); + const int64 batch_size = input_vec.dimension(0); + + const Tensor* sep_tensor; + OP_REQUIRES_OK(ctx, ctx->input("sep", &sep_tensor)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sep_tensor->shape()), + errors::InvalidArgument("sep must be a scalar, got shape: ", + sep_tensor->shape().DebugString())); + const auto sep_vec = sep_tensor->flat<string>(); + StringPiece sep(sep_vec(0)); + std::vector<string> tokens; + // Guess that we'll be unpacking a handful of tokens per example. + static constexpr int kReserveSize = 4; + tokens.reserve(batch_size * kReserveSize); + + int64 output_size = 0; + int64 max_num_entries = 0; + std::vector<int64> num_indices(batch_size); + for (int64 i = 0; i < batch_size; ++i) { + std::vector<string> parts = SplitV2(input_vec(i), sep, maxsplit_); + int64 n_entries = parts.size(); + num_indices[i] = n_entries; + output_size += n_entries; + max_num_entries = std::max(max_num_entries, n_entries); + tokens.insert(tokens.end(), parts.begin(), parts.end()); + } + + Tensor* sp_indices_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}), + &sp_indices_t)); + Tensor* sp_tokens_t; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t)); + Tensor* sp_shape_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t)); + + auto sp_indices = sp_indices_t->matrix<int64>(); + auto sp_tokens = sp_tokens_t->vec<string>(); + auto sp_shape = sp_shape_t->vec<int64>(); + sp_shape(0) = batch_size; + sp_shape(1) = max_num_entries; + size_t c = 0; + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_indices[i]; ++j) { + sp_indices(c, 0) = i; + sp_indices(c, 1) = j; + sp_tokens(c) = tokens[c]; + ++c; + } + } + } + + private: + int maxsplit_; +}; + REGISTER_KERNEL_BUILDER(Name("StringSplit").Device(DEVICE_CPU), StringSplitOp); +REGISTER_KERNEL_BUILDER(Name("StringSplitV2").Device(DEVICE_CPU), + StringSplitV2Op); } // namespace tensorflow |