diff options
author | James Qin <jamesqin@google.com> | 2018-04-25 19:00:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-25 19:03:03 -0700 |
commit | 270a6e925493b6c2219b7a0152f6b81fbb88dfee (patch) | |
tree | f60074d1844c7bdcfbba029da834271c3c0d0b72 /tensorflow/stream_executor/stream_executor_pimpl.h | |
parent | ca634912e9b121d2e6b2ea04084886c73993e6aa (diff) |
Cudnn RNN v2 kernels with autotune capability
CudnnRNN V2 kernels run all applicable cudnn rnn algorithms and pick the best one for following runs.
* To enable autotune, TF_CUDNN_RNN_USE_AUTOTUNE and TF_CUDNN_RNN_USE_V2 need to be set to {"1" or unset}.
* TF_CUDNN_RNN_USE_AUTOTUNE does not work with existing CudnnRNN kernels.
* V2 kernels work with existing cudnn checkpoints, since it doesn't change persistence format.
This change
* Introduces v2 kernels as templates inheriting the v1 kernels.
* Profiles fwd and bak runs in v2 kernel (forward pass)
* Exposes the chosen algorithm as fwd op output and bak op input.
* Changes rnn descriptor cache key to include AlgorithmDesc (since cudnn rnn descriptor can't be reused across different algorithms)
* Updates unittests s.t. it tests both v1 and v2 kernels. When testing v2 kernels, autotune is turned on.
PiperOrigin-RevId: 194333948
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.h')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.h | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 39af7115d8..ab6b00f660 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -373,7 +373,7 @@ class StreamExecutor { // Create an RNN descriptor based on model shapes and configurations. // The caller retains the ownership of the descriptor. port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( - int num_layers, int hidden_size, int input_size, + int num_layers, int hidden_size, int input_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed, |