aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_grad_ops_3d.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_ops_3d.cc')
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc1324
1 files changed, 976 insertions, 348 deletions
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 15f1bf9aba..d26b86c712 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -32,111 +33,130 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
using stream_executor::dnn::DimIndex;
#endif
+namespace {
+
+// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
+// conv_grad_input_ops_3d.cc.
+
+// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
+
+// "Depth" is already used for the channel dimension, so for the third spatial
+// dimension in this file we use "plane", although in NDHWC layout it's
+// indicated with a "D".
+
+// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
+// order (planes, height, width, depth), constructed from patches in 'col_data',
+// which is required to be in storage order (out_planes * out_height *
+// out_width, filter_planes, filter_height, filter_width, in_depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Col2im(const T* col_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* im_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ T* im_patch_data =
+ im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ for (int i = 0; i < depth; ++i) {
+ im_patch_data[i] += col_data[i];
+ }
+ }
+ im_patch_data += depth;
+ col_data += depth;
+ }
+ // Jump over remaining number of depth.
+ im_patch_data += depth * (width - filter_w);
+ }
+ // Jump over remaining number of (depth * width).
+ im_patch_data += (depth * width) * (height - filter_h);
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+// Returns in 'col_data', image patches in storage order (planes, height, width,
+// depth) extracted from image at 'input_data', which is required to be in
+// storage order (batch, planes, height, width, depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Im2col(const T* input_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* col_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ memcpy(col_data,
+ input_data +
+ (ip * height * width + ih * width + iw) * depth,
+ sizeof(T) * depth);
+ } else {
+ // This should be simply padded with zero.
+ memset(col_data, 0, sizeof(T) * depth);
+ }
+ col_data += depth;
+ }
+ }
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+} // namespace
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-// TODO(mjanusz): Get rid of the macro and return shapes directly.
-#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \
- const Tensor& out_backprop = context->input(2); \
- OP_REQUIRES( \
- context, input_shape.dims() == 5, \
- errors::InvalidArgument(label, ": input must be 5-dimensional")); \
- OP_REQUIRES( \
- context, filter_shape.dims() == 5, \
- errors::InvalidArgument(label, ": filter must be 5-dimensional")); \
- OP_REQUIRES( \
- context, out_backprop.dims() == 5, \
- errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \
- const int64 batch = input_shape.dim_size(0); \
- OP_REQUIRES( \
- context, batch == out_backprop.dim_size(0), \
- errors::InvalidArgument( \
- label, ": input and out_backprop must have the same batch size")); \
- const std::array<int64, 3> input_size = { \
- {GetTensorDim(input_shape, data_format_, '0'), \
- GetTensorDim(input_shape, data_format_, '1'), \
- GetTensorDim(input_shape, data_format_, '2')}}; \
- const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \
- const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0), \
- filter_shape.dim_size(1), \
- filter_shape.dim_size(2)}}; \
- const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2'); \
- const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1'); \
- const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0'); \
- OP_REQUIRES(context, in_depth == filter_shape.dim_size(3), \
- errors::InvalidArgument( \
- label, ": input and filter must have the same depth")); \
- const int64 out_depth = filter_shape.dim_size(4); \
- OP_REQUIRES( \
- context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \
- errors::InvalidArgument( \
- label, ": filter and out_backprop must have the same out_depth")); \
- const std::array<int64, 3> dilations = { \
- {GetTensorDim(dilation_, data_format_, '0'), \
- GetTensorDim(dilation_, data_format_, '1'), \
- GetTensorDim(dilation_, data_format_, '2')}}; \
- const std::array<int64, 3> strides = { \
- {GetTensorDim(stride_, data_format_, '0'), \
- GetTensorDim(stride_, data_format_, '1'), \
- GetTensorDim(stride_, data_format_, '2')}}; \
- std::array<int64, 3> out, padding; \
- OP_REQUIRES_OK( \
- context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides, \
- padding_, &out, &padding)); \
- OP_REQUIRES(context, output_planes == out[0], \
- errors::InvalidArgument( \
- label, \
- ": Number of planes of out_backprop doesn't match " \
- "computed: actual = ", \
- output_planes, ", computed = ", out[0])); \
- OP_REQUIRES( \
- context, output_rows == out[1], \
- errors::InvalidArgument( \
- label, ": Number of rows of out_backprop doesn't match computed: ", \
- "actual = ", output_rows, ", computed = ", out[1])); \
- OP_REQUIRES( \
- context, output_cols == out[2], \
- errors::InvalidArgument( \
- label, ": Number of cols of out_backprop doesn't match computed: ", \
- "actual = ", output_cols, ", computed = ", out[2])); \
- const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \
- const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \
- const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \
- const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \
- const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \
- const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \
- const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \
- const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \
- const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \
- const auto bottom_pad_planes = \
- padded_out_planes - expanded_out_planes - top_pad_planes; \
- const auto bottom_pad_rows = \
- padded_out_rows - expanded_out_rows - top_pad_rows; \
- const auto right_pad_cols = \
- padded_out_cols - expanded_out_cols - left_pad_cols; \
- VLOG(2) << "Conv3d: " << label \
- << ": expanded_out_planes = " << expanded_out_planes \
- << ": expanded_out_rows = " << expanded_out_rows \
- << ", expanded_out_cols = " << expanded_out_cols \
- << ", padded_out_planes = " << padded_out_planes \
- << ", padded_out_rows = " << padded_out_rows \
- << ", padded_out_cols = " << padded_out_cols \
- << ", top_pad_planes = " << top_pad_planes \
- << ", top_pad_rows = " << top_pad_rows \
- << ", left_pad_cols = " << left_pad_cols \
- << ", bottom_pad_planes = " << bottom_pad_planes \
- << ", bottom_pad_rows = " << bottom_pad_rows \
- << ", right_pad_cols = " << right_pad_cols
-
-// Backprop for input.
+// Backprop for input that offloads computation to
+// Eigen::CuboidConvolutionBackwardInput.
template <typename Device, class T>
class Conv3DBackpropInputOp : public OpKernel {
public:
@@ -192,6 +212,10 @@ class Conv3DBackpropInputOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -200,51 +224,345 @@ class Conv3DBackpropInputOp : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
- // Fill out a padded out_backprop.
- TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows,
- padded_out_cols, out_depth});
- Tensor padded_output;
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
+};
+
+// Custom backprop for input that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropInputOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& filter = context->input(1);
+ const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape input_shape;
+ if (takes_shape_) {
+ const Tensor& input_sizes = context->input(0);
+ // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
+ OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
+ } else {
+ input_shape = context->input(0).shape();
+ }
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
+ Tensor* in_backprop;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4};
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // Fill a new "reverted" filter. We need to transpose the in_depth and
- // out_depth for the filter and reverse the planes, rows and cols.
- TensorShape r_filter_shape(
- {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth});
- Tensor r_filter;
- OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
- r_filter_shape, &r_filter));
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order,
- filter_rev_dims, r_filter.tensor<T, 5>());
- const Tensor& r_filter_cref = r_filter;
-
- // Now we can call conv_3d directly.
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), in_backprop->tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
+ context->allocate_output(0, input_shape, &in_backprop));
+
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // Use L3 cache size as target working set size.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ // Calculate size of matrices involved in MatMul: C = A x B.
+ const int64 size_A = output_image_size * dims.out_depth;
+
+ const int64 size_B = filter_total_size * dims.out_depth;
+
+ const int64 size_C = output_image_size * filter_total_size;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ // Use parallel tensor contractions if there is no batching.
+ //
+ // Compared to Conv2D code, this version is missing work size estimation. In
+ // benchmarks I didn't find a case when it's beneficial to run parallel
+ // contraction compared to sharding and matmuls.
+ const bool use_parallel_contraction = dims.batch_size == 1;
+
+ const size_t shard_size =
+ use_parallel_contraction
+ ? 1
+ : (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* filter_data = filter.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+
+ auto in_backprop_flat = in_backprop->template flat<T>();
+ T* input_backprop_data = in_backprop_flat.data();
+ in_backprop_flat.device(context->eigen_device<Device>()) =
+ in_backprop_flat.constant(T(0));
+
+ if (use_parallel_contraction) {
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ // Initialize contraction dims (we need to transpose 'B' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 1;
+ contract_dims[0].second = 1;
+
+ for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
+ // Compute gradient into col_buffer.
+ TensorMap C(col_buffer_data, output_image_size, filter_total_size);
+
+ ConstTensorMap A(out_backprop_data + output_offset * image_id,
+ output_image_size, dims.out_depth);
+ ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
+
+ Col2im<T>(col_buffer_data, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_backprop_data);
+
+ input_backprop_data += input_offset;
+ }
+ } else {
+ typedef Eigen::Map<
+ Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+ MatrixMap;
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>>
+ ConstMatrixMap;
+
+ for (int image_id = 0; image_id < dims.batch_size;
+ image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
+ &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
+ &output_image_size, &filter_total_size,
+ &input_backprop_data, &col_buffer_data,
+ &out_backprop_data, &filter_data, &input_offset,
+ &output_offset, &size_C](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ T* im2col_buf = col_buffer_data + shard_id * size_C;
+ T* input_data = input_backprop_data + shard_id * input_offset;
+ const T* out_data = out_backprop_data + shard_id * output_offset;
+
+ // Compute gradient into 'im2col_buf'.
+ MatrixMap C(im2col_buf, output_image_size, filter_total_size);
+
+ ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
+ ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.noalias() = A * B.transpose();
+
+ Col2im<T>(im2col_buf, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_data);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ work_unit_size, shard);
+
+ input_backprop_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
+ }
}
private:
@@ -253,21 +571,48 @@ class Conv3DBackpropInputOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>); \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>);
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>);
+
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
-// Backprop for filter.
+// Backprop for filter that offloads computation to
+// Eigen::CuboidConvolutionBackwardFilter.
template <typename Device, class T>
class Conv3DBackpropFilterOp : public OpKernel {
public:
@@ -323,8 +668,11 @@ class Conv3DBackpropFilterOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
- TensorShape filter_shape;
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
@@ -333,13 +681,13 @@ class Conv3DBackpropFilterOp : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, filter_shape, &filter_backprop));
@@ -349,70 +697,292 @@ class Conv3DBackpropFilterOp : public OpKernel {
return;
}
- // For the backprop of the filter, we need to also transpose the
- // out_backprop.
- // The shape of backprop is
- // [batch, out_z, out_y, out_x, out_depth]
- // And we need to change it to
- // [out_depth, out_x, out_y, out_z, batch]
- Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
- TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
- padded_out_cols, batch});
- Tensor padded_output;
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
+};
+
+// Custom backprop for filter that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropFilterOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
+ if (takes_shape_) {
+ const Tensor& filter_sizes = context->input(1);
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ filter_sizes.vec<int32>(), &filter_shape));
+ } else {
+ filter_shape = context->input(1).shape();
+ }
+
+ ConvBackpropDimensions dims;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // For the backprop of the filter, we need to transpose the input.
- // The shape of input is
- // [batch, in_z, in_y, in_x, in_depth]
- // And we need to change it to
- // [in_z, in_y, in_x, batch, in_depth]
- Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
- TensorShape in_shuffle_shape(
- {input_size[0], input_size[1], input_size[2], batch, in_depth});
- Tensor in_shuffle;
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
+ Tensor* filter_backprop;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- in_shuffle_shape, &in_shuffle));
- // No need for reversing this time.
- Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
- no_reverse, in_shuffle.tensor<T, 5>());
- const Tensor& in_shuffle_cref = in_shuffle;
-
- // The output of the conv_3d would be
- // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
- // and we need to shuffle it back to
- // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
- // And we need to reverse the filter backprops.
- // So we need to allocate (sigh) yet another piece of memory to hold the
- // output.
- TensorShape filter_shuffle_shape(
- {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
- Tensor filter_shuffle;
- OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<T>::v(),
- filter_shuffle_shape, &filter_shuffle));
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
-
- // Now copy the filter_backprop back to the destination.
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- const Tensor& filter_shuffle_cref = filter_shuffle;
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
- filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
+ context->allocate_output(0, filter_shape, &filter_backprop));
+
+ if (input_shape.num_elements() == 0) {
+ filter_backprop->template flat<T>().setZero();
+ return;
+ }
+
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ // Shard 'batch' images (volumes) into 'shard_size' groups of images
+ // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
+ // dividing the L3 cache size ('target_working_set_size') by the matmul size
+ // of an individual image ('work_unit_size').
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // TODO(andydavis)
+ // *) Consider reducing 'target_working_set_size' if L3 is shared by
+ // other concurrently running tensorflow ops.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ const int64 size_A = output_image_size * filter_total_size;
+
+ const int64 size_B = output_image_size * dims.out_depth;
+
+ const int64 size_C = filter_total_size * dims.out_depth;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ const size_t shard_size =
+ (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* input_data = input.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+ T* filter_backprop_data = filter_backprop->template flat<T>().data();
+
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
+ C.setZero();
+
+ // Initialize contraction dims (we need to transpose 'A' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 0;
+ contract_dims[0].second = 0;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
+ &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
+ &bottom_pad_rows, &right_pad_cols, &input_offset,
+ &size_A](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ const T* input_data_shard = input_data + shard_id * input_offset;
+ T* col_data_shard = col_buffer_data + shard_id * size_A;
+
+ // When we compute the gradient with respect to the filters, we need
+ // to do im2col to allow gemm-type computation.
+ Im2col<T>(input_data_shard, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ col_data_shard);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ size_A, shard);
+
+ ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
+ filter_total_size);
+ ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
+ dims.out_depth);
+
+ // Gradient with respect to filter.
+ C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
+
+ input_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
}
private:
@@ -421,21 +991,60 @@ class Conv3DBackpropFilterOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropFilterOp<CPUDevice, T>); \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
.Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
.TypeConstraint<T>("T"), \
Conv3DBackpropFilterOp<CPUDevice, T>);
-TF_CALL_half(REGISTER_CPU_KERNEL);
+
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
+// WARNING: Eigen::half is not trivially copyable and can't be used in
+// custom backprop filter kernel because of memcpy and memset in Im2col.
+#define REGISTER_CPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>);
+
+TF_CALL_half(REGISTER_CPU_KERNEL);
+#undef REGISTER_CPU_KERNEL
+
// GPU definitions of both ops.
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
@@ -523,6 +1132,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -531,7 +1144,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
@@ -539,13 +1159,15 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 &&
- dilation_[0] == 1 && dilation_[1] == 1 && dilation_[2] == 1 &&
- stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 &&
+ if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 &&
+ dims.filter_size(2) == 1 && dims.dilation(0) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 &&
+ dims.stride(1) == 1 && dims.stride(2) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = batch * input_size[0] * input_size[1] * input_size[2];
- const uint64 k = out_depth;
- const uint64 n = in_depth;
+ const uint64 m = dims.batch_size * dims.input_size(0) *
+ dims.input_size(1) * dims.input_size(2);
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -567,13 +1189,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = batch;
- const uint64 k = out_depth;
- const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.batch_size;
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -597,65 +1220,59 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
- const bool planes_odd = (padding_planes % 2 != 0);
TensorShape compatible_input_shape;
if (rows_odd || cols_odd || planes_odd) {
// cuDNN only supports the same amount of padding on both sides.
compatible_input_shape = {
- batch,
- in_depth,
- input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd,
+ dims.batch_size,
+ dims.in_depth,
+ dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd,
};
} else {
- compatible_input_shape = {batch, in_depth, input_size[0], input_size[1],
- input_size[2]};
+ compatible_input_shape = {dims.batch_size, dims.in_depth,
+ dims.input_size(0), dims.input_size(1),
+ dims.input_size(2)};
}
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
.set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
.set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -664,10 +1281,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
Tensor transformed_filter;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &transformed_filter));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 5>()(
context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
@@ -675,9 +1293,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
// Shape: batch, filters, z, y, x.
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
- if (out_depth > 1) {
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
+ if (dims.out_depth > 1) {
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
@@ -713,14 +1332,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = context->input(0).dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
@@ -799,10 +1418,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
if (rows_odd || cols_odd || planes_odd) {
Tensor in_backprop_remove_padding;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- {batch, in_depth, input_size[0],
- input_size[1], input_size[2]},
- &in_backprop_remove_padding));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ {dims.batch_size, dims.in_depth, dims.input_size(0),
+ dims.input_size(1), dims.input_size(2)},
+ &in_backprop_remove_padding));
// Remove the padding for odd spatial dimensions.
functor::PadInput<GPUDevice, T, int, 5>()(
@@ -896,6 +1516,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
@@ -905,7 +1529,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
@@ -914,13 +1543,15 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
- dilations[2] == 1 && dilations[1] == 1 && dilations[0] == 1 &&
- strides[2] == 1 && strides[1] == 1 && strides[0] == 1 &&
+ if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
+ dims.filter_size(0) == 1 && dims.dilation(2) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 &&
+ dims.stride(1) == 1 && dims.stride(0) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = in_depth;
- const uint64 k = batch * input_size[1] * input_size[2] * input_size[0];
- const uint64 n = out_depth;
+ const uint64 m = dims.in_depth;
+ const uint64 k = dims.batch_size * dims.input_size(1) *
+ dims.input_size(2) * dims.input_size(0);
+ const uint64 n = dims.out_depth;
// The shape of output backprop is
// [batch, out_z, out_y, out_x, out_depth]
@@ -951,13 +1582,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth;
- const uint64 k = batch;
- const uint64 n = out_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
+ const uint64 k = dims.batch_size;
+ const uint64 n = dims.out_depth;
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
input.template flat<T>().size());
@@ -979,30 +1611,24 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
- bool rows_odd = (padding_rows % 2 != 0);
- bool cols_odd = (padding_cols % 2 != 0);
- bool planes_odd = (padding_planes % 2 != 0);
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
+ const bool rows_odd = (padding_rows % 2 != 0);
+ const bool cols_odd = (padding_cols % 2 != 0);
Tensor compatible_input;
if (rows_odd || cols_odd || planes_odd) {
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format_, batch,
- {{input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd}},
- in_depth),
- &compatible_input));
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format_, dims.batch_size,
+ {{dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd}},
+ dims.in_depth),
+ &compatible_input));
functor::PadInput<GPUDevice, T, int, 5>()(
context->template eigen_device<GPUDevice>(),
To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
@@ -1016,35 +1642,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X,
GetTensorDim(compatible_input, data_format_, '2'))
.set_spatial_dim(DimIndex::Y,
GetTensorDim(compatible_input, data_format_, '1'))
.set_spatial_dim(DimIndex::Z,
GetTensorDim(compatible_input, data_format_, '0'))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -1052,19 +1678,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
Tensor pre_transformed_filter_backprop;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &pre_transformed_filter_backprop));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &pre_transformed_filter_backprop));
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
- if (out_depth > 1) {
+ if (dims.out_depth > 1) {
functor::NHWCToNCHW<GPUDevice, T, 5>()(
context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
transformed_out_backprop.tensor<T, 5>());
@@ -1076,10 +1704,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
}
Tensor transformed_input;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1),
- compatible_input.dim_size(2),
- compatible_input.dim_size(3)};
- if (in_depth > 1) {
+ TensorShape nchw_shape = {
+ dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
+ compatible_input.dim_size(2), compatible_input.dim_size(3)};
+ if (dims.in_depth > 1) {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::value,
nchw_shape, &transformed_input));
@@ -1110,14 +1738,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = input.dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,