aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-01 00:18:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 00:20:43 -0700
commit961a39346d8be33cff473f1e81498b887c155070 (patch)
treed1175e89f82bd60137cf9fb2ecbee64d4ac5e59c /tensorflow/stream_executor
parent54b20c4be0372fb14ec9a289e4d7de7f67c03ff6 (diff)
Unify error handling in CudnnSupport.
PiperOrigin-RevId: 198836479
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc2874
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h128
-rw-r--r--tensorflow/stream_executor/cuda/cuda_timer.h3
-rw-r--r--tensorflow/stream_executor/dnn.cc4
-rw-r--r--tensorflow/stream_executor/dnn.h5
5 files changed, 1340 insertions, 1674 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index c2c0c283b3..55c1083a61 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <functional>
#include <memory>
+#include <utility>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/lib/core/errors.h"
@@ -55,6 +56,33 @@ namespace {
static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher");
+// Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
+#define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
+
+// If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
+// function with a non-successful port::Status.
+#define RETURN_IF_CUDNN_ERROR(expr) \
+ do { \
+ cudnnStatus_t _status = expr; \
+ if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \
+ std::ostringstream oss; \
+ oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
+ << "): '" << #expr << "'"; \
+ return port::Status(port::error::UNKNOWN, oss.str().c_str()); \
+ } \
+ } while (false)
+
+// Returns whether status is 'ok', and potentially logs the error.
+bool IsStatusOk(const port::Status& status, bool report_error) {
+ if (status.ok()) {
+ return true;
+ }
+ if (report_error) {
+ LOG(ERROR) << status.error_message();
+ }
+ return false;
+}
+
// Converts (via narrowing) a type T value to a type U, and checks that the
// value has no value change due to the conversion.
template <typename WideT, typename NarrowT>
@@ -89,26 +117,20 @@ string ToString(cudnnStatus_t status) {
return "CUDNN_STATUS_NOT_SUPPORTED";
case CUDNN_STATUS_LICENSE_ERROR:
return "CUDNN_STATUS_LICENSE_ERROR";
+ case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
+ return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
+#if CUDNN_VERSION >= 7000
+ case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
+ return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
+ case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
+ return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
+#endif
default:
return port::StrCat("<unknown cudnn status: ", static_cast<int>(status),
">");
}
}
-string ToString(libraryPropertyType type) {
- switch (type) {
- case MAJOR_VERSION:
- return "MAJOR_VERSION";
- case MINOR_VERSION:
- return "MINOR_VERSION";
- case PATCH_LEVEL:
- return "PATCH_LEVEL";
- default:
- return port::StrCat(
- "<unknown libraryPropertyType: ", static_cast<int>(type), ">");
- }
-}
-
template <typename T>
cudnnDataType_t GetCudnnDataType();
@@ -150,9 +172,9 @@ class CudnnHandle {
} // namespace
-// Wraps a cuDNN handle and provides access to it through CudnnHandle instances,
-// which also locks a mutex, acquires the CUDA context, and sets the stream
-// that cuDNN should use to enqueue any work.
+// Wraps a cuDNN handle and provides access to it through CudnnHandle
+// instances, which also locks a mutex, acquires the CUDA context, and sets
+// the stream that cuDNN should use to enqueue any work.
//
// Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
class CudnnAccess {
@@ -167,13 +189,13 @@ class CudnnAccess {
// Creates a CudnnHandle instance for stream.
//
- // cuDNN API calls using the same handle instance need to be serialized across
- // threads. This is guaranteed by CudnnHandle instances locking the mutex
- // owned by this class.
+ // cuDNN API calls using the same handle instance need to be serialized
+ // across threads. This is guaranteed by CudnnHandle instances locking the
+ // mutex owned by this class.
//
// Most cuDNN APIs taking a handle perform work on a CUDA stream. The
- // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN to
- // use the provided stream.
+ // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
+ // to use the provided stream.
//
// The stream argument may be null, which translates to the legacy default
// stream. See
@@ -187,7 +209,6 @@ class CudnnAccess {
CUstream cu_stream = stream ? AsCUDAStreamValue(stream) : cudaStreamLegacy;
auto status = cudnnSetStream(handle_, cu_stream);
CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
- using my_mutex_lock = mutex_lock;
return CudnnHandle(std::move(context), std::move(lock), handle_);
}
@@ -201,6 +222,8 @@ class CudnnAccess {
namespace {
+// A helper function to return the internal compute type for
+// RNNs in cudnn.
cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
@@ -264,16 +287,10 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
}
}
-port::Status GetCudnnProperty(libraryPropertyType type, int* value) {
- cudnnStatus_t status = cudnnGetProperty(type, value);
- if (status != CUDNN_STATUS_SUCCESS) {
- const string error =
- port::StrCat("cudnnGetProperty failed for type: ", ToString(type),
- " with status: ", ToString(status));
- LOG(ERROR) << error;
- return port::Status(port::error::INTERNAL, error);
- }
- return port::Status::OK();
+port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
+ int value;
+ RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
+ return value;
}
cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) {
@@ -294,9 +311,9 @@ cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) {
}
port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
- TF_RETURN_IF_ERROR(GetCudnnProperty(MAJOR_VERSION, &version->major_version));
- TF_RETURN_IF_ERROR(GetCudnnProperty(MINOR_VERSION, &version->minor_version));
- TF_RETURN_IF_ERROR(GetCudnnProperty(PATCH_LEVEL, &version->patch_level));
+ SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
+ SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
+ SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
return port::Status::OK();
}
@@ -319,9 +336,11 @@ port::Status CudnnSupport::Init() {
". CuDNN library major and minor version needs to match or have "
"higher minor version in case of CuDNN 7.0 or later version. If "
"using a binary install, upgrade your CuDNN library. If building "
- "from sources, make sure the library loaded at runtime is compatible "
+ "from sources, make sure the library loaded at runtime is "
+ "compatible "
"with the version specified during compile configuration.");
LOG(ERROR) << error;
+ cudnnDestroy(cudnn_handle);
return port::Status(port::error::INTERNAL, error);
}
@@ -329,23 +348,17 @@ port::Status CudnnSupport::Init() {
return port::Status::OK();
}
- LOG(ERROR) << "could not create cudnn handle: " << ToString(status);
+ CHECK_EQ(cudnn_handle, nullptr);
+ LOG(ERROR) << "Could not create cudnn handle: " << ToString(status);
if (status == CUDNN_STATUS_NOT_INITIALIZED) {
auto result = cuda::Diagnostician::FindKernelDriverVersion();
if (!result.ok()) {
- LOG(ERROR) << "error retrieving driver version: "
+ LOG(ERROR) << "Error retrieving driver version: "
<< DriverVersionStatusToString(result);
} else {
const auto& version = result.ValueOrDie();
- LOG(ERROR) << "possibly insufficient driver version: "
+ LOG(ERROR) << "Possibly insufficient driver version: "
<< DriverVersionToString(version);
- // OS X kernel driver does not report version accurately
-#if !defined(__APPLE__)
- if (std::get<0>(version) < 340) {
- LOG(ERROR)
- << "cudnn library is only supported on 340.XX+ driver versions";
- }
-#endif
}
}
@@ -364,18 +377,129 @@ CudnnSupport::GetVersion() {
namespace {
-// Turns a BatchDescriptor structure into a cudnn tensor handle within a scope.
+// Deleter functors for cuDNN types that need to be deleted.
+struct TensorDescriptorDeleter {
+ void operator()(cudnnTensorDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
+ }
+};
+struct FilterDescriptorDeleter {
+ void operator()(cudnnFilterDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
+ }
+};
+struct ConvolutionDescriptorDeleter {
+ void operator()(cudnnConvolutionDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
+ }
+};
+struct PoolingDescriptorDeleter {
+ void operator()(cudnnPoolingDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
+ }
+};
+struct LrnDescriptorDeleter {
+ void operator()(cudnnLRNDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
+ }
+};
+
+struct ActivationDescriptorDeleter {
+ void operator()(cudnnActivationDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
+ }
+};
+struct DropoutDescriptorDeleter {
+ void operator()(cudnnDropoutDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
+ }
+};
+struct RnnDescriptorDeleter {
+ void operator()(cudnnRNNDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
+ }
+};
+struct PersistentRnnPlanDeleter {
+ void operator()(cudnnPersistentRNNPlan_t plan) const {
+ CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
+ }
+};
+
+// RAII wrappers for cuDNN types.
+using TensorDescriptor =
+ std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
+using FilterDescriptor =
+ std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
+using ConvolutionDescriptor =
+ std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
+using PoolingDescriptor =
+ std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
+using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
+using ActivationDescriptor =
+ std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
+using DropoutDescriptor =
+ std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
+using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
+using PersistentRnnPlan =
+ std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
+
+// Factory methods for cuDNN types.
+TensorDescriptor CreateTensorDescriptor() {
+ cudnnTensorDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
+ return TensorDescriptor(result);
+}
+FilterDescriptor CreateFilterDescriptor() {
+ cudnnFilterDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
+ return FilterDescriptor(result);
+}
+ConvolutionDescriptor CreateConvolutionDescriptor() {
+ cudnnConvolutionDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
+ return ConvolutionDescriptor(result);
+}
+PoolingDescriptor CreatePoolingDescriptor() {
+ cudnnPoolingDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
+ return PoolingDescriptor(result);
+}
+LrnDescriptor CreateLrnDescriptor() {
+ cudnnLRNDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
+ return LrnDescriptor(result);
+}
+ActivationDescriptor CreateActivationDescriptor() {
+ cudnnActivationDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
+ return ActivationDescriptor(result);
+}
+DropoutDescriptor CreateDropoutDescriptor() {
+ cudnnDropoutDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
+ return DropoutDescriptor(result);
+}
+RnnDescriptor CreateRnnDescriptor() {
+ cudnnRNNDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
+ return RnnDescriptor(result);
+}
+PersistentRnnPlan CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,
+ int batch_size,
+ cudnnDataType_t data_type) {
+ cudnnPersistentRNNPlan_t result;
+ CHECK_CUDNN_OK(
+ cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
+ return PersistentRnnPlan(result);
+}
+
+// Turns a BatchDescriptor structure into a cudnn tensor handle within a
+// scope.
class ScopedTensorDescriptor {
public:
ScopedTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
cudnnDataType_t elem_type)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn tensor descriptor: "
- << ToString(status);
- }
-
+ : handle_(CreateTensorDescriptor()) {
switch (batch_descriptor.layout()) {
case dnn::DataLayout::kBatchYXDepth:
case dnn::DataLayout::kBatchDepthYX: {
@@ -393,25 +517,16 @@ class ScopedTensorDescriptor {
&CheckedNarrowing<int64, int>);
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
&CheckedNarrowing<int64, int>);
- status = cudnnSetTensorNdDescriptor(handle_, elem_type, nd, dims.data(),
- strides.data());
-
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not convert BatchDescriptor "
- << batch_descriptor.ToString()
- << " to cudnn tensor descriptor: " << ToString(status);
- }
+ CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
+ dims.data(), strides.data()))
+ << "batch_descriptor: " << batch_descriptor.ToString();
} break;
case dnn::DataLayout::kBatchDepthYX4: {
- status = cudnnSetTensor4dDescriptor(
- handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type,
+ CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
+ handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
batch_descriptor.count(), batch_descriptor.feature_map_count(),
- batch_descriptor.height(), batch_descriptor.width());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not convert BatchDescriptor "
- << batch_descriptor.ToString()
- << " to cudnn tensor descriptor: " << ToString(status);
- }
+ batch_descriptor.height(), batch_descriptor.width()))
+ << "batch_descriptor: " << batch_descriptor.ToString();
} break;
default:
LOG(FATAL) << "Unsupported tensor format "
@@ -420,37 +535,24 @@ class ScopedTensorDescriptor {
}
}
- ~ScopedTensorDescriptor() {
- cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn tensor descriptor: "
- << ToString(status);
- }
- }
-
- cudnnTensorDescriptor_t handle() const { return handle_; }
+ cudnnTensorDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnTensorDescriptor_t handle_; // Owned.
+ TensorDescriptor handle_;
SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
};
-// Turns a FilterDescriptor structure into a cudnn filter handle within a scope.
+// Turns a FilterDescriptor structure into a cudnn filter handle within a
+// scope.
class ScopedFilterDescriptor {
public:
ScopedFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
cudnnDataType_t elem_type)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreateFilterDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn filter descriptor: "
- << ToString(status);
- }
-
+ : handle_(CreateFilterDescriptor()) {
// TODO(b/23032134): Even if the filter layout is not supported,
- // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because it
- // does not take layout as an input. Maybe force cuDNN by giving wrong
+ // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
+ // it does not take layout as an input. Maybe force cuDNN by giving wrong
// inputs intentionally?
cudnnTensorFormat_t format;
switch (filter_descriptor.layout()) {
@@ -475,32 +577,20 @@ class ScopedFilterDescriptor {
const auto& spatial_dims = filter_descriptor.input_filter_dims();
std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
- status = cudnnSetFilterNdDescriptor(handle_, elem_type, format, dims.size(),
- dims.data());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn filter descriptor: "
- << ToString(status);
- }
+ CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
+ dims.size(), dims.data()));
}
- ~ScopedFilterDescriptor() {
- cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn filter descriptor: "
- << ToString(status);
- }
- }
-
- cudnnFilterDescriptor_t handle() const { return handle_; }
+ cudnnFilterDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnFilterDescriptor_t handle_; // Owned.
+ FilterDescriptor handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
};
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
-static bool TensorOpMathEnabled() {
+bool TensorOpMathEnabled() {
static bool is_enabled = [] {
bool is_disabled = false;
TF_CHECK_OK(
@@ -513,7 +603,7 @@ static bool TensorOpMathEnabled() {
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
// for RNNs.
-static bool RnnTensorOpMathEnabled() {
+bool RnnTensorOpMathEnabled() {
static bool is_enabled = [] {
bool is_disabled = false;
TF_CHECK_OK(
@@ -524,15 +614,16 @@ static bool RnnTensorOpMathEnabled() {
return is_enabled;
}
-// A helper function to decide whether to use CUDNN_BATCHNORM_SPATIAL_PERSISTENT
-// in batchnorm. This mode can be faster in some tasks because an optimized path
-// may be selected for CUDNN_DATA_FLOAT and CUDNN_DATA_HALF data types, compute
-// capability 6.0 or higher. The reason we set it to false by default is that
-// this mode may use scaled atomic integer reduction that may cause a numerical
-// overflow for certain input data range.
+// A helper function to decide whether to use
+// CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
+// some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
+// and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
+// reason we set it to false by default is that this mode may use scaled
+// atomic integer reduction that may cause a numerical overflow for certain
+// input data range.
// TODO(yangzihao): Use autotune to choose between this mode and
// CUDNN_BATCHNORM_SPATIAL mode.
-static bool BatchnormSpatialPersistentEnabled() {
+bool BatchnormSpatialPersistentEnabled() {
static bool is_enabled = [] {
bool is_enabled = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
@@ -550,19 +641,13 @@ class ScopedConvolutionDescriptor {
ScopedConvolutionDescriptor(
const dnn::ConvolutionDescriptor& convolution_descriptor,
cudnnDataType_t data_type)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreateConvolutionDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn convolution descriptor: "
- << ToString(status);
- }
+ : handle_(CreateConvolutionDescriptor()) {
const auto& strides64 = convolution_descriptor.strides();
const auto& padding64 = convolution_descriptor.padding();
const auto& dilations64 = convolution_descriptor.dilations();
- if (convolution_descriptor.pad_alignment() ==
- dnn::PadAlignment::kTensorFlowPadding) {
- LOG(ERROR) << "TensorFlow padding alignment is not supported.";
- }
+ CHECK_NE(convolution_descriptor.pad_alignment(),
+ dnn::PadAlignment::kTensorFlowPadding)
+ << "TensorFlow padding alignment is not supported.";
// cuDNN requires arrays of ints.
std::vector<int> strides(convolution_descriptor.ndims());
@@ -577,18 +662,14 @@ class ScopedConvolutionDescriptor {
std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
&CheckedNarrowing<int64, int>);
- status = cudnnSetConvolutionNdDescriptor(
- handle_, convolution_descriptor.ndims(), padding.data(), strides.data(),
- dilations.data(),
+ CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
+ handle_.get(), convolution_descriptor.ndims(), padding.data(),
+ strides.data(), dilations.data(),
// NOTE(keveman): cuDNN supports convolution and cross correlation.
// However, almost all the use cases do cross correlation, so just
// hard coding it here.
- CUDNN_CROSS_CORRELATION, data_type);
+ CUDNN_CROSS_CORRELATION, data_type));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn convolution descriptor: "
- << ToString(status);
- }
// NOTE(benbarsdell): This only applies if tensor op math is enabled
// and algo selection is set to Default.
this->set_use_tensor_op_math(true);
@@ -596,44 +677,28 @@ class ScopedConvolutionDescriptor {
#if CUDNN_MAJOR >= 7
VLOG(2) << "Requesting grouped convolution: "
<< convolution_descriptor.group_count();
- status = cudnnSetConvolutionGroupCount(
- handle_, convolution_descriptor.group_count());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn convolution group count: "
- << ToString(status);
- }
+ CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
+ handle_.get(), convolution_descriptor.group_count()));
#else
CHECK_EQ(convolution_descriptor.group_count(), 1)
<< "Requested grouped convolution for cuDNN version < 7";
#endif
}
- void set_use_tensor_op_math(bool use_tensor_op_math) {
+ void set_use_tensor_op_math(bool use_tensor_op_math) const {
#if CUDNN_VERSION >= 7000
cudnnMathType_t math_type =
(use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
if (TensorOpMathEnabled()) {
- cudnnStatus_t status = cudnnSetConvolutionMathType(handle_, math_type);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn convolution math type: "
- << ToString(status);
- }
+ CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
}
#endif
}
- ~ScopedConvolutionDescriptor() {
- cudnnStatus_t status = cudnnDestroyConvolutionDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn convolution descriptor: "
- << ToString(status);
- }
- }
-
- cudnnConvolutionDescriptor_t handle() const { return handle_; }
+ cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnConvolutionDescriptor_t handle_; // Owned.
+ ConvolutionDescriptor handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
};
@@ -644,12 +709,7 @@ class ScopedPoolingDescriptor {
public:
explicit ScopedPoolingDescriptor(
const dnn::PoolingDescriptor& pooling_descriptor)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreatePoolingDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn pooling descriptor: "
- << ToString(status);
- }
+ : handle_(CreatePoolingDescriptor()) {
const std::vector<int64> strides64 = pooling_descriptor.strides();
const std::vector<int64> padding64 = pooling_descriptor.padding();
const std::vector<int64> shape64 = pooling_descriptor.window();
@@ -665,30 +725,19 @@ class ScopedPoolingDescriptor {
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
&CheckedNarrowing<int64, int>);
bool propagate_nans = pooling_descriptor.propagate_nans();
- status = cudnnSetPoolingNdDescriptor(
- handle_,
+ CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
+ handle_.get(),
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
? CUDNN_POOLING_MAX
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
- shape.data(), padding.data(), strides.data());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn pooling descriptor: "
- << ToString(status);
- }
- }
- ~ScopedPoolingDescriptor() {
- cudnnStatus_t status = cudnnDestroyPoolingDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn pooling descriptor: "
- << ToString(status);
- }
+ shape.data(), padding.data(), strides.data()));
}
- cudnnPoolingDescriptor_t handle() const { return handle_; }
+ cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnPoolingDescriptor_t handle_; // Owned.
+ PoolingDescriptor handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
};
@@ -698,13 +747,7 @@ class ScopedNormalizeDescriptor {
public:
explicit ScopedNormalizeDescriptor(
const dnn::NormalizeDescriptor& normalize_descriptor)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreateLRNDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn LRN descriptor: "
- << ToString(status);
- }
-
+ : handle_(CreateLrnDescriptor()) {
// The range specifies that the indices in the closed range
// [i - range, i + range] should be included in the normalization for index
// i. The lrnN value is the total number of elements in the range, so
@@ -725,24 +768,14 @@ class ScopedNormalizeDescriptor {
double lrnBeta = normalize_descriptor.beta();
double lrnK = normalize_descriptor.bias();
- status = cudnnSetLRNDescriptor(handle_, lrnN, lrnAlpha, lrnBeta, lrnK);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn LRN descriptor: " << ToString(status);
- }
- }
-
- ~ScopedNormalizeDescriptor() {
- cudnnStatus_t status = cudnnDestroyLRNDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn LRN descriptor: "
- << ToString(status);
- }
+ CHECK_CUDNN_OK(
+ cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
}
- cudnnLRNDescriptor_t handle() const { return handle_; }
+ cudnnLRNDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnLRNDescriptor_t handle_; // Owned.
+ LrnDescriptor handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
};
@@ -754,13 +787,7 @@ class ScopedActivationDescriptor {
ScopedActivationDescriptor(dnn::ActivationMode activation_mode,
cudnnNanPropagation_t nan_propagation,
double value_max)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreateActivationDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn activation descriptor: "
- << ToString(status);
- }
-
+ : handle_(CreateActivationDescriptor()) {
double relu_ceiling = 0.0;
cudnnActivationMode_t mode;
switch (activation_mode) {
@@ -786,26 +813,14 @@ class ScopedActivationDescriptor {
<< static_cast<int>(activation_mode);
}
- status = cudnnSetActivationDescriptor(handle_, mode, nan_propagation,
- relu_ceiling);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn activation descriptor: "
- << ToString(status);
- }
+ CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
+ nan_propagation, relu_ceiling));
}
- ~ScopedActivationDescriptor() {
- cudnnStatus_t status = cudnnDestroyActivationDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn activation descriptor: "
- << ToString(status);
- }
- }
-
- cudnnActivationDescriptor_t handle() const { return handle_; }
+ cudnnActivationDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnActivationDescriptor_t handle_; // Owned.
+ ActivationDescriptor handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
};
@@ -873,117 +888,74 @@ int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
}
}
-template <typename Base>
-class MixinBase : public Base {};
-template <>
-class MixinBase<void> {};
-
-#define CUDNN_RETURN_IF_FAIL(STATUS, ...) \
- if (!SE_PREDICT_TRUE((STATUS) == CUDNN_STATUS_SUCCESS)) { \
- string error_msg = port::StrCat(ToString(STATUS), " ", __VA_ARGS__); \
- SetFailure(port::Status(port::error::UNKNOWN, error_msg)); \
- LOG(ERROR) << error_msg; \
- return; \
- }
+class ScopedDropoutDescriptor {
+ explicit ScopedDropoutDescriptor(DropoutDescriptor handle)
+ : handle_(std::move(handle)) {}
-// TODO(csigg): Remove inheritance for code reuse.
-template <typename Base>
-class CudnnDescriptorCommon : public MixinBase<Base> {
public:
- bool ok() const { return status_.ok(); }
- port::Status Status() const { return status_; }
+ ScopedDropoutDescriptor(ScopedDropoutDescriptor&&) = default;
- protected:
- void SetFailure(const port::Status& status) { status_.Update(status); }
- port::Status status_;
-};
+ static port::StatusOr<ScopedDropoutDescriptor> Create(
+ const CudnnHandle& cudnn, float dropout, uint64 seed,
+ ScratchAllocator* state_allocator) {
+ DropoutDescriptor handle = CreateDropoutDescriptor();
-class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
- public:
- CudnnDropoutDescriptor(const CudnnHandle& cudnn, float dropout, uint64 seed,
- ScratchAllocator* state_allocator)
- : handle_(nullptr) {
- cudnnStatus_t status;
- status = cudnnCreateDropoutDescriptor(&handle_);
- CUDNN_RETURN_IF_FAIL(status, "Failed to create dropout descriptor");
-
- if (dropout == 0.f) {
- return;
+ if (dropout == 0.0f) {
+ // Return 'empty' dropout descriptor.
+ return ScopedDropoutDescriptor(std::move(handle));
}
DeviceMemory<uint8> state_memory;
if (state_allocator) {
size_t state_sizes_in_bytes = 0;
- status = cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes);
- CUDNN_RETURN_IF_FAIL(status, "Failed to query dropout state sizes");
-
- auto allocated =
- state_allocator->AllocateBytes(nullptr, state_sizes_in_bytes);
- if (!allocated.ok() ||
- (state_memory = allocated.ValueOrDie()) == nullptr) {
- string error_msg =
- port::StrCat("Failed to allocate Cudnn dropout state memory of ",
- state_sizes_in_bytes, " bytes.");
- status_ = port::Status(port::error::UNKNOWN, error_msg);
- LOG(ERROR) << error_msg;
- return;
- }
+ RETURN_IF_CUDNN_ERROR(
+ cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
+ SE_ASSIGN_OR_RETURN(state_memory, state_allocator->AllocateBytes(
+ nullptr, state_sizes_in_bytes));
}
- status = cudnnSetDropoutDescriptor(handle_, cudnn.handle(), dropout,
- state_memory.opaque(),
- state_memory.size(), seed);
- CUDNN_RETURN_IF_FAIL(
- status, port::StrCat(
- "Failed to set dropout descriptor with state memory size: ",
- state_memory.size(), " bytes."));
- }
+ RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
+ handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
+ state_memory.size(), seed));
- ~CudnnDropoutDescriptor() {
- cudnnStatus_t status = cudnnDestroyDropoutDescriptor(handle_);
- // TODO(csigg): This is a no-op (error is not reported). Same below.
- CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: ");
+ return ScopedDropoutDescriptor(std::move(handle));
}
- cudnnDropoutDescriptor_t handle() const {
- if (!ok()) return nullptr;
- return handle_;
- }
+ cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnDropoutDescriptor_t handle_; // Owned.
- float dropout_;
- uint64 seed_;
- SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
+ DropoutDescriptor handle_; // Owned.
+ SE_DISALLOW_COPY_AND_ASSIGN(ScopedDropoutDescriptor);
};
-class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> {
- public:
- typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion;
+class CudnnRnnParamsDescriptor {
typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
- CudnnRnnParamsDescriptor(const CudnnHandle& cudnn,
- const CudnnRnnDescriptor& rnn_desc);
- ~CudnnRnnParamsDescriptor() {
- cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_);
- CUDNN_RETURN_IF_FAIL(status, "Failed to destroy RNN filter descriptor");
- }
- cudnnFilterDescriptor_t handle() const {
- if (!ok()) return nullptr;
- return handle_;
- }
+
+ CudnnRnnParamsDescriptor(FilterDescriptor handle, int64 params_size_in_bytes,
+ ParamsRegions weights, ParamsRegions biases)
+ : handle_(std::move(handle)),
+ params_size_in_bytes_(params_size_in_bytes),
+ weights_(std::move(weights)),
+ biases_(std::move(biases)) {}
+
+ public:
+ CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
+
+ static port::StatusOr<CudnnRnnParamsDescriptor> Create(
+ const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
+ cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
+ cudnnDirectionMode_t direction_mode, int num_layers);
+
+ cudnnFilterDescriptor_t handle() const { return handle_.get(); }
int64 params_size_in_bytes() const { return params_size_in_bytes_; }
ParamsRegions params_weights() const {
- if (!ok()) return ParamsRegions();
return weights_;
}
ParamsRegions params_biases() const {
- if (!ok()) return ParamsRegions();
return biases_;
}
private:
- int GetRegionCountPerLayer() const;
- cudnnFilterDescriptor_t handle_;
- const CudnnRnnDescriptor* rnn_desc_;
+ FilterDescriptor handle_;
int64 params_size_in_bytes_;
ParamsRegions weights_;
ParamsRegions biases_;
@@ -992,97 +964,90 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> {
} // namespace
-class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
- public:
- CudnnRnnDescriptor(const CudnnHandle& cudnn, int num_layers, int hidden_size,
- int input_size, int batch_size,
+class CudnnRnnDescriptor : public dnn::RnnDescriptor {
+ CudnnRnnDescriptor(const CudnnHandle& cudnn, cuda::RnnDescriptor rnn_desc,
+ PersistentRnnPlan rnn_plan, int num_layers,
+ int hidden_size, int input_size, int batch_size,
cudnnRNNInputMode_t input_mode,
cudnnDirectionMode_t direction_mode,
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
cudnnDataType_t compute_type,
const dnn::AlgorithmConfig& algorithm_config,
- float dropout, uint64 seed,
- ScratchAllocator* state_allocator)
- : rnn_desc_(nullptr),
+ ScopedDropoutDescriptor dropout_desc,
+ CudnnRnnParamsDescriptor params_desc)
+ : rnn_desc_(std::move(rnn_desc)),
+ rnn_plan_(std::move(rnn_plan)),
num_layers_(num_layers),
hidden_size_(hidden_size),
input_size_(input_size),
batch_size_(batch_size),
- rnn_plan_(nullptr),
+ rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
input_mode_(input_mode),
direction_mode_(direction_mode),
rnn_mode_(rnn_mode),
data_type_(data_type),
compute_type_(compute_type),
- algorithm_config_(algorithm_config) {
- // Create the dropout handle.
- cudnn_dropout_desc_.reset(
- new CudnnDropoutDescriptor(cudnn, dropout, seed, state_allocator));
- if (!cudnn_dropout_desc_->ok()) {
- SetFailure(cudnn_dropout_desc_->Status());
- return;
- }
+ algorithm_config_(algorithm_config),
+ dropout_desc_(std::move(dropout_desc)),
+ params_desc_(std::move(params_desc)) {}
+
+ public:
+ CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
+
+ static port::StatusOr<CudnnRnnDescriptor> Create(
+ const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
+ int batch_size, cudnnRNNInputMode_t input_mode,
+ cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
+ cudnnDataType_t data_type, cudnnDataType_t compute_type,
+ const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
+ ScratchAllocator* state_allocator) {
+ SE_ASSIGN_OR_RETURN(
+ ScopedDropoutDescriptor dropout_desc,
+ ScopedDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
+
+ cuda::RnnDescriptor rnn_desc = CreateRnnDescriptor();
+ cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
- // Create the RNN handle
- cudnnStatus_t status = cudnnCreateRNNDescriptor(&rnn_desc_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
// TODO: allow the user to choose an algorithm.
- rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm());
- status = cudnnSetRNNDescriptor_v6(
- cudnn.handle(), /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size,
- /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(),
+ RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
+ cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/hidden_size,
+ /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(),
/*inputMode=*/input_mode, /*direction=*/direction_mode,
- /*mode=*/rnn_mode, /*algo=*/rnn_algo_, /*dataType=*/compute_type);
- CUDNN_RETURN_IF_FAIL(status, ::tensorflow::strings::Printf(
- "Unable to update RNN descriptor with "
- "algo_id: %d and compute_type: %d",
- static_cast<int>(rnn_algo_),
- static_cast<int>(compute_type)));
-
- if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
- CHECK_GE(batch_size_, 0);
- status = cudnnCreatePersistentRNNPlan(rnn_desc_, batch_size_, data_type_,
- &rnn_plan_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to create persistent RNN plan.");
- status = cudnnSetPersistentRNNPlan(rnn_desc_, rnn_plan_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan.");
+ /*mode=*/rnn_mode, /*algo=*/rnn_algo,
+ /*dataType=*/compute_type));
+
+ PersistentRnnPlan rnn_plan;
+ if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
+ CHECK_GE(batch_size, 0);
+ rnn_plan = CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
+ RETURN_IF_CUDNN_ERROR(
+ cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
}
// Create the params handle.
- cudnn_params_desc_.reset(new CudnnRnnParamsDescriptor(cudnn, *this));
- if (!cudnn_params_desc_->ok()) {
- SetFailure(cudnn_params_desc_->Status());
- return;
- }
- set_use_tensor_op_math(algorithm_config_.algorithm().tensor_ops_enabled());
- }
- ~CudnnRnnDescriptor() override {
- if (rnn_desc_) {
- cudnnStatus_t status;
- if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) {
- status = cudnnDestroyPersistentRNNPlan(rnn_plan_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan.");
- }
- status = cudnnDestroyRNNDescriptor(rnn_desc_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor");
- }
- }
- void set_use_tensor_op_math(bool use_tensor_op_math) {
+ SE_ASSIGN_OR_RETURN(auto params_desc,
+ CudnnRnnParamsDescriptor::Create(
+ cudnn, input_size, data_type, rnn_desc.get(),
+ rnn_mode, direction_mode, num_layers));
+
#if CUDNN_VERSION >= 7000
- cudnnMathType_t math_type =
- (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
if (RnnTensorOpMathEnabled()) {
- cudnnStatus_t status = cudnnSetRNNMatrixMathType(rnn_desc_, math_type);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn RNN math type: " << ToString(status);
- }
+ cudnnMathType_t math_type =
+ algorithm_config.algorithm().tensor_ops_enabled()
+ ? CUDNN_TENSOR_OP_MATH
+ : CUDNN_DEFAULT_MATH;
+ CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
}
#endif
+
+ return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
+ num_layers, hidden_size, input_size, batch_size,
+ input_mode, direction_mode, rnn_mode, data_type,
+ compute_type, algorithm_config,
+ std::move(dropout_desc), std::move(params_desc));
}
- cudnnRNNDescriptor_t handle() const {
- if (!ok()) return nullptr;
- return rnn_desc_;
- }
+
+ cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
int num_layers() const { return num_layers_; }
int hidden_size() const { return hidden_size_; }
int input_size() const { return input_size_; }
@@ -1096,27 +1061,21 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
return algorithm_config_;
}
int64 ParamsSizeInBytes() const override {
- return cudnn_params_desc_->params_size_in_bytes();
- }
- cudnnDropoutDescriptor_t dropout_handle() const {
- if (!cudnn_dropout_desc_) return nullptr;
- return cudnn_dropout_desc_->handle();
+ return params_desc_.params_size_in_bytes();
}
cudnnFilterDescriptor_t params_handle() const {
- if (!cudnn_params_desc_) return nullptr;
- return cudnn_params_desc_->handle();
+ return params_desc_.handle();
}
ParamsRegions ParamsWeightRegions() const override {
- if (!ok()) return ParamsRegions();
- return cudnn_params_desc_->params_weights();
+ return params_desc_.params_weights();
}
ParamsRegions ParamsBiasRegions() const override {
- if (!ok()) return ParamsRegions();
- return cudnn_params_desc_->params_biases();
+ return params_desc_.params_biases();
}
private:
- cudnnRNNDescriptor_t rnn_desc_;
+ cuda::RnnDescriptor rnn_desc_;
+ PersistentRnnPlan rnn_plan_;
int num_layers_;
int hidden_size_;
int input_size_;
@@ -1124,180 +1083,142 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
// algorithm.
int batch_size_;
cudnnRNNAlgo_t rnn_algo_;
- cudnnPersistentRNNPlan_t rnn_plan_;
cudnnRNNInputMode_t input_mode_;
cudnnDirectionMode_t direction_mode_;
cudnnRNNMode_t rnn_mode_;
cudnnDataType_t data_type_;
cudnnDataType_t compute_type_;
dnn::AlgorithmConfig algorithm_config_;
- std::unique_ptr<CudnnDropoutDescriptor> cudnn_dropout_desc_;
- std::unique_ptr<CudnnRnnParamsDescriptor> cudnn_params_desc_;
+ ScopedDropoutDescriptor dropout_desc_;
+ CudnnRnnParamsDescriptor params_desc_;
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
};
namespace {
-CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
- const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc)
- : handle_(nullptr), rnn_desc_(&rnn_desc), params_size_in_bytes_(0) {
- cudnnTensorDescriptor_t input_desc = nullptr;
- {
- // Query the params size.
- auto status = cudnnCreateTensorDescriptor(&input_desc);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create tensor descriptor");
- int dims[] = {1, rnn_desc.input_size(), 1};
- int strides[] = {dims[1] * dims[2], dims[2], 1};
- status = cudnnSetTensorNdDescriptor(
- /*tensorDesc=*/input_desc, /*dataType=*/rnn_desc.data_type(),
- /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
- /*strideA=*/strides);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor");
-
- size_t params_size = 0;
- status = cudnnGetRNNParamsSize(
- /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
- /*xDesc=*/input_desc, /*sizeInBytes=*/&params_size,
- /*dataType=*/rnn_desc.data_type());
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get RNN parameter size");
- params_size_in_bytes_ = static_cast<int64>(params_size);
- }
-
- {
- // Create the params descriptor.
- auto status = cudnnCreateFilterDescriptor(&handle_);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor");
- int dims[] = {static_cast<int>(params_size_in_bytes_), 1, 1};
- status = cudnnSetFilterNdDescriptor(
- /*filterDesc=*/handle_, /*dataType=*/rnn_desc.data_type(),
- /*format=*/CUDNN_TENSOR_NCHW, /*nbDims=*/sizeof(dims) / sizeof(dims[0]),
- /*filterDimA=*/dims);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to update RNN filter descriptor");
- }
+port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
+ const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
+ cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
+ cudnnDirectionMode_t direction_mode, int num_layers) {
+ // Query the params size.
+ TensorDescriptor input_desc = CreateTensorDescriptor();
+ int tensor_dims[] = {1, input_size, 1};
+ int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
+ RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
+ /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
+ /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
+ /*dimA=*/tensor_dims,
+ /*strideA=*/strides));
+
+ size_t params_size = 0;
+ RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
+ /*xDesc=*/input_desc.get(), /*sizeInBytes=*/&params_size,
+ /*dataType=*/data_type));
+ int64 params_size_in_bytes = static_cast<int64>(params_size);
+
+ FilterDescriptor filter_desc = CreateFilterDescriptor();
+ int filter_dims[] = {static_cast<int>(params_size_in_bytes), 1, 1};
+ RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
+ /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
+ /*format=*/CUDNN_TENSOR_NCHW,
+ /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
+ /*filterDimA=*/filter_dims));
+
+ // Create the weights and biases into the params buffer
+ int region_count_per_layer = [&] {
+ switch (rnn_mode) {
+ case CUDNN_RNN_RELU:
+ case CUDNN_RNN_TANH:
+ return 2;
+ case CUDNN_LSTM:
+ return 8;
+ case CUDNN_GRU:
+ return 6;
+ default:
+ LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
+ return 0;
+ }
+ }();
- {
- // Create the weights and biases into the params buffer
- int region_count_per_layer = GetRegionCountPerLayer();
- cudnnFilterDescriptor_t region_desc_handle = nullptr;
- auto status = cudnnCreateFilterDescriptor(&region_desc_handle);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create filter descriptor");
- const int layer_count = rnn_desc.direction_mode() == CUDNN_UNIDIRECTIONAL
- ? rnn_desc.num_layers()
- : 2 * rnn_desc.num_layers();
- for (int layer = 0; layer < layer_count; layer++) {
- for (int region = 0; region < region_count_per_layer; region++) {
- for (int type = 0; type < 2; type++) {
- void* offset = nullptr;
- if (type == 0) {
- status = cudnnGetRNNLinLayerMatrixParams(
- /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
- /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_,
- /*w=*/nullptr, /*linLayerID=*/region,
- /*linLayerMatDesc=*/region_desc_handle,
- /*linLayerMat=*/&offset);
- CUDNN_RETURN_IF_FAIL(
- status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams");
- } else {
- status = cudnnGetRNNLinLayerBiasParams(
- /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
- /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_,
- /*w=*/nullptr, /*linLayerID=*/region,
- /*linLayerBiasDesc=*/region_desc_handle,
- /*linLayerBias=*/&offset);
- CUDNN_RETURN_IF_FAIL(
- status, "Cudnn fails to call cudnnGetRNNLinLayerBiasParams");
- }
- int dims[] = {1, 1, 1};
- cudnnDataType_t data_type;
- cudnnTensorFormat_t tensor_format;
- int n_dims;
- status = cudnnGetFilterNdDescriptor(
- /*filterDesc=*/region_desc_handle,
- /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
- /*dataType=*/&data_type, /*format=*/&tensor_format,
- /*nbDims=*/&n_dims, /*filterDimA=*/dims);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get filter description");
- int64 size = dims[0] * dims[1] * dims[2] *
- CudnnDataTypeToByteSize(rnn_desc.data_type());
- ParamsRegion region = {reinterpret_cast<int64>(offset), size};
- if (type == 0) {
- weights_.push_back(region);
- } else {
- biases_.push_back(region);
- }
- }
+ FilterDescriptor region_desc_handle = CreateFilterDescriptor();
+ const int layer_count =
+ direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
+
+ ParamsRegions weights;
+ ParamsRegions biases;
+
+ for (int layer = 0; layer < layer_count; layer++) {
+ for (int region = 0; region < region_count_per_layer; region++) {
+ for (int type = 0; type < 2; type++) {
+ void* offset = nullptr;
+ RETURN_IF_CUDNN_ERROR((type == 0 ? cudnnGetRNNLinLayerMatrixParams
+ : cudnnGetRNNLinLayerBiasParams)(
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
+ /*layer=*/layer, /*xDesc=*/input_desc.get(),
+ /*wDesc=*/filter_desc.get(),
+ /*w=*/nullptr, /*linLayerID=*/region,
+ /*linLayerMatDesc=*/region_desc_handle.get(),
+ /*linLayerMat or linLayerBias=*/&offset));
+ int dims[] = {1, 1, 1};
+ cudnnDataType_t data_type;
+ cudnnTensorFormat_t tensor_format;
+ int n_dims;
+ RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
+ /*filterDesc=*/region_desc_handle.get(),
+ /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
+ /*dataType=*/&data_type, /*format=*/&tensor_format,
+ /*nbDims=*/&n_dims, /*filterDimA=*/dims));
+ int64 size =
+ dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
+ dnn::RnnDescriptor::ParamsRegion region = {
+ reinterpret_cast<int64>(offset), size};
+ (type == 0 ? weights : biases).push_back(region);
}
}
- status = cudnnDestroyFilterDescriptor(region_desc_handle);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy filter descriptor");
}
- {
- // Release the dummy input tensor descriptor.
- auto status = cudnnDestroyTensorDescriptor(input_desc);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy tensor descriptor");
- }
-}
-
-int CudnnRnnParamsDescriptor::GetRegionCountPerLayer() const {
- auto rnn_mode = rnn_desc_->rnn_mode();
- switch (rnn_mode) {
- case CUDNN_RNN_RELU:
- case CUDNN_RNN_TANH:
- return 2;
- case CUDNN_LSTM:
- return 8;
- case CUDNN_GRU:
- return 6;
- default:
- LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
- }
+ return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
+ weights, biases);
}
} // namespace
class CudnnRnnSequenceTensorDescriptor
- : public CudnnDescriptorCommon<dnn::RnnSequenceTensorDescriptor> {
- public:
+ : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(CUDAExecutor* parent, int seq_length,
int batch_size, int data_size,
- cudnnDataType_t data_type)
+ cudnnDataType_t data_type,
+ TensorDescriptor handle)
: parent_(parent),
seq_length_(seq_length),
batch_size_(batch_size),
data_size_(data_size),
- data_type_(data_type) {
- cudnnTensorDescriptor_t handle = nullptr;
- if (seq_length <= 0) {
- string error_msg =
- port::StrCat("sequence length must be positive: ", seq_length);
- LOG(ERROR) << error_msg;
- SetFailure(port::Status(port::error::UNKNOWN, error_msg));
- return;
- }
- cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle);
- CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
+ data_type_(data_type),
+ handle_(std::move(handle)),
+ handles_(seq_length, handle_.get()) {}
+
+ public:
+ CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
+ default;
+
+ static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
+ CUDAExecutor* parent, int seq_length, int batch_size, int data_size,
+ cudnnDataType_t data_type) {
+ CHECK_GT(seq_length, 0);
int dims[] = {batch_size, data_size, 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
- status = cudnnSetTensorNdDescriptor(
- /*tensorDesc=*/handle, /*dataType=*/data_type,
+ TensorDescriptor tensor_desc = CreateTensorDescriptor();
+ RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
+ /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
/*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
- /*strideA=*/strides);
- CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
- // Replicate handle across the number of steps.
- handles_.assign(seq_length, handle);
- }
-
- ~CudnnRnnSequenceTensorDescriptor() override {
- // Only the first one needs to be destroyed. All others are the same.
- cudnnStatus_t status = cudnnDestroyTensorDescriptor(handles_[0]);
- CUDNN_RETURN_IF_FAIL(status,
- "Failed to destroy sequence tensor descriptor");
+ /*strideA=*/strides));
+ return CudnnRnnSequenceTensorDescriptor(parent, seq_length, batch_size,
+ data_size, data_type,
+ std::move(tensor_desc));
}
const cudnnTensorDescriptor_t* handles() const {
- if (!ok()) return nullptr;
- CHECK(!handles_.empty()) << "handles cannot be empty";
return handles_.data();
}
@@ -1311,51 +1232,39 @@ class CudnnRnnSequenceTensorDescriptor
int batch_size_;
int data_size_;
cudnnDataType_t data_type_;
- std::vector<cudnnTensorDescriptor_t> handles_;
+ TensorDescriptor handle_;
+ std::vector<cudnnTensorDescriptor_t> handles_; // Copies of handle_.
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
};
-class CudnnRnnStateTensorDescriptor
- : public CudnnDescriptorCommon<dnn::RnnStateTensorDescriptor> {
+class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
public:
CudnnRnnStateTensorDescriptor(CUDAExecutor* parent, int num_layers,
int batch_size, int data_size,
cudnnDataType_t data_type)
: parent_(parent),
- handle_(nullptr),
+ handle_(CreateTensorDescriptor()),
num_layers_(num_layers),
batch_size_(batch_size),
data_size_(data_size),
data_type_(data_type) {
- cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_);
- CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
int dims[] = {num_layers, batch_size, data_size};
int strides[] = {dims[1] * dims[2], dims[2], 1};
- status = cudnnSetTensorNdDescriptor(
- /*tensorDesc=*/handle_, /*dataType=*/data_type,
+ CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
+ /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
/*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
- /*strideA=*/strides);
- CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
+ /*strideA=*/strides));
}
- ~CudnnRnnStateTensorDescriptor() override {
- if (!handle_) {
- cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN state tensor");
- }
- }
+ cudnnTensorDescriptor_t handle() const { return handle_.get(); }
- cudnnTensorDescriptor_t handle() const {
- if (!ok()) return nullptr;
- return handle_;
- }
int num_layers() const { return num_layers_; }
int batch_size() const { return batch_size_; }
int data_size() const { return data_size_; }
private:
CUDAExecutor* parent_;
- cudnnTensorDescriptor_t handle_;
+ TensorDescriptor handle_;
int num_layers_;
int batch_size_;
int data_size_;
@@ -1375,7 +1284,7 @@ struct RnnModelDims {
};
template <class T>
-bool ExtractAndCheckRnnForward(
+port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
const CudnnRnnDescriptor& rnn_desc,
const CudnnRnnSequenceTensorDescriptor& input_desc,
const DeviceMemory<T>& input_data,
@@ -1388,103 +1297,89 @@ bool ExtractAndCheckRnnForward(
const CudnnRnnStateTensorDescriptor& output_h_desc,
const DeviceMemory<T>& output_h_data,
const CudnnRnnStateTensorDescriptor& output_c_desc,
- const DeviceMemory<T>& output_c_data, RnnModelDims* model_dims) {
+ const DeviceMemory<T>& output_c_data) {
// extract model parameters
- model_dims->num_layers = rnn_desc.num_layers();
- model_dims->batch_size = input_desc.batch_size();
- model_dims->seq_length = input_desc.seq_length();
- model_dims->hidden_size = rnn_desc.hidden_size();
- model_dims->input_size = input_desc.data_size();
- model_dims->dir_count =
+ RnnModelDims model_dims;
+ model_dims.num_layers = rnn_desc.num_layers();
+ model_dims.batch_size = input_desc.batch_size();
+ model_dims.seq_length = input_desc.seq_length();
+ model_dims.hidden_size = rnn_desc.hidden_size();
+ model_dims.input_size = input_desc.data_size();
+ model_dims.dir_count =
(rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
// check parameters
if (!(input_h_desc.num_layers() ==
- model_dims->num_layers * model_dims->dir_count &&
- input_h_desc.batch_size() == model_dims->batch_size &&
- input_h_desc.data_size() == model_dims->hidden_size)) {
- LOG(ERROR) << "Invalid input_h shape";
- return false;
+ model_dims.num_layers * model_dims.dir_count &&
+ input_h_desc.batch_size() == model_dims.batch_size &&
+ input_h_desc.data_size() == model_dims.hidden_size)) {
+ return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
}
if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
input_h_desc.batch_size() == input_c_desc.batch_size() &&
input_h_desc.data_size() == input_c_desc.data_size())) {
- LOG(ERROR) << "Invalid input_c shape";
- return false;
+ return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
}
- if (!(output_desc.seq_length() == model_dims->seq_length &&
- output_desc.batch_size() == model_dims->batch_size &&
+ if (!(output_desc.seq_length() == model_dims.seq_length &&
+ output_desc.batch_size() == model_dims.batch_size &&
output_desc.data_size() ==
- model_dims->hidden_size * model_dims->dir_count)) {
- LOG(ERROR) << "Invalid output shape";
- return false;
+ model_dims.hidden_size * model_dims.dir_count)) {
+ return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
}
if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
input_h_desc.batch_size() == output_h_desc.batch_size() &&
input_h_desc.data_size() == output_h_desc.data_size())) {
- LOG(ERROR) << "Invalid output_h shape";
- return false;
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "Invalid output_h shape");
}
if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
input_h_desc.batch_size() == output_c_desc.batch_size() &&
input_h_desc.data_size() == output_c_desc.data_size())) {
- LOG(ERROR) << "Invalid output_h shape";
- return false;
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "Invalid output_c shape");
}
- return true;
+ return model_dims;
}
-bool CheckRNNParameterSize(const CudnnHandle& cudnn,
- const CudnnRnnDescriptor& rnn_desc,
- const CudnnRnnSequenceTensorDescriptor& input_desc) {
+port::Status CheckRNNParameterSize(
+ const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc) {
size_t params_size_in_bytes = 0;
- cudnnStatus_t status = cudnnGetRNNParamsSize(
+ RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/&params_size_in_bytes,
- /*dataType=*/rnn_desc.data_type());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "Unable to check RNN param size: " << ToString(status);
- return false;
+ /*dataType=*/rnn_desc.data_type()));
+ if (static_cast<int64>(params_size_in_bytes) !=
+ rnn_desc.ParamsSizeInBytes()) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "Mismatching RNN parameter size");
}
- return static_cast<int64>(params_size_in_bytes) ==
- rnn_desc.ParamsSizeInBytes();
+ return port::Status::OK();
}
-bool CreateRnnWorkspace(Stream* stream, const CudnnHandle& cudnn,
- const CudnnRnnDescriptor& rnn_desc,
- const CudnnRnnSequenceTensorDescriptor& input_desc,
- ScratchAllocator* workspace_allocator,
- DeviceMemory<uint8>* workspace) {
+port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
+ Stream* stream, const CudnnHandle& cudnn,
+ const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ ScratchAllocator* workspace_allocator) {
// Query the workspace size.
size_t workspace_size_in_bytes = 0;
- cudnnStatus_t status = cudnnGetRNNWorkspaceSize(
+ RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/input_desc.seq_length(), /*xDesc=*/input_desc.handles(),
- /*sizeInBytes=*/&workspace_size_in_bytes);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "Unable to query workspace size: " << ToString(status);
- return false;
- }
+ /*sizeInBytes=*/&workspace_size_in_bytes));
// Allocate the workspace.
- if (workspace_size_in_bytes > 0) {
- auto allocated =
- workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
- if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) {
- LOG(ERROR) << port::StrCat("Failed to allocate RNN workspace of ",
- workspace_size_in_bytes, " bytes.");
- return false;
- }
- } else {
- *workspace = DeviceMemory<uint8>();
+ if (workspace_size_in_bytes == 0) {
+ return DeviceMemory<uint8>();
}
- return true;
+ return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
}
} // namespace
template <class T>
-bool CudnnSupport::DoRnnForwardImpl(
+port::Status CudnnSupport::DoRnnForwardImpl(
Stream* stream, const CudnnRnnDescriptor& rnn_desc,
const CudnnRnnSequenceTensorDescriptor& input_desc,
const DeviceMemory<T>& input_data,
@@ -1501,57 +1396,34 @@ bool CudnnSupport::DoRnnForwardImpl(
ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result) {
- // extract model parameters
- RnnModelDims model_dims;
- bool res = ExtractAndCheckRnnForward(
- rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
- input_c_desc, input_c_data, params, output_desc, *output_data,
- output_h_desc, *output_h_data, output_c_desc, *output_c_data,
- &model_dims);
- if (!res) {
- LOG(ERROR) << "Invalid parameters for RNN Model";
- return false;
- }
+ SE_ASSIGN_OR_RETURN(
+ RnnModelDims model_dims,
+ ExtractAndCheckRnnForward(
+ rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
+ input_c_desc, input_c_data, params, output_desc, *output_data,
+ output_h_desc, *output_h_data, output_c_desc, *output_c_data));
auto cudnn = cudnn_->GetHandle(parent_, stream);
- // check params size
- if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) {
- LOG(ERROR) << "Invalid parameters";
- return false;
- }
-
- // create the workspace
- DeviceMemory<uint8> workspace;
- if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
- workspace_allocator, &workspace)) {
- LOG(ERROR) << "Unable to create rnn workspace";
- return false;
- }
+ SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
+ SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
+ CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
+ workspace_allocator))
// query the reserve space size
// allocate the reserve space
DeviceMemory<uint8> reserve_space;
if (is_training) {
size_t reserve_space_size_in_bytes = 0;
- cudnnStatus_t status = cudnnGetRNNTrainingReserveSize(
+ RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
- /*sizeInBytes=*/&reserve_space_size_in_bytes);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "Unable to query reserve space size: " << ToString(status);
- return false;
- }
+ /*sizeInBytes=*/&reserve_space_size_in_bytes));
if (reserve_space_size_in_bytes > 0) {
- auto allocated = reserve_space_allocator->AllocateBytes(
- stream, reserve_space_size_in_bytes);
- if (!allocated.ok() ||
- (reserve_space = allocated.ValueOrDie()) == nullptr) {
- LOG(ERROR) << "Failed to allocate RNN reserve space of "
- << reserve_space_size_in_bytes << " bytes.";
- return false;
- }
+ SE_ASSIGN_OR_RETURN(reserve_space,
+ reserve_space_allocator->AllocateBytes(
+ stream, reserve_space_size_in_bytes));
}
}
@@ -1559,20 +1431,16 @@ bool CudnnSupport::DoRnnForwardImpl(
const bool is_profiling = output_profile_result != nullptr;
if (is_profiling) {
timer.reset(new CUDATimer(parent_));
- if (!timer->Init()) {
- return false;
- }
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- if (!timer->Start(AsCUDAStream(stream))) {
- return false;
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
}
}
- // make the forward call
- cudnnStatus_t status;
+
if (!is_training) {
- status = cudnnRNNForwardInference(
+ RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
/*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
@@ -1582,9 +1450,9 @@ bool CudnnSupport::DoRnnForwardImpl(
/*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
/*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
/*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
- /*workSpaceSizeInBytes=*/workspace.size());
+ /*workSpaceSizeInBytes=*/workspace.size()));
} else {
- status = cudnnRNNForwardTraining(
+ RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
/*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
@@ -1596,35 +1464,24 @@ bool CudnnSupport::DoRnnForwardImpl(
/*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
/*workSpaceSizeInBytes=*/workspace.size(),
/*reserveSpace=*/reserve_space.opaque(),
- /*reserveSpaceSizeInBytes=*/reserve_space.size());
+ /*reserveSpaceSizeInBytes=*/reserve_space.size()));
}
+
if (is_profiling) {
if (!timer->Stop(AsCUDAStream(stream))) {
- return false;
- }
- if (status == CUDNN_STATUS_SUCCESS) {
- auto algo_desc = rnn_desc.algorithm_config().algorithm();
- output_profile_result->set_algorithm(algo_desc);
- output_profile_result->set_elapsed_time_in_ms(
- timer->GetElapsedMilliseconds());
- }
- }
- if (status != CUDNN_STATUS_SUCCESS) {
- // Silently return when we are profiling.
- if (!is_profiling) {
- LOG(ERROR) << "Failed to call "
- << (is_training ? "cudnnRNNForwardTraining "
- : "cudnnRNNForwardInference ")
- << ToString(status);
- return false;
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
+ auto algo_desc = rnn_desc.algorithm_config().algorithm();
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
}
- return true;
+ return port::Status::OK();
}
template <class T>
-bool CudnnSupport::DoRnnBackwardImpl(
+port::Status CudnnSupport::DoRnnBackwardImpl(
Stream* stream, const CudnnRnnDescriptor& rnn_desc,
const CudnnRnnSequenceTensorDescriptor& input_desc,
const DeviceMemory<T>& input_data,
@@ -1648,53 +1505,38 @@ bool CudnnSupport::DoRnnBackwardImpl(
DeviceMemory<uint8>* reserve_space_data,
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result) {
- // extract model parameters
- RnnModelDims model_dims;
- bool res = ExtractAndCheckRnnForward(
- rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
- input_c_desc, input_c_data, params, output_desc, output_data,
- output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims);
- if (!res) {
- LOG(ERROR) << "Invalid parameters for RNN Model";
- return false;
- }
+ SE_ASSIGN_OR_RETURN(
+ RnnModelDims model_dims,
+ ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
+ input_h_data, input_c_desc, input_c_data,
+ params, output_desc, output_data, output_h_desc,
+ output_h_data, output_c_desc, output_c_data));
auto cudnn = cudnn_->GetHandle(parent_, stream);
- // check params size
- if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) {
- LOG(ERROR) << "Invalid parameters";
- return false;
- }
-
- // create the workspace
- DeviceMemory<uint8> workspace;
- if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
- workspace_allocator, &workspace)) {
- LOG(ERROR) << "Unable to create rnn workspace";
- return false;
- }
+ SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
+ SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
+ CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
+ workspace_allocator));
std::unique_ptr<CUDATimer, TimerDeleter> timer;
const bool is_profiling = output_profile_result != nullptr;
if (is_profiling) {
timer.reset(new CUDATimer(parent_));
- if (!timer->Init()) {
- return false;
- }
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- if (!timer->Start(AsCUDAStream(stream))) {
- return false;
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
}
}
- // make the backward data call
- cudnnStatus_t status = cudnnRNNBackwardData(
+
+ RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/model_dims.seq_length, /*yDesc=*/output_desc.handles(),
/*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
- /*dy=*/output_backprop_data.opaque(), /*dhyDesc=*/output_h_desc.handle(),
+ /*dy=*/output_backprop_data.opaque(),
+ /*dhyDesc=*/output_h_desc.handle(),
/*dhy=*/output_h_backprop_data.opaque(),
/*dcyDesc=*/output_c_desc.handle(),
/*dcy=*/output_c_backprop_data.opaque(),
@@ -1705,24 +1547,17 @@ bool CudnnSupport::DoRnnBackwardImpl(
/*dhxDesc=*/input_h_desc.handle(),
/*dhx=*/input_h_backprop_data->opaque(),
/*dcxDesc=*/input_c_desc.handle(),
- /*dcx=*/input_c_backprop_data->opaque(), /*workspace=*/workspace.opaque(),
+ /*dcx=*/input_c_backprop_data->opaque(),
+ /*workspace=*/workspace.opaque(),
/*workSpaceSizeInBytes=*/workspace.size(),
/*reserveSpace=*/reserve_space_data->opaque(),
- /*reserveSpaceSizeInBytes=*/reserve_space_data->size());
-
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- timer->Stop(AsCUDAStream(stream));
- }
- LOG(ERROR) << "Failed to call cudnnRNNBackwardData: " << ToString(status);
- return false;
- }
+ /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
if (params_backprop_data != nullptr) {
// Clear the dw to zeros.
stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
// make the backward weight call
- status = cudnnRNNBackwardWeights(
+ RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
/*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
@@ -1732,19 +1567,12 @@ bool CudnnSupport::DoRnnBackwardImpl(
/*dwDesc=*/rnn_desc.params_handle(),
/*dw=*/params_backprop_data->opaque(),
/*reserveSpace=*/reserve_space_data->opaque(),
- /*reserveSpaceSizeInBytes=*/reserve_space_data->size());
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- timer->Stop(AsCUDAStream(stream));
- }
- LOG(ERROR) << "Failed to call cudnnRNNBackwardWeights: "
- << ToString(status);
- return false;
- }
+ /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
}
+
if (is_profiling) {
if (!timer->Stop(AsCUDAStream(stream))) {
- return false;
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
auto algo_desc = rnn_desc.algorithm_config().algorithm();
output_profile_result->set_algorithm(algo_desc);
@@ -1752,7 +1580,7 @@ bool CudnnSupport::DoRnnBackwardImpl(
timer->GetElapsedMilliseconds());
}
- return true;
+ return port::Status::OK();
}
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
@@ -1765,46 +1593,37 @@ CudnnSupport::createRnnDescriptor(
// Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
// not enqueueing anything into a stream, we pass in the null stream.
auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
- std::unique_ptr<CudnnRnnDescriptor> rnn_desc(new CudnnRnnDescriptor(
- cudnn, num_layers, hidden_size, input_size, batch_size,
- ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode),
- ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type),
- GetRnnComputeType(data_type), algorithm_config, dropout, seed,
- state_allocator));
- if (!rnn_desc->ok()) {
- return rnn_desc->Status();
- }
- return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>(
- std::move(rnn_desc));
+ SE_ASSIGN_OR_RETURN(
+ CudnnRnnDescriptor rnn_desc,
+ CudnnRnnDescriptor::Create(
+ cudnn, num_layers, hidden_size, input_size, batch_size,
+ ToCudnnRnnInputMode(input_mode),
+ ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
+ ToCudnnDataType(data_type), GetRnnComputeType(data_type),
+ algorithm_config, dropout, seed, state_allocator));
+ return std::unique_ptr<dnn::RnnDescriptor>(
+ new CudnnRnnDescriptor(std::move(rnn_desc)));
}
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
int data_size,
dnn::DataType data_type) {
- std::unique_ptr<CudnnRnnSequenceTensorDescriptor> seq_desc(
- new CudnnRnnSequenceTensorDescriptor(parent_, seq_length, batch_size,
- data_size,
- ToCudnnDataType(data_type)));
- if (!seq_desc->ok()) {
- return seq_desc->Status();
- }
- return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>(
- std::move(seq_desc));
+ SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
+ CudnnRnnSequenceTensorDescriptor::Create(
+ parent_, seq_length, batch_size, data_size,
+ ToCudnnDataType(data_type)));
+ return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
+ new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
}
port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
int data_size,
dnn::DataType data_type) {
- std::unique_ptr<CudnnRnnStateTensorDescriptor> state_desc(
+ return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
data_size, ToCudnnDataType(data_type)));
- if (!state_desc->ok()) {
- return state_desc->Status();
- }
- return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>(
- std::move(state_desc));
}
bool CudnnSupport::DoRnnForward(
@@ -1840,12 +1659,14 @@ bool CudnnSupport::DoRnnForward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnForwardImpl<Eigen::half>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, is_training, reserve_space_allocator, workspace_allocator,
- output_profile_result);
+ return IsStatusOk(
+ DoRnnForwardImpl<Eigen::half>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data, is_training,
+ reserve_space_allocator, workspace_allocator, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoRnnForward(
@@ -1880,12 +1701,14 @@ bool CudnnSupport::DoRnnForward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnForwardImpl<float>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, is_training, reserve_space_allocator, workspace_allocator,
- output_profile_result);
+ return IsStatusOk(
+ DoRnnForwardImpl<float>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data, is_training,
+ reserve_space_allocator, workspace_allocator, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoRnnForward(
@@ -1921,12 +1744,14 @@ bool CudnnSupport::DoRnnForward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnForwardImpl<double>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, is_training, reserve_space_allocator, workspace_allocator,
- output_profile_result);
+ return IsStatusOk(
+ DoRnnForwardImpl<double>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data, is_training,
+ reserve_space_allocator, workspace_allocator, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoRnnBackward(
@@ -1969,14 +1794,17 @@ bool CudnnSupport::DoRnnBackward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnBackwardImpl<Eigen::half>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, output_backprop_data, output_h_backprop_data,
- output_c_backprop_data, input_backprop_data, input_h_backprop_data,
- input_c_backprop_data, params_backprop_data, reserve_space_data,
- workspace_allocator, output_profile_result);
+ return IsStatusOk(
+ DoRnnBackwardImpl<Eigen::half>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data,
+ output_backprop_data, output_h_backprop_data, output_c_backprop_data,
+ input_backprop_data, input_h_backprop_data, input_c_backprop_data,
+ params_backprop_data, reserve_space_data, workspace_allocator,
+ output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoRnnBackward(
@@ -2018,14 +1846,17 @@ bool CudnnSupport::DoRnnBackward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnBackwardImpl<float>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, output_backprop_data, output_h_backprop_data,
- output_c_backprop_data, input_backprop_data, input_h_backprop_data,
- input_c_backprop_data, params_backprop_data, reserve_space_data,
- workspace_allocator, output_profile_result);
+ return IsStatusOk(
+ DoRnnBackwardImpl<float>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data,
+ output_backprop_data, output_h_backprop_data, output_c_backprop_data,
+ input_backprop_data, input_h_backprop_data, input_c_backprop_data,
+ params_backprop_data, reserve_space_data, workspace_allocator,
+ output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoRnnBackward(
@@ -2068,19 +1899,25 @@ bool CudnnSupport::DoRnnBackward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnBackwardImpl<double>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, output_backprop_data, output_h_backprop_data,
- output_c_backprop_data, input_backprop_data, input_h_backprop_data,
- input_c_backprop_data, params_backprop_data, reserve_space_data,
- workspace_allocator, output_profile_result);
+ return IsStatusOk(
+ DoRnnBackwardImpl<double>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data,
+ output_backprop_data, output_h_backprop_data, output_c_backprop_data,
+ input_backprop_data, input_h_backprop_data, input_c_backprop_data,
+ params_backprop_data, reserve_space_data, workspace_allocator,
+ output_profile_result),
+ /*report_error=*/!output_profile_result);
}
namespace {
-inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo(
+// TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
+// and backward filter.
+
+port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
const CudnnHandle& cudnn, const ScopedTensorDescriptor& input_nd,
const ScopedFilterDescriptor& filter,
const ScopedConvolutionDescriptor& conv,
@@ -2089,100 +1926,331 @@ inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo(
cudnnConvolutionFwdPreference_t preference =
specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
: CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
-
cudnnConvolutionFwdAlgo_t algo_to_use;
- auto status = cudnnGetConvolutionForwardAlgorithm(
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
- output_nd.handle(), preference, memory_limit_bytes, &algo_to_use);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS)
- << "Unable to find a suitable algorithm for doing forward convolution";
+ output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
+ return algo_to_use;
+}
+
+port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
+GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
+ const ScopedTensorDescriptor& input_nd,
+ const ScopedFilterDescriptor& filter,
+ const ScopedConvolutionDescriptor& conv,
+ const ScopedTensorDescriptor& output_nd,
+ bool specify_workspace_limit,
+ size_t memory_limit_bytes) {
+ cudnnConvolutionBwdDataPreference_t preference =
+ specify_workspace_limit
+ ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
+ : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
+ cudnnConvolutionBwdDataAlgo_t algo_to_use;
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
+ cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
+ input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
+ return algo_to_use;
+}
+
+port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
+GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
+ const ScopedTensorDescriptor& input_nd,
+ const ScopedFilterDescriptor& filter,
+ const ScopedConvolutionDescriptor& conv,
+ const ScopedTensorDescriptor& output_nd,
+ bool specify_workspace_limit,
+ size_t memory_limit_bytes) {
+ cudnnConvolutionBwdFilterPreference_t preference =
+ specify_workspace_limit
+ ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
+ : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
+ cudnnConvolutionBwdFilterAlgo_t algo_to_use;
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
+ cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
+ filter.handle(), preference, memory_limit_bytes, &algo_to_use));
return algo_to_use;
}
-dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm(
+port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
Stream* stream, const CudnnHandle& cudnn,
- const dnn::AlgorithmConfig& algorithm_config, bool is_profiling,
+ const dnn::AlgorithmDesc& algorithm_desc,
const ScopedTensorDescriptor& input_nd,
const ScopedFilterDescriptor& filter,
const ScopedConvolutionDescriptor& conv,
const ScopedTensorDescriptor& output_nd,
- ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
- cudnnConvolutionFwdAlgo_t algo;
- bool use_tensor_ops;
- if (algorithm_config.algorithm().is_default()) {
- use_tensor_ops = true;
+ ScratchAllocator* scratch_allocator) {
+ // TODO(csigg): This has side effects on the convolution descriptor. It is
+ // functionally correct because the convolution is run with the algorithm of
+ // the last call to this function, but should be fixed anyway.
+ conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
-
- algo = GetCudnnConvolutionForwardAlgo(
- cudnn, input_nd, filter, conv, output_nd,
- /*specify_workspace_limit=*/scratch_allocator != nullptr,
- memory_limit_bytes);
- } else {
- use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
- algo = ToConvForwardAlgo(algorithm_config.algorithm());
- }
+ // Query the size of the workspace and allocate it.
size_t size_in_bytes;
- auto status = cudnnGetConvolutionForwardWorkspaceSize(
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
cudnn.handle(),
/*xDesc=*/input_nd.handle(),
/*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*yDesc=*/output_nd.handle(), /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
+ /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
+ /*sizeInBytes=*/&size_in_bytes));
int64 size_in_bytes_int64 = size_in_bytes;
- if (TF_PREDICT_FALSE(status != CUDNN_STATUS_SUCCESS)) {
- CHECK(is_profiling) << "Cannot query the size of workspace needed "
- "for the specified algorithm: "
- << algorithm_config.algorithm().algo_id() << " "
- << ToString(status);
- // Silently return when we are profiling.
- return dnn::AlgorithmDesc();
+
+ if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
+ return port::Status(
+ port::error::INTERNAL,
+ "cudnnGetConvolutionForwardWorkspaceSize() returned "
+ "negative sizeInBytes value. This could be a cudnn bug.");
}
+
+ if (size_in_bytes_int64 == 0) {
+ return DeviceMemory<uint8>();
+ }
+
+ if (TF_PREDICT_FALSE(!scratch_allocator)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No scratch allocator provided");
+ }
+
+ return scratch_allocator->AllocateBytes(stream, size_in_bytes);
+}
+
+port::StatusOr<DeviceMemory<uint8>>
+AllocateCudnnConvolutionBackwardDataWorkspace(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmDesc& algorithm_desc,
+ const ScopedTensorDescriptor& input_nd,
+ const ScopedFilterDescriptor& filter,
+ const ScopedConvolutionDescriptor& conv,
+ const ScopedTensorDescriptor& output_nd,
+ ScratchAllocator* scratch_allocator) {
+ // TODO(csigg): This has side effects on the convolution descriptor. It is
+ // functionally correct because the convolution is run with the algorithm of
+ // the last call to this function, but should be fixed anyway.
+ conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+
+ // Query the size of the workspace and allocate it.
+ size_t size_in_bytes;
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
+ cudnn.handle(),
+ /*wDesc=*/filter.handle(),
+ /*dyDesc=*/output_nd.handle(),
+ /*convDesc=*/conv.handle(),
+ /*dxDesc=*/input_nd.handle(),
+ /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
+ /*sizeInBytes=*/&size_in_bytes));
+ int64 size_in_bytes_int64 = size_in_bytes;
+
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
- LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- if (TF_PREDICT_TRUE(is_profiling)) {
- return dnn::AlgorithmDesc();
- }
- } else if (size_in_bytes_int64 > 0) {
- port::StatusOr<DeviceMemory<uint8>> allocated;
- if (TF_PREDICT_TRUE(scratch_allocator)) {
- allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (TF_PREDICT_TRUE(allocated.ok())) {
- *scratch = allocated.ValueOrDie();
- } else {
- if (TF_PREDICT_TRUE(is_profiling)) {
- // Silently return when we are profiling.
- return dnn::AlgorithmDesc();
- }
- LOG(WARNING) << allocated.status().error_message();
- // For the int8 case, we fail at this point since the no_scratch
- // algorithm should be set to dnn::kDefaultAlgorithm.
- CHECK(!algorithm_config.algorithm_no_scratch().is_default())
- << "The primary convolution algorithm failed memory allocation, "
- "while a secondary algorithm is not provided.";
- }
- }
- if (TF_PREDICT_FALSE(!allocated.ok())) {
- if (algorithm_config.algorithm_no_scratch().is_default()) {
- use_tensor_ops = true;
- algo = GetCudnnConvolutionForwardAlgo(
- cudnn, input_nd, filter, conv, output_nd,
- /*specify_workspace_limit=*/false, 0);
- } else {
- use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
- algo = ToConvForwardAlgo(algorithm_config.algorithm_no_scratch());
- }
- }
+ return port::Status(
+ port::error::INTERNAL,
+ "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
+ "negative sizeInBytes value. This could be a cudnn bug.");
+ }
+
+ if (size_in_bytes_int64 == 0) {
+ return DeviceMemory<uint8>();
+ }
+
+ if (TF_PREDICT_FALSE(!scratch_allocator)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No scratch allocator provided");
+ }
+
+ return scratch_allocator->AllocateBytes(stream, size_in_bytes);
+}
+
+port::StatusOr<DeviceMemory<uint8>>
+AllocateCudnnConvolutionBackwardFilterWorkspace(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmDesc& algorithm_desc,
+ const ScopedTensorDescriptor& input_nd,
+ const ScopedFilterDescriptor& filter,
+ const ScopedConvolutionDescriptor& conv,
+ const ScopedTensorDescriptor& output_nd,
+ ScratchAllocator* scratch_allocator) {
+ // TODO(csigg): This has side effects on the convolution descriptor. It is
+ // functionally correct because the convolution is run with the algorithm of
+ // the last call to this function, but should be fixed anyway.
+ conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+
+ // Query the size of the workspace and allocate it.
+ size_t size_in_bytes;
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
+ cudnn.handle(),
+ /*xDesc=*/input_nd.handle(),
+ /*dyDesc=*/output_nd.handle(),
+ /*convDesc=*/conv.handle(),
+ /*gradDesc=*/filter.handle(),
+ /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
+ /*sizeInBytes=*/&size_in_bytes));
+ int64 size_in_bytes_int64 = size_in_bytes;
+
+ if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
+ return port::Status(
+ port::error::INTERNAL,
+ "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
+ "negative sizeInBytes value. This could be a cudnn bug.");
+ }
+
+ if (size_in_bytes_int64 == 0) {
+ return DeviceMemory<uint8>();
+ }
+
+ if (TF_PREDICT_FALSE(!scratch_allocator)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No scratch allocator provided");
+ }
+
+ return scratch_allocator->AllocateBytes(stream, size_in_bytes);
+}
+
+port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmConfig& algorithm_config,
+ const ScopedTensorDescriptor& input_nd,
+ const ScopedFilterDescriptor& filter,
+ const ScopedConvolutionDescriptor& conv,
+ const ScopedTensorDescriptor& output_nd,
+ ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
+ dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
+ if (algorithm_config.algorithm().is_default()) {
+ // Pick fastest algorithm within memory limit according to cuDNN's
+ // heuristics.
+ bool specify_workspace_limit = scratch_allocator != nullptr;
+ auto memory_limit_bytes =
+ specify_workspace_limit
+ ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
+ : 0ll;
+ SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
+ GetCudnnConvolutionForwardAlgo(
+ cudnn, input_nd, filter, conv, output_nd,
+ specify_workspace_limit, memory_limit_bytes));
+ algo_desc = dnn::AlgorithmDesc(
+ algo, algorithm_config.algorithm().tensor_ops_enabled());
+ }
+
+ auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
+ stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ scratch_allocator);
+
+ if (scratch_or.ok()) {
+ *scratch = scratch_or.ValueOrDie();
+ return algo_desc;
+ }
+
+ // Failed to allocate workspace for the first algorithm, fall back to the
+ // no_scratch algorithm.
+ if (algorithm_config.algorithm_no_scratch().is_default()) {
+ return port::Status(
+ port::error::INVALID_ARGUMENT,
+ "The primary convolution algorithm failed memory allocation, "
+ "while a secondary algorithm is not provided.");
+ }
+
+ SE_ASSIGN_OR_RETURN(
+ *scratch, AllocateCudnnConvolutionForwardWorkspace(
+ stream, cudnn, algorithm_config.algorithm_no_scratch(),
+ input_nd, filter, conv, output_nd, scratch_allocator));
+ return algorithm_config.algorithm_no_scratch();
+}
+
+port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmConfig& algorithm_config,
+ const ScopedTensorDescriptor& input_nd,
+ const ScopedFilterDescriptor& filter,
+ const ScopedConvolutionDescriptor& conv,
+ const ScopedTensorDescriptor& output_nd,
+ ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
+ dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
+ if (algorithm_config.algorithm().is_default()) {
+ // Pick fastest algorithm within memory limit according to cuDNN's
+ // heuristics.
+ bool specify_workspace_limit = scratch_allocator != nullptr;
+ auto memory_limit_bytes =
+ specify_workspace_limit
+ ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
+ : 0ll;
+ SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
+ GetCudnnConvolutionBackwardDataAlgo(
+ cudnn, input_nd, filter, conv, output_nd,
+ specify_workspace_limit, memory_limit_bytes));
+ algo_desc = dnn::AlgorithmDesc(
+ algo, algorithm_config.algorithm().tensor_ops_enabled());
+ }
+
+ auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
+ stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ scratch_allocator);
+
+ if (scratch_or.ok()) {
+ *scratch = scratch_or.ValueOrDie();
+ return algo_desc;
}
- return dnn::AlgorithmDesc(algo, use_tensor_ops);
+ // Failed to allocate workspace for the first algorithm, fall back to the
+ // no_scratch algorithm.
+ if (algorithm_config.algorithm_no_scratch().is_default()) {
+ return port::Status(
+ port::error::INVALID_ARGUMENT,
+ "The primary convolution algorithm failed memory allocation, "
+ "while a secondary algorithm is not provided.");
+ }
+
+ SE_ASSIGN_OR_RETURN(
+ *scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
+ stream, cudnn, algorithm_config.algorithm_no_scratch(),
+ input_nd, filter, conv, output_nd, scratch_allocator));
+ return algorithm_config.algorithm_no_scratch();
+}
+
+port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmConfig& algorithm_config,
+ const ScopedTensorDescriptor& input_nd,
+ const ScopedFilterDescriptor& filter,
+ const ScopedConvolutionDescriptor& conv,
+ const ScopedTensorDescriptor& output_nd,
+ ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
+ dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
+ if (algorithm_config.algorithm().is_default()) {
+ // Pick fastest algorithm within memory limit according to cuDNN's
+ // heuristics.
+ bool specify_workspace_limit = scratch_allocator != nullptr;
+ auto memory_limit_bytes =
+ specify_workspace_limit
+ ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
+ : 0ll;
+ SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
+ GetCudnnConvolutionBackwardFilterAlgo(
+ cudnn, input_nd, filter, conv, output_nd,
+ specify_workspace_limit, memory_limit_bytes));
+ algo_desc = dnn::AlgorithmDesc(
+ algo, algorithm_config.algorithm().tensor_ops_enabled());
+ }
+
+ auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
+ stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ scratch_allocator);
+
+ if (scratch_or.ok()) {
+ *scratch = scratch_or.ValueOrDie();
+ return algo_desc;
+ }
+
+ // Failed to allocate workspace for the first algorithm, fall back to the
+ // no_scratch algorithm.
+ if (algorithm_config.algorithm_no_scratch().is_default()) {
+ return port::Status(
+ port::error::INVALID_ARGUMENT,
+ "The primary convolution algorithm failed memory allocation, "
+ "while a secondary algorithm is not provided.");
+ }
+
+ SE_ASSIGN_OR_RETURN(*scratch,
+ AllocateCudnnConvolutionBackwardFilterWorkspace(
+ stream, cudnn, algorithm_config.algorithm(), input_nd,
+ filter, conv, output_nd, scratch_allocator));
+ return algorithm_config.algorithm_no_scratch();
}
// A helper class to set env-vars and choose options for cudnn-related
@@ -2282,8 +2350,6 @@ struct RnnDoFP32ComputationFP16Input {
static constexpr bool kDefaultFlag = false;
};
-// A helper function to return the internal compute type for
-// RNNs in cudnn.
cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
switch (data_type) {
case dnn::DataType::kFloat:
@@ -2304,7 +2370,7 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
} // namespace
template <class T>
-bool CudnnSupport::DoConvolveImpl(
+port::Status CudnnSupport::DoConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
const dnn::FilterDescriptor& filter_descriptor,
@@ -2334,177 +2400,48 @@ bool CudnnSupport::DoConvolveImpl(
: static_cast<void*>(&fbeta);
const bool is_profiling = output_profile_result != nullptr;
- cudnnConvolutionFwdAlgo_t algo;
- bool use_tensor_ops;
- DeviceMemory<uint8> scratch;
- // TODO(pauldonnelly): Replace the following code with a call to
- // GetCudnnConvolutionForwardAlgorithm().
- if (algorithm_config.algorithm().is_default()) {
- // With the default algorithm, use Cudnn's heuristics.
- auto get_algorithm = [&](bool specify_limit) {
- cudnnConvolutionFwdPreference_t preference =
- specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
- : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
-
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
-
- cudnnConvolutionFwdAlgo_t algo_to_use;
- auto status = cudnnGetConvolutionForwardAlgorithm(
- cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
- output_nd.handle(),
- /*preference=*/preference,
- /*memoryLimitInBytes=*/memory_limit_bytes,
- /*algo=*/&algo_to_use);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
- "algorithm for doing forward "
- "convolution";
- return algo_to_use;
- };
-
- algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
- use_tensor_ops = true;
- if (scratch_allocator != nullptr) {
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionForwardWorkspaceSize(
- cudnn.handle(),
- /*xDesc=*/input_nd.handle(),
- /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*yDesc=*/output_nd.handle(), /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- int64 size_in_bytes_int64 = size_in_bytes;
- if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) {
- if (size_in_bytes_int64 > 0) {
- auto allocated =
- scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- } else {
- LOG(WARNING)
- << "cudnnGetConvolutionForwardWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
- }
+ DeviceMemory<uint8> scratch;
+ SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc,
+ GetCudnnConvolutionForwardAlgorithm(
+ stream, cudnn, algorithm_config, input_nd, filter,
+ conv, output_nd, scratch_allocator, &scratch));
- // If we didn't allocate any scratch space (perhaps because of failed
- // allocation), we force a switch back to the "no workspace" algorithm.
- if (scratch == nullptr) {
- algo = get_algorithm(/*specify_limit=*/false);
- }
- } else {
- // An algorithm has been specified.
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
- algo = ToConvForwardAlgo(algotype);
- use_tensor_ops = algotype.tensor_ops_enabled();
- conv.set_use_tensor_op_math(use_tensor_ops);
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionForwardWorkspaceSize(
- cudnn.handle(),
- /*xDesc=*/input_nd.handle(),
- /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*yDesc=*/output_nd.handle(), /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- // Silently return when we are profiling.
- return false;
- }
- LOG(FATAL) << "Cannot query the size of workspace needed for the given "
- "algorithm: "
- << algorithm_config.algorithm().algo_id();
- }
- int64 size_in_bytes_int64 = size_in_bytes;
- if (size_in_bytes_int64 > 0) {
- if (scratch_allocator == nullptr) {
- LOG(FATAL) << "An allocator must be specified when scratch memory is "
- "needed";
- }
- auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (is_profiling && !allocated.ok()) {
- // Silently return when we are profiling.
- return false;
- }
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- if (scratch == nullptr) {
- CHECK(!algorithm_config.algorithm_no_scratch().is_default())
- << "The primary convolution algorithm failed memory allocation, "
- "while a secondary algorithm is not provided.";
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
- algo = ToConvForwardAlgo(algotype);
- use_tensor_ops = algotype.tensor_ops_enabled();
- conv.set_use_tensor_op_math(use_tensor_ops);
- }
- } else if (size_in_bytes_int64 < 0) {
- LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
- std::unique_ptr<CUDATimer> timer;
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (is_profiling) {
timer.reset(new CUDATimer(parent_)); // NOLINT
- if (!timer->Init()) {
- return false;
- }
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- if (!timer->Start(AsCUDAStream(stream))) {
- timer->Destroy();
- return false;
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
}
}
- auto status = cudnnConvolutionForward(
+
+ RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
cudnn.handle(),
/*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
/*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
/*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
- /*algo=*/algo, /*workSpace=*/scratch.opaque(),
+ /*algo=*/ToConvForwardAlgo(algo_desc), /*workSpace=*/scratch.opaque(),
/*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/beta,
- /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
+ /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque()));
if (is_profiling) {
if (!timer->Stop(AsCUDAStream(stream))) {
- timer->Destroy();
- return false;
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
- if (status == CUDNN_STATUS_SUCCESS) {
- dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
- output_profile_result->set_algorithm(algotype);
- output_profile_result->set_elapsed_time_in_ms(
- timer->GetElapsedMilliseconds());
- }
- timer->Destroy();
- }
-
- if (status != CUDNN_STATUS_SUCCESS) {
- // Silently return when we are profiling.
- if (!is_profiling) {
- LOG(ERROR) << "failed to enqueue convolution on stream: "
- << ToString(status);
- }
- return false;
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
}
- return true;
+ return port::Status::OK();
}
template <typename Type, typename BiasType, typename ScaleType,
int cudnn_data_type, int cudnn_compute_type>
-bool CudnnSupport::DoFusedConvolveImpl(
+port::Status CudnnSupport::DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
const dnn::FilterDescriptor& filter_descriptor,
@@ -2517,6 +2454,12 @@ bool CudnnSupport::DoFusedConvolveImpl(
DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
+ if (activation_mode != dnn::ActivationMode::kRelu) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "cudnnConvolutionBiasActivationForward() only supports "
+ "Relu activation.");
+ }
+
ScopedTensorDescriptor conv_input_nd(
conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
ScopedTensorDescriptor output_nd(
@@ -2528,38 +2471,24 @@ bool CudnnSupport::DoFusedConvolveImpl(
convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
auto cudnn = cudnn_->GetHandle(parent_, stream);
+
const bool is_profiling = output_profile_result != nullptr;
- DeviceMemory<uint8> scratch;
- dnn::AlgorithmDesc algotype = GetCudnnConvolutionForwardAlgorithm(
- stream, cudnn, algorithm_config, is_profiling, conv_input_nd, filter,
- conv, output_nd, scratch_allocator, &scratch);
- if (algotype.is_default()) {
- if (!is_profiling) {
- LOG(ERROR) << "No suitable algorithm found";
- }
- return false;
- }
- auto algo = static_cast<cudnnConvolutionFwdAlgo_t>(algotype.algo_id());
- conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
- if (activation_mode != dnn::ActivationMode::kRelu) {
- LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only supports Relu "
- "activation.";
- return false;
- }
+ DeviceMemory<uint8> scratch;
+ SE_ASSIGN_OR_RETURN(
+ dnn::AlgorithmDesc algo_desc,
+ GetCudnnConvolutionForwardAlgorithm(
+ stream, cudnn, algorithm_config, conv_input_nd, filter, conv,
+ output_nd, scratch_allocator, &scratch));
- std::unique_ptr<CUDATimer> timer;
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (is_profiling) {
timer.reset(new CUDATimer(parent_)); // NOLINT
- if (!timer->Init()) {
- return false;
- }
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- if (!timer->Start(AsCUDAStream(stream))) {
- timer->Destroy();
- return false;
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
}
}
// CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
@@ -2576,7 +2505,8 @@ bool CudnnSupport::DoFusedConvolveImpl(
<< "\nconv_input_data.opaque() = " << conv_input_data.opaque()
<< "\nfilter.handle() = " << filter.handle()
<< "\nfilter_data.opaque() = " << filter_data.opaque()
- << "\nconv.handle() = " << conv.handle() << "\nalgo = " << algo
+ << "\nconv.handle() = " << conv.handle()
+ << "\nalgo = " << algo_desc.algo_id()
<< "\nscratch.opaque() = " << scratch.opaque()
<< "\nscratch.size() = " << scratch.size()
<< "\nside_input_scale = " << side_input_scale
@@ -2588,41 +2518,29 @@ bool CudnnSupport::DoFusedConvolveImpl(
<< "\noutput_nd.handle() = " << output_nd.handle()
<< "\noutput_data->opaque() = " << output_data->opaque();
- auto status = cudnnConvolutionBiasActivationForward(
+ RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
cudnn.handle(),
/*alpha1=*/&conv_input_scale,
/*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
/*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
- /*convDesc=*/conv.handle(), algo, /*workSpace=*/scratch.opaque(),
+ /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc),
+ /*workSpace=*/scratch.opaque(),
/*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
/*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
/*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
/*activationDesc=*/activation_desc.handle(),
- /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
+ /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque()));
if (is_profiling) {
if (!timer->Stop(AsCUDAStream(stream))) {
- timer->Destroy();
- return false;
- }
- if (status == CUDNN_STATUS_SUCCESS) {
- output_profile_result->set_algorithm(algotype);
- output_profile_result->set_elapsed_time_in_ms(
- timer->GetElapsedMilliseconds());
- }
- timer->Destroy();
- }
-
- if (status != CUDNN_STATUS_SUCCESS) {
- // Silently return when we are profiling.
- if (!is_profiling) {
- LOG(ERROR) << "failed to enqueue convolution on stream: "
- << ToString(status);
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
- return false;
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
}
- return true;
+ return port::Status::OK();
}
bool CudnnSupport::GetConvolveAlgorithms(
@@ -2746,11 +2664,13 @@ bool CudnnSupport::DoBatchNormalizationForward(
DeviceMemory<float>* saved_inv_var, bool is_training,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
- return DoBatchNormalizationForwardImpl<float, float>(
- stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset,
- estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y,
- batch_mean, batch_var, saved_mean, saved_inv_var, is_training,
- std::move(var_to_inv_var), std::move(inv_var_to_var));
+ return IsStatusOk(
+ DoBatchNormalizationForwardImpl<float, float>(
+ stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
+ offset, estimated_mean, estimated_variance, x_desc, scale_offset_desc,
+ epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var,
+ is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)),
+ /*report_error=*/true);
}
bool CudnnSupport::DoBatchNormalizationForward(
@@ -2765,15 +2685,17 @@ bool CudnnSupport::DoBatchNormalizationForward(
DeviceMemory<float>* saved_inv_var, bool is_training,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
- return DoBatchNormalizationForwardImpl<Eigen::half, float>(
- stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
- estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y,
- batch_mean, batch_var, saved_mean, saved_inv_var, is_training,
- std::move(var_to_inv_var), std::move(inv_var_to_var));
+ return IsStatusOk(
+ DoBatchNormalizationForwardImpl<Eigen::half, float>(
+ stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
+ estimated_mean, estimated_variance, x_desc, scale_offset_desc,
+ epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var,
+ is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)),
+ /*report_error=*/true);
}
template <class T, class U>
-bool CudnnSupport::DoBatchNormalizationForwardImpl(
+port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
Stream* stream, dnn::DataType input_data_type,
dnn::DataType scale_data_type, const DeviceMemory<T>& x,
const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
@@ -2798,7 +2720,6 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
float zero = 0.0;
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = CUDNN_STATUS_SUCCESS;
if (is_training) {
CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
<< "batch_mean and batch_var must both be null or both be non-null";
@@ -2815,26 +2736,21 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
batch_var_opaque = nullptr;
}
- status = cudnnBatchNormalizationForwardTraining(
+ RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque,
batch_var_opaque, epsilon, saved_mean->opaque(),
- saved_inv_var->opaque());
+ saved_inv_var->opaque()));
} else {
const void* maybe_inv_var = estimated_variance.opaque();
- status = cudnnBatchNormalizationForwardInference(
+ RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
- epsilon);
+ epsilon));
}
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue forward batch normalization on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ return port::Status::OK();
}
bool CudnnSupport::DoBatchNormalizationBackward(
@@ -2845,10 +2761,11 @@ bool CudnnSupport::DoBatchNormalizationBackward(
const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
DeviceMemory<float>* offset_backprop) {
- return DoBatchNormalizationBackwardImpl(
- stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop, x, scale, mean,
- inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
- offset_backprop);
+ return IsStatusOk(DoBatchNormalizationBackwardImpl(
+ stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop,
+ x, scale, mean, inv_var, x_desc, scale_offset_desc,
+ epsilon, x_backprop, scale_backprop, offset_backprop),
+ /*report_error=*/true);
}
bool CudnnSupport::DoBatchNormalizationBackward(
@@ -2859,14 +2776,15 @@ bool CudnnSupport::DoBatchNormalizationBackward(
const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
DeviceMemory<float>* offset_backprop) {
- return DoBatchNormalizationBackwardImpl(
- stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop, x, scale, mean,
- inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
- offset_backprop);
+ return IsStatusOk(DoBatchNormalizationBackwardImpl(
+ stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop,
+ x, scale, mean, inv_var, x_desc, scale_offset_desc,
+ epsilon, x_backprop, scale_backprop, offset_backprop),
+ /*report_error=*/true);
}
template <class T, class U>
-bool CudnnSupport::DoBatchNormalizationBackwardImpl(
+port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
Stream* stream, int cudnn_input_type, int cudnn_scale_type,
const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
@@ -2889,19 +2807,14 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnBatchNormalizationBackward(
+ RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
x_descriptor.handle(), x_backprop->opaque(),
scale_offset_descriptor.handle(), scale.opaque(),
scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
- mean.opaque(), inv_var.opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue backward batch normalization on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ mean.opaque(), inv_var.opaque()));
+ return port::Status::OK();
}
bool CudnnSupport::DoConvolve(
@@ -2914,10 +2827,12 @@ bool CudnnSupport::DoConvolve(
DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveImpl<float>(
- stream, batch_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveImpl<float>(
+ stream, batch_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, output_data,
+ scratch_allocator, algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolve(
@@ -2930,10 +2845,12 @@ bool CudnnSupport::DoConvolve(
DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveImpl<double>(
- stream, batch_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveImpl<double>(
+ stream, batch_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, output_data,
+ scratch_allocator, algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolve(
@@ -2946,10 +2863,12 @@ bool CudnnSupport::DoConvolve(
DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveImpl<Eigen::half>(
- stream, batch_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveImpl<Eigen::half>(
+ stream, batch_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, output_data,
+ scratch_allocator, algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoFusedConvolve(
@@ -2965,13 +2884,15 @@ bool CudnnSupport::DoFusedConvolve(
DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
- CUDNN_DATA_DOUBLE>(
- stream, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor, side_input_data,
- side_input_scale, bias_descriptor, biases, activation_mode,
- output_descriptor, output_data, scratch_allocator, algorithm_config,
- output_profile_result);
+ return IsStatusOk(
+ DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
+ CUDNN_DATA_DOUBLE>(
+ stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoFusedConvolve(
@@ -2987,13 +2908,15 @@ bool CudnnSupport::DoFusedConvolve(
DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
- CUDNN_DATA_FLOAT>(
- stream, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor, side_input_data,
- side_input_scale, bias_descriptor, biases, activation_mode,
- output_descriptor, output_data, scratch_allocator, algorithm_config,
- output_profile_result);
+ return IsStatusOk(
+ DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
+ CUDNN_DATA_FLOAT>(
+ stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoFusedConvolve(
@@ -3010,13 +2933,15 @@ bool CudnnSupport::DoFusedConvolve(
DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
- CUDNN_DATA_FLOAT>(
- stream, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor, side_input_data,
- side_input_scale, bias_descriptor, biases, activation_mode,
- output_descriptor, output_data, scratch_allocator, algorithm_config,
- output_profile_result);
+ return IsStatusOk(
+ DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
+ CUDNN_DATA_FLOAT>(
+ stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoFusedConvolve(
@@ -3040,13 +2965,15 @@ bool CudnnSupport::DoFusedConvolve(
"supported on GPUs with compute capability 6.1 or later.";
return false;
}
- return DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
- CUDNN_DATA_INT32>(
- stream, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor, side_input_data,
- side_input_scale, bias_descriptor, biases, activation_mode,
- output_descriptor, output_data, scratch_allocator, algorithm_config,
- output_profile_result);
+ return IsStatusOk(
+ DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
+ CUDNN_DATA_INT32>(
+ stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoTransformTensor(Stream* stream,
@@ -3062,22 +2989,17 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
ScopedTensorDescriptor output_tensor_desc(
output_desc, ToCudnnDataType(output_type, output_desc.layout()));
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnTransformTensor(
- cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
- &beta, output_tensor_desc.handle(), output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "Could not transform a tensor with layout "
- << input_desc.ToString() << " and data type "
- << static_cast<int>(input_type) << " to another with layout "
- << output_desc.ToString() << " and data type "
- << static_cast<int>(output_type) << ": " << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
+ cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
+ &beta, output_tensor_desc.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
template <class T>
-bool CudnnSupport::DoConvolveBackwardDataImpl(
+port::Status CudnnSupport::DoConvolveBackwardDataImpl(
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
const dnn::BatchDescriptor& output_descriptor,
@@ -3108,139 +3030,25 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
GetConvComputeType<T>());
const bool is_profiling = output_profile_result != nullptr;
- cudnnConvolutionBwdDataAlgo_t algo;
- DeviceMemory<uint8> scratch;
-
- if (algorithm_config.algorithm().is_default()) {
- // With the default algorithm, use Cudnn's heuristics.
- auto get_algorithm =
- [&](bool specify_limit) -> cudnnConvolutionBwdDataAlgo_t {
- cudnnConvolutionBwdDataPreference_t preference =
- specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
- : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
-
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
- cudnnConvolutionBwdDataAlgo_t algo_to_use;
- cudnnStatus_t status = cudnnGetConvolutionBackwardDataAlgorithm(
- cudnn.handle(),
- /*filterDesc=*/filter.handle(),
- /*diffDesc=*/out_back_nd.handle(),
- /*convDesc=*/conv.handle(),
- /*gradDesc=*/in_back_nd.handle(),
- /*preference=*/preference,
- /*memoryLimitInBytes=*/memory_limit_bytes,
- /*algo=*/&algo_to_use);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
- "algorithm for doing backward "
- "data convolution";
- return algo_to_use;
- };
-
- algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
-
- if (scratch_allocator != nullptr) {
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionBackwardDataWorkspaceSize(
- cudnn.handle(),
- /*filterDesc=*/filter.handle(),
- /*diffDesc=*/out_back_nd.handle(),
- /*convDesc=*/conv.handle(),
- /*gradDesc=*/in_back_nd.handle(),
- /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- int64 size_in_bytes_int64 = size_in_bytes;
- if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) {
- if (size_in_bytes_int64 > 0) {
- auto allocated =
- scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- } else {
- LOG(WARNING)
- << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
- }
- // If we didn't allocate any scratch space (perhaps because of failed
- // allocation), we force a switch back to the "no workspace" algorithm.
- if (scratch == nullptr) {
- algo = get_algorithm(/*specify_limit=*/false);
- }
- } else {
- // An algorithm has been specified.
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
- algo = ToConvBackwardDataAlgo(algotype);
- conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionBackwardDataWorkspaceSize(
- cudnn.handle(),
- /*filterDesc=*/filter.handle(),
- /*diffDesc=*/out_back_nd.handle(),
- /*convDesc=*/conv.handle(),
- /*gradDesc=*/in_back_nd.handle(),
- /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- // Silently return when we are profiling.
- return false;
- }
- LOG(FATAL) << "Cannot query the size of workspace needed for the given "
- "algorithm: "
- << algorithm_config.algorithm().algo_id();
- }
- int64 size_in_bytes_int64 = size_in_bytes;
- if (size_in_bytes_int64 > 0) {
- if (scratch_allocator == nullptr) {
- LOG(FATAL) << "An allocator must be specified when scratch memory is "
- "needed";
- }
- auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (is_profiling && !allocated.ok()) {
- // Silently return when we are profiling.
- return false;
- }
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- if (scratch == nullptr) {
- CHECK(!algorithm_config.algorithm_no_scratch().is_default())
- << "The primary convolution algorithm failed memory allocation, "
- "while a secondary algorithm is not provided.";
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
- algo = ToConvBackwardDataAlgo(algotype);
- conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
- }
- } else if (size_in_bytes_int64 < 0) {
- LOG(WARNING) << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
+ DeviceMemory<uint8> scratch;
+ SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc,
+ GetCudnnConvolutionBackwardDataAlgorithm(
+ stream, cudnn, algorithm_config, in_back_nd, filter,
+ conv, out_back_nd, scratch_allocator, &scratch));
- std::unique_ptr<CUDATimer> timer;
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (is_profiling) {
timer.reset(new CUDATimer(parent_)); // NOLINT
- timer->Init();
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- timer->Start(AsCUDAStream(stream));
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
+ }
}
- auto status =
+ RETURN_IF_CUDNN_ERROR(
cudnnConvolutionBackwardData(cudnn.handle(),
/*alpha=*/alpha,
/*wDesc=*/filter.handle(),
@@ -3248,32 +3056,22 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
/*dyDesc=*/out_back_nd.handle(),
/*dy=*/backward_output_data.opaque(),
/*convDesc=*/conv.handle(),
- /*algo=*/algo,
+ /*algo=*/ToConvBackwardDataAlgo(algo_desc),
/*workSpace=*/scratch.opaque(),
/*workSpaceSizeInBytes=*/scratch.size(),
/*beta=*/beta,
/*dxDesc=*/in_back_nd.handle(),
- /*dx=*/backward_input_data->opaque());
+ /*dx=*/backward_input_data->opaque()));
if (is_profiling) {
- timer->Stop(AsCUDAStream(stream));
- if (status == CUDNN_STATUS_SUCCESS) {
- bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
- dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
- output_profile_result->set_algorithm(algotype);
- output_profile_result->set_elapsed_time_in_ms(
- timer->GetElapsedMilliseconds());
- }
- timer->Destroy();
- }
- if (status != CUDNN_STATUS_SUCCESS) {
- // Silently return when we are profiling.
- if (!is_profiling) {
- LOG(ERROR) << "failed to enqueue convolution on stream: "
- << ToString(status);
+ if (!timer->Stop(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
- return false;
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
}
- return true;
+
+ return port::Status::OK();
}
bool CudnnSupport::DoConvolveBackwardData(
@@ -3287,11 +3085,13 @@ bool CudnnSupport::DoConvolveBackwardData(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, input_descriptor,
- backward_input_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, input_descriptor,
+ backward_input_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolveBackwardData(
@@ -3305,11 +3105,13 @@ bool CudnnSupport::DoConvolveBackwardData(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, input_descriptor,
- backward_input_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, input_descriptor,
+ backward_input_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolveBackwardData(
@@ -3323,15 +3125,17 @@ bool CudnnSupport::DoConvolveBackwardData(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, input_descriptor,
- backward_input_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, input_descriptor,
+ backward_input_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
template <class T>
-bool CudnnSupport::DoConvolveBackwardFilterImpl(
+port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& output_descriptor,
@@ -3362,141 +3166,25 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
GetConvComputeType<T>());
const bool is_profiling = output_profile_result != nullptr;
- cudnnConvolutionBwdFilterAlgo_t algo;
- DeviceMemory<uint8> scratch;
- if (algorithm_config.algorithm().is_default()) {
- // With the default algorithm, use Cudnn's heuristics.
-
- // Lambda that retrieves the algorithm.
- // specify_limit will occur when we have a scratch allocator and it succeeds
- // in allocating; otherwise, we'll fall back to the "no workspace" version.
- auto get_algorithm = [&](bool specify_limit) {
- cudnnConvolutionBwdFilterPreference_t preference =
- specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
- : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
-
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
-
- cudnnConvolutionBwdFilterAlgo_t algo_to_use;
- cudnnStatus_t status = cudnnGetConvolutionBackwardFilterAlgorithm(
- cudnn.handle(),
- /*srcDesc=*/input_nd.handle(),
- /*diffDesc=*/out_back_nd.handle(),
- /*convDesc=*/conv.handle(),
- /*gradDesc=*/filter.handle(),
- /*preference=*/preference,
- /*memoryLimitInBytes=*/memory_limit_bytes,
- /*algo=*/&algo_to_use);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
- "algorithm for doing backward "
- "filter convolution";
- return algo_to_use;
- };
-
- algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
-
- if (scratch_allocator != nullptr) {
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
- cudnn.handle(),
- /*xDesc=*/input_nd.handle(),
- /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
- /*gradDesc=*/filter.handle(), /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- int64 size_in_bytes_int64 = size_in_bytes;
- if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) {
- if (size_in_bytes_int64 > 0) {
- auto allocated =
- scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- } else {
- LOG(WARNING)
- << "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
- }
-
- // If we didn't allocate any scratch space (perhaps because of failed
- // allocation), we force a switch back to the "no workspace" algorithm.
- if (scratch == nullptr) {
- algo = get_algorithm(/*specify_limit=*/false);
- }
- } else {
- // An algorithm has been specified.
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
- algo = ToConvBackwardFilterAlgo(algotype);
- conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
-
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
- cudnn.handle(),
- /*xDesc=*/input_nd.handle(),
- /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
- /*gradDesc=*/filter.handle(), /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- // Silently return when we are profiling.
- return false;
- }
- LOG(FATAL) << "Cannot query the size of workspace needed for the given "
- "algorithm: "
- << algorithm_config.algorithm().algo_id();
- }
- int64 size_in_bytes_int64 = size_in_bytes;
- if (size_in_bytes_int64 > 0) {
- if (scratch_allocator == nullptr) {
- LOG(FATAL) << "An allocator must be specified when scratch memory is "
- "needed";
- }
- auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (is_profiling && !allocated.ok()) {
- // Silently return when we are profiling.
- return false;
- }
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- if (scratch == nullptr) {
- CHECK(!algorithm_config.algorithm_no_scratch().is_default())
- << "The primary convolution algorithm failed memory allocation, "
- "while a secondary algorithm is not provided.";
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
- algo = ToConvBackwardFilterAlgo(algotype);
- conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
- }
- } else if (size_in_bytes_int64 < 0) {
- LOG(WARNING)
- << "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
+ DeviceMemory<uint8> scratch;
+ SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc,
+ GetCudnnConvolutionBackwardFilterAlgorithm(
+ stream, cudnn, algorithm_config, input_nd, filter,
+ conv, out_back_nd, scratch_allocator, &scratch));
- std::unique_ptr<CUDATimer> timer;
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (is_profiling) {
timer.reset(new CUDATimer(parent_)); // NOLINT
- timer->Init();
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- timer->Start(AsCUDAStream(stream));
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
+ }
}
- auto status = cudnnConvolutionBackwardFilter(
+ RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
cudnn.handle(),
/*alpha=*/alpha,
/*srcDesc=*/input_nd.handle(),
@@ -3504,33 +3192,22 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
/*diffDesc=*/out_back_nd.handle(),
/*diffData=*/backward_output_data.opaque(),
/*convDesc=*/conv.handle(),
- /*algo=*/algo,
+ /*algo=*/ToConvBackwardFilterAlgo(algo_desc),
/*workSpace=*/scratch.opaque(),
/*workSpaceSizeInBytes=*/scratch.size(),
/*beta=*/beta,
/*gradDesc=*/filter.handle(),
- /*gradData=*/backward_filter_data->opaque());
-
+ /*dw=*/backward_filter_data->opaque()));
if (is_profiling) {
- timer->Stop(AsCUDAStream(stream));
- if (status == CUDNN_STATUS_SUCCESS) {
- bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
- dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
- output_profile_result->set_algorithm(algotype);
- output_profile_result->set_elapsed_time_in_ms(
- timer->GetElapsedMilliseconds());
- }
- timer->Destroy();
- }
- if (status != CUDNN_STATUS_SUCCESS) {
- // Silently return when we are profiling.
- if (!is_profiling) {
- LOG(ERROR) << "failed to enqueue convolution on stream: "
- << ToString(status);
+ if (!timer->Stop(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
- return false;
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
}
- return true;
+
+ return port::Status::OK();
}
bool CudnnSupport::DoConvolveBackwardFilter(
@@ -3544,11 +3221,13 @@ bool CudnnSupport::DoConvolveBackwardFilter(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, filter_descriptor,
- backward_filter_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, filter_descriptor,
+ backward_filter_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolveBackwardFilter(
@@ -3562,11 +3241,13 @@ bool CudnnSupport::DoConvolveBackwardFilter(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, filter_descriptor,
- backward_filter_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, filter_descriptor,
+ backward_filter_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolveBackwardFilter(
@@ -3580,15 +3261,17 @@ bool CudnnSupport::DoConvolveBackwardFilter(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, filter_descriptor,
- backward_filter_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, filter_descriptor,
+ backward_filter_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
template <class T>
-bool CudnnSupport::DoConvolveBackwardBiasImpl(
+port::Status CudnnSupport::DoConvolveBackwardBiasImpl(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
@@ -3603,15 +3286,10 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl(
float beta = 0.0;
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnConvolutionBackwardBias(
+ RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardBias(
cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
- bias_nd.handle(), backward_bias_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue backward convolution on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ bias_nd.handle(), backward_bias_data->opaque()));
+ return port::Status::OK();
}
bool CudnnSupport::DoConvolveBackwardBias(
@@ -3619,8 +3297,10 @@ bool CudnnSupport::DoConvolveBackwardBias(
const DeviceMemory<double>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<double>* backward_bias_data) {
- return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
- bias_descriptor, backward_bias_data);
+ return IsStatusOk(
+ DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
+ bias_descriptor, backward_bias_data),
+ /*report_error=*/true);
}
bool CudnnSupport::DoConvolveBackwardBias(
@@ -3628,8 +3308,10 @@ bool CudnnSupport::DoConvolveBackwardBias(
const DeviceMemory<float>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<float>* backward_bias_data) {
- return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
- bias_descriptor, backward_bias_data);
+ return IsStatusOk(
+ DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
+ bias_descriptor, backward_bias_data),
+ /*report_error=*/true);
}
bool CudnnSupport::DoConvolveBackwardBias(
@@ -3637,8 +3319,10 @@ bool CudnnSupport::DoConvolveBackwardBias(
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<Eigen::half>* backward_bias_data) {
- return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
- bias_descriptor, backward_bias_data);
+ return IsStatusOk(
+ DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
+ bias_descriptor, backward_bias_data),
+ /*report_error=*/true);
}
bool CudnnSupport::DoMatMul(Stream* stream,
@@ -3810,16 +3494,13 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnAddTensor(
- cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), &beta,
- input_descriptor.handle(), output_data->opaque());
-
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
- return false;
- }
-
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
+ cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
+ &beta, input_descriptor.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoActivate(Stream* stream,
@@ -3838,16 +3519,13 @@ bool CudnnSupport::DoActivate(Stream* stream,
float beta = 0.0;
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnActivationForward(
- cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
- input_data.opaque(), &beta, input_nd.handle(), output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "stream " << stream
- << " could not enqueue activation: " << ToString(status);
- return false;
- }
-
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
+ cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
+ input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolForward(
@@ -3866,15 +3544,13 @@ bool CudnnSupport::DoPoolForward(
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingForward(
- cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
- input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue forward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
+ input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolForward(
@@ -3893,15 +3569,13 @@ bool CudnnSupport::DoPoolForward(
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingForward(
- cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
- input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue forward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
+ input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolForward(
@@ -3919,15 +3593,13 @@ bool CudnnSupport::DoPoolForward(
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingForward(
- cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
- input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue forward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
+ input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolBackward(
@@ -3948,17 +3620,15 @@ bool CudnnSupport::DoPoolBackward(
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingBackward(
- cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
- output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
- src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
- output_diff_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue backward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
+ output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
+ src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
+ output_diff_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolBackward(
@@ -3979,17 +3649,15 @@ bool CudnnSupport::DoPoolBackward(
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingBackward(
- cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
- output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
- src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
- output_diff_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue backward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
+ output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
+ src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
+ output_diff_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolBackward(
@@ -4010,17 +3678,15 @@ bool CudnnSupport::DoPoolBackward(
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingBackward(
- cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
- output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
- src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
- output_diff_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue backward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
+ output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
+ src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
+ output_diff_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoNormalize(
@@ -4055,15 +3721,14 @@ bool CudnnSupport::DoNormalizeWithDimensions(
auto cudnn = cudnn_->GetHandle(parent_, stream);
// Launch the normalization.
- auto status = cudnnLRNCrossChannelForward(
- cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha,
- dims.handle(), input_data.opaque(), &beta, dims.handle(),
- output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to run cudnnLRNCrossChannelForward";
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
+ cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
+ &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
+ output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoNormalizeBackwardWithDimensions(
@@ -4089,16 +3754,15 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
float beta = 0.0f;
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnLRNCrossChannelBackward(
- cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha,
- dims.handle(), normalized_data.opaque(), dims.handle(),
- normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
- &beta, dims.handle(), raw_variable_gradient->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to run cudnnLRNCrossChannelBackward";
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
+ cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
+ &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
+ normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
+ &beta, dims.handle(), raw_variable_gradient->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoDepthConcatenate(
@@ -4213,24 +3877,20 @@ bool CudnnSupport::DeriveOutputBatchDescriptor(
int dn = batch_descriptor.ndims() + 2;
std::vector<int> dims(dn); // in BDYX
- auto status = cudnnGetConvolutionNdForwardOutputDim(
- conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not get output tensor for convolution: "
- << ToString(status);
- return false;
- }
-
- output_batch_descriptor->set_count(dims[0])
- .set_feature_map_count(dims[1])
- .set_layout(batch_descriptor.layout());
-
- for (int i = 0; i < batch_descriptor.ndims(); i++) {
- output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
- dims.rbegin()[i]);
- }
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
+ conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
+ output_batch_descriptor->set_count(dims[0])
+ .set_feature_map_count(dims[1])
+ .set_layout(batch_descriptor.layout());
- return true;
+ for (int i = 0; i < batch_descriptor.ndims(); i++) {
+ output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
+ dims.rbegin()[i]);
+ }
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
} // namespace cuda
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index e2de3c62d8..c924d41cb5 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -631,7 +631,7 @@ class CudnnSupport : public dnn::DnnSupport {
std::unique_ptr<class CudnnAccess> cudnn_;
template <class T, class U>
- bool DoBatchNormalizationForwardImpl(
+ port::Status DoBatchNormalizationForwardImpl(
Stream* stream, dnn::DataType input_data_type,
dnn::DataType scale_data_type, const DeviceMemory<T>& x,
const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
@@ -646,7 +646,7 @@ class CudnnSupport : public dnn::DnnSupport {
std::function<void()> inv_var_to_var);
template <class T, class U>
- bool DoBatchNormalizationBackwardImpl(
+ port::Status DoBatchNormalizationBackwardImpl(
Stream* stream, int cudnn_input_type, int cudnn_scale_type,
const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
@@ -656,21 +656,20 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<U>* offset_backprop);
template <class T>
- bool DoConvolveImpl(Stream* stream,
- const dnn::BatchDescriptor& input_descriptor,
- const DeviceMemory<T>& input_data,
- const dnn::FilterDescriptor& filter_descriptor,
- const DeviceMemory<T>& filter_data,
- const dnn::ConvolutionDescriptor& convolution_descriptor,
- const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<T>* output_data,
- ScratchAllocator* scratch_allocator,
- const dnn::AlgorithmConfig& algorithm_config,
- dnn::ProfileResult* output_profile_result);
+ port::Status DoConvolveImpl(
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
+ const DeviceMemory<T>& input_data,
+ const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<T>& filter_data,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<T>* output_data, ScratchAllocator* scratch_allocator,
+ const dnn::AlgorithmConfig& algorithm_config,
+ dnn::ProfileResult* output_profile_result);
template <typename Type, typename BiasType, typename ScaleType,
int cudnn_data_type, int cudnn_compute_type>
- bool DoFusedConvolveImpl(
+ port::Status DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
const dnn::FilterDescriptor& filter_descriptor,
@@ -685,9 +684,8 @@ class CudnnSupport : public dnn::DnnSupport {
dnn::ProfileResult* output_profile_result);
template <class T>
- bool DoConvolveBackwardDataImpl(
- Stream* stream,
- const dnn::FilterDescriptor& filter_descriptor,
+ port::Status DoConvolveBackwardDataImpl(
+ Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<T> backward_output_data,
@@ -698,10 +696,10 @@ class CudnnSupport : public dnn::DnnSupport {
dnn::ProfileResult* output_profile_result);
template <class T>
- bool DoConvolveBackwardFilterImpl(
+ port::Status DoConvolveBackwardFilterImpl(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
- const dnn::BatchDescriptor& output_descriptor_in,
+ const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<T> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
@@ -711,56 +709,56 @@ class CudnnSupport : public dnn::DnnSupport {
dnn::ProfileResult* output_profile_result);
template <class T>
- bool DoConvolveBackwardBiasImpl(Stream* stream,
- const dnn::BatchDescriptor& input_descriptor,
- const DeviceMemory<T>& input_data,
- const dnn::BatchDescriptor& bias_descriptor,
- DeviceMemory<T>* backward_bias_data);
+ port::Status DoConvolveBackwardBiasImpl(
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
+ const DeviceMemory<T>& input_data,
+ const dnn::BatchDescriptor& bias_descriptor,
+ DeviceMemory<T>* backward_bias_data);
template <class T>
- bool DoRnnForwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc,
- const CudnnRnnSequenceTensorDescriptor& input_desc,
- const DeviceMemory<T>& input_data,
- const CudnnRnnStateTensorDescriptor& input_h_desc,
- const DeviceMemory<T>& input_h_data,
- const CudnnRnnStateTensorDescriptor& input_c_desc,
- const DeviceMemory<T>& input_c_data,
- const DeviceMemory<T>& params,
- const CudnnRnnSequenceTensorDescriptor& output_desc,
- DeviceMemory<T>* output_data,
- const CudnnRnnStateTensorDescriptor& output_h_desc,
- DeviceMemory<T>* output_h_data,
- const CudnnRnnStateTensorDescriptor& output_c_desc,
- DeviceMemory<T>* output_c_data, bool is_training,
- ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator,
- dnn::ProfileResult* output_profile_result);
+ port::Status DoRnnForwardImpl(
+ Stream* stream, const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<T>& input_data,
+ const CudnnRnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<T>& input_h_data,
+ const CudnnRnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
+ const CudnnRnnSequenceTensorDescriptor& output_desc,
+ DeviceMemory<T>* output_data,
+ const CudnnRnnStateTensorDescriptor& output_h_desc,
+ DeviceMemory<T>* output_h_data,
+ const CudnnRnnStateTensorDescriptor& output_c_desc,
+ DeviceMemory<T>* output_c_data, bool is_training,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result);
template <class T>
- bool DoRnnBackwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc,
- const CudnnRnnSequenceTensorDescriptor& input_desc,
- const DeviceMemory<T>& input_data,
- const CudnnRnnStateTensorDescriptor& input_h_desc,
- const DeviceMemory<T>& input_h_data,
- const CudnnRnnStateTensorDescriptor& input_c_desc,
- const DeviceMemory<T>& input_c_data,
- const DeviceMemory<T>& params,
- const CudnnRnnSequenceTensorDescriptor& output_desc,
- const DeviceMemory<T>& output_data,
- const CudnnRnnStateTensorDescriptor& output_h_desc,
- const DeviceMemory<T>& output_h_data,
- const CudnnRnnStateTensorDescriptor& output_c_desc,
- const DeviceMemory<T>& output_c_data,
- const DeviceMemory<T>& output_backprop_data,
- const DeviceMemory<T>& output_h_backprop_data,
- const DeviceMemory<T>& output_c_backprop_data,
- DeviceMemory<T>* input_backprop_data,
- DeviceMemory<T>* input_h_backprop_data,
- DeviceMemory<T>* input_c_backprop_data,
- DeviceMemory<T>* params_backprop_data,
- DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator,
- dnn::ProfileResult* output_profile_result);
+ port::Status DoRnnBackwardImpl(
+ Stream* stream, const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<T>& input_data,
+ const CudnnRnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<T>& input_h_data,
+ const CudnnRnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
+ const CudnnRnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<T>& output_data,
+ const CudnnRnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<T>& output_h_data,
+ const CudnnRnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<T>& output_c_data,
+ const DeviceMemory<T>& output_backprop_data,
+ const DeviceMemory<T>& output_h_backprop_data,
+ const DeviceMemory<T>& output_c_backprop_data,
+ DeviceMemory<T>* input_backprop_data,
+ DeviceMemory<T>* input_h_backprop_data,
+ DeviceMemory<T>* input_c_backprop_data,
+ DeviceMemory<T>* params_backprop_data,
+ DeviceMemory<uint8>* reserve_space_data,
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result);
SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
};
diff --git a/tensorflow/stream_executor/cuda/cuda_timer.h b/tensorflow/stream_executor/cuda/cuda_timer.h
index 70554ec931..e040cf86fa 100644
--- a/tensorflow/stream_executor/cuda/cuda_timer.h
+++ b/tensorflow/stream_executor/cuda/cuda_timer.h
@@ -37,8 +37,9 @@ class CUDATimer : public internal::TimerInterface {
explicit CUDATimer(CUDAExecutor *parent)
: parent_(parent), start_event_(nullptr), stop_event_(nullptr) {}
- // Note: teardown is explicitly handled in this API by a call to
+ // Note: teardown needs to be explicitly handled in this API by a call to
// StreamExecutor::DeallocateTimer(), which invokes Destroy().
+ // TODO(csigg): Change to RAII.
~CUDATimer() override {}
// Allocates the platform-specific pieces of the timer, called as part of
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 5315d1f3da..82aa8ceb32 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -141,6 +141,10 @@ string PadAlignmentString(PadAlignment alignment) {
return "unknown pad alignment";
}
+std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) {
+ return str << PadAlignmentString(alignment);
+}
+
string ShortPoolingModeString(PoolingMode mode) {
switch (mode) {
case PoolingMode::kMaximum:
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 3df5365c23..9eca5abe1a 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -469,6 +469,9 @@ enum class PadAlignment : int64 {
// Returns a string representation of the given padding alignment.
string PadAlignmentString(PadAlignment alignment);
+// Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
+std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
+
// Describes a convolution.
//
// Uses the named argument construction form:
@@ -710,7 +713,7 @@ class PoolingDescriptor {
class AlgorithmDesc {
public:
typedef int64 Index;
- AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(false) {}
+ AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {}
AlgorithmDesc(Index a, bool use_tensor_ops)
: algo_(a), tensor_ops_enabled_(use_tensor_ops) {}
bool is_default() const { return algo_ == kDefaultAlgorithm; }