diff options
author | Benoit Steiner <bsteiner@google.com> | 2016-11-09 16:14:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-09 16:25:40 -0800 |
commit | 0f81e91137e305a545f070815099f24ae217640f (patch) | |
tree | 011c4ade29c52f7f64c455a5f82100b33965686e /tensorflow/core/kernels/substr_op.cc | |
parent | 58b4581d0ef07a3016d44bfae0cc28738bce2e90 (diff) |
Fixed formatting and lint issues introduced with the last pull from OSS (cl/138675832)
Change: 138699007
Diffstat (limited to 'tensorflow/core/kernels/substr_op.cc')
-rw-r--r-- | tensorflow/core/kernels/substr_op.cc | 361 |
1 files changed, 176 insertions, 185 deletions
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 020ad12c2d..5c72c9e1ae 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -32,202 +32,193 @@ namespace tensorflow { // Position/length can be 32 or 64-bit integers template <typename T> class SubstrOp : public OpKernel { - public: - using OpKernel::OpKernel; - - void Compute(OpKernelContext* context) override { - // Get inputs - const Tensor& input_tensor = context->input(0); - const Tensor& pos_tensor = context->input(1); - const Tensor& len_tensor = context->input(2); - const TensorShape input_shape = input_tensor.shape(); - const TensorShape pos_shape = pos_tensor.shape(); - const TensorShape len_shape = len_tensor.shape(); - - bool is_scalar = TensorShapeUtils::IsScalar(pos_shape); - - if (is_scalar || input_shape == pos_shape) { - // pos/len are either scalar or match the shape of input_tensor - // Do not need to do broadcasting - - // Reshape input - auto input = input_tensor.flat<string>(); - // Allocate output - Tensor* output_tensor = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output("output", input_tensor.shape(), - &output_tensor)); - auto output = output_tensor->flat<string>(); - if (is_scalar) { - // Perform Op with scalar pos/len - const T pos = tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()()); - const T len = tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); - for (size_t i = 0; i < input_tensor.NumElements(); ++i) { - string in = input(i); - OP_REQUIRES(context, FastBoundsCheck(pos, in.size()), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); - } - } else { - // Perform Op element-wise with tensor pos/len - auto pos_flat = pos_tensor.flat<T>(); - auto len_flat = len_tensor.flat<T>(); - for (size_t i = 0; i < input_tensor.NumElements(); ++i) { - string in = input(i); - const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); - const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); - OP_REQUIRES(context, FastBoundsCheck(pos, in.size()), - errors::InvalidArgument("pos ", pos, " out of range for string", + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* context) override { + // Get inputs + const Tensor& input_tensor = context->input(0); + const Tensor& pos_tensor = context->input(1); + const Tensor& len_tensor = context->input(2); + const TensorShape& input_shape = input_tensor.shape(); + const TensorShape& pos_shape = pos_tensor.shape(); + + bool is_scalar = TensorShapeUtils::IsScalar(pos_shape); + + if (is_scalar || input_shape == pos_shape) { + // pos/len are either scalar or match the shape of input_tensor + // Do not need to do broadcasting + + // Reshape input + auto input = input_tensor.flat<string>(); + // Allocate output + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("output", input_tensor.shape(), + &output_tensor)); + auto output = output_tensor->flat<string>(); + if (is_scalar) { + // Perform Op with scalar pos/len + const T pos = + tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()()); + const T len = + tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); + for (size_t i = 0; i < input_tensor.NumElements(); ++i) { + string in = input(i); + OP_REQUIRES( + context, FastBoundsCheck(pos, in.size()), + errors::InvalidArgument("pos ", pos, " out of range for string", + "b'", in, "' at index ", i)); + output(i) = in.substr(pos, len); + } + } else { + // Perform Op element-wise with tensor pos/len + auto pos_flat = pos_tensor.flat<T>(); + auto len_flat = len_tensor.flat<T>(); + for (size_t i = 0; i < input_tensor.NumElements(); ++i) { + string in = input(i); + const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); + const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); + OP_REQUIRES( + context, FastBoundsCheck(pos, in.size()), + errors::InvalidArgument("pos ", pos, " out of range for string", + "b'", in, "' at index ", i)); + output(i) = in.substr(pos, len); + } + } + } else { + // Perform op with broadcasting + // TODO: Use ternary broadcasting for once available in Eigen. Current + // implementation iterates through broadcasted ops element-wise; + // this should be parallelized. + + // Create BCast helper with shape of input and pos/len + BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape)); + OP_REQUIRES(context, bcast.IsValid(), + errors::InvalidArgument("Incompatible shapes: ", + input_shape.DebugString(), " vs. ", + pos_shape.DebugString())); + TensorShape output_shape = BCast::ToShape(bcast.result_shape()); + int ndims = output_shape.dims(); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output("output", output_shape, + &output_tensor)); + switch (ndims) { + case 1: { + // Reshape tensors according to BCast results + auto input = input_tensor.shaped<string, 1>(bcast.x_reshape()); + auto output = output_tensor->shaped<string, 1>(bcast.result_shape()); + auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape()); + auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape()); + + // Allocate temporary buffer for broadcasted input tensor + Tensor input_buffer; + OP_REQUIRES_OK(context, context->allocate_temp( + DT_STRING, output_shape, &input_buffer)); + typename TTypes<string, 1>::Tensor input_bcast = + input_buffer.shaped<string, 1>(bcast.result_shape()); + input_bcast = + input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast())); + + // Allocate temporary buffer for broadcasted position tensor + Tensor pos_buffer; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::v(), + output_shape, &pos_buffer)); + typename TTypes<T, 1>::Tensor pos_bcast = + pos_buffer.shaped<T, 1>(bcast.result_shape()); + pos_bcast = + pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); + + // Allocate temporary buffer for broadcasted length tensor + Tensor len_buffer; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::v(), + output_shape, &len_buffer)); + typename TTypes<T, 1>::Tensor len_bcast = + len_buffer.shaped<T, 1>(bcast.result_shape()); + len_bcast = + len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); + + // Iterate through broadcasted tensors and perform substr + for (int i = 0; i < output_shape.dim_size(0); ++i) { + string in = input_bcast(i); + const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); + const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); + OP_REQUIRES( + context, FastBoundsCheck(pos, input_bcast(i).size()), + errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); output(i) = in.substr(pos, len); } + break; } - } else { - // Perform op with broadcasting - // TODO: Use ternary broadcasting for once available in Eigen. Current - // implementation iterates through broadcasted ops element-wise; - // this should be parallelized. - - // Create BCast helper with shape of input and pos/len - BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape)); - OP_REQUIRES(context, bcast.IsValid(), - errors::InvalidArgument("Incompatible shapes: ", - input_shape.DebugString(), " vs. ", - pos_shape.DebugString())); - TensorShape output_shape = BCast::ToShape(bcast.result_shape()); - int ndims = output_shape.dims(); - Tensor* output_tensor = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output("output", output_shape, - &output_tensor)); - switch (ndims) { - case 1: { - // Reshape tensors according to BCast results - auto input = input_tensor.shaped<string,1>(bcast.x_reshape()); - auto output = output_tensor->shaped<string,1>(bcast.result_shape()); - auto pos_shaped = pos_tensor.shaped<T,1>(bcast.y_reshape()); - auto len_shaped = len_tensor.shaped<T,1>(bcast.y_reshape()); - - // Allocate temporary buffer for broadcasted input tensor - Tensor input_buffer; - OP_REQUIRES_OK(context, - context->allocate_temp(DT_STRING, - output_shape, - &input_buffer)); - typename TTypes<string,1>::Tensor input_bcast = - input_buffer.shaped<string,1>(bcast.result_shape()); - input_bcast = input.broadcast( - BCast::ToIndexArray<1>(bcast.x_bcast())); - - // Allocate temporary buffer for broadcasted position tensor - Tensor pos_buffer; - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - output_shape, - &pos_buffer)); - typename TTypes<T,1>::Tensor pos_bcast = pos_buffer.shaped<T,1>( - bcast.result_shape()); - pos_bcast = pos_shaped.broadcast( - BCast::ToIndexArray<1>(bcast.y_bcast())); - - // Allocate temporary buffer for broadcasted length tensor - Tensor len_buffer; - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - output_shape, - &len_buffer)); - typename TTypes<T,1>::Tensor len_bcast = len_buffer.shaped<T,1>( - bcast.result_shape()); - len_bcast = len_shaped.broadcast( - BCast::ToIndexArray<1>(bcast.y_bcast())); - - // Iterate through broadcasted tensors and perform substr - for (int i = 0; i < output_shape.dim_size(0); ++i) { - string in = input_bcast(i); - const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); - const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); - OP_REQUIRES(context, FastBoundsCheck(pos, input_bcast(i).size()), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); - } - break; - } - case 2: { - // Reshape tensors according to BCast results - auto input = input_tensor.shaped<string,2>(bcast.x_reshape()); - auto output = output_tensor->shaped<string,2>(bcast.result_shape()); - auto pos_shaped = pos_tensor.shaped<T,2>(bcast.y_reshape()); - auto len_shaped = len_tensor.shaped<T,2>(bcast.y_reshape()); - - // Allocate temporary buffer for broadcasted input tensor - Tensor input_buffer; - OP_REQUIRES_OK(context, - context->allocate_temp(DT_STRING, - output_shape, - &input_buffer)); - typename TTypes<string,2>::Tensor input_bcast = - input_buffer.shaped<string,2>(bcast.result_shape()); - input_bcast = input.broadcast( - BCast::ToIndexArray<2>(bcast.x_bcast())); - - // Allocate temporary buffer for broadcasted position tensor - Tensor pos_buffer; - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - output_shape, - &pos_buffer)); - typename TTypes<T,2>::Tensor pos_bcast = pos_buffer.shaped<T,2>( - bcast.result_shape()); - pos_bcast = pos_shaped.broadcast( - BCast::ToIndexArray<2>(bcast.y_bcast())); - - // Allocate temporary buffer for broadcasted length tensor - Tensor len_buffer; - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - output_shape, - &len_buffer)); - typename TTypes<T,2>::Tensor len_bcast = len_buffer.shaped<T,2>( - bcast.result_shape()); - len_bcast = len_shaped.broadcast( - BCast::ToIndexArray<2>(bcast.y_bcast())); - - // Iterate through broadcasted tensors and perform substr - for (int i = 0; i < output_shape.dim_size(0); ++i) { - for (int j = 0; j < output_shape.dim_size(1); ++j) { - string in = input_bcast(i, j); - const T pos = tensorflow::internal::SubtleMustCopy( - pos_bcast(i, j)); - const T len = tensorflow::internal::SubtleMustCopy( - len_bcast(i, j)); - OP_REQUIRES( - context, - FastBoundsCheck(pos, in.size()), - errors::InvalidArgument("pos ", pos, " out of range for ", - "string b'", in, "' at index (" - , i, ", ", j, ")")); - output(i, j) = in.substr(pos, len); - - } + case 2: { + // Reshape tensors according to BCast results + auto input = input_tensor.shaped<string, 2>(bcast.x_reshape()); + auto output = output_tensor->shaped<string, 2>(bcast.result_shape()); + auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape()); + auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape()); + + // Allocate temporary buffer for broadcasted input tensor + Tensor input_buffer; + OP_REQUIRES_OK(context, context->allocate_temp( + DT_STRING, output_shape, &input_buffer)); + typename TTypes<string, 2>::Tensor input_bcast = + input_buffer.shaped<string, 2>(bcast.result_shape()); + input_bcast = + input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast())); + + // Allocate temporary buffer for broadcasted position tensor + Tensor pos_buffer; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::v(), + output_shape, &pos_buffer)); + typename TTypes<T, 2>::Tensor pos_bcast = + pos_buffer.shaped<T, 2>(bcast.result_shape()); + pos_bcast = + pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); + + // Allocate temporary buffer for broadcasted length tensor + Tensor len_buffer; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::v(), + output_shape, &len_buffer)); + typename TTypes<T, 2>::Tensor len_bcast = + len_buffer.shaped<T, 2>(bcast.result_shape()); + len_bcast = + len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); + + // Iterate through broadcasted tensors and perform substr + for (int i = 0; i < output_shape.dim_size(0); ++i) { + for (int j = 0; j < output_shape.dim_size(1); ++j) { + string in = input_bcast(i, j); + const T pos = + tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); + const T len = + tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); + OP_REQUIRES(context, FastBoundsCheck(pos, in.size()), + errors::InvalidArgument( + "pos ", pos, " out of range for ", "string b'", + in, "' at index (", i, ", ", j, ")")); + output(i, j) = in.substr(pos, len); } - break; - } - default: { - context->SetStatus(errors::Unimplemented( - "Substr broadcast not implemented for ", ndims, " dimensions")); } + break; + } + default: { + context->SetStatus(errors::Unimplemented( + "Substr broadcast not implemented for ", ndims, " dimensions")); } } } + } }; -#define REGISTER_SUBSTR(type) \ - REGISTER_KERNEL_BUILDER(Name("Substr") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T"), \ - SubstrOp<type>); +#define REGISTER_SUBSTR(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Substr").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + SubstrOp<type>); REGISTER_SUBSTR(int32); REGISTER_SUBSTR(int64); } // namespace tensorflow |