aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/gather_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/gather_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/gather_test.cc27
1 files changed, 25 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc
index 6343d3b4ef..658d977b8d 100644
--- a/tensorflow/contrib/lite/kernels/gather_test.cc
+++ b/tensorflow/contrib/lite/kernels/gather_test.cc
@@ -48,8 +48,8 @@ class GatherOpModel : public SingleOpModel {
PopulateStringTensor(input_, data);
}
- void SetPositions(std::initializer_list<int32> data) {
- PopulateTensor<int32>(positions_, data);
+ void SetPositions(std::initializer_list<int> data) {
+ PopulateTensor<int>(positions_, data);
}
std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); }
@@ -76,6 +76,29 @@ TEST(GatherOpTest, Shuffle) {
ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2})));
}
+TEST(GatherOpTest, Test0DIndex) {
+ GatherOpModel m({2, 2}, TensorType_FLOAT32, {});
+ m.SetInputFloat({-2.0, 0.2, 0.7, 0.8});
+ m.SetPositions({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(),
+ ElementsAreArray(ArrayFloatNear({0.7, 0.8})));
+ EXPECT_THAT(m.GetOutputShape(),
+ ElementsAreArray({2}));
+}
+
+TEST(GatherOpTest, Test0DIndexWith0DResult) {
+ // 0D tensor is special case in current TFLite. Test it once to make sure
+ // existing workarounds are fine with it.
+ GatherOpModel m({3}, TensorType_FLOAT32, {});
+ m.SetInputFloat({1.0, 2.0, 3.0});
+ m.SetPositions({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(),
+ ElementsAreArray(ArrayFloatNear({2.0})));
+ EXPECT_TRUE(m.GetOutputShape().empty());
+}
+
TEST(FloatGatherOpTest, Duplicate) {
GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2});
m.SetInputFloat({-2.0, 0.2, 0.7, 0.8});