aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar Reed Wanderman-Milne <reedwm@google.com>2017-09-27 12:58:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-27 13:04:57 -0700
commit759690f026a1a08b3ac5cc84d8498c05c32b2a7d (patch)
tree9c7ba12fef51b97226f4e0a07b9aa0eff7fccff1 /tensorflow/stream_executor/stream.cc
parent20370104cd8adf4c3f9068dfe95bde54cccadfa5 (diff)
Add float16 support to tf.nn.fused_batch_norm on the GPU.
Scale, offset, mean, and variance must still be float32 if the input is float16. PiperOrigin-RevId: 170239448
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc51
1 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index dc768e0273..6d756ab191 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -361,6 +361,57 @@ Stream &Stream::ThenBatchNormalizationBackward(
return *this;
}
+Stream &Stream::ThenBatchNormalizationForward(
+ const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
+ const DeviceMemory<float> &offset,
+ const DeviceMemory<float> &estimated_mean,
+ const DeviceMemory<float> &estimated_variance,
+ const dnn::BatchDescriptor &x_desc,
+ const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
+ DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean,
+ DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
+ DeviceMemory<float> *saved_inv_var, bool is_training,
+ std::function<const DeviceMemory<float> &()> var_to_inv_var,
+ std::function<void()> inv_var_to_var) {
+ VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
+ PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoBatchNormalizationForward(
+ this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
+ scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
+ saved_inv_var, is_training, std::move(var_to_inv_var),
+ std::move(inv_var_to_var)));
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenBatchNormalizationBackward(
+ const DeviceMemory<Eigen::half> &y_backprop,
+ const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
+ const DeviceMemory<float> &mean, const DeviceMemory<float> &variance,
+ const dnn::BatchDescriptor &x_desc,
+ const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
+ DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
+ DeviceMemory<float> *offset_backprop) {
+ VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
+ PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
+ PARAM(scale_backprop), PARAM(offset_backprop));
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoBatchNormalizationBackward(
+ this, y_backprop, x, scale, mean, variance, x_desc, scale_offset_desc,
+ epsilon, x_backprop, scale_backprop, offset_backprop));
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
Stream &Stream::ThenFusedConvolveWithScratch(
const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<int8> &conv_input_data, float conv_input_scale,