diff options
author | 2017-09-04 03:19:53 -0700 | |
---|---|---|
committer | 2017-09-04 03:24:08 -0700 | |
commit | 07356b48e4b374efd406fd142faa77cfa4db05e9 (patch) | |
tree | f5049f7ef36486535e386934f3dfc48f72831f45 /tensorflow/core/kernels/conv_ops.h | |
parent | 0302320e11c7561cafac1cc279fea87de02b0cf9 (diff) |
Exposing launchpad for conv2d backprop, and unify launchpads for conv2d and depthwise_conv to match example in documentation (see ./extend/adding_an_op.md)
PiperOrigin-RevId: 167480081
Diffstat (limited to 'tensorflow/core/kernels/conv_ops.h')
-rw-r--r-- | tensorflow/core/kernels/conv_ops.h | 32 |
1 files changed, 15 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h index 60091fc27f..e29271dff2 100644 --- a/tensorflow/core/kernels/conv_ops.h +++ b/tensorflow/core/kernels/conv_ops.h @@ -32,14 +32,23 @@ namespace tensorflow { class OpKernelContext; template <typename Device, typename T> -class LaunchConv2DOp { - public: - void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format); +struct LaunchConv2DOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* output, + TensorFormat data_format); }; +#ifdef GOOGLE_CUDA +template <typename T> +struct LaunchConv2DOp<Eigen::GpuDevice, T> { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* output, + TensorFormat data_format); +}; +#endif // GOOGLE_CUDA + // Used to keep track of persistent memory buffers used within the op. // It uses malloc and free to avoid the time cost of initializing the memory. template <class T, size_t size> @@ -55,17 +64,6 @@ struct Im2ColBufferResource : public ResourceBase { string DebugString() { return "Im2ColBufferResource"; } }; -#ifdef GOOGLE_CUDA -template <typename T> -class LaunchConv2DOp<Eigen::GpuDevice, T> { - public: - void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format); -}; -#endif // GOOGLE_CUDA - } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONV_OPS_H |