diff options
author | Xiaoqiang Zheng <zhengxq@google.com> | 2016-08-19 10:01:32 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-19 11:18:09 -0700 |
commit | 859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (patch) | |
tree | ae0c4e4e690d53b92720049fb75acc119e4ea5e0 /tensorflow/stream_executor/stream_executor_pimpl.h | |
parent | de8838042fb34e53a511f25b4613611fc368beeb (diff) |
Add stream-executor changes to enable Cudnn fused LSTM/RNN support.
Change: 130770287
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.h')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.h | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 5f1b8bda02..2b5a70f807 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -350,6 +350,26 @@ class StreamExecutor { bool GetConvolveBackwardFilterAlgorithms( std::vector<dnn::AlgorithmType> *out_algorithms); + // Create an RNN descriptor based on model shapes and configurations. + // The caller retains the ownership of the descriptor. + port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( + int num_layers, int hidden_size, int input_size, + dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, + dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, + uint64 seed, ScratchAllocator *state_allocator); + + // Create a RNN sequence descriptor that specifies either the input or output + // sequence. The caller retains the ownership of the returned descriptor. + port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> + createRnnSequenceTensorDescriptor(int seq_length, int batch_size, + int data_size, dnn::DataType data_type); + + // Create an RNN state descriptor that specifies the input or hidden state. + // The caller retains the ownership of the returned descriptor. + port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> + createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + dnn::DataType data_type); + // Returns the device ordinal that this StreamExecutor was initialized with. // Meaningless before initialization. int device_ordinal() const { return device_ordinal_; } |