aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.h
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-11-03 12:09:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 12:12:52 -0700
commit509d0f2ca7f988d294d7234d31fac6a1cedcc39b (patch)
tree6597b62469e2cdb61879af6b79bf7b4a0daac28d /tensorflow/stream_executor/stream.h
parent7c7e04e9959b23aee6a194727eeaeb2d0d24e79a (diff)
Support Cudnn RNN Fp16
Relax CudnnRNNTestCompatibleRNNCells test error tolerance a bit. PiperOrigin-RevId: 174495089
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r--tensorflow/stream_executor/stream.h43
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 21172d5a16..023cffb965 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1753,6 +1753,24 @@ class Stream {
// See DnnSupport::DoRnnForward for more details.
Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<Eigen::half> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<Eigen::half> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<Eigen::half> &input_c_data,
+ const DeviceMemory<Eigen::half> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ DeviceMemory<Eigen::half> *output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ DeviceMemory<Eigen::half> *output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ DeviceMemory<Eigen::half> *output_c_data,
+ bool is_training,
+ ScratchAllocator *reserve_space_allocator,
+ ScratchAllocator *workspace_allocator);
+
+ Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
const DeviceMemory<float> &input_data,
const dnn::RnnStateTensorDescriptor &input_h_desc,
const DeviceMemory<float> &input_h_data,
@@ -1787,6 +1805,31 @@ class Stream {
// Enqueue a backward operation of the RNN model onto the stream.
// See DnnSupport::DoRnnBackward for more details.
+ Stream &ThenRnnBackward(
+ const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<Eigen::half> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<Eigen::half> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<Eigen::half> &input_c_data,
+ const DeviceMemory<Eigen::half> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ const DeviceMemory<Eigen::half> &output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ const DeviceMemory<Eigen::half> &output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ const DeviceMemory<Eigen::half> &output_c_data,
+ const DeviceMemory<Eigen::half> &output_backprop_data,
+ const DeviceMemory<Eigen::half> &output_h_backprop_data,
+ const DeviceMemory<Eigen::half> &output_c_backprop_data,
+ DeviceMemory<Eigen::half> *input_backprop_data,
+ DeviceMemory<Eigen::half> *input_h_backprop_data,
+ DeviceMemory<Eigen::half> *input_c_backprop_data,
+ DeviceMemory<Eigen::half> *params_backprop_data,
+ DeviceMemory<uint8> *reserve_space_data,
+ ScratchAllocator *workspace_allocator);
+
Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
const dnn::RnnSequenceTensorDescriptor &input_desc,
const DeviceMemory<float> &input_data,