diff options
Diffstat (limited to 'tensorflow/python/ops/nn_impl.py')
-rw-r--r-- | tensorflow/python/ops/nn_impl.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 783d485892..f47f38e29e 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -621,7 +621,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): """Calculate the mean and variance of based on the sufficient statistics. Args: - counts: A `Tensor` containing a the total count of the data (one value). + counts: A `Tensor` containing the total count of the data (one value). mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly shifted) sum of the elements to average over. variance_ss: A `Tensor` containing the variance sufficient statistics: the @@ -689,6 +689,9 @@ def moments( # Compute true mean while keeping the dims for proper broadcasting. mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean") # sample variance, not unbiased variance + # Note: stop_gradient does not change the gradient that gets + # backpropagated to the mean from the variance calculation, + # because that gradient is zero variance = math_ops.reduce_mean( math_ops.squared_difference(y, array_ops.stop_gradient(mean)), axes, |