aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-01 06:03:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-01 06:07:59 -0800
commit2b7a7ee30666d160929c9aa3e941fbc94c17cc52 (patch)
treeb93b13c47c0c69bb6534a6894e2253ac749174c0 /tensorflow
parent46355f9065967dd39cd340b17d91a91f70d2c0c1 (diff)
Add RegexReplace Op that internally calls RE2::Replace.
PiperOrigin-RevId: 187467840
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt25
-rw-r--r--tensorflow/core/kernels/BUILD8
-rw-r--r--tensorflow/core/kernels/regex_replace_op.cc76
-rw-r--r--tensorflow/core/ops/string_ops.cc14
-rw-r--r--tensorflow/python/kernel_tests/BUILD12
-rw-r--r--tensorflow/python/kernel_tests/regex_replace_op_test.py71
-rw-r--r--tensorflow/python/ops/string_ops.py2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
8 files changed, 212 insertions, 0 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt
new file mode 100644
index 0000000000..70ad521926
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RegexReplace.pbtxt
@@ -0,0 +1,25 @@
+op {
+ graph_op_name: "RegexReplace"
+ in_arg {
+ name: "input"
+ description: "The text to be processed."
+ }
+ in_arg {
+ name: "pattern"
+ description: "The regular expression to match the input."
+ }
+ in_arg {
+ name: "rewrite"
+ description: "The rewrite to be applied to the matched expresion."
+ }
+ out_arg {
+ name: "output"
+ description: "The text after applying pattern and rewrite."
+ }
+ attr {
+ name: "replace_global"
+ description: "If True, the replacement is global, otherwise the replacement\nis done only on the first match."
+ }
+ summary: "Replaces the match of pattern in input with rewrite."
+ description: "It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)"
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 3426cf6e40..feacee5d63 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4155,6 +4155,7 @@ cc_library(
":as_string_op",
":base64_ops",
":reduce_join_op",
+ ":regex_replace_op",
":string_join_op",
":string_split_op",
":string_to_hash_bucket_op",
@@ -4190,6 +4191,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "regex_replace_op",
+ prefix = "regex_replace_op",
+ deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
+)
+
+tf_kernel_library(
name = "string_split_op",
prefix = "string_split_op",
deps = STRING_DEPS,
@@ -5063,6 +5070,7 @@ filegroup(
"scatter_nd_op*",
"mutex_ops.*",
"batch_kernels.*",
+ "regex_replace_op.cc",
],
),
visibility = ["//visibility:public"],
diff --git a/tensorflow/core/kernels/regex_replace_op.cc b/tensorflow/core/kernels/regex_replace_op.cc
new file mode 100644
index 0000000000..59ec854a79
--- /dev/null
+++ b/tensorflow/core/kernels/regex_replace_op.cc
@@ -0,0 +1,76 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "re2/re2.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class RegexReplaceOp : public OpKernel {
+ public:
+ explicit RegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ const Tensor* pattern_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
+ errors::InvalidArgument("Pattern must be scalar, but received ",
+ pattern_tensor->shape().DebugString()));
+ const string pattern = pattern_tensor->flat<string>()(0);
+ const RE2 match(pattern);
+ OP_REQUIRES(ctx, match.ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", match.error()));
+
+ const Tensor* rewrite_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("rewrite", &rewrite_tensor));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rewrite_tensor->shape()),
+ errors::InvalidArgument("Rewrite must be scalar, but received ",
+ rewrite_tensor->shape().DebugString()));
+ const string rewrite = rewrite_tensor->flat<string>()(0);
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<string>();
+ for (size_t i = 0; i < input_flat.size(); ++i) {
+ output_flat(i) = input_flat(i);
+ if (replace_global_) {
+ RE2::GlobalReplace(&output_flat(i), match, rewrite);
+ } else {
+ RE2::Replace(&output_flat(i), match, rewrite);
+ }
+ }
+ }
+
+ private:
+ bool replace_global_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU),
+ RegexReplaceOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index e4c5bcfb54..05f216a83e 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -23,6 +23,20 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
+REGISTER_OP("RegexReplace")
+ .Input("input: string")
+ .Input("pattern: string")
+ .Input("rewrite: string")
+ .Output("output: string")
+ .Attr("replace_global: bool = true")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ });
+
REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
.Output("output: int64")
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index c9aa4a252d..0f13e8bba5 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -713,6 +713,18 @@ cuda_py_test(
)
tf_py_test(
+ name = "regex_replace_op_test",
+ size = "small",
+ srcs = ["regex_replace_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:string_ops",
+ ],
+)
+
+tf_py_test(
name = "save_restore_ops_test",
size = "small",
srcs = ["save_restore_ops_test.py"],
diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py
new file mode 100644
index 0000000000..6739ac3224
--- /dev/null
+++ b/tensorflow/python/kernel_tests/regex_replace_op_test.py
@@ -0,0 +1,71 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for RegexReplace op from string_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class RegexReplaceOpTest(test.TestCase):
+
+ def testRemovePrefix(self):
+ values = ["a:foo", "a:bar", "a:foo", "b:baz", "b:qux", "ca:b"]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ stripped = string_ops.regex_replace(
+ input_vector, "^(a:|b:)", "", replace_global=False).eval()
+ self.assertAllEqual([b"foo", b"bar", b"foo", b"baz", b"qux", b"ca:b"],
+ stripped)
+
+ def testRegexReplace(self):
+ values = ["aba\naba", "abcdabcde"]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ stripped = string_ops.regex_replace(input_vector, "a.*a", "(\\0)").eval()
+ self.assertAllEqual([b"(aba)\n(aba)", b"(abcda)bcde"], stripped)
+
+ def testEmptyMatch(self):
+ values = ["abc", "1"]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ stripped = string_ops.regex_replace(input_vector, "", "x").eval()
+ self.assertAllEqual([b"xaxbxcx", b"x1x"], stripped)
+
+ def testInvalidPattern(self):
+ values = ["abc", "1"]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ invalid_pattern = "A["
+ replace = string_ops.regex_replace(input_vector, invalid_pattern, "x")
+ with self.assertRaisesOpError("Invalid pattern"):
+ replace.eval()
+
+ def testGlobal(self):
+ values = ["ababababab", "abcabcabc", ""]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ stripped = string_ops.regex_replace(input_vector, "ab", "abc",
+ True).eval()
+ self.assertAllEqual([b"abcabcabcabcabc", b"abccabccabcc", b""], stripped)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 0335d2456a..5bd75b9215 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -17,6 +17,7 @@
See the @{$python/string_ops} guide.
+@@regex_replace
@@string_to_hash_bucket_fast
@@string_to_hash_bucket_strong
@@string_to_hash_bucket
@@ -139,6 +140,7 @@ def reduce_join(inputs, axis=None,
reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
+ops.NotDifferentiable("RegexReplace")
ops.NotDifferentiable("StringToHashBucket")
ops.NotDifferentiable("StringToHashBucketFast")
ops.NotDifferentiable("StringToHashBucketStrong")
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 2333736583..8c9e7af89b 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1601,6 +1601,10 @@ tf_module {
argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "regex_replace"
+ argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ }
+ member_method {
name: "register_tensor_conversion_function"
argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], "
}