diff options
author | 2016-08-19 10:01:32 -0800 | |
---|---|---|
committer | 2016-08-19 11:18:09 -0700 | |
commit | 859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (patch) | |
tree | ae0c4e4e690d53b92720049fb75acc119e4ea5e0 | |
parent | de8838042fb34e53a511f25b4613611fc368beeb (diff) |
Add stream-executor changes to enable Cudnn fused LSTM/RNN support.
Change: 130770287
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 1047 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.h | 102 | ||||
-rw-r--r-- | tensorflow/stream_executor/dnn.h | 264 | ||||
-rw-r--r-- | tensorflow/stream_executor/platform/port.h | 2 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 73 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream.h | 45 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.cc | 42 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.h | 20 |
8 files changed, 1594 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index b042dda29f..7fbafa3d7e 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -238,7 +238,25 @@ CUDNN_DNN_ROUTINE_EACH_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) __macro(cudnnCreateActivationDescriptor) \ __macro(cudnnSetActivationDescriptor) \ __macro(cudnnGetActivationDescriptor) \ - __macro(cudnnDestroyActivationDescriptor) + __macro(cudnnDestroyActivationDescriptor) \ + __macro(cudnnCreateDropoutDescriptor) \ + __macro(cudnnDestroyDropoutDescriptor) \ + __macro(cudnnSetDropoutDescriptor) \ + __macro(cudnnDropoutGetStatesSize) \ + __macro(cudnnCreateRNNDescriptor) \ + __macro(cudnnDestroyRNNDescriptor) \ + __macro(cudnnGetRNNParamsSize) \ + __macro(cudnnGetRNNWorkspaceSize) \ + __macro(cudnnGetRNNTrainingReserveSize) \ + __macro(cudnnGetRNNLinLayerMatrixParams) \ + __macro(cudnnGetRNNLinLayerBiasParams) \ + __macro(cudnnRNNForwardInference) \ + __macro(cudnnRNNForwardTraining) \ + __macro(cudnnRNNBackwardData) \ + __macro(cudnnRNNBackwardWeights) \ + __macro(cudnnSetRNNDescriptor) \ + __macro(cudnnGetFilterNdDescriptor) + // clang-format on CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) @@ -759,6 +777,1033 @@ class ScopedActivationDescriptor { }; #endif +namespace { + +#if CUDNN_VERSION >= 5000 + +cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) { + switch (input_mode) { + case dnn::RnnInputMode::kRnnLinearSkip: + case dnn::RnnInputMode::kRnnSkipInput: + return static_cast<cudnnRNNInputMode_t>(input_mode); + default: + LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode); + } +} + +cudnnDirectionMode_t ToCudnnRnnDirectionMode( + dnn::RnnDirectionMode direction_mode) { + switch (direction_mode) { + case dnn::RnnDirectionMode::kRnnUnidirectional: + case dnn::RnnDirectionMode::kRnnBidirectional: + return static_cast<cudnnDirectionMode_t>(direction_mode); + default: + LOG(FATAL) << "Invalid RNN direction mode: " + << static_cast<int>(direction_mode); + } +} + +cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) { + switch (rnn_mode) { + case dnn::RnnMode::kRnnRelu: + case dnn::RnnMode::kRnnTanh: + case dnn::RnnMode::kRnnLstm: + case dnn::RnnMode::kRnnGru: + return static_cast<cudnnRNNMode_t>(rnn_mode); + default: + LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode); + } +} + +cudnnDataType_t ToCudnnDataType(dnn::DataType data_type) { + switch (data_type) { + case dnn::DataType::kFloat: + case dnn::DataType::kDouble: + case dnn::DataType::kHalf: + return static_cast<cudnnDataType_t>(data_type); + default: + LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type); + } +} + +int CudnnDataTypeToByteSize(cudnnDataType_t data_type) { + switch (data_type) { + case CUDNN_DATA_FLOAT: + return sizeof(float); + case CUDNN_DATA_DOUBLE: + return sizeof(double); + case CUDNN_DATA_HALF: + return sizeof(Eigen::half); + default: + LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type); + } +} + +#endif // CUDNN_VERSION + +template <typename Base> +class MixinBase : public Base {}; +template <> +class MixinBase<void> {}; + +} // namespace + +#if CUDNN_VERSION >= 5000 + +#define CUDNN_RETURN_IF_FAIL(STATUS, ...) \ + if (!SE_PREDICT_TRUE((STATUS) == CUDNN_STATUS_SUCCESS)) { \ + string error_msg = port::StrCat(ToString(STATUS), " ", __VA_ARGS__); \ + SetFailure(port::Status(port::error::UNKNOWN, error_msg)); \ + LOG(ERROR) << error_msg; \ + return; \ + } + +template <typename Base> +class CudnnDescriptorCommon : public MixinBase<Base> { + public: + bool ok() const { return status_.ok(); } + port::Status Status() const { return status_; } + + protected: + void SetFailure(const port::Status& status) { status_.Update(status); } + port::Status status_; +}; + +class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> { + public: + CudnnDropoutDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle, + float dropout, uint64 seed, + ScratchAllocator* state_allocator) + : parent_(parent), handle_(nullptr) { + cudnnStatus_t status; + status = dynload::cudnnCreateDropoutDescriptor(parent_, &handle_); + CUDNN_RETURN_IF_FAIL(status, "Failed to create dropout descriptor"); + + if (dropout == 0.f) { + return; + } + + DeviceMemory<uint8> state_memory; + if (state_allocator) { + size_t state_sizes_in_bytes = 0; + status = dynload::cudnnDropoutGetStatesSize(parent_, cudnn_handle, + &state_sizes_in_bytes); + CUDNN_RETURN_IF_FAIL(status, "Failed to query dropout state sizes"); + + auto allocated = + state_allocator->AllocateBytes(nullptr, state_sizes_in_bytes); + if (!allocated.ok() || + (state_memory = allocated.ValueOrDie()) == nullptr) { + string error_msg = + port::StrCat("Fail to allocate Cudnn dropout state memory"); + status_ = port::Status(port::error::UNKNOWN, error_msg); + LOG(ERROR) << error_msg; + return; + } + } + status = dynload::cudnnSetDropoutDescriptor(parent_, handle_, cudnn_handle, + dropout, state_memory.opaque(), + state_memory.size(), seed); + CUDNN_RETURN_IF_FAIL(status, "Failed to set dropout descriptor"); + } + + ~CudnnDropoutDescriptor() { + if (handle_) { + cudnnStatus_t status = + dynload::cudnnDestroyDropoutDescriptor(parent_, handle_); + CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: "); + } + } + + cudnnDropoutDescriptor_t handle() const { + if (!ok()) return nullptr; + return handle_; + } + + private: + CUDAExecutor* parent_; + cudnnDropoutDescriptor_t handle_; + float dropout_; + uint64 seed_; + port::Status status_; + SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor); +}; + +class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> { + public: + typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion; + typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions; + CudnnRnnParamsDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle, + const CudnnRnnDescriptor& rnn_desc); + ~CudnnRnnParamsDescriptor() { + cudnnStatus_t status = + dynload::cudnnDestroyFilterDescriptor(parent_, handle_); + CUDNN_RETURN_IF_FAIL(status, "Failed to destroy RNN filter desciptor"); + } + cudnnFilterDescriptor_t handle() const { + if (!ok()) return nullptr; + return handle_; + } + int64 params_size_in_bytes() const { return params_size_in_bytes_; } + ParamsRegions params_weights() const { + if (!ok()) return ParamsRegions(); + return weights_; + } + ParamsRegions params_biases() const { + if (!ok()) return ParamsRegions(); + return biases_; + } + + private: + int GetRegionCountPerLayer() const; + CUDAExecutor* parent_; + cudnnFilterDescriptor_t handle_; + const CudnnRnnDescriptor* rnn_desc_; + int64 params_size_in_bytes_; + ParamsRegions weights_; + ParamsRegions biases_; + port::Status status_; + SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor); +}; + +class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> { + public: + CudnnRnnDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle, + int num_layers, int hidden_size, int input_size, + cudnnRNNInputMode_t input_mode, + cudnnDirectionMode_t direction_mode, + cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type, + float dropout, uint64 seed, + ScratchAllocator* state_allocator) + : parent_(parent), + rnn_desc_(nullptr), + num_layers_(num_layers), + hidden_size_(hidden_size), + input_size_(input_size), + input_mode_(input_mode), + direction_mode_(direction_mode), + rnn_mode_(rnn_mode), + data_type_(data_type) { + // Create the dropout handle. + cudnn_dropout_desc_.reset(new CudnnDropoutDescriptor( + parent, cudnn_handle, dropout, seed, state_allocator)); + if (!cudnn_dropout_desc_->ok()) { + SetFailure(cudnn_dropout_desc_->Status()); + return; + } + + // Create the RNN handle + cudnnStatus_t status = + dynload::cudnnCreateRNNDescriptor(parent_, &rnn_desc_); + CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor"); + status = dynload::cudnnSetRNNDescriptor( + parent, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/, + num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/, + input_mode /*inputMode*/, direction_mode /*direction*/, + rnn_mode /*mode*/, data_type /*dataType*/); + CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor"); + + // Create the params handle. + cudnn_params_desc_.reset( + new CudnnRnnParamsDescriptor(parent, cudnn_handle, *this)); + if (!cudnn_params_desc_->ok()) { + SetFailure(cudnn_params_desc_->Status()); + return; + } + } + ~CudnnRnnDescriptor() override { + if (rnn_desc_) { + cudnnStatus_t status = + dynload::cudnnDestroyRNNDescriptor(parent_, rnn_desc_); + CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor"); + } + } + cudnnRNNDescriptor_t handle() const { + if (!ok()) return nullptr; + return rnn_desc_; + } + int num_layers() const { return num_layers_; } + int hidden_size() const { return hidden_size_; } + int input_size() const { return input_size_; } + cudnnRNNInputMode_t input_mode() const { return input_mode_; } + cudnnDirectionMode_t direction_mode() const { return direction_mode_; } + cudnnRNNMode_t rnn_mode() const { return rnn_mode_; } + cudnnDataType_t data_type() const { return data_type_; } + int64 ParamsSizeInBytes() const override { + return cudnn_params_desc_->params_size_in_bytes(); + } + cudnnDropoutDescriptor_t dropout_handle() const { + if (!cudnn_dropout_desc_) return nullptr; + return cudnn_dropout_desc_->handle(); + } + cudnnFilterDescriptor_t params_handle() const { + if (!cudnn_params_desc_) return nullptr; + return cudnn_params_desc_->handle(); + } + ParamsRegions ParamsWeightRegions() const override { + if (!ok()) return ParamsRegions(); + return cudnn_params_desc_->params_weights(); + } + ParamsRegions ParamsBiasRegions() const override { + if (!ok()) return ParamsRegions(); + return cudnn_params_desc_->params_biases(); + } + + private: + CUDAExecutor* parent_; + cudnnRNNDescriptor_t rnn_desc_; + int num_layers_; + int hidden_size_; + int input_size_; + cudnnRNNInputMode_t input_mode_; + cudnnDirectionMode_t direction_mode_; + cudnnRNNMode_t rnn_mode_; + cudnnDataType_t data_type_; + port::Status status_; + std::unique_ptr<CudnnDropoutDescriptor> cudnn_dropout_desc_; + std::unique_ptr<CudnnRnnParamsDescriptor> cudnn_params_desc_; + SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor); +}; + +CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor( + CUDAExecutor* parent, cudnnHandle_t cudnn_handle, + const CudnnRnnDescriptor& rnn_desc) + : parent_(parent), + handle_(nullptr), + rnn_desc_(&rnn_desc), + params_size_in_bytes_(0) { + cudnnTensorDescriptor_t input_desc = nullptr; + { + // Query the params size. + auto status = dynload::cudnnCreateTensorDescriptor(parent, &input_desc); + CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create tensor descriptor"); + int dims[] = {1, rnn_desc.input_size(), 1}; + int strides[] = {dims[1] * dims[2], dims[2], 1}; + status = dynload::cudnnSetTensorNdDescriptor( + parent, input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/, + sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/, + strides /*strideA*/); + CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor"); + + size_t params_size = 0; + status = dynload::cudnnGetRNNParamsSize( + parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, + input_desc /*xDesc*/, ¶ms_size /*sizeInBytes*/, + rnn_desc.data_type() /*dataType*/); + CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get RNN parameter size"); + params_size_in_bytes_ = static_cast<int64>(params_size); + } + + { + // Create the params descriptor. + auto status = dynload::cudnnCreateFilterDescriptor(parent, &handle_); + CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor"); + int dims[] = {static_cast<int>(params_size_in_bytes_), 1, 1}; + status = dynload::cudnnSetFilterNdDescriptor( + parent, handle_ /*filterDesc*/, rnn_desc.data_type() /*dataType*/, + CUDNN_TENSOR_NCHW /*format*/, sizeof(dims) / sizeof(dims[0]) /*nbDims*/, + dims /*filterDimA*/); + CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to update RNN filter descriptor"); + } + + { + // Create the weights and biases into the params buffer + int region_count_per_layer = GetRegionCountPerLayer(); + cudnnFilterDescriptor_t region_desc_handle = nullptr; + auto status = + dynload::cudnnCreateFilterDescriptor(parent, ®ion_desc_handle); + CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create filter descriptor"); + for (int layer = 0; layer < rnn_desc.num_layers(); layer++) { + for (int region = 0; region < region_count_per_layer; region++) { + for (int type = 0; type < 2; type++) { + void* offset = nullptr; + if (type == 0) { + status = dynload::cudnnGetRNNLinLayerMatrixParams( + parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, + layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/, + nullptr /*w*/, region /*linLayerID*/, + region_desc_handle /*linLayerMatDesc*/, + &offset /*linLayerMat*/); + CUDNN_RETURN_IF_FAIL( + status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams"); + } else { + status = dynload::cudnnGetRNNLinLayerBiasParams( + parent, cudnn_handle /*rnnDesc*/, rnn_desc.handle() /*rnnDesc*/, + layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/, + nullptr /*w*/, region /*linLayerID*/, + region_desc_handle /*linLayerBiasDesc*/, + &offset /*linLayerBias*/); + CUDNN_RETURN_IF_FAIL( + status, "Cudnn fails to call cudnnGetRNNLinLayerBiasParams"); + } + int dims[] = {1, 1, 1}; + cudnnDataType_t data_type; + cudnnTensorFormat_t tensor_format; + int n_dims; + status = dynload::cudnnGetFilterNdDescriptor( + parent, region_desc_handle /*filterDesc*/, + sizeof(dims) / sizeof(dims[0]) /*nbDimsRequested*/, + &data_type /*dataType*/, &tensor_format /*format*/, + &n_dims /*nbDims*/, dims /*filterDimA*/); + CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get filter description"); + int64 size = dims[0] * dims[1] * dims[2] * + CudnnDataTypeToByteSize(rnn_desc.data_type()); + auto region = ParamsRegion{reinterpret_cast<int64>(offset), size}; + if (type == 0) { + weights_.push_back(region); + } else { + biases_.push_back(region); + } + } + } + } + status = dynload::cudnnDestroyFilterDescriptor(parent, region_desc_handle); + CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy filter descriptor"); + } + + { + // Release the dummy input tensor descriptor. + auto status = dynload::cudnnDestroyTensorDescriptor(parent, input_desc); + CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy tensor descriptor"); + } +} + +int CudnnRnnParamsDescriptor::GetRegionCountPerLayer() const { + auto rnn_mode = rnn_desc_->rnn_mode(); + switch (rnn_mode) { + case CUDNN_RNN_RELU: + case CUDNN_RNN_TANH: + return 2; + case CUDNN_LSTM: + return 8; + case CUDNN_GRU: + return 6; + default: + LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode); + } +} + +class CudnnRnnSequenceTensorDescriptor + : public CudnnDescriptorCommon<dnn::RnnSequenceTensorDescriptor> { + public: + CudnnRnnSequenceTensorDescriptor(CUDAExecutor* parent, int seq_length, + int batch_size, int data_size, + cudnnDataType_t data_type) + : parent_(parent), + seq_length_(seq_length), + batch_size_(batch_size), + data_size_(data_size), + data_type_(data_type) { + cudnnTensorDescriptor_t handle = nullptr; + if (seq_length <= 0) { + string error_msg = + port::StrCat("sequence length must be positive: ", seq_length); + LOG(ERROR) << error_msg; + SetFailure(port::Status(port::error::UNKNOWN, error_msg)); + return; + } + cudnnStatus_t status = + dynload::cudnnCreateTensorDescriptor(parent, &handle); + CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor"); + int dims[] = {batch_size, data_size, 1}; + int strides[] = {dims[1] * dims[2], dims[2], 1}; + status = dynload::cudnnSetTensorNdDescriptor( + parent, handle /*tensorDesc*/, data_type /*dataType*/, + sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/, + strides /*strideA*/); + CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor"); + // Replicate handle across the number of steps. + handles_.assign(seq_length, handle); + } + + ~CudnnRnnSequenceTensorDescriptor() override { + // Only the first one needs to be destroyed. All others are the same. + cudnnStatus_t status = + dynload::cudnnDestroyTensorDescriptor(parent_, handles_[0]); + CUDNN_RETURN_IF_FAIL(status, "Failed to destroy sequence tensor desciptor"); + } + + const cudnnTensorDescriptor_t* handles() const { + if (!ok()) return nullptr; + CHECK(!handles_.empty()) << "handles cannot be empty"; + return handles_.data(); + } + + int seq_length() const { return seq_length_; } + int batch_size() const { return batch_size_; } + int data_size() const { return data_size_; } + + private: + CUDAExecutor* parent_; + int seq_length_; + int batch_size_; + int data_size_; + cudnnDataType_t data_type_; + std::vector<cudnnTensorDescriptor_t> handles_; + port::Status status_; + SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor); +}; + +class CudnnRnnStateTensorDescriptor + : public CudnnDescriptorCommon<dnn::RnnStateTensorDescriptor> { + public: + CudnnRnnStateTensorDescriptor(CUDAExecutor* parent, int num_layers, + int batch_size, int data_size, + cudnnDataType_t data_type) + : parent_(parent), + handle_(nullptr), + num_layers_(num_layers), + batch_size_(batch_size), + data_size_(data_size), + data_type_(data_type) { + cudnnStatus_t status = + dynload::cudnnCreateTensorDescriptor(parent, &handle_); + CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor"); + int dims[] = {num_layers, batch_size, data_size}; + int strides[] = {dims[1] * dims[2], dims[2], 1}; + status = dynload::cudnnSetTensorNdDescriptor( + parent, handle_ /*tensorDesc*/, data_type /*dataType*/, + sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/, + strides /*strideA*/); + CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor"); + } + + ~CudnnRnnStateTensorDescriptor() override { + if (!handle_) { + cudnnStatus_t status = + dynload::cudnnDestroyTensorDescriptor(parent_, handle_); + CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN state tensor"); + } + } + + cudnnTensorDescriptor_t handle() const { + if (!ok()) return nullptr; + return handle_; + } + int num_layers() const { return num_layers_; } + int batch_size() const { return batch_size_; } + int data_size() const { return data_size_; } + + private: + CUDAExecutor* parent_; + cudnnTensorDescriptor_t handle_; + int num_layers_; + int batch_size_; + int data_size_; + port::Status status_; + cudnnDataType_t data_type_; + SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor); +}; + +namespace { + +struct RnnModelDims { + int num_layers = 0; + int batch_size = 0; + int seq_length = 0; + int hidden_size = 0; + int input_size = 0; + int dir_count = 0; +}; + +template <class T> +bool ExtractAndCheckRnnForward( + const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<T>& input_data, + const CudnnRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<T>& input_h_data, + const CudnnRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params, + const CudnnRnnSequenceTensorDescriptor& output_desc, + const DeviceMemory<T>& output_data, + const CudnnRnnStateTensorDescriptor& output_h_desc, + const DeviceMemory<T>& output_h_data, + const CudnnRnnStateTensorDescriptor& output_c_desc, + const DeviceMemory<T>& output_c_data, RnnModelDims* model_dims) { + // extract model parameters + model_dims->num_layers = rnn_desc.num_layers(); + model_dims->batch_size = input_desc.batch_size(); + model_dims->seq_length = input_desc.seq_length(); + model_dims->hidden_size = rnn_desc.hidden_size(); + model_dims->input_size = input_desc.data_size(); + model_dims->dir_count = + (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1; + + // check parameters + if (!(input_h_desc.num_layers() == + model_dims->num_layers * model_dims->dir_count && + input_h_desc.batch_size() == model_dims->batch_size && + input_h_desc.data_size() == model_dims->hidden_size)) { + LOG(ERROR) << "Invalid input_h shape"; + return false; + } + if (!(input_h_desc.num_layers() == input_c_desc.num_layers() && + input_h_desc.batch_size() == input_c_desc.batch_size() && + input_h_desc.data_size() == input_c_desc.data_size())) { + LOG(ERROR) << "Invalid input_c shape"; + return false; + } + if (!(output_desc.seq_length() == model_dims->seq_length && + output_desc.batch_size() == model_dims->batch_size && + output_desc.data_size() == + model_dims->hidden_size * model_dims->dir_count)) { + LOG(ERROR) << "Invalid output shape"; + return false; + } + if (!(input_h_desc.num_layers() == output_h_desc.num_layers() && + input_h_desc.batch_size() == output_h_desc.batch_size() && + input_h_desc.data_size() == output_h_desc.data_size())) { + LOG(ERROR) << "Invalid output_h shape"; + return false; + } + if (!(input_h_desc.num_layers() == output_c_desc.num_layers() && + input_h_desc.batch_size() == output_c_desc.batch_size() && + input_h_desc.data_size() == output_c_desc.data_size())) { + LOG(ERROR) << "Invalid output_h shape"; + return false; + } + + return true; +} + +bool CheckRNNParameterSize(CUDAExecutor* parent, cudnnHandle_t cudnn_handle, + const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc) { + size_t params_size_in_bytes = 0; + cudnnStatus_t status = dynload::cudnnGetRNNParamsSize( + parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, + input_desc.handles()[0] /*xDesc*/, ¶ms_size_in_bytes /*sizeInBytes*/, + rnn_desc.data_type() /*dataType*/); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "Unable to check RNN param size: " << ToString(status); + return false; + } + return static_cast<int64>(params_size_in_bytes) == + rnn_desc.ParamsSizeInBytes(); +} + +bool CreateRnnWorkspace(Stream* stream, CUDAExecutor* parent, + cudnnHandle_t cudnn_handle, + const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + ScratchAllocator* workspace_allocator, + DeviceMemory<uint8>* workspace) { + // Query the workspace size. + size_t workspace_size_in_bytes = 0; + cudnnStatus_t status = dynload::cudnnGetRNNWorkspaceSize( + parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, + input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/, + &workspace_size_in_bytes /*sizeInBytes*/); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "Unable to query workspace size: " << ToString(status); + return false; + } + // Allocate the workspace. + if (workspace_size_in_bytes > 0) { + auto allocated = + workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); + if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) { + LOG(ERROR) << "Failed to allocate RNN workspace"; + return false; + } + } else { + *workspace = DeviceMemory<uint8>(); + } + return true; +} + +} // namespace + +template <class T> +bool CudnnSupport::DoRnnForwardImpl( + Stream* stream, const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<T>& input_data, + const CudnnRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<T>& input_h_data, + const CudnnRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params, + const CudnnRnnSequenceTensorDescriptor& output_desc, + DeviceMemory<T>* output_data, + const CudnnRnnStateTensorDescriptor& output_h_desc, + DeviceMemory<T>* output_h_data, + const CudnnRnnStateTensorDescriptor& output_c_desc, + DeviceMemory<T>* output_c_data, bool is_training, + ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator) { + // extract model parameters + RnnModelDims model_dims; + bool res = ExtractAndCheckRnnForward( + rnn_desc, input_desc, input_data, input_h_desc, input_h_data, + input_c_desc, input_c_data, params, output_desc, *output_data, + output_h_desc, *output_h_data, output_c_desc, *output_c_data, + &model_dims); + if (!res) { + LOG(ERROR) << "Invalid parameters for RNN Model"; + return false; + } + + // check params size + mutex_lock lock{dnn_handle_mutex_}; + + if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc, + input_desc)) { + LOG(ERROR) << "Invalid parameters"; + return false; + } + + // create the workspace + DeviceMemory<uint8> workspace; + if (!CreateRnnWorkspace(stream, parent_, ToHandle(dnn_handle_), rnn_desc, + input_desc, workspace_allocator, &workspace)) { + LOG(ERROR) << "Unable to create rnn workspace"; + return false; + } + + // query the reserve space size + // allocate the reserve space + DeviceMemory<uint8> reserve_space; + if (is_training) { + size_t reserve_space_size_in_bytes = 0; + cudnnStatus_t status = dynload::cudnnGetRNNTrainingReserveSize( + parent_, ToHandle(dnn_handle_) /*handle*/, + rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, + input_desc.handles() /*xDesc*/, + &reserve_space_size_in_bytes /*sizeInBytes*/); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "Unable to query reserve space size: " << ToString(status); + return false; + } + + if (reserve_space_size_in_bytes > 0) { + auto allocated = reserve_space_allocator->AllocateBytes( + stream, reserve_space_size_in_bytes); + if (!allocated.ok() || + (reserve_space = allocated.ValueOrDie()) == nullptr) { + LOG(ERROR) << "Fail to allocate RNN reserve space"; + return false; + } + } + } + + // make the forward call + if (!is_training) { + cudnnStatus_t status = dynload::cudnnRNNForwardInference( + parent_, ToHandle(dnn_handle_) /*handle*/, + rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, + input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/, + input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/, + input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/, + rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/, + output_desc.handles() /*yDesc*/, output_data->opaque() /*y*/, + output_h_desc.handle() /*hyDesc*/, output_h_data->opaque() /*hy*/, + output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/, + workspace.opaque() /*workspace*/, + workspace.size() /*workSpaceSizeInBytes*/); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "Failed to call cudnnRNNForwardInference: " + << ToString(status); + return false; + } + } else { + cudnnStatus_t status = dynload::cudnnRNNForwardTraining( + parent_, ToHandle(dnn_handle_) /*handle*/, + rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, + input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/, + input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/, + input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/, + rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/, + output_desc.handles() /*yDesc*/, output_data->opaque() /*y*/, + output_h_desc.handle() /*hyDesc*/, output_h_data->opaque() /*hy*/, + output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/, + workspace.opaque() /*workspace*/, + workspace.size() /*workSpaceSizeInBytes*/, + reserve_space.opaque() /*reserveSpace*/, + reserve_space.size() /*reserveSpaceSizeInBytes*/); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "Failed to call cudnnRNNForwardTraining" + << ToString(status); + return false; + } + } + + return true; +} + +template <class T> +bool CudnnSupport::DoRnnBackwardImpl( + Stream* stream, const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<T>& input_data, + const CudnnRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<T>& input_h_data, + const CudnnRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params, + const CudnnRnnSequenceTensorDescriptor& output_desc, + const DeviceMemory<T>& output_data, + const CudnnRnnStateTensorDescriptor& output_h_desc, + const DeviceMemory<T>& output_h_data, + const CudnnRnnStateTensorDescriptor& output_c_desc, + const DeviceMemory<T>& output_c_data, + const DeviceMemory<float>& output_backprop_data, + const DeviceMemory<float>& output_h_backprop_data, + const DeviceMemory<float>& output_c_backprop_data, + DeviceMemory<float>* input_backprop_data, + DeviceMemory<float>* input_h_backprop_data, + DeviceMemory<float>* input_c_backprop_data, + DeviceMemory<float>* params_backprop_data, + DeviceMemory<uint8>* reserve_space_data, + ScratchAllocator* workspace_allocator) { + // extract model parameters + RnnModelDims model_dims; + bool res = ExtractAndCheckRnnForward( + rnn_desc, input_desc, input_data, input_h_desc, input_h_data, + input_c_desc, input_c_data, params, output_desc, output_data, + output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims); + if (!res) { + LOG(ERROR) << "Invalid parameters for RNN Model"; + return false; + } + + // check params size + mutex_lock lock{dnn_handle_mutex_}; + + if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc, + input_desc)) { + LOG(ERROR) << "Invalid parameters"; + return false; + } + + // create the workspace + DeviceMemory<uint8> workspace; + if (!CreateRnnWorkspace(stream, parent_, ToHandle(dnn_handle_), rnn_desc, + input_desc, workspace_allocator, &workspace)) { + LOG(ERROR) << "Unable to create rnn workspace"; + return false; + } + + // make the backward data call + cudnnStatus_t status = dynload::cudnnRNNBackwardData( + parent_, ToHandle(dnn_handle_) /*handle*/, rnn_desc.handle() /*rnnDesc*/, + model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/, + output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/, + output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/, + output_h_backprop_data.opaque() /*dhy*/, + output_c_desc.handle() /*dcyDesc*/, + output_c_backprop_data.opaque() /*dcy*/, + rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/, + input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/, + input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/, + input_desc.handles() /*dxDesc*/, input_backprop_data->opaque() /*dx*/, + input_h_desc.handle() /*dhxDesc*/, + input_h_backprop_data->opaque() /*dhx*/, + input_c_desc.handle() /*dcxDesc*/, + input_c_backprop_data->opaque() /*dcx*/, workspace.opaque() /*workspace*/, + workspace.size() /*workSpaceSizeInBytes*/, + reserve_space_data->opaque() /*reserveSpace*/, + reserve_space_data->size() /*reserveSpaceSizeInBytes*/); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "Failed to call cudnnRNNBackwardData: " << ToString(status); + return false; + } + + if (params_backprop_data != nullptr) { + // Clear the dw to zeros. + stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); + // make the backward weight call + status = dynload::cudnnRNNBackwardWeights( + parent_, ToHandle(dnn_handle_) /*handle*/, + rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, + input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/, + input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/, + output_desc.handles() /*yDesc*/, output_data.opaque() /*y*/, + workspace.opaque() /*workspace*/, + workspace.size() /*workSpaceSizeInBytes*/, + rnn_desc.params_handle() /*dwDesc*/, + params_backprop_data->opaque() /*dw*/, + reserve_space_data->opaque() /*reserveSpace*/, + reserve_space_data->size() /*reserveSpaceSizeInBytes*/); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "Failed to call cudnnRNNBackwardWeights: " + << ToString(status); + return false; + } + } + + return true; +} + +#endif // CUDNN_VERSION + +port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> +CudnnSupport::createRnnDescriptor(int num_layers, int hidden_size, + int input_size, dnn::RnnInputMode input_mode, + dnn::RnnDirectionMode direction_mode, + dnn::RnnMode rnn_mode, + dnn::DataType data_type, float dropout, + uint64 seed, + ScratchAllocator* state_allocator) { +#if CUDNN_VERSION >= 5000 + mutex_lock lock{dnn_handle_mutex_}; + std::unique_ptr<CudnnRnnDescriptor> rnn_desc(new CudnnRnnDescriptor( + parent_, ToHandle(dnn_handle_), num_layers, hidden_size, input_size, + ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode), + ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type), dropout, seed, + state_allocator)); + if (!rnn_desc->ok()) { + return rnn_desc->Status(); + } + return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>( + std::move(rnn_desc)); +#else + string error_msg = + port::StrCat("createRnnDescriptor needs at least Cudnn 5.0 to work. ", + "Current Cudnn version: ", CUDNN_VERSION, ". "); + LOG(ERROR) << error_msg; + return port::Status{port::error::UNIMPLEMENTED, error_msg}; +#endif // CUDNN_VERSION +} + +port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> +CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, + int data_size, + dnn::DataType data_type) { +#if CUDNN_VERSION >= 5000 + std::unique_ptr<CudnnRnnSequenceTensorDescriptor> seq_desc( + new CudnnRnnSequenceTensorDescriptor(parent_, seq_length, batch_size, + data_size, + ToCudnnDataType(data_type))); + if (!seq_desc->ok()) { + return seq_desc->Status(); + } + return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>( + std::move(seq_desc)); +#else + string error_msg = port::StrCat( + "createRnnSequenceTensorDescriptor needs at least Cudnn 5.0 to work. ", + "Current Cudnn version: ", CUDNN_VERSION, ". "); + LOG(ERROR) << error_msg; + return port::Status{port::error::UNIMPLEMENTED, error_msg}; +#endif // CUDNN_VERSION +} + +port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> +CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, + int data_size, + dnn::DataType data_type) { +#if CUDNN_VERSION >= 5000 + std::unique_ptr<CudnnRnnStateTensorDescriptor> state_desc( + new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size, + data_size, ToCudnnDataType(data_type))); + if (!state_desc->ok()) { + return state_desc->Status(); + } + return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>( + std::move(state_desc)); +#else + string error_msg = port::StrCat( + "createRnnStateTensorDescriptor needs at least Cudnn 5.0 to work. ", + "Current Cudnn version: ", CUDNN_VERSION, ". "); + LOG(ERROR) << error_msg; + return port::Status{port::error::UNIMPLEMENTED, error_msg}; +#endif // CUDNN_VERSION +} + +bool CudnnSupport::DoRnnForward( + Stream* stream, const dnn::RnnDescriptor& rnn_desc, + const dnn::RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<float>& input_data, + const dnn::RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<float>& input_h_data, + const dnn::RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params, + const dnn::RnnSequenceTensorDescriptor& output_desc, + DeviceMemory<float>* output_data, + const dnn::RnnStateTensorDescriptor& output_h_desc, + DeviceMemory<float>* output_h_data, + const dnn::RnnStateTensorDescriptor& output_c_desc, + DeviceMemory<float>* output_c_data, bool is_training, + ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator) { +#if CUDNN_VERSION >= 5000 + const CudnnRnnDescriptor& cudnn_rnn_desc = + static_cast<const CudnnRnnDescriptor&>(rnn_desc); + const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = + static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc); + const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc = + static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc); + const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc = + static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc); + const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc = + static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc); + const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc = + static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc); + const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = + static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc); + + return DoRnnForwardImpl<float>( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, + input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, is_training, reserve_space_allocator, workspace_allocator); +#else + return false; +#endif // CUDNN_VERSION +} + +bool CudnnSupport::DoRnnBackward( + Stream* stream, const dnn::RnnDescriptor& rnn_desc, + const dnn::RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<float>& input_data, + const dnn::RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<float>& input_h_data, + const dnn::RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params, + const dnn::RnnSequenceTensorDescriptor& output_desc, + const DeviceMemory<float>& output_data, + const dnn::RnnStateTensorDescriptor& output_h_desc, + const DeviceMemory<float>& output_h_data, + const dnn::RnnStateTensorDescriptor& output_c_desc, + const DeviceMemory<float>& output_c_data, + const DeviceMemory<float>& output_backprop_data, + const DeviceMemory<float>& output_h_backprop_data, + const DeviceMemory<float>& output_c_backprop_data, + DeviceMemory<float>* input_backprop_data, + DeviceMemory<float>* input_h_backprop_data, + DeviceMemory<float>* input_c_backprop_data, + DeviceMemory<float>* params_backprop_data, + DeviceMemory<uint8>* reserve_space_data, + ScratchAllocator* workspace_allocator) { +#if CUDNN_VERSION >= 5000 + const CudnnRnnDescriptor& cudnn_rnn_desc = + static_cast<const CudnnRnnDescriptor&>(rnn_desc); + const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = + static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc); + const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc = + static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc); + const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc = + static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc); + const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc = + static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc); + const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc = + static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc); + const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = + static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc); + + return DoRnnBackwardImpl<float>( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, + input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, output_backprop_data, output_h_backprop_data, + output_c_backprop_data, input_backprop_data, input_h_backprop_data, + input_c_backprop_data, params_backprop_data, reserve_space_data, + workspace_allocator); +#else + return false; +#endif // CUDNN_VERSION +} + template <class T> bool CudnnSupport::DoConvolveImpl( Stream* stream, int cudnn_type, // Actually cudnnDataType_t. diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 9805781fcb..f8bc0c493f 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -31,6 +31,9 @@ namespace gputools { namespace cuda { class CUDAExecutor; +class CudnnRnnDescriptor; +class CudnnRnnSequenceTensorDescriptor; +class CudnnRnnStateTensorDescriptor; // Opaque and unique identifier for the cuDNN plugin. extern const PluginId kCuDnnPlugin; @@ -44,6 +47,62 @@ class CudnnSupport : public dnn::DnnSupport { port::Status Init() override; + port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( + int num_layers, int hidden_size, int input_size, + dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, + dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, + uint64 seed, ScratchAllocator* state_allocator) override; + + port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> + createRnnSequenceTensorDescriptor(int seq_length, int batch_size, + int data_size, + dnn::DataType data_type) override; + + port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> + createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + dnn::DataType data_type) override; + + bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, + const dnn::RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<float>& input_data, + const dnn::RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<float>& input_h_data, + const dnn::RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<float>& input_c_data, + const DeviceMemory<float>& params, + const dnn::RnnSequenceTensorDescriptor& output_desc, + DeviceMemory<float>* output_data, + const dnn::RnnStateTensorDescriptor& output_h_desc, + DeviceMemory<float>* output_h_data, + const dnn::RnnStateTensorDescriptor& output_c_desc, + DeviceMemory<float>* output_c_data, bool is_training, + ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator) override; + + bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, + const dnn::RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<float>& input_data, + const dnn::RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<float>& input_h_data, + const dnn::RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<float>& input_c_data, + const DeviceMemory<float>& params, + const dnn::RnnSequenceTensorDescriptor& output_desc, + const DeviceMemory<float>& output_data, + const dnn::RnnStateTensorDescriptor& output_h_desc, + const DeviceMemory<float>& output_h_data, + const dnn::RnnStateTensorDescriptor& output_c_desc, + const DeviceMemory<float>& output_c_data, + const DeviceMemory<float>& output_backprop_data, + const DeviceMemory<float>& output_h_backprop_data, + const DeviceMemory<float>& output_c_backprop_data, + DeviceMemory<float>* input_backprop_data, + DeviceMemory<float>* input_h_backprop_data, + DeviceMemory<float>* input_c_backprop_data, + DeviceMemory<float>* params_backprop_data, + DeviceMemory<uint8>* reserve_space_data, + ScratchAllocator* workspace_allocator) override; + bool GetConvolveAlgorithms( std::vector<dnn::AlgorithmType>* out_algorithms) override; @@ -369,6 +428,49 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& bias_descriptor, DeviceMemory<T>* backward_bias_data); + template <class T> + bool DoRnnForwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<T>& input_data, + const CudnnRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<T>& input_h_data, + const CudnnRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<T>& input_c_data, + const DeviceMemory<T>& params, + const CudnnRnnSequenceTensorDescriptor& output_desc, + DeviceMemory<T>* output_data, + const CudnnRnnStateTensorDescriptor& output_h_desc, + DeviceMemory<T>* output_h_data, + const CudnnRnnStateTensorDescriptor& output_c_desc, + DeviceMemory<T>* output_c_data, bool is_training, + ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator); + + template <class T> + bool DoRnnBackwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<T>& input_data, + const CudnnRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<T>& input_h_data, + const CudnnRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<T>& input_c_data, + const DeviceMemory<T>& params, + const CudnnRnnSequenceTensorDescriptor& output_desc, + const DeviceMemory<T>& output_data, + const CudnnRnnStateTensorDescriptor& output_h_desc, + const DeviceMemory<T>& output_h_data, + const CudnnRnnStateTensorDescriptor& output_c_desc, + const DeviceMemory<T>& output_c_data, + const DeviceMemory<float>& output_backprop_data, + const DeviceMemory<float>& output_h_backprop_data, + const DeviceMemory<float>& output_c_backprop_data, + DeviceMemory<float>* input_backprop_data, + DeviceMemory<float>* input_h_backprop_data, + DeviceMemory<float>* input_c_backprop_data, + DeviceMemory<float>* params_backprop_data, + DeviceMemory<uint8>* reserve_space_data, + ScratchAllocator* workspace_allocator); + SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport); }; diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index c2310c8938..1c31178526 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -23,10 +23,12 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_DNN_H_ #include <limits> +#include <memory> #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/lib/array_slice.h" #include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" @@ -76,6 +78,94 @@ enum class QuantizedActivationMode { k32Bit = 4, }; +// Specifies the data type used by an operation. +enum class DataType { + kFloat = 0, + kDouble = 1, + kHalf = 2, +}; + +// A helper class to convert C/C++ types to the proper enums. +template <typename T> +struct ToDataType; +template <> +struct ToDataType<float> { + static constexpr DataType value = DataType::kFloat; +}; +template <> +struct ToDataType<double> { + static constexpr DataType value = DataType::kDouble; +}; +template <> +struct ToDataType<Eigen::half> { + static constexpr DataType value = DataType::kHalf; +}; + +// Specifies the types of a RNN model. +enum class RnnMode { + kRnnRelu = 0, + kRnnTanh = 1, + kRnnLstm = 2, + kRnnGru = 3, +}; + +// Specifies the input model and whether there is a linear transformation +// between the input state and the first layer hidden state. +enum class RnnInputMode { + kRnnLinearSkip = 0, + kRnnSkipInput = 1, +}; + +// Specifies the number of directions used in a RNN model. When bidirection +// is used, the input states and output sequence contain data for both +// directions. +enum class RnnDirectionMode { + kRnnUnidirectional = 0, + kRnnBidirectional = 1, +}; + +// Specifies the descriptor for a RNN model. +// +// An example use case: +// * The user first creates a model through createRnnDescriptor. +// * The user queries the size of the underlying opaque parameter buffer. +// * The user creates and initializes a parameter buffer of the proper size. +// * The user runs forward and backward operations using this RNN descriptor. +// * Once a while, user queries maintainable weights and bias regions from +// the underlying parameter buffer. They are more likely to be forward +// compatible and should used in saving and restoring a model. +// * The user releases the RNN descriptor when the model is no longer in use. +class RnnDescriptor { + public: + struct ParamsRegion { + int64 offset; + int64 size; + }; + typedef std::vector<ParamsRegion> ParamsRegions; + virtual ~RnnDescriptor() {} + virtual int64 ParamsSizeInBytes() const { return -1; } + virtual ParamsRegions ParamsWeightRegions() const { return ParamsRegions(); } + virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); } +}; + +// Specifies the sequence in a RNN model. +// +// The user is responsible for releasing this descriptor when it is no longer +// in use. The destructor releases the underlying descriptors. +class RnnSequenceTensorDescriptor { + public: + virtual ~RnnSequenceTensorDescriptor() {} +}; + +// Specifies either the input and hidden state in a RNN model. +// +// The user is responsible for releasing this descriptor when it is no longer +// in use. The destructor releases the underlying descriptors. +class RnnStateTensorDescriptor { + public: + virtual ~RnnStateTensorDescriptor() {} +}; + // Returns a string representation of the given quantization mode. string QuantizedActivationModeString(QuantizedActivationMode mode); @@ -1260,6 +1350,179 @@ class DnnSupport { QuantizedActivationMode mode, DeviceMemory<float>* gpu_unquantized_dst) = 0; + + // Create an RNN descriptor based on model shapes and configurations. + // The caller retains the ownership of the descriptor. + // + // Arguments: + // num_layers: the number of layers for a RNN model. + // hidden_size: the size of the hidden state. + // input_size: the size of the input state. + // input_mode: an enum to specify whether a linear transformation is added + // after the input state. If input_size is different from hidden_size, this + // is required. + // direction_mode: an enum to specify whether this model is unidirectional or + // bidirectional. + // rnn_mode: an enum to specify the type of model to build. + // data_type: an enum to specify the data types used in this model. + // dropout: the dropout threshold between layers. When it is 0., no dropout + // is added. + // seed: a seed for initializing the dropout layers. + // state_allocator: an memory allocator that will be used to store the state + // for dropout layer. The user has to maintain the memory until the model + // is no longer in use. + virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> + createRnnDescriptor(int num_layers, int hidden_size, int input_size, + dnn::RnnInputMode input_mode, + dnn::RnnDirectionMode direction_mode, + dnn::RnnMode rnn_mode, dnn::DataType data_type, + float dropout, uint64 seed, + ScratchAllocator* state_allocator) { + return port::Status{port::error::UNIMPLEMENTED, + "createRnnDescriptor is unimplemented"}; + } + + // Create a RNN sequence descriptor that specifies either the input or output + // sequence. The caller retains the ownership of the returned descriptor. + // + // Arguments: + // seq_length: the length of the sequence. + // batch_size: the size of a minibatch. + // data_size: the size of the state. + // data_type: an enum to specify the type for the underlying data. + virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> + createRnnSequenceTensorDescriptor(int seq_length, int batch_size, + int data_size, dnn::DataType data_type) { + return port::Status{port::error::UNIMPLEMENTED, + "createRnnSequenceTensorDescriptor is unimplemented"}; + } + + // Create an RNN state descriptor that specifies the input or hidden state. + // The caller retains the ownership of the returned descriptor. + virtual port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> + createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + dnn::DataType data_type) { + return port::Status{port::error::UNIMPLEMENTED, + "createRnnStateTensorDescriptor is unimplemented"}; + } + + // Enqueue a forward operation of the RNN model onto the stream. + // + // Arguments: + // stream: pointer to the stream where this operation should be enqueued to. + // rnn_desc: a RNN descriptor created by createRnnDescriptor. + // input_desc: descriptor for the input sequence. + // input_data: the device memory region that contains the input data. + // input_h_desc: descriptor for the input "h" state. + // input_h_data: the device memory region that contains the input "h" data. + // input_c_desc: descriptor for the input "c" state. + // input_c_data: the device memory region that contains the input "c" data. + // This must be specified for LSTM models. + // params: the device memory region that contains the parameters used in this + // model. + // output_desc: descriptor for the output sequence. + // output_data: the memory region that stores the output sequence data. + // output_h_desc: descriptor for the output "h" state. + // output_h_data: the memory region that stores the output "h" data. + // output_c_desc: descriptor for the output "c" state. + // output_c_data: the memory region that stores the outptu "c" data. This + // must be specified for LSTM models. + // is_training: whether this is used in training or inference. That decides + // whether respace_space data need to be produced. + // reserve_space_allocator: if "is_training" is true, an memory allocator + // to create memory that holds the produced reserve_space. The caller is + // retains the data and feed it to the backward pass. + // workspace_allocator: an allocator to create temporary workspace used in + // this kernel. The caller is responsible for retaining the memory long + // enough for the lifespan of this operation, and recycles aftewards. + virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, + const dnn::RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<float>& input_data, + const dnn::RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<float>& input_h_data, + const dnn::RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<float>& input_c_data, + const DeviceMemory<float>& params, + const dnn::RnnSequenceTensorDescriptor& output_desc, + DeviceMemory<float>* output_data, + const dnn::RnnStateTensorDescriptor& output_h_desc, + DeviceMemory<float>* output_h_data, + const dnn::RnnStateTensorDescriptor& output_c_desc, + DeviceMemory<float>* output_c_data, + bool is_training, + ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator) { + return false; + } + + // Enqueue a backward operation of the RNN model onto the stream. + // + // Arguments: + // stream: pointer to the stream where this operation should be enqueued to. + // rnn_desc: a RNN descriptor created by createRnnDescriptor. + // input_desc: descriptor for the input sequence. + // input_data: the device memory region that contains the input data. + // input_h_desc: descriptor for the input "h" state. + // input_h_data: the device memory region that contains the input "h" data. + // input_c_desc: descriptor for the input "c" state. + // input_c_data: the device memory region that contains the input "c" data. + // This must be specified for LSTM models. + // params: the device memory region that contains the parameters used in this + // model. + // output_desc: descriptor for the output sequence. + // output_data: the memory region that stores the output sequence data. + // output_h_desc: descriptor for the output "h" state. + // output_h_data: the memory region that stores the output "h" data. + // output_c_desc: descriptor for the output "c" state. + // output_c_data: the memory region that stores the outptu "c" data. This + // must be specified for LSTM models. + // output_backprop_data: the device memory region that contains the backprop + // to the output sequence. + // output_h_backprop_data: the device memory region that contains the + // backprop to the output "h" state. + // output_c_backprop_data: the device memory region that contains the + // backprop to the output "c" state. + // input_backprop_data: the device memory region that stores the backprop + // to the input sequence. + // input_h_backprop_data: the device memory region that stores the backprop + // to the input "h" state. + // input_c_backprop_data: the device memory region that stores the backprop + // to the input "c" state. + // params_backprop_data: the device memory region that stores the backprop + // to the parameters. + // reserve_space_data: the reserve_space data that is produced by the forward + // operation. This memory region could be modified by this operation. + // workspace_allocator: a memory allocator that creates the temporary + // workspace memory used by this operation. The caller is responsible for + // keeping the memory alive long enough for this operation, and recylces + // afterwards. + virtual bool DoRnnBackward( + Stream* stream, const dnn::RnnDescriptor& rnn_desc, + const dnn::RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<float>& input_data, + const dnn::RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<float>& input_h_data, + const dnn::RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<float>& input_c_data, + const DeviceMemory<float>& params, + const dnn::RnnSequenceTensorDescriptor& output_desc, + const DeviceMemory<float>& output_data, + const dnn::RnnStateTensorDescriptor& output_h_desc, + const DeviceMemory<float>& output_h_data, + const dnn::RnnStateTensorDescriptor& output_c_desc, + const DeviceMemory<float>& output_c_data, + const DeviceMemory<float>& output_backprop_data, + const DeviceMemory<float>& output_h_backprop_data, + const DeviceMemory<float>& output_c_backprop_data, + DeviceMemory<float>* input_backprop_data, + DeviceMemory<float>* input_h_backprop_data, + DeviceMemory<float>* input_c_backprop_data, + DeviceMemory<float>* params_backprop_data, + DeviceMemory<uint8>* reserve_space_data, + ScratchAllocator* workspace_allocator) { + return false; + } + private: SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport); }; @@ -1269,3 +1532,4 @@ class DnnSupport { } // namespace perftools #endif // TENSORFLOW_STREAM_EXECUTOR_DNN_H_ + diff --git a/tensorflow/stream_executor/platform/port.h b/tensorflow/stream_executor/platform/port.h index 3c1b40b016..6603df4878 100644 --- a/tensorflow/stream_executor/platform/port.h +++ b/tensorflow/stream_executor/platform/port.h @@ -55,5 +55,7 @@ using tensorflow::LINKER_INITIALIZED; #define SE_DISALLOW_COPY_AND_ASSIGN TF_DISALLOW_COPY_AND_ASSIGN #define SE_MUST_USE_RESULT TF_MUST_USE_RESULT +#define SE_PREDICT_TRUE TF_PREDICT_TRUE +#define SE_PREDICT_FALSE TF_PREDICT_FALSE #endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_ diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 92ceb5bbb2..8c0e45f1a6 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -3774,6 +3774,79 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern, return *this; } +Stream &Stream::ThenRnnForward( + const dnn::RnnDescriptor &rnn_desc, + const dnn::RnnSequenceTensorDescriptor &input_desc, + const DeviceMemory<float> &input_data, + const dnn::RnnStateTensorDescriptor &input_h_desc, + const DeviceMemory<float> &input_h_data, + const dnn::RnnStateTensorDescriptor &input_c_desc, + const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms, + const dnn::RnnSequenceTensorDescriptor &output_desc, + DeviceMemory<float> *output_data, + const dnn::RnnStateTensorDescriptor &output_h_desc, + DeviceMemory<float> *output_h_data, + const dnn::RnnStateTensorDescriptor &output_c_desc, + DeviceMemory<float> *output_c_data, bool is_training, + ScratchAllocator *reserve_space_allocator, + ScratchAllocator *workspace_allocator) { + // TODO(zhengxq): add VLOG PARAM calls. + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoRnnForward( + this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, + input_c_desc, input_c_data, params, output_desc, output_data, + output_h_desc, output_h_data, output_c_desc, output_c_data, + is_training, reserve_space_allocator, workspace_allocator)); + } else { + SetError(); + LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support"; + } + } + return *this; +} + +Stream &Stream::ThenRnnBackward( + const dnn::RnnDescriptor &rnn_desc, + const dnn::RnnSequenceTensorDescriptor &input_desc, + const DeviceMemory<float> &input_data, + const dnn::RnnStateTensorDescriptor &input_h_desc, + const DeviceMemory<float> &input_h_data, + const dnn::RnnStateTensorDescriptor &input_c_desc, + const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms, + const dnn::RnnSequenceTensorDescriptor &output_desc, + const DeviceMemory<float> &output_data, + const dnn::RnnStateTensorDescriptor &output_h_desc, + const DeviceMemory<float> &output_h_data, + const dnn::RnnStateTensorDescriptor &output_c_desc, + const DeviceMemory<float> &output_c_data, + const DeviceMemory<float> &output_backprop_data, + const DeviceMemory<float> &output_h_backprop_data, + const DeviceMemory<float> &output_c_backprop_data, + DeviceMemory<float> *input_backprop_data, + DeviceMemory<float> *input_h_backprop_data, + DeviceMemory<float> *input_c_backprop_data, + DeviceMemory<float> *params_backprop_data, + DeviceMemory<uint8> *reserve_space_data, + ScratchAllocator *workspace_allocator) { + // TODO(zhengxq): add VLOG PARAM calls. + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoRnnBackward( + this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, + input_c_desc, input_c_data, params, output_desc, output_data, + output_h_desc, output_h_data, output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator)); + } else { + SetError(); + LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; + } + } + return *this; +} + Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) { VLOG_CALL(PARAM(callback)); diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 4d514804e5..61058528c2 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1383,6 +1383,51 @@ class Stream { Stream &ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern, uint64 size); + // Enqueue a forward operation of the RNN model onto the stream. + // See DnnSupport::DoRnnForward for more details. + Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, + const dnn::RnnSequenceTensorDescriptor &input_desc, + const DeviceMemory<float> &input_data, + const dnn::RnnStateTensorDescriptor &input_h_desc, + const DeviceMemory<float> &input_h_data, + const dnn::RnnStateTensorDescriptor &input_c_desc, + const DeviceMemory<float> &input_c_data, + const DeviceMemory<float> ¶ms, + const dnn::RnnSequenceTensorDescriptor &output_desc, + DeviceMemory<float> *output_data, + const dnn::RnnStateTensorDescriptor &output_h_desc, + DeviceMemory<float> *output_h_data, + const dnn::RnnStateTensorDescriptor &output_c_desc, + DeviceMemory<float> *output_c_data, bool is_training, + ScratchAllocator *reserve_space_allocator, + ScratchAllocator *workspace_allocator); + + // Enqueue a backward operation of the RNN model onto the stream. + // See DnnSupport::DoRnnBackward for more details. + Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, + const dnn::RnnSequenceTensorDescriptor &input_desc, + const DeviceMemory<float> &input_data, + const dnn::RnnStateTensorDescriptor &input_h_desc, + const DeviceMemory<float> &input_h_data, + const dnn::RnnStateTensorDescriptor &input_c_desc, + const DeviceMemory<float> &input_c_data, + const DeviceMemory<float> ¶ms, + const dnn::RnnSequenceTensorDescriptor &output_desc, + const DeviceMemory<float> &output_data, + const dnn::RnnStateTensorDescriptor &output_h_desc, + const DeviceMemory<float> &output_h_data, + const dnn::RnnStateTensorDescriptor &output_c_desc, + const DeviceMemory<float> &output_c_data, + const DeviceMemory<float> &output_backprop_data, + const DeviceMemory<float> &output_h_backprop_data, + const DeviceMemory<float> &output_c_backprop_data, + DeviceMemory<float> *input_backprop_data, + DeviceMemory<float> *input_h_backprop_data, + DeviceMemory<float> *input_c_backprop_data, + DeviceMemory<float> *params_backprop_data, + DeviceMemory<uint8> *reserve_space_data, + ScratchAllocator *workspace_allocator); + // (Synchronously) block the host code waiting for the operations // entrained on the stream (enqueued to this point in program // execution) to complete. diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 07dc375ef4..2fdd1e4b49 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -309,6 +309,48 @@ bool StreamExecutor::GetConvolveBackwardFilterAlgorithms( return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms); } +port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> +StreamExecutor::createRnnDescriptor( + int num_layers, int hidden_size, int input_size, + dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, + dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed, + ScratchAllocator *state_allocator) { + dnn::DnnSupport *dnn_support = AsDnn(); + if (!dnn_support) { + return port::Status(port::error::UNKNOWN, + "Fail to find the dnn implementation."); + } + return dnn_support->createRnnDescriptor( + num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode, + data_type, dropout, seed, state_allocator); +} + +port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> +StreamExecutor::createRnnSequenceTensorDescriptor(int seq_length, + int batch_size, int data_size, + dnn::DataType data_type) { + dnn::DnnSupport *dnn_support = AsDnn(); + if (!dnn_support) { + return port::Status(port::error::UNKNOWN, + "Fail to find the dnn implementation."); + } + return dnn_support->createRnnSequenceTensorDescriptor(seq_length, batch_size, + data_size, data_type); +} + +port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> +StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size, + int data_size, + dnn::DataType data_type) { + dnn::DnnSupport *dnn_support = AsDnn(); + if (!dnn_support) { + return port::Status(port::error::UNKNOWN, + "Fail to find the dnn implementation."); + } + return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size, + data_size, data_type); +} + dnn::DnnSupport *StreamExecutor::AsDnn() { mutex_lock lock{mu_}; if (dnn_ != nullptr) { diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 5f1b8bda02..2b5a70f807 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -350,6 +350,26 @@ class StreamExecutor { bool GetConvolveBackwardFilterAlgorithms( std::vector<dnn::AlgorithmType> *out_algorithms); + // Create an RNN descriptor based on model shapes and configurations. + // The caller retains the ownership of the descriptor. + port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( + int num_layers, int hidden_size, int input_size, + dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, + dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, + uint64 seed, ScratchAllocator *state_allocator); + + // Create a RNN sequence descriptor that specifies either the input or output + // sequence. The caller retains the ownership of the returned descriptor. + port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> + createRnnSequenceTensorDescriptor(int seq_length, int batch_size, + int data_size, dnn::DataType data_type); + + // Create an RNN state descriptor that specifies the input or hidden state. + // The caller retains the ownership of the returned descriptor. + port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> + createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + dnn::DataType data_type); + // Returns the device ordinal that this StreamExecutor was initialized with. // Meaningless before initialization. int device_ordinal() const { return device_ordinal_; } |