aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/activations_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/activations_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index 587e1303da..083cdf78d7 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -339,6 +339,29 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
kQuantizedTolerance)));
}
+TEST(FloatActivationsOpTest, Softmax1D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {8}});
+ m.SetInput({0, -6, 2, 4, 3, -2, 10, 1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {.09752, .05352, .11911, .14548, .13164, .07984, .26509, .10778})));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax1D) {
+ QuantizedActivationsOpModel m(0.1,
+ /*input=*/{TensorType_UINT8, {8}, -10, 10});
+ m.SetInput<uint8_t>({0, -6, 2, 4, 3, -2, 10, 1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453,
+ 0.13281, 0.07813, 0.26563, 0.10938},
+ kQuantizedTolerance)));
+}
+
TEST(FloatActivationsOpTest, Softmax2D) {
FloatActivationsOpModel m(0.1,
/*input=*/{TensorType_FLOAT32, {2, 4}});