aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/maxpooling_op_gpu.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op_gpu.h')
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.h92
1 files changed, 50 insertions, 42 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.h b/tensorflow/core/kernels/maxpooling_op_gpu.h
index d1c73a372e..d2029f5719 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.h
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.h
@@ -24,54 +24,62 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
+namespace functor {
// Run the forward pass of max pooling, optionally writing the argmax indices to
// the mask array, if it is not nullptr. If mask is passed in as nullptr, the
// argmax indices are not written.
-bool MaxPoolForwardWithOptionalArgmax(
- const float* bottom_data, const int batch, const int height,
- const int width, const int channels, const int pooled_height,
- const int pooled_width, const int kernel_h, const int kernel_w,
- const int stride_h, const int stride_w, const int pad_t, const int pad_l,
- float* top_data, int64* mask, const Eigen::GpuDevice& d);
-
-bool MaxPoolForwardWithOptionalArgmax(
- const Eigen::half* bottom_data, const int batch, const int height,
- const int width, const int channels, const int pooled_height,
- const int pooled_width, const int kernel_h, const int kernel_w,
- const int stride_h, const int stride_w, const int pad_t, const int pad_l,
- Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d);
-
-bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
- const float* top_diff, const int64* mask,
- const int top_offset, const int bottom_offset,
- float* bottom_diff, const Eigen::GpuDevice& d);
-
-bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
- const Eigen::half* top_diff, const int64* mask,
- const int top_offset, const int bottom_offset,
- Eigen::half* bottom_diff,
- const Eigen::GpuDevice& d);
-
-bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
- const int height, const int width,
- const int channels, const int pooled_height,
- const int pooled_width, const int kernel_h,
- const int kernel_w, const int stride_h,
- const int stride_w, const int pad_t, const int pad_l,
- const float* top_diff, float* bottom_diff,
- const Eigen::GpuDevice& d);
-
-bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch,
- const int height, const int width,
- const int channels, const int pooled_height,
- const int pooled_width, const int kernel_h,
- const int kernel_w, const int stride_h,
- const int stride_w, const int pad_t, const int pad_l,
- const Eigen::half* top_diff, Eigen::half* bottom_diff,
- const Eigen::GpuDevice& d);
+template <typename T>
+struct MaxPoolForwardWithOptionalArgmax {
+ bool operator()(const T* bottom_data, const int batch, const int height,
+ const int width, const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h, const int stride_w,
+ const int pad_t, const int pad_l, T* top_data, int64* mask,
+ const Eigen::GpuDevice& d);
+};
+
+template <typename T>
+struct MaxPoolBackwardWithArgmax {
+ bool operator()(const int output_size, const int input_size,
+ const T* top_diff, const int64* mask, const int top_offset,
+ const int bottom_offset, T* bottom_diff,
+ const Eigen::GpuDevice& d);
+};
+
+template <typename T>
+struct MaxPoolBackwardNoMask {
+ bool operator()(const T* bottom_data, const int batch, const int height,
+ const int width, const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h, const int stride_w,
+ const int pad_t, const int pad_l, const T* top_diff,
+ T* bottom_diff, const Eigen::GpuDevice& d);
+};
+
+template <typename T>
+struct MaxPoolGradBackwardWithArgmax {
+ bool operator()(const int output_size, const int input_size,
+ const T* top_diff, const int64* mask, const int top_offset,
+ const int bottom_offset, T* bottom_diff,
+ const Eigen::GpuDevice& d);
+};
+
+template <typename T>
+struct MaxPoolGradBackwardNoMask {
+ bool operator()(TensorFormat data_format, const T* bottom_data,
+ const T* output_data, const int batch,
+ const int pooled_height, const int pooled_width,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t, const int pad_l,
+ const T* top_diff, T* bottom_diff, const Eigen::GpuDevice& d);
+};
+
+} // namespace functor
} // namespace tensorflow