aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-01-26 08:57:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 09:01:21 -0800
commit7fc61bfb50aac4e2d0ff9dab9d99a6001aa5cccf (patch)
tree92988447bbc035f7bb6608df1b78ec7a63db49d9
parentabdc62aee1eeba32be56d761a2f9988306356084 (diff)
Change `reduce_logsumexp` to internally use `reshape` rather than `squeeze`
since the latter requires the `axis` arg to be a Python `list`. PiperOrigin-RevId: 183396533
-rw-r--r--tensorflow/python/ops/math_ops.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 9ad1031354..827e3caa36 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1841,12 +1841,11 @@ def reduce_logsumexp(input_tensor,
reduce_sum(
gen_math_ops.exp(input_tensor - my_max),
axis,
- keepdims=True,
- reduction_indices=reduction_indices)) + my_max
+ keepdims=keepdims,
+ reduction_indices=reduction_indices))
if not keepdims:
- if isinstance(axis, int):
- axis = [axis]
- result = array_ops.squeeze(result, axis)
+ my_max = array_ops.reshape(my_max, array_ops.shape(result))
+ result += my_max
return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)