aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_relu_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc23
1 files changed, 7 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 0a0f69522f..267f4f8d12 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -392,7 +392,7 @@ class MklReluOpBase : public OpKernel {
Tensor* dst_tensor = nullptr;
if (src_tensor.dims() == 0) {
- Compute_Scalar(context); // scalar case doesn't use in-place operation
+ Compute_Scalar(context);
return;
}
@@ -437,15 +437,11 @@ class MklReluOpBase : public OpKernel {
dnn_shape_dst.SetMklTensor(false);
tf_shape_dst = src_tensor.shape();
}
-
- // Allocate output and MklDnnShape tensors separately for possible
- // in-place operation
- OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {src_index}, dst_index, tf_shape_dst, &dst_tensor));
- AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst);
+ AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst,
+ dnn_shape_dst);
// Destination memory descriptor is same as source memory descriptor.
- auto &dst_md = src_md;
+ auto dst_md = src_md;
dst.SetUsrMem(dst_md, dst_tensor);
// execute net
@@ -496,7 +492,7 @@ class MklReluGradOpBase : public OpKernel {
int src_dims_size = src_tensor.dims();
if (src_dims_size == 0) {
- Compute_Scalar(context); // scalar case doesn't use in-place operation
+ Compute_Scalar(context);
return;
}
@@ -607,13 +603,8 @@ class MklReluGradOpBase : public OpKernel {
// so it is ok to get TensorFlow shape.
tf_shape_diff_src = src_tensor.shape();
}
-
- // Allocate diff_src and MklDnnShape tensors separately for possible
- // in-place operation
- OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {diff_dst_index}, diff_src_index, tf_shape_diff_src,
- &diff_src_tensor));
- AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src);
+ AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
+ tf_shape_diff_src, dnn_shape_diff_src);
// diff_src memory descriptor is same as memory descriptor for both
// inputs.