aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc23
1 files changed, 15 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index b78b763fd6..f4cfc48af5 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -35,6 +35,7 @@ using mkldnn::prop_kind;
using mkldnn::relu_backward;
using mkldnn::relu_forward;
using mkldnn::stream;
+using mkldnn::memory;
#else
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
@@ -867,11 +868,12 @@ class MklReluOpBase : public OpKernel {
eltwise_fwd->Execute(src_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) + ", in file " +
- string(__FILE__) + ":" + std::to_string(__LINE__);
- OP_REQUIRES_OK(
- context,
- errors::Aborted("Operation received an exception:", error_msg));
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:",
+ error_msg));
}
}
@@ -886,7 +888,8 @@ class MklReluGradOpBase : public OpKernel {
~MklReluGradOpBase() {}
explicit MklReluGradOpBase(OpKernelConstruction* context)
- : OpKernel(context) {}
+ : OpKernel(context) {
+ }
virtual void Compute_Scalar(OpKernelContext* context) = 0;
@@ -942,8 +945,12 @@ class MklReluGradOpBase : public OpKernel {
dnn_shape_diff_dst.GetTfDataFormat();
auto diff_dst_tf_data_format =
MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
- src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
- diff_dst_tf_data_format);
+
+ src_dims = (src_tensor.dims() == 4)
+ ? TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
+ diff_dst_tf_data_format)
+ : TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
+ diff_dst_tf_data_format);
src_md =
memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format);
} else {