aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/activations_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-20 13:57:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-20 14:00:54 -0800
commit0fad3428f4de84e10524a7ed5ed53b7e4b636edb (patch)
tree718b5b1fe2133a473493d80b5264827e1057891b /tensorflow/contrib/lite/kernels/activations_test.cc
parentfdeab946c0c8146c8040d7e125e5ca9e41b0336a (diff)
Basic LogSoftmax support
PiperOrigin-RevId: 186357933
Diffstat (limited to 'tensorflow/contrib/lite/kernels/activations_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc41
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index 68d49944e5..302e52b96d 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -313,6 +313,47 @@ TEST(QuantizedActivationsOpTest, Softmax2D) {
kQuantizedTolerance)));
}
+// This contains the same test values as the Softmax test, but reference answer
+// generated via the following snippet of python:
+// logits1 = tf.constant([[0, -6, 2, 4],[3, -2, 10, 1]], dtype=tf.float32)
+// logits2 = tf.constant([[0,-6],[2,4],[3,-2],[10,1]], dtype=tf.float32)
+// lsm1 = tf.nn.log_softmax(logits1)
+// lsm2 = tf.nn.log_softmax(logits2)
+// with tf.Session() as sess:
+// print('lsm1', sess.run(lsm1))
+// print('lsm2', sess.run(lsm2))
+
+TEST(FloatActivationsOpTest, LogSoftmax) {
+ FloatActivationsOpModel m(BuiltinOperator_LOG_SOFTMAX,
+ /*input=*/{TensorType_FLOAT32, {2, 4}});
+ m.SetInput({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ -4.14297, -10.14297, -2.14297, -.142971, //
+ -7.00104, -12.00104, -.00104087, -9.00104, //
+ })));
+
+ // Same input, but a different shape.
+ FloatActivationsOpModel m2(BuiltinOperator_LOG_SOFTMAX,
+ /*input=*/{TensorType_FLOAT32, {4, 2}});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ -.00247565, -6.00247, //
+ -2.12692, -.126928, //
+ -.00671534, -5.00671, //
+ -.000123374, -9.00012, //
+ })));
+}
+
} // namespace
} // namespace tflite