diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-04 14:28:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 14:39:10 -0700 |
commit | 9bea7a8aa991b63f7349514a5a2dc0d04d261f8f (patch) | |
tree | f5fd4c370554cbd7d6b8989ebd96a322a2010691 /tensorflow/contrib/lite/kernels/activations_test.cc | |
parent | ed643f5522774d8dcb98530cf241e94a86ae88c2 (diff) |
Add support for Softmax of 3D tensors
PiperOrigin-RevId: 211524810
Diffstat (limited to 'tensorflow/contrib/lite/kernels/activations_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/activations_test.cc | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index e577e3a762..9fa47e190a 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -339,6 +339,76 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { kQuantizedTolerance))); } +TEST(FloatActivationsOpTest, Softmax3D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {1, 2, 4}}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 1, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax3D) { + QuantizedActivationsOpModel m( + 0.1, + /*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10}); + m.SetInput<uint8_t>({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2( + 0.1, + /*input=*/{TensorType_UINT8, {4, 1, 2}, -10, 10}); + m2.SetInput<uint8_t>({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + TEST(FloatActivationsOpTest, Softmax1D) { FloatActivationsOpModel m(0.1, /*input=*/{TensorType_FLOAT32, {8}}); |