aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py')
-rw-r--r--tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
index f8b372546b..650eab97a3 100644
--- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
+++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
@@ -64,11 +64,11 @@ def _statistics(x, axes):
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
# Compute true mean while keeping the dims for proper broadcasting.
- shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True))
+ shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True))
- shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True)
+ shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True)
mean = shifted_mean + shift
- mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True)
+ mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True)
mean = array_ops.squeeze(mean, axes)
mean_squared = array_ops.squeeze(mean_squared, axes)