aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-11-03 12:09:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 12:12:52 -0700
commit509d0f2ca7f988d294d7234d31fac6a1cedcc39b (patch)
tree6597b62469e2cdb61879af6b79bf7b4a0daac28d /tensorflow/stream_executor/dnn.h
parent7c7e04e9959b23aee6a194727eeaeb2d0d24e79a (diff)
Support Cudnn RNN Fp16
Relax CudnnRNNTestCompatibleRNNCells test error tolerance a bit. PiperOrigin-RevId: 174495089
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 624357b82f..49235167ab 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -2029,6 +2029,26 @@ class DnnSupport {
// enough for the lifespan of this operation, and recycles aftewards.
virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<Eigen::half>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<Eigen::half>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<Eigen::half>& input_c_data,
+ const DeviceMemory<Eigen::half>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ DeviceMemory<Eigen::half>* output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ DeviceMemory<Eigen::half>* output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ DeviceMemory<Eigen::half>* output_c_data,
+ bool is_training,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator) {
+ return false;
+ }
+
+ virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
const DeviceMemory<float>& input_data,
const dnn::RnnStateTensorDescriptor& input_h_desc,
const DeviceMemory<float>& input_h_data,
@@ -2110,6 +2130,33 @@ class DnnSupport {
virtual bool DoRnnBackward(
Stream* stream, const dnn::RnnDescriptor& rnn_desc,
const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<Eigen::half>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<Eigen::half>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<Eigen::half>& input_c_data,
+ const DeviceMemory<Eigen::half>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<Eigen::half>& output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<Eigen::half>& output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<Eigen::half>& output_c_data,
+ const DeviceMemory<Eigen::half>& output_backprop_data,
+ const DeviceMemory<Eigen::half>& output_h_backprop_data,
+ const DeviceMemory<Eigen::half>& output_c_backprop_data,
+ DeviceMemory<Eigen::half>* input_backprop_data,
+ DeviceMemory<Eigen::half>* input_h_backprop_data,
+ DeviceMemory<Eigen::half>* input_c_backprop_data,
+ DeviceMemory<Eigen::half>* params_backprop_data,
+ DeviceMemory<uint8>* reserve_space_data,
+ ScratchAllocator* workspace_allocator) {
+ return false;
+ }
+
+ virtual bool DoRnnBackward(
+ Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
const DeviceMemory<float>& input_data,
const dnn::RnnStateTensorDescriptor& input_h_desc,
const DeviceMemory<float>& input_h_data,