diff options
author | Xiaoqiang Zheng <zhengxq@google.com> | 2016-08-19 10:01:32 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-19 11:18:09 -0700 |
commit | 859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (patch) | |
tree | ae0c4e4e690d53b92720049fb75acc119e4ea5e0 /tensorflow/stream_executor/stream.h | |
parent | de8838042fb34e53a511f25b4613611fc368beeb (diff) |
Add stream-executor changes to enable Cudnn fused LSTM/RNN support.
Change: 130770287
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r-- | tensorflow/stream_executor/stream.h | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 4d514804e5..61058528c2 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1383,6 +1383,51 @@ class Stream { Stream &ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern, uint64 size); + // Enqueue a forward operation of the RNN model onto the stream. + // See DnnSupport::DoRnnForward for more details. + Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, + const dnn::RnnSequenceTensorDescriptor &input_desc, + const DeviceMemory<float> &input_data, + const dnn::RnnStateTensorDescriptor &input_h_desc, + const DeviceMemory<float> &input_h_data, + const dnn::RnnStateTensorDescriptor &input_c_desc, + const DeviceMemory<float> &input_c_data, + const DeviceMemory<float> ¶ms, + const dnn::RnnSequenceTensorDescriptor &output_desc, + DeviceMemory<float> *output_data, + const dnn::RnnStateTensorDescriptor &output_h_desc, + DeviceMemory<float> *output_h_data, + const dnn::RnnStateTensorDescriptor &output_c_desc, + DeviceMemory<float> *output_c_data, bool is_training, + ScratchAllocator *reserve_space_allocator, + ScratchAllocator *workspace_allocator); + + // Enqueue a backward operation of the RNN model onto the stream. + // See DnnSupport::DoRnnBackward for more details. + Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, + const dnn::RnnSequenceTensorDescriptor &input_desc, + const DeviceMemory<float> &input_data, + const dnn::RnnStateTensorDescriptor &input_h_desc, + const DeviceMemory<float> &input_h_data, + const dnn::RnnStateTensorDescriptor &input_c_desc, + const DeviceMemory<float> &input_c_data, + const DeviceMemory<float> ¶ms, + const dnn::RnnSequenceTensorDescriptor &output_desc, + const DeviceMemory<float> &output_data, + const dnn::RnnStateTensorDescriptor &output_h_desc, + const DeviceMemory<float> &output_h_data, + const dnn::RnnStateTensorDescriptor &output_c_desc, + const DeviceMemory<float> &output_c_data, + const DeviceMemory<float> &output_backprop_data, + const DeviceMemory<float> &output_h_backprop_data, + const DeviceMemory<float> &output_c_backprop_data, + DeviceMemory<float> *input_backprop_data, + DeviceMemory<float> *input_h_backprop_data, + DeviceMemory<float> *input_c_backprop_data, + DeviceMemory<float> *params_backprop_data, + DeviceMemory<uint8> *reserve_space_data, + ScratchAllocator *workspace_allocator); + // (Synchronously) block the host code waiting for the operations // entrained on the stream (enqueued to this point in program // execution) to complete. |