aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-07-18 15:27:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-18 15:36:13 -0700
commited03ee3fc74c09b64b6703337a5d0ed5d25b8b43 (patch)
tree4ab15d11b7c58642bc8b734f327e44a286083dad /tensorflow/stream_executor/dnn.h
parent79e85fa6efcaea3fc26fa914119883bcb7a99ecd (diff)
Support float64 CuDNN RNN
PiperOrigin-RevId: 162412879
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h46
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index f12aa6d38b..f97deb7222 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -1902,6 +1902,25 @@ class DnnSupport {
return false;
}
+ virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<double>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<double>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<double>& input_c_data,
+ const DeviceMemory<double>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ DeviceMemory<double>* output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ DeviceMemory<double>* output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ DeviceMemory<double>* output_c_data,
+ bool is_training,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator) {
+ return false;
+ }
// Enqueue a backward operation of the RNN model onto the stream.
//
// Arguments:
@@ -1970,6 +1989,33 @@ class DnnSupport {
return false;
}
+ virtual bool DoRnnBackward(
+ Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<double>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<double>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<double>& input_c_data,
+ const DeviceMemory<double>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<double>& output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<double>& output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<double>& output_c_data,
+ const DeviceMemory<double>& output_backprop_data,
+ const DeviceMemory<double>& output_h_backprop_data,
+ const DeviceMemory<double>& output_c_backprop_data,
+ DeviceMemory<double>* input_backprop_data,
+ DeviceMemory<double>* input_h_backprop_data,
+ DeviceMemory<double>* input_c_backprop_data,
+ DeviceMemory<double>* params_backprop_data,
+ DeviceMemory<uint8>* reserve_space_data,
+ ScratchAllocator* workspace_allocator) {
+ return false;
+ }
+
// Transforms a tensor into another tensor with a different layout and/or data
// type.
//