diff options
author | 2016-09-15 20:44:49 -0800 | |
---|---|---|
committer | 2016-09-15 21:47:12 -0700 | |
commit | 04df8d868fab5df0002fa0ec2765dc2e0aeb68d6 (patch) | |
tree | ff4155ad9ced636cc1cca1451aa702805401a516 /tensorflow/stream_executor/stream.cc | |
parent | 4e96e274443805df8afad5cb48f654fbf1776a4a (diff) |
Add the interface in steam executor to call cuDNN batch normalization functions.
Change: 133345765
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 8c0e45f1a6..512e882cad 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -277,6 +277,62 @@ Stream &Stream::ThenRecordEvent(Event *event) { return *this; } +Stream &Stream::ThenBatchNormalizationForward( + const DeviceMemory<float> &x, const DeviceMemory<float> &scale, + const DeviceMemory<float> &offset, + const DeviceMemory<float> &estimated_mean, + const DeviceMemory<float> &estimated_variance, + const dnn::BatchDescriptor &x_desc, + const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, + DeviceMemory<float> *y, DeviceMemory<float> *batch_mean, + DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean, + DeviceMemory<float> *saved_inv_var, bool is_training, + std::function<const DeviceMemory<float> &()> var_to_inv_var, + std::function<void()> inv_var_to_var) { + VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), + PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoBatchNormalizationForward( + this, x, scale, offset, estimated_mean, estimated_variance, x_desc, + scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean, + saved_inv_var, is_training, std::move(var_to_inv_var), + std::move(inv_var_to_var))); + } else { + SetError(); + LOG(WARNING) + << "attempting to perform DNN operation using StreamExecutor " + "without DNN support"; + } + } + return *this; +} + +Stream &Stream::ThenBatchNormalizationBackward( + const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x, + const DeviceMemory<float> &scale, const DeviceMemory<float> &mean, + const DeviceMemory<float> &variance, const dnn::BatchDescriptor &x_desc, + const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, + DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop, + DeviceMemory<float> *offset_backprop) { + VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), + PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), + PARAM(scale_backprop), PARAM(offset_backprop)); + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoBatchNormalizationBackward( + this, y_backprop, x, scale, mean, variance, x_desc, scale_offset_desc, + epsilon, x_backprop, scale_backprop, offset_backprop)); + } else { + SetError(); + LOG(WARNING) + << "attempting to perform DNN operation using StreamExecutor " + "without DNN support"; + } + } + return *this; +} + Stream &Stream::ThenConvolveWithScratch( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory<Eigen::half> &input_data, |