aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/activations_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-04 14:28:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 14:39:10 -0700
commit9bea7a8aa991b63f7349514a5a2dc0d04d261f8f (patch)
treef5fd4c370554cbd7d6b8989ebd96a322a2010691 /tensorflow/contrib/lite/kernels/activations_test.cc
parented643f5522774d8dcb98530cf241e94a86ae88c2 (diff)
Add support for Softmax of 3D tensors
PiperOrigin-RevId: 211524810
Diffstat (limited to 'tensorflow/contrib/lite/kernels/activations_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc70
1 files changed, 70 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index e577e3a762..9fa47e190a 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -339,6 +339,76 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
kQuantizedTolerance)));
}
+TEST(FloatActivationsOpTest, Softmax3D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4}});
+ m.SetInput({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ })));
+
+ // Same input, but a different shape.
+ FloatActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_FLOAT32, {4, 1, 2}});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax3D) {
+ QuantizedActivationsOpModel m(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10});
+ m.SetInput<uint8_t>({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ },
+ kQuantizedTolerance)));
+
+ // Same input, but a different shape.
+ QuantizedActivationsOpModel m2(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {4, 1, 2}, -10, 10});
+ m2.SetInput<uint8_t>({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
+}
+
TEST(FloatActivationsOpTest, Softmax1D) {
FloatActivationsOpModel m(0.1,
/*input=*/{TensorType_FLOAT32, {8}});