diff options
author | 2017-06-26 14:00:17 -0700 | |
---|---|---|
committer | 2017-06-26 14:04:35 -0700 | |
commit | 1fa73c53ab95693f070ce70e6be0c644d83c163a (patch) | |
tree | ffbedf825daf1f3453c695a433c8a9cdf93f6019 /tensorflow/core/kernels/mkl_relu_op.cc | |
parent | b13e96e21c1229a905a623111dd89d2bd0cba53b (diff) |
Automated g4 rollback of changelist 160182040
PiperOrigin-RevId: 160190881
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_relu_op.cc | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index fabecc39a8..10d2937584 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -184,31 +184,38 @@ class MklReluGradOp : public OpKernel { dnnLayout_t lt_input, lt_grad; void MklPrepareReluGradInputs(OpKernelContext* context, + Tensor* mkl_tmp_grad_buf_tensor, Tensor* mkl_tmp_input_buf_tensor) { + dnnPrimitive_t cv_user_to_reluB_input, cv_user_to_reluB_grad; + dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_grad; + const Tensor& g = MklGetInput(context, 0); const Tensor& a = MklGetInput(context, 1); - void* buf_input = static_cast<void*>(const_cast<T*>(a.flat<T>().data())); + + void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data())); + void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data())); + dnnPrimitive_t cv_input_to_grad = NULL; + Tensor mkl_tmp_buf_tensor; void* mkl_buffer_convert = nullptr; - dnnPrimitive_t cv_input_to_grad = nullptr; // if input and grad are not in the same layout, do a conversion between // them. if (!dnnLayoutCompare_F32(lt_input, lt_grad)) { - AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_grad, + AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad, &mkl_buffer_convert); - CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, - lt_grad), E_SUCCESS); - CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, buf_input, + CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad), + E_SUCCESS); + + CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i, mkl_buffer_convert), E_SUCCESS); relu_res[dnnResourceSrc] = mkl_buffer_convert; dnnDelete_F32(cv_input_to_grad); } else { - relu_res[dnnResourceSrc] = buf_input; + relu_res[dnnResourceSrc] = user_i; } - void* buf_grad = static_cast<void*>(const_cast<T*>(g.flat<T>().data())); - relu_res[dnnResourceDiffDst] = buf_grad; + relu_res[dnnResourceDiffDst] = user_g; } void MklCreateInputLayouts(OpKernelContext* context) { @@ -310,8 +317,9 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { mkl_context.lt_grad, mkl_context.lt_grad, negative_slope), E_SUCCESS); - Tensor mkl_tmp_input_buf_tensor; - mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_input_buf_tensor); + Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor; + mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_grad_buf_tensor, + &mkl_tmp_input_buf_tensor); if (input_is_mkl || grad_is_mkl) { /*if grad or input are MKL leave it in MKL*/ |