aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/pooling_ops_3d.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/pooling_ops_3d.cc')
-rw-r--r--tensorflow/core/kernels/pooling_ops_3d.cc314
1 files changed, 283 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/pooling_ops_3d.cc b/tensorflow/core/kernels/pooling_ops_3d.cc
index f12c18eaa8..538dca24ae 100644
--- a/tensorflow/core/kernels/pooling_ops_3d.cc
+++ b/tensorflow/core/kernels/pooling_ops_3d.cc
@@ -14,12 +14,15 @@ limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
+#include "tensorflow/core/kernels/pooling_ops_3d.h"
+
#include <array>
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
@@ -28,15 +31,64 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cudnn_pooling_gpu.h"
+#include "tensorflow/core/kernels/pooling_ops_3d_gpu.h"
#endif
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+Pool3dParameters::Pool3dParameters(OpKernelContext* context,
+ const std::vector<int32>& ksize,
+ const std::vector<int32>& stride,
+ Padding padding, TensorFormat data_format,
+ const TensorShape& tensor_in_shape) {
+ // For maxpooling, tensor_in should have 4 dimensions.
+ OP_REQUIRES(context, tensor_in_shape.dims() == 5,
+ errors::InvalidArgument("tensor_in must be 4-dimensional"));
+
+ this->data_format = data_format;
+ depth = GetTensorDim(tensor_in_shape, data_format, 'C');
+ tensor_in_planes = GetTensorDim(tensor_in_shape, data_format, '0');
+ tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, '1');
+ tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, '2');
+ tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
+ window_planes = GetTensorDim(ksize, data_format, '0');
+ window_rows = GetTensorDim(ksize, data_format, '1');
+ window_cols = GetTensorDim(ksize, data_format, '2');
+ depth_window = GetTensorDim(ksize, data_format, 'C');
+ plane_stride = GetTensorDim(stride, data_format, '0');
+ row_stride = GetTensorDim(stride, data_format, '1');
+ col_stride = GetTensorDim(stride, data_format, '2');
+ depth_stride = GetTensorDim(stride, data_format, 'C');
+
+ // We only support 3D pooling across plane/width/height. Depthwise
+ // pooling is not supported.
+ OP_REQUIRES(
+ context, depth_window == 1 && depth_stride == 1,
+ errors::Unimplemented(
+ "Pooling3d only supports pooling across plane/width/height."));
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSize(tensor_in_planes, window_planes,
+ plane_stride, padding,
+ &out_plane, &pad_planes));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(tensor_in_rows, window_rows, row_stride,
+ padding, &out_height, &pad_rows));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(tensor_in_cols, window_cols, col_stride,
+ padding, &out_width, &pad_cols));
+}
+
+TensorShape Pool3dParameters::forward_output_shape() {
+ return ShapeFromFormat(data_format, tensor_in_batch,
+ {{out_plane, out_height, out_width}}, depth);
+}
+
enum PoolingType { MAX, AVG };
template <typename Device, typename T, PoolingType Type>
@@ -147,12 +199,6 @@ class Pooling3DOp : public UnaryOp<T> {
Padding padding_;
TensorFormat data_format_;
};
-REGISTER_KERNEL_BUILDER(
- Name("AvgPool3D").Device(DEVICE_CPU).TypeConstraint<float>("T"),
- Pooling3DOp<CPUDevice, float, AVG>);
-REGISTER_KERNEL_BUILDER(
- Name("MaxPool3D").Device(DEVICE_CPU).TypeConstraint<float>("T"),
- Pooling3DOp<CPUDevice, float, MAX>);
template <typename Device, typename T>
struct LaunchMaxPooling3dGradOp;
@@ -331,10 +377,6 @@ class MaxPooling3dGradOp : public OpKernel {
TensorFormat data_format_;
};
-REGISTER_KERNEL_BUILDER(
- Name("MaxPool3DGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
- MaxPooling3dGradOp<CPUDevice, float>);
-
template <typename Device, typename T>
struct LaunchAvgPooling3dGradOp;
@@ -499,11 +541,208 @@ class AvgPooling3dGradOp : public OpKernel {
TensorFormat data_format_;
};
-REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad")
- .Device(DEVICE_CPU)
- .TypeConstraint<float>("T")
- .HostMemory("orig_input_shape"),
- AvgPooling3dGradOp<CPUDevice, float>);
+template <typename Device, typename T>
+struct LaunchMaxPooling3dGradGradOp;
+
+template <typename T>
+struct LaunchMaxPooling3dGradGradOp<CPUDevice, T> {
+ static void launch(OpKernelContext* context, const Pool3dParameters& params,
+ const Tensor& tensor_in, const Tensor& tensor_out,
+ const Tensor& tensor_top_diff,
+ Tensor* tensor_bottom_diff) {
+ OP_REQUIRES(
+ context, params.data_format == FORMAT_NHWC,
+ errors::InvalidArgument("Default MaxPooling3dGradGradOp only supports",
+ "NDHWC on CPU device type"));
+
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ ConstEigenMatrixMap;
+ typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ EigenMatrixMap;
+
+ ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), params.depth,
+ params.tensor_in_planes * params.tensor_in_cols *
+ params.tensor_in_rows *
+ params.tensor_in_batch);
+ ConstEigenMatrixMap out_mat(tensor_out.flat<T>().data(), params.depth,
+ params.out_plane * params.out_width *
+ params.out_height * params.tensor_in_batch);
+ ConstEigenMatrixMap top_diff_mat(
+ tensor_top_diff.flat<T>().data(), params.depth,
+ params.tensor_in_planes * params.tensor_in_cols *
+ params.tensor_in_rows * params.tensor_in_batch);
+ EigenMatrixMap bottom_diff_mat(
+ tensor_bottom_diff->flat<T>().data(), params.depth,
+ params.out_plane * params.out_width * params.out_height *
+ params.tensor_in_batch);
+
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *(context->device()->tensorflow_cpu_worker_threads());
+
+ auto shard = [&params, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat](
+ int64 start, int64 limit) {
+ const int32 depth = params.depth;
+ const int32 in_planes = params.tensor_in_planes;
+ const int32 in_rows = params.tensor_in_rows;
+ const int32 in_cols = params.tensor_in_cols;
+ const int32 pad_planes = params.pad_planes;
+ const int32 pad_rows = params.pad_rows;
+ const int32 pad_cols = params.pad_cols;
+ const int32 window_planes = params.window_planes;
+ const int32 window_rows = params.window_rows;
+ const int32 window_cols = params.window_cols;
+ const int32 plane_stride = params.plane_stride;
+ const int32 row_stride = params.row_stride;
+ const int32 col_stride = params.col_stride;
+ const int32 out_plane = params.out_plane;
+ const int32 out_height = params.out_height;
+ const int32 out_width = params.out_width;
+
+ {
+ // Initializes the output grad backprop tensor with 0.
+ const int32 output_image_size =
+ out_plane * out_height * out_width * params.depth;
+ EigenMatrixMap bottom_diff_shard(
+ bottom_diff_mat.data() + start * output_image_size, 1,
+ (limit - start) * output_image_size);
+ bottom_diff_shard.setZero();
+ }
+
+ for (int b = start; b < limit; ++b) {
+ for (int pp = 0; pp < out_plane; ++pp) {
+ for (int ph = 0; ph < out_height; ++ph) {
+ for (int pw = 0; pw < out_width; ++pw) {
+ // (p_start, p_end) * (h_start, h_end) * (w_start, w_end) is the
+ // range that the input vector projects to.
+ int p_start = pp * plane_stride - pad_planes;
+ const int p_end = std::min(p_start + window_planes, in_planes);
+ int h_start = ph * row_stride - pad_rows;
+ const int h_end = std::min(h_start + window_rows, in_rows);
+ int w_start = pw * col_stride - pad_cols;
+ const int w_end = std::min(w_start + window_cols, in_cols);
+ p_start = std::max(p_start, 0);
+ h_start = std::max(h_start, 0);
+ w_start = std::max(w_start, 0);
+ const int out_index =
+ ((b * out_plane + pp) * out_height + ph) * out_width + pw;
+ // Find value corresponding to the input maximum in top_diff.
+ for (int d = 0; d < depth; ++d) {
+ const T& output_ref = out_mat.coeffRef(d, out_index);
+ bool should_stop = false;
+ for (int p = p_start; p < p_end && !should_stop; ++p) {
+ for (int h = h_start; h < h_end && !should_stop; ++h) {
+ for (int w = w_start; w < w_end && !should_stop; ++w) {
+ const int in_index =
+ ((b * in_planes + p) * in_rows + h) * in_cols + w;
+ const T& input_ref = in_mat.coeffRef(d, in_index);
+ if (output_ref == input_ref) {
+ T& bottom_diff_ref =
+ bottom_diff_mat.coeffRef(d, out_index);
+ bottom_diff_ref = top_diff_mat.coeffRef(d, in_index);
+ should_stop = true;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ };
+ const int64 shard_cost =
+ params.out_plane * params.out_height * params.out_width * params.depth *
+ params.window_planes * params.window_rows * params.window_cols;
+ Shard(worker_threads.num_threads, worker_threads.workers,
+ params.tensor_in_batch, shard_cost, shard);
+ }
+};
+
+template <class Device, class T>
+class MaxPooling3dGradGradOp : public OpKernel {
+ public:
+ explicit MaxPooling3dGradGradOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 5,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ const int32 ksize_c = GetTensorDim(ksize_, data_format_, 'C');
+ const int32 stride_c = GetTensorDim(stride_, data_format_, 'C');
+ OP_REQUIRES(context, ksize_c == 1 && stride_c == 1,
+ errors::Unimplemented("MaxPooling3dGradGrad is not yet "
+ "supported on the depth dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+ const Tensor& tensor_out = context->input(1);
+ const Tensor& out_grad_backprop = context->input(2);
+
+ // For maxpooling3d, tensor_in should have 5 dimensions.
+ OP_REQUIRES(context, tensor_in.dims() == 5,
+ errors::InvalidArgument("tensor_in must be 5-dimensional"));
+ OP_REQUIRES(context, tensor_out.dims() == 5,
+ errors::InvalidArgument("tensor_out must be 5-dimensional"));
+ // For maxpooling3d, out_grad_backprop should have 5 dimensions.
+ OP_REQUIRES(
+ context, out_grad_backprop.dims() == 5,
+ errors::InvalidArgument("out_grad_backprop must be 5-dimensional"));
+
+ Pool3dParameters params{context, ksize_, stride_,
+ padding_, data_format_, tensor_in.shape()};
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {2}, 0, tensor_out.shape(), &output));
+
+ LaunchMaxPooling3dGradGradOp<Device, T>::launch(
+ context, params, tensor_in, tensor_out, out_grad_backprop, output);
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+};
+
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MaxPool3D").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ Pooling3DOp<D##Device, T, MAX>); \
+ REGISTER_KERNEL_BUILDER(Name("MaxPool3DGrad") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<T>("TInput"), \
+ MaxPooling3dGradOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MaxPool3DGradGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ MaxPooling3dGradGradOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AvgPool3D").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ Pooling3DOp<D##Device, T, AVG>); \
+ REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("orig_input_shape"), \
+ AvgPooling3dGradOp<D##Device, T>);
+
+#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T)
+TF_CALL_float(REGISTER_CPU_KERNELS);
+#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
@@ -535,13 +774,6 @@ struct LaunchPoolingOp<GPUDevice, T, MAX> {
}
};
-REGISTER_KERNEL_BUILDER(
- Name("AvgPool3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
- Pooling3DOp<GPUDevice, float, AVG>);
-REGISTER_KERNEL_BUILDER(
- Name("MaxPool3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
- Pooling3DOp<GPUDevice, float, MAX>);
-
template <typename T>
struct LaunchMaxPooling3dGradOp<GPUDevice, T> {
static void launch(OpKernelContext* context, const Tensor& tensor_in,
@@ -559,10 +791,6 @@ struct LaunchMaxPooling3dGradOp<GPUDevice, T> {
}
};
-REGISTER_KERNEL_BUILDER(
- Name("MaxPool3DGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
- MaxPooling3dGradOp<GPUDevice, float>);
-
template <typename T>
struct LaunchAvgPooling3dGradOp<GPUDevice, T> {
static void launch(OpKernelContext* context,
@@ -579,12 +807,36 @@ struct LaunchAvgPooling3dGradOp<GPUDevice, T> {
nullptr, nullptr, output);
}
};
-REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T")
- .HostMemory("orig_input_shape"),
- AvgPooling3dGradOp<GPUDevice, float>);
+
+template <typename T>
+struct LaunchMaxPooling3dGradGradOp<GPUDevice, T> {
+ static void launch(OpKernelContext* context, const Pool3dParameters& params,
+ const Tensor& tensor_in, const Tensor& tensor_out,
+ const Tensor& tensor_top_diff,
+ Tensor* tensor_bottom_diff) {
+ bool status = functor::MaxPool3dGradBackward<T>()(
+ params.data_format, tensor_in.flat<T>().data(),
+ tensor_out.flat<T>().data(), params.tensor_in_batch, params.out_plane,
+ params.out_height, params.out_width, params.depth,
+ params.tensor_in_planes, params.tensor_in_rows, params.tensor_in_cols,
+ params.window_planes, params.window_rows, params.window_cols,
+ params.plane_stride, params.row_stride, params.col_stride,
+ params.pad_planes, params.pad_rows, params.pad_cols,
+ tensor_top_diff.flat<T>().data(), tensor_bottom_diff->flat<T>().data(),
+ context->eigen_gpu_device());
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launching MaxPool3dGradBackward"));
+ }
+ }
+};
+
+#define REGISTER_GPU_KERNELS(T) REGISTER_KERNELS(GPU, T)
+TF_CALL_float(REGISTER_GPU_KERNELS) TF_CALL_half(REGISTER_GPU_KERNELS)
+#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
+#undef REGISTER_KERNELS
+
} // namespace tensorflow