diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/gather_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/gather_test.cc | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc index 6343d3b4ef..658d977b8d 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -48,8 +48,8 @@ class GatherOpModel : public SingleOpModel { PopulateStringTensor(input_, data); } - void SetPositions(std::initializer_list<int32> data) { - PopulateTensor<int32>(positions_, data); + void SetPositions(std::initializer_list<int> data) { + PopulateTensor<int>(positions_, data); } std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); } @@ -76,6 +76,29 @@ TEST(GatherOpTest, Shuffle) { ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2}))); } +TEST(GatherOpTest, Test0DIndex) { + GatherOpModel m({2, 2}, TensorType_FLOAT32, {}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({0.7, 0.8}))); + EXPECT_THAT(m.GetOutputShape(), + ElementsAreArray({2})); +} + +TEST(GatherOpTest, Test0DIndexWith0DResult) { + // 0D tensor is special case in current TFLite. Test it once to make sure + // existing workarounds are fine with it. + GatherOpModel m({3}, TensorType_FLOAT32, {}); + m.SetInputFloat({1.0, 2.0, 3.0}); + m.SetPositions({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({2.0}))); + EXPECT_TRUE(m.GetOutputShape().empty()); +} + TEST(FloatGatherOpTest, Duplicate) { GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2}); m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); |