aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_pimpl.cc
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2016-08-19 10:01:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-19 11:18:09 -0700
commit859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (patch)
treeae0c4e4e690d53b92720049fb75acc119e4ea5e0 /tensorflow/stream_executor/stream_executor_pimpl.cc
parentde8838042fb34e53a511f25b4613611fc368beeb (diff)
Add stream-executor changes to enable Cudnn fused LSTM/RNN support.
Change: 130770287
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 07dc375ef4..2fdd1e4b49 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -309,6 +309,48 @@ bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms);
}
+port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
+StreamExecutor::createRnnDescriptor(
+ int num_layers, int hidden_size, int input_size,
+ dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed,
+ ScratchAllocator *state_allocator) {
+ dnn::DnnSupport *dnn_support = AsDnn();
+ if (!dnn_support) {
+ return port::Status(port::error::UNKNOWN,
+ "Fail to find the dnn implementation.");
+ }
+ return dnn_support->createRnnDescriptor(
+ num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode,
+ data_type, dropout, seed, state_allocator);
+}
+
+port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
+StreamExecutor::createRnnSequenceTensorDescriptor(int seq_length,
+ int batch_size, int data_size,
+ dnn::DataType data_type) {
+ dnn::DnnSupport *dnn_support = AsDnn();
+ if (!dnn_support) {
+ return port::Status(port::error::UNKNOWN,
+ "Fail to find the dnn implementation.");
+ }
+ return dnn_support->createRnnSequenceTensorDescriptor(seq_length, batch_size,
+ data_size, data_type);
+}
+
+port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
+StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
+ int data_size,
+ dnn::DataType data_type) {
+ dnn::DnnSupport *dnn_support = AsDnn();
+ if (!dnn_support) {
+ return port::Status(port::error::UNKNOWN,
+ "Fail to find the dnn implementation.");
+ }
+ return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
+ data_size, data_type);
+}
+
dnn::DnnSupport *StreamExecutor::AsDnn() {
mutex_lock lock{mu_};
if (dnn_ != nullptr) {