aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r--tensorflow/python/ops/nn_ops.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 66ccedf546..ccce9402c7 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1478,14 +1478,14 @@ def _softmax(logits, compute_op, dim=-1, name=None):
InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
dimension of `logits`.
"""
- def _swap_axis(logits, dim_index, last_index):
+ def _swap_axis(logits, dim_index, last_index, name=None):
"""Swaps logits's dim_index and last_index."""
return array_ops.transpose(logits,
array_ops.concat([
math_ops.range(dim_index), [last_index],
math_ops.range(dim_index + 1, last_index),
[dim_index]
- ], 0))
+ ], 0), name=name)
logits = ops.convert_to_tensor(logits)
@@ -1501,8 +1501,8 @@ def _softmax(logits, compute_op, dim=-1, name=None):
if is_last_dim:
input_shape = array_ops.shape(logits)
logits = _flatten_outer_dims(logits)
- output = compute_op(logits, name=name)
- output = array_ops.reshape(output, input_shape)
+ output = compute_op(logits)
+ output = array_ops.reshape(output, input_shape, name=name)
return output
# If dim is not the last dimension, we have to do a reshape and transpose so
@@ -1517,11 +1517,11 @@ def _softmax(logits, compute_op, dim=-1, name=None):
logits = _flatten_outer_dims(logits)
# Do the actual softmax on its last dimension.
- output = compute_op(logits, name=name)
+ output = compute_op(logits)
# Transform back the output tensor.
output = array_ops.reshape(output, shape_after_swap)
- output = _swap_axis(output, dim, math_ops.subtract(input_rank, 1))
+ output = _swap_axis(output, dim, math_ops.subtract(input_rank, 1), name=name)
# Make shape inference work since reshape and transpose may erase its static
# shape.