diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_ops_3d.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_grad_ops_3d.cc | 1324 |
1 files changed, 976 insertions, 348 deletions
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index 15f1bf9aba..d26b86c712 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" @@ -32,111 +33,130 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" +#include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA #include "tensorflow/core/platform/stream_executor.h" using stream_executor::dnn::DimIndex; #endif +namespace { + +// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and +// conv_grad_input_ops_3d.cc. + +// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels. + +// "Depth" is already used for the channel dimension, so for the third spatial +// dimension in this file we use "plane", although in NDHWC layout it's +// indicated with a "D". + +// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage +// order (planes, height, width, depth), constructed from patches in 'col_data', +// which is required to be in storage order (out_planes * out_height * +// out_width, filter_planes, filter_height, filter_width, in_depth). +// +// Based on 2-dimensional implementation written by Yangqing Jia (jiayq). +template <typename T> +void Col2im(const T* col_data, const int depth, const int planes, + const int height, const int width, const int filter_p, + const int filter_h, const int filter_w, const int pad_pt, + const int pad_t, const int pad_l, const int pad_pb, const int pad_b, + const int pad_r, const int stride_p, const int stride_h, + const int stride_w, T* im_data) { + const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1; + const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; + const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; + int p_pad = -pad_pt; + for (int p = 0; p < planes_col; ++p) { + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + T* im_patch_data = + im_data + (p_pad * height * width + h_pad * width + w_pad) * depth; + for (int ip = p_pad; ip < p_pad + filter_p; ++ip) { + for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { + for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { + if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 && + iw < width) { + for (int i = 0; i < depth; ++i) { + im_patch_data[i] += col_data[i]; + } + } + im_patch_data += depth; + col_data += depth; + } + // Jump over remaining number of depth. + im_patch_data += depth * (width - filter_w); + } + // Jump over remaining number of (depth * width). + im_patch_data += (depth * width) * (height - filter_h); + } + w_pad += stride_w; + } + h_pad += stride_h; + } + p_pad += stride_p; + } +} + +// Returns in 'col_data', image patches in storage order (planes, height, width, +// depth) extracted from image at 'input_data', which is required to be in +// storage order (batch, planes, height, width, depth). +// +// Based on 2-dimensional implementation written by Yangqing Jia (jiayq). +template <typename T> +void Im2col(const T* input_data, const int depth, const int planes, + const int height, const int width, const int filter_p, + const int filter_h, const int filter_w, const int pad_pt, + const int pad_t, const int pad_l, const int pad_pb, const int pad_b, + const int pad_r, const int stride_p, const int stride_h, + const int stride_w, T* col_data) { + const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1; + const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; + const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; + + int p_pad = -pad_pt; + for (int p = 0; p < planes_col; ++p) { + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + for (int ip = p_pad; ip < p_pad + filter_p; ++ip) { + for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { + for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { + if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 && + iw < width) { + memcpy(col_data, + input_data + + (ip * height * width + ih * width + iw) * depth, + sizeof(T) * depth); + } else { + // This should be simply padded with zero. + memset(col_data, 0, sizeof(T) * depth); + } + col_data += depth; + } + } + } + w_pad += stride_w; + } + h_pad += stride_h; + } + p_pad += stride_p; + } +} + +} // namespace + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -// TODO(mjanusz): Get rid of the macro and return shapes directly. -#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \ - const Tensor& out_backprop = context->input(2); \ - OP_REQUIRES( \ - context, input_shape.dims() == 5, \ - errors::InvalidArgument(label, ": input must be 5-dimensional")); \ - OP_REQUIRES( \ - context, filter_shape.dims() == 5, \ - errors::InvalidArgument(label, ": filter must be 5-dimensional")); \ - OP_REQUIRES( \ - context, out_backprop.dims() == 5, \ - errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \ - const int64 batch = input_shape.dim_size(0); \ - OP_REQUIRES( \ - context, batch == out_backprop.dim_size(0), \ - errors::InvalidArgument( \ - label, ": input and out_backprop must have the same batch size")); \ - const std::array<int64, 3> input_size = { \ - {GetTensorDim(input_shape, data_format_, '0'), \ - GetTensorDim(input_shape, data_format_, '1'), \ - GetTensorDim(input_shape, data_format_, '2')}}; \ - const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \ - const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0), \ - filter_shape.dim_size(1), \ - filter_shape.dim_size(2)}}; \ - const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2'); \ - const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1'); \ - const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0'); \ - OP_REQUIRES(context, in_depth == filter_shape.dim_size(3), \ - errors::InvalidArgument( \ - label, ": input and filter must have the same depth")); \ - const int64 out_depth = filter_shape.dim_size(4); \ - OP_REQUIRES( \ - context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \ - errors::InvalidArgument( \ - label, ": filter and out_backprop must have the same out_depth")); \ - const std::array<int64, 3> dilations = { \ - {GetTensorDim(dilation_, data_format_, '0'), \ - GetTensorDim(dilation_, data_format_, '1'), \ - GetTensorDim(dilation_, data_format_, '2')}}; \ - const std::array<int64, 3> strides = { \ - {GetTensorDim(stride_, data_format_, '0'), \ - GetTensorDim(stride_, data_format_, '1'), \ - GetTensorDim(stride_, data_format_, '2')}}; \ - std::array<int64, 3> out, padding; \ - OP_REQUIRES_OK( \ - context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides, \ - padding_, &out, &padding)); \ - OP_REQUIRES(context, output_planes == out[0], \ - errors::InvalidArgument( \ - label, \ - ": Number of planes of out_backprop doesn't match " \ - "computed: actual = ", \ - output_planes, ", computed = ", out[0])); \ - OP_REQUIRES( \ - context, output_rows == out[1], \ - errors::InvalidArgument( \ - label, ": Number of rows of out_backprop doesn't match computed: ", \ - "actual = ", output_rows, ", computed = ", out[1])); \ - OP_REQUIRES( \ - context, output_cols == out[2], \ - errors::InvalidArgument( \ - label, ": Number of cols of out_backprop doesn't match computed: ", \ - "actual = ", output_cols, ", computed = ", out[2])); \ - const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \ - const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \ - const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \ - const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \ - const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \ - const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \ - const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \ - const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \ - const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \ - const auto bottom_pad_planes = \ - padded_out_planes - expanded_out_planes - top_pad_planes; \ - const auto bottom_pad_rows = \ - padded_out_rows - expanded_out_rows - top_pad_rows; \ - const auto right_pad_cols = \ - padded_out_cols - expanded_out_cols - left_pad_cols; \ - VLOG(2) << "Conv3d: " << label \ - << ": expanded_out_planes = " << expanded_out_planes \ - << ": expanded_out_rows = " << expanded_out_rows \ - << ", expanded_out_cols = " << expanded_out_cols \ - << ", padded_out_planes = " << padded_out_planes \ - << ", padded_out_rows = " << padded_out_rows \ - << ", padded_out_cols = " << padded_out_cols \ - << ", top_pad_planes = " << top_pad_planes \ - << ", top_pad_rows = " << top_pad_rows \ - << ", left_pad_cols = " << left_pad_cols \ - << ", bottom_pad_planes = " << bottom_pad_planes \ - << ", bottom_pad_rows = " << bottom_pad_rows \ - << ", right_pad_cols = " << right_pad_cols - -// Backprop for input. +// Backprop for input that offloads computation to +// Eigen::CuboidConvolutionBackwardInput. template <typename Device, class T> class Conv3DBackpropInputOp : public OpKernel { public: @@ -192,6 +212,10 @@ class Conv3DBackpropInputOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& filter = context->input(1); const TensorShape& filter_shape = filter.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + TensorShape input_shape; if (takes_shape_) { const Tensor& input_sizes = context->input(0); @@ -200,51 +224,345 @@ class Conv3DBackpropInputOp : public OpKernel { } else { input_shape = context->input(0).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{ - {0, 0}, - {top_pad_planes, bottom_pad_planes}, - {top_pad_rows, bottom_pad_rows}, - {left_pad_cols, right_pad_cols}, - {0, 0}}; + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, ConvBackpropComputeDimensions( + "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, + stride_, padding_, data_format_, &dims)); + Tensor* in_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); - // Fill out a padded out_backprop. - TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows, - padded_out_cols, out_depth}); - Tensor padded_output; + functor::CuboidConvolutionBackwardInput<Device, T>()( + context->eigen_device<Device>(), + in_backprop->tensor<T, 5>(), // input_backward + filter.tensor<T, 5>(), // filter + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + } + + private: + std::vector<int32> dilation_; + std::vector<int32> stride_; + Padding padding_; + TensorFormat data_format_; + bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp); +}; + +// Custom backprop for input that explicitly does the work sharding and calls +// Eigen only to multiply matrices. +template <typename Device, class T> +class Conv3DCustomBackpropInputOp : public OpKernel { + // Limit the maximum size of allocated temporary buffer to + // kMaxTempAllocationOverhead times the size of the input tensors (input, + // filter, out_backprop). If the size of the temporary buffer exceeds this + // limit, fallback on Eigen implementation. + static constexpr int kMaxTempAllocationOverhead = 25; + + public: + explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context), + data_format_(FORMAT_NHWC), + takes_shape_(type_string().find("V2") != std::string::npos) { + // data_format is only available in V2. + if (takes_shape_) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU.")); + } + + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); + OP_REQUIRES(context, dilation_.size() == 5, + errors::InvalidArgument("Dilation rates field must " + "specify 5 dimensions")); + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, 'C') == 1 && + GetTensorDim(dilation_, data_format_, 'N') == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilation rates in the batch and depth dimensions.")); + + // TODO(yangzihao): Add CPU version of dilated conv 3D. + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, '0') == 1 && + GetTensorDim(dilation_, data_format_, '1') == 1 && + GetTensorDim(dilation_, data_format_, '2') == 1), + errors::InvalidArgument( + "Current CPU implementation does not yet support " + "dilation rates larger than 1.")); + + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, + (GetTensorDim(stride_, data_format_, 'C') == 1 && + GetTensorDim(stride_, data_format_, 'N') == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filter = context->input(1); + const TensorShape& filter_shape = filter.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + + TensorShape input_shape; + if (takes_shape_) { + const Tensor& input_sizes = context->input(0); + // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes. + OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape)); + } else { + input_shape = context->input(0).shape(); + } + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, ConvBackpropComputeDimensions( + "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, + stride_, padding_, data_format_, &dims)); + + Tensor* in_backprop; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - padded_out_shape, &padded_output)); - Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4}; - Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1], - strides[2], 1}; - functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), out_backprop.tensor<T, 5>(), - eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>()); - const Tensor& padded_output_cref = padded_output; - - // Fill a new "reverted" filter. We need to transpose the in_depth and - // out_depth for the filter and reverse the planes, rows and cols. - TensorShape r_filter_shape( - {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth}); - Tensor r_filter; - OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(), - r_filter_shape, &r_filter)); - Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3}; - Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false}; - functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order, - filter_rev_dims, r_filter.tensor<T, 5>()); - const Tensor& r_filter_cref = r_filter; - - // Now we can call conv_3d directly. - functor::CuboidConvolution<Device, T>()( - context->eigen_device<Device>(), in_backprop->tensor<T, 5>(), - padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1, - 1, BrainPadding2EigenPadding(VALID)); + context->allocate_output(0, input_shape, &in_backprop)); + + int64 top_pad_planes, bottom_pad_planes; + int64 top_pad_rows, bottom_pad_rows; + int64 left_pad_cols, right_pad_cols; + + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, + &top_pad_planes, &bottom_pad_planes)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, + dims.spatial_dims[1].filter_size, + dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, + &top_pad_rows, &bottom_pad_rows)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[2].input_size, + dims.spatial_dims[2].filter_size, + dims.spatial_dims[2].stride, padding_, + &dims.spatial_dims[2].output_size, + &left_pad_cols, &right_pad_cols)); + + // TODO(ezhulenev): Extract work size and shard estimation to shared + // functions in conv_grad_ops, and update 2d convolution backprop. + + // The total dimension size of each kernel. + const int64 filter_total_size = + dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size * + dims.spatial_dims[2].filter_size * dims.in_depth; + + // The output image size is the spatial size of the output. + const int64 output_image_size = dims.spatial_dims[0].output_size * + dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size; + + const auto cache_sizes = Eigen::internal::CacheSizes(); + const ptrdiff_t l3_cache_size = cache_sizes.m_l3; + + // Use L3 cache size as target working set size. + const size_t target_working_set_size = l3_cache_size / sizeof(T); + + // Calculate size of matrices involved in MatMul: C = A x B. + const int64 size_A = output_image_size * dims.out_depth; + + const int64 size_B = filter_total_size * dims.out_depth; + + const int64 size_C = output_image_size * filter_total_size; + + const int64 work_unit_size = size_A + size_B + size_C; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + // Use parallel tensor contractions if there is no batching. + // + // Compared to Conv2D code, this version is missing work size estimation. In + // benchmarks I didn't find a case when it's beneficial to run parallel + // contraction compared to sharding and matmuls. + const bool use_parallel_contraction = dims.batch_size == 1; + + const size_t shard_size = + use_parallel_contraction + ? 1 + : (target_working_set_size + work_unit_size - 1) / work_unit_size; + + // Total number of elements in all the tensors used by this kernel. + int64 total_tensor_elements = input_shape.num_elements() + + filter_shape.num_elements() + + out_backprop_shape.num_elements(); + + // Shape of the temporary workspace buffer. + TensorShape col_buffer_shape = {static_cast<int64>(shard_size), + static_cast<int64>(output_image_size), + static_cast<int64>(filter_total_size)}; + int64 col_buffer_elements = col_buffer_shape.num_elements(); + + // If the temporary allocation overhead is too large, fallback on Eigen + // implementation which requires much less memory. + int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements; + if (col_buffer_overhead > kMaxTempAllocationOverhead) { + VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: " + "col_buffer_overhead=" + << col_buffer_overhead; + + functor::CuboidConvolutionBackwardInput<Device, T>()( + context->eigen_device<Device>(), + in_backprop->tensor<T, 5>(), // input_backward + filter.tensor<T, 5>(), // filter + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + + return; + } + + Tensor col_buffer; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::value, + col_buffer_shape, &col_buffer)); + + // The input offset corresponding to a single input image. + const int64 input_offset = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * + dims.spatial_dims[2].input_size * dims.in_depth; + + // The output offset corresponding to a single output image. + const int64 output_offset = + dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size * dims.out_depth; + + const T* filter_data = filter.template flat<T>().data(); + T* col_buffer_data = col_buffer.template flat<T>().data(); + const T* out_backprop_data = out_backprop.template flat<T>().data(); + + auto in_backprop_flat = in_backprop->template flat<T>(); + T* input_backprop_data = in_backprop_flat.data(); + in_backprop_flat.device(context->eigen_device<Device>()) = + in_backprop_flat.constant(T(0)); + + if (use_parallel_contraction) { + typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + TensorMap; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + ConstTensorMap; + + // Initialize contraction dims (we need to transpose 'B' below). + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; + contract_dims[0].first = 1; + contract_dims[0].second = 1; + + for (int image_id = 0; image_id < dims.batch_size; ++image_id) { + // Compute gradient into col_buffer. + TensorMap C(col_buffer_data, output_image_size, filter_total_size); + + ConstTensorMap A(out_backprop_data + output_offset * image_id, + output_image_size, dims.out_depth); + ConstTensorMap B(filter_data, filter_total_size, dims.out_depth); + + C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); + + Col2im<T>(col_buffer_data, dims.in_depth, + // Input spatial dimensions. + dims.spatial_dims[0].input_size, // input planes + dims.spatial_dims[1].input_size, // input rows + dims.spatial_dims[2].input_size, // input cols + // Filter spatial dimensions. + dims.spatial_dims[0].filter_size, // filter planes + dims.spatial_dims[1].filter_size, // filter rows + dims.spatial_dims[2].filter_size, // filter cols + // Spatial padding. + top_pad_planes, top_pad_rows, left_pad_cols, + bottom_pad_planes, bottom_pad_rows, right_pad_cols, + // Spatial striding. + dims.spatial_dims[0].stride, // stride planes + dims.spatial_dims[1].stride, // stride rows + dims.spatial_dims[2].stride, // stride cols + input_backprop_data); + + input_backprop_data += input_offset; + } + } else { + typedef Eigen::Map< + Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> + MatrixMap; + typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, + Eigen::RowMajor>> + ConstMatrixMap; + + for (int image_id = 0; image_id < dims.batch_size; + image_id += shard_size) { + const int shard_limit = + std::min(static_cast<int>(shard_size), + static_cast<int>(dims.batch_size) - image_id); + + auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols, + &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols, + &output_image_size, &filter_total_size, + &input_backprop_data, &col_buffer_data, + &out_backprop_data, &filter_data, &input_offset, + &output_offset, &size_C](int64 start, int64 limit) { + for (int shard_id = start; shard_id < limit; ++shard_id) { + T* im2col_buf = col_buffer_data + shard_id * size_C; + T* input_data = input_backprop_data + shard_id * input_offset; + const T* out_data = out_backprop_data + shard_id * output_offset; + + // Compute gradient into 'im2col_buf'. + MatrixMap C(im2col_buf, output_image_size, filter_total_size); + + ConstMatrixMap A(out_data, output_image_size, dims.out_depth); + ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth); + + C.noalias() = A * B.transpose(); + + Col2im<T>(im2col_buf, dims.in_depth, + // Input spatial dimensions. + dims.spatial_dims[0].input_size, // input planes + dims.spatial_dims[1].input_size, // input rows + dims.spatial_dims[2].input_size, // input cols + // Filter spatial dimensions. + dims.spatial_dims[0].filter_size, // filter planes + dims.spatial_dims[1].filter_size, // filter rows + dims.spatial_dims[2].filter_size, // filter cols + // Spatial padding. + top_pad_planes, top_pad_rows, left_pad_cols, + bottom_pad_planes, bottom_pad_rows, right_pad_cols, + // Spatial striding. + dims.spatial_dims[0].stride, // stride planes + dims.spatial_dims[1].stride, // stride rows + dims.spatial_dims[2].stride, // stride cols + input_data); + } + }; + Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, + work_unit_size, shard); + + input_backprop_data += input_offset * shard_limit; + out_backprop_data += output_offset * shard_limit; + } + } } private: @@ -253,21 +571,48 @@ class Conv3DBackpropInputOp : public OpKernel { Padding padding_; TensorFormat data_format_; bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp); }; +// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than +// default Eigen implementation (at the cost of ~2x-8x peak memory usage). + #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - Conv3DBackpropInputOp<CPUDevice, T>); \ + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - Conv3DBackpropInputOp<CPUDevice, T>); + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \ + .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ + .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropInputOp<CPUDevice, T>); + TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL -// Backprop for filter. +// Backprop for filter that offloads computation to +// Eigen::CuboidConvolutionBackwardFilter. template <typename Device, class T> class Conv3DBackpropFilterOp : public OpKernel { public: @@ -323,8 +668,11 @@ class Conv3DBackpropFilterOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const TensorShape& input_shape = input.shape(); - TensorShape filter_shape; + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + + TensorShape filter_shape; if (takes_shape_) { const Tensor& filter_sizes = context->input(1); OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( @@ -333,13 +681,13 @@ class Conv3DBackpropFilterOp : public OpKernel { filter_shape = context->input(1).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter"); - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{ - {0, 0}, - {top_pad_planes, bottom_pad_planes}, - {top_pad_rows, bottom_pad_rows}, - {left_pad_cols, right_pad_cols}, - {0, 0}}; + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensions( + "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, stride_, + padding_, data_format_, &dims)); + Tensor* filter_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, filter_shape, &filter_backprop)); @@ -349,70 +697,292 @@ class Conv3DBackpropFilterOp : public OpKernel { return; } - // For the backprop of the filter, we need to also transpose the - // out_backprop. - // The shape of backprop is - // [batch, out_z, out_y, out_x, out_depth] - // And we need to change it to - // [out_depth, out_x, out_y, out_z, batch] - Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0}; - TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows, - padded_out_cols, batch}); - Tensor padded_output; + functor::CuboidConvolutionBackwardFilter<Device, T>()( + context->eigen_device<Device>(), + filter_backprop->tensor<T, 5>(), // filter_backward + input.tensor<T, 5>(), // input + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + } + + private: + std::vector<int32> dilation_; + std::vector<int32> stride_; + Padding padding_; + TensorFormat data_format_; + bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp); +}; + +// Custom backprop for filter that explicitly does the work sharding and calls +// Eigen only to multiply matrices. +template <typename Device, class T> +class Conv3DCustomBackpropFilterOp : public OpKernel { + // Limit the maximum size of allocated temporary buffer to + // kMaxTempAllocationOverhead times the size of the input tensors (input, + // filter, out_backprop). If the size of the temporary buffer exceeds this + // limit, fallback on Eigen implementation. + static constexpr int kMaxTempAllocationOverhead = 25; + + public: + explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context), + data_format_(FORMAT_NHWC), + takes_shape_(type_string().find("V2") != std::string::npos) { + // data_format is only available in V2. + if (takes_shape_) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU.")); + } + + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); + OP_REQUIRES(context, dilation_.size() == 5, + errors::InvalidArgument("Dilation rates field must " + "specify 5 dimensions")); + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, 'C') == 1 && + GetTensorDim(dilation_, data_format_, 'N') == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilation rates in the batch and depth dimensions.")); + + // TODO(yangzihao): Add CPU version of dilated conv 3D. + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, '0') == 1 && + GetTensorDim(dilation_, data_format_, '1') == 1 && + GetTensorDim(dilation_, data_format_, '2') == 1), + errors::InvalidArgument( + "Current CPU implementation does not yet support " + "dilation rates larger than 1.")); + + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, + (GetTensorDim(stride_, data_format_, 'C') == 1 && + GetTensorDim(stride_, data_format_, 'N') == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const TensorShape& input_shape = input.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + + TensorShape filter_shape; + if (takes_shape_) { + const Tensor& filter_sizes = context->input(1); + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( + filter_sizes.vec<int32>(), &filter_shape)); + } else { + filter_shape = context->input(1).shape(); + } + + ConvBackpropDimensions dims; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - padded_out_shape, &padded_output)); - Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1], - strides[2], 1}; - functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), out_backprop.tensor<T, 5>(), - eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>()); - const Tensor& padded_output_cref = padded_output; - - // For the backprop of the filter, we need to transpose the input. - // The shape of input is - // [batch, in_z, in_y, in_x, in_depth] - // And we need to change it to - // [in_z, in_y, in_x, batch, in_depth] - Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4}; - TensorShape in_shuffle_shape( - {input_size[0], input_size[1], input_size[2], batch, in_depth}); - Tensor in_shuffle; + ConvBackpropComputeDimensions( + "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, stride_, + padding_, data_format_, &dims)); + + Tensor* filter_backprop; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - in_shuffle_shape, &in_shuffle)); - // No need for reversing this time. - Eigen::array<bool, 5> no_reverse{false, false, false, false, false}; - functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), input.tensor<T, 5>(), in_order, - no_reverse, in_shuffle.tensor<T, 5>()); - const Tensor& in_shuffle_cref = in_shuffle; - - // The output of the conv_3d would be - // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth] - // and we need to shuffle it back to - // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth]; - // And we need to reverse the filter backprops. - // So we need to allocate (sigh) yet another piece of memory to hold the - // output. - TensorShape filter_shuffle_shape( - {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth}); - Tensor filter_shuffle; - OP_REQUIRES_OK( - context, context->allocate_temp(DataTypeToEnum<T>::v(), - filter_shuffle_shape, &filter_shuffle)); - functor::CuboidConvolution<Device, T>()( - context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(), - padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1, - 1, BrainPadding2EigenPadding(VALID)); - - // Now copy the filter_backprop back to the destination. - Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0}; - Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false}; - const Tensor& filter_shuffle_cref = filter_shuffle; - functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(), - filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>()); + context->allocate_output(0, filter_shape, &filter_backprop)); + + if (input_shape.num_elements() == 0) { + filter_backprop->template flat<T>().setZero(); + return; + } + + int64 top_pad_planes, bottom_pad_planes; + int64 top_pad_rows, bottom_pad_rows; + int64 left_pad_cols, right_pad_cols; + + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, + &top_pad_planes, &bottom_pad_planes)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, + dims.spatial_dims[1].filter_size, + dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, + &top_pad_rows, &bottom_pad_rows)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[2].input_size, + dims.spatial_dims[2].filter_size, + dims.spatial_dims[2].stride, padding_, + &dims.spatial_dims[2].output_size, + &left_pad_cols, &right_pad_cols)); + + // TODO(ezhulenev): Extract work size and shard estimation to shared + // functions in conv_grad_ops, and update 2d convolution backprop. + + // The total dimension size of each kernel. + const int64 filter_total_size = + dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size * + dims.spatial_dims[2].filter_size * dims.in_depth; + // The output image size is the spatial size of the output. + const int64 output_image_size = dims.spatial_dims[0].output_size * + dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size; + + // Shard 'batch' images (volumes) into 'shard_size' groups of images + // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by + // dividing the L3 cache size ('target_working_set_size') by the matmul size + // of an individual image ('work_unit_size'). + + const auto cache_sizes = Eigen::internal::CacheSizes(); + const ptrdiff_t l3_cache_size = cache_sizes.m_l3; + + // TODO(andydavis) + // *) Consider reducing 'target_working_set_size' if L3 is shared by + // other concurrently running tensorflow ops. + const size_t target_working_set_size = l3_cache_size / sizeof(T); + + const int64 size_A = output_image_size * filter_total_size; + + const int64 size_B = output_image_size * dims.out_depth; + + const int64 size_C = filter_total_size * dims.out_depth; + + const int64 work_unit_size = size_A + size_B + size_C; + + const size_t shard_size = + (target_working_set_size + work_unit_size - 1) / work_unit_size; + + // Total number of elements in all the tensors used by this kernel. + int64 total_tensor_elements = input_shape.num_elements() + + filter_shape.num_elements() + + out_backprop_shape.num_elements(); + + // Shape of the temporary workspace buffer. + TensorShape col_buffer_shape = {static_cast<int64>(shard_size), + static_cast<int64>(output_image_size), + static_cast<int64>(filter_total_size)}; + int64 col_buffer_elements = col_buffer_shape.num_elements(); + + // If the temporary allocation overhead is too large, fallback on Eigen + // implementation which requires much less memory. + int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements; + if (col_buffer_overhead > kMaxTempAllocationOverhead) { + VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: " + "col_buffer_overhead=" + << col_buffer_overhead; + + functor::CuboidConvolutionBackwardFilter<Device, T>()( + context->eigen_device<Device>(), + filter_backprop->tensor<T, 5>(), // filter_backward + input.tensor<T, 5>(), // input + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + + return; + } + + Tensor col_buffer; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::value, + col_buffer_shape, &col_buffer)); + + // The input offset corresponding to a single input image. + const int64 input_offset = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * + dims.spatial_dims[2].input_size * dims.in_depth; + // The output offset corresponding to a single output image. + const int64 output_offset = + dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size * dims.out_depth; + + const T* input_data = input.template flat<T>().data(); + T* col_buffer_data = col_buffer.template flat<T>().data(); + const T* out_backprop_data = out_backprop.template flat<T>().data(); + T* filter_backprop_data = filter_backprop->template flat<T>().data(); + + typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + TensorMap; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + ConstTensorMap; + + TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth); + C.setZero(); + + // Initialize contraction dims (we need to transpose 'A' below). + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; + contract_dims[0].first = 0; + contract_dims[0].second = 0; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) { + const int shard_limit = + std::min(static_cast<int>(shard_size), + static_cast<int>(dims.batch_size) - image_id); + + auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes, + &top_pad_rows, &left_pad_cols, &bottom_pad_planes, + &bottom_pad_rows, &right_pad_cols, &input_offset, + &size_A](int64 start, int64 limit) { + for (int shard_id = start; shard_id < limit; ++shard_id) { + const T* input_data_shard = input_data + shard_id * input_offset; + T* col_data_shard = col_buffer_data + shard_id * size_A; + + // When we compute the gradient with respect to the filters, we need + // to do im2col to allow gemm-type computation. + Im2col<T>(input_data_shard, dims.in_depth, + // Input spatial dimensions. + dims.spatial_dims[0].input_size, // input planes + dims.spatial_dims[1].input_size, // input rows + dims.spatial_dims[2].input_size, // input cols + // Filter spatial dimensions. + dims.spatial_dims[0].filter_size, // filter planes + dims.spatial_dims[1].filter_size, // filter rows + dims.spatial_dims[2].filter_size, // filter cols + // Spatial padding. + top_pad_planes, top_pad_rows, left_pad_cols, + bottom_pad_planes, bottom_pad_rows, right_pad_cols, + // Spatial striding. + dims.spatial_dims[0].stride, // stride planes + dims.spatial_dims[1].stride, // stride rows + dims.spatial_dims[2].stride, // stride cols + col_data_shard); + } + }; + Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, + size_A, shard); + + ConstTensorMap A(col_buffer_data, output_image_size * shard_limit, + filter_total_size); + ConstTensorMap B(out_backprop_data, output_image_size * shard_limit, + dims.out_depth); + + // Gradient with respect to filter. + C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims); + + input_data += input_offset * shard_limit; + out_backprop_data += output_offset * shard_limit; + } } private: @@ -421,21 +991,60 @@ class Conv3DBackpropFilterOp : public OpKernel { Padding padding_; TensorFormat data_format_; bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp); }; +// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than +// default Eigen implementation (at the cost of ~2x-8x peak memory usage). + #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - Conv3DBackpropFilterOp<CPUDevice, T>); \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropFilterOp<CPUDevice, T>); \ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ .TypeConstraint<T>("T"), \ Conv3DBackpropFilterOp<CPUDevice, T>); -TF_CALL_half(REGISTER_CPU_KERNEL); + TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL +// WARNING: Eigen::half is not trivially copyable and can't be used in +// custom backprop filter kernel because of memcpy and memset in Im2col. +#define REGISTER_CPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + Conv3DBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropFilterOp<CPUDevice, T>); + +TF_CALL_half(REGISTER_CPU_KERNEL); +#undef REGISTER_CPU_KERNEL + // GPU definitions of both ops. #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. @@ -523,6 +1132,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& filter = context->input(1); const TensorShape& filter_shape = filter.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + TensorShape input_shape; if (takes_shape_) { const Tensor& input_sizes = context->input(0); @@ -531,7 +1144,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { } else { input_shape = context->input(0).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensionsV2( + "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, dilation_, + stride_, padding_, data_format_, &dims)); + Tensor* in_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); @@ -539,13 +1159,15 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 && - dilation_[0] == 1 && dilation_[1] == 1 && dilation_[2] == 1 && - stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 && + if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 && + dims.filter_size(2) == 1 && dims.dilation(0) == 1 && + dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 && + dims.stride(1) == 1 && dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) { - const uint64 m = batch * input_size[0] * input_size[1] * input_size[2]; - const uint64 k = out_depth; - const uint64 n = in_depth; + const uint64 m = dims.batch_size * dims.input_size(0) * + dims.input_size(1) * dims.input_size(2); + const uint64 k = dims.out_depth; + const uint64 n = dims.in_depth; auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), out_backprop.template flat<T>().size()); @@ -567,13 +1189,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { ", n=", n, ", k=", k)); } return; - } else if (filter_size[0] == input_size[0] && - filter_size[1] == input_size[1] && - filter_size[2] == input_size[2] && padding_ == Padding::VALID && - data_format_ == FORMAT_NHWC) { - const uint64 m = batch; - const uint64 k = out_depth; - const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth; + } else if (dims.filter_size(0) == dims.input_size(0) && + dims.filter_size(1) == dims.input_size(1) && + dims.filter_size(2) == dims.input_size(2) && + padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { + const uint64 m = dims.batch_size; + const uint64 k = dims.out_depth; + const uint64 n = dims.input_size(0) * dims.input_size(1) * + dims.input_size(2) * dims.in_depth; auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), out_backprop.template flat<T>().size()); @@ -597,65 +1220,59 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { return; } - int padding_rows = 0, padding_cols = 0, padding_planes = 0; - - if (padding_ == Padding::SAME) { - padding_planes = std::max<int>( - 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]); - padding_cols = std::max<int>( - 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]); - padding_rows = std::max<int>( - 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]); - } + int padding_planes = dims.SpatialPadding(padding_, 0); + int padding_rows = dims.SpatialPadding(padding_, 1); + int padding_cols = dims.SpatialPadding(padding_, 2); + const bool planes_odd = (padding_planes % 2 != 0); const bool rows_odd = (padding_rows % 2 != 0); const bool cols_odd = (padding_cols % 2 != 0); - const bool planes_odd = (padding_planes % 2 != 0); TensorShape compatible_input_shape; if (rows_odd || cols_odd || planes_odd) { // cuDNN only supports the same amount of padding on both sides. compatible_input_shape = { - batch, - in_depth, - input_size[0] + planes_odd, - input_size[1] + rows_odd, - input_size[2] + cols_odd, + dims.batch_size, + dims.in_depth, + dims.input_size(0) + planes_odd, + dims.input_size(1) + rows_odd, + dims.input_size(2) + cols_odd, }; } else { - compatible_input_shape = {batch, in_depth, input_size[0], input_size[1], - input_size[2]}; + compatible_input_shape = {dims.batch_size, dims.in_depth, + dims.input_size(0), dims.input_size(1), + dims.input_size(2)}; } CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) << "Negative paddings: (" << padding_rows << ", " << padding_cols << ", " << padding_planes << ")"; se::dnn::BatchDescriptor input_desc(3); - input_desc.set_count(batch) + input_desc.set_count(dims.batch_size) .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4)) .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3)) .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2)) - .set_feature_map_count(in_depth) + .set_feature_map_count(dims.in_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::BatchDescriptor output_desc(3); - output_desc.set_count(batch) - .set_spatial_dim(DimIndex::X, output_cols) - .set_spatial_dim(DimIndex::Y, output_rows) - .set_spatial_dim(DimIndex::Z, output_planes) - .set_feature_map_count(out_depth) + output_desc.set_count(dims.batch_size) + .set_spatial_dim(DimIndex::X, dims.output_size(2)) + .set_spatial_dim(DimIndex::Y, dims.output_size(1)) + .set_spatial_dim(DimIndex::Z, dims.output_size(0)) + .set_feature_map_count(dims.out_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::FilterDescriptor filter_desc(3); - filter_desc.set_spatial_dim(DimIndex::X, filter_size[2]) - .set_spatial_dim(DimIndex::Y, filter_size[1]) - .set_spatial_dim(DimIndex::Z, filter_size[0]) - .set_input_feature_map_count(in_depth) - .set_output_feature_map_count(out_depth); + filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) + .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) + .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); se::dnn::ConvolutionDescriptor conv_desc(3); - conv_desc.set_dilation_rate(DimIndex::X, dilations[2]) - .set_dilation_rate(DimIndex::Y, dilations[1]) - .set_dilation_rate(DimIndex::Z, dilations[0]) - .set_filter_stride(DimIndex::X, strides[2]) - .set_filter_stride(DimIndex::Y, strides[1]) - .set_filter_stride(DimIndex::Z, strides[0]) + conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) + .set_dilation_rate(DimIndex::Y, dims.dilation(1)) + .set_dilation_rate(DimIndex::Z, dims.dilation(0)) + .set_filter_stride(DimIndex::X, dims.stride(2)) + .set_filter_stride(DimIndex::Y, dims.stride(1)) + .set_filter_stride(DimIndex::Z, dims.stride(0)) .set_zero_padding(DimIndex::X, padding_cols / 2) .set_zero_padding(DimIndex::Y, padding_rows / 2) .set_zero_padding(DimIndex::Z, padding_planes / 2); @@ -664,10 +1281,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { Tensor transformed_filter; OP_REQUIRES_OK( context, - context->allocate_temp(DataTypeToEnum<T>::value, - TensorShape({out_depth, in_depth, filter_size[0], - filter_size[1], filter_size[2]}), - &transformed_filter)); + context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0), + dims.filter_size(1), dims.filter_size(2)}), + &transformed_filter)); functor::TransformFilter<GPUDevice, T, int, 5>()( context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()), To32Bit(transformed_filter.tensor<T, 5>())); @@ -675,9 +1293,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { // Shape: batch, filters, z, y, x. Tensor transformed_out_backprop; if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows, - output_cols}; - if (out_depth > 1) { + TensorShape nchw_shape = {dims.batch_size, dims.out_depth, + dims.output_size(0), dims.output_size(1), + dims.output_size(2)}; + if (dims.out_depth > 1) { OP_REQUIRES_OK(context, context->allocate_temp( DataTypeToEnum<T>::value, nchw_shape, &transformed_out_backprop)); @@ -713,14 +1332,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { const int device_id = stream->parent()->device_ordinal(); DataType dtype = context->input(0).dtype(); const ConvParameters conv_parameters = { - batch, - in_depth, - {{input_size[0], input_size[1], input_size[2]}}, + dims.batch_size, + dims.in_depth, + {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, FORMAT_NCHW, - out_depth, - {{filter_size[0], filter_size[1], filter_size[2]}}, - {{dilations[0], dilations[1], dilations[2]}}, - {{strides[0], strides[1], strides[2]}}, + dims.out_depth, + {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, + {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, + {{dims.stride(0), dims.stride(1), dims.stride(2)}}, {{padding_planes, padding_rows, padding_cols}}, dtype, device_id, @@ -799,10 +1418,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { if (rows_odd || cols_odd || planes_odd) { Tensor in_backprop_remove_padding; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::value, - {batch, in_depth, input_size[0], - input_size[1], input_size[2]}, - &in_backprop_remove_padding)); + context->allocate_temp( + DataTypeToEnum<T>::value, + {dims.batch_size, dims.in_depth, dims.input_size(0), + dims.input_size(1), dims.input_size(2)}, + &in_backprop_remove_padding)); // Remove the padding for odd spatial dimensions. functor::PadInput<GPUDevice, T, int, 5>()( @@ -896,6 +1516,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const TensorShape& input_shape = input.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + TensorShape filter_shape; if (takes_shape_) { const Tensor& filter_sizes = context->input(1); @@ -905,7 +1529,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { filter_shape = context->input(1).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter"); + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensionsV2( + "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, dilation_, + stride_, padding_, data_format_, &dims)); Tensor* filter_backprop; OP_REQUIRES_OK(context, @@ -914,13 +1543,15 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 && - dilations[2] == 1 && dilations[1] == 1 && dilations[0] == 1 && - strides[2] == 1 && strides[1] == 1 && strides[0] == 1 && + if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 && + dims.filter_size(0) == 1 && dims.dilation(2) == 1 && + dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 && + dims.stride(1) == 1 && dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) { - const uint64 m = in_depth; - const uint64 k = batch * input_size[1] * input_size[2] * input_size[0]; - const uint64 n = out_depth; + const uint64 m = dims.in_depth; + const uint64 k = dims.batch_size * dims.input_size(1) * + dims.input_size(2) * dims.input_size(0); + const uint64 n = dims.out_depth; // The shape of output backprop is // [batch, out_z, out_y, out_x, out_depth] @@ -951,13 +1582,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { ", n=", n, ", k=", k)); } return; - } else if (filter_size[0] == input_size[0] && - filter_size[1] == input_size[1] && - filter_size[2] == input_size[2] && padding_ == Padding::VALID && - data_format_ == FORMAT_NHWC) { - const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth; - const uint64 k = batch; - const uint64 n = out_depth; + } else if (dims.filter_size(0) == dims.input_size(0) && + dims.filter_size(1) == dims.input_size(1) && + dims.filter_size(2) == dims.input_size(2) && + padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { + const uint64 m = dims.input_size(0) * dims.input_size(1) * + dims.input_size(2) * dims.in_depth; + const uint64 k = dims.batch_size; + const uint64 n = dims.out_depth; auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), input.template flat<T>().size()); @@ -979,30 +1611,24 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { return; } - int padding_rows = 0, padding_cols = 0, padding_planes = 0; - - if (padding_ == Padding::SAME) { - padding_planes = std::max<int>( - 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]); - padding_cols = std::max<int>( - 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]); - padding_rows = std::max<int>( - 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]); - } - bool rows_odd = (padding_rows % 2 != 0); - bool cols_odd = (padding_cols % 2 != 0); - bool planes_odd = (padding_planes % 2 != 0); + int padding_planes = dims.SpatialPadding(padding_, 0); + int padding_rows = dims.SpatialPadding(padding_, 1); + int padding_cols = dims.SpatialPadding(padding_, 2); + const bool planes_odd = (padding_planes % 2 != 0); + const bool rows_odd = (padding_rows % 2 != 0); + const bool cols_odd = (padding_cols % 2 != 0); Tensor compatible_input; if (rows_odd || cols_odd || planes_odd) { - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum<T>::value, - ShapeFromFormat(data_format_, batch, - {{input_size[0] + planes_odd, - input_size[1] + rows_odd, - input_size[2] + cols_odd}}, - in_depth), - &compatible_input)); + OP_REQUIRES_OK(context, + context->allocate_temp( + DataTypeToEnum<T>::value, + ShapeFromFormat(data_format_, dims.batch_size, + {{dims.input_size(0) + planes_odd, + dims.input_size(1) + rows_odd, + dims.input_size(2) + cols_odd}}, + dims.in_depth), + &compatible_input)); functor::PadInput<GPUDevice, T, int, 5>()( context->template eigen_device<GPUDevice>(), To32Bit(input.tensor<T, 5>()), {{0, 0, 0}}, @@ -1016,35 +1642,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { << "Negative paddings: (" << padding_rows << ", " << padding_cols << ", " << padding_planes << ")"; se::dnn::BatchDescriptor input_desc(3); - input_desc.set_count(batch) + input_desc.set_count(dims.batch_size) .set_spatial_dim(DimIndex::X, GetTensorDim(compatible_input, data_format_, '2')) .set_spatial_dim(DimIndex::Y, GetTensorDim(compatible_input, data_format_, '1')) .set_spatial_dim(DimIndex::Z, GetTensorDim(compatible_input, data_format_, '0')) - .set_feature_map_count(in_depth) + .set_feature_map_count(dims.in_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::BatchDescriptor output_desc(3); - output_desc.set_count(batch) - .set_spatial_dim(DimIndex::X, output_cols) - .set_spatial_dim(DimIndex::Y, output_rows) - .set_spatial_dim(DimIndex::Z, output_planes) - .set_feature_map_count(out_depth) + output_desc.set_count(dims.batch_size) + .set_spatial_dim(DimIndex::X, dims.output_size(2)) + .set_spatial_dim(DimIndex::Y, dims.output_size(1)) + .set_spatial_dim(DimIndex::Z, dims.output_size(0)) + .set_feature_map_count(dims.out_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::FilterDescriptor filter_desc(3); - filter_desc.set_spatial_dim(DimIndex::X, filter_size[2]) - .set_spatial_dim(DimIndex::Y, filter_size[1]) - .set_spatial_dim(DimIndex::Z, filter_size[0]) - .set_input_feature_map_count(in_depth) - .set_output_feature_map_count(out_depth); + filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) + .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) + .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); se::dnn::ConvolutionDescriptor conv_desc(3); - conv_desc.set_dilation_rate(DimIndex::X, dilations[2]) - .set_dilation_rate(DimIndex::Y, dilations[1]) - .set_dilation_rate(DimIndex::Z, dilations[0]) - .set_filter_stride(DimIndex::X, strides[2]) - .set_filter_stride(DimIndex::Y, strides[1]) - .set_filter_stride(DimIndex::Z, strides[0]) + conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) + .set_dilation_rate(DimIndex::Y, dims.dilation(1)) + .set_dilation_rate(DimIndex::Z, dims.dilation(0)) + .set_filter_stride(DimIndex::X, dims.stride(2)) + .set_filter_stride(DimIndex::Y, dims.stride(1)) + .set_filter_stride(DimIndex::Z, dims.stride(0)) .set_zero_padding(DimIndex::X, padding_cols / 2) .set_zero_padding(DimIndex::Y, padding_rows / 2) .set_zero_padding(DimIndex::Z, padding_planes / 2); @@ -1052,19 +1678,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { Tensor pre_transformed_filter_backprop; OP_REQUIRES_OK( context, - context->allocate_temp(DataTypeToEnum<T>::value, - TensorShape({out_depth, in_depth, filter_size[0], - filter_size[1], filter_size[2]}), - &pre_transformed_filter_backprop)); + context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0), + dims.filter_size(1), dims.filter_size(2)}), + &pre_transformed_filter_backprop)); Tensor transformed_out_backprop; if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows, - output_cols}; + TensorShape nchw_shape = {dims.batch_size, dims.out_depth, + dims.output_size(0), dims.output_size(1), + dims.output_size(2)}; OP_REQUIRES_OK( context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, &transformed_out_backprop)); - if (out_depth > 1) { + if (dims.out_depth > 1) { functor::NHWCToNCHW<GPUDevice, T, 5>()( context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(), transformed_out_backprop.tensor<T, 5>()); @@ -1076,10 +1704,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { } Tensor transformed_input; if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1), - compatible_input.dim_size(2), - compatible_input.dim_size(3)}; - if (in_depth > 1) { + TensorShape nchw_shape = { + dims.batch_size, dims.in_depth, compatible_input.dim_size(1), + compatible_input.dim_size(2), compatible_input.dim_size(3)}; + if (dims.in_depth > 1) { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, &transformed_input)); @@ -1110,14 +1738,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { const int device_id = stream->parent()->device_ordinal(); DataType dtype = input.dtype(); const ConvParameters conv_parameters = { - batch, - in_depth, - {{input_size[0], input_size[1], input_size[2]}}, + dims.batch_size, + dims.in_depth, + {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, FORMAT_NCHW, - out_depth, - {{filter_size[0], filter_size[1], filter_size[2]}}, - {{dilations[0], dilations[1], dilations[2]}}, - {{strides[0], strides[1], strides[2]}}, + dims.out_depth, + {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, + {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, + {{dims.stride(0), dims.stride(1), dims.stride(2)}}, {{padding_planes, padding_rows, padding_cols}}, dtype, device_id, |