diff options
author | James Qin <jamesqin@google.com> | 2017-07-18 15:27:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-18 15:36:13 -0700 |
commit | ed03ee3fc74c09b64b6703337a5d0ed5d25b8b43 (patch) | |
tree | 4ab15d11b7c58642bc8b734f327e44a286083dad /tensorflow/stream_executor/stream.cc | |
parent | 79e85fa6efcaea3fc26fa914119883bcb7a99ecd (diff) |
Support float64 CuDNN RNN
PiperOrigin-RevId: 162412879
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 9b4a4c4fb1..5996195173 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4388,6 +4388,39 @@ Stream &Stream::ThenRnnForward( return *this; } +Stream &Stream::ThenRnnForward( + const dnn::RnnDescriptor &rnn_desc, + const dnn::RnnSequenceTensorDescriptor &input_desc, + const DeviceMemory<double> &input_data, + const dnn::RnnStateTensorDescriptor &input_h_desc, + const DeviceMemory<double> &input_h_data, + const dnn::RnnStateTensorDescriptor &input_c_desc, + const DeviceMemory<double> &input_c_data, + const DeviceMemory<double> ¶ms, + const dnn::RnnSequenceTensorDescriptor &output_desc, + DeviceMemory<double> *output_data, + const dnn::RnnStateTensorDescriptor &output_h_desc, + DeviceMemory<double> *output_h_data, + const dnn::RnnStateTensorDescriptor &output_c_desc, + DeviceMemory<double> *output_c_data, bool is_training, + ScratchAllocator *reserve_space_allocator, + ScratchAllocator *workspace_allocator) { + // TODO(zhengxq): add VLOG PARAM calls. + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoRnnForward( + this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, + input_c_desc, input_c_data, params, output_desc, output_data, + output_h_desc, output_h_data, output_c_desc, output_c_data, + is_training, reserve_space_allocator, workspace_allocator)); + } else { + SetError(); + LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support"; + } + } + return *this; +} + Stream &Stream::ThenRnnBackward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, @@ -4429,6 +4462,48 @@ Stream &Stream::ThenRnnBackward( return *this; } +Stream &Stream::ThenRnnBackward( + const dnn::RnnDescriptor &rnn_desc, + const dnn::RnnSequenceTensorDescriptor &input_desc, + const DeviceMemory<double> &input_data, + const dnn::RnnStateTensorDescriptor &input_h_desc, + const DeviceMemory<double> &input_h_data, + const dnn::RnnStateTensorDescriptor &input_c_desc, + const DeviceMemory<double> &input_c_data, + const DeviceMemory<double> ¶ms, + const dnn::RnnSequenceTensorDescriptor &output_desc, + const DeviceMemory<double> &output_data, + const dnn::RnnStateTensorDescriptor &output_h_desc, + const DeviceMemory<double> &output_h_data, + const dnn::RnnStateTensorDescriptor &output_c_desc, + const DeviceMemory<double> &output_c_data, + const DeviceMemory<double> &output_backprop_data, + const DeviceMemory<double> &output_h_backprop_data, + const DeviceMemory<double> &output_c_backprop_data, + DeviceMemory<double> *input_backprop_data, + DeviceMemory<double> *input_h_backprop_data, + DeviceMemory<double> *input_c_backprop_data, + DeviceMemory<double> *params_backprop_data, + DeviceMemory<uint8> *reserve_space_data, + ScratchAllocator *workspace_allocator) { + // TODO(zhengxq): add VLOG PARAM calls. + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoRnnBackward( + this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, + input_c_desc, input_c_data, params, output_desc, output_data, + output_h_desc, output_h_data, output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator)); + } else { + SetError(); + LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; + } + } + return *this; +} + Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc, dnn::DataType input_type, const DeviceMemoryBase &input_data, |