aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_impl.py')
-rw-r--r--tensorflow/python/ops/nn_impl.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 783d485892..f47f38e29e 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -621,7 +621,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
"""Calculate the mean and variance of based on the sufficient statistics.
Args:
- counts: A `Tensor` containing a the total count of the data (one value).
+ counts: A `Tensor` containing the total count of the data (one value).
mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
shifted) sum of the elements to average over.
variance_ss: A `Tensor` containing the variance sufficient statistics: the
@@ -689,6 +689,9 @@ def moments(
# Compute true mean while keeping the dims for proper broadcasting.
mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
# sample variance, not unbiased variance
+ # Note: stop_gradient does not change the gradient that gets
+ # backpropagated to the mean from the variance calculation,
+ # because that gradient is zero
variance = math_ops.reduce_mean(
math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
axes,