aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/xsmm_conv2d.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/xsmm_conv2d.cc')
-rw-r--r--tensorflow/core/kernels/xsmm_conv2d.cc131
1 files changed, 82 insertions, 49 deletions
diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc
index 0301ad49e7..823cdf7e09 100644
--- a/tensorflow/core/kernels/xsmm_conv2d.cc
+++ b/tensorflow/core/kernels/xsmm_conv2d.cc
@@ -33,6 +33,7 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(void);
#include "include/libxsmm_cpuid.h"
#include "libxsmm_dnn_handle.h"
+#include "libxsmm_malloc.h"
namespace tensorflow {
@@ -143,26 +144,28 @@ struct HashFunction{
S << w.d.S; u << w.d.u;
v << w.d.v; padh << w.d.pad_h_in;
padw << w.d.pad_w_in;
-
-
+
+
std::string out_ = N.str() + C.str()\
+ H.str() + W.str()\
+ K.str() + R.str()\
+ S.str() + u.str()\
+ v.str() + padh.str()\
+ padw.str();
-
+
return ( std::hash<std::string>()(out_));
}
};
-
+
class handles{
public:
- libxsmm_dnn_conv_handle* find( const libxsmm_dnn_conv_desc_wrap &w) {
- std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i = libxsmm_handles.find(w);
+ libxsmm_dnn_layer* find( const libxsmm_dnn_conv_desc_wrap &w) {
+ std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*,
+ HashFunction>::iterator i = libxsmm_handles.find(w);
if (i == libxsmm_handles.end()){
libxsmm_dnn_err_t status;
- libxsmm_dnn_conv_handle* libxsmm_handle = libxsmm_dnn_create_conv_handle_check(w.d, &status);
+ libxsmm_dnn_layer* libxsmm_handle =
+ libxsmm_dnn_create_conv_layer(w.d, &status);
chk_libxsmm_err(status, "Create handle");
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
return libxsmm_handle;
@@ -171,15 +174,14 @@ class handles{
return i->second;
}
~handles(){
- std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i;
+ std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*,
+ HashFunction>::iterator i;
for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
- chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(i->second),
+ chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
"Destroy handle");
}
private:
-
- std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction> libxsmm_handles;
-
+ std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction> libxsmm_handles;
};
static handles libxsmm_handles;
@@ -187,22 +189,25 @@ static handles libxsmm_handles;
template <typename InputPtr, typename FilterPtr, typename OutputPtr>
static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
const libxsmm_dnn_conv_desc& desc,
- libxsmm_dnn_conv_kind kind, InputPtr input,
+ libxsmm_dnn_compute_kind kind, InputPtr input,
FilterPtr filter, OutputPtr output) {
+ // setup scoped allocator, which adopts the allocator from the context
+ const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx);
libxsmm_dnn_err_t status;
- libxsmm_dnn_conv_handle* libxsmm_handle;
+ libxsmm_dnn_layer* libxsmm_handle;
libxsmm_dnn_conv_desc_wrap w(desc);
+ void* scratch;
- if(kind == LIBXSMM_DNN_CONV_KIND_FWD)
+ if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
libxsmm_handle = libxsmm_handles.find(w);
- else{
- libxsmm_handle = libxsmm_dnn_create_conv_handle_check(desc, &status);
+ else {
+ libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
chk_libxsmm_err(status, "Create handle");
}
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
- chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
+ chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
"Destroy handle");
return false; // Use non-libxsmm code
}
@@ -211,23 +216,23 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
libxsmm_dnn_buffer* libxsmm_input;
libxsmm_dnn_buffer* libxsmm_output;
libxsmm_dnn_filter* libxsmm_filter;
-
- /*
+
+ /*
const DeviceBase::CpuWorkerThreads* worker_threads =
ctx->device()->tensorflow_cpu_worker_threads();
-
+
int num_threads = worker_threads->num_threads;
*/
int ifmblock = (libxsmm_handle->ifmblock);
- int ofmblock = (libxsmm_handle->ofmblock);
+ int ofmblock = (libxsmm_handle->ofmblock);
- int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
+ int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1;
- float *native_filter = (float*)libxsmm_aligned_malloc( blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), 2097152);
-
+ float *native_filter = (float*)libxsmm_aligned_scratch(
+ blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float),
+ 2097152);
-
const DeviceBase::CpuWorkerThreads* worker_threads =
ctx->device()->tensorflow_cpu_worker_threads();
@@ -264,50 +269,78 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
count.Wait();
}
- libxsmm_input = libxsmm_dnn_link_input_buffer_check(
- libxsmm_handle, input, LIBXSMM_DNN_CONV_FORMAT_NHWC_PTR, &status);
+ libxsmm_input = libxsmm_dnn_link_buffer(
+ libxsmm_handle, LIBXSMM_DNN_INPUT, input, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
chk_libxsmm_err(status, "Link input buffer");
- libxsmm_output = libxsmm_dnn_link_output_buffer_check(
- libxsmm_handle, output, LIBXSMM_DNN_CONV_FORMAT_NHWC_PTR, &status);
+ libxsmm_output = libxsmm_dnn_link_buffer(
+ libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
chk_libxsmm_err(status, "Link output buffer");
- libxsmm_filter = libxsmm_dnn_link_filter_check(
- libxsmm_handle, native_filter, LIBXSMM_DNN_CONV_FORMAT_LIBXSMM_PTR, &status);
+ libxsmm_filter = libxsmm_dnn_link_filter(
+ libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
chk_libxsmm_err(status, "Link filter");
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output");
- chk_libxsmm_err(libxsmm_dnn_bind_input_buffer(libxsmm_handle, libxsmm_input),
- "Bind input");
- chk_libxsmm_err(
- libxsmm_dnn_bind_output_buffer(libxsmm_handle, libxsmm_output),
- "Bind output");
- chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter),
- "Bind filter");
- if (kind == LIBXSMM_DNN_CONV_KIND_BWD) {
- libxsmm_dnn_transpose_filter(libxsmm_handle);
+ if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
+ chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT),
+ "Bind input forward");
+ chk_libxsmm_err(
+ libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_REGULAR_OUTPUT),
+ "Bind output forward");
+ chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
+ "Bind filter forward");
+ } else {
+ chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_GRADIENT_INPUT),
+ "Bind input backward");
+ chk_libxsmm_err(
+ libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT),
+ "Bind output backward");
+ chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
+ "Bind filter backward");
}
- BlockingCounter counter(num_threads);
-
+ /* bind scratch */
+ scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, kind, &status ), 2097152);
+ chk_libxsmm_err( status, "scratch allocation" );
+ chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, kind, scratch ), "binding scratch" );
+ if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
+ libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER);
+ }
+ BlockingCounter counter(num_threads);
+
for (int i = 0; i < num_threads; ++i) {
worker_threads->workers->Schedule([=, &counter]() {
- chk_libxsmm_err(libxsmm_dnn_convolve_st(libxsmm_handle, kind, 0, i),
+ chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i),
"Worker");
counter.DecrementCount();
});
}
counter.Wait();
+
+ /* clean up */
+ chk_libxsmm_err( libxsmm_dnn_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ), "release scratch" );
+ if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
+ chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ), "release input" );
+ chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT ), "release output" );
+ chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" );
+ } else {
+ chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT ), "release input" );
+ chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ), "release output" );
+ chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" );
+ }
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
- if(kind != LIBXSMM_DNN_CONV_KIND_FWD)
- chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
+ if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
+ chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
"Destroy handle");
+
libxsmm_free(native_filter);
+ libxsmm_free(scratch);
return true; // Succeeded
}
@@ -315,7 +348,7 @@ template <typename T>
struct XsmmFwdConv2D<CPUDevice, T> {
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
const T* input, const T* filter, T* output) {
- return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_FWD, input,
+ return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD, input,
filter, output);
}
};
@@ -324,7 +357,7 @@ template <typename T>
struct XsmmBkwInputConv2D<CPUDevice, T> {
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
T* input, const T* filter, const T* output) {
- return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_BWD, input,
+ return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD, input,
filter, output);
}
};
@@ -333,7 +366,7 @@ template <typename T>
struct XsmmBkwFilterConv2D<CPUDevice, T> {
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
const T* input, T* filter, const T* output) {
- return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_CONV_KIND_UPD, input,
+ return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD, input,
filter, output);
}
};