aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_pimpl.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.h')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h11
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index a2a77218cb..69d0374d73 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -349,10 +349,14 @@ class StreamExecutor {
// platform that underlies this interface.
bool SupportsDnn() const;
- // Get the list of supported algorithms for the forward convolution opeartion.
+ // Returns the list of supported algorithms for the forward convolution
+ // operation.
bool GetConvolveAlgorithms(bool with_winograd_nonfused,
std::vector<dnn::AlgorithmDesc> *out_algorithms);
+ // Returns the list of supported algorithms for rnn operation.
+ bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
+
// Get the list of supported algorithms for the backward convolution on data.
bool GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
@@ -372,8 +376,9 @@ class StreamExecutor {
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
int num_layers, int hidden_size, int input_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
- dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout,
- uint64 seed, ScratchAllocator *state_allocator);
+ dnn::RnnMode rnn_mode, dnn::DataType data_type,
+ const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
+ ScratchAllocator *state_allocator);
// Create a RNN sequence descriptor that specifies either the input or output
// sequence. The caller retains the ownership of the returned descriptor.