aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/squeeze_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-22 16:41:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-22 16:45:18 -0800
commit2e1dfe4a288fe1258f2f497ae6a5874eff127f82 (patch)
tree0bb9d295d081191af91cccbd33b8d3b2219e10a3 /tensorflow/contrib/lite/kernels/squeeze_test.cc
parent9fc2c8ebc8adcbca10c9850a54f913f5e731429f (diff)
Adds unit tests for squeeze, to test the case when all dims are 1.
PiperOrigin-RevId: 182857374
Diffstat (limited to 'tensorflow/contrib/lite/kernels/squeeze_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/squeeze_test.cc11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/squeeze_test.cc b/tensorflow/contrib/lite/kernels/squeeze_test.cc
index 409227b626..a8aab88357 100644
--- a/tensorflow/contrib/lite/kernels/squeeze_test.cc
+++ b/tensorflow/contrib/lite/kernels/squeeze_test.cc
@@ -22,6 +22,7 @@ namespace tflite {
namespace {
using ::testing::ElementsAreArray;
+using ::testing::IsEmpty;
class BaseSqueezeOpModel : public SingleOpModel {
public:
@@ -103,6 +104,16 @@ TEST(FloatSqueezeOpTest, SqueezeNegativeAxis) {
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
}
+TEST(FloatSqueezeOpTest, SqueezeAllDims) {
+ std::initializer_list<float> data = {3.85};
+ FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 1, 1, 1, 1, 1, 1}},
+ {TensorType_FLOAT32, {1}}, {});
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.85}));
+}
+
} // namespace
} // namespace tflite