diff options
Diffstat (limited to 'tensorflow/core/kernels/pooling_ops_3d.cc')
-rw-r--r-- | tensorflow/core/kernels/pooling_ops_3d.cc | 314 |
1 files changed, 283 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/pooling_ops_3d.cc b/tensorflow/core/kernels/pooling_ops_3d.cc index f12c18eaa8..538dca24ae 100644 --- a/tensorflow/core/kernels/pooling_ops_3d.cc +++ b/tensorflow/core/kernels/pooling_ops_3d.cc @@ -14,12 +14,15 @@ limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS +#include "tensorflow/core/kernels/pooling_ops_3d.h" + #include <array> #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" @@ -28,15 +31,64 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA #include "tensorflow/core/kernels/cudnn_pooling_gpu.h" +#include "tensorflow/core/kernels/pooling_ops_3d_gpu.h" #endif namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +Pool3dParameters::Pool3dParameters(OpKernelContext* context, + const std::vector<int32>& ksize, + const std::vector<int32>& stride, + Padding padding, TensorFormat data_format, + const TensorShape& tensor_in_shape) { + // For maxpooling, tensor_in should have 4 dimensions. + OP_REQUIRES(context, tensor_in_shape.dims() == 5, + errors::InvalidArgument("tensor_in must be 4-dimensional")); + + this->data_format = data_format; + depth = GetTensorDim(tensor_in_shape, data_format, 'C'); + tensor_in_planes = GetTensorDim(tensor_in_shape, data_format, '0'); + tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, '1'); + tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, '2'); + tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N'); + window_planes = GetTensorDim(ksize, data_format, '0'); + window_rows = GetTensorDim(ksize, data_format, '1'); + window_cols = GetTensorDim(ksize, data_format, '2'); + depth_window = GetTensorDim(ksize, data_format, 'C'); + plane_stride = GetTensorDim(stride, data_format, '0'); + row_stride = GetTensorDim(stride, data_format, '1'); + col_stride = GetTensorDim(stride, data_format, '2'); + depth_stride = GetTensorDim(stride, data_format, 'C'); + + // We only support 3D pooling across plane/width/height. Depthwise + // pooling is not supported. + OP_REQUIRES( + context, depth_window == 1 && depth_stride == 1, + errors::Unimplemented( + "Pooling3d only supports pooling across plane/width/height.")); + + OP_REQUIRES_OK(context, GetWindowedOutputSize(tensor_in_planes, window_planes, + plane_stride, padding, + &out_plane, &pad_planes)); + OP_REQUIRES_OK(context, + GetWindowedOutputSize(tensor_in_rows, window_rows, row_stride, + padding, &out_height, &pad_rows)); + OP_REQUIRES_OK(context, + GetWindowedOutputSize(tensor_in_cols, window_cols, col_stride, + padding, &out_width, &pad_cols)); +} + +TensorShape Pool3dParameters::forward_output_shape() { + return ShapeFromFormat(data_format, tensor_in_batch, + {{out_plane, out_height, out_width}}, depth); +} + enum PoolingType { MAX, AVG }; template <typename Device, typename T, PoolingType Type> @@ -147,12 +199,6 @@ class Pooling3DOp : public UnaryOp<T> { Padding padding_; TensorFormat data_format_; }; -REGISTER_KERNEL_BUILDER( - Name("AvgPool3D").Device(DEVICE_CPU).TypeConstraint<float>("T"), - Pooling3DOp<CPUDevice, float, AVG>); -REGISTER_KERNEL_BUILDER( - Name("MaxPool3D").Device(DEVICE_CPU).TypeConstraint<float>("T"), - Pooling3DOp<CPUDevice, float, MAX>); template <typename Device, typename T> struct LaunchMaxPooling3dGradOp; @@ -331,10 +377,6 @@ class MaxPooling3dGradOp : public OpKernel { TensorFormat data_format_; }; -REGISTER_KERNEL_BUILDER( - Name("MaxPool3DGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"), - MaxPooling3dGradOp<CPUDevice, float>); - template <typename Device, typename T> struct LaunchAvgPooling3dGradOp; @@ -499,11 +541,208 @@ class AvgPooling3dGradOp : public OpKernel { TensorFormat data_format_; }; -REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad") - .Device(DEVICE_CPU) - .TypeConstraint<float>("T") - .HostMemory("orig_input_shape"), - AvgPooling3dGradOp<CPUDevice, float>); +template <typename Device, typename T> +struct LaunchMaxPooling3dGradGradOp; + +template <typename T> +struct LaunchMaxPooling3dGradGradOp<CPUDevice, T> { + static void launch(OpKernelContext* context, const Pool3dParameters& params, + const Tensor& tensor_in, const Tensor& tensor_out, + const Tensor& tensor_top_diff, + Tensor* tensor_bottom_diff) { + OP_REQUIRES( + context, params.data_format == FORMAT_NHWC, + errors::InvalidArgument("Default MaxPooling3dGradGradOp only supports", + "NDHWC on CPU device type")); + + typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + ConstEigenMatrixMap; + typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> + EigenMatrixMap; + + ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), params.depth, + params.tensor_in_planes * params.tensor_in_cols * + params.tensor_in_rows * + params.tensor_in_batch); + ConstEigenMatrixMap out_mat(tensor_out.flat<T>().data(), params.depth, + params.out_plane * params.out_width * + params.out_height * params.tensor_in_batch); + ConstEigenMatrixMap top_diff_mat( + tensor_top_diff.flat<T>().data(), params.depth, + params.tensor_in_planes * params.tensor_in_cols * + params.tensor_in_rows * params.tensor_in_batch); + EigenMatrixMap bottom_diff_mat( + tensor_bottom_diff->flat<T>().data(), params.depth, + params.out_plane * params.out_width * params.out_height * + params.tensor_in_batch); + + const DeviceBase::CpuWorkerThreads& worker_threads = + *(context->device()->tensorflow_cpu_worker_threads()); + + auto shard = [¶ms, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat]( + int64 start, int64 limit) { + const int32 depth = params.depth; + const int32 in_planes = params.tensor_in_planes; + const int32 in_rows = params.tensor_in_rows; + const int32 in_cols = params.tensor_in_cols; + const int32 pad_planes = params.pad_planes; + const int32 pad_rows = params.pad_rows; + const int32 pad_cols = params.pad_cols; + const int32 window_planes = params.window_planes; + const int32 window_rows = params.window_rows; + const int32 window_cols = params.window_cols; + const int32 plane_stride = params.plane_stride; + const int32 row_stride = params.row_stride; + const int32 col_stride = params.col_stride; + const int32 out_plane = params.out_plane; + const int32 out_height = params.out_height; + const int32 out_width = params.out_width; + + { + // Initializes the output grad backprop tensor with 0. + const int32 output_image_size = + out_plane * out_height * out_width * params.depth; + EigenMatrixMap bottom_diff_shard( + bottom_diff_mat.data() + start * output_image_size, 1, + (limit - start) * output_image_size); + bottom_diff_shard.setZero(); + } + + for (int b = start; b < limit; ++b) { + for (int pp = 0; pp < out_plane; ++pp) { + for (int ph = 0; ph < out_height; ++ph) { + for (int pw = 0; pw < out_width; ++pw) { + // (p_start, p_end) * (h_start, h_end) * (w_start, w_end) is the + // range that the input vector projects to. + int p_start = pp * plane_stride - pad_planes; + const int p_end = std::min(p_start + window_planes, in_planes); + int h_start = ph * row_stride - pad_rows; + const int h_end = std::min(h_start + window_rows, in_rows); + int w_start = pw * col_stride - pad_cols; + const int w_end = std::min(w_start + window_cols, in_cols); + p_start = std::max(p_start, 0); + h_start = std::max(h_start, 0); + w_start = std::max(w_start, 0); + const int out_index = + ((b * out_plane + pp) * out_height + ph) * out_width + pw; + // Find value corresponding to the input maximum in top_diff. + for (int d = 0; d < depth; ++d) { + const T& output_ref = out_mat.coeffRef(d, out_index); + bool should_stop = false; + for (int p = p_start; p < p_end && !should_stop; ++p) { + for (int h = h_start; h < h_end && !should_stop; ++h) { + for (int w = w_start; w < w_end && !should_stop; ++w) { + const int in_index = + ((b * in_planes + p) * in_rows + h) * in_cols + w; + const T& input_ref = in_mat.coeffRef(d, in_index); + if (output_ref == input_ref) { + T& bottom_diff_ref = + bottom_diff_mat.coeffRef(d, out_index); + bottom_diff_ref = top_diff_mat.coeffRef(d, in_index); + should_stop = true; + } + } + } + } + } + } + } + } + } + }; + const int64 shard_cost = + params.out_plane * params.out_height * params.out_width * params.depth * + params.window_planes * params.window_rows * params.window_cols; + Shard(worker_threads.num_threads, worker_threads.workers, + params.tensor_in_batch, shard_cost, shard); + } +}; + +template <class Device, class T> +class MaxPooling3dGradGradOp : public OpKernel { + public: + explicit MaxPooling3dGradGradOp(OpKernelConstruction* context) + : OpKernel(context) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 5, + errors::InvalidArgument("Sliding window ksize field must " + "specify 5 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + const int32 ksize_c = GetTensorDim(ksize_, data_format_, 'C'); + const int32 stride_c = GetTensorDim(stride_, data_format_, 'C'); + OP_REQUIRES(context, ksize_c == 1 && stride_c == 1, + errors::Unimplemented("MaxPooling3dGradGrad is not yet " + "supported on the depth dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + const Tensor& tensor_out = context->input(1); + const Tensor& out_grad_backprop = context->input(2); + + // For maxpooling3d, tensor_in should have 5 dimensions. + OP_REQUIRES(context, tensor_in.dims() == 5, + errors::InvalidArgument("tensor_in must be 5-dimensional")); + OP_REQUIRES(context, tensor_out.dims() == 5, + errors::InvalidArgument("tensor_out must be 5-dimensional")); + // For maxpooling3d, out_grad_backprop should have 5 dimensions. + OP_REQUIRES( + context, out_grad_backprop.dims() == 5, + errors::InvalidArgument("out_grad_backprop must be 5-dimensional")); + + Pool3dParameters params{context, ksize_, stride_, + padding_, data_format_, tensor_in.shape()}; + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {2}, 0, tensor_out.shape(), &output)); + + LaunchMaxPooling3dGradGradOp<Device, T>::launch( + context, params, tensor_in, tensor_out, out_grad_backprop, output); + } + + private: + std::vector<int32> ksize_; + std::vector<int32> stride_; + Padding padding_; + TensorFormat data_format_; +}; + +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("MaxPool3D").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + Pooling3DOp<D##Device, T, MAX>); \ + REGISTER_KERNEL_BUILDER(Name("MaxPool3DGrad") \ + .Device(DEVICE_##D) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<T>("TInput"), \ + MaxPooling3dGradOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER( \ + Name("MaxPool3DGradGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + MaxPooling3dGradGradOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER( \ + Name("AvgPool3D").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + Pooling3DOp<D##Device, T, AVG>); \ + REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad") \ + .Device(DEVICE_##D) \ + .TypeConstraint<T>("T") \ + .HostMemory("orig_input_shape"), \ + AvgPooling3dGradOp<D##Device, T>); + +#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T) +TF_CALL_float(REGISTER_CPU_KERNELS); +#undef REGISTER_CPU_KERNELS #if GOOGLE_CUDA @@ -535,13 +774,6 @@ struct LaunchPoolingOp<GPUDevice, T, MAX> { } }; -REGISTER_KERNEL_BUILDER( - Name("AvgPool3D").Device(DEVICE_GPU).TypeConstraint<float>("T"), - Pooling3DOp<GPUDevice, float, AVG>); -REGISTER_KERNEL_BUILDER( - Name("MaxPool3D").Device(DEVICE_GPU).TypeConstraint<float>("T"), - Pooling3DOp<GPUDevice, float, MAX>); - template <typename T> struct LaunchMaxPooling3dGradOp<GPUDevice, T> { static void launch(OpKernelContext* context, const Tensor& tensor_in, @@ -559,10 +791,6 @@ struct LaunchMaxPooling3dGradOp<GPUDevice, T> { } }; -REGISTER_KERNEL_BUILDER( - Name("MaxPool3DGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"), - MaxPooling3dGradOp<GPUDevice, float>); - template <typename T> struct LaunchAvgPooling3dGradOp<GPUDevice, T> { static void launch(OpKernelContext* context, @@ -579,12 +807,36 @@ struct LaunchAvgPooling3dGradOp<GPUDevice, T> { nullptr, nullptr, output); } }; -REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad") - .Device(DEVICE_GPU) - .TypeConstraint<float>("T") - .HostMemory("orig_input_shape"), - AvgPooling3dGradOp<GPUDevice, float>); + +template <typename T> +struct LaunchMaxPooling3dGradGradOp<GPUDevice, T> { + static void launch(OpKernelContext* context, const Pool3dParameters& params, + const Tensor& tensor_in, const Tensor& tensor_out, + const Tensor& tensor_top_diff, + Tensor* tensor_bottom_diff) { + bool status = functor::MaxPool3dGradBackward<T>()( + params.data_format, tensor_in.flat<T>().data(), + tensor_out.flat<T>().data(), params.tensor_in_batch, params.out_plane, + params.out_height, params.out_width, params.depth, + params.tensor_in_planes, params.tensor_in_rows, params.tensor_in_cols, + params.window_planes, params.window_rows, params.window_cols, + params.plane_stride, params.row_stride, params.col_stride, + params.pad_planes, params.pad_rows, params.pad_cols, + tensor_top_diff.flat<T>().data(), tensor_bottom_diff->flat<T>().data(), + context->eigen_gpu_device()); + if (!status) { + context->SetStatus( + errors::Internal("Failed launching MaxPool3dGradBackward")); + } + } +}; + +#define REGISTER_GPU_KERNELS(T) REGISTER_KERNELS(GPU, T) +TF_CALL_float(REGISTER_GPU_KERNELS) TF_CALL_half(REGISTER_GPU_KERNELS) +#undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA +#undef REGISTER_KERNELS + } // namespace tensorflow |