aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_dnn.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_dnn.cc')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc192
1 files changed, 138 insertions, 54 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 08faeefe74..087ae556e7 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
@@ -231,6 +232,7 @@ CUDNN_DNN_ROUTINE_EACH_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
__macro(cudnnRNNBackwardData) \
__macro(cudnnRNNBackwardWeights) \
__macro(cudnnSetRNNDescriptor) \
+ __macro(cudnnSetRNNDescriptor_v6) \
__macro(cudnnGetFilterNdDescriptor)
// clang-format on
@@ -250,6 +252,17 @@ CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R6
#endif
+// APIs in R7
+// clang-format off
+#if CUDNN_VERSION >= 7000
+#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
+ __macro(cudnnSetConvolutionMathType)
+
+// clang-format on
+CUDNN_DNN_ROUTINE_EACH_R7(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
+#undef CUDNN_DNN_ROUTINE_EACH_R7
+#endif
+
#undef CUDNN_DNN_ROUTINE_EACH
} // namespace wrap
@@ -260,8 +273,9 @@ cudnnHandle_t ToHandle(void* opaque_handle) {
return static_cast<cudnnHandle_t>(opaque_handle);
}
-cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmType algorithm) {
- cudnnConvolutionFwdAlgo_t algo = cudnnConvolutionFwdAlgo_t(algorithm);
+cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
+ cudnnConvolutionFwdAlgo_t algo =
+ cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
switch (algo) {
case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
@@ -278,13 +292,14 @@ cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmType algorithm) {
return algo;
default:
LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
- << algorithm;
+ << algorithm.algo_id();
}
}
cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
- dnn::AlgorithmType algorithm) {
- cudnnConvolutionBwdDataAlgo_t algo = cudnnConvolutionBwdDataAlgo_t(algorithm);
+ dnn::AlgorithmDesc algorithm) {
+ cudnnConvolutionBwdDataAlgo_t algo =
+ cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
switch (algo) {
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
@@ -300,14 +315,14 @@ cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
default:
LOG(FATAL)
<< "Unsupported Cudnn convolution backward algorithm for data: "
- << algorithm;
+ << algorithm.algo_id();
}
}
cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
- dnn::AlgorithmType algorithm) {
+ dnn::AlgorithmDesc algorithm) {
cudnnConvolutionBwdFilterAlgo_t algo =
- cudnnConvolutionBwdFilterAlgo_t(algorithm);
+ cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
switch (algo) {
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
@@ -322,7 +337,7 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
default:
LOG(FATAL)
<< "Unsupported Cudnn convolution backward algorithm for filter: "
- << algorithm;
+ << algorithm.algo_id();
}
}
@@ -541,6 +556,17 @@ class ScopedFilterDescriptor {
SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
};
+// A helper function to decide whether to enable the TENSOR_OP_MATH math type
+static bool TensorOpMathEnabled() {
+ static bool is_enabled = [] {
+ bool ret;
+ TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DISABLE_TENSOR_OP_MATH",
+ /*default=*/false, &ret));
+ return ret;
+ }();
+ return is_enabled;
+}
+
// Turns a ConvolutionDescriptor structure into a cudnn convolution handle
// within a scope.
class ScopedConvolutionDescriptor {
@@ -583,6 +609,24 @@ class ScopedConvolutionDescriptor {
LOG(FATAL) << "could not set cudnn convolution descriptor: "
<< ToString(status);
}
+ // NOTE(benbarsdell): This only applies if tensor op math is enabled
+ // and algo selection is set to Default.
+ this->set_use_tensor_op_math(true);
+ }
+
+ void set_use_tensor_op_math(bool use_tensor_op_math) {
+#if CUDNN_VERSION >= 7000
+ cudnnMathType_t math_type =
+ (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
+ if (TensorOpMathEnabled()) {
+ cudnnStatus_t status =
+ wrap::cudnnSetConvolutionMathType(parent_, handle_, math_type);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not set cudnn convolution math type: "
+ << ToString(status);
+ }
+ }
+#endif
}
~ScopedConvolutionDescriptor() {
@@ -1010,11 +1054,21 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
// Create the RNN handle
cudnnStatus_t status = wrap::cudnnCreateRNNDescriptor(parent_, &rnn_desc_);
CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
+#if CUDNN_VERSION >= 6000
+ // TODO: allow the user to choose an algorithm.
+ cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
+ status = wrap::cudnnSetRNNDescriptor_v6(
+ parent, cudnn_handle, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
+ num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
+ input_mode /*inputMode*/, direction_mode /*direction*/,
+ rnn_mode /*mode*/, rnn_algo /*algo*/, data_type /*dataType*/);
+#else
status = wrap::cudnnSetRNNDescriptor(
parent, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
input_mode /*inputMode*/, direction_mode /*direction*/,
rnn_mode /*mode*/, data_type /*dataType*/);
+#endif
CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor");
// Create the params handle.
@@ -1943,7 +1997,7 @@ inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo(
return algo_to_use;
}
-dnn::AlgorithmType GetCudnnConvolutionForwardAlgorithm(
+dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm(
Stream* stream, CUDAExecutor* parent, void* dnn_handle,
int cudnn_type, // Actually cudnnDataType_t.
const dnn::AlgorithmConfig& algorithm_config, bool is_profiling,
@@ -1952,13 +2006,18 @@ dnn::AlgorithmType GetCudnnConvolutionForwardAlgorithm(
const ScopedConvolutionDescriptor& conv,
const ScopedTensorDescriptor& output_nd,
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
- cudnnConvolutionFwdAlgo_t algo =
- (algorithm_config.algorithm() == dnn::kDefaultAlgorithm)
- ? GetCudnnConvolutionForwardAlgo(
- stream, parent, dnn_handle, input_nd, filter, conv, output_nd,
- /*specify_workspace_limit=*/scratch_allocator != nullptr,
- scratch_allocator)
- : ToConvForwardAlgo(algorithm_config.algorithm());
+ cudnnConvolutionFwdAlgo_t algo;
+ bool use_tensor_ops;
+ if (algorithm_config.algorithm().is_default()) {
+ use_tensor_ops = true;
+ algo = GetCudnnConvolutionForwardAlgo(
+ stream, parent, dnn_handle, input_nd, filter, conv, output_nd,
+ /*specify_workspace_limit=*/scratch_allocator != nullptr,
+ scratch_allocator);
+ } else {
+ use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
+ algo = ToConvForwardAlgo(algorithm_config.algorithm());
+ }
size_t size_in_bytes;
auto status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
parent, ToHandle(dnn_handle), /*srcDesc=*/input_nd.handle(),
@@ -1969,16 +2028,16 @@ dnn::AlgorithmType GetCudnnConvolutionForwardAlgorithm(
if (TF_PREDICT_FALSE(status != CUDNN_STATUS_SUCCESS)) {
CHECK(is_profiling) << "Cannot query the size of workspace needed "
"for the specified algorithm: "
- << algorithm_config.algorithm() << " "
+ << algorithm_config.algorithm().algo_id() << " "
<< ToString(status);
// Silently return when we are profiling.
- return dnn::kNoSuitableAlgorithmFound;
+ return dnn::AlgorithmDesc();
}
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned "
"negative sizeInBytes value. This could be a cudnn bug.";
if (TF_PREDICT_TRUE(is_profiling)) {
- return dnn::kNoSuitableAlgorithmFound;
+ return dnn::AlgorithmDesc();
}
} else if (size_in_bytes_int64 > 0) {
port::StatusOr<DeviceMemory<uint8>> allocated;
@@ -1989,26 +2048,30 @@ dnn::AlgorithmType GetCudnnConvolutionForwardAlgorithm(
} else {
if (TF_PREDICT_TRUE(is_profiling)) {
// Silently return when we are profiling.
- return dnn::kNoSuitableAlgorithmFound;
+ return dnn::AlgorithmDesc();
}
LOG(WARNING) << allocated.status().error_message();
// For the int8 case, we fail at this point since the no_scratch
// algorithm should be set to dnn::kDefaultAlgorithm.
- CHECK(algorithm_config.algorithm_no_scratch() != dnn::kDefaultAlgorithm)
+ CHECK(!algorithm_config.algorithm_no_scratch().is_default())
<< "The primary convolution algorithm failed memory allocation, "
"while a secondary algorithm is not provided.";
}
}
if (TF_PREDICT_FALSE(!allocated.ok())) {
- algo = (algorithm_config.algorithm_no_scratch() == dnn::kDefaultAlgorithm)
- ? GetCudnnConvolutionForwardAlgo(
- stream, parent, dnn_handle, input_nd, filter, conv,
- output_nd, /*specify_workspace_limit=*/false, nullptr)
- : ToConvForwardAlgo(algorithm_config.algorithm_no_scratch());
+ if (algorithm_config.algorithm_no_scratch().is_default()) {
+ use_tensor_ops = true;
+ algo = GetCudnnConvolutionForwardAlgo(
+ stream, parent, dnn_handle, input_nd, filter, conv, output_nd,
+ /*specify_workspace_limit=*/false, nullptr);
+ } else {
+ use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
+ algo = ToConvForwardAlgo(algorithm_config.algorithm_no_scratch());
+ }
}
}
- return algo;
+ return dnn::AlgorithmDesc(algo, use_tensor_ops);
}
} // namespace
@@ -2050,11 +2113,12 @@ bool CudnnSupport::DoConvolveImpl(
const bool is_profiling = output_profile_result != nullptr;
cudnnConvolutionFwdAlgo_t algo;
+ bool use_tensor_ops;
DeviceMemory<uint8> scratch;
// TODO(pauldonnelly): Replace the following code with a call to
// GetCudnnConvolutionForwardAlgorithm().
- if (algorithm_config.algorithm() == dnn::kDefaultAlgorithm) {
+ if (algorithm_config.algorithm().is_default()) {
// With the default algorithm, use Cudnn's heuristics.
auto get_algorithm =
[&](bool specify_limit) SHARED_LOCKS_REQUIRED(dnn_handle_mutex_) {
@@ -2085,6 +2149,7 @@ bool CudnnSupport::DoConvolveImpl(
};
algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
+ use_tensor_ops = true;
if (scratch_allocator != nullptr) {
size_t size_in_bytes;
status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
@@ -2117,7 +2182,10 @@ bool CudnnSupport::DoConvolveImpl(
}
} else {
// An algorithm has been specified.
- algo = ToConvForwardAlgo(algorithm_config.algorithm());
+ dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
+ algo = ToConvForwardAlgo(algotype);
+ use_tensor_ops = algotype.tensor_ops_enabled();
+ conv.set_use_tensor_op_math(use_tensor_ops);
size_t size_in_bytes;
status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
@@ -2131,7 +2199,7 @@ bool CudnnSupport::DoConvolveImpl(
}
LOG(FATAL) << "Cannot query the size of workspace needed for the given "
"algorithm: "
- << algorithm_config.algorithm();
+ << algorithm_config.algorithm().algo_id();
}
int64 size_in_bytes_int64 = size_in_bytes;
if (size_in_bytes_int64 > 0) {
@@ -2150,10 +2218,13 @@ bool CudnnSupport::DoConvolveImpl(
LOG(WARNING) << allocated.status().error_message();
}
if (scratch == nullptr) {
- CHECK(algorithm_config.algorithm_no_scratch() != dnn::kDefaultAlgorithm)
+ CHECK(!algorithm_config.algorithm_no_scratch().is_default())
<< "The primary convolution algorithm failed memory allocation, "
"while a secondary algorithm is not provided.";
- algo = ToConvForwardAlgo(algorithm_config.algorithm_no_scratch());
+ dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
+ algo = ToConvForwardAlgo(algotype);
+ use_tensor_ops = algotype.tensor_ops_enabled();
+ conv.set_use_tensor_op_math(use_tensor_ops);
}
} else if (size_in_bytes_int64 < 0) {
LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned "
@@ -2189,7 +2260,8 @@ bool CudnnSupport::DoConvolveImpl(
return false;
}
if (status == CUDNN_STATUS_SUCCESS) {
- output_profile_result->set_algorithm(algo);
+ dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
+ output_profile_result->set_algorithm(algotype);
output_profile_result->set_elapsed_time_in_ms(
timer->GetElapsedMilliseconds());
}
@@ -2250,17 +2322,18 @@ bool CudnnSupport::DoFusedConvolveImpl(
const bool is_profiling = output_profile_result != nullptr;
DeviceMemory<uint8> scratch;
- dnn::AlgorithmType algorithm_type = GetCudnnConvolutionForwardAlgorithm(
+ dnn::AlgorithmDesc algotype = GetCudnnConvolutionForwardAlgorithm(
stream, parent_, dnn_handle_, cudnn_data_type, algorithm_config,
is_profiling, conv_input_nd, filter, conv, output_nd, scratch_allocator,
&scratch);
- if (algorithm_type == dnn::kNoSuitableAlgorithmFound) {
+ if (algotype.is_default()) {
if (!is_profiling) {
LOG(ERROR) << "No suitable algorithm found";
}
return false;
}
- auto algo = static_cast<cudnnConvolutionFwdAlgo_t>(algorithm_type);
+ auto algo = static_cast<cudnnConvolutionFwdAlgo_t>(algotype.algo_id());
+ conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
if (activation_mode != dnn::ActivationMode::kRelu) {
LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only supports Relu "
@@ -2326,7 +2399,7 @@ bool CudnnSupport::DoFusedConvolveImpl(
return false;
}
if (status == CUDNN_STATUS_SUCCESS) {
- output_profile_result->set_algorithm(algo);
+ output_profile_result->set_algorithm(algotype);
output_profile_result->set_elapsed_time_in_ms(
timer->GetElapsedMilliseconds());
}
@@ -2397,7 +2470,7 @@ struct WinogradNonfused {
bool CudnnSupport::GetConvolveAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmType>* out_algorithms) {
+ std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) {
out_algorithms->assign({
// clang-format off
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
@@ -2423,7 +2496,7 @@ bool CudnnSupport::GetConvolveAlgorithms(
bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmType>* out_algorithms) {
+ std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) {
out_algorithms->assign({
// clang-format off
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
@@ -2446,7 +2519,7 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmType>* out_algorithms) {
+ std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) {
out_algorithms->assign({
// clang-format off
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
@@ -2858,7 +2931,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
cudnnConvolutionBwdDataAlgo_t algo;
DeviceMemory<uint8> scratch;
- if (algorithm_config.algorithm() == dnn::kDefaultAlgorithm) {
+ if (algorithm_config.algorithm().is_default()) {
// With the default algorithm, use Cudnn's heuristics.
auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
dnn_handle_mutex_) -> cudnnConvolutionBwdDataAlgo_t {
@@ -2927,7 +3000,9 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
}
} else {
// An algorithm has been specified.
- algo = ToConvBackwardDataAlgo(algorithm_config.algorithm());
+ dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
+ algo = ToConvBackwardDataAlgo(algotype);
+ conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
size_t size_in_bytes;
status = wrap::cudnnGetConvolutionBackwardDataWorkspaceSize(
parent_, ToHandle(dnn_handle_),
@@ -2944,7 +3019,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
}
LOG(FATAL) << "Cannot query the size of workspace needed for the given "
"algorithm: "
- << algorithm_config.algorithm();
+ << algorithm_config.algorithm().algo_id();
}
int64 size_in_bytes_int64 = size_in_bytes;
if (size_in_bytes_int64 > 0) {
@@ -2963,10 +3038,12 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
LOG(WARNING) << allocated.status().error_message();
}
if (scratch == nullptr) {
- CHECK(algorithm_config.algorithm_no_scratch() != dnn::kDefaultAlgorithm)
+ CHECK(!algorithm_config.algorithm_no_scratch().is_default())
<< "The primary convolution algorithm failed memory allocation, "
"while a secondary algorithm is not provided.";
- algo = ToConvBackwardDataAlgo(algorithm_config.algorithm_no_scratch());
+ dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
+ algo = ToConvBackwardDataAlgo(algotype);
+ conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
}
} else if (size_in_bytes_int64 < 0) {
LOG(WARNING) << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
@@ -3005,7 +3082,9 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
if (is_profiling) {
timer->Stop(AsCUDAStream(stream));
if (status == CUDNN_STATUS_SUCCESS) {
- output_profile_result->set_algorithm(algo);
+ bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
+ dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
+ output_profile_result->set_algorithm(algotype);
output_profile_result->set_elapsed_time_in_ms(
timer->GetElapsedMilliseconds());
}
@@ -3108,7 +3187,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
cudnnConvolutionBwdFilterAlgo_t algo;
DeviceMemory<uint8> scratch;
- if (algorithm_config.algorithm() == dnn::kDefaultAlgorithm) {
+ if (algorithm_config.algorithm().is_default()) {
// With the default algorithm, use Cudnn's heuristics.
// Lambda that retrieves the algorithm.
@@ -3178,7 +3257,9 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
}
} else {
// An algorithm has been specified.
- algo = ToConvBackwardFilterAlgo(algorithm_config.algorithm());
+ dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
+ algo = ToConvBackwardFilterAlgo(algotype);
+ conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
size_t size_in_bytes;
status = wrap::cudnnGetConvolutionBackwardFilterWorkspaceSize(
@@ -3193,7 +3274,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
}
LOG(FATAL) << "Cannot query the size of workspace needed for the given "
"algorithm: "
- << algorithm_config.algorithm();
+ << algorithm_config.algorithm().algo_id();
}
int64 size_in_bytes_int64 = size_in_bytes;
if (size_in_bytes_int64 > 0) {
@@ -3212,11 +3293,12 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
LOG(WARNING) << allocated.status().error_message();
}
if (scratch == nullptr) {
- CHECK(algorithm_config.algorithm_no_scratch() != dnn::kDefaultAlgorithm)
+ CHECK(!algorithm_config.algorithm_no_scratch().is_default())
<< "The primary convolution algorithm failed memory allocation, "
"while a secondary algorithm is not provided.";
- algo =
- ToConvBackwardFilterAlgo(algorithm_config.algorithm_no_scratch());
+ dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
+ algo = ToConvBackwardFilterAlgo(algotype);
+ conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
}
} else if (size_in_bytes_int64 < 0) {
LOG(WARNING)
@@ -3255,7 +3337,9 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
if (is_profiling) {
timer->Stop(AsCUDAStream(stream));
if (status == CUDNN_STATUS_SUCCESS) {
- output_profile_result->set_algorithm(algo);
+ bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
+ dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
+ output_profile_result->set_algorithm(algotype);
output_profile_result->set_elapsed_time_in_ms(
timer->GetElapsedMilliseconds());
}