diff options
author | James Qin <jamesqin@google.com> | 2017-11-03 12:09:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-03 12:12:52 -0700 |
commit | 509d0f2ca7f988d294d7234d31fac6a1cedcc39b (patch) | |
tree | 6597b62469e2cdb61879af6b79bf7b4a0daac28d /tensorflow/stream_executor/stream.cc | |
parent | 7c7e04e9959b23aee6a194727eeaeb2d0d24e79a (diff) |
Support Cudnn RNN Fp16
Relax CudnnRNNTestCompatibleRNNCells test error tolerance a bit.
PiperOrigin-RevId: 174495089
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 6d756ab191..22fd6bce78 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4682,6 +4682,39 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern, Stream &Stream::ThenRnnForward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, + const DeviceMemory<Eigen::half> &input_data, + const dnn::RnnStateTensorDescriptor &input_h_desc, + const DeviceMemory<Eigen::half> &input_h_data, + const dnn::RnnStateTensorDescriptor &input_c_desc, + const DeviceMemory<Eigen::half> &input_c_data, + const DeviceMemory<Eigen::half> ¶ms, + const dnn::RnnSequenceTensorDescriptor &output_desc, + DeviceMemory<Eigen::half> *output_data, + const dnn::RnnStateTensorDescriptor &output_h_desc, + DeviceMemory<Eigen::half> *output_h_data, + const dnn::RnnStateTensorDescriptor &output_c_desc, + DeviceMemory<Eigen::half> *output_c_data, bool is_training, + ScratchAllocator *reserve_space_allocator, + ScratchAllocator *workspace_allocator) { + // TODO(zhengxq): add VLOG PARAM calls. + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoRnnForward( + this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, + input_c_desc, input_c_data, params, output_desc, output_data, + output_h_desc, output_h_data, output_c_desc, output_c_data, + is_training, reserve_space_allocator, workspace_allocator)); + } else { + SetError(); + LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support"; + } + } + return *this; +} + +Stream &Stream::ThenRnnForward( + const dnn::RnnDescriptor &rnn_desc, + const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory<float> &input_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory<float> &input_h_data, @@ -4747,6 +4780,48 @@ Stream &Stream::ThenRnnForward( Stream &Stream::ThenRnnBackward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, + const DeviceMemory<Eigen::half> &input_data, + const dnn::RnnStateTensorDescriptor &input_h_desc, + const DeviceMemory<Eigen::half> &input_h_data, + const dnn::RnnStateTensorDescriptor &input_c_desc, + const DeviceMemory<Eigen::half> &input_c_data, + const DeviceMemory<Eigen::half> ¶ms, + const dnn::RnnSequenceTensorDescriptor &output_desc, + const DeviceMemory<Eigen::half> &output_data, + const dnn::RnnStateTensorDescriptor &output_h_desc, + const DeviceMemory<Eigen::half> &output_h_data, + const dnn::RnnStateTensorDescriptor &output_c_desc, + const DeviceMemory<Eigen::half> &output_c_data, + const DeviceMemory<Eigen::half> &output_backprop_data, + const DeviceMemory<Eigen::half> &output_h_backprop_data, + const DeviceMemory<Eigen::half> &output_c_backprop_data, + DeviceMemory<Eigen::half> *input_backprop_data, + DeviceMemory<Eigen::half> *input_h_backprop_data, + DeviceMemory<Eigen::half> *input_c_backprop_data, + DeviceMemory<Eigen::half> *params_backprop_data, + DeviceMemory<uint8> *reserve_space_data, + ScratchAllocator *workspace_allocator) { + // TODO(zhengxq): add VLOG PARAM calls. + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoRnnBackward( + this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, + input_c_desc, input_c_data, params, output_desc, output_data, + output_h_desc, output_h_data, output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator)); + } else { + SetError(); + LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; + } + } + return *this; +} + +Stream &Stream::ThenRnnBackward( + const dnn::RnnDescriptor &rnn_desc, + const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory<float> &input_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory<float> &input_h_data, |