aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scoped_allocator_ops.cc
diff options
context:
space:
mode:
authorGravatar Ayush Dubey <ayushd@google.com>2018-05-02 16:13:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 16:56:08 -0700
commitc7a5787fef8daf3e44313cbd48591464f9643f56 (patch)
treef0ba7403e0dbda3fca7d3d78dee972e7f214ccad /tensorflow/core/kernels/scoped_allocator_ops.cc
parent1f4efb78320e1406c0cc9ce4b8753f3d2511048e (diff)
Enable reshape of _ScopedAllocatorConcat output.
The _ScopedAllocatorConcat kernel outputs the backing tensor after performing runtime bounds checks. However, the shape of the backing tensor may not match the desired output shape of the concat operation. This change adds a "reshape" boolean attribute to _ScopedAllocatorConcat kernel. When this attribute is set to true, the kernel outputs a reshaped backing tensor according to the "shape" attribute. PiperOrigin-RevId: 195169105
Diffstat (limited to 'tensorflow/core/kernels/scoped_allocator_ops.cc')
-rw-r--r--tensorflow/core/kernels/scoped_allocator_ops.cc39
1 files changed, 27 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/scoped_allocator_ops.cc b/tensorflow/core/kernels/scoped_allocator_ops.cc
index d7b25ffad0..1800ee8c1f 100644
--- a/tensorflow/core/kernels/scoped_allocator_ops.cc
+++ b/tensorflow/core/kernels/scoped_allocator_ops.cc
@@ -94,7 +94,8 @@ class ScopedAllocatorConcatOp : public OpKernel {
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_));
- // This stuff is just for debugging
+ OP_REQUIRES_OK(context, context->GetAttr("reshape", &reshape_));
+ // These attributes are just for debugging.
OP_REQUIRES_OK(context, context->GetAttr("sa_name", &name_));
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
device_ = context->device();
@@ -114,11 +115,14 @@ class ScopedAllocatorConcatOp : public OpKernel {
backing_tensor.NumElements(),
" is not equal to expected ",
shape_.num_elements()));
- VLOG(1) << "_ScopedAllocatorConcatOp outputting backing tensor at "
- << DMAHelper::base(&backing_tensor);
- Tensor backing_copy(backing_tensor);
- context->set_output(0, backing_copy);
- const TensorBuffer* backing_buf = DMAHelper::buffer(&backing_copy);
+ Tensor output(dtype_);
+ if (reshape_) {
+ CHECK(output.CopyFrom(backing_tensor, shape_));
+ } else {
+ CHECK(output.CopyFrom(backing_tensor, backing_tensor.shape()));
+ }
+ context->set_output(0, output);
+ const TensorBuffer* backing_buf = DMAHelper::buffer(&output);
const void* backing_tensor_lb = backing_buf->data();
const void* backing_tensor_ub = static_cast<const void*>(
static_cast<const char*>(backing_tensor_lb) + backing_buf->size());
@@ -126,17 +130,27 @@ class ScopedAllocatorConcatOp : public OpKernel {
for (int i = 1; i < context->num_inputs(); ++i) {
const TensorBuffer* input_buf = DMAHelper::buffer(&context->input(i));
const void* input_lb = input_buf->data();
- OP_REQUIRES(
- context, input_lb >= backing_tensor_lb,
- errors::InvalidArgument("Lower bound check fail for input ", i,
- " to node ", context->op_kernel().name()));
const void* input_ub = static_cast<const void*>(
static_cast<const char*>(input_lb) + input_buf->size());
OP_REQUIRES(
+ context, input_lb >= backing_tensor_lb,
+ errors::InvalidArgument(
+ "Lower bound check fail for input ", i, " from node ",
+ context->op_kernel().requested_input(i), " to node ",
+ context->op_kernel().name(), " input bounds = [", input_lb, ", ",
+ input_ub, "]", " backing_tensor bounds = [", backing_tensor_lb,
+ ", ", backing_tensor_ub, "]"));
+ OP_REQUIRES(
context, input_ub <= backing_tensor_ub,
- errors::InvalidArgument("Upper bound check fail for input ", i,
- " to node ", context->op_kernel().name()));
+ errors::InvalidArgument(
+ "Upper bound check fail for input ", i, " from node ",
+ context->op_kernel().requested_input(i), " to node ",
+ context->op_kernel().name(), " input bounds = [", input_lb, ", ",
+ input_ub, "]", " backing_tensor bounds = [", backing_tensor_lb,
+ ", ", backing_tensor_ub, "]"));
}
+ VLOG(1) << "_ScopedAllocatorConcatOp outputting backing tensor at "
+ << backing_buf;
}
private:
@@ -144,6 +158,7 @@ class ScopedAllocatorConcatOp : public OpKernel {
DataType dtype_;
string name_;
int32 id_;
+ bool reshape_;
DeviceBase* device_;
};