diff options
author | Guozhong Zhuang <guozhong.zhuang@intel.com> | 2018-08-01 09:35:31 -0700 |
---|---|---|
committer | Guozhong Zhuang <guozhong.zhuang@intel.com> | 2018-08-01 09:35:31 -0700 |
commit | 478c4161f2524f9e9a6b78f7de297dc7d194d37a (patch) | |
tree | d6a8f1085faa2c860e29a9cf5e84b6445e323725 /tensorflow/core/kernels/mkl_relu_op.cc | |
parent | f0f9a6119136fd48ae0c5eec4169e5e2feac563b (diff) |
Code changes based on Rasmus's code review suggestions on PR19403 and enhancing MklInputConversion for MKL-DNN v0.15 integration
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_relu_op.cc | 73 |
1 files changed, 37 insertions, 36 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 3d5a05be73..69f2e37b61 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -83,8 +83,9 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { // Eltwise forward execute // src_data: input data buffer of src // dst_data: output data buffer of dst - void Execute(T* src_data, T* dst_data) { - context_.src_mem->set_data_handle(static_cast<void*>(src_data)); + void Execute(const T* src_data, T* dst_data) { + context_.src_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(src_data))); context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); context_.fwd_stream->submit(context_.fwd_primitives); @@ -261,10 +262,11 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { // src_data: input data buffer of src // diff_dst_data: input data buffer of diff_dst // diff_src_data: output data buffer of diff_src - - void Execute(T* src_data, T* diff_dst_data, T* diff_src_data) { - context_.src_mem->set_data_handle(static_cast<void*>(src_data)); - context_.diff_dst_mem->set_data_handle(static_cast<void*>(diff_dst_data)); + void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data) { + context_.src_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(src_data))); + context_.diff_dst_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(diff_dst_data))); context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); context_.bwd_stream->submit(context_.bwd_primitives); @@ -810,17 +812,15 @@ class MklReluOpBase : public OpKernel { MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams); // prepare for execuation - T* src_data = nullptr; + const T* src_data = src_tensor.flat<T>().data(); // check wehther src need to reorder if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) { src.SetUsrMem(src_md, &src_tensor); auto src_target_pd = memory::primitive_desc({{src_dims}, MklDnnType<T>(), eltwise_fwd->GetSrcMemoryFormat()}, cpu_engine); src.CheckReorderToOpMem(src_target_pd); - src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); - } else { - src_data = static_cast<T*>( - const_cast<T*>(src_tensor.flat<T>().data())); + src_data = const_cast<T*>( + reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); } // allocate dst tensor, always set it as MKL-DNN layout @@ -836,20 +836,20 @@ class MklReluOpBase : public OpKernel { dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), dnn_shape_src.GetSizesAsMklDnnDims(), dnn_shape_src.GetTfDataFormat()); - tf_shape_dst.AddDim(dst_pd.get_size()/sizeof(T)); + tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); } else { - // TODO(yli135): why relu's input is TF tensor in VGG16?? dnn_shape_dst.SetMklTensor(false); tf_shape_dst = src_tensor.shape(); } Tensor* dst_tensor = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {src_index}, dst_index, tf_shape_dst, &dst_tensor)); + {static_cast<const int>(src_index)}, + static_cast<const int>(dst_index), + tf_shape_dst, &dst_tensor)); AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst); - T* dst_data = static_cast<T*>(const_cast<T*>( - dst_tensor->flat<T>().data())); + T* dst_data = dst_tensor->flat<T>().data(); // execute eltwise eltwise_fwd->Execute(src_data, dst_data); @@ -874,8 +874,8 @@ class MklReluGradOpBase : public OpKernel { public: ~MklReluGradOpBase() {} - explicit MklReluGradOpBase(OpKernelConstruction* context) : - OpKernel(context) { + explicit MklReluGradOpBase(OpKernelConstruction* context) + : OpKernel(context) { } virtual void Compute_Scalar(OpKernelContext* context) = 0; @@ -964,41 +964,43 @@ class MklReluGradOpBase : public OpKernel { auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd(); // check whether need reorder for src / diff_dst - T* src_data; - T* diff_dst_data; + const T* src_data = src_tensor.flat<T>().data(); if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) { src.SetUsrMem(src_md, &src_tensor); src.CheckReorderToOpMem( eltwise_bwd_pd.get()->diff_src_primitive_desc()); - src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); - } else { - src_data = static_cast<T*>( - const_cast<T*>(src_tensor.flat<T>().data())); + src_data = const_cast<T*>( + reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); } + const T* diff_dst_data = diff_dst_tensor.flat<T>().data(); if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); diff_dst.CheckReorderToOpMem( eltwise_bwd_pd.get()->diff_src_primitive_desc()); - diff_dst_data = static_cast<T*>( - diff_dst.GetOpMem().get_data_handle()); - } else { - diff_dst_data = static_cast<T*>(const_cast<T*>( - diff_dst_tensor.flat<T>().data())); + diff_dst_data = const_cast<T*>( + reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle())); } // allocate diff_src tensor MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; - if (dnn_shape_src.IsMklTensor()) { + if (dnn_shape_src.IsMklTensor() || + dnn_shape_diff_dst.IsMklTensor()) { auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc(); dnn_shape_diff_src.SetMklTensor(true); dnn_shape_diff_src.SetMklLayout(&diff_src_pd); dnn_shape_diff_src.SetElemType(MklDnnType<T>()); - dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), - dnn_shape_src.GetSizesAsMklDnnDims(), - dnn_shape_src.GetTfDataFormat()); - tf_shape_diff_src.AddDim(diff_src_pd.get_size()/sizeof(T)); + if (dnn_shape_src.IsMklTensor()) { + dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), + dnn_shape_src.GetSizesAsMklDnnDims(), + dnn_shape_src.GetTfDataFormat()); + } else { + dnn_shape_diff_src.SetTfLayout(dnn_shape_diff_dst.GetDimension(), + dnn_shape_diff_dst.GetSizesAsMklDnnDims(), + dnn_shape_diff_dst.GetTfDataFormat()); + } + tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); } else { dnn_shape_diff_src.SetMklTensor(false); tf_shape_diff_src = src_tensor.shape(); @@ -1009,8 +1011,7 @@ class MklReluGradOpBase : public OpKernel { &diff_src_tensor)); AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src); - T* diff_src_data = static_cast<T*>(const_cast<T*>( - diff_src_tensor->flat<T>().data())); + T* diff_src_data = diff_src_tensor->flat<T>().data(); // execute eltwise bwd eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data); |