diff options
Diffstat (limited to 'tensorflow/python/ops/distributions/util.py')
-rw-r--r-- | tensorflow/python/ops/distributions/util.py | 11 |
1 files changed, 4 insertions, 7 deletions
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 5bc25128a8..0fe6aa30f9 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -1041,14 +1041,14 @@ def reduce_weighted_logsumexp( with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]): logx = ops.convert_to_tensor(logx, name="logx") if w is None: - lswe = math_ops.reduce_logsumexp(logx, axis=axis, keep_dims=keep_dims) + lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims) if return_sign: sgn = array_ops.ones_like(lswe) return lswe, sgn return lswe w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w") log_absw_x = logx + math_ops.log(math_ops.abs(w)) - max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keep_dims=True) + max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True) # If the largest element is `-inf` or `inf` then we don't bother subtracting # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That # this is ok follows from the fact that we're actually free to subtract any @@ -1060,9 +1060,7 @@ def reduce_weighted_logsumexp( wx_over_max_absw_x = ( math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x)) sum_wx_over_max_absw_x = math_ops.reduce_sum( - wx_over_max_absw_x, - axis=axis, - keep_dims=keep_dims) + wx_over_max_absw_x, axis=axis, keepdims=keep_dims) if not keep_dims: max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis) sgn = math_ops.sign(sum_wx_over_max_absw_x) @@ -1180,8 +1178,7 @@ def process_quadrature_grid_and_probs( grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype) - probs /= linalg_ops.norm(probs, ord=1, axis=-1, keep_dims=True, - name="probs") + probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs") def _static_event_size(x): """Returns the static size of a specific dimension or `None`.""" |