diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_relu_op.cc | 23 |
1 files changed, 7 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 0a0f69522f..267f4f8d12 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -392,7 +392,7 @@ class MklReluOpBase : public OpKernel { Tensor* dst_tensor = nullptr; if (src_tensor.dims() == 0) { - Compute_Scalar(context); // scalar case doesn't use in-place operation + Compute_Scalar(context); return; } @@ -437,15 +437,11 @@ class MklReluOpBase : public OpKernel { dnn_shape_dst.SetMklTensor(false); tf_shape_dst = src_tensor.shape(); } - - // Allocate output and MklDnnShape tensors separately for possible - // in-place operation - OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {src_index}, dst_index, tf_shape_dst, &dst_tensor)); - AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst); + AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst, + dnn_shape_dst); // Destination memory descriptor is same as source memory descriptor. - auto &dst_md = src_md; + auto dst_md = src_md; dst.SetUsrMem(dst_md, dst_tensor); // execute net @@ -496,7 +492,7 @@ class MklReluGradOpBase : public OpKernel { int src_dims_size = src_tensor.dims(); if (src_dims_size == 0) { - Compute_Scalar(context); // scalar case doesn't use in-place operation + Compute_Scalar(context); return; } @@ -607,13 +603,8 @@ class MklReluGradOpBase : public OpKernel { // so it is ok to get TensorFlow shape. tf_shape_diff_src = src_tensor.shape(); } - - // Allocate diff_src and MklDnnShape tensors separately for possible - // in-place operation - OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {diff_dst_index}, diff_src_index, tf_shape_diff_src, - &diff_src_tensor)); - AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src); + AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, + tf_shape_diff_src, dnn_shape_diff_src); // diff_src memory descriptor is same as memory descriptor for both // inputs. |