aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-05-29 17:42:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-29 17:46:01 -0700
commitce88b47799caa472509a34c6c2e4265e2d16ceb9 (patch)
tree52b8b79d5895462f918424a1beb8c223a9f95248 /tensorflow/python
parent79755d82a02526950ee4bd3fbc11d515308e76fd (diff)
Use absolute indexing in `fill_triangular`.
PiperOrigin-RevId: 198485926
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/ops/distributions/util.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 728fda28c2..1b2c8762a4 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -914,10 +914,11 @@ def fill_triangular(x, upper=False, name=None):
# = 2 (n**2 / 2 + n / 2) - n**2
# = n**2 + n - n**2
# = n
+ ndims = array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims
if upper:
- x_list = [x, array_ops.reverse(x[..., n:], axis=[-1])]
+ x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
else:
- x_list = [x[..., n:], array_ops.reverse(x, axis=[-1])]
+ x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])]
new_shape = (
static_final_shape.as_list()
if static_final_shape.is_fully_defined()