diff options
author | 2018-06-12 15:54:37 -0700 | |
---|---|---|
committer | 2018-06-12 15:54:37 -0700 | |
commit | f369de2bb9f28c36b8b654db3dbd4dd187482c22 (patch) | |
tree | c8f54bae3bb2e47ceb5729c5f03e4f35ce5c6259 /tensorflow/core/kernels/mkl_relu_op.cc | |
parent | 2bcd873e839c66b2405226508286da371dd8afbe (diff) |
code refactoring per Rasmus's suggestions on PR 19754
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_relu_op.cc | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 048d4883b2..a52c879721 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -65,7 +65,8 @@ class MklEltwiseFwdParams { template <typename T> class MklEltwiseFwdPrimitive : public MklPrimitive { public: - explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams) { + explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams) : + cpu_engine_(engine::cpu, 0) { // store expected format context_.src_fmt = static_cast<mkldnn::memory::format>( fwdParams.src_md.data.format); @@ -90,7 +91,6 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { // after execution, set data handle back context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); - return; } std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> GetEltwiseFwdPd() { @@ -133,7 +133,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { fwd_desc(nullptr), fwd_pd(nullptr), src_md(nullptr), dst_md(nullptr), src_mpd(nullptr), eltwise_fwd(nullptr), fwd_stream(nullptr) { } - } context_; + }; // Eltwise forward primitive setup void Setup(const MklEltwiseFwdParams<T>& fwdParams) { @@ -159,10 +159,10 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { *context_.src_mem, *context_.dst_mem)); context_.fwd_primitives.push_back(*context_.eltwise_fwd); - return; } - engine cpu_engine_ = engine(engine::cpu, 0); + struct EltwiseFwdContext context_; + engine cpu_engine_; }; template <typename T> @@ -242,7 +242,8 @@ class MklEltwiseBwdParams { template <typename T> class MklEltwiseBwdPrimitive : public MklPrimitive { public: - explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams) { + explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams) : + cpu_engine_(engine::cpu, 0) { context_.src_fmt = static_cast<mkldnn::memory::format>( bwdParams.common_md.data.format); context_.diff_dst_fmt = static_cast<mkldnn::memory::format>( @@ -271,7 +272,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { context_.src_mem->set_data_handle(DummyData); context_.diff_dst_mem->set_data_handle(DummyData); context_.diff_src_mem->set_data_handle(DummyData); - return; } std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> GetEltwiseBwdPd() { @@ -329,7 +329,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { fwd_desc(nullptr), fwd_pd(nullptr), bwd_pd(nullptr), eltwise_bwd(nullptr), bwd_stream(nullptr) { } - } context_; + }; // Eltwise backward primitive setup void Setup(const MklEltwiseBwdParams<T>& bwdParams) { @@ -365,10 +365,10 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { *context_.src_mem, *context_.diff_dst_mem, *context_.diff_src_mem)); context_.bwd_primitives.push_back(*context_.eltwise_bwd); - return; } - engine cpu_engine_ = engine(engine::cpu, 0); + struct EltwiseBwdContext context_; + engine cpu_engine_; }; |