diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-20 13:57:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-20 14:00:54 -0800 |
commit | 0fad3428f4de84e10524a7ed5ed53b7e4b636edb (patch) | |
tree | 718b5b1fe2133a473493d80b5264827e1057891b /tensorflow/contrib/lite/kernels/activations_test.cc | |
parent | fdeab946c0c8146c8040d7e125e5ca9e41b0336a (diff) |
Basic LogSoftmax support
PiperOrigin-RevId: 186357933
Diffstat (limited to 'tensorflow/contrib/lite/kernels/activations_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/activations_test.cc | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index 68d49944e5..302e52b96d 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -313,6 +313,47 @@ TEST(QuantizedActivationsOpTest, Softmax2D) { kQuantizedTolerance))); } +// This contains the same test values as the Softmax test, but reference answer +// generated via the following snippet of python: +// logits1 = tf.constant([[0, -6, 2, 4],[3, -2, 10, 1]], dtype=tf.float32) +// logits2 = tf.constant([[0,-6],[2,4],[3,-2],[10,1]], dtype=tf.float32) +// lsm1 = tf.nn.log_softmax(logits1) +// lsm2 = tf.nn.log_softmax(logits2) +// with tf.Session() as sess: +// print('lsm1', sess.run(lsm1)) +// print('lsm2', sess.run(lsm2)) + +TEST(FloatActivationsOpTest, LogSoftmax) { + FloatActivationsOpModel m(BuiltinOperator_LOG_SOFTMAX, + /*input=*/{TensorType_FLOAT32, {2, 4}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + -4.14297, -10.14297, -2.14297, -.142971, // + -7.00104, -12.00104, -.00104087, -9.00104, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(BuiltinOperator_LOG_SOFTMAX, + /*input=*/{TensorType_FLOAT32, {4, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + -.00247565, -6.00247, // + -2.12692, -.126928, // + -.00671534, -5.00671, // + -.000123374, -9.00012, // + }))); +} + } // namespace } // namespace tflite |