diff options
Diffstat (limited to 'tensorflow/core/kernels/xsmm_conv2d.cc')
-rw-r--r-- | tensorflow/core/kernels/xsmm_conv2d.cc | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index 878abe9712..7936cbcd46 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -145,8 +145,8 @@ struct HashFunction { S << w.d.S; u << w.d.u; v << w.d.v; padh << w.d.pad_h_in; padw << w.d.pad_w_in; - - + + std::string out_ = N.str() + C.str()\ + H.str() + W.str()\ + K.str() + R.str()\ @@ -172,8 +172,9 @@ class handles { chk_libxsmm_err(status, "Create handle"); libxsmm_handles.insert(std::make_pair(w, libxsmm_handle)); return libxsmm_handle; - } else + } else { return i->second; + } } ~handles() { std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*, @@ -191,7 +192,7 @@ class handles { static handles libxsmm_handles; -//#define LIBXSMM_DETAILED_TIMING +// #define LIBXSMM_DETAILED_TIMING template <typename InputPtr, typename FilterPtr, typename OutputPtr> static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, @@ -287,9 +288,8 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, } count.Wait(); } - } - // Added: for weight update - else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { + } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { + // Added: for weight update libxsmm_filter = libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter, LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status); @@ -352,7 +352,7 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT), - "Bind input weight udpate"); + "Bind input weight update"); chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT), "Bind output weight update"); |