aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_ops.cc')
-rw-r--r--tensorflow/core/kernels/conv_ops.cc10
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index facfe4467d..8076daf387 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -213,8 +213,8 @@ class LaunchXsmmConvOp<CPUDevice, float> {
desc.v = stride_cols;
desc.pad_h = pad_rows;
desc.pad_w = pad_cols;
- desc.pad_h_in = pad_rows; // libxsmm supports only physical padding for now
- desc.pad_w_in = pad_cols; // libxsmm supports only physical padding for now
+ desc.pad_h_in = 0;
+ desc.pad_w_in = 0;
desc.pad_h_out = 0;
desc.pad_w_out = 0;
desc.threads = num_threads;
@@ -222,13 +222,17 @@ class LaunchXsmmConvOp<CPUDevice, float> {
desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
- desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
+ desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
if (!CanUseXsmmConv2D(desc, data_format)) {
return false;
}
+ if (!CanUseXsmmConv2D(desc, data_format)) {
+ return false;
+ }
+
auto input_ptr = input.template flat<float>().data();
auto filter_ptr = filter.template flat<float>().data();
auto output_ptr = output->template flat<float>().data();