aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_relu_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 14:00:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 14:04:35 -0700
commit1fa73c53ab95693f070ce70e6be0c644d83c163a (patch)
treeffbedf825daf1f3453c695a433c8a9cdf93f6019 /tensorflow/core/kernels/mkl_relu_op.cc
parentb13e96e21c1229a905a623111dd89d2bd0cba53b (diff)
Automated g4 rollback of changelist 160182040
PiperOrigin-RevId: 160190881
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc30
1 files changed, 19 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index fabecc39a8..10d2937584 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -184,31 +184,38 @@ class MklReluGradOp : public OpKernel {
dnnLayout_t lt_input, lt_grad;
void MklPrepareReluGradInputs(OpKernelContext* context,
+ Tensor* mkl_tmp_grad_buf_tensor,
Tensor* mkl_tmp_input_buf_tensor) {
+ dnnPrimitive_t cv_user_to_reluB_input, cv_user_to_reluB_grad;
+ dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_grad;
+
const Tensor& g = MklGetInput(context, 0);
const Tensor& a = MklGetInput(context, 1);
- void* buf_input = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
+
+ void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
+ void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
+ dnnPrimitive_t cv_input_to_grad = NULL;
+ Tensor mkl_tmp_buf_tensor;
void* mkl_buffer_convert = nullptr;
- dnnPrimitive_t cv_input_to_grad = nullptr;
// if input and grad are not in the same layout, do a conversion between
// them.
if (!dnnLayoutCompare_F32(lt_input, lt_grad)) {
- AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_grad,
+ AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad,
&mkl_buffer_convert);
- CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input,
- lt_grad), E_SUCCESS);
- CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, buf_input,
+ CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad),
+ E_SUCCESS);
+
+ CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i,
mkl_buffer_convert),
E_SUCCESS);
relu_res[dnnResourceSrc] = mkl_buffer_convert;
dnnDelete_F32(cv_input_to_grad);
} else {
- relu_res[dnnResourceSrc] = buf_input;
+ relu_res[dnnResourceSrc] = user_i;
}
- void* buf_grad = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
- relu_res[dnnResourceDiffDst] = buf_grad;
+ relu_res[dnnResourceDiffDst] = user_g;
}
void MklCreateInputLayouts(OpKernelContext* context) {
@@ -310,8 +317,9 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
mkl_context.lt_grad, mkl_context.lt_grad,
negative_slope),
E_SUCCESS);
- Tensor mkl_tmp_input_buf_tensor;
- mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_input_buf_tensor);
+ Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
+ mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_grad_buf_tensor,
+ &mkl_tmp_input_buf_tensor);
if (input_is_mkl ||
grad_is_mkl) { /*if grad or input are MKL leave it in MKL*/