diff options
author | 2018-05-02 16:13:06 -0700 | |
---|---|---|
committer | 2018-05-02 16:56:08 -0700 | |
commit | c7a5787fef8daf3e44313cbd48591464f9643f56 (patch) | |
tree | f0ba7403e0dbda3fca7d3d78dee972e7f214ccad /tensorflow/core/kernels/scoped_allocator_ops.cc | |
parent | 1f4efb78320e1406c0cc9ce4b8753f3d2511048e (diff) |
Enable reshape of _ScopedAllocatorConcat output.
The _ScopedAllocatorConcat kernel outputs the backing tensor after performing
runtime bounds checks. However, the shape of the backing tensor may not match
the desired output shape of the concat operation.
This change adds a "reshape" boolean attribute to _ScopedAllocatorConcat kernel.
When this attribute is set to true, the kernel outputs a reshaped backing tensor
according to the "shape" attribute.
PiperOrigin-RevId: 195169105
Diffstat (limited to 'tensorflow/core/kernels/scoped_allocator_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/scoped_allocator_ops.cc | 39 |
1 files changed, 27 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/scoped_allocator_ops.cc b/tensorflow/core/kernels/scoped_allocator_ops.cc index d7b25ffad0..1800ee8c1f 100644 --- a/tensorflow/core/kernels/scoped_allocator_ops.cc +++ b/tensorflow/core/kernels/scoped_allocator_ops.cc @@ -94,7 +94,8 @@ class ScopedAllocatorConcatOp : public OpKernel { : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_)); - // This stuff is just for debugging + OP_REQUIRES_OK(context, context->GetAttr("reshape", &reshape_)); + // These attributes are just for debugging. OP_REQUIRES_OK(context, context->GetAttr("sa_name", &name_)); OP_REQUIRES_OK(context, context->GetAttr("id", &id_)); device_ = context->device(); @@ -114,11 +115,14 @@ class ScopedAllocatorConcatOp : public OpKernel { backing_tensor.NumElements(), " is not equal to expected ", shape_.num_elements())); - VLOG(1) << "_ScopedAllocatorConcatOp outputting backing tensor at " - << DMAHelper::base(&backing_tensor); - Tensor backing_copy(backing_tensor); - context->set_output(0, backing_copy); - const TensorBuffer* backing_buf = DMAHelper::buffer(&backing_copy); + Tensor output(dtype_); + if (reshape_) { + CHECK(output.CopyFrom(backing_tensor, shape_)); + } else { + CHECK(output.CopyFrom(backing_tensor, backing_tensor.shape())); + } + context->set_output(0, output); + const TensorBuffer* backing_buf = DMAHelper::buffer(&output); const void* backing_tensor_lb = backing_buf->data(); const void* backing_tensor_ub = static_cast<const void*>( static_cast<const char*>(backing_tensor_lb) + backing_buf->size()); @@ -126,17 +130,27 @@ class ScopedAllocatorConcatOp : public OpKernel { for (int i = 1; i < context->num_inputs(); ++i) { const TensorBuffer* input_buf = DMAHelper::buffer(&context->input(i)); const void* input_lb = input_buf->data(); - OP_REQUIRES( - context, input_lb >= backing_tensor_lb, - errors::InvalidArgument("Lower bound check fail for input ", i, - " to node ", context->op_kernel().name())); const void* input_ub = static_cast<const void*>( static_cast<const char*>(input_lb) + input_buf->size()); OP_REQUIRES( + context, input_lb >= backing_tensor_lb, + errors::InvalidArgument( + "Lower bound check fail for input ", i, " from node ", + context->op_kernel().requested_input(i), " to node ", + context->op_kernel().name(), " input bounds = [", input_lb, ", ", + input_ub, "]", " backing_tensor bounds = [", backing_tensor_lb, + ", ", backing_tensor_ub, "]")); + OP_REQUIRES( context, input_ub <= backing_tensor_ub, - errors::InvalidArgument("Upper bound check fail for input ", i, - " to node ", context->op_kernel().name())); + errors::InvalidArgument( + "Upper bound check fail for input ", i, " from node ", + context->op_kernel().requested_input(i), " to node ", + context->op_kernel().name(), " input bounds = [", input_lb, ", ", + input_ub, "]", " backing_tensor bounds = [", backing_tensor_lb, + ", ", backing_tensor_ub, "]")); } + VLOG(1) << "_ScopedAllocatorConcatOp outputting backing tensor at " + << backing_buf; } private: @@ -144,6 +158,7 @@ class ScopedAllocatorConcatOp : public OpKernel { DataType dtype_; string name_; int32 id_; + bool reshape_; DeviceBase* device_; }; |