From 859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb Mon Sep 17 00:00:00 2001 From: Xiaoqiang Zheng Date: Fri, 19 Aug 2016 10:01:32 -0800 Subject: Add stream-executor changes to enable Cudnn fused LSTM/RNN support. Change: 130770287 --- .../stream_executor/stream_executor_pimpl.cc | 42 ++++++++++++++++++++++ 1 file changed, 42 insertions(+) (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc') diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 07dc375ef4..2fdd1e4b49 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -309,6 +309,48 @@ bool StreamExecutor::GetConvolveBackwardFilterAlgorithms( return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms); } +port::StatusOr> +StreamExecutor::createRnnDescriptor( + int num_layers, int hidden_size, int input_size, + dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, + dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed, + ScratchAllocator *state_allocator) { + dnn::DnnSupport *dnn_support = AsDnn(); + if (!dnn_support) { + return port::Status(port::error::UNKNOWN, + "Fail to find the dnn implementation."); + } + return dnn_support->createRnnDescriptor( + num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode, + data_type, dropout, seed, state_allocator); +} + +port::StatusOr> +StreamExecutor::createRnnSequenceTensorDescriptor(int seq_length, + int batch_size, int data_size, + dnn::DataType data_type) { + dnn::DnnSupport *dnn_support = AsDnn(); + if (!dnn_support) { + return port::Status(port::error::UNKNOWN, + "Fail to find the dnn implementation."); + } + return dnn_support->createRnnSequenceTensorDescriptor(seq_length, batch_size, + data_size, data_type); +} + +port::StatusOr> +StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size, + int data_size, + dnn::DataType data_type) { + dnn::DnnSupport *dnn_support = AsDnn(); + if (!dnn_support) { + return port::Status(port::error::UNKNOWN, + "Fail to find the dnn implementation."); + } + return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size, + data_size, data_type); +} + dnn::DnnSupport *StreamExecutor::AsDnn() { mutex_lock lock{mu_}; if (dnn_ != nullptr) { -- cgit v1.2.3