diff options
author | 2017-09-27 12:58:14 -0700 | |
---|---|---|
committer | 2017-09-27 13:04:57 -0700 | |
commit | 759690f026a1a08b3ac5cc84d8498c05c32b2a7d (patch) | |
tree | 9c7ba12fef51b97226f4e0a07b9aa0eff7fccff1 /tensorflow/stream_executor/dnn.h | |
parent | 20370104cd8adf4c3f9068dfe95bde54cccadfa5 (diff) |
Add float16 support to tf.nn.fused_batch_norm on the GPU.
Scale, offset, mean, and variance must still be float32 if the input is float16.
PiperOrigin-RevId: 170239448
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r-- | tensorflow/stream_executor/dnn.h | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index b11c6417be..4beb46090c 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -900,6 +900,23 @@ class DnnSupport { return false; } + // Performs a half-precision forwards batch normalization operation onto the + // stream. See DoBatchNormalizationForward above for argument details. + virtual bool DoBatchNormalizationForward( + Stream* stream, const DeviceMemory<Eigen::half>& x, + const DeviceMemory<float>& scale, const DeviceMemory<float>& offset, + const DeviceMemory<float>& estimated_mean, + const DeviceMemory<float>& estimated_variance, + const dnn::BatchDescriptor& x_desc, + const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, + DeviceMemory<Eigen::half>* y, DeviceMemory<float>* batch_mean, + DeviceMemory<float>* batch_var, DeviceMemory<float>* reserve_space_1, + DeviceMemory<float>* reserve_space_2, bool is_training, + std::function<const DeviceMemory<float>&()> var_to_inv_var, + std::function<void()> inv_var_to_var) { + return false; + } + // Performs a single-precision backward batch normalization gradient // computation operation onto the stream. // @@ -927,6 +944,21 @@ class DnnSupport { return false; } + // Performs a half-precision backward batch normalization gradient computation + // operation onto the stream. See DoBatchNormalizationBackward above for + // argument details. + virtual bool DoBatchNormalizationBackward( + Stream* stream, const DeviceMemory<Eigen::half>& y_backprop, + const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale, + const DeviceMemory<float>& mean, const DeviceMemory<float>& variance, + const dnn::BatchDescriptor& x_desc, + const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, + DeviceMemory<Eigen::half>* x_backprop, + DeviceMemory<float>* scale_backprop, + DeviceMemory<float>* offset_backprop) { + return false; + } + // Enqueues a fused convolution operation onto the stream. // We provide several variants with different types for inputs, biases and // scaling parameters. |