diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_relu_op.cc | 30 |
1 files changed, 11 insertions, 19 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 10d2937584..fabecc39a8 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -184,38 +184,31 @@ class MklReluGradOp : public OpKernel { dnnLayout_t lt_input, lt_grad; void MklPrepareReluGradInputs(OpKernelContext* context, - Tensor* mkl_tmp_grad_buf_tensor, Tensor* mkl_tmp_input_buf_tensor) { - dnnPrimitive_t cv_user_to_reluB_input, cv_user_to_reluB_grad; - dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_grad; - const Tensor& g = MklGetInput(context, 0); const Tensor& a = MklGetInput(context, 1); - - void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data())); - void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data())); - dnnPrimitive_t cv_input_to_grad = NULL; - Tensor mkl_tmp_buf_tensor; + void* buf_input = static_cast<void*>(const_cast<T*>(a.flat<T>().data())); void* mkl_buffer_convert = nullptr; + dnnPrimitive_t cv_input_to_grad = nullptr; // if input and grad are not in the same layout, do a conversion between // them. if (!dnnLayoutCompare_F32(lt_input, lt_grad)) { - AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad, + AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_grad, &mkl_buffer_convert); - CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad), - E_SUCCESS); - - CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i, + CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, + lt_grad), E_SUCCESS); + CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, buf_input, mkl_buffer_convert), E_SUCCESS); relu_res[dnnResourceSrc] = mkl_buffer_convert; dnnDelete_F32(cv_input_to_grad); } else { - relu_res[dnnResourceSrc] = user_i; + relu_res[dnnResourceSrc] = buf_input; } - relu_res[dnnResourceDiffDst] = user_g; + void* buf_grad = static_cast<void*>(const_cast<T*>(g.flat<T>().data())); + relu_res[dnnResourceDiffDst] = buf_grad; } void MklCreateInputLayouts(OpKernelContext* context) { @@ -317,9 +310,8 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { mkl_context.lt_grad, mkl_context.lt_grad, negative_slope), E_SUCCESS); - Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor; - mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_grad_buf_tensor, - &mkl_tmp_input_buf_tensor); + Tensor mkl_tmp_input_buf_tensor; + mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_input_buf_tensor); if (input_is_mkl || grad_is_mkl) { /*if grad or input are MKL leave it in MKL*/ |