aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/string_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/string_ops.cc')
-rw-r--r--tensorflow/core/ops/string_ops.cc109
1 files changed, 109 insertions, 0 deletions
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index a112e1c879..c427d247b1 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -281,4 +281,113 @@ input: Base64 strings to decode.
output: Decoded strings.
)doc");
+REGISTER_OP("Substr")
+ .Input("input: string")
+ .Input("pos: T")
+ .Input("len: T")
+ .Output("output: string")
+ .Attr("T: {int32, int64}")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle pos_shape = c->input(1);
+ ShapeHandle len_shape = c->input(2);
+ ShapeHandle unused;
+ // Check that pos/len have same rank
+ TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused));
+ // Check that dimensions are equal
+ for (int32 i = 0; i < c->Rank(pos_shape); ++i) {
+ DimensionHandle pos_dim = c->Dim(pos_shape, i);
+ DimensionHandle len_dim = c->Dim(len_shape, i);
+ if (c->Value(pos_dim) != c->Value(len_dim)) {
+ return errors::InvalidArgument("pos and len shapes must match: ",
+ c->DebugString(pos_shape), " vs. ",
+ c->DebugString(len_shape));
+ }
+ }
+ // c->input(0) is the ShapeHandle to input strings
+ // BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1).
+ return shape_inference::BroadcastBinaryOpShapeFn(c);
+ })
+ .Doc(R"doc(
+Return substrings from `Tensor` of strings.
+
+For each string in the input `Tensor`, creates a substring starting at index
+`pos` with a total length of `len`.
+
+If `len` defines a substring that would extend beyond the length of the input
+string, then as many characters as possible are used.
+
+If `pos` is negative or specifies a character index larger than any of the input
+strings, then an `InvalidArgumentError` is thrown.
+
+`pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on
+Op creation.
+
+*NOTE*: `Substr` supports broadcasting up to two dimensions. More about
+broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+
+---
+
+Examples
+
+Using scalar `pos` and `len`:
+
+```
+input = [b'Hello', b'World']
+position = 1
+length = 3
+
+output = [b'ell', b'orl']
+```
+
+Using `pos` and `len` with same shape as `input`:
+
+```
+input = [[b'ten', b'eleven', b'twelve'],
+ [b'thirteen', b'fourteen', b'fifteen'],
+ [b'sixteen', b'seventeen', b'eighteen']]
+position = [[1, 2, 3],
+ [1, 2, 3],
+ [1, 2, 3]]
+length = [[2, 3, 4],
+ [4, 3, 2],
+ [5, 5, 5]]
+
+output = [[b'en', b'eve', b'lve'],
+ [b'hirt', b'urt', b'te'],
+ [b'ixtee', b'vente', b'hteen']]
+```
+
+Broadcasting `pos` and `len` onto `input`:
+
+```
+input = [[b'ten', b'eleven', b'twelve'],
+ [b'thirteen', b'fourteen', b'fifteen'],
+ [b'sixteen', b'seventeen', b'eighteen'],
+ [b'nineteen', b'twenty', b'twentyone']]
+position = [1, 2, 3]
+length = [1, 2, 3]
+
+output = [[b'e', b'ev', b'lve'],
+ [b'h', b'ur', b'tee'],
+ [b'i', b've', b'hte'],
+ [b'i', b'en', b'nty']]
+```
+
+Broadcasting `input` onto `pos` and `len`:
+
+```
+input = b'thirteen'
+position = [1, 5, 7]
+length = [3, 2, 1]
+
+output = [b'hir', b'ee', b'n"]
+```
+
+input: Tensor of strings
+pos: Scalar defining the position of first character in each substring
+len: Scalar defining the number of characters to include in each substring
+output: Tensor of substrings
+)doc");
+
} // namespace tensorflow