diff options
Diffstat (limited to 'tensorflow/core/kernels/substr_op.cc')
-rw-r--r-- | tensorflow/core/kernels/substr_op.cc | 50 |
1 files changed, 35 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 22e45918a0..07f1d6e767 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <cstddef> +#include <cstdlib> #include <string> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -25,6 +27,8 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" namespace tensorflow { @@ -64,26 +68,28 @@ class SubstrOp : public OpKernel { const T len = tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { - string in = input(i); + StringPiece in(input(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, in.size() + 1), + context, FastBoundsCheck(std::abs(pos), in.size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i).assign(sub_in.data(), sub_in.size()); } } else { // Perform Op element-wise with tensor pos/len auto pos_flat = pos_tensor.flat<T>(); auto len_flat = len_tensor.flat<T>(); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { - string in = input(i); + StringPiece in(input(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, in.size() + 1), + context, FastBoundsCheck(std::abs(pos), in.size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i).assign(sub_in.data(), sub_in.size()); } } } else { @@ -142,14 +148,16 @@ class SubstrOp : public OpKernel { // Iterate through broadcasted tensors and perform substr for (int i = 0; i < output_shape.dim_size(0); ++i) { - string in = input_bcast(i); + StringPiece in(input_bcast(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, input_bcast(i).size() + 1), + context, + FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i).assign(sub_in.data(), sub_in.size()); } break; } @@ -192,16 +200,18 @@ class SubstrOp : public OpKernel { // Iterate through broadcasted tensors and perform substr for (int i = 0; i < output_shape.dim_size(0); ++i) { for (int j = 0; j < output_shape.dim_size(1); ++j) { - string in = input_bcast(i, j); + StringPiece in(input_bcast(i, j)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); - OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1), - errors::InvalidArgument( - "pos ", pos, " out of range for ", "string b'", - in, "' at index (", i, ", ", j, ")")); - output(i, j) = in.substr(pos, len); + OP_REQUIRES( + context, FastBoundsCheck(std::abs(pos), in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index (", i, + ", ", j, ")")); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i, j).assign(sub_in.data(), sub_in.size()); } } break; @@ -213,6 +223,16 @@ class SubstrOp : public OpKernel { } } } + + private: + // This adjusts the requested position. Note it does not perform any bound + // checks. + T AdjustedPosIndex(const T pos_requested, const StringPiece s) { + if (pos_requested < 0) { + return s.size() + pos_requested; + } + return pos_requested; + } }; #define REGISTER_SUBSTR(type) \ |