aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar Reed Wanderman-Milne <reedwm@google.com>2017-09-27 12:58:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-27 13:04:57 -0700
commit759690f026a1a08b3ac5cc84d8498c05c32b2a7d (patch)
tree9c7ba12fef51b97226f4e0a07b9aa0eff7fccff1 /tensorflow/stream_executor/dnn.h
parent20370104cd8adf4c3f9068dfe95bde54cccadfa5 (diff)
Add float16 support to tf.nn.fused_batch_norm on the GPU.
Scale, offset, mean, and variance must still be float32 if the input is float16. PiperOrigin-RevId: 170239448
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index b11c6417be..4beb46090c 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -900,6 +900,23 @@ class DnnSupport {
return false;
}
+ // Performs a half-precision forwards batch normalization operation onto the
+ // stream. See DoBatchNormalizationForward above for argument details.
+ virtual bool DoBatchNormalizationForward(
+ Stream* stream, const DeviceMemory<Eigen::half>& x,
+ const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
+ const DeviceMemory<float>& estimated_mean,
+ const DeviceMemory<float>& estimated_variance,
+ const dnn::BatchDescriptor& x_desc,
+ const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
+ DeviceMemory<Eigen::half>* y, DeviceMemory<float>* batch_mean,
+ DeviceMemory<float>* batch_var, DeviceMemory<float>* reserve_space_1,
+ DeviceMemory<float>* reserve_space_2, bool is_training,
+ std::function<const DeviceMemory<float>&()> var_to_inv_var,
+ std::function<void()> inv_var_to_var) {
+ return false;
+ }
+
// Performs a single-precision backward batch normalization gradient
// computation operation onto the stream.
//
@@ -927,6 +944,21 @@ class DnnSupport {
return false;
}
+ // Performs a half-precision backward batch normalization gradient computation
+ // operation onto the stream. See DoBatchNormalizationBackward above for
+ // argument details.
+ virtual bool DoBatchNormalizationBackward(
+ Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
+ const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
+ const DeviceMemory<float>& mean, const DeviceMemory<float>& variance,
+ const dnn::BatchDescriptor& x_desc,
+ const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
+ DeviceMemory<Eigen::half>* x_backprop,
+ DeviceMemory<float>* scale_backprop,
+ DeviceMemory<float>* offset_backprop) {
+ return false;
+ }
+
// Enqueues a fused convolution operation onto the stream.
// We provide several variants with different types for inputs, biases and
// scaling parameters.