aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.h
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2016-08-19 10:01:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-19 11:18:09 -0700
commit859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (patch)
treeae0c4e4e690d53b92720049fb75acc119e4ea5e0 /tensorflow/stream_executor/stream.h
parentde8838042fb34e53a511f25b4613611fc368beeb (diff)
Add stream-executor changes to enable Cudnn fused LSTM/RNN support.
Change: 130770287
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r--tensorflow/stream_executor/stream.h45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 4d514804e5..61058528c2 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1383,6 +1383,51 @@ class Stream {
Stream &ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern,
uint64 size);
+ // Enqueue a forward operation of the RNN model onto the stream.
+ // See DnnSupport::DoRnnForward for more details.
+ Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<float> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<float> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<float> &input_c_data,
+ const DeviceMemory<float> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ DeviceMemory<float> *output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ DeviceMemory<float> *output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ DeviceMemory<float> *output_c_data, bool is_training,
+ ScratchAllocator *reserve_space_allocator,
+ ScratchAllocator *workspace_allocator);
+
+ // Enqueue a backward operation of the RNN model onto the stream.
+ // See DnnSupport::DoRnnBackward for more details.
+ Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<float> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<float> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<float> &input_c_data,
+ const DeviceMemory<float> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ const DeviceMemory<float> &output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ const DeviceMemory<float> &output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ const DeviceMemory<float> &output_c_data,
+ const DeviceMemory<float> &output_backprop_data,
+ const DeviceMemory<float> &output_h_backprop_data,
+ const DeviceMemory<float> &output_c_backprop_data,
+ DeviceMemory<float> *input_backprop_data,
+ DeviceMemory<float> *input_h_backprop_data,
+ DeviceMemory<float> *input_c_backprop_data,
+ DeviceMemory<float> *params_backprop_data,
+ DeviceMemory<uint8> *reserve_space_data,
+ ScratchAllocator *workspace_allocator);
+
// (Synchronously) block the host code waiting for the operations
// entrained on the stream (enqueued to this point in program
// execution) to complete.