aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 11:03:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 11:11:43 -0700
commitca5952670d98b568fa4ac671cf2310d78474c525 (patch)
tree91da2f08c3755039ff0f4dff33c360e62ce80e00
parent025277a1598fa227b53ddc4e316a7a953b2006c8 (diff)
Add StaticRegexFullMatch which can be used in place of RegexFullMatch when the regex pattern are fixed.
This allows the Op to perform the expensive regex compilation once upon creation instead of with each call to compute. RELNOTES: Performance improvements for regex full match operations. PiperOrigin-RevId: 211835278
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt29
-rw-r--r--tensorflow/core/kernels/regex_full_match_op.cc33
-rw-r--r--tensorflow/core/ops/string_ops.cc6
-rw-r--r--tensorflow/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/regex_full_match_op_test.py60
-rw-r--r--tensorflow/python/ops/string_ops.py34
6 files changed, 151 insertions, 12 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
new file mode 100644
index 0000000000..6d9d9908ca
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "StaticRegexFullMatch"
+ in_arg {
+ name: "input"
+ description: <<END
+A string tensor of the text to be processed.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A bool tensor with the same shape as `input`.
+END
+ }
+ attr {
+ name: "pattern"
+ description: "The regular expression to match the input."
+ }
+ summary: "Check if the input matches the regex pattern."
+ description: <<END
+The input is a string tensor of any shape. The pattern is the
+regular expression to be matched with every element of the input tensor.
+The boolean values (True or False) of the output tensor indicate
+if the input matches the regex pattern provided.
+
+The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc
index 5863a2c8e4..7edaaad8f7 100644
--- a/tensorflow/core/kernels/regex_full_match_op.cc
+++ b/tensorflow/core/kernels/regex_full_match_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -56,4 +57,36 @@ class RegexFullMatchOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
RegexFullMatchOp);
+class StaticRegexFullMatchOp : public OpKernel {
+ public:
+ explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string pattern;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
+ re_ = MakeUnique<RE2>(pattern);
+ OP_REQUIRES(ctx, re_->ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", re_->error()));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<bool>();
+ for (size_t i = 0; i < input_flat.size(); ++i) {
+ output_flat(i) = RE2::FullMatch(input_flat(i), *re_);
+ }
+ }
+
+ private:
+ std::unique_ptr<RE2> re_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU),
+ StaticRegexFullMatchOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 7aa1e71809..ef8b15dc8a 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -56,6 +56,12 @@ REGISTER_OP("RegexFullMatch")
return Status::OK();
});
+REGISTER_OP("StaticRegexFullMatch")
+ .Input("input: string")
+ .Attr("pattern: string")
+ .Output("output: bool")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
.Output("output: int64")
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 58c8975daa..d4396bf3eb 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -779,6 +779,7 @@ tf_py_test(
size = "small",
srcs = ["regex_full_match_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
index 5daae1b79b..7bd8c3ca27 100644
--- a/tensorflow/python/kernel_tests/regex_full_match_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
@@ -18,37 +18,77 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class RegexFullMatchOpTest(test.TestCase):
+@parameterized.parameters(
+ (gen_string_ops.regex_full_match),
+ (gen_string_ops.static_regex_full_match))
+class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase):
- def testRegexFullMatch(self):
+ def testRegexFullMatch(self, op):
values = ["abaaba", "abcdabcde"]
with self.test_session():
- input_vector = constant_op.constant(values, dtypes.string)
- matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
+ input_tensor = constant_op.constant(values, dtypes.string)
+ matched = op(input_tensor, "a.*a").eval()
self.assertAllEqual([True, False], matched)
- def testEmptyMatch(self):
+ def testRegexFullMatchTwoDims(self, op):
+ values = [["abaaba", "abcdabcde"], ["acdcba", "ebcda"]]
+ with self.test_session():
+ input_tensor = constant_op.constant(values, dtypes.string)
+ matched = op(input_tensor, "a.*a").eval()
+ self.assertAllEqual([[True, False], [True, False]], matched)
+
+ def testEmptyMatch(self, op):
values = ["abc", "1"]
with self.test_session():
- input_vector = constant_op.constant(values, dtypes.string)
- matched = string_ops.regex_full_match(input_vector, "").eval()
+ input_tensor = constant_op.constant(values, dtypes.string)
+ matched = op(input_tensor, "").eval()
self.assertAllEqual([False, False], matched)
- def testInvalidPattern(self):
+ def testInvalidPattern(self, op):
values = ["abc", "1"]
with self.test_session():
- input_vector = constant_op.constant(values, dtypes.string)
+ input_tensor = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
- matched = string_ops.regex_full_match(input_vector, invalid_pattern)
+ matched = op(input_tensor, invalid_pattern)
with self.assertRaisesOpError("Invalid pattern"):
matched.eval()
+class RegexFullMatchOpTest(test.TestCase):
+
+ def testRegexFullMatchDelegation(self):
+ with compat.forward_compatibility_horizon(2018, 11, 1):
+ with self.test_session():
+ input_tensor = constant_op.constant("foo", dtypes.string)
+ pattern = "[a-z]"
+ op = string_ops.regex_full_match(input_tensor, pattern)
+ self.assertTrue(op.name.startswith("RegexFullMatch"), op.name)
+
+ pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
+ op_tensor = string_ops.regex_full_match(input_tensor, pattern_tensor)
+ self.assertTrue(op_tensor.name.startswith("RegexFullMatch"), op.name)
+
+ def testStaticRegexFullMatchDelegation(self):
+ with compat.forward_compatibility_horizon(2018, 11, 20):
+ with self.test_session():
+ input_tensor = constant_op.constant("foo", dtypes.string)
+ pattern = "[a-z]*"
+ op = string_ops.regex_full_match(input_tensor, pattern)
+ self.assertTrue(op.name.startswith("StaticRegexFullMatch"), op.name)
+
+ pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
+ op_vec = string_ops.regex_full_match(input_tensor, pattern_tensor)
+ self.assertTrue(op_vec.name.startswith("RegexFullMatch"), op.name)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index c832ba4e2a..29fefbe3a5 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -41,12 +41,41 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
+
+# pylint: disable=redefined-builtin
+def regex_full_match(input, pattern, name=None):
+ r"""Match elements of `input` with regex `pattern`.
+
+ Args:
+ input: string `Tensor`, the source strings to process.
+ pattern: string or scalar string `Tensor`, regular expression to use,
+ see more details at https://github.com/google/re2/wiki/Syntax
+ name: Name of the op.
+
+ Returns:
+ bool `Tensor` of the same shape as `input` with match results.
+ """
+ # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
+ if not compat.forward_compatible(2018, 11, 10):
+ return gen_string_ops.regex_full_match(
+ input=input, pattern=pattern, name=name)
+ if isinstance(pattern, util_compat.bytes_or_text_types):
+ # When `pattern` is static through the life of the op we can
+ # use a version which performs the expensive regex compilation once at
+ # creation time.
+ return gen_string_ops.static_regex_full_match(
+ input=input, pattern=pattern, name=name)
+ return gen_string_ops.regex_full_match(
+ input=input, pattern=pattern, name=name)
+
+regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
+
# Expose regex_full_match in strings namespace
tf_export("strings.regex_full_match")(regex_full_match)
def regex_replace(source, pattern, rewrite, replace_global=True):
- r"""Replace elements of `source` matching regex `pattern with `rewrite`.
+ r"""Replace elements of `source` matching regex `pattern` with `rewrite`.
Args:
source: string `Tensor`, the source strings to process.
@@ -128,6 +157,7 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv
shape.set_shape([2])
return sparse_tensor.SparseTensor(indices, values, shape)
+
@tf_export("strings.split")
def string_split_v2(source, sep=None, maxsplit=-1):
"""Split elements of `source` based on `sep` into a `SparseTensor`.
@@ -170,7 +200,7 @@ def string_split_v2(source, sep=None, maxsplit=-1):
second column corresponds to the index of the split component in this row.
"""
if sep is None:
- sep = ''
+ sep = ""
sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
source = ops.convert_to_tensor(source, dtype=dtypes.string)