diff options
author | 2018-01-22 16:41:37 -0800 | |
---|---|---|
committer | 2018-01-22 16:45:18 -0800 | |
commit | 2e1dfe4a288fe1258f2f497ae6a5874eff127f82 (patch) | |
tree | 0bb9d295d081191af91cccbd33b8d3b2219e10a3 /tensorflow/contrib/lite/kernels/squeeze_test.cc | |
parent | 9fc2c8ebc8adcbca10c9850a54f913f5e731429f (diff) |
Adds unit tests for squeeze, to test the case when all dims are 1.
PiperOrigin-RevId: 182857374
Diffstat (limited to 'tensorflow/contrib/lite/kernels/squeeze_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/squeeze_test.cc | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/squeeze_test.cc b/tensorflow/contrib/lite/kernels/squeeze_test.cc index 409227b626..a8aab88357 100644 --- a/tensorflow/contrib/lite/kernels/squeeze_test.cc +++ b/tensorflow/contrib/lite/kernels/squeeze_test.cc @@ -22,6 +22,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using ::testing::IsEmpty; class BaseSqueezeOpModel : public SingleOpModel { public: @@ -103,6 +104,16 @@ TEST(FloatSqueezeOpTest, SqueezeNegativeAxis) { 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); } +TEST(FloatSqueezeOpTest, SqueezeAllDims) { + std::initializer_list<float> data = {3.85}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 1, 1, 1, 1, 1, 1}}, + {TensorType_FLOAT32, {1}}, {}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.85})); +} + } // namespace } // namespace tflite |