aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_relu_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc60
1 files changed, 22 insertions, 38 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 25c8359cc5..0c66f73141 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -16,17 +16,17 @@ limitations under the License.
// See docs in ../ops/nn_ops.cc.
#ifdef INTEL_MKL
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "third_party/mkl/include/mkl_dnn.h"
-#include "third_party/mkl/include/mkl_dnn_types.h"
#include "tensorflow/core/platform/default/logging.h"
#include "tensorflow/core/util/mkl_util.h"
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
namespace tensorflow {
@@ -194,45 +194,29 @@ class MklReluGradOp : public OpKernel {
void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
-
- CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
- &mkl_lt_internal_grad, prim_relu_bwd, dnnResourceDiffDst),
- E_SUCCESS);
-
- CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
- prim_relu_bwd, dnnResourceSrc),
- E_SUCCESS);
-
- if (!dnnLayoutCompare_F32(mkl_lt_internal_grad, lt_grad)) {
- AllocTmpBuffer(context, mkl_tmp_grad_buf_tensor, mkl_lt_internal_grad,
- &relu_res[dnnResourceDiffDst]);
- CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_grad, lt_grad,
- mkl_lt_internal_grad),
+ dnnPrimitive_t cv_input_to_grad = NULL;
+ Tensor mkl_tmp_buf_tensor;
+ void* mkl_buffer_convert = nullptr;
+
+ // if input and grad are not in the same layout, do a conversion between
+ // them.
+ if (!dnnLayoutCompare_F32(lt_input, lt_grad)) {
+ AllocTmpBuffer(context, &mkl_tmp_buf_tensor, lt_grad,
+ &mkl_buffer_convert);
+ CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad),
E_SUCCESS);
- CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_grad, user_g,
- relu_res[dnnResourceDiffDst]),
- E_SUCCESS);
- dnnDelete_F32(cv_user_to_reluB_grad);
- } else {
- relu_res[dnnResourceDiffDst] = user_g;
- }
- if (!dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input)) {
- AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
- &relu_res[dnnResourceSrc]);
- CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_input, lt_input,
- mkl_lt_internal_input),
- E_SUCCESS);
- CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_input, user_i,
- relu_res[dnnResourceSrc]),
+ CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, user_i,
+ mkl_buffer_convert),
E_SUCCESS);
- dnnDelete_F32(cv_user_to_reluB_input);
+ relu_res[dnnResourceSrc] = mkl_buffer_convert;
+ dnnDelete_F32(cv_input_to_grad);
} else {
relu_res[dnnResourceSrc] = user_i;
}
- dnnLayoutDelete_F32(mkl_lt_internal_input);
- dnnLayoutDelete_F32(mkl_lt_internal_grad);
+ relu_res[dnnResourceDiffDst] = user_g;
+
}
void MklCreateInputLayouts(OpKernelContext* context) {
@@ -331,7 +315,7 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
mkl_context.MklCreateInputLayouts(context);
float negative_slope = 0.0;
CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL,
- mkl_context.lt_grad, mkl_context.lt_input,
+ mkl_context.lt_grad, mkl_context.lt_grad,
negative_slope),
E_SUCCESS);
Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
@@ -380,12 +364,12 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
/* Register DNN kernels for supported operations and supported types - right now
* it is only Relu and f32*/
#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
- REGISTER_KERNEL_BUILDER(Name("_MklRelu") \
+ REGISTER_KERNEL_BUILDER(Name("_MklRelu") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklReluOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \
+ REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklOpLabel), \