aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/activations_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-20 16:23:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 16:25:58 -0700
commit41781bad97698c29cd74203cef465d2adb2f04e8 (patch)
tree334a945f0b3f1bc11330fe7c32ec2e9b38afa750 /tensorflow/contrib/lite/kernels/activations_test.cc
parentbf020cb3160345a30f0551ffbd6c507e33753a1e (diff)
Add support for computing Softmax activation over tensors of rank 1.
PiperOrigin-RevId: 205470922
Diffstat (limited to 'tensorflow/contrib/lite/kernels/activations_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index 587e1303da..083cdf78d7 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -339,6 +339,29 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
kQuantizedTolerance)));
}
+TEST(FloatActivationsOpTest, Softmax1D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {8}});
+ m.SetInput({0, -6, 2, 4, 3, -2, 10, 1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {.09752, .05352, .11911, .14548, .13164, .07984, .26509, .10778})));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax1D) {
+ QuantizedActivationsOpModel m(0.1,
+ /*input=*/{TensorType_UINT8, {8}, -10, 10});
+ m.SetInput<uint8_t>({0, -6, 2, 4, 3, -2, 10, 1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453,
+ 0.13281, 0.07813, 0.26563, 0.10938},
+ kQuantizedTolerance)));
+}
+
TEST(FloatActivationsOpTest, Softmax2D) {
FloatActivationsOpModel m(0.1,
/*input=*/{TensorType_FLOAT32, {2, 4}});