diff options
author | 2018-03-01 06:03:38 -0800 | |
---|---|---|
committer | 2018-03-01 06:07:59 -0800 | |
commit | 2b7a7ee30666d160929c9aa3e941fbc94c17cc52 (patch) | |
tree | b93b13c47c0c69bb6534a6894e2253ac749174c0 /tensorflow | |
parent | 46355f9065967dd39cd340b17d91a91f70d2c0c1 (diff) |
Add RegexReplace Op that internally calls RE2::Replace.
PiperOrigin-RevId: 187467840
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt | 25 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/regex_replace_op.cc | 76 | ||||
-rw-r--r-- | tensorflow/core/ops/string_ops.cc | 14 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/regex_replace_op_test.py | 71 | ||||
-rw-r--r-- | tensorflow/python/ops/string_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/tensorflow.pbtxt | 4 |
8 files changed, 212 insertions, 0 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt new file mode 100644 index 0000000000..70ad521926 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt @@ -0,0 +1,25 @@ +op { + graph_op_name: "RegexReplace" + in_arg { + name: "input" + description: "The text to be processed." + } + in_arg { + name: "pattern" + description: "The regular expression to match the input." + } + in_arg { + name: "rewrite" + description: "The rewrite to be applied to the matched expresion." + } + out_arg { + name: "output" + description: "The text after applying pattern and rewrite." + } + attr { + name: "replace_global" + description: "If True, the replacement is global, otherwise the replacement\nis done only on the first match." + } + summary: "Replaces the match of pattern in input with rewrite." + description: "It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)" +} diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 3426cf6e40..feacee5d63 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4155,6 +4155,7 @@ cc_library( ":as_string_op", ":base64_ops", ":reduce_join_op", + ":regex_replace_op", ":string_join_op", ":string_split_op", ":string_to_hash_bucket_op", @@ -4190,6 +4191,12 @@ tf_kernel_library( ) tf_kernel_library( + name = "regex_replace_op", + prefix = "regex_replace_op", + deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"], +) + +tf_kernel_library( name = "string_split_op", prefix = "string_split_op", deps = STRING_DEPS, @@ -5063,6 +5070,7 @@ filegroup( "scatter_nd_op*", "mutex_ops.*", "batch_kernels.*", + "regex_replace_op.cc", ], ), visibility = ["//visibility:public"], diff --git a/tensorflow/core/kernels/regex_replace_op.cc b/tensorflow/core/kernels/regex_replace_op.cc new file mode 100644 index 0000000000..59ec854a79 --- /dev/null +++ b/tensorflow/core/kernels/regex_replace_op.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> + +#include "re2/re2.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class RegexReplaceOp : public OpKernel { + public: + explicit RegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat<string>(); + + const Tensor* pattern_tensor; + OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()), + errors::InvalidArgument("Pattern must be scalar, but received ", + pattern_tensor->shape().DebugString())); + const string pattern = pattern_tensor->flat<string>()(0); + const RE2 match(pattern); + OP_REQUIRES(ctx, match.ok(), + errors::InvalidArgument("Invalid pattern: ", pattern, + ", error: ", match.error())); + + const Tensor* rewrite_tensor; + OP_REQUIRES_OK(ctx, ctx->input("rewrite", &rewrite_tensor)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rewrite_tensor->shape()), + errors::InvalidArgument("Rewrite must be scalar, but received ", + rewrite_tensor->shape().DebugString())); + const string rewrite = rewrite_tensor->flat<string>()(0); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat<string>(); + for (size_t i = 0; i < input_flat.size(); ++i) { + output_flat(i) = input_flat(i); + if (replace_global_) { + RE2::GlobalReplace(&output_flat(i), match, rewrite); + } else { + RE2::Replace(&output_flat(i), match, rewrite); + } + } + } + + private: + bool replace_global_; +}; + +REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU), + RegexReplaceOp); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index e4c5bcfb54..05f216a83e 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -23,6 +23,20 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; +REGISTER_OP("RegexReplace") + .Input("input: string") + .Input("pattern: string") + .Input("rewrite: string") + .Output("output: string") + .Attr("replace_global: bool = true") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(0, c->input(0)); + return Status::OK(); + }); + REGISTER_OP("StringToHashBucketFast") .Input("input: string") .Output("output: int64") diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index c9aa4a252d..0f13e8bba5 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -713,6 +713,18 @@ cuda_py_test( ) tf_py_test( + name = "regex_replace_op_test", + size = "small", + srcs = ["regex_replace_op_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:string_ops", + ], +) + +tf_py_test( name = "save_restore_ops_test", size = "small", srcs = ["save_restore_ops_test.py"], diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py new file mode 100644 index 0000000000..6739ac3224 --- /dev/null +++ b/tensorflow/python/kernel_tests/regex_replace_op_test.py @@ -0,0 +1,71 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for RegexReplace op from string_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class RegexReplaceOpTest(test.TestCase): + + def testRemovePrefix(self): + values = ["a:foo", "a:bar", "a:foo", "b:baz", "b:qux", "ca:b"] + with self.test_session(): + input_vector = constant_op.constant(values, dtypes.string) + stripped = string_ops.regex_replace( + input_vector, "^(a:|b:)", "", replace_global=False).eval() + self.assertAllEqual([b"foo", b"bar", b"foo", b"baz", b"qux", b"ca:b"], + stripped) + + def testRegexReplace(self): + values = ["aba\naba", "abcdabcde"] + with self.test_session(): + input_vector = constant_op.constant(values, dtypes.string) + stripped = string_ops.regex_replace(input_vector, "a.*a", "(\\0)").eval() + self.assertAllEqual([b"(aba)\n(aba)", b"(abcda)bcde"], stripped) + + def testEmptyMatch(self): + values = ["abc", "1"] + with self.test_session(): + input_vector = constant_op.constant(values, dtypes.string) + stripped = string_ops.regex_replace(input_vector, "", "x").eval() + self.assertAllEqual([b"xaxbxcx", b"x1x"], stripped) + + def testInvalidPattern(self): + values = ["abc", "1"] + with self.test_session(): + input_vector = constant_op.constant(values, dtypes.string) + invalid_pattern = "A[" + replace = string_ops.regex_replace(input_vector, invalid_pattern, "x") + with self.assertRaisesOpError("Invalid pattern"): + replace.eval() + + def testGlobal(self): + values = ["ababababab", "abcabcabc", ""] + with self.test_session(): + input_vector = constant_op.constant(values, dtypes.string) + stripped = string_ops.regex_replace(input_vector, "ab", "abc", + True).eval() + self.assertAllEqual([b"abcabcabcabcabc", b"abccabccabcc", b""], stripped) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index 0335d2456a..5bd75b9215 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -17,6 +17,7 @@ See the @{$python/string_ops} guide. +@@regex_replace @@string_to_hash_bucket_fast @@string_to_hash_bucket_strong @@string_to_hash_bucket @@ -139,6 +140,7 @@ def reduce_join(inputs, axis=None, reduce_join.__doc__ = deprecation.rewrite_argument_docstring( gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis") +ops.NotDifferentiable("RegexReplace") ops.NotDifferentiable("StringToHashBucket") ops.NotDifferentiable("StringToHashBucketFast") ops.NotDifferentiable("StringToHashBucketStrong") diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 2333736583..8c9e7af89b 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -1601,6 +1601,10 @@ tf_module { argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { + name: "regex_replace" + argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { name: "register_tensor_conversion_function" argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], " } |