diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_filter_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_grad_filter_ops.cc | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 2e385f2c55..f88862bfeb 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -30,6 +30,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/conv_2d.h" +#ifdef TENSORFLOW_USE_LIBXSMM +#include "tensorflow/core/kernels/xsmm_conv2d.h" +#endif #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -88,6 +91,75 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_LIBXSMM +template <typename Device, class T> +struct LaunchXsmmBackwardFilter { + bool operator()(OpKernelContext* context, const Device& d, + typename TTypes<T, 4>::ConstTensor input_backward, + typename TTypes<T, 4>::Tensor kernel, + typename TTypes<T, 4>::ConstTensor output_backward, + int input_rows, int input_cols, int row_stride, + int col_stride, int pad_h, int pad_w, + TensorFormat data_format) const { + return false; + } +}; + +template <> +struct LaunchXsmmBackwardFilter<CPUDevice, float> { + bool operator()(OpKernelContext* context, const CPUDevice& d, + typename TTypes<float, 4>::ConstTensor input, + typename TTypes<float, 4>::Tensor filter, + typename TTypes<float, 4>::ConstTensor output, int input_rows, + int input_cols, int row_stride, int col_stride, int pad_h, + int pad_w, TensorFormat data_format) const { + auto batch = input.dimension(0); + auto in_depth = input.dimension(3); + auto out_depth = output.dimension(3); + auto filter_rows = filter.dimension(0); + auto filter_cols = filter.dimension(1); + + auto num_threads = + context->device()->tensorflow_cpu_worker_threads()->num_threads; + // See libxsmm_dnn.h for this struct definition. + libxsmm_dnn_conv_desc desc; + desc.N = batch; + desc.C = in_depth; + desc.H = input_rows; + desc.W = input_cols; + desc.K = out_depth; + desc.R = filter_rows; + desc.S = filter_cols; + desc.u = row_stride; + desc.v = col_stride; + desc.pad_h = pad_h; + desc.pad_w = pad_w; + desc.pad_h_in = 0; // pad_rows; // ignored by libxsmm for now. + desc.pad_w_in = 0; // pad_cols; // ignored by libxsmm for now. + desc.pad_h_out = 0; + desc.pad_w_out = 0; + desc.threads = num_threads; + desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT; + desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; + desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK; + desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; + desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; + desc.datatype = LIBXSMM_DNN_DATATYPE_F32; + + if (!CanUseXsmmConv2D(desc, data_format)) { + return false; + } + + auto input_ptr = input.data(); + auto filter_ptr = filter.data(); + auto output_ptr = output.data(); + bool success = functor::XsmmBkwFilterConv2D<CPUDevice, float>()( + context, desc, input_ptr, filter_ptr, output_ptr); + return success; + } +}; +#endif + template <typename Device, class T> class Conv2DFastBackpropFilterOp : public OpKernel { public: @@ -135,6 +207,36 @@ class Conv2DFastBackpropFilterOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, filter_shape, &filter_backprop)); +#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD + + int64 pad_top, pad_bottom; + int64 pad_left, pad_right; + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, + dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, + dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); + + if (pad_left == pad_right && pad_top == pad_bottom) { + if (LaunchXsmmBackwardFilter<Device, T>()( + context, context->eigen_device<Device>(), input.tensor<T, 4>(), + filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(), + dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size, + (int)dims.spatial_dims[0].stride, + (int)dims.spatial_dims[1].stride, (int)pad_top, (int)pad_left, + data_format_)) { + return; + } + } +#endif + functor::SpatialConvolutionBackwardKernel<Device, T>()( context->eigen_device<Device>(), filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(), out_backprop.tensor<T, 4>(), @@ -213,6 +315,19 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, dims.spatial_dims[1].stride, padding_, &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); +#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD + if (pad_left == pad_right && pad_top == pad_bottom) { + if (LaunchXsmmBackwardFilter<Device, T>()( + context, context->eigen_device<Device>(), input.tensor<T, 4>(), + filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(), + dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size, + (int)dims.spatial_dims[0].stride, + (int)dims.spatial_dims[1].stride, (int)pad_top, (int)pad_left, + data_format_)) { + return; + } + } +#endif // The total dimension size of each kernel. const int filter_total_size = dims.spatial_dims[0].filter_size * |