diff options
Diffstat (limited to 'tensorflow/core/ops/string_ops.cc')
-rw-r--r-- | tensorflow/core/ops/string_ops.cc | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index a112e1c879..c427d247b1 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -281,4 +281,113 @@ input: Base64 strings to decode. output: Decoded strings. )doc"); +REGISTER_OP("Substr") + .Input("input: string") + .Input("pos: T") + .Input("len: T") + .Output("output: string") + .Attr("T: {int32, int64}") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle pos_shape = c->input(1); + ShapeHandle len_shape = c->input(2); + ShapeHandle unused; + // Check that pos/len have same rank + TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused)); + // Check that dimensions are equal + for (int32 i = 0; i < c->Rank(pos_shape); ++i) { + DimensionHandle pos_dim = c->Dim(pos_shape, i); + DimensionHandle len_dim = c->Dim(len_shape, i); + if (c->Value(pos_dim) != c->Value(len_dim)) { + return errors::InvalidArgument("pos and len shapes must match: ", + c->DebugString(pos_shape), " vs. ", + c->DebugString(len_shape)); + } + } + // c->input(0) is the ShapeHandle to input strings + // BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1). + return shape_inference::BroadcastBinaryOpShapeFn(c); + }) + .Doc(R"doc( +Return substrings from `Tensor` of strings. + +For each string in the input `Tensor`, creates a substring starting at index +`pos` with a total length of `len`. + +If `len` defines a substring that would extend beyond the length of the input +string, then as many characters as possible are used. + +If `pos` is negative or specifies a character index larger than any of the input +strings, then an `InvalidArgumentError` is thrown. + +`pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on +Op creation. + +*NOTE*: `Substr` supports broadcasting up to two dimensions. More about +broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + +--- + +Examples + +Using scalar `pos` and `len`: + +``` +input = [b'Hello', b'World'] +position = 1 +length = 3 + +output = [b'ell', b'orl'] +``` + +Using `pos` and `len` with same shape as `input`: + +``` +input = [[b'ten', b'eleven', b'twelve'], + [b'thirteen', b'fourteen', b'fifteen'], + [b'sixteen', b'seventeen', b'eighteen']] +position = [[1, 2, 3], + [1, 2, 3], + [1, 2, 3]] +length = [[2, 3, 4], + [4, 3, 2], + [5, 5, 5]] + +output = [[b'en', b'eve', b'lve'], + [b'hirt', b'urt', b'te'], + [b'ixtee', b'vente', b'hteen']] +``` + +Broadcasting `pos` and `len` onto `input`: + +``` +input = [[b'ten', b'eleven', b'twelve'], + [b'thirteen', b'fourteen', b'fifteen'], + [b'sixteen', b'seventeen', b'eighteen'], + [b'nineteen', b'twenty', b'twentyone']] +position = [1, 2, 3] +length = [1, 2, 3] + +output = [[b'e', b'ev', b'lve'], + [b'h', b'ur', b'tee'], + [b'i', b've', b'hte'], + [b'i', b'en', b'nty']] +``` + +Broadcasting `input` onto `pos` and `len`: + +``` +input = b'thirteen' +position = [1, 5, 7] +length = [3, 2, 1] + +output = [b'hir', b'ee', b'n"] +``` + +input: Tensor of strings +pos: Scalar defining the position of first character in each substring +len: Scalar defining the number of characters to include in each substring +output: Tensor of substrings +)doc"); + } // namespace tensorflow |