diff options
author | 2018-02-10 12:45:12 -0800 | |
---|---|---|
committer | 2018-02-10 12:49:22 -0800 | |
commit | 45fae93d626e41c17fc988b88de0e2721771d222 (patch) | |
tree | eb885f2de3a7ddd30f9813d42fe4d8ec3bb85d28 /tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc | |
parent | f669885e1c135b1e5ad7b2e936083860a84b0aea (diff) |
Getting rid of unnecessary GPUDevice typedef.
Passing DepthwiseArgs by reference in host code.
PiperOrigin-RevId: 185263307
Diffstat (limited to 'tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc | 70 |
1 files changed, 36 insertions, 34 deletions
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 126b64f73d..1e9345828a 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -34,13 +34,12 @@ limitations under the License. namespace tensorflow { -typedef Eigen::GpuDevice GPUDevice; using Eigen::GpuDevice; // Returns whether depthwise convolution forward or backward input pass can be // performed using the faster ('Small') variant of the kernel. EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall( - const DepthwiseArgs args) { + const DepthwiseArgs& args) { return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 && args.in_cols <= 32 && args.in_rows == args.out_rows && args.in_cols == args.out_cols && args.pad_rows >= 0 && @@ -53,7 +52,7 @@ EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall( // Returns whether depthwise convolution backward filter pass can be performed // using the faster ('Small') variant of the kernel. EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const DepthwiseArgs args, const int block_rows) { + const DepthwiseArgs& args, const int block_rows) { return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 && args.in_cols <= 32 && args.in_rows == args.out_rows && args.in_cols == args.out_cols && args.pad_rows >= 0 && @@ -565,8 +564,9 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( template <typename T, DepthwiseConv2dDirection kDirection, int kKnownFilterWidth, int kKnownFilterHeight, int kBlockSlices, bool kKnownEvenRows> -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, - const T* input, const T* filter, T* output, +void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, TensorFormat data_format) { const int block_rows = (args.in_rows + 1) / 2; dim3 block_dim; @@ -602,8 +602,9 @@ void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, template <typename T, DepthwiseConv2dDirection kDirection, int kKnownFilterWidth, int kKnownFilterHeight, int kBlockSlices> -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, - const T* input, const T* filter, T* output, +void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, TensorFormat data_format) { if (args.in_rows & 1) { LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, @@ -618,8 +619,9 @@ void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, template <typename T, DepthwiseConv2dDirection kDirection, int kKnownFilterWidth, int kKnownFilterHeight> -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, - const T* input, const T* filter, T* output, +void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, TensorFormat data_format) { // Maximize (power of two) kBlockSlices while keeping a block within 1024 // threads (2 pixels per thread). @@ -641,7 +643,7 @@ void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, int kKnownDepthMultiplier> -void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, +void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs& args, const T* input, const T* filter, T* output, TensorFormat data_format) { void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); @@ -671,7 +673,7 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, } template <typename T, int kKnownFilterWidth, int kKnownFilterHeight> -void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, +void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs& args, const T* input, const T* filter, T* output, TensorFormat data_format) { if (args.depth_multiplier == 1) { @@ -692,12 +694,12 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, // A simple launch pad to launch the Cuda kernel for depthwise convolution. template <typename T> -void LaunchDepthwiseConvOp<GPUDevice, T>::operator()(OpKernelContext* ctx, - const DepthwiseArgs args, +void LaunchDepthwiseConvOp<GpuDevice, T>::operator()(OpKernelContext* ctx, + const DepthwiseArgs& args, const T* input, const T* filter, T* output, TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device<GPUDevice>(); + const GpuDevice& d = ctx->eigen_device<GpuDevice>(); if (args.filter_rows == 3 && args.filter_cols == 3) { LaunchDepthwiseConv2dGPU<T, 3, 3>(d, args, input, filter, output, data_format); @@ -711,9 +713,9 @@ void LaunchDepthwiseConvOp<GPUDevice, T>::operator()(OpKernelContext* ctx, "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed")); } -template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>; -template struct LaunchDepthwiseConvOp<GPUDevice, float>; -template struct LaunchDepthwiseConvOp<GPUDevice, double>; +template struct LaunchDepthwiseConvOp<GpuDevice, Eigen::half>; +template struct LaunchDepthwiseConvOp<GpuDevice, float>; +template struct LaunchDepthwiseConvOp<GpuDevice, double>; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. input. template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, @@ -854,7 +856,7 @@ __global__ void __launch_bounds__(640, 2) template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, int kKnownDepthMultiplier> void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, - const DepthwiseArgs args, + const DepthwiseArgs& args, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { @@ -879,7 +881,7 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, template <typename T, int kKnownFilterWidth, int kKnownFilterHeight> void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, - const DepthwiseArgs args, + const DepthwiseArgs& args, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { @@ -903,10 +905,10 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, // A simple launch pad to launch the Cuda kernel for depthwise convolution. template <typename T> -void LaunchDepthwiseConvBackpropInputOp<GPUDevice, T>::operator()( +void LaunchDepthwiseConvBackpropInputOp<GpuDevice, T>::operator()( OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device<GPUDevice>(); + const GpuDevice& d = ctx->eigen_device<GpuDevice>(); if (args.filter_rows == 3 && args.filter_cols == 3) { LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>( d, args, out_backprop, filter, in_backprop, data_format); @@ -921,9 +923,9 @@ void LaunchDepthwiseConvBackpropInputOp<GPUDevice, T>::operator()( "utGPULaunch failed")); } -template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, Eigen::half>; -template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>; -template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>; +template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, Eigen::half>; +template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, float>; +template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, double>; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, @@ -1450,7 +1452,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, int kBlockSlices, int kAccumPixels> bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs args, const int block_rows, + const GpuDevice& d, const DepthwiseArgs& args, const int block_rows, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { const int tile_cols = args.in_cols + args.filter_cols - 1; @@ -1490,7 +1492,7 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, int kBlockSlices> bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs args, const int block_rows, + const GpuDevice& d, const DepthwiseArgs& args, const int block_rows, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { // Minimize (power of two) kAccumPixels, while satisfying @@ -1513,7 +1515,7 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( template <typename T, int kKnownFilterWidth, int kKnownFilterHeight> bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs args, const T* out_backprop, + const GpuDevice& d, const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { // Maximize (power of two) kBlockSlices while keeping a block within 1024 // threads (2 pixels per thread). @@ -1560,7 +1562,7 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, int kKnownDepthMultiplier> void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, - const DepthwiseArgs args, + const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { @@ -1585,7 +1587,7 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, template <typename T, int kKnownFilterWidth, int kKnownFilterHeight> void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, - const DepthwiseArgs args, + const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { @@ -1608,10 +1610,10 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, // A simple launch pad to launch the Cuda kernel for depthwise convolution. template <typename T> -void LaunchDepthwiseConvBackpropFilterOp<GPUDevice, T>::operator()( +void LaunchDepthwiseConvBackpropFilterOp<GpuDevice, T>::operator()( OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device<GPUDevice>(); + const GpuDevice& d = ctx->eigen_device<GpuDevice>(); auto stream = ctx->op_device_context()->stream(); // Initialize the results to 0. @@ -1634,8 +1636,8 @@ void LaunchDepthwiseConvBackpropFilterOp<GPUDevice, T>::operator()( "terGPULaunch failed")); } -template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, Eigen::half>; -template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>; -template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>; +template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, Eigen::half>; +template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, float>; +template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, double>; } // namespace tensorflow #endif // GOOGLE_CUDA |