diff options
author | Yao Zhang <yaozhang@google.com> | 2016-09-15 20:44:49 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-15 21:47:12 -0700 |
commit | 04df8d868fab5df0002fa0ec2765dc2e0aeb68d6 (patch) | |
tree | ff4155ad9ced636cc1cca1451aa702805401a516 /tensorflow/stream_executor/stream.h | |
parent | 4e96e274443805df8afad5cb48f654fbf1776a4a (diff) |
Add the interface in steam executor to call cuDNN batch normalization functions.
Change: 133345765
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r-- | tensorflow/stream_executor/stream.h | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 61058528c2..0d16495a1d 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -215,6 +215,27 @@ class Stream { // // See DnnSupport::* for comments on the following methods. + Stream &ThenBatchNormalizationForward( + const DeviceMemory<float> &x, const DeviceMemory<float> &scale, + const DeviceMemory<float> &offset, + const DeviceMemory<float> &estimated_mean, + const DeviceMemory<float> &estimated_variance, + const dnn::BatchDescriptor &x_desc, + const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, + DeviceMemory<float> *y, DeviceMemory<float> *batch_mean, + DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean, + DeviceMemory<float> *saved_inv_var, bool is_training, + std::function<const DeviceMemory<float> &()> var_to_inv_var, + std::function<void()> inv_var_to_var); + + Stream &ThenBatchNormalizationBackward( + const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x, + const DeviceMemory<float> &scale, const DeviceMemory<float> &mean, + const DeviceMemory<float> &variance, const dnn::BatchDescriptor &x_desc, + const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, + DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop, + DeviceMemory<float> *offset_backprop); + // TODO(leary) add double-precision version of this interface. Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor, const DeviceMemory<float> &input_data, |