diff options
-rw-r--r-- | tensorflow/core/kernels/mkl_relu_op.cc | 23 |
1 files changed, 15 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index b78b763fd6..f4cfc48af5 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -35,6 +35,7 @@ using mkldnn::prop_kind; using mkldnn::relu_backward; using mkldnn::relu_forward; using mkldnn::stream; +using mkldnn::memory; #else #include "mkl_dnn.h" #include "mkl_dnn_types.h" @@ -867,11 +868,12 @@ class MklReluOpBase : public OpKernel { eltwise_fwd->Execute(src_data, dst_data); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", + error_msg)); } } @@ -886,7 +888,8 @@ class MklReluGradOpBase : public OpKernel { ~MklReluGradOpBase() {} explicit MklReluGradOpBase(OpKernelConstruction* context) - : OpKernel(context) {} + : OpKernel(context) { + } virtual void Compute_Scalar(OpKernelContext* context) = 0; @@ -942,8 +945,12 @@ class MklReluGradOpBase : public OpKernel { dnn_shape_diff_dst.GetTfDataFormat(); auto diff_dst_tf_data_format = MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format); - src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), - diff_dst_tf_data_format); + + src_dims = (src_tensor.dims() == 4) + ? TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), + diff_dst_tf_data_format) + : TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(), + diff_dst_tf_data_format); src_md = memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format); } else { |