diff options
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py | 6 |
2 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py index 4b10bc0f8e..4b1105f6bd 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -161,7 +161,7 @@ def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim): proj = random_ops.random_normal( [array_ops.shape(a)[1], random_projection_dim]) proj *= math_ops.rsqrt( - math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True)) + math_ops.reduce_sum(math_ops.square(proj), 0, keepdims=True)) # Project both distributions and sort them. proj_a = math_ops.matmul(a, proj) proj_b = math_ops.matmul(b, proj) diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py index f8b372546b..650eab97a3 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py @@ -64,11 +64,11 @@ def _statistics(x, axes): y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x # Compute true mean while keeping the dims for proper broadcasting. - shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True)) + shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True)) - shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True) + shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True) mean = shifted_mean + shift - mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True) + mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True) mean = array_ops.squeeze(mean, axes) mean_squared = array_ops.squeeze(mean_squared, axes) |