aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/substr_op.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2016-11-09 16:14:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-09 16:25:40 -0800
commit0f81e91137e305a545f070815099f24ae217640f (patch)
tree011c4ade29c52f7f64c455a5f82100b33965686e /tensorflow/core/kernels/substr_op.cc
parent58b4581d0ef07a3016d44bfae0cc28738bce2e90 (diff)
Fixed formatting and lint issues introduced with the last pull from OSS (cl/138675832)
Change: 138699007
Diffstat (limited to 'tensorflow/core/kernels/substr_op.cc')
-rw-r--r--tensorflow/core/kernels/substr_op.cc361
1 files changed, 176 insertions, 185 deletions
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 020ad12c2d..5c72c9e1ae 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -32,202 +32,193 @@ namespace tensorflow {
// Position/length can be 32 or 64-bit integers
template <typename T>
class SubstrOp : public OpKernel {
- public:
- using OpKernel::OpKernel;
-
- void Compute(OpKernelContext* context) override {
- // Get inputs
- const Tensor& input_tensor = context->input(0);
- const Tensor& pos_tensor = context->input(1);
- const Tensor& len_tensor = context->input(2);
- const TensorShape input_shape = input_tensor.shape();
- const TensorShape pos_shape = pos_tensor.shape();
- const TensorShape len_shape = len_tensor.shape();
-
- bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
-
- if (is_scalar || input_shape == pos_shape) {
- // pos/len are either scalar or match the shape of input_tensor
- // Do not need to do broadcasting
-
- // Reshape input
- auto input = input_tensor.flat<string>();
- // Allocate output
- Tensor* output_tensor = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output("output", input_tensor.shape(),
- &output_tensor));
- auto output = output_tensor->flat<string>();
- if (is_scalar) {
- // Perform Op with scalar pos/len
- const T pos = tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()());
- const T len = tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
- for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
- OP_REQUIRES(context, FastBoundsCheck(pos, in.size()),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
- }
- } else {
- // Perform Op element-wise with tensor pos/len
- auto pos_flat = pos_tensor.flat<T>();
- auto len_flat = len_tensor.flat<T>();
- for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
- const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
- const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
- OP_REQUIRES(context, FastBoundsCheck(pos, in.size()),
- errors::InvalidArgument("pos ", pos, " out of range for string",
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* context) override {
+ // Get inputs
+ const Tensor& input_tensor = context->input(0);
+ const Tensor& pos_tensor = context->input(1);
+ const Tensor& len_tensor = context->input(2);
+ const TensorShape& input_shape = input_tensor.shape();
+ const TensorShape& pos_shape = pos_tensor.shape();
+
+ bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
+
+ if (is_scalar || input_shape == pos_shape) {
+ // pos/len are either scalar or match the shape of input_tensor
+ // Do not need to do broadcasting
+
+ // Reshape input
+ auto input = input_tensor.flat<string>();
+ // Allocate output
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("output", input_tensor.shape(),
+ &output_tensor));
+ auto output = output_tensor->flat<string>();
+ if (is_scalar) {
+ // Perform Op with scalar pos/len
+ const T pos =
+ tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()());
+ const T len =
+ tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
+ for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
+ string in = input(i);
+ OP_REQUIRES(
+ context, FastBoundsCheck(pos, in.size()),
+ errors::InvalidArgument("pos ", pos, " out of range for string",
+ "b'", in, "' at index ", i));
+ output(i) = in.substr(pos, len);
+ }
+ } else {
+ // Perform Op element-wise with tensor pos/len
+ auto pos_flat = pos_tensor.flat<T>();
+ auto len_flat = len_tensor.flat<T>();
+ for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
+ string in = input(i);
+ const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
+ const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
+ OP_REQUIRES(
+ context, FastBoundsCheck(pos, in.size()),
+ errors::InvalidArgument("pos ", pos, " out of range for string",
+ "b'", in, "' at index ", i));
+ output(i) = in.substr(pos, len);
+ }
+ }
+ } else {
+ // Perform op with broadcasting
+ // TODO: Use ternary broadcasting for once available in Eigen. Current
+ // implementation iterates through broadcasted ops element-wise;
+ // this should be parallelized.
+
+ // Create BCast helper with shape of input and pos/len
+ BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape));
+ OP_REQUIRES(context, bcast.IsValid(),
+ errors::InvalidArgument("Incompatible shapes: ",
+ input_shape.DebugString(), " vs. ",
+ pos_shape.DebugString()));
+ TensorShape output_shape = BCast::ToShape(bcast.result_shape());
+ int ndims = output_shape.dims();
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
+ &output_tensor));
+ switch (ndims) {
+ case 1: {
+ // Reshape tensors according to BCast results
+ auto input = input_tensor.shaped<string, 1>(bcast.x_reshape());
+ auto output = output_tensor->shaped<string, 1>(bcast.result_shape());
+ auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape());
+ auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape());
+
+ // Allocate temporary buffer for broadcasted input tensor
+ Tensor input_buffer;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DT_STRING, output_shape, &input_buffer));
+ typename TTypes<string, 1>::Tensor input_bcast =
+ input_buffer.shaped<string, 1>(bcast.result_shape());
+ input_bcast =
+ input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast()));
+
+ // Allocate temporary buffer for broadcasted position tensor
+ Tensor pos_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ output_shape, &pos_buffer));
+ typename TTypes<T, 1>::Tensor pos_bcast =
+ pos_buffer.shaped<T, 1>(bcast.result_shape());
+ pos_bcast =
+ pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
+
+ // Allocate temporary buffer for broadcasted length tensor
+ Tensor len_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ output_shape, &len_buffer));
+ typename TTypes<T, 1>::Tensor len_bcast =
+ len_buffer.shaped<T, 1>(bcast.result_shape());
+ len_bcast =
+ len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
+
+ // Iterate through broadcasted tensors and perform substr
+ for (int i = 0; i < output_shape.dim_size(0); ++i) {
+ string in = input_bcast(i);
+ const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
+ const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
+ OP_REQUIRES(
+ context, FastBoundsCheck(pos, input_bcast(i).size()),
+ errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
output(i) = in.substr(pos, len);
}
+ break;
}
- } else {
- // Perform op with broadcasting
- // TODO: Use ternary broadcasting for once available in Eigen. Current
- // implementation iterates through broadcasted ops element-wise;
- // this should be parallelized.
-
- // Create BCast helper with shape of input and pos/len
- BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape));
- OP_REQUIRES(context, bcast.IsValid(),
- errors::InvalidArgument("Incompatible shapes: ",
- input_shape.DebugString(), " vs. ",
- pos_shape.DebugString()));
- TensorShape output_shape = BCast::ToShape(bcast.result_shape());
- int ndims = output_shape.dims();
- Tensor* output_tensor = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output("output", output_shape,
- &output_tensor));
- switch (ndims) {
- case 1: {
- // Reshape tensors according to BCast results
- auto input = input_tensor.shaped<string,1>(bcast.x_reshape());
- auto output = output_tensor->shaped<string,1>(bcast.result_shape());
- auto pos_shaped = pos_tensor.shaped<T,1>(bcast.y_reshape());
- auto len_shaped = len_tensor.shaped<T,1>(bcast.y_reshape());
-
- // Allocate temporary buffer for broadcasted input tensor
- Tensor input_buffer;
- OP_REQUIRES_OK(context,
- context->allocate_temp(DT_STRING,
- output_shape,
- &input_buffer));
- typename TTypes<string,1>::Tensor input_bcast =
- input_buffer.shaped<string,1>(bcast.result_shape());
- input_bcast = input.broadcast(
- BCast::ToIndexArray<1>(bcast.x_bcast()));
-
- // Allocate temporary buffer for broadcasted position tensor
- Tensor pos_buffer;
- OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- output_shape,
- &pos_buffer));
- typename TTypes<T,1>::Tensor pos_bcast = pos_buffer.shaped<T,1>(
- bcast.result_shape());
- pos_bcast = pos_shaped.broadcast(
- BCast::ToIndexArray<1>(bcast.y_bcast()));
-
- // Allocate temporary buffer for broadcasted length tensor
- Tensor len_buffer;
- OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- output_shape,
- &len_buffer));
- typename TTypes<T,1>::Tensor len_bcast = len_buffer.shaped<T,1>(
- bcast.result_shape());
- len_bcast = len_shaped.broadcast(
- BCast::ToIndexArray<1>(bcast.y_bcast()));
-
- // Iterate through broadcasted tensors and perform substr
- for (int i = 0; i < output_shape.dim_size(0); ++i) {
- string in = input_bcast(i);
- const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
- const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
- OP_REQUIRES(context, FastBoundsCheck(pos, input_bcast(i).size()),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
- }
- break;
- }
- case 2: {
- // Reshape tensors according to BCast results
- auto input = input_tensor.shaped<string,2>(bcast.x_reshape());
- auto output = output_tensor->shaped<string,2>(bcast.result_shape());
- auto pos_shaped = pos_tensor.shaped<T,2>(bcast.y_reshape());
- auto len_shaped = len_tensor.shaped<T,2>(bcast.y_reshape());
-
- // Allocate temporary buffer for broadcasted input tensor
- Tensor input_buffer;
- OP_REQUIRES_OK(context,
- context->allocate_temp(DT_STRING,
- output_shape,
- &input_buffer));
- typename TTypes<string,2>::Tensor input_bcast =
- input_buffer.shaped<string,2>(bcast.result_shape());
- input_bcast = input.broadcast(
- BCast::ToIndexArray<2>(bcast.x_bcast()));
-
- // Allocate temporary buffer for broadcasted position tensor
- Tensor pos_buffer;
- OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- output_shape,
- &pos_buffer));
- typename TTypes<T,2>::Tensor pos_bcast = pos_buffer.shaped<T,2>(
- bcast.result_shape());
- pos_bcast = pos_shaped.broadcast(
- BCast::ToIndexArray<2>(bcast.y_bcast()));
-
- // Allocate temporary buffer for broadcasted length tensor
- Tensor len_buffer;
- OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- output_shape,
- &len_buffer));
- typename TTypes<T,2>::Tensor len_bcast = len_buffer.shaped<T,2>(
- bcast.result_shape());
- len_bcast = len_shaped.broadcast(
- BCast::ToIndexArray<2>(bcast.y_bcast()));
-
- // Iterate through broadcasted tensors and perform substr
- for (int i = 0; i < output_shape.dim_size(0); ++i) {
- for (int j = 0; j < output_shape.dim_size(1); ++j) {
- string in = input_bcast(i, j);
- const T pos = tensorflow::internal::SubtleMustCopy(
- pos_bcast(i, j));
- const T len = tensorflow::internal::SubtleMustCopy(
- len_bcast(i, j));
- OP_REQUIRES(
- context,
- FastBoundsCheck(pos, in.size()),
- errors::InvalidArgument("pos ", pos, " out of range for ",
- "string b'", in, "' at index ("
- , i, ", ", j, ")"));
- output(i, j) = in.substr(pos, len);
-
- }
+ case 2: {
+ // Reshape tensors according to BCast results
+ auto input = input_tensor.shaped<string, 2>(bcast.x_reshape());
+ auto output = output_tensor->shaped<string, 2>(bcast.result_shape());
+ auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape());
+ auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape());
+
+ // Allocate temporary buffer for broadcasted input tensor
+ Tensor input_buffer;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DT_STRING, output_shape, &input_buffer));
+ typename TTypes<string, 2>::Tensor input_bcast =
+ input_buffer.shaped<string, 2>(bcast.result_shape());
+ input_bcast =
+ input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast()));
+
+ // Allocate temporary buffer for broadcasted position tensor
+ Tensor pos_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ output_shape, &pos_buffer));
+ typename TTypes<T, 2>::Tensor pos_bcast =
+ pos_buffer.shaped<T, 2>(bcast.result_shape());
+ pos_bcast =
+ pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
+
+ // Allocate temporary buffer for broadcasted length tensor
+ Tensor len_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ output_shape, &len_buffer));
+ typename TTypes<T, 2>::Tensor len_bcast =
+ len_buffer.shaped<T, 2>(bcast.result_shape());
+ len_bcast =
+ len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
+
+ // Iterate through broadcasted tensors and perform substr
+ for (int i = 0; i < output_shape.dim_size(0); ++i) {
+ for (int j = 0; j < output_shape.dim_size(1); ++j) {
+ string in = input_bcast(i, j);
+ const T pos =
+ tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
+ const T len =
+ tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
+ OP_REQUIRES(context, FastBoundsCheck(pos, in.size()),
+ errors::InvalidArgument(
+ "pos ", pos, " out of range for ", "string b'",
+ in, "' at index (", i, ", ", j, ")"));
+ output(i, j) = in.substr(pos, len);
}
- break;
- }
- default: {
- context->SetStatus(errors::Unimplemented(
- "Substr broadcast not implemented for ", ndims, " dimensions"));
}
+ break;
+ }
+ default: {
+ context->SetStatus(errors::Unimplemented(
+ "Substr broadcast not implemented for ", ndims, " dimensions"));
}
}
}
+ }
};
-#define REGISTER_SUBSTR(type) \
- REGISTER_KERNEL_BUILDER(Name("Substr") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T"), \
- SubstrOp<type>);
+#define REGISTER_SUBSTR(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Substr").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SubstrOp<type>);
REGISTER_SUBSTR(int32);
REGISTER_SUBSTR(int64);
} // namespace tensorflow