diff options
author | 2018-05-29 17:42:37 -0700 | |
---|---|---|
committer | 2018-05-29 17:46:01 -0700 | |
commit | ce88b47799caa472509a34c6c2e4265e2d16ceb9 (patch) | |
tree | 52b8b79d5895462f918424a1beb8c223a9f95248 /tensorflow/python | |
parent | 79755d82a02526950ee4bd3fbc11d515308e76fd (diff) |
Use absolute indexing in `fill_triangular`.
PiperOrigin-RevId: 198485926
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/ops/distributions/util.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 728fda28c2..1b2c8762a4 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -914,10 +914,11 @@ def fill_triangular(x, upper=False, name=None): # = 2 (n**2 / 2 + n / 2) - n**2 # = n**2 + n - n**2 # = n + ndims = array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims if upper: - x_list = [x, array_ops.reverse(x[..., n:], axis=[-1])] + x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])] else: - x_list = [x[..., n:], array_ops.reverse(x, axis=[-1])] + x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])] new_shape = ( static_final_shape.as_list() if static_final_shape.is_fully_defined() |