aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_relu_op.cc
diff options
context:
space:
mode:
authorGravatar Guozhong Zhuang <guozhong.zhuang@intel.com>2018-06-12 15:54:37 -0700
committerGravatar Guozhong Zhuang <guozhong.zhuang@intel.com>2018-06-12 15:54:37 -0700
commitf369de2bb9f28c36b8b654db3dbd4dd187482c22 (patch)
treec8f54bae3bb2e47ceb5729c5f03e4f35ce5c6259 /tensorflow/core/kernels/mkl_relu_op.cc
parent2bcd873e839c66b2405226508286da371dd8afbe (diff)
code refactoring per Rasmus's suggestions on PR 19754
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc20
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 048d4883b2..a52c879721 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -65,7 +65,8 @@ class MklEltwiseFwdParams {
template <typename T>
class MklEltwiseFwdPrimitive : public MklPrimitive {
public:
- explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams) {
+ explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams) :
+ cpu_engine_(engine::cpu, 0) {
// store expected format
context_.src_fmt = static_cast<mkldnn::memory::format>(
fwdParams.src_md.data.format);
@@ -90,7 +91,6 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
// after execution, set data handle back
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
- return;
}
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> GetEltwiseFwdPd() {
@@ -133,7 +133,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
fwd_desc(nullptr), fwd_pd(nullptr), src_md(nullptr), dst_md(nullptr),
src_mpd(nullptr), eltwise_fwd(nullptr), fwd_stream(nullptr) {
}
- } context_;
+ };
// Eltwise forward primitive setup
void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
@@ -159,10 +159,10 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
*context_.src_mem, *context_.dst_mem));
context_.fwd_primitives.push_back(*context_.eltwise_fwd);
- return;
}
- engine cpu_engine_ = engine(engine::cpu, 0);
+ struct EltwiseFwdContext context_;
+ engine cpu_engine_;
};
template <typename T>
@@ -242,7 +242,8 @@ class MklEltwiseBwdParams {
template <typename T>
class MklEltwiseBwdPrimitive : public MklPrimitive {
public:
- explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams) {
+ explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams) :
+ cpu_engine_(engine::cpu, 0) {
context_.src_fmt = static_cast<mkldnn::memory::format>(
bwdParams.common_md.data.format);
context_.diff_dst_fmt = static_cast<mkldnn::memory::format>(
@@ -271,7 +272,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
context_.src_mem->set_data_handle(DummyData);
context_.diff_dst_mem->set_data_handle(DummyData);
context_.diff_src_mem->set_data_handle(DummyData);
- return;
}
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> GetEltwiseBwdPd() {
@@ -329,7 +329,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
fwd_desc(nullptr), fwd_pd(nullptr), bwd_pd(nullptr),
eltwise_bwd(nullptr), bwd_stream(nullptr) {
}
- } context_;
+ };
// Eltwise backward primitive setup
void Setup(const MklEltwiseBwdParams<T>& bwdParams) {
@@ -365,10 +365,10 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
*context_.src_mem, *context_.diff_dst_mem, *context_.diff_src_mem));
context_.bwd_primitives.push_back(*context_.eltwise_bwd);
- return;
}
- engine cpu_engine_ = engine(engine::cpu, 0);
+ struct EltwiseBwdContext context_;
+ engine cpu_engine_;
};