aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/xsmm_conv2d.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/xsmm_conv2d.cc')
-rw-r--r--tensorflow/core/kernels/xsmm_conv2d.cc482
1 files changed, 316 insertions, 166 deletions
diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc
index 823cdf7e09..878abe9712 100644
--- a/tensorflow/core/kernels/xsmm_conv2d.cc
+++ b/tensorflow/core/kernels/xsmm_conv2d.cc
@@ -26,14 +26,18 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(void);
#include "tensorflow/core/kernels/xsmm_conv2d.h"
#include <stdlib.h>
+#include <cstring>
+#if 0
+#include <omp.h>
+#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/
#include "include/libxsmm_cpuid.h"
-#include "libxsmm_dnn_handle.h"
-#include "libxsmm_malloc.h"
+#include "include/libxsmm_malloc.h"
namespace tensorflow {
@@ -59,10 +63,6 @@ bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
VLOG(1) << "Cannot use XSMM convolutions: unsupported format!";
return false;
}
- if (desc.pad_h_in != 0 || desc.pad_w_in != 0) {
- VLOG(1) << "Cannot use XSMM convolutions: unsupported padding!";
- return false;
- }
if (desc.K % VECTOR_SIZE != 0) {
VLOG(1) << "Cannot use XSMM convolutions: output features count not"
" divisible by vector size!";
@@ -72,7 +72,6 @@ bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
return true;
}
-
typedef Eigen::ThreadPoolDevice CPUDevice;
namespace functor {
@@ -83,25 +82,34 @@ static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) {
}
}
-LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float *kcrs, int R, int S, int C, int K,int blocksifm, int blocksofm, int ifmblock,int ofmblock, int start, int end)
-{
- LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C,K);
- LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm,R,S,ifmblock, ofmblock);
- int r, s, k,c, v1,v2;
-
- for (k = start; k < end ; k++ ) {
- for(c = 0; c < blocksifm;c++){
- for ( r = 0; r < R; r++ ) {
- for ( s = 0; s < S; s++ ){
- for ( v1 = c*ifmblock; v1 < std::min(C,(c+1)*ifmblock) ; v1++ ) {
- for ( v2 = k*ofmblock; v2 < std::min(K, (k+1)*ofmblock); v2++ )
- LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
- for ( v2 = K; v2 < (k+1)*ofmblock ; v2++ )
- LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = 0.0f;
- }
- for ( v1 = C; v1 < (c+1)*ifmblock ; v1++ ) {
- for ( v2 = k*ofmblock; v2 < (k+1)*ofmblock; v2++ )
- LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = 0.0f;
+LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
+ int S, int C, int K, int blocksifm,
+ int blocksofm, int ifmblock,
+ int ofmblock, int start, int end) {
+ LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K);
+ LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm, R, S, ifmblock, ofmblock);
+ int r, s, k, c, v1, v2;
+
+ for (k = start; k < end; k++) {
+ for (c = 0; c < blocksifm; c++) {
+ for (r = 0; r < R; r++) {
+ for (s = 0; s < S; s++) {
+ for (v1 = c * ifmblock; v1 < std::min(C, (c + 1) * ifmblock); v1++) {
+ for (v2 = k * ofmblock; v2 < std::min(K, (k + 1) * ofmblock); v2++)
+ LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
+ v2 - k * ofmblock, blocksifm, R, S, ifmblock,
+ ofmblock) =
+ LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
+ for (v2 = K; v2 < (k + 1) * ofmblock; v2++)
+ LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
+ v2 - k * ofmblock, blocksifm, R, S, ifmblock,
+ ofmblock) = 0.0f;
+ }
+ for (v1 = C; v1 < (c + 1) * ifmblock; v1++) {
+ for (v2 = k * ofmblock; v2 < (k + 1) * ofmblock; v2++)
+ LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
+ v2 - k * ofmblock, blocksifm, R, S, ifmblock,
+ ofmblock) = 0.0f;
}
}
}
@@ -109,35 +117,28 @@ LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float *kcrs, int R, i
}
}
-
+class libxsmm_dnn_conv_desc_wrap {
+ public:
+ const libxsmm_dnn_conv_desc d;
-class libxsmm_dnn_conv_desc_wrap{
- public:
- const libxsmm_dnn_conv_desc d;
-
- libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc &d_) : d(d_){
- }
- bool operator==(const libxsmm_dnn_conv_desc_wrap &w) const{
- return( d.N == w.d.N &&
- d.C == w.d.C &&
- d.H == w.d.H &&
- d.W == w.d.W &&
- d.K == w.d.K &&
- d.R == w.d.R &&
- d.S == w.d.S &&
- d.u == w.d.u &&
- d.v == w.d.v &&
- d.pad_h_in == w.d.pad_h_in &&
- d.pad_w_in == w.d.pad_w_in
- );
- }
+ libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {}
+ bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const {
+ return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W &&
+ d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u &&
+ d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w);
+ }
};
-
-
-struct HashFunction{
- std::size_t operator()(const libxsmm_dnn_conv_desc_wrap & w) const{
+
+struct HashFunction {
+ std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const {
+ // unsigned char ptr[sizeof(&w.d)];
+
+ // memcpy(ptr, (unsigned char *)&w.d, sizeof(&w.d))
+
+ //
+ /*
std::ostringstream N,C,H,W,K,R,S,u,v,padh,padw;
-
+
N << w.d.N; C << w.d.C;
H << w.d.H; W << w.d.W;
K << w.d.K; R << w.d.R;
@@ -152,59 +153,71 @@ struct HashFunction{
+ S.str() + u.str()\
+ v.str() + padh.str()\
+ padw.str();
-
- return ( std::hash<std::string>()(out_));
+ //
+ //
+ */
+ return (std::hash<unsigned long long>()((unsigned long long)&(w.d)));
}
};
-class handles{
- public:
- libxsmm_dnn_layer* find( const libxsmm_dnn_conv_desc_wrap &w) {
- std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*,
- HashFunction>::iterator i = libxsmm_handles.find(w);
- if (i == libxsmm_handles.end()){
- libxsmm_dnn_err_t status;
- libxsmm_dnn_layer* libxsmm_handle =
- libxsmm_dnn_create_conv_layer(w.d, &status);
- chk_libxsmm_err(status, "Create handle");
- libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
- return libxsmm_handle;
- }
- else
- return i->second;
- }
- ~handles(){
- std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*,
- HashFunction>::iterator i;
- for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
+class handles {
+ public:
+ libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) {
+ std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
+ HashFunction>::iterator i = libxsmm_handles.find(w);
+ if (i == libxsmm_handles.end()) {
+ libxsmm_dnn_err_t status;
+ libxsmm_dnn_layer* libxsmm_handle =
+ libxsmm_dnn_create_conv_layer(w.d, &status);
+ chk_libxsmm_err(status, "Create handle");
+ libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
+ return libxsmm_handle;
+ } else
+ return i->second;
+ }
+ ~handles() {
+ std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
+ HashFunction>::iterator i;
+ for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
- "Destroy handle");
- }
- private:
- std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction> libxsmm_handles;
+ "Destroy handle");
+ }
+
+ private:
+ std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
+ HashFunction>
+ libxsmm_handles;
};
static handles libxsmm_handles;
+//#define LIBXSMM_DETAILED_TIMING
+
template <typename InputPtr, typename FilterPtr, typename OutputPtr>
static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
const libxsmm_dnn_conv_desc& desc,
- libxsmm_dnn_compute_kind kind, InputPtr input,
- FilterPtr filter, OutputPtr output) {
+ libxsmm_dnn_compute_kind kind,
+ InputPtr input, FilterPtr filter,
+ OutputPtr output) {
+#if defined(LIBXSMM_DETAILED_TIMING)
+ unsigned long long l_tick1, l_tick2, l_tick3, l_tick4, l_tick5, l_tick6,
+ l_tick7, l_tick8, l_tick9, l_tick10;
+ l_tick1 = libxsmm_timer_tick();
+#endif
// setup scoped allocator, which adopts the allocator from the context
const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx);
libxsmm_dnn_err_t status;
libxsmm_dnn_layer* libxsmm_handle;
libxsmm_dnn_conv_desc_wrap w(desc);
void* scratch;
-
- if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
- libxsmm_handle = libxsmm_handles.find(w);
- else {
- libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
- chk_libxsmm_err(status, "Create handle");
- }
-
+
+ // if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
+ libxsmm_handle = libxsmm_handles.find(w);
+ // else{
+ // libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
+ // chk_libxsmm_err(status, "Create handle");
+ //}
+
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
@@ -217,100 +230,168 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
libxsmm_dnn_buffer* libxsmm_output;
libxsmm_dnn_filter* libxsmm_filter;
- /*
- const DeviceBase::CpuWorkerThreads* worker_threads =
- ctx->device()->tensorflow_cpu_worker_threads();
-
- int num_threads = worker_threads->num_threads;
-*/
+#if defined(LIBXSMM_DETAILED_TIMING)
+ l_tick2 = libxsmm_timer_tick();
+#endif
int ifmblock = (libxsmm_handle->ifmblock);
int ofmblock = (libxsmm_handle->ofmblock);
- int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
- int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1;
- float *native_filter = (float*)libxsmm_aligned_scratch(
- blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float),
- 2097152);
+ int blocksifm =
+ desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1;
+ int blocksofm =
+ desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1;
+ float* native_filter =
+ (float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S *
+ ifmblock * ofmblock * sizeof(float),
+ 2097152);
const DeviceBase::CpuWorkerThreads* worker_threads =
ctx->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
-
- if(blocksofm > num_threads){
- int work = blocksofm;
- BlockingCounter count(num_threads);
- for (int i = 0; i < num_threads; ++i) {
+#if 1
+ if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
+ kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
+ if (blocksofm > num_threads) {
+ int work = blocksofm;
+ BlockingCounter count(num_threads);
+ for (int i = 0; i < num_threads; ++i) {
worker_threads->workers->Schedule([=, &count]() {
- int start = work/num_threads*i;
- int end = (start + work/num_threads) > work ? work: start + work/num_threads;
- copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S,desc.C, desc.K,blocksifm,blocksofm,ifmblock,ofmblock,start, end);
- count.DecrementCount();
+ int start = work / num_threads * i;
+ int end = (start + work / num_threads) > work
+ ? work
+ : start + work / num_threads;
+ copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
+ desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
+ start, end);
+ count.DecrementCount();
});
- }
- count.Wait();
- }
- else{
+ }
+ count.Wait();
+ } else {
+ int work = blocksofm;
+ int num_threads = work;
- int work = blocksofm;
- int num_threads = work;
-
- BlockingCounter count(num_threads);
- for (int i = 0; i < num_threads; ++i) {
+ BlockingCounter count(num_threads);
+ for (int i = 0; i < num_threads; ++i) {
worker_threads->workers->Schedule([=, &count]() {
- int start = i;
- int end = i+1;
- copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S,desc.C, desc.K,blocksifm,blocksofm,ifmblock,ofmblock, start, end);
- count.DecrementCount();
+ int start = i;
+ int end = i + 1;
+ copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
+ desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
+ start, end);
+ count.DecrementCount();
});
+ }
+ count.Wait();
}
- count.Wait();
}
-
- libxsmm_input = libxsmm_dnn_link_buffer(
- libxsmm_handle, LIBXSMM_DNN_INPUT, input, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
+ // Added: for weight update
+ else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
+ libxsmm_filter =
+ libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter,
+ LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status);
+ chk_libxsmm_err(status,
+ "Link filter"); // weight update is in RSCK as
+ // filter should be returned in RSCK
+ // format
+ }
+#else
+ memset(native_filter, 0,
+ blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock *
+ sizeof(float));
+#endif
+
+#if defined(LIBXSMM_DETAILED_TIMING)
+ l_tick3 = libxsmm_timer_tick();
+#endif
+
+ libxsmm_input =
+ libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input,
+ LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
chk_libxsmm_err(status, "Link input buffer");
- libxsmm_output = libxsmm_dnn_link_buffer(
- libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
+ libxsmm_output =
+ libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output,
+ LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
chk_libxsmm_err(status, "Link output buffer");
- libxsmm_filter = libxsmm_dnn_link_filter(
- libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
- chk_libxsmm_err(status, "Link filter");
-
- chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output");
-
-
+ if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
+ kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
+ libxsmm_filter = libxsmm_dnn_link_filter(
+ libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter,
+ LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
+ chk_libxsmm_err(status, "Link filter");
+ }
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
- chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT),
+ chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output");
+
+ chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
+ LIBXSMM_DNN_REGULAR_INPUT),
"Bind input forward");
- chk_libxsmm_err(
- libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_REGULAR_OUTPUT),
- "Bind output forward");
- chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
+ chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
+ LIBXSMM_DNN_REGULAR_OUTPUT),
+ "Bind output forward");
+ chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
+ LIBXSMM_DNN_REGULAR_FILTER),
"Bind filter forward");
- } else {
- chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_GRADIENT_INPUT),
+ } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
+ chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input");
+
+ chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
+ LIBXSMM_DNN_GRADIENT_INPUT),
"Bind input backward");
- chk_libxsmm_err(
- libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT),
- "Bind output backward");
- chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
+ chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
+ LIBXSMM_DNN_GRADIENT_OUTPUT),
+ "Bind output backward");
+ chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
+ LIBXSMM_DNN_REGULAR_FILTER),
"Bind filter backward");
+ } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
+ chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter");
+
+ chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
+ LIBXSMM_DNN_REGULAR_INPUT),
+ "Bind input weight udpate");
+ chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
+ LIBXSMM_DNN_GRADIENT_OUTPUT),
+ "Bind output weight update");
+ chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
+ LIBXSMM_DNN_GRADIENT_FILTER),
+ "Bind filter weight update");
+ } else {
+ /* shouldn't happen */
}
+#if defined(LIBXSMM_DETAILED_TIMING)
+ l_tick4 = libxsmm_timer_tick();
+#endif
+
/* bind scratch */
- scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, kind, &status ), 2097152);
- chk_libxsmm_err( status, "scratch allocation" );
- chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, kind, scratch ), "binding scratch" );
+ scratch = (void*)libxsmm_aligned_scratch(
+ libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL,
+ &status),
+ 2097152);
+ chk_libxsmm_err(status, "scratch allocation");
+ chk_libxsmm_err(libxsmm_dnn_bind_scratch(
+ libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
+ "binding scratch");
+
+#if defined(LIBXSMM_DETAILED_TIMING)
+ l_tick5 = libxsmm_timer_tick();
+#endif
if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER);
}
+#if defined(LIBXSMM_DETAILED_TIMING)
+ l_tick6 = libxsmm_timer_tick();
+#endif
+
+#if 1
BlockingCounter counter(num_threads);
-
+
for (int i = 0; i < num_threads; ++i) {
worker_threads->workers->Schedule([=, &counter]() {
chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i),
@@ -319,28 +400,97 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
});
}
counter.Wait();
+#else
+#pragma omp parallel
+ {
+ chk_libxsmm_err(
+ libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, omp_get_thread_num()),
+ "Worker");
+ }
+#endif
+
+#if defined(LIBXSMM_DETAILED_TIMING)
+ l_tick7 = libxsmm_timer_tick();
+#endif
+
+ if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
+ libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER);
+ }
+
+#if defined(LIBXSMM_DETAILED_TIMING)
+ l_tick8 = libxsmm_timer_tick();
+#endif
/* clean up */
- chk_libxsmm_err( libxsmm_dnn_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ), "release scratch" );
+ chk_libxsmm_err(
+ libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL),
+ "release scratch");
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
- chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ), "release input" );
- chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT ), "release output" );
- chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" );
+ chk_libxsmm_err(
+ libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
+ "release input");
+ chk_libxsmm_err(
+ libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT),
+ "release output");
+ chk_libxsmm_err(
+ libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
+ "release filter");
+ } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
+ chk_libxsmm_err(
+ libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT),
+ "release input");
+ chk_libxsmm_err(
+ libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
+ "release output");
+ chk_libxsmm_err(
+ libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
+ "release filter");
+ } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
+ chk_libxsmm_err(
+ libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
+ "release input");
+ chk_libxsmm_err(
+ libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
+ "release output");
+ chk_libxsmm_err(
+ libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER),
+ "release filter");
} else {
- chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT ), "release input" );
- chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ), "release output" );
- chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" );
+ /* shouldn't happen */
}
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
-
- if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
- chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
- "Destroy handle");
+
+#if defined(LIBXSMM_DETAILED_TIMING)
+ l_tick9 = libxsmm_timer_tick();
+#endif
+
+ // if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
+ // chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
+ // "Destroy handle");
libxsmm_free(native_filter);
libxsmm_free(scratch);
+
+#if defined(LIBXSMM_DETAILED_TIMING)
+ l_tick10 = libxsmm_timer_tick();
+ printf(
+ "time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, "
+ "%f, %f, %f\n",
+ desc.N, desc.C, desc.K, desc.R, desc.S,
+ libxsmm_timer_duration(l_tick1, l_tick2),
+ libxsmm_timer_duration(l_tick2, l_tick3),
+ libxsmm_timer_duration(l_tick3, l_tick4),
+ libxsmm_timer_duration(l_tick4, l_tick5),
+ libxsmm_timer_duration(l_tick5, l_tick6),
+ libxsmm_timer_duration(l_tick6, l_tick7),
+ libxsmm_timer_duration(l_tick7, l_tick8),
+ libxsmm_timer_duration(l_tick8, l_tick9),
+ libxsmm_timer_duration(l_tick9, l_tick10),
+ libxsmm_timer_duration(l_tick1, l_tick10));
+#endif
+
return true; // Succeeded
}
@@ -348,8 +498,8 @@ template <typename T>
struct XsmmFwdConv2D<CPUDevice, T> {
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
const T* input, const T* filter, T* output) {
- return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD, input,
- filter, output);
+ return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD,
+ input, filter, output);
}
};
@@ -357,8 +507,8 @@ template <typename T>
struct XsmmBkwInputConv2D<CPUDevice, T> {
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
T* input, const T* filter, const T* output) {
- return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD, input,
- filter, output);
+ return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD,
+ input, filter, output);
}
};
@@ -366,8 +516,8 @@ template <typename T>
struct XsmmBkwFilterConv2D<CPUDevice, T> {
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
const T* input, T* filter, const T* output) {
- return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD, input,
- filter, output);
+ return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD,
+ input, filter, output);
}
};