aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2016-05-11 09:28:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-11 10:32:20 -0700
commit939ede027be73ecafcc422371afe27dceccc720d (patch)
tree2ac4b3f35f4e0744d5993271a01185e6bcc8905e
parent1f8fe742e11de53ccbb34d9fa540302156cb1655 (diff)
Add the autotune version for the backward passes. It is currently
disabled by default, and can be enabled through the env-var "TF_CUDNN_USE_AUTOTUNE=1". It will eventually be turned on by default. The following is the benchmarks with large enough changes. Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_ConvFloatFwdGPU_conv13 3810933 2167784 +43.1% BM_ConvFloatFwdGPU_conv23 4173607 2450503 +41.3% BM_ConvFloatFwdGPU_conv54 26731131 7098361 +73.4% BM_ConvFloatBkInGPU_conv1 1496407 1039979 +30.5% BM_ConvFloatBkInGPU_conv2 1501744 999774 +33.4% BM_ConvFloatBkInGPU_conv12 6826426 968258 +85.8% BM_ConvFloatBkFilterGPU_conv13 3852185 2110649 +45.2% BM_ConvFloatBkInGPU_conv15 7011109 910837 +87.0% BM_ConvFloatBkInGPU_conv17 2724054 1930013 +29.1% BM_ConvFloatBkInGPU_conv18 2940634 1846089 +37.2% BM_ConvFloatBkInGPU_conv19 2995599 1853970 +38.1% BM_ConvFloatBkInGPU_conv22 2685772 1940984 +27.7% BM_ConvFloatBkInGPU_conv24 2343034 1519468 +35.1% BM_ConvFloatBkInGPU_conv27 2339471 1516779 +35.2% BM_ConvFloatBkFilterGPU_conv28 3091452 1880773 +39.2% BM_ConvFloatBkInGPU_conv31 1265237 1120846 +11.4% BM_ConvFloatBkInGPU_conv46 3346414 2070659 +38.1% BM_ConvFloatBkFilterGPU_conv52 20677347 14342254 +30.6% BM_ConvFloatBkInGPU_conv54 13291278 10495521 +21.0% Change: 122067373
-rw-r--r--tensorflow/core/kernels/conv_2d.h6
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc116
-rw-r--r--tensorflow/core/kernels/conv_ops.cc81
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h73
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc371
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h12
-rw-r--r--tensorflow/stream_executor/dnn.cc10
-rw-r--r--tensorflow/stream_executor/dnn.h28
-rw-r--r--tensorflow/stream_executor/stream.cc76
-rw-r--r--tensorflow/stream_executor/stream.h22
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc18
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h9
12 files changed, 629 insertions, 193 deletions
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index 9bbc67520f..40ee4420bb 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -249,6 +249,12 @@ struct ReverseTransformFilter {
};
} // namespace functor
+
+template <class T>
+class ConvAlgorithmMap;
+
+template <>
+class ConvAlgorithmMap<Eigen::ThreadPoolDevice> {};
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_CONV_2D_H_
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index 84cc7017c4..0057db6967 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -816,6 +816,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
+ cudnn_use_autotune_ = CudnnUseAutotune();
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
@@ -994,11 +995,62 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
);
CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
context);
+ int device_id = stream->parent()->device_ordinal();
+ ConvParameters conv_parameters = {
+ batch, // batch
+ in_depth, // in_depths
+ input_desc.height(), // in_rows
+ input_desc.width(), // in_cols
+ out_depth, // out_depths
+ filter_rows, // filter_rows
+ filter_cols, // filter_cols
+ stride_rows, // stride_rows
+ stride_cols, // stride_cols
+ padding_rows, // padding_rows
+ padding_cols, // padding_cols
+ device_id, // device_id
+ };
+ using namespace perftools::gputools::dnn;
+ AlgorithmType algorithm = kDefaultAlgorithm;
+ if (cudnn_use_autotune_ &&
+ !conv_algorithm_map_.Find(conv_parameters, &algorithm)) {
+ std::vector<AlgorithmType> algorithms;
+ CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(&algorithms));
+ ProfileResult best_result;
+ best_result.set_elapsed_time_in_ms(std::numeric_limits<float>::max());
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
+ context);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardDataWithAlgorithm(
+ filter_desc, filter_ptr, output_desc, out_backprop_ptr,
+ conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
+ profile_algorithm, &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid() &&
+ profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ }
+ }
+ CHECK(best_result.is_valid() &&
+ best_result.algorithm() != kDefaultAlgorithm)
+ << "No algorithm worked!";
+ algorithm = best_result.algorithm();
+ conv_algorithm_map_.Insert(conv_parameters, algorithm);
+ }
bool cudnn_launch_status =
stream
- ->ThenConvolveBackwardDataWithScratch(
+ ->ThenConvolveBackwardDataWithAlgorithm(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
- conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator)
+ conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
+ algorithm, nullptr)
.ok();
if (!cudnn_launch_status) {
@@ -1048,6 +1100,8 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
Padding padding_;
bool use_cudnn_;
TensorFormat data_format_;
+ ConvAlgorithmMap<Device> conv_algorithm_map_;
+ bool cudnn_use_autotune_;
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp);
};
@@ -1071,6 +1125,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
+ cudnn_use_autotune_ = CudnnUseAutotune();
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
@@ -1267,13 +1322,64 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
);
+ int device_id = stream->parent()->device_ordinal();
+ ConvParameters conv_parameters = {
+ batch, // batch
+ in_depth, // in_depths
+ input_desc.height(), // in_rows
+ input_desc.width(), // in_cols
+ out_depth, // out_depths
+ filter_rows, // filter_rows
+ filter_cols, // filter_cols
+ stride_rows, // stride_rows
+ stride_cols, // stride_cols
+ padding_rows, // padding_rows
+ padding_cols, // padding_cols
+ device_id, // device_id
+ };
+ using namespace perftools::gputools::dnn;
+ AlgorithmType algorithm = kDefaultAlgorithm;
+ if (cudnn_use_autotune_ &&
+ !conv_algorithm_map_.Find(conv_parameters, &algorithm)) {
+ std::vector<AlgorithmType> algorithms;
+ CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(&algorithms));
+ ProfileResult best_result;
+ best_result.set_elapsed_time_in_ms(std::numeric_limits<float>::max());
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(
+ ConvolveBackwardFilterScratchSize, context);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardFilterWithAlgorithm(
+ input_desc, input_ptr, output_desc, out_backprop_ptr,
+ conv_desc, filter_desc, &filter_backprop_ptr,
+ &scratch_allocator, profile_algorithm, &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid() &&
+ profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ }
+ }
+ CHECK(best_result.is_valid() &&
+ best_result.algorithm() != kDefaultAlgorithm)
+ << "No algorithm worked!";
+ algorithm = best_result.algorithm();
+ conv_algorithm_map_.Insert(conv_parameters, algorithm);
+ }
CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
context);
bool cudnn_launch_status =
stream
- ->ThenConvolveBackwardFilterWithScratch(
+ ->ThenConvolveBackwardFilterWithAlgorithm(
input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
- filter_desc, &filter_backprop_ptr, &scratch_allocator)
+ filter_desc, &filter_backprop_ptr, &scratch_allocator,
+ algorithm, nullptr)
.ok();
if (!cudnn_launch_status) {
@@ -1295,6 +1401,8 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
Padding padding_;
bool use_cudnn_;
TensorFormat data_format_;
+ ConvAlgorithmMap<Device> conv_algorithm_map_;
+ bool cudnn_use_autotune_;
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp);
};
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index ccb83fff74..de8d8e784d 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -87,87 +87,6 @@ struct LaunchGeneric {
}
};
-struct ConvParameters {
- int64 batch;
- int64 in_depths;
- int64 in_rows;
- int64 in_cols;
- int64 out_depths;
- int64 filter_rows;
- int64 filter_cols;
- int64 stride_rows;
- int64 stride_cols;
- int64 padding_rows;
- int64 padding_cols;
- int device_id;
-
- bool operator==(const ConvParameters& other) const {
- return memcmp(this, &other, sizeof(ConvParameters)) == 0;
- }
-
- bool operator!=(const ConvParameters& other) const {
- return !(*this == other);
- }
-
- bool operator<(const ConvParameters& other) const {
- return memcmp(this, &other, sizeof(ConvParameters)) < 0;
- }
-};
-
-template <class T>
-class ConvAlgorithmMap;
-
-template <>
-class ConvAlgorithmMap<CPUDevice> {};
-
-#if GOOGLE_CUDA
-
-// A helper class that looks up algorithm from conv-parameters. It is heavily
-// biased toward the last-seen parameter.
-template <>
-class ConvAlgorithmMap<GPUDevice> {
- public:
- typedef perftools::gputools::dnn::AlgorithmType AlgorithmType;
-
- ConvAlgorithmMap() {}
-
- bool Find(const ConvParameters& parameters, AlgorithmType* algorithm) const {
- mutex_lock lock(mu_);
- if (algorithm_map_.empty()) {
- return false;
- }
- if (parameters != last_conv_parameters_) {
- auto iter = algorithm_map_.find(parameters);
- if (iter == algorithm_map_.end()) {
- return false;
- }
- last_conv_parameters_ = parameters;
- last_algorithm_ = iter->second;
- }
- *algorithm = last_algorithm_;
- return true;
- }
-
- void Insert(const ConvParameters& parameters, AlgorithmType algorithm) {
- mutex_lock lock(mu_);
- last_conv_parameters_ = parameters;
- last_algorithm_ = algorithm;
- algorithm_map_[parameters] = algorithm;
- }
-
- private:
- AlgorithmType FindAlgorithm(const ConvParameters& parameters);
-
- mutable mutex mu_;
- std::map<ConvParameters, AlgorithmType> algorithm_map_ GUARDED_BY(mu_);
- mutable ConvParameters last_conv_parameters_ GUARDED_BY(mu_);
- mutable AlgorithmType last_algorithm_ GUARDED_BY(mu_);
-
- TF_DISALLOW_COPY_AND_ASSIGN(ConvAlgorithmMap);
-};
-
-#endif // GOOGLE_CUDA
-
template <typename Device, typename T>
struct LaunchConvOp;
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index a2c8f980d9..419ba4dfc6 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -81,6 +81,79 @@ class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
std::vector<Tensor> allocated_tensors_;
};
+struct ConvParameters {
+ int64 batch;
+ int64 in_depths;
+ int64 in_rows;
+ int64 in_cols;
+ int64 out_depths;
+ int64 filter_rows;
+ int64 filter_cols;
+ int64 stride_rows;
+ int64 stride_cols;
+ int64 padding_rows;
+ int64 padding_cols;
+ int device_id;
+
+ bool operator==(const ConvParameters& other) const {
+ return memcmp(this, &other, sizeof(ConvParameters)) == 0;
+ }
+
+ bool operator!=(const ConvParameters& other) const {
+ return !(*this == other);
+ }
+
+ bool operator<(const ConvParameters& other) const {
+ return memcmp(this, &other, sizeof(ConvParameters)) < 0;
+ }
+};
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// A helper class that looks up algorithm from conv-parameters. It is heavily
+// biased toward the last-seen parameter.
+template <>
+class ConvAlgorithmMap<GPUDevice> {
+ public:
+ typedef perftools::gputools::dnn::AlgorithmType AlgorithmType;
+
+ ConvAlgorithmMap() {}
+
+ bool Find(const ConvParameters& parameters, AlgorithmType* algorithm) const {
+ mutex_lock lock(mu_);
+ if (algorithm_map_.empty()) {
+ return false;
+ }
+ if (parameters != last_conv_parameters_) {
+ auto iter = algorithm_map_.find(parameters);
+ if (iter == algorithm_map_.end()) {
+ return false;
+ }
+ last_conv_parameters_ = parameters;
+ last_algorithm_ = iter->second;
+ }
+ *algorithm = last_algorithm_;
+ return true;
+ }
+
+ void Insert(const ConvParameters& parameters, AlgorithmType algorithm) {
+ mutex_lock lock(mu_);
+ last_conv_parameters_ = parameters;
+ last_algorithm_ = algorithm;
+ algorithm_map_[parameters] = algorithm;
+ }
+
+ private:
+ AlgorithmType FindAlgorithm(const ConvParameters& parameters);
+
+ mutable mutex mu_;
+ std::map<ConvParameters, AlgorithmType> algorithm_map_ GUARDED_BY(mu_);
+ mutable ConvParameters last_conv_parameters_ GUARDED_BY(mu_);
+ mutable AlgorithmType last_algorithm_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ConvAlgorithmMap);
+};
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index f35c59a82a..84d5399022 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -275,6 +275,39 @@ cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmType algorithm) {
}
}
+cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
+ dnn::AlgorithmType algorithm) {
+ cudnnConvolutionBwdDataAlgo_t algo = cudnnConvolutionBwdDataAlgo_t(algorithm);
+ switch (algo) {
+ case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
+ case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
+ case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
+ case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
+ return algo;
+ default:
+ LOG(FATAL)
+ << "Unsupported Cudnn convolution backward algorithm for data: "
+ << algorithm;
+ }
+}
+
+cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
+ dnn::AlgorithmType algorithm) {
+ cudnnConvolutionBwdFilterAlgo_t algo =
+ cudnnConvolutionBwdFilterAlgo_t(algorithm);
+ switch (algo) {
+ case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
+ case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
+ case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
+ case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
+ return algo;
+ default:
+ LOG(FATAL)
+ << "Unsupported Cudnn convolution backward algorithm for filter: "
+ << algorithm;
+ }
+}
+
} // namespace
CudnnSupport::CudnnSupport(CUDAExecutor* parent)
@@ -784,7 +817,6 @@ bool CudnnSupport::DoConvolve(
// to this stream. So it could take multiple profiling measurements.
timer->Start(AsCUDAStream(stream));
}
-
status = dynload::cudnnConvolutionForward(
parent_, ToHandle(dnn_handle_),
/*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
@@ -793,16 +825,6 @@ bool CudnnSupport::DoConvolve(
/*algo=*/algo, /*workSpace=*/scratch.opaque(),
/*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta,
/*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
-
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- // Silently return when we are profiling.
- return false;
- }
- LOG(FATAL) << "failed to enqueue convolution on stream: "
- << ToString(status);
- }
-
if (is_profiling) {
timer->Stop(AsCUDAStream(stream));
output_profile_result->set_is_valid(true);
@@ -812,6 +834,15 @@ bool CudnnSupport::DoConvolve(
timer->Destroy();
}
+ if (status != CUDNN_STATUS_SUCCESS) {
+ // Silently return when we are profiling.
+ if (!is_profiling) {
+ LOG(FATAL) << "failed to enqueue convolution on stream: "
+ << ToString(status);
+ }
+ return false;
+ }
+
return true;
}
@@ -830,6 +861,32 @@ bool CudnnSupport::GetConvolveAlgorithms(
return true;
}
+bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
+ std::vector<dnn::AlgorithmType>* out_algorithms) {
+ out_algorithms->assign({
+ // clang-format off
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
+ // clang-format on
+ });
+ return true;
+}
+
+bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
+ std::vector<dnn::AlgorithmType>* out_algorithms) {
+ out_algorithms->assign({
+ // clang-format off
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
+ // clang-format on
+ });
+ return true;
+}
+
bool CudnnSupport::DoConvolve(
Stream* stream, const BatchDescriptor& batch_descriptor,
const DeviceMemory<double>& input_data,
@@ -883,7 +940,8 @@ bool CudnnSupport::DoConvolveBackwardData(
const ConvolutionDescriptor& convolution_descriptor,
const BatchDescriptor& input_descriptor,
DeviceMemory<float>* backward_input_data,
- ScratchAllocator* scratch_allocator) {
+ ScratchAllocator* scratch_allocator, dnn::AlgorithmType algorithm,
+ dnn::ProfileResult* output_profile_result) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
@@ -937,40 +995,68 @@ bool CudnnSupport::DoConvolveBackwardData(
#endif
#if CUDNN_VERSION >= 3000
- auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
- dnn_handle_mutex_) -> cudnnConvolutionBwdDataAlgo_t {
- cudnnConvolutionBwdDataPreference_t preference =
- specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
- : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
-
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
+ const bool is_profiling = output_profile_result != nullptr;
+ cudnnConvolutionBwdDataAlgo_t algo;
+ DeviceMemory<uint8> scratch;
- cudnnConvolutionBwdDataAlgo_t algo;
- cudnnStatus_t status = dynload::cudnnGetConvolutionBackwardDataAlgorithm(
- parent_, ToHandle(dnn_handle_),
- /*filterDesc=*/filter.handle(),
- /*diffDesc=*/out_back_nd.handle(),
- /*convDesc=*/conv.handle(),
- /*gradDesc=*/in_back_nd.handle(),
- /*preference=*/preference,
- /*memoryLimitInBytes=*/memory_limit_bytes,
- /*algo=*/&algo);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
- "algorithm for doing backward "
- "filter convolution";
- return algo;
- };
+ if (algorithm == dnn::kDefaultAlgorithm) {
+ // With the default algorithm, use Cudnn's heuristics.
+ auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
+ dnn_handle_mutex_) -> cudnnConvolutionBwdDataAlgo_t {
+ cudnnConvolutionBwdDataPreference_t preference =
+ specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
+ : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
+
+ auto memory_limit_bytes =
+ scratch_allocator == nullptr
+ ? 0
+ : scratch_allocator->GetMemoryLimitInBytes(stream);
+ if (memory_limit_bytes < 0) {
+ memory_limit_bytes = 0;
+ }
- auto algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
+ cudnnConvolutionBwdDataAlgo_t algo_to_use;
+ cudnnStatus_t status = dynload::cudnnGetConvolutionBackwardDataAlgorithm(
+ parent_, ToHandle(dnn_handle_),
+ /*filterDesc=*/filter.handle(),
+ /*diffDesc=*/out_back_nd.handle(),
+ /*convDesc=*/conv.handle(),
+ /*gradDesc=*/in_back_nd.handle(),
+ /*preference=*/preference,
+ /*memoryLimitInBytes=*/memory_limit_bytes,
+ /*algo=*/&algo_to_use);
+ CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
+ "algorithm for doing backward "
+ "filter convolution";
+ return algo_to_use;
+ };
- DeviceMemory<uint8> scratch;
- if (scratch_allocator != nullptr) {
+ algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
+
+ if (scratch_allocator != nullptr) {
+ size_t size_in_bytes;
+ status = dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
+ parent_, ToHandle(dnn_handle_),
+ /*filterDesc=*/filter.handle(),
+ /*diffDesc=*/out_back_nd.handle(),
+ /*convDesc=*/conv.handle(),
+ /*gradDesc=*/in_back_nd.handle(),
+ /*algo=*/algo,
+ /*sizeInBytes=*/&size_in_bytes);
+ if (status == CUDNN_STATUS_SUCCESS && size_in_bytes != 0) {
+ scratch = scratch_allocator->AllocateBytes(stream, size_in_bytes)
+ .ValueOrDie();
+ }
+ }
+
+ // If we didn't allocate any scratch space (perhaps because of failed
+ // allocation), we force a switch back to the "no workspace" algorithm.
+ if (scratch == nullptr) {
+ algo = get_algorithm(/*specify_limit=*/false);
+ }
+ } else {
+ // An algorithm has been specified.
+ algo = ToConvBackwardDataAlgo(algorithm);
size_t size_in_bytes;
status = dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
parent_, ToHandle(dnn_handle_),
@@ -980,16 +1066,37 @@ bool CudnnSupport::DoConvolveBackwardData(
/*gradDesc=*/in_back_nd.handle(),
/*algo=*/algo,
/*sizeInBytes=*/&size_in_bytes);
- if (status == CUDNN_STATUS_SUCCESS && size_in_bytes != 0) {
- scratch =
- scratch_allocator->AllocateBytes(stream, size_in_bytes).ValueOrDie();
+ if (status != CUDNN_STATUS_SUCCESS) {
+ if (is_profiling) {
+ // Silently return when we are profiling.
+ return false;
+ }
+ LOG(FATAL) << "Cannot query the size of workspace needed for the given "
+ "algorithm: "
+ << algorithm;
+ }
+ if (size_in_bytes != 0) {
+ if (scratch_allocator == nullptr) {
+ LOG(FATAL) << "An allocator must be specified when scratch memory is "
+ "needed";
+ }
+ auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
+ if (is_profiling && !allocated.ok()) {
+ // Silently return when we are profiling.
+ return false;
+ }
+ scratch = allocated.ValueOrDie();
}
}
- // If we didn't allocate any scratch space (perhaps because of failed
- // allocation), we force a switch back to the "no workspace" algorithm.
- if (scratch == nullptr) {
- algo = get_algorithm(/*specify_limit=*/false);
+ std::unique_ptr<CUDATimer> timer;
+ if (is_profiling) {
+ timer.reset(new CUDATimer(parent_));
+ timer->Init();
+ // The start and stop of the timer should be as close to the Cudnn call as
+ // possible. It is still possible for other threads to issue workload on
+ // to this stream. So it could take multiple profiling measurements.
+ timer->Start(AsCUDAStream(stream));
}
#if CUDNN_VERSION >= 5000
@@ -1010,9 +1117,20 @@ bool CudnnSupport::DoConvolveBackwardData(
/*beta=*/&beta,
/*gradDesc=*/in_back_nd.handle(),
/*gradData=*/backward_input_data->opaque());
+ if (is_profiling) {
+ timer->Stop(AsCUDAStream(stream));
+ output_profile_result->set_is_valid(true);
+ output_profile_result->set_algorithm(algo);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
+ timer->Destroy();
+ }
if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "failed to enqueue convolution on stream: "
- << ToString(status);
+ // Silently return when we are profiling.
+ if (!is_profiling) {
+ LOG(FATAL) << "failed to enqueue convolution on stream: "
+ << ToString(status);
+ }
return false;
}
return true;
@@ -1027,7 +1145,8 @@ bool CudnnSupport::DoConvolveBackwardFilter(
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
DeviceMemory<float>* backward_filter_data,
- ScratchAllocator* scratch_allocator) {
+ ScratchAllocator* scratch_allocator, dnn::AlgorithmType algorithm,
+ dnn::ProfileResult* output_profile_result) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
@@ -1080,59 +1199,108 @@ bool CudnnSupport::DoConvolveBackwardFilter(
#endif
#if CUDNN_VERSION >= 3000
- // Lambda that retrieves the algorithm.
- // specify_limit will occur when we have a scratch allocator and it succeeds
- // in allocating; otherwise, we'll fall back to the "no workspace" version.
- auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
- dnn_handle_mutex_) {
- cudnnConvolutionBwdFilterPreference_t preference =
- specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
- : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
-
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
+ const bool is_profiling = output_profile_result != nullptr;
+ cudnnConvolutionBwdFilterAlgo_t algo;
+ DeviceMemory<uint8> scratch;
- cudnnConvolutionBwdFilterAlgo_t algo;
- cudnnStatus_t status = dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
- parent_, ToHandle(dnn_handle_),
- /*srcDesc=*/input_nd.handle(),
- /*diffDesc=*/out_back_nd.handle(),
- /*convDesc=*/conv.handle(),
- /*gradDesc=*/filter.handle(),
- /*preference=*/preference,
- /*memoryLimitInBytes=*/memory_limit_bytes,
- /*algo=*/&algo);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
- "algorithm for doing backward "
- "filter convolution";
- return algo;
- };
+ if (algorithm == dnn::kDefaultAlgorithm) {
+ // With the default algorithm, use Cudnn's heuristics.
+
+ // Lambda that retrieves the algorithm.
+ // specify_limit will occur when we have a scratch allocator and it succeeds
+ // in allocating; otherwise, we'll fall back to the "no workspace" version.
+ auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
+ dnn_handle_mutex_) {
+ cudnnConvolutionBwdFilterPreference_t preference =
+ specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
+ : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
+
+ auto memory_limit_bytes =
+ scratch_allocator == nullptr
+ ? 0
+ : scratch_allocator->GetMemoryLimitInBytes(stream);
+ if (memory_limit_bytes < 0) {
+ memory_limit_bytes = 0;
+ }
- auto algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
+ cudnnConvolutionBwdFilterAlgo_t algo_to_use;
+ cudnnStatus_t status =
+ dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
+ parent_, ToHandle(dnn_handle_),
+ /*srcDesc=*/input_nd.handle(),
+ /*diffDesc=*/out_back_nd.handle(),
+ /*convDesc=*/conv.handle(),
+ /*gradDesc=*/filter.handle(),
+ /*preference=*/preference,
+ /*memoryLimitInBytes=*/memory_limit_bytes,
+ /*algo=*/&algo_to_use);
+ CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
+ "algorithm for doing backward "
+ "filter convolution";
+ return algo_to_use;
+ };
+
+ algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
+
+ if (scratch_allocator != nullptr) {
+ size_t size_in_bytes;
+ status = dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
+ parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
+ /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
+ /*gradDesc=*/filter.handle(), /*algo=*/algo,
+ /*sizeInBytes=*/&size_in_bytes);
+ if (status == CUDNN_STATUS_SUCCESS && size_in_bytes != 0) {
+ scratch = scratch_allocator->AllocateBytes(stream, size_in_bytes)
+ .ValueOrDie();
+ }
+ }
+
+ // If we didn't allocate any scratch space (perhaps because of failed
+ // allocation), we force a switch back to the "no workspace" algorithm.
+ if (scratch == nullptr) {
+ algo = get_algorithm(/*specify_limit=*/false);
+ }
+ } else {
+ // An algorithm has been specified.
+ algo = ToConvBackwardFilterAlgo(algorithm);
- DeviceMemory<uint8> scratch;
- if (scratch_allocator != nullptr) {
size_t size_in_bytes;
status = dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
/*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(), /*algo=*/algo,
/*sizeInBytes=*/&size_in_bytes);
- if (status == CUDNN_STATUS_SUCCESS && size_in_bytes != 0) {
- scratch =
- scratch_allocator->AllocateBytes(stream, size_in_bytes).ValueOrDie();
+ if (status != CUDNN_STATUS_SUCCESS) {
+ if (is_profiling) {
+ // Silently return when we are profiling.
+ return false;
+ }
+ LOG(FATAL) << "Cannot query the size of workspace needed for the given "
+ "algorithm: "
+ << algorithm;
+ }
+ if (size_in_bytes != 0) {
+ if (scratch_allocator == nullptr) {
+ LOG(FATAL) << "An allocator must be specified when scratch memory is "
+ "needed";
+ }
+ auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
+ if (is_profiling && !allocated.ok()) {
+ // Silently return when we are profiling.
+ return false;
+ }
+ scratch = allocated.ValueOrDie();
}
}
- // If we didn't allocate any scratch space (perhaps because of failed
- // allocation), we force a switch back to the "no workspace" algorithm.
- if (scratch == nullptr) {
- algo = get_algorithm(/*specify_limit=*/false);
+ std::unique_ptr<CUDATimer> timer;
+ if (is_profiling) {
+ timer.reset(new CUDATimer(parent_));
+ timer->Init();
+ // The start and stop of the timer should be as close to the Cudnn call as
+ // possible. It is still possible for other threads to issue workload on
+ // to this stream. So it could take multiple profiling measurements.
+ timer->Start(AsCUDAStream(stream));
}
#if CUDNN_VERSION >= 5000
@@ -1152,9 +1320,20 @@ bool CudnnSupport::DoConvolveBackwardFilter(
/*beta=*/&beta,
/*gradDesc=*/filter.handle(),
/*gradData=*/backward_filter_data->opaque());
+ if (is_profiling) {
+ timer->Stop(AsCUDAStream(stream));
+ output_profile_result->set_is_valid(true);
+ output_profile_result->set_algorithm(algo);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
+ timer->Destroy();
+ }
if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "failed to enqueue convolution on stream: "
- << ToString(status);
+ // Silently return when we are profiling.
+ if (!is_profiling) {
+ LOG(FATAL) << "failed to enqueue convolution on stream: "
+ << ToString(status);
+ }
return false;
}
return true;
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 76af118962..f034ca3f48 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -47,6 +47,12 @@ class CudnnSupport : public dnn::DnnSupport {
bool GetConvolveAlgorithms(
std::vector<dnn::AlgorithmType>* out_algorithms) override;
+ bool GetConvolveBackwardDataAlgorithms(
+ std::vector<dnn::AlgorithmType>* out_algorithms) override;
+
+ bool GetConvolveBackwardFilterAlgorithms(
+ std::vector<dnn::AlgorithmType>* out_algorithms) override;
+
bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<float>& input_data,
const dnn::FilterDescriptor& filter_descriptor,
@@ -87,7 +93,8 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& input_descriptor,
DeviceMemory<float>* backward_input_data,
- ScratchAllocator* scratch_allocator) override;
+ ScratchAllocator* scratch_allocator, dnn::AlgorithmType algorithm,
+ dnn::ProfileResult* output_profile_result) override;
bool DoConvolveBackwardFilter(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
@@ -97,7 +104,8 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
DeviceMemory<float>* backward_filter_data,
- ScratchAllocator* scratch_allocator) override;
+ ScratchAllocator* scratch_allocator, dnn::AlgorithmType algorithm,
+ dnn::ProfileResult* output_profile_result) override;
bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
const DeviceMemory<float>& weights,
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 7e98db723f..3865f85e7e 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -27,6 +27,16 @@ bool DnnSupport::GetConvolveAlgorithms(
return false;
}
+bool DnnSupport::GetConvolveBackwardDataAlgorithms(
+ std::vector<AlgorithmType>* out_algorithms) {
+ return false;
+}
+
+bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
+ std::vector<AlgorithmType>* out_algorithms) {
+ return false;
+}
+
string QuantizedActivationModeString(QuantizedActivationMode mode) {
switch (mode) {
case dnn::QuantizedActivationMode::k8Bit:
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 8cba8295db..8db80544de 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -559,14 +559,14 @@ class ProfileResult {
// input, across all coordinates (batch, y, x), by mapping each V to
// another vector U of the same size using the formula
//
-// V_i = U_i / ((bias + alpha * (sum_j U_j^2)) ^ beta)
+// U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
//
-// where the sum is taken for j in the inclusive range [i - range, i + range].
+// where the sum is taken over j in the closed range [i - range, i + range].
//
-// When calculating V_i the j in the sum can extend beyond the bounds
-// of U. If wrap_around is true, then U_j = U_{j mod F} where F is the
-// size of U, which is the number of feature maps. If wrap_around is
-// false, then U_j = 0 for j outside [0, F-1].
+// When calculating U_i the j in the sum can extend beyond the bounds
+// of V. If wrap_around is true, then V_j = V_{j mod F} where F is the
+// size of V, which is the number of feature maps. If wrap_around is
+// false, then V_j = 0 for j outside [0, F-1].
//
// If segment_size <= F, where F is the number of feature_maps, then
// segment_size has no effect. Otherwise, each consecutive segment of
@@ -769,7 +769,13 @@ class DnnSupport {
const ConvolutionDescriptor& convolution_descriptor,
const BatchDescriptor& input_descriptor,
DeviceMemory<float>* backward_input_data,
- ScratchAllocator* scratch_allocator) = 0;
+ ScratchAllocator* scratch_allocator, AlgorithmType algorithm,
+ ProfileResult* output_profile_result) = 0;
+
+ // Return a list of algorithms supported by the backward convolution pass for
+ // data.
+ virtual bool GetConvolveBackwardDataAlgorithms(
+ std::vector<AlgorithmType>* out_algorithms);
// Enqueues a single-precision backward convolution (for filter) operation
// onto the stream.
@@ -798,7 +804,13 @@ class DnnSupport {
const ConvolutionDescriptor& convolution_descriptor,
const FilterDescriptor& filter_descriptor,
DeviceMemory<float>* backward_filter_data,
- ScratchAllocator* scratch_allocator) = 0;
+ ScratchAllocator* scratch_allocator, AlgorithmType algorithm,
+ ProfileResult* output_profile_result) = 0;
+
+ // Return a list of algorithms supported by the backward convolution pass for
+ // filters.
+ virtual bool GetConvolveBackwardFilterAlgorithms(
+ std::vector<AlgorithmType>* out_algorithms);
// Fully connects the "nodes" (float values) in input_data with
// shape input_dimensions to output_data with output_dimensions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index b02df02c90..5b07f13037 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -403,7 +403,43 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch(
CheckError(dnn->DoConvolveBackwardData(
this, filter_descriptor, filter_data, output_descriptor,
backward_output_data, convolution_descriptor, input_descriptor,
- backward_input_data, scratch_allocator));
+ backward_input_data, scratch_allocator, dnn::kDefaultAlgorithm,
+ nullptr));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &input_descriptor,
+ DeviceMemory<float> *backward_input_data,
+ ScratchAllocator *scratch_allocator, dnn::AlgorithmType algorithm,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
+ PARAM(output_descriptor), PARAM(backward_output_data),
+ PARAM(convolution_descriptor), PARAM(input_descriptor),
+ PARAM(backward_input_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoConvolveBackwardData(
+ this, filter_descriptor, filter_data, output_descriptor,
+ backward_output_data, convolution_descriptor, input_descriptor,
+ backward_input_data, scratch_allocator, algorithm,
+ output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
} else {
SetError();
LOG(WARNING)
@@ -447,7 +483,43 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch(
CheckError(dnn->DoConvolveBackwardFilter(
this, input_descriptor, input_data, output_descriptor,
backward_output_data, convolution_descriptor, filter_descriptor,
- backward_filter_data, scratch_allocator));
+ backward_filter_data, scratch_allocator, dnn::kDefaultAlgorithm,
+ nullptr));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::FilterDescriptor &filter_descriptor,
+ DeviceMemory<float> *backward_filter_data,
+ ScratchAllocator *scratch_allocator, dnn::AlgorithmType algorithm,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+ PARAM(output_descriptor), PARAM(backward_output_data),
+ PARAM(convolution_descriptor), PARAM(filter_descriptor),
+ PARAM(backward_filter_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoConvolveBackwardFilter(
+ this, input_descriptor, input_data, output_descriptor,
+ backward_output_data, convolution_descriptor, filter_descriptor,
+ backward_filter_data, scratch_allocator, algorithm,
+ output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
} else {
SetError();
LOG(WARNING)
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index b800e03ae7..518228b0f2 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -269,6 +269,17 @@ class Stream {
DeviceMemory<float> *backward_input_data,
ScratchAllocator *scratch_allocator);
+ Stream &ThenConvolveBackwardDataWithAlgorithm(
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &input_descriptor,
+ DeviceMemory<float> *backward_input_data,
+ ScratchAllocator *scratch_allocator, dnn::AlgorithmType algorithm,
+ dnn::ProfileResult *output_profile_result);
+
Stream &ThenConvolveBackwardFilter(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
@@ -288,6 +299,17 @@ class Stream {
DeviceMemory<float> *backward_filter_data,
ScratchAllocator *scratch_allocator);
+ Stream &ThenConvolveBackwardFilterWithAlgorithm(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::FilterDescriptor &filter_descriptor,
+ DeviceMemory<float> *backward_filter_data,
+ ScratchAllocator *scratch_allocator, dnn::AlgorithmType algorithm,
+ dnn::ProfileResult *output_profile_result);
+
Stream &ThenMatMul(const DeviceMemory<float> &input_data,
const DeviceMemory<float> &weights,
const dnn::BatchDescriptor &input_dimensions,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index fe32039d71..5e55169613 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -295,6 +295,24 @@ bool StreamExecutor::GetConvolveAlgorithms(
return dnn_support->GetConvolveAlgorithms(out_algorithms);
}
+bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
+ std::vector<dnn::AlgorithmType> *out_algorithms) {
+ dnn::DnnSupport *dnn_support = AsDnn();
+ if (!dnn_support) {
+ return false;
+ }
+ return dnn_support->GetConvolveBackwardDataAlgorithms(out_algorithms);
+}
+
+bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
+ std::vector<dnn::AlgorithmType> *out_algorithms) {
+ dnn::DnnSupport *dnn_support = AsDnn();
+ if (!dnn_support) {
+ return false;
+ }
+ return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms);
+}
+
dnn::DnnSupport *StreamExecutor::AsDnn() {
mutex_lock lock{mu_};
if (dnn_ != nullptr) {
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 31b110a8e0..e424411143 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -341,6 +341,15 @@ class StreamExecutor {
// Get the list of supported algorithms for the forward convolution opeartion.
bool GetConvolveAlgorithms(std::vector<dnn::AlgorithmType> *out_algorithms);
+ // Get the list of supported algorithms for the backward convolution on data.
+ bool GetConvolveBackwardDataAlgorithms(
+ std::vector<dnn::AlgorithmType> *out_algorithms);
+
+ // Get the list of supported algorithms for the backward convolution on the
+ // filter.
+ bool GetConvolveBackwardFilterAlgorithms(
+ std::vector<dnn::AlgorithmType> *out_algorithms);
+
// Returns the device ordinal that this StreamExecutor was initialized with.
// Meaningless before initialization.
int device_ordinal() const { return device_ordinal_; }