diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_ops.cc | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index facfe4467d..8076daf387 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -213,8 +213,8 @@ class LaunchXsmmConvOp<CPUDevice, float> { desc.v = stride_cols; desc.pad_h = pad_rows; desc.pad_w = pad_cols; - desc.pad_h_in = pad_rows; // libxsmm supports only physical padding for now - desc.pad_w_in = pad_cols; // libxsmm supports only physical padding for now + desc.pad_h_in = 0; + desc.pad_w_in = 0; desc.pad_h_out = 0; desc.pad_w_out = 0; desc.threads = num_threads; @@ -222,13 +222,17 @@ class LaunchXsmmConvOp<CPUDevice, float> { desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; - desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; + desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE; desc.datatype = LIBXSMM_DNN_DATATYPE_F32; if (!CanUseXsmmConv2D(desc, data_format)) { return false; } + if (!CanUseXsmmConv2D(desc, data_format)) { + return false; + } + auto input_ptr = input.template flat<float>().data(); auto filter_ptr = filter.template flat<float>().data(); auto output_ptr = output->template flat<float>().data(); |