aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_pimpl.h
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2016-08-19 10:01:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-19 11:18:09 -0700
commit859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (patch)
treeae0c4e4e690d53b92720049fb75acc119e4ea5e0 /tensorflow/stream_executor/stream_executor_pimpl.h
parentde8838042fb34e53a511f25b4613611fc368beeb (diff)
Add stream-executor changes to enable Cudnn fused LSTM/RNN support.
Change: 130770287
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.h')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 5f1b8bda02..2b5a70f807 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -350,6 +350,26 @@ class StreamExecutor {
bool GetConvolveBackwardFilterAlgorithms(
std::vector<dnn::AlgorithmType> *out_algorithms);
+ // Create an RNN descriptor based on model shapes and configurations.
+ // The caller retains the ownership of the descriptor.
+ port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
+ int num_layers, int hidden_size, int input_size,
+ dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout,
+ uint64 seed, ScratchAllocator *state_allocator);
+
+ // Create a RNN sequence descriptor that specifies either the input or output
+ // sequence. The caller retains the ownership of the returned descriptor.
+ port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
+ createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
+ int data_size, dnn::DataType data_type);
+
+ // Create an RNN state descriptor that specifies the input or hidden state.
+ // The caller retains the ownership of the returned descriptor.
+ port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
+ createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
+ dnn::DataType data_type);
+
// Returns the device ordinal that this StreamExecutor was initialized with.
// Meaningless before initialization.
int device_ordinal() const { return device_ordinal_; }