diff options
author | 2016-08-17 14:22:14 -0800 | |
---|---|---|
committer | 2016-08-17 15:34:07 -0700 | |
commit | 204eac513d875a4fb437d4985b2ad520f79c262c (patch) | |
tree | fadcebb02c4c78363449f28ea56fe67aa2bcd63b | |
parent | 699154cedccef7a8a44de4be2d11dc737513f941 (diff) |
Relax static shape information requirements for depthwise and separable conv2d.
The current implementation requires that if rank of a tensor is statically
known, certain dimensions needs to be known statically as well.
Change: 130570253
-rw-r--r-- | tensorflow/core/kernels/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_ops.cc | 519 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_ops.h | 58 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_ops_gpu.h | 5 | ||||
-rw-r--r-- | tensorflow/core/kernels/depthwise_conv_op.cc | 56 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/depthwise_conv_op_test.py | 10 | ||||
-rw-r--r-- | tensorflow/python/ops/nn.py | 61 |
7 files changed, 389 insertions, 322 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 04316fb085..fa494f117e 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1328,6 +1328,7 @@ tf_kernel_library( name = "depthwise_conv_op", prefix = "depthwise_conv_op", deps = [ + ":conv_ops", ":ops_util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -1936,6 +1937,7 @@ filegroup( "batch_norm_op.h", "control_flow_ops.h", "conv_2d.h", + "conv_ops.h", "image_resizer_state.h", "maxpooling_op.h", "pad_op.h", diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index d49d859e30..e2182d0ec0 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS +#include "tensorflow/core/kernels/conv_ops.h" #include <string.h> #include <map> #include <vector> @@ -50,6 +51,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +namespace { template <typename Device, typename T> struct LaunchGeneric { static void launch(OpKernelContext* ctx, const Tensor& input, @@ -87,12 +89,10 @@ struct LaunchGeneric { } } }; - -template <typename Device, typename T> -class LaunchConvOp; +} // namespace template <typename T> -class LaunchConvOp<CPUDevice, T> { +class LaunchConv2DOp<CPUDevice, T> { public: void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Tensor& input, const Tensor& filter, int row_stride, @@ -231,7 +231,7 @@ class Conv2DOp : public BinaryOp<T> { bool use_cudnn_; Padding padding_; TensorFormat data_format_; - LaunchConvOp<Device, T> launcher_; + LaunchConv2DOp<Device, T> launcher_; bool cudnn_use_autotune_; TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp); @@ -244,8 +244,11 @@ class Conv2DOp : public BinaryOp<T> { TF_CALL_half(REGISTER_CPU); TF_CALL_float(REGISTER_CPU); -#if GOOGLE_CUDA +// To be used inside depthwise_conv_op.cc. +template class LaunchConv2DOp<CPUDevice, float>; + +#if GOOGLE_CUDA int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb, int64 default_value_in_bytes) { const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); @@ -264,276 +267,265 @@ int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb, } template <typename T> -class LaunchConvOp<GPUDevice, T> { - public: - void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input_param, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format) { - using perftools::gputools::dnn::AlgorithmConfig; - using perftools::gputools::dnn::AlgorithmType; - using perftools::gputools::dnn::ProfileResult; - using perftools::gputools::dnn::kDefaultAlgorithm; - auto* stream = ctx->op_device_context()->stream(); - OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); - - if (!use_cudnn) { - ctx->SetStatus( - errors::Unimplemented("Conv2D for GPU is not currently supported " - "without cudnn")); - return; - } - - Tensor input = input_param; - if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && - col_stride == 1 && data_format == FORMAT_NHWC) { - // 1x1 filter, so call cublas directly. - const uint64 m = - input.dim_size(0) * input.dim_size(1) * input.dim_size(2); - const uint64 k = filter.dim_size(2); - const uint64 n = filter.dim_size(3); - - auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), - input.template flat<T>().size()); - auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), - filter.template flat<T>().size()); - auto c_ptr = AsDeviceMemory(output->template flat<T>().data(), - output->template flat<T>().size()); - - auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; - bool blas_launch_status = - stream - ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, - n, a_ptr, k, 0.0f, &c_ptr, n) - .ok(); - if (!blas_launch_status) { - ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, - ", n=", n, ", k=", k)); - } +void LaunchConv2DOp<GPUDevice, T>::launch( + OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input_param, const Tensor& filter, int row_stride, + int col_stride, const Eigen::PaddingType& padding, Tensor* output, + TensorFormat data_format) { + using perftools::gputools::dnn::AlgorithmConfig; + using perftools::gputools::dnn::AlgorithmType; + using perftools::gputools::dnn::ProfileResult; + using perftools::gputools::dnn::kDefaultAlgorithm; + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + + if (!use_cudnn) { + ctx->SetStatus( + errors::Unimplemented("Conv2D for GPU is not currently supported " + "without cudnn")); + return; + } - return; + Tensor input = input_param; + if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && + col_stride == 1 && data_format == FORMAT_NHWC) { + // 1x1 filter, so call cublas directly. + const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2); + const uint64 k = filter.dim_size(2); + const uint64 n = filter.dim_size(3); + + auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), + input.template flat<T>().size()); + auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), + filter.template flat<T>().size()); + auto c_ptr = AsDeviceMemory(output->template flat<T>().data(), + output->template flat<T>().size()); + + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + bool blas_launch_status = + stream + ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); } - int padding_rows = 0; - int padding_cols = 0; - const int64 in_batch = GetTensorDim(input, data_format, 'N'); - int64 in_rows = GetTensorDim(input, data_format, 'H'); - int64 in_cols = GetTensorDim(input, data_format, 'W'); - const int64 in_depths = GetTensorDim(input, data_format, 'C'); - const int64 out_batch = GetTensorDim(*output, data_format, 'N'); - const int64 out_rows = GetTensorDim(*output, data_format, 'H'); - const int64 out_cols = GetTensorDim(*output, data_format, 'W'); - const int64 out_depths = GetTensorDim(*output, data_format, 'C'); - const int64 patch_rows = filter.dim_size(0); - const int64 patch_cols = filter.dim_size(1); - if (padding == Eigen::PADDING_SAME) { - // Total padding on rows and cols is - // Pr = (R' - 1) * S + Kr - R - // Pc = (C' - 1) * S + Kc - C - // where (R', C') are output dimensions, (R, C) are input dimensions, S - // is stride, (Kr, Kc) are filter dimensions. - // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top - // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means - // we pad more on the right and bottom than on the top and left. - padding_rows = - std::max<int>(0, (out_rows - 1) * row_stride + patch_rows - in_rows); - padding_cols = - std::max<int>(0, (out_cols - 1) * col_stride + patch_cols - in_cols); - const bool rows_odd = (padding_rows % 2 != 0); - const bool cols_odd = (padding_cols % 2 != 0); - if (rows_odd || cols_odd) { - Tensor transformed_input; - int64 new_in_rows = in_rows + rows_odd; - int64 new_in_cols = in_cols + cols_odd; - OP_REQUIRES_OK(ctx, - ctx->allocate_temp( - DataTypeToEnum<T>::value, - ShapeFromFormat(data_format, in_batch, new_in_rows, - new_in_cols, in_depths), - &transformed_input)); - - functor::PadInput<GPUDevice, T, int, 4>()( - ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()), - {{0, 0}}, {{rows_odd, cols_odd}}, - To32Bit(transformed_input.tensor<T, 4>()), data_format); - - input = transformed_input; - in_rows = new_in_rows; - in_cols = new_in_cols; - } + + return; + } + int padding_rows = 0; + int padding_cols = 0; + const int64 in_batch = GetTensorDim(input, data_format, 'N'); + int64 in_rows = GetTensorDim(input, data_format, 'H'); + int64 in_cols = GetTensorDim(input, data_format, 'W'); + const int64 in_depths = GetTensorDim(input, data_format, 'C'); + const int64 out_batch = GetTensorDim(*output, data_format, 'N'); + const int64 out_rows = GetTensorDim(*output, data_format, 'H'); + const int64 out_cols = GetTensorDim(*output, data_format, 'W'); + const int64 out_depths = GetTensorDim(*output, data_format, 'C'); + const int64 patch_rows = filter.dim_size(0); + const int64 patch_cols = filter.dim_size(1); + if (padding == Eigen::PADDING_SAME) { + // Total padding on rows and cols is + // Pr = (R' - 1) * S + Kr - R + // Pc = (C' - 1) * S + Kc - C + // where (R', C') are output dimensions, (R, C) are input dimensions, S + // is stride, (Kr, Kc) are filter dimensions. + // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top + // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means + // we pad more on the right and bottom than on the top and left. + padding_rows = + std::max<int>(0, (out_rows - 1) * row_stride + patch_rows - in_rows); + padding_cols = + std::max<int>(0, (out_cols - 1) * col_stride + patch_cols - in_cols); + const bool rows_odd = (padding_rows % 2 != 0); + const bool cols_odd = (padding_cols % 2 != 0); + if (rows_odd || cols_odd) { + Tensor transformed_input; + int64 new_in_rows = in_rows + rows_odd; + int64 new_in_cols = in_cols + cols_odd; + OP_REQUIRES_OK( + ctx, + ctx->allocate_temp(DataTypeToEnum<T>::value, + ShapeFromFormat(data_format, in_batch, new_in_rows, + new_in_cols, in_depths), + &transformed_input)); + + functor::PadInput<GPUDevice, T, int, 4>()( + ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()), + {{0, 0}}, {{rows_odd, cols_odd}}, + To32Bit(transformed_input.tensor<T, 4>()), data_format); + + input = transformed_input; + in_rows = new_in_rows; + in_cols = new_in_cols; } + } - if (data_format == FORMAT_NHWC) { - // Convert the input tensor from NHWC to NCHW. - TensorShape nchw_shape = - ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths); - if (in_depths > 1) { - Tensor transformed_input; - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, - nchw_shape, &transformed_input)); - functor::NHWCToNCHW<GPUDevice, T, 4>()( - ctx->eigen_device<GPUDevice>(), - const_cast<const Tensor&>(input).tensor<T, 4>(), - transformed_input.tensor<T, 4>()); - input = transformed_input; - } else { - // If depth <= 1, then just reshape. - CHECK(input.CopyFrom(input, nchw_shape)); - } + if (data_format == FORMAT_NHWC) { + // Convert the input tensor from NHWC to NCHW. + TensorShape nchw_shape = + ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths); + if (in_depths > 1) { + Tensor transformed_input; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, + nchw_shape, &transformed_input)); + functor::NHWCToNCHW<GPUDevice, T, 4>()( + ctx->eigen_device<GPUDevice>(), + const_cast<const Tensor&>(input).tensor<T, 4>(), + transformed_input.tensor<T, 4>()); + input = transformed_input; + } else { + // If depth <= 1, then just reshape. + CHECK(input.CopyFrom(input, nchw_shape)); } + } - CHECK(padding_rows >= 0 && padding_cols >= 0) - << "Negative row or col paddings: (" << padding_rows << ", " - << padding_cols << ")"; - perftools::gputools::dnn::BatchDescriptor input_desc; - input_desc.set_count(in_batch) - .set_feature_map_count(in_depths) - .set_height(in_rows) - .set_width(in_cols) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::BatchDescriptor output_desc; - output_desc.set_count(out_batch) - .set_height(out_rows) - .set_width(out_cols) - .set_feature_map_count(out_depths) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(filter.dim_size(0)) - .set_input_filter_width(filter.dim_size(1)) - .set_input_feature_map_count(filter.dim_size(2)) - .set_output_feature_map_count(filter.dim_size(3)); - perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(row_stride) - .set_horizontal_filter_stride(col_stride) - .set_zero_padding_height(padding_rows / 2) - .set_zero_padding_width(padding_cols / 2); - - Tensor transformed_filter; - OP_REQUIRES_OK(ctx, - ctx->allocate_temp( - DataTypeToEnum<T>::value, - TensorShape({filter.dim_size(3), filter.dim_size(2), - filter.dim_size(0), filter.dim_size(1)}), - &transformed_filter)); - - functor::TransformFilter<GPUDevice, T, int, 4>()( - ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()), - To32Bit(transformed_filter.tensor<T, 4>())); - - Tensor transformed_output; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, - ShapeFromFormat(FORMAT_NCHW, out_batch, - out_rows, out_cols, out_depths), - &transformed_output)); - - auto input_ptr = AsDeviceMemory(input.template flat<T>().data(), - input.template flat<T>().size()); - auto filter_ptr = - AsDeviceMemory(transformed_filter.template flat<T>().data(), - transformed_filter.template flat<T>().size()); - auto output_ptr = - AsDeviceMemory(transformed_output.template flat<T>().data(), - transformed_output.template flat<T>().size()); - - static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( - "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default - ); - - int device_id = stream->parent()->device_ordinal(); - ConvParameters conv_parameters = { - in_batch, // batch - in_depths, // in_depths - in_rows, // in_rows - in_cols, // in_cols - out_depths, // out_depths - patch_rows, // filter_rows - patch_cols, // filter_cols - row_stride, // stride_rows - col_stride, // stride_cols - padding_rows, // padding_rows - padding_cols, // padding_cols - device_id, // device_id - }; - AlgorithmConfig algorithm_config; - if (cudnn_use_autotune && - !autotune_results_.Find(conv_parameters, &algorithm_config)) { - std::vector<AlgorithmType> algorithms; - CHECK(stream->parent()->GetConvolveAlgorithms(&algorithms)); - ProfileResult best_result; - ProfileResult best_result_no_scratch; - for (auto profile_algorithm : algorithms) { - // TODO(zhengxq): profile each algorithm multiple times to better - // accuracy. - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); - ProfileResult profile_result; - bool cudnn_launch_status = - stream - ->ThenConvolveWithAlgorithm( - input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, - output_desc, &output_ptr, &scratch_allocator, - AlgorithmConfig(profile_algorithm), &profile_result) - .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; - } + CHECK(padding_rows >= 0 && padding_cols >= 0) + << "Negative row or col paddings: (" << padding_rows << ", " + << padding_cols << ")"; + perftools::gputools::dnn::BatchDescriptor input_desc; + input_desc.set_count(in_batch) + .set_feature_map_count(in_depths) + .set_height(in_rows) + .set_width(in_cols) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc; + output_desc.set_count(out_batch) + .set_height(out_rows) + .set_width(out_cols) + .set_feature_map_count(out_depths) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(filter.dim_size(0)) + .set_input_filter_width(filter.dim_size(1)) + .set_input_feature_map_count(filter.dim_size(2)) + .set_output_feature_map_count(filter.dim_size(3)); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + conv_desc.set_vertical_filter_stride(row_stride) + .set_horizontal_filter_stride(col_stride) + .set_zero_padding_height(padding_rows / 2) + .set_zero_padding_width(padding_cols / 2); + + Tensor transformed_filter; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({filter.dim_size(3), filter.dim_size(2), + filter.dim_size(0), filter.dim_size(1)}), + &transformed_filter)); + + functor::TransformFilter<GPUDevice, T, int, 4>()( + ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()), + To32Bit(transformed_filter.tensor<T, 4>())); + + Tensor transformed_output; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, + ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows, + out_cols, out_depths), + &transformed_output)); + + auto input_ptr = AsDeviceMemory(input.template flat<T>().data(), + input.template flat<T>().size()); + auto filter_ptr = + AsDeviceMemory(transformed_filter.template flat<T>().data(), + transformed_filter.template flat<T>().size()); + auto output_ptr = + AsDeviceMemory(transformed_output.template flat<T>().data(), + transformed_output.template flat<T>().size()); + + static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default + ); + + int device_id = stream->parent()->device_ordinal(); + ConvParameters conv_parameters = { + in_batch, // batch + in_depths, // in_depths + in_rows, // in_rows + in_cols, // in_cols + out_depths, // out_depths + patch_rows, // filter_rows + patch_cols, // filter_cols + row_stride, // stride_rows + col_stride, // stride_cols + padding_rows, // padding_rows + padding_cols, // padding_cols + device_id, // device_id + }; + AlgorithmConfig algorithm_config; + if (cudnn_use_autotune && + !autotune_results_.Find(conv_parameters, &algorithm_config)) { + std::vector<AlgorithmType> algorithms; + CHECK(stream->parent()->GetConvolveAlgorithms(&algorithms)); + ProfileResult best_result; + ProfileResult best_result_no_scratch; + for (auto profile_algorithm : algorithms) { + // TODO(zhengxq): profile each algorithm multiple times to better + // accuracy. + CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + ProfileResult profile_result; + bool cudnn_launch_status = + stream + ->ThenConvolveWithAlgorithm( + input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, + output_desc, &output_ptr, &scratch_allocator, + AlgorithmConfig(profile_algorithm), &profile_result) + .ok(); + if (cudnn_launch_status) { + if (profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalByteSize() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_no_scratch.elapsed_time_in_ms()) { + best_result_no_scratch = profile_result; } } } - OP_REQUIRES(ctx, best_result.is_valid() && - best_result.algorithm() != kDefaultAlgorithm, - errors::NotFound("No algorithm worked!")); - OP_REQUIRES(ctx, - best_result_no_scratch.is_valid() && - best_result_no_scratch.algorithm() != kDefaultAlgorithm, - errors::NotFound("No algorithm without scratch worked!")); - algorithm_config.set_algorithm(best_result.algorithm()); - algorithm_config.set_algorithm_no_scratch( - best_result_no_scratch.algorithm()); - autotune_results_.Insert(conv_parameters, algorithm_config); - } - - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); - bool cudnn_launch_status = - stream - ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc, - filter_ptr, conv_desc, output_desc, - &output_ptr, &scratch_allocator, - algorithm_config, nullptr) - .ok(); - - if (!cudnn_launch_status) { - ctx->SetStatus(errors::Internal( - "cuDNN launch failure : input shape(", input.shape().DebugString(), - ") filter shape(", filter.shape().DebugString(), ")")); - } - - // Convert the output tensor back from NHWC to NCHW. - if (data_format == FORMAT_NHWC) { - functor::NCHWToNHWC<GPUDevice, T, 4>()( - ctx->eigen_device<GPUDevice>(), - const_cast<const Tensor&>(transformed_output).tensor<T, 4>(), - output->tensor<T, 4>()); - } else { - *output = transformed_output; } + OP_REQUIRES(ctx, best_result.is_valid() && + best_result.algorithm() != kDefaultAlgorithm, + errors::NotFound("No algorithm worked!")); + OP_REQUIRES(ctx, + best_result_no_scratch.is_valid() && + best_result_no_scratch.algorithm() != kDefaultAlgorithm, + errors::NotFound("No algorithm without scratch worked!")); + algorithm_config.set_algorithm(best_result.algorithm()); + algorithm_config.set_algorithm_no_scratch( + best_result_no_scratch.algorithm()); + autotune_results_.Insert(conv_parameters, algorithm_config); } - private: - AutoTuneMap<ConvParameters, perftools::gputools::dnn::AlgorithmConfig> - autotune_results_; -}; + CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + bool cudnn_launch_status = + stream + ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc, + filter_ptr, conv_desc, output_desc, + &output_ptr, &scratch_allocator, + algorithm_config, nullptr) + .ok(); + + if (!cudnn_launch_status) { + ctx->SetStatus(errors::Internal( + "cuDNN launch failure : input shape(", input.shape().DebugString(), + ") filter shape(", filter.shape().DebugString(), ")")); + } -#endif // GOOGLE_CUDA + // Convert the output tensor back from NHWC to NCHW. + if (data_format == FORMAT_NHWC) { + functor::NCHWToNHWC<GPUDevice, T, 4>()( + ctx->eigen_device<GPUDevice>(), + const_cast<const Tensor&>(transformed_output).tensor<T, 4>(), + output->tensor<T, 4>()); + } else { + *output = transformed_output; + } +} -#if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -577,6 +569,9 @@ REGISTER_KERNEL_BUILDER( Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<float>("T"), Conv2DOp<GPUDevice, float>); +// To be used inside depthwise_conv_op.cc. +template class LaunchConv2DOp<GPUDevice, float>; + #endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h new file mode 100644 index 0000000000..d09db3dc15 --- /dev/null +++ b/tensorflow/core/kernels/conv_ops.h @@ -0,0 +1,58 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_CONV_OPS_H_ +#define TENSORFLOW_KERNELS_CONV_OPS_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/util/tensor_format.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +// Forward declaration. +class OpKernelContext; + +template <typename Device, typename T> +class LaunchConv2DOp { + public: + void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Eigen::PaddingType& padding, Tensor* output, + TensorFormat data_format); +}; + +#ifdef GOOGLE_CUDA +template <typename T> +class LaunchConv2DOp<Eigen::GpuDevice, T> { + public: + void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Eigen::PaddingType& padding, Tensor* output, + TensorFormat data_format); + + private: + AutoTuneMap<ConvParameters, perftools::gputools::dnn::AlgorithmConfig> + autotune_results_; +}; +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CONV_OPS_H diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 6cfa541d06..26d03908d3 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -19,6 +19,7 @@ limitations under the License. #if GOOGLE_CUDA #include <tuple> +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/stream_executor.h" namespace tensorflow { @@ -26,8 +27,8 @@ namespace tensorflow { // TODO(zhengxq): move this to gpu_util.h. The use of such wrappers is wide // spread. template <typename T> -perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, - uint64 size) { +inline perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, + uint64 size) { perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T)); perftools::gputools::DeviceMemory<T> typed(wrapped); diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index 88d2651938..172bc1bcba 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include <algorithm> #include <cmath> +#include <type_traits> #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -25,12 +26,14 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/conv_ops.h" #include "tensorflow/core/kernels/depthwise_conv_op.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA @@ -225,6 +228,9 @@ struct LaunchDepthwiseConvOp<CPUDevice, T> { } }; +// Extern template instantiated in conv_ops.cc. +extern template class LaunchConv2DOp<CPUDevice, float>; + #if GOOGLE_CUDA template <typename T> @@ -247,6 +253,9 @@ struct LaunchDepthwiseConvOp<GPUDevice, T> { } }; +// Extern template instantiated in conv_ops.cc. +extern template class LaunchConv2DOp<GPUDevice, float>; + #endif template <typename Device, typename T> @@ -267,18 +276,20 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + + // For special case when in_depth == 1. + use_cudnn_ = CanUseCudnn(); + cudnn_use_autotune_ = CudnnUseAutotune(); } void Compute(OpKernelContext* context) override { // Input tensor is of the following dimensions: // [ batch, in_rows, in_cols, in_depth ] const Tensor& input = context->input(0); - auto input_ptr = input.template flat<T>().data(); // Input filter is of the following dimensions: // [ filter_rows, filter_cols, in_depth, depth_multiplier] const Tensor& filter = context->input(1); - auto filter_ptr = filter.template flat<T>().data(); // For 2D convolution, there should be 4 dimensions. OP_REQUIRES(context, input.dims() == 4, @@ -338,7 +349,26 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> { // [ in_batch, out_rows, out_cols, out_depth ] Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); - auto output_ptr = output->template flat<T>().data(); + + VLOG(2) << "DepthwiseConv2dNative: " + << " Input: [" << batch << ", " << input_rows << ", " << input_cols + << ", " << in_depth << "]; Filter: [" << filter_rows << ", " + << filter_cols << ", " << in_depth << ", " << depth_multiplier + << "]; stride = " << stride << ", pad_rows = " << pad_rows + << ", pad_cols = " << pad_cols << ", output: [" << batch << ", " + << out_rows << ", " << out_cols << ", " << out_depth << "]"; + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + + if (std::is_same<T, float>::value && in_depth == 1) { + launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter, + stride, stride, BrainPadding2EigenPadding(padding_), + output, FORMAT_NHWC); + return; + } DepthwiseArgs args; args.batch = batch; @@ -355,18 +385,9 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> { args.out_cols = out_cols; args.out_depth = out_depth; - VLOG(2) << "DepthwiseConv2dNative: " - << " Input: [" << batch << ", " << input_rows << ", " << input_cols - << ", " << in_depth << "]; Filter: [" << filter_rows << ", " - << filter_cols << ", " << in_depth << ", " << depth_multiplier - << "]; stride = " << stride << ", pad_rows = " << pad_rows - << ", pad_cols = " << pad_cols << ", output: [" << batch << ", " - << out_rows << ", " << out_cols << ", " << out_depth << "]"; - - // If there is nothing to compute, return. - if (out_shape.num_elements() == 0) { - return; - } + auto input_ptr = input.template flat<T>().data(); + auto filter_ptr = filter.template flat<T>().data(); + auto output_ptr = output->template flat<T>().data(); LaunchDepthwiseConvOp<Device, T>::launch(context, args, input_ptr, filter_ptr, output_ptr); } @@ -375,6 +396,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> { std::vector<int32> strides_; Padding padding_; + // For the case in_depth == 1. + LaunchConv2DOp<Device, T> launcher_; + bool use_cudnn_; + bool cudnn_use_autotune_; + TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp); }; diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py index 96c00ae01b..7a1545daca 100644 --- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -29,17 +29,17 @@ def ConfigsToTest(): convolution parameters. """ input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 35, 35, 2], - [4, 147, 147, 2], [3, 299, 299, 3]] + [4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]] filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [5, 5, 2, 1], - [3, 3, 2, 8], [2, 2, 3, 8]] + [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]] out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 35, 35, 2], - [4, 49, 49, 16], [3, 150, 150, 24]] - strides = [1, 1, 1, 1, 3, 2] + [4, 49, 49, 16], [3, 150, 150, 24], [5, 92, 92, 2]] + strides = [1, 1, 1, 1, 3, 2, 2] # pylint: disable=invalid-name VALID = "VALID" SAME = "SAME" # pylint: enable=invalid-name - paddings = [SAME, SAME, SAME, SAME, VALID, SAME, SAME] + paddings = [SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME] for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, paddings): yield i, f, o, s, p diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 303932f9dd..fa3fba3375 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -586,7 +586,7 @@ def zero_fraction(value, name=None): math_ops.cast(math_ops.equal(value, zero), dtypes.float32)) -# pylint: disable=redefined-builtin,line-too-long +# pylint: disable=redefined-builtin def depthwise_conv2d(input, filter, strides, padding, name=None): """Depthwise 2-D convolution. @@ -625,28 +625,10 @@ def depthwise_conv2d(input, filter, strides, padding, name=None): with ops.name_scope(name, "depthwise", [input, filter]) as name: input = ops.convert_to_tensor(input, name="tensor_in") filter = ops.convert_to_tensor(filter, name="filter_in") - # A shape is required to statically compute the number of separable filters. - if filter.get_shape().ndims is not None: - assert len(filter.get_shape()) == 4 - in_channels = filter.get_shape()[2] - # Sanity checks, if shape information is available for the inputs. - if input.get_shape().ndims is not None: - assert len(input.get_shape()) == 4 - assert input.get_shape()[3] == in_channels, ( - "Mismatched input depth %d and number of depthwise filters %d." % - (input.get_shape()[3].value, in_channels)) - else: - assert input.get_shape().ndims is not None, ( - "Either tensor must provide static shape information.") - assert input.get_shape().ndims == 4 - in_channels = input.get_shape()[3] - - if in_channels == 1: - return nn_ops.conv2d(input, filter, strides, padding, name=name) - else: - return nn_ops.depthwise_conv2d_native( - input, filter, strides, padding, name=name) -# pylint: enable=redefined-builtin,line-too-long + + return nn_ops.depthwise_conv2d_native( + input, filter, strides, padding, name=name) +# pylint: enable=redefined-builtin # pylint: disable=redefined-builtin,line-too-long @@ -702,21 +684,24 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides, pointwise_filter = ops.convert_to_tensor( pointwise_filter, name="pointwise_filter") - if pointwise_filter.get_shape().ndims is not None: - assert len(pointwise_filter.get_shape()) == 4 - assert pointwise_filter.get_shape()[0] == 1 - assert pointwise_filter.get_shape()[1] == 1 - if depthwise_filter.get_shape().ndims and input.get_shape().ndims: - channel_multiplier = depthwise_filter.get_shape()[3] - in_channels = input.get_shape()[3] - out_channels = pointwise_filter.get_shape()[3] - if channel_multiplier * in_channels > out_channels: - raise ValueError( - ("Refusing to perform an overparameterized separable " - "convolution: channel_multiplier * in_channels = " - "%d * %d = %d > %d = out_channels" % - (channel_multiplier, in_channels, - channel_multiplier * in_channels, out_channels))) + pointwise_filter_shape = pointwise_filter.get_shape().with_rank(4) + pointwise_filter_shape[0].assert_is_compatible_with(1) + pointwise_filter_shape[1].assert_is_compatible_with(1) + + channel_multiplier = depthwise_filter.get_shape().with_rank(4)[3] + in_channels = input.get_shape().with_rank(4)[3] + out_channels = pointwise_filter_shape[3] + + # If any of channel numbers is unknown, then the comparison below returns + # None. See TensorShape.__gt__(). + if channel_multiplier * in_channels > out_channels: + raise ValueError( + "Refusing to perform an overparameterized separable " + "convolution: channel_multiplier * in_channels = " + "%d * %d = %d > %d = out_channels" % + (channel_multiplier, in_channels, + channel_multiplier * in_channels, out_channels)) + # The layout of the ops in the graph are expected to be as follows: # depthwise_conv2d // Conv2D op corresponding to native deptwise conv. # separable_conv2d // Conv2D op corresponding to the pointwise conv. |