diff options
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 66ccedf546..ccce9402c7 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1478,14 +1478,14 @@ def _softmax(logits, compute_op, dim=-1, name=None): InvalidArgumentError: if `logits` is empty or `dim` is beyond the last dimension of `logits`. """ - def _swap_axis(logits, dim_index, last_index): + def _swap_axis(logits, dim_index, last_index, name=None): """Swaps logits's dim_index and last_index.""" return array_ops.transpose(logits, array_ops.concat([ math_ops.range(dim_index), [last_index], math_ops.range(dim_index + 1, last_index), [dim_index] - ], 0)) + ], 0), name=name) logits = ops.convert_to_tensor(logits) @@ -1501,8 +1501,8 @@ def _softmax(logits, compute_op, dim=-1, name=None): if is_last_dim: input_shape = array_ops.shape(logits) logits = _flatten_outer_dims(logits) - output = compute_op(logits, name=name) - output = array_ops.reshape(output, input_shape) + output = compute_op(logits) + output = array_ops.reshape(output, input_shape, name=name) return output # If dim is not the last dimension, we have to do a reshape and transpose so @@ -1517,11 +1517,11 @@ def _softmax(logits, compute_op, dim=-1, name=None): logits = _flatten_outer_dims(logits) # Do the actual softmax on its last dimension. - output = compute_op(logits, name=name) + output = compute_op(logits) # Transform back the output tensor. output = array_ops.reshape(output, shape_after_swap) - output = _swap_axis(output, dim, math_ops.subtract(input_rank, 1)) + output = _swap_axis(output, dim, math_ops.subtract(input_rank, 1), name=name) # Make shape inference work since reshape and transpose may erase its static # shape. |