diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_relu_op.cc | 60 |
1 files changed, 22 insertions, 38 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 25c8359cc5..0c66f73141 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -16,17 +16,17 @@ limitations under the License. // See docs in ../ops/nn_ops.cc. #ifdef INTEL_MKL -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/mkl/include/mkl_dnn.h" -#include "third_party/mkl/include/mkl_dnn_types.h" #include "tensorflow/core/platform/default/logging.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/mkl/include/mkl_dnn.h" +#include "third_party/mkl/include/mkl_dnn_types.h" namespace tensorflow { @@ -194,45 +194,29 @@ class MklReluGradOp : public OpKernel { void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data())); void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data())); - - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( - &mkl_lt_internal_grad, prim_relu_bwd, dnnResourceDiffDst), - E_SUCCESS); - - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input, - prim_relu_bwd, dnnResourceSrc), - E_SUCCESS); - - if (!dnnLayoutCompare_F32(mkl_lt_internal_grad, lt_grad)) { - AllocTmpBuffer(context, mkl_tmp_grad_buf_tensor, mkl_lt_internal_grad, - &relu_res[dnnResourceDiffDst]); - CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_grad, lt_grad, - mkl_lt_internal_grad), + dnnPrimitive_t cv_input_to_grad = NULL; + Tensor mkl_tmp_buf_tensor; + void* mkl_buffer_convert = nullptr; + + // if input and grad are not in the same layout, do a conversion between + // them. + if (!dnnLayoutCompare_F32(lt_input, lt_grad)) { + AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad, + &mkl_buffer_convert); + CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad), E_SUCCESS); - CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_grad, user_g, - relu_res[dnnResourceDiffDst]), - E_SUCCESS); - dnnDelete_F32(cv_user_to_reluB_grad); - } else { - relu_res[dnnResourceDiffDst] = user_g; - } - if (!dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input)) { - AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, - &relu_res[dnnResourceSrc]); - CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_input, lt_input, - mkl_lt_internal_input), - E_SUCCESS); - CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_input, user_i, - relu_res[dnnResourceSrc]), + CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i, + mkl_buffer_convert), E_SUCCESS); - dnnDelete_F32(cv_user_to_reluB_input); + relu_res[dnnResourceSrc] = mkl_buffer_convert; + dnnDelete_F32(cv_input_to_grad); } else { relu_res[dnnResourceSrc] = user_i; } - dnnLayoutDelete_F32(mkl_lt_internal_input); - dnnLayoutDelete_F32(mkl_lt_internal_grad); + relu_res[dnnResourceDiffDst] = user_g; + } void MklCreateInputLayouts(OpKernelContext* context) { @@ -331,7 +315,7 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { mkl_context.MklCreateInputLayouts(context); float negative_slope = 0.0; CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL, - mkl_context.lt_grad, mkl_context.lt_input, + mkl_context.lt_grad, mkl_context.lt_grad, negative_slope), E_SUCCESS); Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor; @@ -380,12 +364,12 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { /* Register DNN kernels for supported operations and supported types - right now * it is only Relu and f32*/ #define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ - REGISTER_KERNEL_BUILDER(Name("_MklRelu") \ + REGISTER_KERNEL_BUILDER(Name("_MklRelu") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklReluOp<CPUDevice, type>); \ - REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \ + REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .Label(mkl_op_registry::kMklOpLabel), \ |