aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-11-03 12:09:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 12:12:52 -0700
commit509d0f2ca7f988d294d7234d31fac6a1cedcc39b (patch)
tree6597b62469e2cdb61879af6b79bf7b4a0daac28d /tensorflow/stream_executor/stream.cc
parent7c7e04e9959b23aee6a194727eeaeb2d0d24e79a (diff)
Support Cudnn RNN Fp16
Relax CudnnRNNTestCompatibleRNNCells test error tolerance a bit. PiperOrigin-RevId: 174495089
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc75
1 files changed, 75 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 6d756ab191..22fd6bce78 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -4682,6 +4682,39 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
Stream &Stream::ThenRnnForward(
const dnn::RnnDescriptor &rnn_desc,
const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<Eigen::half> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<Eigen::half> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<Eigen::half> &input_c_data,
+ const DeviceMemory<Eigen::half> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ DeviceMemory<Eigen::half> *output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ DeviceMemory<Eigen::half> *output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ DeviceMemory<Eigen::half> *output_c_data, bool is_training,
+ ScratchAllocator *reserve_space_allocator,
+ ScratchAllocator *workspace_allocator) {
+ // TODO(zhengxq): add VLOG PARAM calls.
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoRnnForward(
+ this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
+ input_c_desc, input_c_data, params, output_desc, output_data,
+ output_h_desc, output_h_data, output_c_desc, output_c_data,
+ is_training, reserve_space_allocator, workspace_allocator));
+ } else {
+ SetError();
+ LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenRnnForward(
+ const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
const DeviceMemory<float> &input_data,
const dnn::RnnStateTensorDescriptor &input_h_desc,
const DeviceMemory<float> &input_h_data,
@@ -4747,6 +4780,48 @@ Stream &Stream::ThenRnnForward(
Stream &Stream::ThenRnnBackward(
const dnn::RnnDescriptor &rnn_desc,
const dnn::RnnSequenceTensorDescriptor &input_desc,
+ const DeviceMemory<Eigen::half> &input_data,
+ const dnn::RnnStateTensorDescriptor &input_h_desc,
+ const DeviceMemory<Eigen::half> &input_h_data,
+ const dnn::RnnStateTensorDescriptor &input_c_desc,
+ const DeviceMemory<Eigen::half> &input_c_data,
+ const DeviceMemory<Eigen::half> &params,
+ const dnn::RnnSequenceTensorDescriptor &output_desc,
+ const DeviceMemory<Eigen::half> &output_data,
+ const dnn::RnnStateTensorDescriptor &output_h_desc,
+ const DeviceMemory<Eigen::half> &output_h_data,
+ const dnn::RnnStateTensorDescriptor &output_c_desc,
+ const DeviceMemory<Eigen::half> &output_c_data,
+ const DeviceMemory<Eigen::half> &output_backprop_data,
+ const DeviceMemory<Eigen::half> &output_h_backprop_data,
+ const DeviceMemory<Eigen::half> &output_c_backprop_data,
+ DeviceMemory<Eigen::half> *input_backprop_data,
+ DeviceMemory<Eigen::half> *input_h_backprop_data,
+ DeviceMemory<Eigen::half> *input_c_backprop_data,
+ DeviceMemory<Eigen::half> *params_backprop_data,
+ DeviceMemory<uint8> *reserve_space_data,
+ ScratchAllocator *workspace_allocator) {
+ // TODO(zhengxq): add VLOG PARAM calls.
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoRnnBackward(
+ this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
+ input_c_desc, input_c_data, params, output_desc, output_data,
+ output_h_desc, output_h_data, output_c_desc, output_c_data,
+ output_backprop_data, output_h_backprop_data, output_c_backprop_data,
+ input_backprop_data, input_h_backprop_data, input_c_backprop_data,
+ params_backprop_data, reserve_space_data, workspace_allocator));
+ } else {
+ SetError();
+ LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenRnnBackward(
+ const dnn::RnnDescriptor &rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor &input_desc,
const DeviceMemory<float> &input_data,
const dnn::RnnStateTensorDescriptor &input_h_desc,
const DeviceMemory<float> &input_h_data,