diff options
author | James Qin <jamesqin@google.com> | 2018-04-06 11:56:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-06 12:00:45 -0700 |
commit | 4f7943f7358fc69af62dc280c6f6ba549ebe2167 (patch) | |
tree | 19bdd3dddebeb7d26f9685328d0598bb58347bc0 /tensorflow/stream_executor/stream_executor_pimpl.cc | |
parent | f15c117c4f4d51a6660bf14b6d6cf73c52692cfb (diff) |
Support RNN profiling in StreamExecutor for CUDA GPUs.
This change hasn't applied autotune on TF Cudnn kernels, only provides lower level support.
PiperOrigin-RevId: 191919566
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.cc | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index afca1c2e59..f55fa68402 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -305,6 +305,15 @@ bool StreamExecutor::GetConvolveAlgorithms( cc_minor, out_algorithms); } +bool StreamExecutor::GetRnnAlgorithms( + std::vector<dnn::AlgorithmDesc> *out_algorithms) { + dnn::DnnSupport *dnn_support = AsDnn(); + if (!dnn_support) { + return false; + } + return dnn_support->GetRnnAlgorithms(out_algorithms); +} + bool StreamExecutor::GetConvolveBackwardDataAlgorithms( bool with_winograd_nonfused, std::vector<dnn::AlgorithmDesc> *out_algorithms) { @@ -344,7 +353,8 @@ port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> StreamExecutor::createRnnDescriptor( int num_layers, int hidden_size, int input_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, - dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed, + dnn::RnnMode rnn_mode, dnn::DataType data_type, + const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed, ScratchAllocator *state_allocator) { dnn::DnnSupport *dnn_support = AsDnn(); if (!dnn_support) { @@ -353,7 +363,7 @@ StreamExecutor::createRnnDescriptor( } return dnn_support->createRnnDescriptor( num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode, - data_type, dropout, seed, state_allocator); + data_type, algorithm_config, dropout, seed, state_allocator); } port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> |