aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python')
-rw-r--r--tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py2
-rw-r--r--tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py6
2 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py
index 4b10bc0f8e..4b1105f6bd 100644
--- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py
@@ -161,7 +161,7 @@ def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim):
proj = random_ops.random_normal(
[array_ops.shape(a)[1], random_projection_dim])
proj *= math_ops.rsqrt(
- math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True))
+ math_ops.reduce_sum(math_ops.square(proj), 0, keepdims=True))
# Project both distributions and sort them.
proj_a = math_ops.matmul(a, proj)
proj_b = math_ops.matmul(b, proj)
diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
index f8b372546b..650eab97a3 100644
--- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
+++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
@@ -64,11 +64,11 @@ def _statistics(x, axes):
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
# Compute true mean while keeping the dims for proper broadcasting.
- shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True))
+ shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True))
- shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True)
+ shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True)
mean = shifted_mean + shift
- mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True)
+ mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True)
mean = array_ops.squeeze(mean, axes)
mean_squared = array_ops.squeeze(mean_squared, axes)