diff options
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.h')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.h | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index a2a77218cb..69d0374d73 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -349,10 +349,14 @@ class StreamExecutor { // platform that underlies this interface. bool SupportsDnn() const; - // Get the list of supported algorithms for the forward convolution opeartion. + // Returns the list of supported algorithms for the forward convolution + // operation. bool GetConvolveAlgorithms(bool with_winograd_nonfused, std::vector<dnn::AlgorithmDesc> *out_algorithms); + // Returns the list of supported algorithms for rnn operation. + bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms); + // Get the list of supported algorithms for the backward convolution on data. bool GetConvolveBackwardDataAlgorithms( bool with_winograd_nonfused, @@ -372,8 +376,9 @@ class StreamExecutor { port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( int num_layers, int hidden_size, int input_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, - dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, - uint64 seed, ScratchAllocator *state_allocator); + dnn::RnnMode rnn_mode, dnn::DataType data_type, + const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed, + ScratchAllocator *state_allocator); // Create a RNN sequence descriptor that specifies either the input or output // sequence. The caller retains the ownership of the returned descriptor. |