diff options
Diffstat (limited to 'tensorflow/core/kernels/xsmm_conv2d.cc')
-rw-r--r-- | tensorflow/core/kernels/xsmm_conv2d.cc | 131 |
1 files changed, 82 insertions, 49 deletions
diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index 0301ad49e7..823cdf7e09 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -33,6 +33,7 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(void); #include "include/libxsmm_cpuid.h" #include "libxsmm_dnn_handle.h" +#include "libxsmm_malloc.h" namespace tensorflow { @@ -143,26 +144,28 @@ struct HashFunction{ S << w.d.S; u << w.d.u; v << w.d.v; padh << w.d.pad_h_in; padw << w.d.pad_w_in; - - + + std::string out_ = N.str() + C.str()\ + H.str() + W.str()\ + K.str() + R.str()\ + S.str() + u.str()\ + v.str() + padh.str()\ + padw.str(); - + return ( std::hash<std::string>()(out_)); } }; - + class handles{ public: - libxsmm_dnn_conv_handle* find( const libxsmm_dnn_conv_desc_wrap &w) { - std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i = libxsmm_handles.find(w); + libxsmm_dnn_layer* find( const libxsmm_dnn_conv_desc_wrap &w) { + std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, + HashFunction>::iterator i = libxsmm_handles.find(w); if (i == libxsmm_handles.end()){ libxsmm_dnn_err_t status; - libxsmm_dnn_conv_handle* libxsmm_handle = libxsmm_dnn_create_conv_handle_check(w.d, &status); + libxsmm_dnn_layer* libxsmm_handle = + libxsmm_dnn_create_conv_layer(w.d, &status); chk_libxsmm_err(status, "Create handle"); libxsmm_handles.insert(std::make_pair(w, libxsmm_handle)); return libxsmm_handle; @@ -171,15 +174,14 @@ class handles{ return i->second; } ~handles(){ - std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i; + std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, + HashFunction>::iterator i; for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++) - chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(i->second), + chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second), "Destroy handle"); } private: - - std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction> libxsmm_handles; - + std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction> libxsmm_handles; }; static handles libxsmm_handles; @@ -187,22 +189,25 @@ static handles libxsmm_handles; template <typename InputPtr, typename FilterPtr, typename OutputPtr> static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, - libxsmm_dnn_conv_kind kind, InputPtr input, + libxsmm_dnn_compute_kind kind, InputPtr input, FilterPtr filter, OutputPtr output) { + // setup scoped allocator, which adopts the allocator from the context + const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx); libxsmm_dnn_err_t status; - libxsmm_dnn_conv_handle* libxsmm_handle; + libxsmm_dnn_layer* libxsmm_handle; libxsmm_dnn_conv_desc_wrap w(desc); + void* scratch; - if(kind == LIBXSMM_DNN_CONV_KIND_FWD) + if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) libxsmm_handle = libxsmm_handles.find(w); - else{ - libxsmm_handle = libxsmm_dnn_create_conv_handle_check(desc, &status); + else { + libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status); chk_libxsmm_err(status, "Create handle"); } status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind); if (status == LIBXSMM_DNN_WARN_FALLBACK) { - chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle), + chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), "Destroy handle"); return false; // Use non-libxsmm code } @@ -211,23 +216,23 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, libxsmm_dnn_buffer* libxsmm_input; libxsmm_dnn_buffer* libxsmm_output; libxsmm_dnn_filter* libxsmm_filter; - - /* + + /* const DeviceBase::CpuWorkerThreads* worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); - + int num_threads = worker_threads->num_threads; */ int ifmblock = (libxsmm_handle->ifmblock); - int ofmblock = (libxsmm_handle->ofmblock); + int ofmblock = (libxsmm_handle->ofmblock); - int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1; + int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1; int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1; - float *native_filter = (float*)libxsmm_aligned_malloc( blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), 2097152); - + float *native_filter = (float*)libxsmm_aligned_scratch( + blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), + 2097152); - const DeviceBase::CpuWorkerThreads* worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); @@ -264,50 +269,78 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, count.Wait(); } - libxsmm_input = libxsmm_dnn_link_input_buffer_check( - libxsmm_handle, input, LIBXSMM_DNN_CONV_FORMAT_NHWC_PTR, &status); + libxsmm_input = libxsmm_dnn_link_buffer( + libxsmm_handle, LIBXSMM_DNN_INPUT, input, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); chk_libxsmm_err(status, "Link input buffer"); - libxsmm_output = libxsmm_dnn_link_output_buffer_check( - libxsmm_handle, output, LIBXSMM_DNN_CONV_FORMAT_NHWC_PTR, &status); + libxsmm_output = libxsmm_dnn_link_buffer( + libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); chk_libxsmm_err(status, "Link output buffer"); - libxsmm_filter = libxsmm_dnn_link_filter_check( - libxsmm_handle, native_filter, LIBXSMM_DNN_CONV_FORMAT_LIBXSMM_PTR, &status); + libxsmm_filter = libxsmm_dnn_link_filter( + libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status); chk_libxsmm_err(status, "Link filter"); chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output"); - chk_libxsmm_err(libxsmm_dnn_bind_input_buffer(libxsmm_handle, libxsmm_input), - "Bind input"); - chk_libxsmm_err( - libxsmm_dnn_bind_output_buffer(libxsmm_handle, libxsmm_output), - "Bind output"); - chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter), - "Bind filter"); - if (kind == LIBXSMM_DNN_CONV_KIND_BWD) { - libxsmm_dnn_transpose_filter(libxsmm_handle); + if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT), + "Bind input forward"); + chk_libxsmm_err( + libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_REGULAR_OUTPUT), + "Bind output forward"); + chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER), + "Bind filter forward"); + } else { + chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_GRADIENT_INPUT), + "Bind input backward"); + chk_libxsmm_err( + libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT), + "Bind output backward"); + chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER), + "Bind filter backward"); } - BlockingCounter counter(num_threads); - + /* bind scratch */ + scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, kind, &status ), 2097152); + chk_libxsmm_err( status, "scratch allocation" ); + chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, kind, scratch ), "binding scratch" ); + if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { + libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER); + } + BlockingCounter counter(num_threads); + for (int i = 0; i < num_threads; ++i) { worker_threads->workers->Schedule([=, &counter]() { - chk_libxsmm_err(libxsmm_dnn_convolve_st(libxsmm_handle, kind, 0, i), + chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i), "Worker"); counter.DecrementCount(); }); } counter.Wait(); + + /* clean up */ + chk_libxsmm_err( libxsmm_dnn_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ), "release scratch" ); + if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { + chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ), "release input" ); + chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT ), "release output" ); + chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" ); + } else { + chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT ), "release input" ); + chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ), "release output" ); + chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" ); + } chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input"); chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output"); chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter"); - if(kind != LIBXSMM_DNN_CONV_KIND_FWD) - chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle), + if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD) + chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), "Destroy handle"); + libxsmm_free(native_filter); + libxsmm_free(scratch); return true; // Succeeded } @@ -315,7 +348,7 @@ template <typename T> struct XsmmFwdConv2D<CPUDevice, T> { bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, const T* input, const T* filter, T* output) { - return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_FWD, input, + return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD, input, filter, output); } }; @@ -324,7 +357,7 @@ template <typename T> struct XsmmBkwInputConv2D<CPUDevice, T> { bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, T* input, const T* filter, const T* output) { - return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_BWD, input, + return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD, input, filter, output); } }; @@ -333,7 +366,7 @@ template <typename T> struct XsmmBkwFilterConv2D<CPUDevice, T> { bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, const T* input, T* filter, const T* output) { - return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_UPD, input, + return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD, input, filter, output); } }; |