aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_grad_input_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_input_ops.cc')
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc55
1 files changed, 30 insertions, 25 deletions
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index ce561aa99c..0732bf4046 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -630,7 +630,7 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
int col_stride, const Padding& padding, Tensor* in_backprop,
TensorFormat data_format) {
using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmType;
+ using perftools::gputools::dnn::AlgorithmDesc;
using perftools::gputools::dnn::ProfileResult;
std::vector<int32> strides(4, 1);
@@ -870,34 +870,39 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmType> algorithms;
+ std::vector<AlgorithmDesc::Index> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- for (auto profile_algorithm : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
- ctx);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardDataWithAlgorithm(
- filter_desc, filter_ptr, output_desc, out_backprop_ptr,
- conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
- AlgorithmConfig(profile_algorithm), &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
+ // TODO(benbarsdell): Ideally this should not attempt using tensor op math
+ // if it's not enabled.
+ for (bool use_tensor_ops : {false, true}) {
+ for (auto algo_index : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
+ ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardDataWithAlgorithm(
+ filter_desc, filter_ptr, output_desc, out_backprop_ptr,
+ conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
+ }
}
}
}