diff options
author | Jianwei Xie <xiejw@google.com> | 2018-01-24 10:02:35 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-24 10:06:06 -0800 |
commit | d9f93c42a50b1f1401d9c186eac0ae8dc9093c3b (patch) | |
tree | 178d1a692f56580c266139642b5a1d0d155c477e /tensorflow/contrib/lite/kernels/gather_test.cc | |
parent | 7b62a71e2d46c148df7d5704972f4592bc5e0f1b (diff) |
Merge changes from github.
PiperOrigin-RevId: 183100142
Diffstat (limited to 'tensorflow/contrib/lite/kernels/gather_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/gather_test.cc | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc index 6343d3b4ef..658d977b8d 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -48,8 +48,8 @@ class GatherOpModel : public SingleOpModel { PopulateStringTensor(input_, data); } - void SetPositions(std::initializer_list<int32> data) { - PopulateTensor<int32>(positions_, data); + void SetPositions(std::initializer_list<int> data) { + PopulateTensor<int>(positions_, data); } std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); } @@ -76,6 +76,29 @@ TEST(GatherOpTest, Shuffle) { ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2}))); } +TEST(GatherOpTest, Test0DIndex) { + GatherOpModel m({2, 2}, TensorType_FLOAT32, {}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({0.7, 0.8}))); + EXPECT_THAT(m.GetOutputShape(), + ElementsAreArray({2})); +} + +TEST(GatherOpTest, Test0DIndexWith0DResult) { + // 0D tensor is special case in current TFLite. Test it once to make sure + // existing workarounds are fine with it. + GatherOpModel m({3}, TensorType_FLOAT32, {}); + m.SetInputFloat({1.0, 2.0, 3.0}); + m.SetPositions({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({2.0}))); + EXPECT_TRUE(m.GetOutputShape().empty()); +} + TEST(FloatGatherOpTest, Duplicate) { GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2}); m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); |