aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_grad_filter_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_filter_ops.cc')
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc115
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 2e385f2c55..f88862bfeb 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -30,6 +30,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
+#ifdef TENSORFLOW_USE_LIBXSMM
+#include "tensorflow/core/kernels/xsmm_conv2d.h"
+#endif
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -88,6 +91,75 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_LIBXSMM
+template <typename Device, class T>
+struct LaunchXsmmBackwardFilter {
+ bool operator()(OpKernelContext* context, const Device& d,
+ typename TTypes<T, 4>::ConstTensor input_backward,
+ typename TTypes<T, 4>::Tensor kernel,
+ typename TTypes<T, 4>::ConstTensor output_backward,
+ int input_rows, int input_cols, int row_stride,
+ int col_stride, int pad_h, int pad_w,
+ TensorFormat data_format) const {
+ return false;
+ }
+};
+
+template <>
+struct LaunchXsmmBackwardFilter<CPUDevice, float> {
+ bool operator()(OpKernelContext* context, const CPUDevice& d,
+ typename TTypes<float, 4>::ConstTensor input,
+ typename TTypes<float, 4>::Tensor filter,
+ typename TTypes<float, 4>::ConstTensor output, int input_rows,
+ int input_cols, int row_stride, int col_stride, int pad_h,
+ int pad_w, TensorFormat data_format) const {
+ auto batch = input.dimension(0);
+ auto in_depth = input.dimension(3);
+ auto out_depth = output.dimension(3);
+ auto filter_rows = filter.dimension(0);
+ auto filter_cols = filter.dimension(1);
+
+ auto num_threads =
+ context->device()->tensorflow_cpu_worker_threads()->num_threads;
+ // See libxsmm_dnn.h for this struct definition.
+ libxsmm_dnn_conv_desc desc;
+ desc.N = batch;
+ desc.C = in_depth;
+ desc.H = input_rows;
+ desc.W = input_cols;
+ desc.K = out_depth;
+ desc.R = filter_rows;
+ desc.S = filter_cols;
+ desc.u = row_stride;
+ desc.v = col_stride;
+ desc.pad_h = pad_h;
+ desc.pad_w = pad_w;
+ desc.pad_h_in = 0; // pad_rows; // ignored by libxsmm for now.
+ desc.pad_w_in = 0; // pad_cols; // ignored by libxsmm for now.
+ desc.pad_h_out = 0;
+ desc.pad_w_out = 0;
+ desc.threads = num_threads;
+ desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
+ desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
+ desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
+ desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
+ desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
+ desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
+
+ if (!CanUseXsmmConv2D(desc, data_format)) {
+ return false;
+ }
+
+ auto input_ptr = input.data();
+ auto filter_ptr = filter.data();
+ auto output_ptr = output.data();
+ bool success = functor::XsmmBkwFilterConv2D<CPUDevice, float>()(
+ context, desc, input_ptr, filter_ptr, output_ptr);
+ return success;
+ }
+};
+#endif
+
template <typename Device, class T>
class Conv2DFastBackpropFilterOp : public OpKernel {
public:
@@ -135,6 +207,36 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, filter_shape, &filter_backprop));
+#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
+
+ int64 pad_top, pad_bottom;
+ int64 pad_left, pad_right;
+ OP_REQUIRES_OK(
+ context,
+ GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(
+ context,
+ GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
+
+ if (pad_left == pad_right && pad_top == pad_bottom) {
+ if (LaunchXsmmBackwardFilter<Device, T>()(
+ context, context->eigen_device<Device>(), input.tensor<T, 4>(),
+ filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(),
+ dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size,
+ (int)dims.spatial_dims[0].stride,
+ (int)dims.spatial_dims[1].stride, (int)pad_top, (int)pad_left,
+ data_format_)) {
+ return;
+ }
+ }
+#endif
+
functor::SpatialConvolutionBackwardKernel<Device, T>()(
context->eigen_device<Device>(), filter_backprop->tensor<T, 4>(),
input.tensor<T, 4>(), out_backprop.tensor<T, 4>(),
@@ -213,6 +315,19 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
dims.spatial_dims[1].stride, padding_,
&dims.spatial_dims[1].output_size, &pad_left, &pad_right));
+#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
+ if (pad_left == pad_right && pad_top == pad_bottom) {
+ if (LaunchXsmmBackwardFilter<Device, T>()(
+ context, context->eigen_device<Device>(), input.tensor<T, 4>(),
+ filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(),
+ dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size,
+ (int)dims.spatial_dims[0].stride,
+ (int)dims.spatial_dims[1].stride, (int)pad_top, (int)pad_left,
+ data_format_)) {
+ return;
+ }
+ }
+#endif
// The total dimension size of each kernel.
const int filter_total_size = dims.spatial_dims[0].filter_size *