aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-07-18 15:27:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-18 15:36:13 -0700
commited03ee3fc74c09b64b6703337a5d0ed5d25b8b43 (patch)
tree4ab15d11b7c58642bc8b734f327e44a286083dad /tensorflow/stream_executor/stream.cc
parent79e85fa6efcaea3fc26fa914119883bcb7a99ecd (diff)
Support float64 CuDNN RNN
PiperOrigin-RevId: 162412879
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc75
1 files changed, 75 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 9b4a4c4fb1..5996195173 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -4388,6 +4388,39 @@ Stream &Stream::ThenRnnForward(
return *this;
}
+Stream &Stream::ThenRnnForward(
+ const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<double> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<double> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<double> &input_c_data,
+ const DeviceMemory<double> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ DeviceMemory<double> *output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ DeviceMemory<double> *output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ DeviceMemory<double> *output_c_data, bool is_training,
+ ScratchAllocator *reserve_space_allocator,
+ ScratchAllocator *workspace_allocator) {
+ // TODO(zhengxq): add VLOG PARAM calls.
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoRnnForward(
+ this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
+ input_c_desc, input_c_data, params, output_desc, output_data,
+ output_h_desc, output_h_data, output_c_desc, output_c_data,
+ is_training, reserve_space_allocator, workspace_allocator));
+ } else {
+ SetError();
+ LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
+ }
+ }
+ return *this;
+}
+
Stream &Stream::ThenRnnBackward(
const dnn::RnnDescriptor &rnn_desc,
const dnn::RnnSequenceTensorDescriptor &input_desc,
@@ -4429,6 +4462,48 @@ Stream &Stream::ThenRnnBackward(
return *this;
}
+Stream &Stream::ThenRnnBackward(
+ const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<double> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<double> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<double> &input_c_data,
+ const DeviceMemory<double> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ const DeviceMemory<double> &output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ const DeviceMemory<double> &output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ const DeviceMemory<double> &output_c_data,
+ const DeviceMemory<double> &output_backprop_data,
+ const DeviceMemory<double> &output_h_backprop_data,
+ const DeviceMemory<double> &output_c_backprop_data,
+ DeviceMemory<double> *input_backprop_data,
+ DeviceMemory<double> *input_h_backprop_data,
+ DeviceMemory<double> *input_c_backprop_data,
+ DeviceMemory<double> *params_backprop_data,
+ DeviceMemory<uint8> *reserve_space_data,
+ ScratchAllocator *workspace_allocator) {
+ // TODO(zhengxq): add VLOG PARAM calls.
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoRnnBackward(
+ this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
+ input_c_desc, input_c_data, params, output_desc, output_data,
+ output_h_desc, output_h_data, output_c_desc, output_c_data,
+ output_backprop_data, output_h_backprop_data, output_c_backprop_data,
+ input_backprop_data, input_h_backprop_data, input_c_backprop_data,
+ params_backprop_data, reserve_space_data, workspace_allocator));
+ } else {
+ SetError();
+ LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
+ }
+ }
+ return *this;
+}
+
Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
dnn::DataType input_type,
const DeviceMemoryBase &input_data,