aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_pimpl.cc
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2018-04-25 19:00:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-25 19:03:03 -0700
commit270a6e925493b6c2219b7a0152f6b81fbb88dfee (patch)
treef60074d1844c7bdcfbba029da834271c3c0d0b72 /tensorflow/stream_executor/stream_executor_pimpl.cc
parentca634912e9b121d2e6b2ea04084886c73993e6aa (diff)
Cudnn RNN v2 kernels with autotune capability
CudnnRNN V2 kernels run all applicable cudnn rnn algorithms and pick the best one for following runs. * To enable autotune, TF_CUDNN_RNN_USE_AUTOTUNE and TF_CUDNN_RNN_USE_V2 need to be set to {"1" or unset}. * TF_CUDNN_RNN_USE_AUTOTUNE does not work with existing CudnnRNN kernels. * V2 kernels work with existing cudnn checkpoints, since it doesn't change persistence format. This change * Introduces v2 kernels as templates inheriting the v1 kernels. * Profiles fwd and bak runs in v2 kernel (forward pass) * Exposes the chosen algorithm as fwd op output and bak op input. * Changes rnn descriptor cache key to include AlgorithmDesc (since cudnn rnn descriptor can't be reused across different algorithms) * Updates unittests s.t. it tests both v1 and v2 kernels. When testing v2 kernels, autotune is turned on. PiperOrigin-RevId: 194333948
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 2e1adeb31e..20579790ef 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -350,7 +350,7 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
StreamExecutor::createRnnDescriptor(
- int num_layers, int hidden_size, int input_size,
+ int num_layers, int hidden_size, int input_size, int batch_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode, dnn::DataType data_type,
const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
@@ -361,8 +361,9 @@ StreamExecutor::createRnnDescriptor(
"Fail to find the dnn implementation.");
}
return dnn_support->createRnnDescriptor(
- num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode,
- data_type, algorithm_config, dropout, seed, state_allocator);
+ num_layers, hidden_size, input_size, batch_size, input_mode,
+ direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
+ state_allocator);
}
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>