aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2016-03-17 13:06:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-18 08:45:26 -0700
commit5cf9c5a5a6875b02b30166c13b1d5839fff18f43 (patch)
treec1eef926361e65e2b05ac923f793c2fe390bddb9
parentdba72e0338ad7a18ccc0ba2d99cfd73c6ec13fc0 (diff)
Added check for 0 length input before it can get to Eigen.
Change: 117482953
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py12
-rw-r--r--tensorflow/core/kernels/softmax_op.h8
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py7
3 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 2c024b7bce..68200db076 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -211,6 +211,18 @@ class FullyConnectedTest(tf.test.TestCase):
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
self.assertEqual(1, cnt[0])
+ def test_empty_x_results_in_empty_output(self):
+ # Empty x is common if someone masks their input with tf.boolean_mask in
+ # order to drop missing entries, and in a particular batch all entries are
+ # missing.
+ with self.test_session():
+ x = tf.constant([[]], shape=[0, 3])
+ self.assertEqual(0, tf.size(x).eval())
+ y = tf.contrib.layers.fully_connected(x, 2, activation_fn=tf.nn.softmax)
+ tf.initialize_all_variables().run()
+ expected_y = np.array([]).reshape(0,2)
+ np.testing.assert_array_equal(expected_y, y.eval())
+
class Convolution2dTest(tf.test.TestCase):
diff --git a/tensorflow/core/kernels/softmax_op.h b/tensorflow/core/kernels/softmax_op.h
index 480df0816b..e5e8c584fc 100644
--- a/tensorflow/core/kernels/softmax_op.h
+++ b/tensorflow/core/kernels/softmax_op.h
@@ -40,9 +40,11 @@ class SoftmaxOp : public OpKernel {
Tensor* softmax_out = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, logits_in.shape(), &softmax_out));
- functor::SoftmaxFunctor<Device, T> functor;
- functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
- softmax_out->matrix<T>());
+ if (logits_in.NumElements()) {
+ functor::SoftmaxFunctor<Device, T> functor;
+ functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
+ softmax_out->matrix<T>());
+ }
}
};
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index 00d22a0f8a..8b5447d8d3 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -77,6 +77,13 @@ class SoftmaxTest(tf.test.TestCase):
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
use_gpu=False)
+ def testEmpty(self):
+ with self.test_session():
+ x = tf.constant([[]], shape=[0, 3])
+ self.assertEqual(0, tf.size(x).eval())
+ expected_y = np.array([]).reshape(0, 3)
+ np.testing.assert_array_equal(expected_y, tf.nn.softmax(x).eval())
+
if __name__ == "__main__":
tf.test.main()