aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_pimpl.cc
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2018-04-06 11:56:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 12:00:45 -0700
commit4f7943f7358fc69af62dc280c6f6ba549ebe2167 (patch)
tree19bdd3dddebeb7d26f9685328d0598bb58347bc0 /tensorflow/stream_executor/stream_executor_pimpl.cc
parentf15c117c4f4d51a6660bf14b6d6cf73c52692cfb (diff)
Support RNN profiling in StreamExecutor for CUDA GPUs.
This change hasn't applied autotune on TF Cudnn kernels, only provides lower level support. PiperOrigin-RevId: 191919566
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index afca1c2e59..f55fa68402 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -305,6 +305,15 @@ bool StreamExecutor::GetConvolveAlgorithms(
cc_minor, out_algorithms);
}
+bool StreamExecutor::GetRnnAlgorithms(
+ std::vector<dnn::AlgorithmDesc> *out_algorithms) {
+ dnn::DnnSupport *dnn_support = AsDnn();
+ if (!dnn_support) {
+ return false;
+ }
+ return dnn_support->GetRnnAlgorithms(out_algorithms);
+}
+
bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmDesc> *out_algorithms) {
@@ -344,7 +353,8 @@ port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
StreamExecutor::createRnnDescriptor(
int num_layers, int hidden_size, int input_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
- dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type,
+ const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
ScratchAllocator *state_allocator) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
@@ -353,7 +363,7 @@ StreamExecutor::createRnnDescriptor(
}
return dnn_support->createRnnDescriptor(
num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode,
- data_type, dropout, seed, state_allocator);
+ data_type, algorithm_config, dropout, seed, state_allocator);
}
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>