diff options
author | 2017-07-18 15:27:37 -0700 | |
---|---|---|
committer | 2017-07-18 15:36:13 -0700 | |
commit | ed03ee3fc74c09b64b6703337a5d0ed5d25b8b43 (patch) | |
tree | 4ab15d11b7c58642bc8b734f327e44a286083dad /tensorflow/stream_executor/dnn.h | |
parent | 79e85fa6efcaea3fc26fa914119883bcb7a99ecd (diff) |
Support float64 CuDNN RNN
PiperOrigin-RevId: 162412879
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r-- | tensorflow/stream_executor/dnn.h | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index f12aa6d38b..f97deb7222 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -1902,6 +1902,25 @@ class DnnSupport { return false; } + virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, + const dnn::RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<double>& input_data, + const dnn::RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<double>& input_h_data, + const dnn::RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<double>& input_c_data, + const DeviceMemory<double>& params, + const dnn::RnnSequenceTensorDescriptor& output_desc, + DeviceMemory<double>* output_data, + const dnn::RnnStateTensorDescriptor& output_h_desc, + DeviceMemory<double>* output_h_data, + const dnn::RnnStateTensorDescriptor& output_c_desc, + DeviceMemory<double>* output_c_data, + bool is_training, + ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator) { + return false; + } // Enqueue a backward operation of the RNN model onto the stream. // // Arguments: @@ -1970,6 +1989,33 @@ class DnnSupport { return false; } + virtual bool DoRnnBackward( + Stream* stream, const dnn::RnnDescriptor& rnn_desc, + const dnn::RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<double>& input_data, + const dnn::RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<double>& input_h_data, + const dnn::RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<double>& input_c_data, + const DeviceMemory<double>& params, + const dnn::RnnSequenceTensorDescriptor& output_desc, + const DeviceMemory<double>& output_data, + const dnn::RnnStateTensorDescriptor& output_h_desc, + const DeviceMemory<double>& output_h_data, + const dnn::RnnStateTensorDescriptor& output_c_desc, + const DeviceMemory<double>& output_c_data, + const DeviceMemory<double>& output_backprop_data, + const DeviceMemory<double>& output_h_backprop_data, + const DeviceMemory<double>& output_c_backprop_data, + DeviceMemory<double>* input_backprop_data, + DeviceMemory<double>* input_h_backprop_data, + DeviceMemory<double>* input_c_backprop_data, + DeviceMemory<double>* params_backprop_data, + DeviceMemory<uint8>* reserve_space_data, + ScratchAllocator* workspace_allocator) { + return false; + } + // Transforms a tensor into another tensor with a different layout and/or data // type. // |