aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/xsmm_conv2d.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/xsmm_conv2d.cc')
-rw-r--r--tensorflow/core/kernels/xsmm_conv2d.cc16
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc
index 878abe9712..7936cbcd46 100644
--- a/tensorflow/core/kernels/xsmm_conv2d.cc
+++ b/tensorflow/core/kernels/xsmm_conv2d.cc
@@ -145,8 +145,8 @@ struct HashFunction {
S << w.d.S; u << w.d.u;
v << w.d.v; padh << w.d.pad_h_in;
padw << w.d.pad_w_in;
-
-
+
+
std::string out_ = N.str() + C.str()\
+ H.str() + W.str()\
+ K.str() + R.str()\
@@ -172,8 +172,9 @@ class handles {
chk_libxsmm_err(status, "Create handle");
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
return libxsmm_handle;
- } else
+ } else {
return i->second;
+ }
}
~handles() {
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
@@ -191,7 +192,7 @@ class handles {
static handles libxsmm_handles;
-//#define LIBXSMM_DETAILED_TIMING
+// #define LIBXSMM_DETAILED_TIMING
template <typename InputPtr, typename FilterPtr, typename OutputPtr>
static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
@@ -287,9 +288,8 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
}
count.Wait();
}
- }
- // Added: for weight update
- else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
+ } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
+ // Added: for weight update
libxsmm_filter =
libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter,
LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status);
@@ -352,7 +352,7 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
LIBXSMM_DNN_REGULAR_INPUT),
- "Bind input weight udpate");
+ "Bind input weight update");
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
LIBXSMM_DNN_GRADIENT_OUTPUT),
"Bind output weight update");