diff options
author | Reed Wanderman-Milne <reedwm@google.com> | 2017-09-27 12:58:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-27 13:04:57 -0700 |
commit | 759690f026a1a08b3ac5cc84d8498c05c32b2a7d (patch) | |
tree | 9c7ba12fef51b97226f4e0a07b9aa0eff7fccff1 /tensorflow/stream_executor/stream.h | |
parent | 20370104cd8adf4c3f9068dfe95bde54cccadfa5 (diff) |
Add float16 support to tf.nn.fused_batch_norm on the GPU.
Scale, offset, mean, and variance must still be float32 if the input is float16.
PiperOrigin-RevId: 170239448
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r-- | tensorflow/stream_executor/stream.h | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 98484eb850..a72ee804c1 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -239,6 +239,29 @@ class Stream { DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop); + Stream &ThenBatchNormalizationForward( + const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, + const DeviceMemory<float> &offset, + const DeviceMemory<float> &estimated_mean, + const DeviceMemory<float> &estimated_variance, + const dnn::BatchDescriptor &x_desc, + const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, + DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean, + DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean, + DeviceMemory<float> *saved_inv_var, bool is_training, + std::function<const DeviceMemory<float> &()> var_to_inv_var, + std::function<void()> inv_var_to_var); + + Stream &ThenBatchNormalizationBackward( + const DeviceMemory<Eigen::half> &y_backprop, + const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, + const DeviceMemory<float> &mean, const DeviceMemory<float> &variance, + const dnn::BatchDescriptor &x_desc, + const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, + DeviceMemory<Eigen::half> *x_backprop, + DeviceMemory<float> *scale_backprop, + DeviceMemory<float> *offset_backprop); + // TODO(leary) add double-precision version of this interface. Stream &ThenFusedConvolve( const dnn::BatchDescriptor &conv_input_descriptor, |