diff options
Diffstat (limited to 'tensorflow/core/kernels/xsmm_conv2d.cc')
-rw-r--r-- | tensorflow/core/kernels/xsmm_conv2d.cc | 482 |
1 files changed, 316 insertions, 166 deletions
diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index 823cdf7e09..878abe9712 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -26,14 +26,18 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(void); #include "tensorflow/core/kernels/xsmm_conv2d.h" #include <stdlib.h> +#include <cstring> +#if 0 +#include <omp.h> +#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/ #include "include/libxsmm_cpuid.h" -#include "libxsmm_dnn_handle.h" -#include "libxsmm_malloc.h" +#include "include/libxsmm_malloc.h" namespace tensorflow { @@ -59,10 +63,6 @@ bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc, VLOG(1) << "Cannot use XSMM convolutions: unsupported format!"; return false; } - if (desc.pad_h_in != 0 || desc.pad_w_in != 0) { - VLOG(1) << "Cannot use XSMM convolutions: unsupported padding!"; - return false; - } if (desc.K % VECTOR_SIZE != 0) { VLOG(1) << "Cannot use XSMM convolutions: output features count not" " divisible by vector size!"; @@ -72,7 +72,6 @@ bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc, return true; } - typedef Eigen::ThreadPoolDevice CPUDevice; namespace functor { @@ -83,25 +82,34 @@ static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) { } } -LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float *kcrs, int R, int S, int C, int K,int blocksifm, int blocksofm, int ifmblock,int ofmblock, int start, int end) -{ - LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C,K); - LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm,R,S,ifmblock, ofmblock); - int r, s, k,c, v1,v2; - - for (k = start; k < end ; k++ ) { - for(c = 0; c < blocksifm;c++){ - for ( r = 0; r < R; r++ ) { - for ( s = 0; s < S; s++ ){ - for ( v1 = c*ifmblock; v1 < std::min(C,(c+1)*ifmblock) ; v1++ ) { - for ( v2 = k*ofmblock; v2 < std::min(K, (k+1)*ofmblock); v2++ ) - LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K); - for ( v2 = K; v2 < (k+1)*ofmblock ; v2++ ) - LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = 0.0f; - } - for ( v1 = C; v1 < (c+1)*ifmblock ; v1++ ) { - for ( v2 = k*ofmblock; v2 < (k+1)*ofmblock; v2++ ) - LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = 0.0f; +LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R, + int S, int C, int K, int blocksifm, + int blocksofm, int ifmblock, + int ofmblock, int start, int end) { + LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K); + LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm, R, S, ifmblock, ofmblock); + int r, s, k, c, v1, v2; + + for (k = start; k < end; k++) { + for (c = 0; c < blocksifm; c++) { + for (r = 0; r < R; r++) { + for (s = 0; s < S; s++) { + for (v1 = c * ifmblock; v1 < std::min(C, (c + 1) * ifmblock); v1++) { + for (v2 = k * ofmblock; v2 < std::min(K, (k + 1) * ofmblock); v2++) + LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock, + v2 - k * ofmblock, blocksifm, R, S, ifmblock, + ofmblock) = + LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K); + for (v2 = K; v2 < (k + 1) * ofmblock; v2++) + LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock, + v2 - k * ofmblock, blocksifm, R, S, ifmblock, + ofmblock) = 0.0f; + } + for (v1 = C; v1 < (c + 1) * ifmblock; v1++) { + for (v2 = k * ofmblock; v2 < (k + 1) * ofmblock; v2++) + LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock, + v2 - k * ofmblock, blocksifm, R, S, ifmblock, + ofmblock) = 0.0f; } } } @@ -109,35 +117,28 @@ LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float *kcrs, int R, i } } - +class libxsmm_dnn_conv_desc_wrap { + public: + const libxsmm_dnn_conv_desc d; -class libxsmm_dnn_conv_desc_wrap{ - public: - const libxsmm_dnn_conv_desc d; - - libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc &d_) : d(d_){ - } - bool operator==(const libxsmm_dnn_conv_desc_wrap &w) const{ - return( d.N == w.d.N && - d.C == w.d.C && - d.H == w.d.H && - d.W == w.d.W && - d.K == w.d.K && - d.R == w.d.R && - d.S == w.d.S && - d.u == w.d.u && - d.v == w.d.v && - d.pad_h_in == w.d.pad_h_in && - d.pad_w_in == w.d.pad_w_in - ); - } + libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {} + bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const { + return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W && + d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u && + d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w); + } }; - - -struct HashFunction{ - std::size_t operator()(const libxsmm_dnn_conv_desc_wrap & w) const{ + +struct HashFunction { + std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const { + // unsigned char ptr[sizeof(&w.d)]; + + // memcpy(ptr, (unsigned char *)&w.d, sizeof(&w.d)) + + // + /* std::ostringstream N,C,H,W,K,R,S,u,v,padh,padw; - + N << w.d.N; C << w.d.C; H << w.d.H; W << w.d.W; K << w.d.K; R << w.d.R; @@ -152,59 +153,71 @@ struct HashFunction{ + S.str() + u.str()\ + v.str() + padh.str()\ + padw.str(); - - return ( std::hash<std::string>()(out_)); + // + // + */ + return (std::hash<unsigned long long>()((unsigned long long)&(w.d))); } }; -class handles{ - public: - libxsmm_dnn_layer* find( const libxsmm_dnn_conv_desc_wrap &w) { - std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, - HashFunction>::iterator i = libxsmm_handles.find(w); - if (i == libxsmm_handles.end()){ - libxsmm_dnn_err_t status; - libxsmm_dnn_layer* libxsmm_handle = - libxsmm_dnn_create_conv_layer(w.d, &status); - chk_libxsmm_err(status, "Create handle"); - libxsmm_handles.insert(std::make_pair(w, libxsmm_handle)); - return libxsmm_handle; - } - else - return i->second; - } - ~handles(){ - std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, - HashFunction>::iterator i; - for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++) +class handles { + public: + libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) { + std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*, + HashFunction>::iterator i = libxsmm_handles.find(w); + if (i == libxsmm_handles.end()) { + libxsmm_dnn_err_t status; + libxsmm_dnn_layer* libxsmm_handle = + libxsmm_dnn_create_conv_layer(w.d, &status); + chk_libxsmm_err(status, "Create handle"); + libxsmm_handles.insert(std::make_pair(w, libxsmm_handle)); + return libxsmm_handle; + } else + return i->second; + } + ~handles() { + std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*, + HashFunction>::iterator i; + for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++) chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second), - "Destroy handle"); - } - private: - std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction> libxsmm_handles; + "Destroy handle"); + } + + private: + std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*, + HashFunction> + libxsmm_handles; }; static handles libxsmm_handles; +//#define LIBXSMM_DETAILED_TIMING + template <typename InputPtr, typename FilterPtr, typename OutputPtr> static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, - libxsmm_dnn_compute_kind kind, InputPtr input, - FilterPtr filter, OutputPtr output) { + libxsmm_dnn_compute_kind kind, + InputPtr input, FilterPtr filter, + OutputPtr output) { +#if defined(LIBXSMM_DETAILED_TIMING) + unsigned long long l_tick1, l_tick2, l_tick3, l_tick4, l_tick5, l_tick6, + l_tick7, l_tick8, l_tick9, l_tick10; + l_tick1 = libxsmm_timer_tick(); +#endif // setup scoped allocator, which adopts the allocator from the context const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx); libxsmm_dnn_err_t status; libxsmm_dnn_layer* libxsmm_handle; libxsmm_dnn_conv_desc_wrap w(desc); void* scratch; - - if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) - libxsmm_handle = libxsmm_handles.find(w); - else { - libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status); - chk_libxsmm_err(status, "Create handle"); - } - + + // if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) + libxsmm_handle = libxsmm_handles.find(w); + // else{ + // libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status); + // chk_libxsmm_err(status, "Create handle"); + //} + status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind); if (status == LIBXSMM_DNN_WARN_FALLBACK) { chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), @@ -217,100 +230,168 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, libxsmm_dnn_buffer* libxsmm_output; libxsmm_dnn_filter* libxsmm_filter; - /* - const DeviceBase::CpuWorkerThreads* worker_threads = - ctx->device()->tensorflow_cpu_worker_threads(); - - int num_threads = worker_threads->num_threads; -*/ +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick2 = libxsmm_timer_tick(); +#endif int ifmblock = (libxsmm_handle->ifmblock); int ofmblock = (libxsmm_handle->ofmblock); - int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1; - int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1; - float *native_filter = (float*)libxsmm_aligned_scratch( - blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), - 2097152); + int blocksifm = + desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1; + int blocksofm = + desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1; + float* native_filter = + (float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S * + ifmblock * ofmblock * sizeof(float), + 2097152); const DeviceBase::CpuWorkerThreads* worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; - - if(blocksofm > num_threads){ - int work = blocksofm; - BlockingCounter count(num_threads); - for (int i = 0; i < num_threads; ++i) { +#if 1 + if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || + kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { + if (blocksofm > num_threads) { + int work = blocksofm; + BlockingCounter count(num_threads); + for (int i = 0; i < num_threads; ++i) { worker_threads->workers->Schedule([=, &count]() { - int start = work/num_threads*i; - int end = (start + work/num_threads) > work ? work: start + work/num_threads; - copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S,desc.C, desc.K,blocksifm,blocksofm,ifmblock,ofmblock,start, end); - count.DecrementCount(); + int start = work / num_threads * i; + int end = (start + work / num_threads) > work + ? work + : start + work / num_threads; + copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C, + desc.K, blocksifm, blocksofm, ifmblock, ofmblock, + start, end); + count.DecrementCount(); }); - } - count.Wait(); - } - else{ + } + count.Wait(); + } else { + int work = blocksofm; + int num_threads = work; - int work = blocksofm; - int num_threads = work; - - BlockingCounter count(num_threads); - for (int i = 0; i < num_threads; ++i) { + BlockingCounter count(num_threads); + for (int i = 0; i < num_threads; ++i) { worker_threads->workers->Schedule([=, &count]() { - int start = i; - int end = i+1; - copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S,desc.C, desc.K,blocksifm,blocksofm,ifmblock,ofmblock, start, end); - count.DecrementCount(); + int start = i; + int end = i + 1; + copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C, + desc.K, blocksifm, blocksofm, ifmblock, ofmblock, + start, end); + count.DecrementCount(); }); + } + count.Wait(); } - count.Wait(); } - - libxsmm_input = libxsmm_dnn_link_buffer( - libxsmm_handle, LIBXSMM_DNN_INPUT, input, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); + // Added: for weight update + else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { + libxsmm_filter = + libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter, + LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status); + chk_libxsmm_err(status, + "Link filter"); // weight update is in RSCK as + // filter should be returned in RSCK + // format + } +#else + memset(native_filter, 0, + blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock * + sizeof(float)); +#endif + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick3 = libxsmm_timer_tick(); +#endif + + libxsmm_input = + libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input, + LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); chk_libxsmm_err(status, "Link input buffer"); - libxsmm_output = libxsmm_dnn_link_buffer( - libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); + libxsmm_output = + libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, + LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); chk_libxsmm_err(status, "Link output buffer"); - libxsmm_filter = libxsmm_dnn_link_filter( - libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status); - chk_libxsmm_err(status, "Link filter"); - - chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output"); - - + if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || + kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { + libxsmm_filter = libxsmm_dnn_link_filter( + libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, + LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status); + chk_libxsmm_err(status, "Link filter"); + } if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { - chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT), + chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output"); + + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, + LIBXSMM_DNN_REGULAR_INPUT), "Bind input forward"); - chk_libxsmm_err( - libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_REGULAR_OUTPUT), - "Bind output forward"); - chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER), + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, + LIBXSMM_DNN_REGULAR_OUTPUT), + "Bind output forward"); + chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, + LIBXSMM_DNN_REGULAR_FILTER), "Bind filter forward"); - } else { - chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_GRADIENT_INPUT), + } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { + chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input"); + + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, + LIBXSMM_DNN_GRADIENT_INPUT), "Bind input backward"); - chk_libxsmm_err( - libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT), - "Bind output backward"); - chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER), + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, + LIBXSMM_DNN_GRADIENT_OUTPUT), + "Bind output backward"); + chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, + LIBXSMM_DNN_REGULAR_FILTER), "Bind filter backward"); + } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { + chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter"); + + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, + LIBXSMM_DNN_REGULAR_INPUT), + "Bind input weight udpate"); + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, + LIBXSMM_DNN_GRADIENT_OUTPUT), + "Bind output weight update"); + chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, + LIBXSMM_DNN_GRADIENT_FILTER), + "Bind filter weight update"); + } else { + /* shouldn't happen */ } +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick4 = libxsmm_timer_tick(); +#endif + /* bind scratch */ - scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, kind, &status ), 2097152); - chk_libxsmm_err( status, "scratch allocation" ); - chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, kind, scratch ), "binding scratch" ); + scratch = (void*)libxsmm_aligned_scratch( + libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, + &status), + 2097152); + chk_libxsmm_err(status, "scratch allocation"); + chk_libxsmm_err(libxsmm_dnn_bind_scratch( + libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch), + "binding scratch"); + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick5 = libxsmm_timer_tick(); +#endif if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER); } +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick6 = libxsmm_timer_tick(); +#endif + +#if 1 BlockingCounter counter(num_threads); - + for (int i = 0; i < num_threads; ++i) { worker_threads->workers->Schedule([=, &counter]() { chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i), @@ -319,28 +400,97 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, }); } counter.Wait(); +#else +#pragma omp parallel + { + chk_libxsmm_err( + libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, omp_get_thread_num()), + "Worker"); + } +#endif + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick7 = libxsmm_timer_tick(); +#endif + + if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { + libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER); + } + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick8 = libxsmm_timer_tick(); +#endif /* clean up */ - chk_libxsmm_err( libxsmm_dnn_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ), "release scratch" ); + chk_libxsmm_err( + libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL), + "release scratch"); if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { - chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ), "release input" ); - chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT ), "release output" ); - chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" ); + chk_libxsmm_err( + libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT), + "release input"); + chk_libxsmm_err( + libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT), + "release output"); + chk_libxsmm_err( + libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER), + "release filter"); + } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { + chk_libxsmm_err( + libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT), + "release input"); + chk_libxsmm_err( + libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT), + "release output"); + chk_libxsmm_err( + libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER), + "release filter"); + } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { + chk_libxsmm_err( + libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT), + "release input"); + chk_libxsmm_err( + libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT), + "release output"); + chk_libxsmm_err( + libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER), + "release filter"); } else { - chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT ), "release input" ); - chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ), "release output" ); - chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" ); + /* shouldn't happen */ } chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input"); chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output"); chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter"); - - if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD) - chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), - "Destroy handle"); + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick9 = libxsmm_timer_tick(); +#endif + + // if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD) + // chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), + // "Destroy handle"); libxsmm_free(native_filter); libxsmm_free(scratch); + +#if defined(LIBXSMM_DETAILED_TIMING) + l_tick10 = libxsmm_timer_tick(); + printf( + "time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, " + "%f, %f, %f\n", + desc.N, desc.C, desc.K, desc.R, desc.S, + libxsmm_timer_duration(l_tick1, l_tick2), + libxsmm_timer_duration(l_tick2, l_tick3), + libxsmm_timer_duration(l_tick3, l_tick4), + libxsmm_timer_duration(l_tick4, l_tick5), + libxsmm_timer_duration(l_tick5, l_tick6), + libxsmm_timer_duration(l_tick6, l_tick7), + libxsmm_timer_duration(l_tick7, l_tick8), + libxsmm_timer_duration(l_tick8, l_tick9), + libxsmm_timer_duration(l_tick9, l_tick10), + libxsmm_timer_duration(l_tick1, l_tick10)); +#endif + return true; // Succeeded } @@ -348,8 +498,8 @@ template <typename T> struct XsmmFwdConv2D<CPUDevice, T> { bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, const T* input, const T* filter, T* output) { - return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD, input, - filter, output); + return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD, + input, filter, output); } }; @@ -357,8 +507,8 @@ template <typename T> struct XsmmBkwInputConv2D<CPUDevice, T> { bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, T* input, const T* filter, const T* output) { - return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD, input, - filter, output); + return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD, + input, filter, output); } }; @@ -366,8 +516,8 @@ template <typename T> struct XsmmBkwFilterConv2D<CPUDevice, T> { bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, const T* input, T* filter, const T* output) { - return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD, input, - filter, output); + return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD, + input, filter, output); } }; |