aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_with_summaries.py2
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py90
-rw-r--r--tensorflow/python/ops/hidden_ops.txt2
-rw-r--r--tensorflow/python/ops/nn_ops.py129
4 files changed, 194 insertions, 29 deletions
diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
index 23492e5122..868cfcc3e4 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
@@ -100,7 +100,7 @@ def train():
with tf.name_scope('Wx_plus_b'):
preactivate = tf.matmul(input_tensor, weights) + biases
tf.histogram_summary(layer_name + '/pre_activations', preactivate)
- activations = act(preactivate, 'activation')
+ activations = act(preactivate, name='activation')
tf.histogram_summary(layer_name + '/activations', activations)
return activations
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index 00370d98c3..d4525c730e 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -26,37 +26,39 @@ import tensorflow as tf
class SoftmaxTest(tf.test.TestCase):
- def _npSoftmax(self, features, log=False):
- batch_dim = 0
- class_dim = 1
- batch_size = features.shape[batch_dim]
- e = np.exp(features -
- np.reshape(np.amax(features, axis=class_dim), [batch_size, 1]))
- softmax = e / np.reshape(np.sum(e, axis=class_dim), [batch_size, 1])
+ def _npSoftmax(self, features, dim=-1, log=False):
+ if dim is -1:
+ dim = len(features.shape) - 1
+ one_only_on_dim = list(features.shape)
+ one_only_on_dim[dim] = 1
+ e = np.exp(features - np.reshape(
+ np.amax(
+ features, axis=dim), one_only_on_dim))
+ softmax = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim)
if log:
return np.log(softmax)
else:
return softmax
- def _testSoftmax(self, np_features, log=False, use_gpu=False):
+ def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False):
# A previous version of the code checked the op name rather than the op type
# to distinguish between log and non-log. Use an arbitrary name to catch
# this bug in future.
name = "arbitrary"
- np_softmax = self._npSoftmax(np_features, log=log)
+ np_softmax = self._npSoftmax(np_features, dim=dim, log=log)
with self.test_session(use_gpu=use_gpu):
if log:
- tf_softmax = tf.nn.log_softmax(np_features, name=name)
+ tf_softmax = tf.nn.log_softmax(np_features, dim=dim, name=name)
else:
- tf_softmax = tf.nn.softmax(np_features, name=name)
+ tf_softmax = tf.nn.softmax(np_features, dim=dim, name=name)
out = tf_softmax.eval()
self.assertAllCloseAccordingToType(np_softmax, out)
self.assertShapeEqual(np_softmax, tf_softmax)
if not log:
- # Bonus check: the softmaxes should add to one in each
- # batch element.
- self.assertAllCloseAccordingToType(np.ones(out.shape[0]),
- np.sum(out, axis=1))
+ # Bonus check: the softmaxes should add to one in dimension dim.
+ sum_along_dim = np.sum(out, axis=dim)
+ self.assertAllCloseAccordingToType(
+ np.ones(sum_along_dim.shape), sum_along_dim)
def _testAll(self, features):
self._testSoftmax(features, use_gpu=False)
@@ -90,17 +92,11 @@ class SoftmaxTest(tf.test.TestCase):
np_lsm,
rtol=1.e-5, atol=1.e-5)
- def testShapeMismatch(self):
- with self.assertRaises(ValueError):
- tf.nn.softmax([0., 1., 2., 3.])
- with self.assertRaises(ValueError):
- tf.nn.log_softmax([0., 1., 2., 3.])
-
def _testOverflow(self, use_gpu=False):
if use_gpu:
- type = np.float32
+ type = np.float32
else:
- type = np.float64
+ type = np.float64
max = np.finfo(type).max
features = np.array(
[[1., 1., 1., 1.],
@@ -128,13 +124,55 @@ class SoftmaxTest(tf.test.TestCase):
use_gpu=False)
self._testOverflow(use_gpu=False)
+ def test1DTesnorAsInput(self):
+ self._testSoftmax(
+ np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False)
+ self._testOverflow(use_gpu=False)
+
+ def test3DTensorAsInput(self):
+ self._testSoftmax(
+ np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
+ [[2., 3., 4., 5.], [6., 7., 8., 9.]],
+ [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32),
+ use_gpu=False)
+ self._testOverflow(use_gpu=False)
+
+ def testAlongFirstDimension(self):
+ self._testSoftmax(
+ np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
+ [[2., 3., 4., 5.], [6., 7., 8., 9.]],
+ [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32),
+ dim=0,
+ use_gpu=False)
+ self._testOverflow(use_gpu=False)
+
+ def testAlongSecondDimension(self):
+ self._testSoftmax(
+ np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
+ [[2., 3., 4., 5.], [6., 7., 8., 9.]],
+ [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32),
+ dim=1,
+ use_gpu=False)
+ self._testOverflow(use_gpu=False)
+
+ def testShapeInference(self):
+ op = tf.nn.softmax([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
+ [[2., 3., 4., 5.], [6., 7., 8., 9.]],
+ [[5., 4., 3., 2.], [1., 2., 3., 4.]]])
+ self.assertEqual([3, 2, 4], op.get_shape())
- def testEmpty(self):
+ def testEmptyInput(self):
with self.test_session():
x = tf.constant([[]], shape=[0, 3])
self.assertEqual(0, tf.size(x).eval())
- expected_y = np.array([]).reshape(0, 3)
- np.testing.assert_array_equal(expected_y, tf.nn.softmax(x).eval())
+ # reshape would raise if logits is empty
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ tf.nn.softmax(x, dim=0).eval()
+
+ def testDimTooLarge(self):
+ with self.test_session():
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ tf.nn.softmax([1., 2., 3., 4.], dim=100).eval()
if __name__ == "__main__":
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index c56db8f5d9..3a94b9f4e3 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -179,6 +179,8 @@ BiasAddV1
Relu6
AvgPool
MaxPool
+Softmax
+LogSoftmax
# parsing_ops
ParseExample
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 11ca3ff0d7..f3805c3f2d 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -468,6 +468,129 @@ def relu6(features, name=None):
features = ops.convert_to_tensor(features, name="features")
return gen_nn_ops._relu6(features, name=name)
+def _softmax(logits, compute_op, dim=-1, name=None):
+ """Helper function for softmax and log_softmax.
+
+ It reshapes and transposes the input logits into a 2-D Tensor and then invokes
+ the tf.nn._softmax or tf.nn._log_softmax function. The output would be
+ transposed and reshaped back.
+
+ Args:
+ logits: A non-empty `Tensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ compute_op: Either gen_nn_ops._softmax or gen_nn_ops._log_softmax
+ dim: The dimension softmax would be performed on. The default is -1 which
+ indicates the last dimension.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
+ Raises:
+ InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
+ dimension of `logits`.
+ """
+ # Helper function to swap dim_index and last_index of logits. last_index must
+ # be logits' last dimension.
+ def _swap_axis(logits, dim_index, last_index):
+ return array_ops.transpose(logits, array_ops.concat(
+ 0, [math_ops.range(dim_index), [last_index],
+ math_ops.range(dim_index + 1, last_index), [dim_index]]))
+
+ # Helper function to flatten logits' outer dimensions and keep its last
+ # dimension.
+ def _flatten_outer_dims(logits):
+ rank = array_ops.rank(logits)
+ last_dim_size = array_ops.slice(
+ array_ops.shape(logits), [math_ops.sub(rank, 1)], [1])
+ return array_ops.reshape(logits, array_ops.concat(0, [[-1], last_dim_size]))
+
+ logits = ops.convert_to_tensor(logits)
+ if logits.get_shape().ndims is 2 and dim is -1:
+ return compute_op(logits, name=name)
+
+ # We need its original shape for shape inference.
+ shape = logits.get_shape()
+
+ # If dim is the last dimension, simply reshape the logits to a matrix and
+ # apply the internal softmax.
+ if dim is -1:
+ input_shape = array_ops.shape(logits)
+ logits = _flatten_outer_dims(logits)
+ output = compute_op(logits, name=name)
+ output = array_ops.reshape(output, input_shape)
+ return output
+
+ # If dim is not the last dimension, we have to do a reshape and transpose so
+ # that we can still perform softmax on its last dimension.
+
+ # Swap logits' dimension of dim and its last dimension.
+ input_rank = array_ops.rank(logits)
+ logits = _swap_axis(logits, dim, math_ops.sub(input_rank, 1))
+ shape_after_swap = array_ops.shape(logits)
+
+ # Reshape logits into a matrix.
+ logits = _flatten_outer_dims(logits)
+
+ # Do the actual softmax on its last dimension.
+ output = compute_op(logits, name=name)
+
+ # Transform back the output tensor.
+ output = array_ops.reshape(output, shape_after_swap)
+ output = _swap_axis(output, dim, math_ops.sub(input_rank, 1))
+
+ # Make shape inference work since reshape and transpose may erase its static
+ # shape.
+ output.set_shape(shape)
+
+ return output
+
+
+def softmax(logits, dim=-1, name=None):
+ """Computes log softmax activations.
+
+ For each batch `i` and class `j` we have
+
+ softmax = exp(logits) / reduce_sum(exp(logits), dim)
+
+ Args:
+ logits: A non-empty `Tensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ dim: The dimension softmax would be performed on. The default is -1 which
+ indicates the last dimension.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
+ Raises:
+ InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
+ dimension of `logits`.
+ """
+ return _softmax(logits, gen_nn_ops._softmax, dim, name)
+
+
+def log_softmax(logits, dim=-1, name=None):
+ """Computes log softmax activations.
+
+ For each batch `i` and class `j` we have
+
+ logsoftmax = logits - reduce_sum(exp(logits), dim)
+
+ Args:
+ logits: A non-empty `Tensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ dim: The dimension softmax would be performed on. The default is -1 which
+ indicates the last dimension.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
+
+ Raises:
+ InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
+ dimension of `logits`.
+ """
+ return _softmax(logits, gen_nn_ops._log_softmax, dim, name)
+
def softmax_cross_entropy_with_logits(logits, labels, name=None):
"""Computes softmax cross entropy between `logits` and `labels`.
@@ -727,9 +850,11 @@ def _LRNGradShape(op):
return [in_grads_shape.merge_with(in_image_shape).merge_with(out_image_shape)]
-ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank(2))
+ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank_at_least(
+ 1))
-ops.RegisterShape("LogSoftmax")(common_shapes.unchanged_shape_with_rank(2))
+ops.RegisterShape("LogSoftmax")(
+ common_shapes.unchanged_shape_with_rank_at_least(1))
@ops.RegisterShape("InTopK")