aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.h
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2016-09-15 20:44:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-15 21:47:12 -0700
commit04df8d868fab5df0002fa0ec2765dc2e0aeb68d6 (patch)
treeff4155ad9ced636cc1cca1451aa702805401a516 /tensorflow/stream_executor/stream.h
parent4e96e274443805df8afad5cb48f654fbf1776a4a (diff)
Add the interface in steam executor to call cuDNN batch normalization functions.
Change: 133345765
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r--tensorflow/stream_executor/stream.h21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 61058528c2..0d16495a1d 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -215,6 +215,27 @@ class Stream {
//
// See DnnSupport::* for comments on the following methods.
+ Stream &ThenBatchNormalizationForward(
+ const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
+ const DeviceMemory<float> &offset,
+ const DeviceMemory<float> &estimated_mean,
+ const DeviceMemory<float> &estimated_variance,
+ const dnn::BatchDescriptor &x_desc,
+ const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
+ DeviceMemory<float> *y, DeviceMemory<float> *batch_mean,
+ DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
+ DeviceMemory<float> *saved_inv_var, bool is_training,
+ std::function<const DeviceMemory<float> &()> var_to_inv_var,
+ std::function<void()> inv_var_to_var);
+
+ Stream &ThenBatchNormalizationBackward(
+ const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
+ const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
+ const DeviceMemory<float> &variance, const dnn::BatchDescriptor &x_desc,
+ const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
+ DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
+ DeviceMemory<float> *offset_backprop);
+
// TODO(leary) add double-precision version of this interface.
Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,