aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_relu_op.cc
diff options
context:
space:
mode:
authorGravatar Guozhong Zhuang <guozhong.zhuang@intel.com>2018-08-01 09:35:31 -0700
committerGravatar Guozhong Zhuang <guozhong.zhuang@intel.com>2018-08-01 09:35:31 -0700
commit478c4161f2524f9e9a6b78f7de297dc7d194d37a (patch)
treed6a8f1085faa2c860e29a9cf5e84b6445e323725 /tensorflow/core/kernels/mkl_relu_op.cc
parentf0f9a6119136fd48ae0c5eec4169e5e2feac563b (diff)
Code changes based on Rasmus's code review suggestions on PR19403 and enhancing MklInputConversion for MKL-DNN v0.15 integration
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc73
1 files changed, 37 insertions, 36 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 3d5a05be73..69f2e37b61 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -83,8 +83,9 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
// Eltwise forward execute
// src_data: input data buffer of src
// dst_data: output data buffer of dst
- void Execute(T* src_data, T* dst_data) {
- context_.src_mem->set_data_handle(static_cast<void*>(src_data));
+ void Execute(const T* src_data, T* dst_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
context_.fwd_stream->submit(context_.fwd_primitives);
@@ -261,10 +262,11 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
// src_data: input data buffer of src
// diff_dst_data: input data buffer of diff_dst
// diff_src_data: output data buffer of diff_src
-
- void Execute(T* src_data, T* diff_dst_data, T* diff_src_data) {
- context_.src_mem->set_data_handle(static_cast<void*>(src_data));
- context_.diff_dst_mem->set_data_handle(static_cast<void*>(diff_dst_data));
+ void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
context_.bwd_stream->submit(context_.bwd_primitives);
@@ -810,17 +812,15 @@ class MklReluOpBase : public OpKernel {
MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
// prepare for execuation
- T* src_data = nullptr;
+ const T* src_data = src_tensor.flat<T>().data();
// check wehther src need to reorder
if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) {
src.SetUsrMem(src_md, &src_tensor);
auto src_target_pd = memory::primitive_desc({{src_dims},
MklDnnType<T>(), eltwise_fwd->GetSrcMemoryFormat()}, cpu_engine);
src.CheckReorderToOpMem(src_target_pd);
- src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
- } else {
- src_data = static_cast<T*>(
- const_cast<T*>(src_tensor.flat<T>().data()));
+ src_data = const_cast<T*>(
+ reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
}
// allocate dst tensor, always set it as MKL-DNN layout
@@ -836,20 +836,20 @@ class MklReluOpBase : public OpKernel {
dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
dnn_shape_src.GetSizesAsMklDnnDims(),
dnn_shape_src.GetTfDataFormat());
- tf_shape_dst.AddDim(dst_pd.get_size()/sizeof(T));
+ tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
} else {
- // TODO(yli135): why relu's input is TF tensor in VGG16??
dnn_shape_dst.SetMklTensor(false);
tf_shape_dst = src_tensor.shape();
}
Tensor* dst_tensor = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {src_index}, dst_index, tf_shape_dst, &dst_tensor));
+ {static_cast<const int>(src_index)},
+ static_cast<const int>(dst_index),
+ tf_shape_dst, &dst_tensor));
AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst);
- T* dst_data = static_cast<T*>(const_cast<T*>(
- dst_tensor->flat<T>().data()));
+ T* dst_data = dst_tensor->flat<T>().data();
// execute eltwise
eltwise_fwd->Execute(src_data, dst_data);
@@ -874,8 +874,8 @@ class MklReluGradOpBase : public OpKernel {
public:
~MklReluGradOpBase() {}
- explicit MklReluGradOpBase(OpKernelConstruction* context) :
- OpKernel(context) {
+ explicit MklReluGradOpBase(OpKernelConstruction* context)
+ : OpKernel(context) {
}
virtual void Compute_Scalar(OpKernelContext* context) = 0;
@@ -964,41 +964,43 @@ class MklReluGradOpBase : public OpKernel {
auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
// check whether need reorder for src / diff_dst
- T* src_data;
- T* diff_dst_data;
+ const T* src_data = src_tensor.flat<T>().data();
if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(
eltwise_bwd_pd.get()->diff_src_primitive_desc());
- src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
- } else {
- src_data = static_cast<T*>(
- const_cast<T*>(src_tensor.flat<T>().data()));
+ src_data = const_cast<T*>(
+ reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
}
+ const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(
eltwise_bwd_pd.get()->diff_src_primitive_desc());
- diff_dst_data = static_cast<T*>(
- diff_dst.GetOpMem().get_data_handle());
- } else {
- diff_dst_data = static_cast<T*>(const_cast<T*>(
- diff_dst_tensor.flat<T>().data()));
+ diff_dst_data = const_cast<T*>(
+ reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
}
// allocate diff_src tensor
MklDnnShape dnn_shape_diff_src;
TensorShape tf_shape_diff_src;
- if (dnn_shape_src.IsMklTensor()) {
+ if (dnn_shape_src.IsMklTensor() ||
+ dnn_shape_diff_dst.IsMklTensor()) {
auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc();
dnn_shape_diff_src.SetMklTensor(true);
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
- dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(),
- dnn_shape_src.GetSizesAsMklDnnDims(),
- dnn_shape_src.GetTfDataFormat());
- tf_shape_diff_src.AddDim(diff_src_pd.get_size()/sizeof(T));
+ if (dnn_shape_src.IsMklTensor()) {
+ dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(),
+ dnn_shape_src.GetSizesAsMklDnnDims(),
+ dnn_shape_src.GetTfDataFormat());
+ } else {
+ dnn_shape_diff_src.SetTfLayout(dnn_shape_diff_dst.GetDimension(),
+ dnn_shape_diff_dst.GetSizesAsMklDnnDims(),
+ dnn_shape_diff_dst.GetTfDataFormat());
+ }
+ tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
} else {
dnn_shape_diff_src.SetMklTensor(false);
tf_shape_diff_src = src_tensor.shape();
@@ -1009,8 +1011,7 @@ class MklReluGradOpBase : public OpKernel {
&diff_src_tensor));
AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src);
- T* diff_src_data = static_cast<T*>(const_cast<T*>(
- diff_src_tensor->flat<T>().data()));
+ T* diff_src_data = diff_src_tensor->flat<T>().data();
// execute eltwise bwd
eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data);