diff options
-rw-r--r-- | tensorflow/examples/tutorials/mnist/mnist_with_summaries.py | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/softmax_op_test.py | 90 | ||||
-rw-r--r-- | tensorflow/python/ops/hidden_ops.txt | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 129 |
4 files changed, 194 insertions, 29 deletions
diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py index 23492e5122..868cfcc3e4 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py +++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py @@ -100,7 +100,7 @@ def train(): with tf.name_scope('Wx_plus_b'): preactivate = tf.matmul(input_tensor, weights) + biases tf.histogram_summary(layer_name + '/pre_activations', preactivate) - activations = act(preactivate, 'activation') + activations = act(preactivate, name='activation') tf.histogram_summary(layer_name + '/activations', activations) return activations diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py index 00370d98c3..d4525c730e 100644 --- a/tensorflow/python/kernel_tests/softmax_op_test.py +++ b/tensorflow/python/kernel_tests/softmax_op_test.py @@ -26,37 +26,39 @@ import tensorflow as tf class SoftmaxTest(tf.test.TestCase): - def _npSoftmax(self, features, log=False): - batch_dim = 0 - class_dim = 1 - batch_size = features.shape[batch_dim] - e = np.exp(features - - np.reshape(np.amax(features, axis=class_dim), [batch_size, 1])) - softmax = e / np.reshape(np.sum(e, axis=class_dim), [batch_size, 1]) + def _npSoftmax(self, features, dim=-1, log=False): + if dim is -1: + dim = len(features.shape) - 1 + one_only_on_dim = list(features.shape) + one_only_on_dim[dim] = 1 + e = np.exp(features - np.reshape( + np.amax( + features, axis=dim), one_only_on_dim)) + softmax = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim) if log: return np.log(softmax) else: return softmax - def _testSoftmax(self, np_features, log=False, use_gpu=False): + def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False): # A previous version of the code checked the op name rather than the op type # to distinguish between log and non-log. Use an arbitrary name to catch # this bug in future. name = "arbitrary" - np_softmax = self._npSoftmax(np_features, log=log) + np_softmax = self._npSoftmax(np_features, dim=dim, log=log) with self.test_session(use_gpu=use_gpu): if log: - tf_softmax = tf.nn.log_softmax(np_features, name=name) + tf_softmax = tf.nn.log_softmax(np_features, dim=dim, name=name) else: - tf_softmax = tf.nn.softmax(np_features, name=name) + tf_softmax = tf.nn.softmax(np_features, dim=dim, name=name) out = tf_softmax.eval() self.assertAllCloseAccordingToType(np_softmax, out) self.assertShapeEqual(np_softmax, tf_softmax) if not log: - # Bonus check: the softmaxes should add to one in each - # batch element. - self.assertAllCloseAccordingToType(np.ones(out.shape[0]), - np.sum(out, axis=1)) + # Bonus check: the softmaxes should add to one in dimension dim. + sum_along_dim = np.sum(out, axis=dim) + self.assertAllCloseAccordingToType( + np.ones(sum_along_dim.shape), sum_along_dim) def _testAll(self, features): self._testSoftmax(features, use_gpu=False) @@ -90,17 +92,11 @@ class SoftmaxTest(tf.test.TestCase): np_lsm, rtol=1.e-5, atol=1.e-5) - def testShapeMismatch(self): - with self.assertRaises(ValueError): - tf.nn.softmax([0., 1., 2., 3.]) - with self.assertRaises(ValueError): - tf.nn.log_softmax([0., 1., 2., 3.]) - def _testOverflow(self, use_gpu=False): if use_gpu: - type = np.float32 + type = np.float32 else: - type = np.float64 + type = np.float64 max = np.finfo(type).max features = np.array( [[1., 1., 1., 1.], @@ -128,13 +124,55 @@ class SoftmaxTest(tf.test.TestCase): use_gpu=False) self._testOverflow(use_gpu=False) + def test1DTesnorAsInput(self): + self._testSoftmax( + np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False) + self._testOverflow(use_gpu=False) + + def test3DTensorAsInput(self): + self._testSoftmax( + np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]], + [[2., 3., 4., 5.], [6., 7., 8., 9.]], + [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32), + use_gpu=False) + self._testOverflow(use_gpu=False) + + def testAlongFirstDimension(self): + self._testSoftmax( + np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]], + [[2., 3., 4., 5.], [6., 7., 8., 9.]], + [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32), + dim=0, + use_gpu=False) + self._testOverflow(use_gpu=False) + + def testAlongSecondDimension(self): + self._testSoftmax( + np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]], + [[2., 3., 4., 5.], [6., 7., 8., 9.]], + [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32), + dim=1, + use_gpu=False) + self._testOverflow(use_gpu=False) + + def testShapeInference(self): + op = tf.nn.softmax([[[1., 1., 1., 1.], [1., 2., 3., 4.]], + [[2., 3., 4., 5.], [6., 7., 8., 9.]], + [[5., 4., 3., 2.], [1., 2., 3., 4.]]]) + self.assertEqual([3, 2, 4], op.get_shape()) - def testEmpty(self): + def testEmptyInput(self): with self.test_session(): x = tf.constant([[]], shape=[0, 3]) self.assertEqual(0, tf.size(x).eval()) - expected_y = np.array([]).reshape(0, 3) - np.testing.assert_array_equal(expected_y, tf.nn.softmax(x).eval()) + # reshape would raise if logits is empty + with self.assertRaises(tf.errors.InvalidArgumentError): + tf.nn.softmax(x, dim=0).eval() + + def testDimTooLarge(self): + with self.test_session(): + with self.assertRaises(tf.errors.InvalidArgumentError): + tf.nn.softmax([1., 2., 3., 4.], dim=100).eval() if __name__ == "__main__": diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index c56db8f5d9..3a94b9f4e3 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -179,6 +179,8 @@ BiasAddV1 Relu6 AvgPool MaxPool +Softmax +LogSoftmax # parsing_ops ParseExample diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 11ca3ff0d7..f3805c3f2d 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -468,6 +468,129 @@ def relu6(features, name=None): features = ops.convert_to_tensor(features, name="features") return gen_nn_ops._relu6(features, name=name) +def _softmax(logits, compute_op, dim=-1, name=None): + """Helper function for softmax and log_softmax. + + It reshapes and transposes the input logits into a 2-D Tensor and then invokes + the tf.nn._softmax or tf.nn._log_softmax function. The output would be + transposed and reshaped back. + + Args: + logits: A non-empty `Tensor`. Must be one of the following types: `half`, + `float32`, `float64`. + compute_op: Either gen_nn_ops._softmax or gen_nn_ops._log_softmax + dim: The dimension softmax would be performed on. The default is -1 which + indicates the last dimension. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `logits`. Same shape as `logits`. + Raises: + InvalidArgumentError: if `logits` is empty or `dim` is beyond the last + dimension of `logits`. + """ + # Helper function to swap dim_index and last_index of logits. last_index must + # be logits' last dimension. + def _swap_axis(logits, dim_index, last_index): + return array_ops.transpose(logits, array_ops.concat( + 0, [math_ops.range(dim_index), [last_index], + math_ops.range(dim_index + 1, last_index), [dim_index]])) + + # Helper function to flatten logits' outer dimensions and keep its last + # dimension. + def _flatten_outer_dims(logits): + rank = array_ops.rank(logits) + last_dim_size = array_ops.slice( + array_ops.shape(logits), [math_ops.sub(rank, 1)], [1]) + return array_ops.reshape(logits, array_ops.concat(0, [[-1], last_dim_size])) + + logits = ops.convert_to_tensor(logits) + if logits.get_shape().ndims is 2 and dim is -1: + return compute_op(logits, name=name) + + # We need its original shape for shape inference. + shape = logits.get_shape() + + # If dim is the last dimension, simply reshape the logits to a matrix and + # apply the internal softmax. + if dim is -1: + input_shape = array_ops.shape(logits) + logits = _flatten_outer_dims(logits) + output = compute_op(logits, name=name) + output = array_ops.reshape(output, input_shape) + return output + + # If dim is not the last dimension, we have to do a reshape and transpose so + # that we can still perform softmax on its last dimension. + + # Swap logits' dimension of dim and its last dimension. + input_rank = array_ops.rank(logits) + logits = _swap_axis(logits, dim, math_ops.sub(input_rank, 1)) + shape_after_swap = array_ops.shape(logits) + + # Reshape logits into a matrix. + logits = _flatten_outer_dims(logits) + + # Do the actual softmax on its last dimension. + output = compute_op(logits, name=name) + + # Transform back the output tensor. + output = array_ops.reshape(output, shape_after_swap) + output = _swap_axis(output, dim, math_ops.sub(input_rank, 1)) + + # Make shape inference work since reshape and transpose may erase its static + # shape. + output.set_shape(shape) + + return output + + +def softmax(logits, dim=-1, name=None): + """Computes log softmax activations. + + For each batch `i` and class `j` we have + + softmax = exp(logits) / reduce_sum(exp(logits), dim) + + Args: + logits: A non-empty `Tensor`. Must be one of the following types: `half`, + `float32`, `float64`. + dim: The dimension softmax would be performed on. The default is -1 which + indicates the last dimension. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `logits`. Same shape as `logits`. + Raises: + InvalidArgumentError: if `logits` is empty or `dim` is beyond the last + dimension of `logits`. + """ + return _softmax(logits, gen_nn_ops._softmax, dim, name) + + +def log_softmax(logits, dim=-1, name=None): + """Computes log softmax activations. + + For each batch `i` and class `j` we have + + logsoftmax = logits - reduce_sum(exp(logits), dim) + + Args: + logits: A non-empty `Tensor`. Must be one of the following types: `half`, + `float32`, `float64`. + dim: The dimension softmax would be performed on. The default is -1 which + indicates the last dimension. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `logits`. Same shape as `logits`. + + Raises: + InvalidArgumentError: if `logits` is empty or `dim` is beyond the last + dimension of `logits`. + """ + return _softmax(logits, gen_nn_ops._log_softmax, dim, name) + def softmax_cross_entropy_with_logits(logits, labels, name=None): """Computes softmax cross entropy between `logits` and `labels`. @@ -727,9 +850,11 @@ def _LRNGradShape(op): return [in_grads_shape.merge_with(in_image_shape).merge_with(out_image_shape)] -ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank(2)) +ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank_at_least( + 1)) -ops.RegisterShape("LogSoftmax")(common_shapes.unchanged_shape_with_rank(2)) +ops.RegisterShape("LogSoftmax")( + common_shapes.unchanged_shape_with_rank_at_least(1)) @ops.RegisterShape("InTopK") |