From d44142b807bba47464d2a873e2dfcd641236591e Mon Sep 17 00:00:00 2001 From: Jeremiah Harmsen Date: Mon, 20 Aug 2018 09:42:49 -0700 Subject: RELNOTES: Performance enhancements for StringSplitOp & StringSplitV2Op. PiperOrigin-RevId: 209432936 --- tensorflow/core/kernels/BUILD | 18 +++ tensorflow/core/kernels/string_split_op.cc | 111 +++++++++++++----- tensorflow/core/kernels/string_split_op_test.cc | 129 +++++++++++++++++++++ .../python/kernel_tests/string_split_op_test.py | 22 +++- 4 files changed, 250 insertions(+), 30 deletions(-) create mode 100644 tensorflow/core/kernels/string_split_op_test.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index afe366a194..e07d292629 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4475,6 +4475,24 @@ tf_kernel_library( deps = STRING_DEPS, ) +tf_cc_test( + name = "string_split_op_test", + size = "small", + srcs = ["string_split_op_test.cc"], + deps = [ + ":string_split_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + ], +) + tf_kernel_library( name = "string_strip_op", prefix = "string_strip_op", diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc index 26ab72f12e..3884370a6c 100644 --- a/tensorflow/core/kernels/string_split_op.cc +++ b/tensorflow/core/kernels/string_split_op.cc @@ -26,25 +26,81 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { - namespace { +// Split input string `str` based on a character delimiter. +// Returns a vector of StringPieces which are valid as long as input `str` +// is valid. +// Note: The single character delimiter is a common case and is implemented as +// a series of finds in the input string, making it much more effcient than +// SplitOnCharSet. +template +std::vector SplitOnChar(const string& str, const char delim, + Predicate p) { + std::vector result; + StringPiece text(str); + auto f = text.find(delim); + while (f != StringPiece::npos) { + StringPiece token = text.substr(0, f); + if (p(token)) { + result.emplace_back(token); + } + text.remove_prefix(f + 1); + f = text.find(delim); + } + if (p(text)) { + result.push_back(text); + } + return result; +} -std::vector Split(const string& str, const string& delimiter, - const bool skipEmpty) { - if (!delimiter.empty()) { - if (skipEmpty) { - return str_util::Split(str, delimiter, str_util::SkipEmpty()); +// Split input string `str` based on a set of character delimiters. +// Returns a vector of StringPieces which are valid as long as input `str` +// is valid. +// Based on str_util::Split. +template +std::vector SplitOnCharSet(const string& str, + const string& delim_set, Predicate p) { + std::vector result; + StringPiece text(str); + StringPiece delims(delim_set); + size_t token_start = 0; + for (size_t i = 0; i < text.size() + 1; i++) { + if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) { + StringPiece token(text.data() + token_start, i - token_start); + if (p(token)) { + result.emplace_back(token); + } + token_start = i + 1; } - return str_util::Split(str, delimiter); } - std::vector char_vector(str.size()); - for (size_t i = 0; i < str.size(); ++i) { - char_vector[i] = str[i]; + return result; +} + +// Split input string `str` based on given delimiter. +// Returns a vector of StringPieces which are valid as long as input `str` +// is valid. +template +std::vector Split(const string& str, const string& delimiter, + Predicate predicate) { + if (str.empty()) { + return std::vector(); + } + if (delimiter.empty()) { + std::vector result; + result.resize(str.size()); + for (size_t i = 0; i < str.size(); ++i) { + result[i] = StringPiece(str.data() + i, 1); + } + return result; } - return char_vector; + if (delimiter.size() == 1) { + return SplitOnChar(str, delimiter[0], predicate); + } + return SplitOnCharSet(str, delimiter, predicate); } -std::vector SplitV2(const string& str, StringPiece sep, int maxsplit) { +std::vector SplitV2(const string& str, StringPiece sep, + int maxsplit) { // This SplitV2 method matches the behavior of python's str.split: // If sep is given, consecutive delimiters are not grouped together // and are deemed to delimit empty strings (for example, '1,,2'.split(',') @@ -59,11 +115,11 @@ std::vector SplitV2(const string& str, StringPiece sep, int maxsplit) { // splitting an empty string or a string consisting of just whitespace // with a None separator returns []. - std::vector result; + std::vector result; StringPiece text(str); if (maxsplit == 0) { - result.emplace_back(std::string(text)); + result.emplace_back(text); return result; } @@ -73,11 +129,11 @@ std::vector SplitV2(const string& str, StringPiece sep, int maxsplit) { str_util::RemoveLeadingWhitespace(&text); int split = 0; while (str_util::ConsumeNonWhitespace(&text, &token)) { - result.emplace_back(std::string(token)); + result.push_back(token); str_util::RemoveLeadingWhitespace(&text); ++split; if (maxsplit > 0 && split == maxsplit) { - result.emplace_back(std::string(text)); + result.push_back(text); return result; } } @@ -87,17 +143,17 @@ std::vector SplitV2(const string& str, StringPiece sep, int maxsplit) { int split = 0; while (p != text.end()) { StringPiece token = text.substr(0, p - text.begin()); - result.emplace_back(std::string(token)); + result.push_back(token); text.remove_prefix(token.size()); text.remove_prefix(sep.size()); ++split; if (maxsplit > 0 && split == maxsplit) { - result.emplace_back(std::string(text)); + result.push_back(StringPiece(text)); return result; } p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); } - result.emplace_back(std::string(text)); + result.push_back(text); return result; } @@ -134,7 +190,7 @@ class StringSplitOp : public OpKernel { const auto delimiter_vec = delimiter_tensor->flat(); const string& delimiter = delimiter_vec(0); // Empty delimiter means split the input character by character. - std::vector tokens; + std::vector tokens; // Guess that we'll be unpacking a handful of tokens per example. static constexpr int kReserveSize = 4; tokens.reserve(batch_size * kReserveSize); @@ -143,12 +199,15 @@ class StringSplitOp : public OpKernel { int64 max_num_entries = 0; std::vector num_indices(batch_size); for (int64 i = 0; i < batch_size; ++i) { - std::vector parts = Split(input_vec(i), delimiter, skip_empty_); + std::vector parts = + skip_empty_ ? Split(input_vec(i), delimiter, str_util::SkipEmpty()) + : Split(input_vec(i), delimiter, str_util::AllowEmpty()); int64 n_entries = parts.size(); num_indices[i] = n_entries; output_size += n_entries; max_num_entries = std::max(max_num_entries, n_entries); - tokens.insert(tokens.end(), parts.begin(), parts.end()); + tokens.insert(tokens.end(), std::make_move_iterator(parts.begin()), + std::make_move_iterator(parts.end())); } Tensor* sp_indices_t; @@ -170,7 +229,7 @@ class StringSplitOp : public OpKernel { for (size_t j = 0; j < num_indices[i]; ++j) { sp_indices(c, 0) = i; sp_indices(c, 1) = j; - sp_tokens(c) = tokens[c]; + sp_tokens(c).assign(tokens[c].data(), tokens[c].size()); ++c; } } @@ -204,7 +263,7 @@ class StringSplitV2Op : public OpKernel { sep_tensor->shape().DebugString())); const auto sep_vec = sep_tensor->flat(); StringPiece sep(sep_vec(0)); - std::vector tokens; + std::vector tokens; // Guess that we'll be unpacking a handful of tokens per example. static constexpr int kReserveSize = 4; tokens.reserve(batch_size * kReserveSize); @@ -213,7 +272,7 @@ class StringSplitV2Op : public OpKernel { int64 max_num_entries = 0; std::vector num_indices(batch_size); for (int64 i = 0; i < batch_size; ++i) { - std::vector parts = SplitV2(input_vec(i), sep, maxsplit_); + std::vector parts = SplitV2(input_vec(i), sep, maxsplit_); int64 n_entries = parts.size(); num_indices[i] = n_entries; output_size += n_entries; @@ -240,7 +299,7 @@ class StringSplitV2Op : public OpKernel { for (size_t j = 0; j < num_indices[i]; ++j) { sp_indices(c, 0) = i; sp_indices(c, 1) = j; - sp_tokens(c) = tokens[c]; + sp_tokens(c).assign(tokens[c].data(), tokens[c].size()); ++c; } } diff --git a/tensorflow/core/kernels/string_split_op_test.cc b/tensorflow/core/kernels/string_split_op_test.cc new file mode 100644 index 0000000000..58ad61adc8 --- /dev/null +++ b/tensorflow/core/kernels/string_split_op_test.cc @@ -0,0 +1,129 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +// Test data from the TensorFlow README.md. +const char* lines[] = { + "**TensorFlow** is an open source software library for numerical " + "computation using data flow graphs.", + "The graph nodes represent mathematical operations, while the graph edges " + "represent the multidimensional data arrays (tensors) that flow between " + "them.", + "This flexible architecture enables you to deploy computation to one or " + "more CPUs or GPUs in a desktop, server, or mobile device without " + "rewriting code.", + "TensorFlow also includes " + "[TensorBoard](https://www.tensorflow.org/guide/" + "summaries_and_tensorboard), a data visualization toolkit.", + "TensorFlow was originally developed by researchers and engineers working " + "on the Google Brain team within Google's Machine Intelligence Research " + "organization for the purposes of conducting machine learning and deep " + "neural networks research.", + "The system is general enough to be applicable in a wide variety of other " + "domains, as well.", + "TensorFlow provides stable Python API and C APIs as well as without API " + "backwards compatibility guarantee like C++, Go, Java, JavaScript and " + "Swift."}; + +Tensor GetTestTensor(int batch) { + const int sz = TF_ARRAYSIZE(lines); + Tensor t(DT_STRING, {batch}); + auto s = t.flat(); + for (int i = 0; i < batch; ++i) { + s(i) = lines[i % sz]; + } + return t; +} + +Graph* SetupStringSplitGraph(const Tensor& input) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor delim(DT_STRING, TensorShape({})); + delim.flat().setConstant(" "); + + TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplit") + .Input(test::graph::Constant(g, input)) + .Input(test::graph::Constant(g, delim)) + .Finalize(g, nullptr /* node */)); + return g; +} + +void BM_StringSplit(int iters, int batch_size) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast(iters)); + testing::UseRealTime(); + Tensor input = GetTestTensor(batch_size); + Graph* g = SetupStringSplitGraph(input); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +BENCHMARK(BM_StringSplit) + ->Arg(1) + ->Arg(8) + ->Arg(16) + ->Arg(32) + ->Arg(64) + ->Arg(128) + ->Arg(256); + +Graph* SetupStringSplitV2Graph(const Tensor& input) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor sep(DT_STRING, TensorShape({})); + sep.flat().setConstant(" "); + + TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplitV2") + .Input(test::graph::Constant(g, input)) + .Input(test::graph::Constant(g, sep)) + .Finalize(g, nullptr /* node */)); + return g; +} + +void BM_StringSplitV2(int iters, int batch_size) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast(iters)); + testing::UseRealTime(); + Tensor input = GetTestTensor(batch_size); + Graph* g = SetupStringSplitV2Graph(input); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +BENCHMARK(BM_StringSplitV2) + ->Arg(1) + ->Arg(8) + ->Arg(16) + ->Arg(32) + ->Arg(64) + ->Arg(128) + ->Arg(256); + +} // end namespace tensorflow diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py index e20daccb28..b6a0f45adc 100644 --- a/tensorflow/python/kernel_tests/string_split_op_test.py +++ b/tensorflow/python/kernel_tests/string_split_op_test.py @@ -58,14 +58,28 @@ class StringSplitOpTest(test.TestCase): self.assertAllEqual(shape, [3, 5]) def testStringSplitEmptyToken(self): - strings = [" hello ", "", "world "] + strings = ["", " a", "b ", " c", " ", " d ", " e", "f ", " g ", " "] with self.test_session() as sess: tokens = string_ops.string_split(strings) indices, values, shape = sess.run(tokens) - self.assertAllEqual(indices, [[0, 0], [2, 0]]) - self.assertAllEqual(values, [b"hello", b"world"]) - self.assertAllEqual(shape, [3, 1]) + self.assertAllEqual( + indices, + [[1, 0], [2, 0], [3, 0], [5, 0], [6, 0], [7, 0], [8, 0]]) + self.assertAllEqual(values, [b"a", b"b", b"c", b"d", b"e", b"f", b"g"]) + self.assertAllEqual(shape, [10, 1]) + + def testStringSplitOnSetEmptyToken(self): + strings = ["", " a", "b ", " c", " ", " d ", ". e", "f .", " .g. ", " ."] + + with self.test_session() as sess: + tokens = string_ops.string_split(strings, delimiter=" .") + indices, values, shape = sess.run(tokens) + self.assertAllEqual( + indices, + [[1, 0], [2, 0], [3, 0], [5, 0], [6, 0], [7, 0], [8, 0]]) + self.assertAllEqual(values, [b"a", b"b", b"c", b"d", b"e", b"f", b"g"]) + self.assertAllEqual(shape, [10, 1]) def testStringSplitWithDelimiter(self): strings = ["hello|world", "hello world"] -- cgit v1.2.3