aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/substr_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/substr_op.cc')
-rw-r--r--tensorflow/core/kernels/substr_op.cc50
1 files changed, 35 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 22e45918a0..07f1d6e767 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cstddef>
+#include <cstdlib>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -25,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
@@ -64,26 +68,28 @@ class SubstrOp : public OpKernel {
const T len =
tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
} else {
// Perform Op element-wise with tensor pos/len
auto pos_flat = pos_tensor.flat<T>();
auto len_flat = len_tensor.flat<T>();
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
}
} else {
@@ -142,14 +148,16 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
- string in = input_bcast(i);
+ StringPiece in(input_bcast(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, input_bcast(i).size() + 1),
+ context,
+ FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
break;
}
@@ -192,16 +200,18 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
for (int j = 0; j < output_shape.dim_size(1); ++j) {
- string in = input_bcast(i, j);
+ StringPiece in(input_bcast(i, j));
const T pos =
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
- OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1),
- errors::InvalidArgument(
- "pos ", pos, " out of range for ", "string b'",
- in, "' at index (", i, ", ", j, ")"));
- output(i, j) = in.substr(pos, len);
+ OP_REQUIRES(
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index (", i,
+ ", ", j, ")"));
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i, j).assign(sub_in.data(), sub_in.size());
}
}
break;
@@ -213,6 +223,16 @@ class SubstrOp : public OpKernel {
}
}
}
+
+ private:
+ // This adjusts the requested position. Note it does not perform any bound
+ // checks.
+ T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+ if (pos_requested < 0) {
+ return s.size() + pos_requested;
+ }
+ return pos_requested;
+ }
};
#define REGISTER_SUBSTR(type) \