aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_relu_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc23
1 files changed, 16 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 267f4f8d12..0a0f69522f 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -392,7 +392,7 @@ class MklReluOpBase : public OpKernel {
Tensor* dst_tensor = nullptr;
if (src_tensor.dims() == 0) {
- Compute_Scalar(context);
+ Compute_Scalar(context); // scalar case doesn't use in-place operation
return;
}
@@ -437,11 +437,15 @@ class MklReluOpBase : public OpKernel {
dnn_shape_dst.SetMklTensor(false);
tf_shape_dst = src_tensor.shape();
}
- AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst,
- dnn_shape_dst);
+
+ // Allocate output and MklDnnShape tensors separately for possible
+ // in-place operation
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {src_index}, dst_index, tf_shape_dst, &dst_tensor));
+ AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst);
// Destination memory descriptor is same as source memory descriptor.
- auto dst_md = src_md;
+ auto &dst_md = src_md;
dst.SetUsrMem(dst_md, dst_tensor);
// execute net
@@ -492,7 +496,7 @@ class MklReluGradOpBase : public OpKernel {
int src_dims_size = src_tensor.dims();
if (src_dims_size == 0) {
- Compute_Scalar(context);
+ Compute_Scalar(context); // scalar case doesn't use in-place operation
return;
}
@@ -603,8 +607,13 @@ class MklReluGradOpBase : public OpKernel {
// so it is ok to get TensorFlow shape.
tf_shape_diff_src = src_tensor.shape();
}
- AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
- tf_shape_diff_src, dnn_shape_diff_src);
+
+ // Allocate diff_src and MklDnnShape tensors separately for possible
+ // in-place operation
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {diff_dst_index}, diff_src_index, tf_shape_diff_src,
+ &diff_src_tensor));
+ AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src);
// diff_src memory descriptor is same as memory descriptor for both
// inputs.