aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/maxpooling_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op.cc')
-rw-r--r--tensorflow/core/kernels/maxpooling_op.cc491
1 files changed, 406 insertions, 85 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index 41c6251ac7..eb590280c9 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
@@ -46,6 +47,7 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
const int kInvalidMaxPoolingIndex = -1;
@@ -187,40 +189,6 @@ static void SpatialMaxPoolWithArgMaxHelper(
params.tensor_in_batch, shard_cost, shard);
}
-REGISTER_KERNEL_BUILDER(
- Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<float>("T"),
- MaxPoolingOp<CPUDevice, float>);
-REGISTER_KERNEL_BUILDER(
- Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
- MaxPoolingOp<CPUDevice, Eigen::half>);
-
-#if GOOGLE_CUDA
-// Forward declarations for the functor specializations for GPU.
-namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void SpatialMaxPooling<Eigen::GpuDevice, T>::operator()( \
- const Eigen::GpuDevice& d, typename TTypes<T, 4>::Tensor output, \
- typename TTypes<T, 4>::ConstTensor input, int window_rows, \
- int window_cols, int row_stride, int col_stride, \
- const Eigen::PaddingType& padding); \
- extern template struct SpatialMaxPooling<Eigen::GpuDevice, T>;
-
-DECLARE_GPU_SPEC(float);
-#undef DECLARE_GPU_SPEC
-} // namespace functor
-
-// Note(jiayq): Currently, the Caffe custom implementation is faster than the
-// default Eigen implementation so we are using the custom kernel as the
-// default. However, you can explicitly invoke the eigen version using
-// kernel_label_map.
-REGISTER_KERNEL_BUILDER(Name("MaxPool")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T")
- .Label("eigen_tensor"),
- MaxPoolingOp<Eigen::GpuDevice, float>);
-#endif // GOOGLE_CUDA
-
// The operation to compute MaxPool gradients.
// It takes three inputs:
// - The original input tensor
@@ -237,7 +205,7 @@ class MaxPoolingGradOp : public OpKernel {
errors::InvalidArgument("Invalid data format"));
OP_REQUIRES(
context, data_format_ == FORMAT_NHWC,
- errors::InvalidArgument("Default MaxPoolinGradOp only supports NHWC ",
+ errors::InvalidArgument("Default MaxPoolingGradOp only supports NHWC ",
"on device type ",
DeviceTypeString(context->device_type())));
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
@@ -305,13 +273,6 @@ class MaxPoolingGradOp : public OpKernel {
TensorFormat data_format_;
};
-REGISTER_KERNEL_BUILDER(
- Name("MaxPoolGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
- MaxPoolingGradOp<CPUDevice, float>);
-REGISTER_KERNEL_BUILDER(
- Name("MaxPoolGrad").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
- MaxPoolingGradOp<CPUDevice, Eigen::half>);
-
#ifdef GOOGLE_CUDA
template <typename T>
@@ -329,13 +290,13 @@ static void MaxPoolingBackwardCustomKernel(
return;
}
- MaxPoolBackwardNoMask(
+ functor::MaxPoolBackwardNoMask<T>()(
tensor_in->flat<T>().data(), params.tensor_in_batch,
params.tensor_in_rows, params.tensor_in_cols, params.depth,
params.out_height, params.out_width, params.window_rows,
params.window_cols, params.row_stride, params.col_stride, params.pad_rows,
- params.pad_cols, out_backprop.flat<T>().data(),
- output->flat<T>().data(), context->eigen_device<Eigen::GpuDevice>());
+ params.pad_cols, out_backprop.flat<T>().data(), output->flat<T>().data(),
+ context->eigen_device<Eigen::GpuDevice>());
}
template <class T>
@@ -403,12 +364,252 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
bool use_dnn_;
};
-REGISTER_KERNEL_BUILDER(
- Name("MaxPoolGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
- MaxPoolingGradOp<Eigen::GpuDevice, float>);
-REGISTER_KERNEL_BUILDER(
- Name("MaxPoolGrad").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
- MaxPoolingGradOp<Eigen::GpuDevice, Eigen::half>);
+#endif // GOOGLE_CUDA
+
+// The operation to compute gradient of MaxPool gradients.
+// It takes three inputs:
+// - The original input tensor
+// - The original output tensor
+// - Backprop tensor for output gradients
+// It produces one output: backprop tensor for output gradient.
+template <class Device, class T>
+class MaxPoolingGradGradOp : public OpKernel {
+ public:
+ explicit MaxPoolingGradGradOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Default MaxPoolingGradGradOp only supports NHWC ",
+ "on device type ", DeviceTypeString(context->device_type())));
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ OP_REQUIRES(
+ context, ksize_[3] == 1 && stride_[3] == 1,
+ errors::Unimplemented(
+ "MaxPoolingGradGrad is not yet supported on the depth dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+ const Tensor& tensor_out = context->input(1);
+ const Tensor& out_grad_backprop = context->input(2);
+
+ // For maxpooling, tensor_in should have 4 dimensions.
+ OP_REQUIRES(context, tensor_in.dims() == 4,
+ errors::InvalidArgument("tensor_in must be 4-dimensional"));
+ OP_REQUIRES(context, tensor_out.dims() == 4,
+ errors::InvalidArgument("tensor_out must be 4-dimensional"));
+ // For maxpooling, out_grad_backprop should have 4 dimensions.
+ OP_REQUIRES(
+ context, out_grad_backprop.dims() == 4,
+ errors::InvalidArgument("out_grad_backprop must be 4-dimensional"));
+
+ PoolParameters params{context, ksize_, stride_,
+ padding_, FORMAT_NHWC, tensor_in.shape()};
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {2}, 0, tensor_out.shape(), &output));
+
+ SpatialMaxPoolGradGrad(context, output, tensor_in, tensor_out,
+ out_grad_backprop, params, padding_);
+ }
+
+ private:
+ void SpatialMaxPoolGradGrad(OpKernelContext* context, Tensor* bottom_diff,
+ const Tensor& tensor_in, const Tensor& tensor_out,
+ const Tensor& top_diff,
+ const PoolParameters& params,
+ const Padding& padding) {
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ ConstEigenMatrixMap;
+ typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ EigenMatrixMap;
+
+ ConstEigenMatrixMap in_mat(
+ tensor_in.flat<T>().data(), params.depth,
+ params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch);
+ ConstEigenMatrixMap out_mat(
+ tensor_out.flat<T>().data(), params.depth,
+ params.out_width * params.out_height * params.tensor_in_batch);
+ ConstEigenMatrixMap top_diff_mat(
+ top_diff.flat<T>().data(), params.depth,
+ params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch);
+ EigenMatrixMap bottom_diff_mat(
+ bottom_diff->flat<T>().data(), params.depth,
+ params.out_width * params.out_height * params.tensor_in_batch);
+
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *(context->device()->tensorflow_cpu_worker_threads());
+
+ // The following code basically does the following:
+ // 1. Flattens the input, output, top_diff and bottom_diff tensors into
+ // two dimensional arrays.
+ // tensor_in_as_matrix:
+ // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
+ // tensor_out_as_matrix:
+ // depth by (out_width * out_height * tensor_in_batch)
+ // top_diff_as_matrix:
+ // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
+ // bottom_diff_as_matrix:
+ // depth by (out_width * out_height * tensor_in_batch)
+ //
+ // 2. Walks through the set of columns in the flattened
+ // tensor_in_as_matrix, tensor_out_as_matrix, top_diff_as_matrix
+ // and updates the column(s) corresponding to the maximum values in
+ // tensor_out_as_matrix with the corresponding values in
+ // top_diff_as_matrix.
+ auto shard = [&params, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat](
+ int64 start, int64 limit) {
+ const int32 depth = params.depth;
+ const int32 in_rows = params.tensor_in_rows;
+ const int32 in_cols = params.tensor_in_cols;
+ const int32 pad_rows = params.pad_rows;
+ const int32 pad_cols = params.pad_cols;
+ const int32 window_rows = params.window_rows;
+ const int32 window_cols = params.window_cols;
+ const int32 row_stride = params.row_stride;
+ const int32 col_stride = params.col_stride;
+ const int32 out_height = params.out_height;
+ const int32 out_width = params.out_width;
+
+ {
+ // Initializes the output grad backprop tensor with 0.
+ const int32 output_image_size = out_height * out_width * params.depth;
+ EigenMatrixMap bottom_diff_shard(
+ bottom_diff_mat.data() + start * output_image_size, 1,
+ (limit - start) * output_image_size);
+ bottom_diff_shard.setZero();
+ }
+
+ for (int b = start; b < limit; ++b) {
+ for (int ph = 0; ph < out_height; ++ph) {
+ for (int pw = 0; pw < out_width; ++pw) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ int h_start = ph * row_stride - pad_rows;
+ const int h_end = std::min(h_start + window_rows, in_rows);
+ int w_start = pw * col_stride - pad_cols;
+ const int w_end = std::min(w_start + window_cols, in_cols);
+ h_start = std::max(h_start, 0);
+ w_start = std::max(w_start, 0);
+ const int out_index = (b * out_height + ph) * out_width + pw;
+ // Find value corresponding to the input maximum in top_diff.
+ for (int d = 0; d < depth; ++d) {
+ const T& output_ref = out_mat.coeffRef(d, out_index);
+ bool should_stop = false;
+ for (int h = h_start; h < h_end && !should_stop; ++h) {
+ for (int w = w_start; w < w_end && !should_stop; ++w) {
+ const int in_index = (b * in_rows + h) * in_cols + w;
+ const T& input_ref = in_mat.coeffRef(d, in_index);
+ if (output_ref == input_ref) {
+ T& bottom_diff_ref = bottom_diff_mat.coeffRef(d, out_index);
+ bottom_diff_ref = top_diff_mat.coeffRef(d, in_index);
+ should_stop = true;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ };
+
+ const int64 shard_cost = params.out_width * params.out_height *
+ params.depth * params.window_rows *
+ params.window_cols;
+ Shard(worker_threads.num_threads, worker_threads.workers,
+ params.tensor_in_batch, shard_cost, shard);
+ }
+
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+};
+
+#ifdef GOOGLE_CUDA
+
+template <class T>
+class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
+ public:
+ typedef Eigen::GpuDevice Device;
+
+ explicit MaxPoolingGradGradOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
+ const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
+ OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+ const Tensor& tensor_out = context->input(1);
+ const Tensor& out_grad_backprop = context->input(2);
+
+ // For maxpooling, tensor_in should have 4 dimensions.
+ OP_REQUIRES(context, tensor_in.dims() == 4,
+ errors::InvalidArgument("tensor_in must be 4-dimensional 4"));
+ OP_REQUIRES(context, tensor_out.dims() == 4,
+ errors::InvalidArgument("tensor_out must be 4-dimensional"));
+ // For maxpooling, out_grad_backprop should have 4 dimensions.
+ OP_REQUIRES(
+ context, out_grad_backprop.dims() == 4,
+ errors::InvalidArgument("out_grad_backprop must be 4-dimensional"));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {2}, 0, tensor_out.shape(), &output));
+
+ PoolParameters params{context, ksize_, stride_,
+ padding_, data_format_, tensor_in.shape()};
+
+ functor::MaxPoolGradBackwardNoMask<T>()(
+ data_format_, tensor_in.flat<T>().data(), tensor_out.flat<T>().data(),
+ params.tensor_in_batch, params.out_height, params.out_width,
+ params.depth, params.tensor_in_rows, params.tensor_in_cols,
+ params.window_rows, params.window_cols, params.row_stride,
+ params.col_stride, params.pad_rows, params.pad_cols,
+ out_grad_backprop.flat<T>().data(), output->flat<T>().data(),
+ context->eigen_device<Eigen::GpuDevice>());
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool use_dnn_;
+};
#endif // GOOGLE_CUDA
@@ -565,6 +766,56 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
Padding padding_;
};
+template <typename Device, typename T>
+struct LaunchMaxPoolingGradGradWithArgmax;
+
+template <typename Device, typename T>
+class MaxPoolingGradGradWithArgmaxOp : public OpKernel {
+ public:
+ explicit MaxPoolingGradGradWithArgmaxOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+ const Tensor& grad_in = context->input(1);
+ const Tensor& argmax = context->input(2);
+
+ PoolParameters params{context, ksize_, stride_,
+ padding_, FORMAT_NHWC, tensor_in.shape()};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ TensorShape out_shape({params.tensor_in_batch, params.out_height,
+ params.out_width, params.depth});
+
+ Tensor* grad_out = nullptr;
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {1}, 0, out_shape, &grad_out));
+
+ LaunchMaxPoolingGradGradWithArgmax<Device, T>::launch(
+ context, params, grad_in, argmax, grad_out);
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
#if GOOGLE_CUDA
template <typename T>
class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
@@ -631,7 +882,7 @@ template <typename T>
struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& input, Tensor* output) {
- bool status = MaxPoolForwardWithOptionalArgmax(
+ bool status = functor::MaxPoolForwardWithOptionalArgmax<T>()(
input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
params.tensor_in_cols, params.depth, params.out_height,
params.out_width, params.window_rows, params.window_cols,
@@ -644,18 +895,11 @@ struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
}
};
-REGISTER_KERNEL_BUILDER(
- Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<float>("T"),
- MaxPoolingNoMaskOp<Eigen::GpuDevice, float>);
-REGISTER_KERNEL_BUILDER(
- Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
- MaxPoolingNoMaskOp<Eigen::GpuDevice, Eigen::half>);
-
template <typename T>
struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& input, Tensor* output, Tensor* argmax) {
- bool status = MaxPoolForwardWithOptionalArgmax(
+ bool status = functor::MaxPoolForwardWithOptionalArgmax<T>()(
input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
params.tensor_in_cols, params.depth, params.out_height,
params.out_width, params.window_rows, params.window_cols,
@@ -670,17 +914,6 @@ struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
}
};
-REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax")
- .Device(DEVICE_GPU)
- .TypeConstraint<int64>("Targmax")
- .TypeConstraint<float>("T"),
- MaxPoolingWithArgmaxOp<Eigen::GpuDevice, float>);
-REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax")
- .Device(DEVICE_GPU)
- .TypeConstraint<int64>("Targmax")
- .TypeConstraint<Eigen::half>("T"),
- MaxPoolingWithArgmaxOp<Eigen::GpuDevice, Eigen::half>);
-
template <typename T>
struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
static void launch(OpKernelContext* context, const PoolParameters& params,
@@ -693,30 +926,118 @@ struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
const int top_offset = params.out_height * params.out_width * params.depth;
const int bottom_offset =
params.tensor_in_rows * params.tensor_in_cols * params.depth;
- bool status = MaxPoolBackwardWithArgmax(
+ bool status = functor::MaxPoolBackwardWithArgmax<T>()(
output_size, input_size, grad_in.flat<T>().data(),
reinterpret_cast<const int64*>(argmax.flat<int64>().data()), top_offset,
bottom_offset, grad_out->flat<T>().data(), context->eigen_gpu_device());
if (!status) {
context->SetStatus(
- errors::Internal("Failed launching MaxPoolForwardWithArgmax"));
+ errors::Internal("Failed launching MaxPoolBackwardWithArgmax"));
}
}
};
-REGISTER_KERNEL_BUILDER(
- Name("MaxPoolGradWithArgmax")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T")
- .TypeConstraint<int64>("Targmax"),
- MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, float>);
-REGISTER_KERNEL_BUILDER(
- Name("MaxPoolGradWithArgmax")
- .Device(DEVICE_GPU)
- .TypeConstraint<Eigen::half>("T")
- .TypeConstraint<int64>("Targmax"),
- MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, Eigen::half>);
+template <typename T>
+struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
+ static void launch(OpKernelContext* context, const PoolParameters& params,
+ const Tensor& grad_in, const Tensor& argmax,
+ Tensor* grad_out) {
+ const int input_size = params.tensor_in_batch * params.tensor_in_rows *
+ params.tensor_in_cols * params.depth;
+ const int output_size = params.tensor_in_batch * params.out_height *
+ params.out_width * params.depth;
+ const int top_offset =
+ params.tensor_in_rows * params.tensor_in_cols * params.depth;
+ const int bottom_offset =
+ params.out_width * params.out_height * params.depth;
+ bool status = functor::MaxPoolGradBackwardWithArgmax<T>()(
+ output_size, input_size, grad_in.flat<T>().data(),
+ reinterpret_cast<const int64*>(argmax.flat<int64>().data()), top_offset,
+ bottom_offset, grad_out->flat<T>().data(), context->eigen_gpu_device());
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launching MaxPoolGradBackwardWithArgmax"));
+ }
+ }
+};
#endif // GOOGLE_CUDA
+#define REGISTER_MAX_POOL_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MaxPoolGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ MaxPoolingGradOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MaxPoolGradGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ MaxPoolingGradGradOp<D##Device, T>);
+
+// Below kernels implemented only for CPU device.
+#define REGISTER_CPU_ONLY_POOL_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ MaxPoolingOp<CPUDevice, T>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS);
+#undef REGISTER_CPU_ONLY_POOL_KERNELS
+
+#define REGISTER_CPU_MAX_POOL_KERNELS(T) REGISTER_MAX_POOL_KERNELS(CPU, T);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_MAX_POOL_KERNELS);
+#undef REGISTER_CPU_KERNELS
+
+#if GOOGLE_CUDA
+
+// Forward declarations for the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void SpatialMaxPooling<Eigen::GpuDevice, T>::operator()( \
+ const Eigen::GpuDevice& d, typename TTypes<T, 4>::Tensor output, \
+ typename TTypes<T, 4>::ConstTensor input, int window_rows, \
+ int window_cols, int row_stride, int col_stride, \
+ const Eigen::PaddingType& padding); \
+ extern template struct SpatialMaxPooling<Eigen::GpuDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+#define REGISTER_GPU_MAX_POOL_KERNELS(T) REGISTER_MAX_POOL_KERNELS(GPU, T)
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
+#undef REGISTER_GPU_MAX_POOL_KERNELS
+
+// Below kernels currently implemented only for GPU device.
+// Note(jiayq): Currently, the Caffe custom implementation is faster than the
+// default Eigen implementation so we are using the custom kernel as the
+// default. However, you can explicitly invoke the eigen version using
+// kernel_label_map.
+#define REGISTER_GPU_ONLY_POOL_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("MaxPool") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .Label("eigen_tensor"), \
+ MaxPoolingOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ MaxPoolingNoMaskOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<int64>("Targmax") \
+ .TypeConstraint<T>("T"), \
+ MaxPoolingWithArgmaxOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Targmax"), \
+ MaxPoolingGradWithArgmaxOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradWithArgmax") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Targmax"), \
+ MaxPoolingGradGradWithArgmaxOp<GPUDevice, T>);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS);
+#undef REGISTER_GPU_ONLY_POOL_KERNELS
+
+#endif // GOOGLE_CUDA
+
+#undef REGISTER_MAX_POOL_KERNELS
+
} // namespace tensorflow