aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/distributions/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/distributions/util.py')
-rw-r--r--tensorflow/python/ops/distributions/util.py11
1 files changed, 4 insertions, 7 deletions
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 5bc25128a8..0fe6aa30f9 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -1041,14 +1041,14 @@ def reduce_weighted_logsumexp(
with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]):
logx = ops.convert_to_tensor(logx, name="logx")
if w is None:
- lswe = math_ops.reduce_logsumexp(logx, axis=axis, keep_dims=keep_dims)
+ lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
if return_sign:
sgn = array_ops.ones_like(lswe)
return lswe, sgn
return lswe
w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w")
log_absw_x = logx + math_ops.log(math_ops.abs(w))
- max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keep_dims=True)
+ max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True)
# If the largest element is `-inf` or `inf` then we don't bother subtracting
# off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
# this is ok follows from the fact that we're actually free to subtract any
@@ -1060,9 +1060,7 @@ def reduce_weighted_logsumexp(
wx_over_max_absw_x = (
math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x))
sum_wx_over_max_absw_x = math_ops.reduce_sum(
- wx_over_max_absw_x,
- axis=axis,
- keep_dims=keep_dims)
+ wx_over_max_absw_x, axis=axis, keepdims=keep_dims)
if not keep_dims:
max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis)
sgn = math_ops.sign(sum_wx_over_max_absw_x)
@@ -1180,8 +1178,7 @@ def process_quadrature_grid_and_probs(
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
probs = ops.convert_to_tensor(probs, name="unnormalized_probs",
dtype=dtype)
- probs /= linalg_ops.norm(probs, ord=1, axis=-1, keep_dims=True,
- name="probs")
+ probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs")
def _static_event_size(x):
"""Returns the static size of a specific dimension or `None`."""