aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2016-08-19 10:01:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-19 11:18:09 -0700
commit859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (patch)
treeae0c4e4e690d53b92720049fb75acc119e4ea5e0
parentde8838042fb34e53a511f25b4613611fc368beeb (diff)
Add stream-executor changes to enable Cudnn fused LSTM/RNN support.
Change: 130770287
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc1047
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h102
-rw-r--r--tensorflow/stream_executor/dnn.h264
-rw-r--r--tensorflow/stream_executor/platform/port.h2
-rw-r--r--tensorflow/stream_executor/stream.cc73
-rw-r--r--tensorflow/stream_executor/stream.h45
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc42
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h20
8 files changed, 1594 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index b042dda29f..7fbafa3d7e 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -238,7 +238,25 @@ CUDNN_DNN_ROUTINE_EACH_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
__macro(cudnnCreateActivationDescriptor) \
__macro(cudnnSetActivationDescriptor) \
__macro(cudnnGetActivationDescriptor) \
- __macro(cudnnDestroyActivationDescriptor)
+ __macro(cudnnDestroyActivationDescriptor) \
+ __macro(cudnnCreateDropoutDescriptor) \
+ __macro(cudnnDestroyDropoutDescriptor) \
+ __macro(cudnnSetDropoutDescriptor) \
+ __macro(cudnnDropoutGetStatesSize) \
+ __macro(cudnnCreateRNNDescriptor) \
+ __macro(cudnnDestroyRNNDescriptor) \
+ __macro(cudnnGetRNNParamsSize) \
+ __macro(cudnnGetRNNWorkspaceSize) \
+ __macro(cudnnGetRNNTrainingReserveSize) \
+ __macro(cudnnGetRNNLinLayerMatrixParams) \
+ __macro(cudnnGetRNNLinLayerBiasParams) \
+ __macro(cudnnRNNForwardInference) \
+ __macro(cudnnRNNForwardTraining) \
+ __macro(cudnnRNNBackwardData) \
+ __macro(cudnnRNNBackwardWeights) \
+ __macro(cudnnSetRNNDescriptor) \
+ __macro(cudnnGetFilterNdDescriptor)
+
// clang-format on
CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
@@ -759,6 +777,1033 @@ class ScopedActivationDescriptor {
};
#endif
+namespace {
+
+#if CUDNN_VERSION >= 5000
+
+cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
+ switch (input_mode) {
+ case dnn::RnnInputMode::kRnnLinearSkip:
+ case dnn::RnnInputMode::kRnnSkipInput:
+ return static_cast<cudnnRNNInputMode_t>(input_mode);
+ default:
+ LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
+ }
+}
+
+cudnnDirectionMode_t ToCudnnRnnDirectionMode(
+ dnn::RnnDirectionMode direction_mode) {
+ switch (direction_mode) {
+ case dnn::RnnDirectionMode::kRnnUnidirectional:
+ case dnn::RnnDirectionMode::kRnnBidirectional:
+ return static_cast<cudnnDirectionMode_t>(direction_mode);
+ default:
+ LOG(FATAL) << "Invalid RNN direction mode: "
+ << static_cast<int>(direction_mode);
+ }
+}
+
+cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
+ switch (rnn_mode) {
+ case dnn::RnnMode::kRnnRelu:
+ case dnn::RnnMode::kRnnTanh:
+ case dnn::RnnMode::kRnnLstm:
+ case dnn::RnnMode::kRnnGru:
+ return static_cast<cudnnRNNMode_t>(rnn_mode);
+ default:
+ LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
+ }
+}
+
+cudnnDataType_t ToCudnnDataType(dnn::DataType data_type) {
+ switch (data_type) {
+ case dnn::DataType::kFloat:
+ case dnn::DataType::kDouble:
+ case dnn::DataType::kHalf:
+ return static_cast<cudnnDataType_t>(data_type);
+ default:
+ LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
+ }
+}
+
+int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
+ switch (data_type) {
+ case CUDNN_DATA_FLOAT:
+ return sizeof(float);
+ case CUDNN_DATA_DOUBLE:
+ return sizeof(double);
+ case CUDNN_DATA_HALF:
+ return sizeof(Eigen::half);
+ default:
+ LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
+ }
+}
+
+#endif // CUDNN_VERSION
+
+template <typename Base>
+class MixinBase : public Base {};
+template <>
+class MixinBase<void> {};
+
+} // namespace
+
+#if CUDNN_VERSION >= 5000
+
+#define CUDNN_RETURN_IF_FAIL(STATUS, ...) \
+ if (!SE_PREDICT_TRUE((STATUS) == CUDNN_STATUS_SUCCESS)) { \
+ string error_msg = port::StrCat(ToString(STATUS), " ", __VA_ARGS__); \
+ SetFailure(port::Status(port::error::UNKNOWN, error_msg)); \
+ LOG(ERROR) << error_msg; \
+ return; \
+ }
+
+template <typename Base>
+class CudnnDescriptorCommon : public MixinBase<Base> {
+ public:
+ bool ok() const { return status_.ok(); }
+ port::Status Status() const { return status_; }
+
+ protected:
+ void SetFailure(const port::Status& status) { status_.Update(status); }
+ port::Status status_;
+};
+
+class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
+ public:
+ CudnnDropoutDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
+ float dropout, uint64 seed,
+ ScratchAllocator* state_allocator)
+ : parent_(parent), handle_(nullptr) {
+ cudnnStatus_t status;
+ status = dynload::cudnnCreateDropoutDescriptor(parent_, &handle_);
+ CUDNN_RETURN_IF_FAIL(status, "Failed to create dropout descriptor");
+
+ if (dropout == 0.f) {
+ return;
+ }
+
+ DeviceMemory<uint8> state_memory;
+ if (state_allocator) {
+ size_t state_sizes_in_bytes = 0;
+ status = dynload::cudnnDropoutGetStatesSize(parent_, cudnn_handle,
+ &state_sizes_in_bytes);
+ CUDNN_RETURN_IF_FAIL(status, "Failed to query dropout state sizes");
+
+ auto allocated =
+ state_allocator->AllocateBytes(nullptr, state_sizes_in_bytes);
+ if (!allocated.ok() ||
+ (state_memory = allocated.ValueOrDie()) == nullptr) {
+ string error_msg =
+ port::StrCat("Fail to allocate Cudnn dropout state memory");
+ status_ = port::Status(port::error::UNKNOWN, error_msg);
+ LOG(ERROR) << error_msg;
+ return;
+ }
+ }
+ status = dynload::cudnnSetDropoutDescriptor(parent_, handle_, cudnn_handle,
+ dropout, state_memory.opaque(),
+ state_memory.size(), seed);
+ CUDNN_RETURN_IF_FAIL(status, "Failed to set dropout descriptor");
+ }
+
+ ~CudnnDropoutDescriptor() {
+ if (handle_) {
+ cudnnStatus_t status =
+ dynload::cudnnDestroyDropoutDescriptor(parent_, handle_);
+ CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: ");
+ }
+ }
+
+ cudnnDropoutDescriptor_t handle() const {
+ if (!ok()) return nullptr;
+ return handle_;
+ }
+
+ private:
+ CUDAExecutor* parent_;
+ cudnnDropoutDescriptor_t handle_;
+ float dropout_;
+ uint64 seed_;
+ port::Status status_;
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
+};
+
+class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> {
+ public:
+ typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion;
+ typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
+ CudnnRnnParamsDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
+ const CudnnRnnDescriptor& rnn_desc);
+ ~CudnnRnnParamsDescriptor() {
+ cudnnStatus_t status =
+ dynload::cudnnDestroyFilterDescriptor(parent_, handle_);
+ CUDNN_RETURN_IF_FAIL(status, "Failed to destroy RNN filter desciptor");
+ }
+ cudnnFilterDescriptor_t handle() const {
+ if (!ok()) return nullptr;
+ return handle_;
+ }
+ int64 params_size_in_bytes() const { return params_size_in_bytes_; }
+ ParamsRegions params_weights() const {
+ if (!ok()) return ParamsRegions();
+ return weights_;
+ }
+ ParamsRegions params_biases() const {
+ if (!ok()) return ParamsRegions();
+ return biases_;
+ }
+
+ private:
+ int GetRegionCountPerLayer() const;
+ CUDAExecutor* parent_;
+ cudnnFilterDescriptor_t handle_;
+ const CudnnRnnDescriptor* rnn_desc_;
+ int64 params_size_in_bytes_;
+ ParamsRegions weights_;
+ ParamsRegions biases_;
+ port::Status status_;
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
+};
+
+class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
+ public:
+ CudnnRnnDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
+ int num_layers, int hidden_size, int input_size,
+ cudnnRNNInputMode_t input_mode,
+ cudnnDirectionMode_t direction_mode,
+ cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
+ float dropout, uint64 seed,
+ ScratchAllocator* state_allocator)
+ : parent_(parent),
+ rnn_desc_(nullptr),
+ num_layers_(num_layers),
+ hidden_size_(hidden_size),
+ input_size_(input_size),
+ input_mode_(input_mode),
+ direction_mode_(direction_mode),
+ rnn_mode_(rnn_mode),
+ data_type_(data_type) {
+ // Create the dropout handle.
+ cudnn_dropout_desc_.reset(new CudnnDropoutDescriptor(
+ parent, cudnn_handle, dropout, seed, state_allocator));
+ if (!cudnn_dropout_desc_->ok()) {
+ SetFailure(cudnn_dropout_desc_->Status());
+ return;
+ }
+
+ // Create the RNN handle
+ cudnnStatus_t status =
+ dynload::cudnnCreateRNNDescriptor(parent_, &rnn_desc_);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
+ status = dynload::cudnnSetRNNDescriptor(
+ parent, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
+ num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
+ input_mode /*inputMode*/, direction_mode /*direction*/,
+ rnn_mode /*mode*/, data_type /*dataType*/);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor");
+
+ // Create the params handle.
+ cudnn_params_desc_.reset(
+ new CudnnRnnParamsDescriptor(parent, cudnn_handle, *this));
+ if (!cudnn_params_desc_->ok()) {
+ SetFailure(cudnn_params_desc_->Status());
+ return;
+ }
+ }
+ ~CudnnRnnDescriptor() override {
+ if (rnn_desc_) {
+ cudnnStatus_t status =
+ dynload::cudnnDestroyRNNDescriptor(parent_, rnn_desc_);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor");
+ }
+ }
+ cudnnRNNDescriptor_t handle() const {
+ if (!ok()) return nullptr;
+ return rnn_desc_;
+ }
+ int num_layers() const { return num_layers_; }
+ int hidden_size() const { return hidden_size_; }
+ int input_size() const { return input_size_; }
+ cudnnRNNInputMode_t input_mode() const { return input_mode_; }
+ cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
+ cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
+ cudnnDataType_t data_type() const { return data_type_; }
+ int64 ParamsSizeInBytes() const override {
+ return cudnn_params_desc_->params_size_in_bytes();
+ }
+ cudnnDropoutDescriptor_t dropout_handle() const {
+ if (!cudnn_dropout_desc_) return nullptr;
+ return cudnn_dropout_desc_->handle();
+ }
+ cudnnFilterDescriptor_t params_handle() const {
+ if (!cudnn_params_desc_) return nullptr;
+ return cudnn_params_desc_->handle();
+ }
+ ParamsRegions ParamsWeightRegions() const override {
+ if (!ok()) return ParamsRegions();
+ return cudnn_params_desc_->params_weights();
+ }
+ ParamsRegions ParamsBiasRegions() const override {
+ if (!ok()) return ParamsRegions();
+ return cudnn_params_desc_->params_biases();
+ }
+
+ private:
+ CUDAExecutor* parent_;
+ cudnnRNNDescriptor_t rnn_desc_;
+ int num_layers_;
+ int hidden_size_;
+ int input_size_;
+ cudnnRNNInputMode_t input_mode_;
+ cudnnDirectionMode_t direction_mode_;
+ cudnnRNNMode_t rnn_mode_;
+ cudnnDataType_t data_type_;
+ port::Status status_;
+ std::unique_ptr<CudnnDropoutDescriptor> cudnn_dropout_desc_;
+ std::unique_ptr<CudnnRnnParamsDescriptor> cudnn_params_desc_;
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
+};
+
+CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
+ CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
+ const CudnnRnnDescriptor& rnn_desc)
+ : parent_(parent),
+ handle_(nullptr),
+ rnn_desc_(&rnn_desc),
+ params_size_in_bytes_(0) {
+ cudnnTensorDescriptor_t input_desc = nullptr;
+ {
+ // Query the params size.
+ auto status = dynload::cudnnCreateTensorDescriptor(parent, &input_desc);
+ CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create tensor descriptor");
+ int dims[] = {1, rnn_desc.input_size(), 1};
+ int strides[] = {dims[1] * dims[2], dims[2], 1};
+ status = dynload::cudnnSetTensorNdDescriptor(
+ parent, input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/,
+ sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
+ strides /*strideA*/);
+ CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor");
+
+ size_t params_size = 0;
+ status = dynload::cudnnGetRNNParamsSize(
+ parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ input_desc /*xDesc*/, &params_size /*sizeInBytes*/,
+ rnn_desc.data_type() /*dataType*/);
+ CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get RNN parameter size");
+ params_size_in_bytes_ = static_cast<int64>(params_size);
+ }
+
+ {
+ // Create the params descriptor.
+ auto status = dynload::cudnnCreateFilterDescriptor(parent, &handle_);
+ CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor");
+ int dims[] = {static_cast<int>(params_size_in_bytes_), 1, 1};
+ status = dynload::cudnnSetFilterNdDescriptor(
+ parent, handle_ /*filterDesc*/, rnn_desc.data_type() /*dataType*/,
+ CUDNN_TENSOR_NCHW /*format*/, sizeof(dims) / sizeof(dims[0]) /*nbDims*/,
+ dims /*filterDimA*/);
+ CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to update RNN filter descriptor");
+ }
+
+ {
+ // Create the weights and biases into the params buffer
+ int region_count_per_layer = GetRegionCountPerLayer();
+ cudnnFilterDescriptor_t region_desc_handle = nullptr;
+ auto status =
+ dynload::cudnnCreateFilterDescriptor(parent, &region_desc_handle);
+ CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create filter descriptor");
+ for (int layer = 0; layer < rnn_desc.num_layers(); layer++) {
+ for (int region = 0; region < region_count_per_layer; region++) {
+ for (int type = 0; type < 2; type++) {
+ void* offset = nullptr;
+ if (type == 0) {
+ status = dynload::cudnnGetRNNLinLayerMatrixParams(
+ parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/,
+ nullptr /*w*/, region /*linLayerID*/,
+ region_desc_handle /*linLayerMatDesc*/,
+ &offset /*linLayerMat*/);
+ CUDNN_RETURN_IF_FAIL(
+ status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams");
+ } else {
+ status = dynload::cudnnGetRNNLinLayerBiasParams(
+ parent, cudnn_handle /*rnnDesc*/, rnn_desc.handle() /*rnnDesc*/,
+ layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/,
+ nullptr /*w*/, region /*linLayerID*/,
+ region_desc_handle /*linLayerBiasDesc*/,
+ &offset /*linLayerBias*/);
+ CUDNN_RETURN_IF_FAIL(
+ status, "Cudnn fails to call cudnnGetRNNLinLayerBiasParams");
+ }
+ int dims[] = {1, 1, 1};
+ cudnnDataType_t data_type;
+ cudnnTensorFormat_t tensor_format;
+ int n_dims;
+ status = dynload::cudnnGetFilterNdDescriptor(
+ parent, region_desc_handle /*filterDesc*/,
+ sizeof(dims) / sizeof(dims[0]) /*nbDimsRequested*/,
+ &data_type /*dataType*/, &tensor_format /*format*/,
+ &n_dims /*nbDims*/, dims /*filterDimA*/);
+ CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get filter description");
+ int64 size = dims[0] * dims[1] * dims[2] *
+ CudnnDataTypeToByteSize(rnn_desc.data_type());
+ auto region = ParamsRegion{reinterpret_cast<int64>(offset), size};
+ if (type == 0) {
+ weights_.push_back(region);
+ } else {
+ biases_.push_back(region);
+ }
+ }
+ }
+ }
+ status = dynload::cudnnDestroyFilterDescriptor(parent, region_desc_handle);
+ CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy filter descriptor");
+ }
+
+ {
+ // Release the dummy input tensor descriptor.
+ auto status = dynload::cudnnDestroyTensorDescriptor(parent, input_desc);
+ CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy tensor descriptor");
+ }
+}
+
+int CudnnRnnParamsDescriptor::GetRegionCountPerLayer() const {
+ auto rnn_mode = rnn_desc_->rnn_mode();
+ switch (rnn_mode) {
+ case CUDNN_RNN_RELU:
+ case CUDNN_RNN_TANH:
+ return 2;
+ case CUDNN_LSTM:
+ return 8;
+ case CUDNN_GRU:
+ return 6;
+ default:
+ LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
+ }
+}
+
+class CudnnRnnSequenceTensorDescriptor
+ : public CudnnDescriptorCommon<dnn::RnnSequenceTensorDescriptor> {
+ public:
+ CudnnRnnSequenceTensorDescriptor(CUDAExecutor* parent, int seq_length,
+ int batch_size, int data_size,
+ cudnnDataType_t data_type)
+ : parent_(parent),
+ seq_length_(seq_length),
+ batch_size_(batch_size),
+ data_size_(data_size),
+ data_type_(data_type) {
+ cudnnTensorDescriptor_t handle = nullptr;
+ if (seq_length <= 0) {
+ string error_msg =
+ port::StrCat("sequence length must be positive: ", seq_length);
+ LOG(ERROR) << error_msg;
+ SetFailure(port::Status(port::error::UNKNOWN, error_msg));
+ return;
+ }
+ cudnnStatus_t status =
+ dynload::cudnnCreateTensorDescriptor(parent, &handle);
+ CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
+ int dims[] = {batch_size, data_size, 1};
+ int strides[] = {dims[1] * dims[2], dims[2], 1};
+ status = dynload::cudnnSetTensorNdDescriptor(
+ parent, handle /*tensorDesc*/, data_type /*dataType*/,
+ sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
+ strides /*strideA*/);
+ CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
+ // Replicate handle across the number of steps.
+ handles_.assign(seq_length, handle);
+ }
+
+ ~CudnnRnnSequenceTensorDescriptor() override {
+ // Only the first one needs to be destroyed. All others are the same.
+ cudnnStatus_t status =
+ dynload::cudnnDestroyTensorDescriptor(parent_, handles_[0]);
+ CUDNN_RETURN_IF_FAIL(status, "Failed to destroy sequence tensor desciptor");
+ }
+
+ const cudnnTensorDescriptor_t* handles() const {
+ if (!ok()) return nullptr;
+ CHECK(!handles_.empty()) << "handles cannot be empty";
+ return handles_.data();
+ }
+
+ int seq_length() const { return seq_length_; }
+ int batch_size() const { return batch_size_; }
+ int data_size() const { return data_size_; }
+
+ private:
+ CUDAExecutor* parent_;
+ int seq_length_;
+ int batch_size_;
+ int data_size_;
+ cudnnDataType_t data_type_;
+ std::vector<cudnnTensorDescriptor_t> handles_;
+ port::Status status_;
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
+};
+
+class CudnnRnnStateTensorDescriptor
+ : public CudnnDescriptorCommon<dnn::RnnStateTensorDescriptor> {
+ public:
+ CudnnRnnStateTensorDescriptor(CUDAExecutor* parent, int num_layers,
+ int batch_size, int data_size,
+ cudnnDataType_t data_type)
+ : parent_(parent),
+ handle_(nullptr),
+ num_layers_(num_layers),
+ batch_size_(batch_size),
+ data_size_(data_size),
+ data_type_(data_type) {
+ cudnnStatus_t status =
+ dynload::cudnnCreateTensorDescriptor(parent, &handle_);
+ CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
+ int dims[] = {num_layers, batch_size, data_size};
+ int strides[] = {dims[1] * dims[2], dims[2], 1};
+ status = dynload::cudnnSetTensorNdDescriptor(
+ parent, handle_ /*tensorDesc*/, data_type /*dataType*/,
+ sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
+ strides /*strideA*/);
+ CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
+ }
+
+ ~CudnnRnnStateTensorDescriptor() override {
+ if (!handle_) {
+ cudnnStatus_t status =
+ dynload::cudnnDestroyTensorDescriptor(parent_, handle_);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN state tensor");
+ }
+ }
+
+ cudnnTensorDescriptor_t handle() const {
+ if (!ok()) return nullptr;
+ return handle_;
+ }
+ int num_layers() const { return num_layers_; }
+ int batch_size() const { return batch_size_; }
+ int data_size() const { return data_size_; }
+
+ private:
+ CUDAExecutor* parent_;
+ cudnnTensorDescriptor_t handle_;
+ int num_layers_;
+ int batch_size_;
+ int data_size_;
+ port::Status status_;
+ cudnnDataType_t data_type_;
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
+};
+
+namespace {
+
+struct RnnModelDims {
+ int num_layers = 0;
+ int batch_size = 0;
+ int seq_length = 0;
+ int hidden_size = 0;
+ int input_size = 0;
+ int dir_count = 0;
+};
+
+template <class T>
+bool ExtractAndCheckRnnForward(
+ const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<T>& input_data,
+ const CudnnRnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<T>& input_h_data,
+ const CudnnRnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
+ const CudnnRnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<T>& output_data,
+ const CudnnRnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<T>& output_h_data,
+ const CudnnRnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<T>& output_c_data, RnnModelDims* model_dims) {
+ // extract model parameters
+ model_dims->num_layers = rnn_desc.num_layers();
+ model_dims->batch_size = input_desc.batch_size();
+ model_dims->seq_length = input_desc.seq_length();
+ model_dims->hidden_size = rnn_desc.hidden_size();
+ model_dims->input_size = input_desc.data_size();
+ model_dims->dir_count =
+ (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
+
+ // check parameters
+ if (!(input_h_desc.num_layers() ==
+ model_dims->num_layers * model_dims->dir_count &&
+ input_h_desc.batch_size() == model_dims->batch_size &&
+ input_h_desc.data_size() == model_dims->hidden_size)) {
+ LOG(ERROR) << "Invalid input_h shape";
+ return false;
+ }
+ if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
+ input_h_desc.batch_size() == input_c_desc.batch_size() &&
+ input_h_desc.data_size() == input_c_desc.data_size())) {
+ LOG(ERROR) << "Invalid input_c shape";
+ return false;
+ }
+ if (!(output_desc.seq_length() == model_dims->seq_length &&
+ output_desc.batch_size() == model_dims->batch_size &&
+ output_desc.data_size() ==
+ model_dims->hidden_size * model_dims->dir_count)) {
+ LOG(ERROR) << "Invalid output shape";
+ return false;
+ }
+ if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
+ input_h_desc.batch_size() == output_h_desc.batch_size() &&
+ input_h_desc.data_size() == output_h_desc.data_size())) {
+ LOG(ERROR) << "Invalid output_h shape";
+ return false;
+ }
+ if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
+ input_h_desc.batch_size() == output_c_desc.batch_size() &&
+ input_h_desc.data_size() == output_c_desc.data_size())) {
+ LOG(ERROR) << "Invalid output_h shape";
+ return false;
+ }
+
+ return true;
+}
+
+bool CheckRNNParameterSize(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
+ const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc) {
+ size_t params_size_in_bytes = 0;
+ cudnnStatus_t status = dynload::cudnnGetRNNParamsSize(
+ parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ input_desc.handles()[0] /*xDesc*/, &params_size_in_bytes /*sizeInBytes*/,
+ rnn_desc.data_type() /*dataType*/);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "Unable to check RNN param size: " << ToString(status);
+ return false;
+ }
+ return static_cast<int64>(params_size_in_bytes) ==
+ rnn_desc.ParamsSizeInBytes();
+}
+
+bool CreateRnnWorkspace(Stream* stream, CUDAExecutor* parent,
+ cudnnHandle_t cudnn_handle,
+ const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ ScratchAllocator* workspace_allocator,
+ DeviceMemory<uint8>* workspace) {
+ // Query the workspace size.
+ size_t workspace_size_in_bytes = 0;
+ cudnnStatus_t status = dynload::cudnnGetRNNWorkspaceSize(
+ parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/,
+ &workspace_size_in_bytes /*sizeInBytes*/);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "Unable to query workspace size: " << ToString(status);
+ return false;
+ }
+ // Allocate the workspace.
+ if (workspace_size_in_bytes > 0) {
+ auto allocated =
+ workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
+ if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) {
+ LOG(ERROR) << "Failed to allocate RNN workspace";
+ return false;
+ }
+ } else {
+ *workspace = DeviceMemory<uint8>();
+ }
+ return true;
+}
+
+} // namespace
+
+template <class T>
+bool CudnnSupport::DoRnnForwardImpl(
+ Stream* stream, const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<T>& input_data,
+ const CudnnRnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<T>& input_h_data,
+ const CudnnRnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
+ const CudnnRnnSequenceTensorDescriptor& output_desc,
+ DeviceMemory<T>* output_data,
+ const CudnnRnnStateTensorDescriptor& output_h_desc,
+ DeviceMemory<T>* output_h_data,
+ const CudnnRnnStateTensorDescriptor& output_c_desc,
+ DeviceMemory<T>* output_c_data, bool is_training,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator) {
+ // extract model parameters
+ RnnModelDims model_dims;
+ bool res = ExtractAndCheckRnnForward(
+ rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
+ input_c_desc, input_c_data, params, output_desc, *output_data,
+ output_h_desc, *output_h_data, output_c_desc, *output_c_data,
+ &model_dims);
+ if (!res) {
+ LOG(ERROR) << "Invalid parameters for RNN Model";
+ return false;
+ }
+
+ // check params size
+ mutex_lock lock{dnn_handle_mutex_};
+
+ if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc,
+ input_desc)) {
+ LOG(ERROR) << "Invalid parameters";
+ return false;
+ }
+
+ // create the workspace
+ DeviceMemory<uint8> workspace;
+ if (!CreateRnnWorkspace(stream, parent_, ToHandle(dnn_handle_), rnn_desc,
+ input_desc, workspace_allocator, &workspace)) {
+ LOG(ERROR) << "Unable to create rnn workspace";
+ return false;
+ }
+
+ // query the reserve space size
+ // allocate the reserve space
+ DeviceMemory<uint8> reserve_space;
+ if (is_training) {
+ size_t reserve_space_size_in_bytes = 0;
+ cudnnStatus_t status = dynload::cudnnGetRNNTrainingReserveSize(
+ parent_, ToHandle(dnn_handle_) /*handle*/,
+ rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
+ input_desc.handles() /*xDesc*/,
+ &reserve_space_size_in_bytes /*sizeInBytes*/);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "Unable to query reserve space size: " << ToString(status);
+ return false;
+ }
+
+ if (reserve_space_size_in_bytes > 0) {
+ auto allocated = reserve_space_allocator->AllocateBytes(
+ stream, reserve_space_size_in_bytes);
+ if (!allocated.ok() ||
+ (reserve_space = allocated.ValueOrDie()) == nullptr) {
+ LOG(ERROR) << "Fail to allocate RNN reserve space";
+ return false;
+ }
+ }
+ }
+
+ // make the forward call
+ if (!is_training) {
+ cudnnStatus_t status = dynload::cudnnRNNForwardInference(
+ parent_, ToHandle(dnn_handle_) /*handle*/,
+ rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
+ input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
+ input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
+ input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
+ rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
+ output_desc.handles() /*yDesc*/, output_data->opaque() /*y*/,
+ output_h_desc.handle() /*hyDesc*/, output_h_data->opaque() /*hy*/,
+ output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/,
+ workspace.opaque() /*workspace*/,
+ workspace.size() /*workSpaceSizeInBytes*/);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "Failed to call cudnnRNNForwardInference: "
+ << ToString(status);
+ return false;
+ }
+ } else {
+ cudnnStatus_t status = dynload::cudnnRNNForwardTraining(
+ parent_, ToHandle(dnn_handle_) /*handle*/,
+ rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
+ input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
+ input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
+ input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
+ rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
+ output_desc.handles() /*yDesc*/, output_data->opaque() /*y*/,
+ output_h_desc.handle() /*hyDesc*/, output_h_data->opaque() /*hy*/,
+ output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/,
+ workspace.opaque() /*workspace*/,
+ workspace.size() /*workSpaceSizeInBytes*/,
+ reserve_space.opaque() /*reserveSpace*/,
+ reserve_space.size() /*reserveSpaceSizeInBytes*/);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "Failed to call cudnnRNNForwardTraining"
+ << ToString(status);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+template <class T>
+bool CudnnSupport::DoRnnBackwardImpl(
+ Stream* stream, const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<T>& input_data,
+ const CudnnRnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<T>& input_h_data,
+ const CudnnRnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
+ const CudnnRnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<T>& output_data,
+ const CudnnRnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<T>& output_h_data,
+ const CudnnRnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<T>& output_c_data,
+ const DeviceMemory<float>& output_backprop_data,
+ const DeviceMemory<float>& output_h_backprop_data,
+ const DeviceMemory<float>& output_c_backprop_data,
+ DeviceMemory<float>* input_backprop_data,
+ DeviceMemory<float>* input_h_backprop_data,
+ DeviceMemory<float>* input_c_backprop_data,
+ DeviceMemory<float>* params_backprop_data,
+ DeviceMemory<uint8>* reserve_space_data,
+ ScratchAllocator* workspace_allocator) {
+ // extract model parameters
+ RnnModelDims model_dims;
+ bool res = ExtractAndCheckRnnForward(
+ rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
+ input_c_desc, input_c_data, params, output_desc, output_data,
+ output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims);
+ if (!res) {
+ LOG(ERROR) << "Invalid parameters for RNN Model";
+ return false;
+ }
+
+ // check params size
+ mutex_lock lock{dnn_handle_mutex_};
+
+ if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc,
+ input_desc)) {
+ LOG(ERROR) << "Invalid parameters";
+ return false;
+ }
+
+ // create the workspace
+ DeviceMemory<uint8> workspace;
+ if (!CreateRnnWorkspace(stream, parent_, ToHandle(dnn_handle_), rnn_desc,
+ input_desc, workspace_allocator, &workspace)) {
+ LOG(ERROR) << "Unable to create rnn workspace";
+ return false;
+ }
+
+ // make the backward data call
+ cudnnStatus_t status = dynload::cudnnRNNBackwardData(
+ parent_, ToHandle(dnn_handle_) /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/,
+ output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/,
+ output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/,
+ output_h_backprop_data.opaque() /*dhy*/,
+ output_c_desc.handle() /*dcyDesc*/,
+ output_c_backprop_data.opaque() /*dcy*/,
+ rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
+ input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
+ input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
+ input_desc.handles() /*dxDesc*/, input_backprop_data->opaque() /*dx*/,
+ input_h_desc.handle() /*dhxDesc*/,
+ input_h_backprop_data->opaque() /*dhx*/,
+ input_c_desc.handle() /*dcxDesc*/,
+ input_c_backprop_data->opaque() /*dcx*/, workspace.opaque() /*workspace*/,
+ workspace.size() /*workSpaceSizeInBytes*/,
+ reserve_space_data->opaque() /*reserveSpace*/,
+ reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "Failed to call cudnnRNNBackwardData: " << ToString(status);
+ return false;
+ }
+
+ if (params_backprop_data != nullptr) {
+ // Clear the dw to zeros.
+ stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
+ // make the backward weight call
+ status = dynload::cudnnRNNBackwardWeights(
+ parent_, ToHandle(dnn_handle_) /*handle*/,
+ rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
+ input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
+ input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
+ output_desc.handles() /*yDesc*/, output_data.opaque() /*y*/,
+ workspace.opaque() /*workspace*/,
+ workspace.size() /*workSpaceSizeInBytes*/,
+ rnn_desc.params_handle() /*dwDesc*/,
+ params_backprop_data->opaque() /*dw*/,
+ reserve_space_data->opaque() /*reserveSpace*/,
+ reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "Failed to call cudnnRNNBackwardWeights: "
+ << ToString(status);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+#endif // CUDNN_VERSION
+
+port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
+CudnnSupport::createRnnDescriptor(int num_layers, int hidden_size,
+ int input_size, dnn::RnnInputMode input_mode,
+ dnn::RnnDirectionMode direction_mode,
+ dnn::RnnMode rnn_mode,
+ dnn::DataType data_type, float dropout,
+ uint64 seed,
+ ScratchAllocator* state_allocator) {
+#if CUDNN_VERSION >= 5000
+ mutex_lock lock{dnn_handle_mutex_};
+ std::unique_ptr<CudnnRnnDescriptor> rnn_desc(new CudnnRnnDescriptor(
+ parent_, ToHandle(dnn_handle_), num_layers, hidden_size, input_size,
+ ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode),
+ ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type), dropout, seed,
+ state_allocator));
+ if (!rnn_desc->ok()) {
+ return rnn_desc->Status();
+ }
+ return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>(
+ std::move(rnn_desc));
+#else
+ string error_msg =
+ port::StrCat("createRnnDescriptor needs at least Cudnn 5.0 to work. ",
+ "Current Cudnn version: ", CUDNN_VERSION, ". ");
+ LOG(ERROR) << error_msg;
+ return port::Status{port::error::UNIMPLEMENTED, error_msg};
+#endif // CUDNN_VERSION
+}
+
+port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
+CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
+ int data_size,
+ dnn::DataType data_type) {
+#if CUDNN_VERSION >= 5000
+ std::unique_ptr<CudnnRnnSequenceTensorDescriptor> seq_desc(
+ new CudnnRnnSequenceTensorDescriptor(parent_, seq_length, batch_size,
+ data_size,
+ ToCudnnDataType(data_type)));
+ if (!seq_desc->ok()) {
+ return seq_desc->Status();
+ }
+ return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>(
+ std::move(seq_desc));
+#else
+ string error_msg = port::StrCat(
+ "createRnnSequenceTensorDescriptor needs at least Cudnn 5.0 to work. ",
+ "Current Cudnn version: ", CUDNN_VERSION, ". ");
+ LOG(ERROR) << error_msg;
+ return port::Status{port::error::UNIMPLEMENTED, error_msg};
+#endif // CUDNN_VERSION
+}
+
+port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
+CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
+ int data_size,
+ dnn::DataType data_type) {
+#if CUDNN_VERSION >= 5000
+ std::unique_ptr<CudnnRnnStateTensorDescriptor> state_desc(
+ new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
+ data_size, ToCudnnDataType(data_type)));
+ if (!state_desc->ok()) {
+ return state_desc->Status();
+ }
+ return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>(
+ std::move(state_desc));
+#else
+ string error_msg = port::StrCat(
+ "createRnnStateTensorDescriptor needs at least Cudnn 5.0 to work. ",
+ "Current Cudnn version: ", CUDNN_VERSION, ". ");
+ LOG(ERROR) << error_msg;
+ return port::Status{port::error::UNIMPLEMENTED, error_msg};
+#endif // CUDNN_VERSION
+}
+
+bool CudnnSupport::DoRnnForward(
+ Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<float>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<float>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ DeviceMemory<float>* output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ DeviceMemory<float>* output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ DeviceMemory<float>* output_c_data, bool is_training,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator) {
+#if CUDNN_VERSION >= 5000
+ const CudnnRnnDescriptor& cudnn_rnn_desc =
+ static_cast<const CudnnRnnDescriptor&>(rnn_desc);
+ const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
+ static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
+ const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
+ static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
+ const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
+ static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
+ const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
+ static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
+ const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
+ static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
+ const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
+ static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
+
+ return DoRnnForwardImpl<float>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
+ input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
+ output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
+ output_c_data, is_training, reserve_space_allocator, workspace_allocator);
+#else
+ return false;
+#endif // CUDNN_VERSION
+}
+
+bool CudnnSupport::DoRnnBackward(
+ Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<float>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<float>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<float>& output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<float>& output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<float>& output_c_data,
+ const DeviceMemory<float>& output_backprop_data,
+ const DeviceMemory<float>& output_h_backprop_data,
+ const DeviceMemory<float>& output_c_backprop_data,
+ DeviceMemory<float>* input_backprop_data,
+ DeviceMemory<float>* input_h_backprop_data,
+ DeviceMemory<float>* input_c_backprop_data,
+ DeviceMemory<float>* params_backprop_data,
+ DeviceMemory<uint8>* reserve_space_data,
+ ScratchAllocator* workspace_allocator) {
+#if CUDNN_VERSION >= 5000
+ const CudnnRnnDescriptor& cudnn_rnn_desc =
+ static_cast<const CudnnRnnDescriptor&>(rnn_desc);
+ const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
+ static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
+ const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
+ static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
+ const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
+ static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
+ const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
+ static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
+ const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
+ static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
+ const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
+ static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
+
+ return DoRnnBackwardImpl<float>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
+ input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
+ output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
+ output_c_data, output_backprop_data, output_h_backprop_data,
+ output_c_backprop_data, input_backprop_data, input_h_backprop_data,
+ input_c_backprop_data, params_backprop_data, reserve_space_data,
+ workspace_allocator);
+#else
+ return false;
+#endif // CUDNN_VERSION
+}
+
template <class T>
bool CudnnSupport::DoConvolveImpl(
Stream* stream, int cudnn_type, // Actually cudnnDataType_t.
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 9805781fcb..f8bc0c493f 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -31,6 +31,9 @@ namespace gputools {
namespace cuda {
class CUDAExecutor;
+class CudnnRnnDescriptor;
+class CudnnRnnSequenceTensorDescriptor;
+class CudnnRnnStateTensorDescriptor;
// Opaque and unique identifier for the cuDNN plugin.
extern const PluginId kCuDnnPlugin;
@@ -44,6 +47,62 @@ class CudnnSupport : public dnn::DnnSupport {
port::Status Init() override;
+ port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
+ int num_layers, int hidden_size, int input_size,
+ dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout,
+ uint64 seed, ScratchAllocator* state_allocator) override;
+
+ port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
+ createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
+ int data_size,
+ dnn::DataType data_type) override;
+
+ port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
+ createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
+ dnn::DataType data_type) override;
+
+ bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<float>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<float>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<float>& input_c_data,
+ const DeviceMemory<float>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ DeviceMemory<float>* output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ DeviceMemory<float>* output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ DeviceMemory<float>* output_c_data, bool is_training,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator) override;
+
+ bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<float>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<float>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<float>& input_c_data,
+ const DeviceMemory<float>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<float>& output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<float>& output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<float>& output_c_data,
+ const DeviceMemory<float>& output_backprop_data,
+ const DeviceMemory<float>& output_h_backprop_data,
+ const DeviceMemory<float>& output_c_backprop_data,
+ DeviceMemory<float>* input_backprop_data,
+ DeviceMemory<float>* input_h_backprop_data,
+ DeviceMemory<float>* input_c_backprop_data,
+ DeviceMemory<float>* params_backprop_data,
+ DeviceMemory<uint8>* reserve_space_data,
+ ScratchAllocator* workspace_allocator) override;
+
bool GetConvolveAlgorithms(
std::vector<dnn::AlgorithmType>* out_algorithms) override;
@@ -369,6 +428,49 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<T>* backward_bias_data);
+ template <class T>
+ bool DoRnnForwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<T>& input_data,
+ const CudnnRnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<T>& input_h_data,
+ const CudnnRnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<T>& input_c_data,
+ const DeviceMemory<T>& params,
+ const CudnnRnnSequenceTensorDescriptor& output_desc,
+ DeviceMemory<T>* output_data,
+ const CudnnRnnStateTensorDescriptor& output_h_desc,
+ DeviceMemory<T>* output_h_data,
+ const CudnnRnnStateTensorDescriptor& output_c_desc,
+ DeviceMemory<T>* output_c_data, bool is_training,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator);
+
+ template <class T>
+ bool DoRnnBackwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<T>& input_data,
+ const CudnnRnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<T>& input_h_data,
+ const CudnnRnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<T>& input_c_data,
+ const DeviceMemory<T>& params,
+ const CudnnRnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<T>& output_data,
+ const CudnnRnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<T>& output_h_data,
+ const CudnnRnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<T>& output_c_data,
+ const DeviceMemory<float>& output_backprop_data,
+ const DeviceMemory<float>& output_h_backprop_data,
+ const DeviceMemory<float>& output_c_backprop_data,
+ DeviceMemory<float>* input_backprop_data,
+ DeviceMemory<float>* input_h_backprop_data,
+ DeviceMemory<float>* input_c_backprop_data,
+ DeviceMemory<float>* params_backprop_data,
+ DeviceMemory<uint8>* reserve_space_data,
+ ScratchAllocator* workspace_allocator);
+
SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
};
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index c2310c8938..1c31178526 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -23,10 +23,12 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_DNN_H_
#include <limits>
+#include <memory>
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
@@ -76,6 +78,94 @@ enum class QuantizedActivationMode {
k32Bit = 4,
};
+// Specifies the data type used by an operation.
+enum class DataType {
+ kFloat = 0,
+ kDouble = 1,
+ kHalf = 2,
+};
+
+// A helper class to convert C/C++ types to the proper enums.
+template <typename T>
+struct ToDataType;
+template <>
+struct ToDataType<float> {
+ static constexpr DataType value = DataType::kFloat;
+};
+template <>
+struct ToDataType<double> {
+ static constexpr DataType value = DataType::kDouble;
+};
+template <>
+struct ToDataType<Eigen::half> {
+ static constexpr DataType value = DataType::kHalf;
+};
+
+// Specifies the types of a RNN model.
+enum class RnnMode {
+ kRnnRelu = 0,
+ kRnnTanh = 1,
+ kRnnLstm = 2,
+ kRnnGru = 3,
+};
+
+// Specifies the input model and whether there is a linear transformation
+// between the input state and the first layer hidden state.
+enum class RnnInputMode {
+ kRnnLinearSkip = 0,
+ kRnnSkipInput = 1,
+};
+
+// Specifies the number of directions used in a RNN model. When bidirection
+// is used, the input states and output sequence contain data for both
+// directions.
+enum class RnnDirectionMode {
+ kRnnUnidirectional = 0,
+ kRnnBidirectional = 1,
+};
+
+// Specifies the descriptor for a RNN model.
+//
+// An example use case:
+// * The user first creates a model through createRnnDescriptor.
+// * The user queries the size of the underlying opaque parameter buffer.
+// * The user creates and initializes a parameter buffer of the proper size.
+// * The user runs forward and backward operations using this RNN descriptor.
+// * Once a while, user queries maintainable weights and bias regions from
+// the underlying parameter buffer. They are more likely to be forward
+// compatible and should used in saving and restoring a model.
+// * The user releases the RNN descriptor when the model is no longer in use.
+class RnnDescriptor {
+ public:
+ struct ParamsRegion {
+ int64 offset;
+ int64 size;
+ };
+ typedef std::vector<ParamsRegion> ParamsRegions;
+ virtual ~RnnDescriptor() {}
+ virtual int64 ParamsSizeInBytes() const { return -1; }
+ virtual ParamsRegions ParamsWeightRegions() const { return ParamsRegions(); }
+ virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); }
+};
+
+// Specifies the sequence in a RNN model.
+//
+// The user is responsible for releasing this descriptor when it is no longer
+// in use. The destructor releases the underlying descriptors.
+class RnnSequenceTensorDescriptor {
+ public:
+ virtual ~RnnSequenceTensorDescriptor() {}
+};
+
+// Specifies either the input and hidden state in a RNN model.
+//
+// The user is responsible for releasing this descriptor when it is no longer
+// in use. The destructor releases the underlying descriptors.
+class RnnStateTensorDescriptor {
+ public:
+ virtual ~RnnStateTensorDescriptor() {}
+};
+
// Returns a string representation of the given quantization mode.
string QuantizedActivationModeString(QuantizedActivationMode mode);
@@ -1260,6 +1350,179 @@ class DnnSupport {
QuantizedActivationMode mode,
DeviceMemory<float>* gpu_unquantized_dst) = 0;
+
+ // Create an RNN descriptor based on model shapes and configurations.
+ // The caller retains the ownership of the descriptor.
+ //
+ // Arguments:
+ // num_layers: the number of layers for a RNN model.
+ // hidden_size: the size of the hidden state.
+ // input_size: the size of the input state.
+ // input_mode: an enum to specify whether a linear transformation is added
+ // after the input state. If input_size is different from hidden_size, this
+ // is required.
+ // direction_mode: an enum to specify whether this model is unidirectional or
+ // bidirectional.
+ // rnn_mode: an enum to specify the type of model to build.
+ // data_type: an enum to specify the data types used in this model.
+ // dropout: the dropout threshold between layers. When it is 0., no dropout
+ // is added.
+ // seed: a seed for initializing the dropout layers.
+ // state_allocator: an memory allocator that will be used to store the state
+ // for dropout layer. The user has to maintain the memory until the model
+ // is no longer in use.
+ virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
+ createRnnDescriptor(int num_layers, int hidden_size, int input_size,
+ dnn::RnnInputMode input_mode,
+ dnn::RnnDirectionMode direction_mode,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type,
+ float dropout, uint64 seed,
+ ScratchAllocator* state_allocator) {
+ return port::Status{port::error::UNIMPLEMENTED,
+ "createRnnDescriptor is unimplemented"};
+ }
+
+ // Create a RNN sequence descriptor that specifies either the input or output
+ // sequence. The caller retains the ownership of the returned descriptor.
+ //
+ // Arguments:
+ // seq_length: the length of the sequence.
+ // batch_size: the size of a minibatch.
+ // data_size: the size of the state.
+ // data_type: an enum to specify the type for the underlying data.
+ virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
+ createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
+ int data_size, dnn::DataType data_type) {
+ return port::Status{port::error::UNIMPLEMENTED,
+ "createRnnSequenceTensorDescriptor is unimplemented"};
+ }
+
+ // Create an RNN state descriptor that specifies the input or hidden state.
+ // The caller retains the ownership of the returned descriptor.
+ virtual port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
+ createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
+ dnn::DataType data_type) {
+ return port::Status{port::error::UNIMPLEMENTED,
+ "createRnnStateTensorDescriptor is unimplemented"};
+ }
+
+ // Enqueue a forward operation of the RNN model onto the stream.
+ //
+ // Arguments:
+ // stream: pointer to the stream where this operation should be enqueued to.
+ // rnn_desc: a RNN descriptor created by createRnnDescriptor.
+ // input_desc: descriptor for the input sequence.
+ // input_data: the device memory region that contains the input data.
+ // input_h_desc: descriptor for the input "h" state.
+ // input_h_data: the device memory region that contains the input "h" data.
+ // input_c_desc: descriptor for the input "c" state.
+ // input_c_data: the device memory region that contains the input "c" data.
+ // This must be specified for LSTM models.
+ // params: the device memory region that contains the parameters used in this
+ // model.
+ // output_desc: descriptor for the output sequence.
+ // output_data: the memory region that stores the output sequence data.
+ // output_h_desc: descriptor for the output "h" state.
+ // output_h_data: the memory region that stores the output "h" data.
+ // output_c_desc: descriptor for the output "c" state.
+ // output_c_data: the memory region that stores the outptu "c" data. This
+ // must be specified for LSTM models.
+ // is_training: whether this is used in training or inference. That decides
+ // whether respace_space data need to be produced.
+ // reserve_space_allocator: if "is_training" is true, an memory allocator
+ // to create memory that holds the produced reserve_space. The caller is
+ // retains the data and feed it to the backward pass.
+ // workspace_allocator: an allocator to create temporary workspace used in
+ // this kernel. The caller is responsible for retaining the memory long
+ // enough for the lifespan of this operation, and recycles aftewards.
+ virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<float>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<float>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<float>& input_c_data,
+ const DeviceMemory<float>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ DeviceMemory<float>* output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ DeviceMemory<float>* output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ DeviceMemory<float>* output_c_data,
+ bool is_training,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator) {
+ return false;
+ }
+
+ // Enqueue a backward operation of the RNN model onto the stream.
+ //
+ // Arguments:
+ // stream: pointer to the stream where this operation should be enqueued to.
+ // rnn_desc: a RNN descriptor created by createRnnDescriptor.
+ // input_desc: descriptor for the input sequence.
+ // input_data: the device memory region that contains the input data.
+ // input_h_desc: descriptor for the input "h" state.
+ // input_h_data: the device memory region that contains the input "h" data.
+ // input_c_desc: descriptor for the input "c" state.
+ // input_c_data: the device memory region that contains the input "c" data.
+ // This must be specified for LSTM models.
+ // params: the device memory region that contains the parameters used in this
+ // model.
+ // output_desc: descriptor for the output sequence.
+ // output_data: the memory region that stores the output sequence data.
+ // output_h_desc: descriptor for the output "h" state.
+ // output_h_data: the memory region that stores the output "h" data.
+ // output_c_desc: descriptor for the output "c" state.
+ // output_c_data: the memory region that stores the outptu "c" data. This
+ // must be specified for LSTM models.
+ // output_backprop_data: the device memory region that contains the backprop
+ // to the output sequence.
+ // output_h_backprop_data: the device memory region that contains the
+ // backprop to the output "h" state.
+ // output_c_backprop_data: the device memory region that contains the
+ // backprop to the output "c" state.
+ // input_backprop_data: the device memory region that stores the backprop
+ // to the input sequence.
+ // input_h_backprop_data: the device memory region that stores the backprop
+ // to the input "h" state.
+ // input_c_backprop_data: the device memory region that stores the backprop
+ // to the input "c" state.
+ // params_backprop_data: the device memory region that stores the backprop
+ // to the parameters.
+ // reserve_space_data: the reserve_space data that is produced by the forward
+ // operation. This memory region could be modified by this operation.
+ // workspace_allocator: a memory allocator that creates the temporary
+ // workspace memory used by this operation. The caller is responsible for
+ // keeping the memory alive long enough for this operation, and recylces
+ // afterwards.
+ virtual bool DoRnnBackward(
+ Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<float>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<float>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<float>& input_c_data,
+ const DeviceMemory<float>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<float>& output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<float>& output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<float>& output_c_data,
+ const DeviceMemory<float>& output_backprop_data,
+ const DeviceMemory<float>& output_h_backprop_data,
+ const DeviceMemory<float>& output_c_backprop_data,
+ DeviceMemory<float>* input_backprop_data,
+ DeviceMemory<float>* input_h_backprop_data,
+ DeviceMemory<float>* input_c_backprop_data,
+ DeviceMemory<float>* params_backprop_data,
+ DeviceMemory<uint8>* reserve_space_data,
+ ScratchAllocator* workspace_allocator) {
+ return false;
+ }
+
private:
SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
};
@@ -1269,3 +1532,4 @@ class DnnSupport {
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_DNN_H_
+
diff --git a/tensorflow/stream_executor/platform/port.h b/tensorflow/stream_executor/platform/port.h
index 3c1b40b016..6603df4878 100644
--- a/tensorflow/stream_executor/platform/port.h
+++ b/tensorflow/stream_executor/platform/port.h
@@ -55,5 +55,7 @@ using tensorflow::LINKER_INITIALIZED;
#define SE_DISALLOW_COPY_AND_ASSIGN TF_DISALLOW_COPY_AND_ASSIGN
#define SE_MUST_USE_RESULT TF_MUST_USE_RESULT
+#define SE_PREDICT_TRUE TF_PREDICT_TRUE
+#define SE_PREDICT_FALSE TF_PREDICT_FALSE
#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 92ceb5bbb2..8c0e45f1a6 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -3774,6 +3774,79 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern,
return *this;
}
+Stream &Stream::ThenRnnForward(
+ const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<float> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<float> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ DeviceMemory<float> *output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ DeviceMemory<float> *output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ DeviceMemory<float> *output_c_data, bool is_training,
+ ScratchAllocator *reserve_space_allocator,
+ ScratchAllocator *workspace_allocator) {
+ // TODO(zhengxq): add VLOG PARAM calls.
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoRnnForward(
+ this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
+ input_c_desc, input_c_data, params, output_desc, output_data,
+ output_h_desc, output_h_data, output_c_desc, output_c_data,
+ is_training, reserve_space_allocator, workspace_allocator));
+ } else {
+ SetError();
+ LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenRnnBackward(
+ const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<float> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<float> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ const DeviceMemory<float> &output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ const DeviceMemory<float> &output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ const DeviceMemory<float> &output_c_data,
+ const DeviceMemory<float> &output_backprop_data,
+ const DeviceMemory<float> &output_h_backprop_data,
+ const DeviceMemory<float> &output_c_backprop_data,
+ DeviceMemory<float> *input_backprop_data,
+ DeviceMemory<float> *input_h_backprop_data,
+ DeviceMemory<float> *input_c_backprop_data,
+ DeviceMemory<float> *params_backprop_data,
+ DeviceMemory<uint8> *reserve_space_data,
+ ScratchAllocator *workspace_allocator) {
+ // TODO(zhengxq): add VLOG PARAM calls.
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoRnnBackward(
+ this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
+ input_c_desc, input_c_data, params, output_desc, output_data,
+ output_h_desc, output_h_data, output_c_desc, output_c_data,
+ output_backprop_data, output_h_backprop_data, output_c_backprop_data,
+ input_backprop_data, input_h_backprop_data, input_c_backprop_data,
+ params_backprop_data, reserve_space_data, workspace_allocator));
+ } else {
+ SetError();
+ LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
+ }
+ }
+ return *this;
+}
+
Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) {
VLOG_CALL(PARAM(callback));
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 4d514804e5..61058528c2 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1383,6 +1383,51 @@ class Stream {
Stream &ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern,
uint64 size);
+ // Enqueue a forward operation of the RNN model onto the stream.
+ // See DnnSupport::DoRnnForward for more details.
+ Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<float> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<float> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<float> &input_c_data,
+ const DeviceMemory<float> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ DeviceMemory<float> *output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ DeviceMemory<float> *output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ DeviceMemory<float> *output_c_data, bool is_training,
+ ScratchAllocator *reserve_space_allocator,
+ ScratchAllocator *workspace_allocator);
+
+ // Enqueue a backward operation of the RNN model onto the stream.
+ // See DnnSupport::DoRnnBackward for more details.
+ Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<float> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<float> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<float> &input_c_data,
+ const DeviceMemory<float> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ const DeviceMemory<float> &output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ const DeviceMemory<float> &output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ const DeviceMemory<float> &output_c_data,
+ const DeviceMemory<float> &output_backprop_data,
+ const DeviceMemory<float> &output_h_backprop_data,
+ const DeviceMemory<float> &output_c_backprop_data,
+ DeviceMemory<float> *input_backprop_data,
+ DeviceMemory<float> *input_h_backprop_data,
+ DeviceMemory<float> *input_c_backprop_data,
+ DeviceMemory<float> *params_backprop_data,
+ DeviceMemory<uint8> *reserve_space_data,
+ ScratchAllocator *workspace_allocator);
+
// (Synchronously) block the host code waiting for the operations
// entrained on the stream (enqueued to this point in program
// execution) to complete.
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 07dc375ef4..2fdd1e4b49 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -309,6 +309,48 @@ bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms);
}
+port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
+StreamExecutor::createRnnDescriptor(
+ int num_layers, int hidden_size, int input_size,
+ dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed,
+ ScratchAllocator *state_allocator) {
+ dnn::DnnSupport *dnn_support = AsDnn();
+ if (!dnn_support) {
+ return port::Status(port::error::UNKNOWN,
+ "Fail to find the dnn implementation.");
+ }
+ return dnn_support->createRnnDescriptor(
+ num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode,
+ data_type, dropout, seed, state_allocator);
+}
+
+port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
+StreamExecutor::createRnnSequenceTensorDescriptor(int seq_length,
+ int batch_size, int data_size,
+ dnn::DataType data_type) {
+ dnn::DnnSupport *dnn_support = AsDnn();
+ if (!dnn_support) {
+ return port::Status(port::error::UNKNOWN,
+ "Fail to find the dnn implementation.");
+ }
+ return dnn_support->createRnnSequenceTensorDescriptor(seq_length, batch_size,
+ data_size, data_type);
+}
+
+port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
+StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
+ int data_size,
+ dnn::DataType data_type) {
+ dnn::DnnSupport *dnn_support = AsDnn();
+ if (!dnn_support) {
+ return port::Status(port::error::UNKNOWN,
+ "Fail to find the dnn implementation.");
+ }
+ return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
+ data_size, data_type);
+}
+
dnn::DnnSupport *StreamExecutor::AsDnn() {
mutex_lock lock{mu_};
if (dnn_ != nullptr) {
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 5f1b8bda02..2b5a70f807 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -350,6 +350,26 @@ class StreamExecutor {
bool GetConvolveBackwardFilterAlgorithms(
std::vector<dnn::AlgorithmType> *out_algorithms);
+ // Create an RNN descriptor based on model shapes and configurations.
+ // The caller retains the ownership of the descriptor.
+ port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
+ int num_layers, int hidden_size, int input_size,
+ dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout,
+ uint64 seed, ScratchAllocator *state_allocator);
+
+ // Create a RNN sequence descriptor that specifies either the input or output
+ // sequence. The caller retains the ownership of the returned descriptor.
+ port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
+ createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
+ int data_size, dnn::DataType data_type);
+
+ // Create an RNN state descriptor that specifies the input or hidden state.
+ // The caller retains the ownership of the returned descriptor.
+ port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
+ createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
+ dnn::DataType data_type);
+
// Returns the device ordinal that this StreamExecutor was initialized with.
// Meaningless before initialization.
int device_ordinal() const { return device_ordinal_; }