aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_relu_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc30
1 files changed, 11 insertions, 19 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 10d2937584..fabecc39a8 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -184,38 +184,31 @@ class MklReluGradOp : public OpKernel {
dnnLayout_t lt_input, lt_grad;
void MklPrepareReluGradInputs(OpKernelContext* context,
- Tensor* mkl_tmp_grad_buf_tensor,
Tensor* mkl_tmp_input_buf_tensor) {
- dnnPrimitive_t cv_user_to_reluB_input, cv_user_to_reluB_grad;
- dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_grad;
-
const Tensor& g = MklGetInput(context, 0);
const Tensor& a = MklGetInput(context, 1);
-
- void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
- void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
- dnnPrimitive_t cv_input_to_grad = NULL;
- Tensor mkl_tmp_buf_tensor;
+ void* buf_input = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
void* mkl_buffer_convert = nullptr;
+ dnnPrimitive_t cv_input_to_grad = nullptr;
// if input and grad are not in the same layout, do a conversion between
// them.
if (!dnnLayoutCompare_F32(lt_input, lt_grad)) {
- AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad,
+ AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_grad,
&mkl_buffer_convert);
- CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad),
- E_SUCCESS);
-
- CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i,
+ CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input,
+ lt_grad), E_SUCCESS);
+ CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, buf_input,
mkl_buffer_convert),
E_SUCCESS);
relu_res[dnnResourceSrc] = mkl_buffer_convert;
dnnDelete_F32(cv_input_to_grad);
} else {
- relu_res[dnnResourceSrc] = user_i;
+ relu_res[dnnResourceSrc] = buf_input;
}
- relu_res[dnnResourceDiffDst] = user_g;
+ void* buf_grad = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
+ relu_res[dnnResourceDiffDst] = buf_grad;
}
void MklCreateInputLayouts(OpKernelContext* context) {
@@ -317,9 +310,8 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
mkl_context.lt_grad, mkl_context.lt_grad,
negative_slope),
E_SUCCESS);
- Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
- mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_grad_buf_tensor,
- &mkl_tmp_input_buf_tensor);
+ Tensor mkl_tmp_input_buf_tensor;
+ mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_input_buf_tensor);
if (input_is_mkl ||
grad_is_mkl) { /*if grad or input are MKL leave it in MKL*/