diff options
author | Martin Wicke <wicke@google.com> | 2016-03-17 13:06:04 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-03-18 08:45:26 -0700 |
commit | 5cf9c5a5a6875b02b30166c13b1d5839fff18f43 (patch) | |
tree | c1eef926361e65e2b05ac923f793c2fe390bddb9 | |
parent | dba72e0338ad7a18ccc0ba2d99cfd73c6ec13fc0 (diff) |
Added check for 0 length input before it can get to Eigen.
Change: 117482953
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers_test.py | 12 | ||||
-rw-r--r-- | tensorflow/core/kernels/softmax_op.h | 8 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/softmax_op_test.py | 7 |
3 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 2c024b7bce..68200db076 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -211,6 +211,18 @@ class FullyConnectedTest(tf.test.TestCase): tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) self.assertEqual(1, cnt[0]) + def test_empty_x_results_in_empty_output(self): + # Empty x is common if someone masks their input with tf.boolean_mask in + # order to drop missing entries, and in a particular batch all entries are + # missing. + with self.test_session(): + x = tf.constant([[]], shape=[0, 3]) + self.assertEqual(0, tf.size(x).eval()) + y = tf.contrib.layers.fully_connected(x, 2, activation_fn=tf.nn.softmax) + tf.initialize_all_variables().run() + expected_y = np.array([]).reshape(0,2) + np.testing.assert_array_equal(expected_y, y.eval()) + class Convolution2dTest(tf.test.TestCase): diff --git a/tensorflow/core/kernels/softmax_op.h b/tensorflow/core/kernels/softmax_op.h index 480df0816b..e5e8c584fc 100644 --- a/tensorflow/core/kernels/softmax_op.h +++ b/tensorflow/core/kernels/softmax_op.h @@ -40,9 +40,11 @@ class SoftmaxOp : public OpKernel { Tensor* softmax_out = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, logits_in.shape(), &softmax_out)); - functor::SoftmaxFunctor<Device, T> functor; - functor(context->eigen_device<Device>(), logits_in.matrix<T>(), - softmax_out->matrix<T>()); + if (logits_in.NumElements()) { + functor::SoftmaxFunctor<Device, T> functor; + functor(context->eigen_device<Device>(), logits_in.matrix<T>(), + softmax_out->matrix<T>()); + } } }; diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py index 00d22a0f8a..8b5447d8d3 100644 --- a/tensorflow/python/kernel_tests/softmax_op_test.py +++ b/tensorflow/python/kernel_tests/softmax_op_test.py @@ -77,6 +77,13 @@ class SoftmaxTest(tf.test.TestCase): np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64), use_gpu=False) + def testEmpty(self): + with self.test_session(): + x = tf.constant([[]], shape=[0, 3]) + self.assertEqual(0, tf.size(x).eval()) + expected_y = np.array([]).reshape(0, 3) + np.testing.assert_array_equal(expected_y, tf.nn.softmax(x).eval()) + if __name__ == "__main__": tf.test.main() |