aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2016-07-19 08:37:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-19 09:47:38 -0700
commit9223380e473b1b77791c66272556a441f2d19a40 (patch)
tree819d24a3504a1c50380dba3636c1fb92c5d446f7
parente728b00327934effec0bbae93caaf363165d0408 (diff)
Improved support for fp16
Change: 127840021
-rw-r--r--tensorflow/python/ops/nn.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 79030645f2..4036a3cd54 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -776,6 +776,8 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
# sufficient statistics. As a workaround we simply perform the operations
# on 32-bit floats before converting the mean and variance back to fp16
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
+ shift = math_ops.cast(shift, dtypes.float32) if (
+ shift and x.dtype == dtypes.float16) else shift
counts, m_ss, v_ss, shift = sufficient_statistics(y,
axes,
shift=shift,