diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_input_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_grad_input_ops.cc | 55 |
1 files changed, 30 insertions, 25 deletions
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index ce561aa99c..0732bf4046 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -630,7 +630,7 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()( int col_stride, const Padding& padding, Tensor* in_backprop, TensorFormat data_format) { using perftools::gputools::dnn::AlgorithmConfig; - using perftools::gputools::dnn::AlgorithmType; + using perftools::gputools::dnn::AlgorithmDesc; using perftools::gputools::dnn::ProfileResult; std::vector<int32> strides(4, 1); @@ -870,34 +870,39 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()( AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find( conv_parameters, &algorithm_config)) { - std::vector<AlgorithmType> algorithms; + std::vector<AlgorithmDesc::Index> algorithms; CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms( conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); ProfileResult best_result; ProfileResult best_result_no_scratch; - for (auto profile_algorithm : algorithms) { - // TODO(zhengxq): profile each algorithm multiple times to better - // accuracy. - CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, - ctx); - ProfileResult profile_result; - bool cudnn_launch_status = - stream - ->ThenConvolveBackwardDataWithAlgorithm( - filter_desc, filter_ptr, output_desc, out_backprop_ptr, - conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, - AlgorithmConfig(profile_algorithm), &profile_result) - .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; + // TODO(benbarsdell): Ideally this should not attempt using tensor op math + // if it's not enabled. + for (bool use_tensor_ops : {false, true}) { + for (auto algo_index : algorithms) { + // TODO(zhengxq): profile each algorithm multiple times to better + // accuracy. + AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops); + CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, + ctx); + ProfileResult profile_result; + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardDataWithAlgorithm( + filter_desc, filter_ptr, output_desc, out_backprop_ptr, + conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, + AlgorithmConfig(profile_algorithm), &profile_result) + .ok(); + if (cudnn_launch_status) { + if (profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalByteSize() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_no_scratch.elapsed_time_in_ms()) { + best_result_no_scratch = profile_result; + } } } } |