diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_filter_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_grad_filter_ops.cc | 57 |
1 files changed, 31 insertions, 26 deletions
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 8eb705b2e5..641077ca65 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -555,7 +555,7 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()( int col_stride, const Padding& padding, Tensor* filter_backprop, TensorFormat data_format) { using perftools::gputools::dnn::AlgorithmConfig; - using perftools::gputools::dnn::AlgorithmType; + using perftools::gputools::dnn::AlgorithmDesc; using perftools::gputools::dnn::ProfileResult; std::vector<int32> strides(4, 1); @@ -816,35 +816,40 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()( AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find( conv_parameters, &algorithm_config)) { - std::vector<AlgorithmType> algorithms; + std::vector<AlgorithmDesc::Index> algorithms; CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms( conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); ProfileResult best_result; ProfileResult best_result_no_scratch; - for (auto profile_algorithm : algorithms) { - // TODO(zhengxq): profile each algorithm multiple times to better - // accuracy. - CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, - ctx); - ProfileResult profile_result; - bool cudnn_launch_status = - stream - ->ThenConvolveBackwardFilterWithAlgorithm( - input_desc, input_ptr, output_desc, out_backprop_ptr, - conv_desc, filter_desc, &filter_backprop_ptr, - &scratch_allocator, AlgorithmConfig(profile_algorithm), - &profile_result) - .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; + // TODO(benbarsdell): Ideally this should not attempt using tensor op math + // if it's not enabled. + for (bool use_tensor_ops : {false, true}) { + for (auto algo_index : algorithms) { + // TODO(zhengxq): profile each algorithm multiple times to better + // accuracy. + AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops); + CudnnScratchAllocator scratch_allocator( + ConvolveBackwardFilterScratchSize, ctx); + ProfileResult profile_result; + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardFilterWithAlgorithm( + input_desc, input_ptr, output_desc, out_backprop_ptr, + conv_desc, filter_desc, &filter_backprop_ptr, + &scratch_allocator, AlgorithmConfig(profile_algorithm), + &profile_result) + .ok(); + if (cudnn_launch_status) { + if (profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalByteSize() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_no_scratch.elapsed_time_in_ms()) { + best_result_no_scratch = profile_result; + } } } } |