aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/gather_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-02 15:58:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-02 16:02:26 -0700
commit0a75d72566b98ee0f91cdf4cb29efb918f46f4da (patch)
tree2f7c53aeb1de4c50b5d8c73dd36aad8eb990f5f4 /tensorflow/contrib/lite/kernels/gather_test.cc
parent6a948cb2f1ea06270b7fd223ba4c8501b831f838 (diff)
Remove reduntant checks in gather op.
PiperOrigin-RevId: 203027634
Diffstat (limited to 'tensorflow/contrib/lite/kernels/gather_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/gather_test.cc9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc
index cdadbeda18..1d4292955c 100644
--- a/tensorflow/contrib/lite/kernels/gather_test.cc
+++ b/tensorflow/contrib/lite/kernels/gather_test.cc
@@ -96,6 +96,15 @@ TEST(GatherOpTest, Test0DIndexWith0DResult) {
EXPECT_TRUE(m.GetOutputShape().empty());
}
+TEST(GatherOpTest, Test2DIndexWith2DResult) {
+ GatherOpModel m({3}, TensorType_FLOAT32, {1, 2});
+ m.SetInputFloat({1.0, 2.0, 3.0});
+ m.SetPositions({1, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({2.0, 1.0})));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+}
+
TEST(FloatGatherOpTest, Duplicate) {
GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2});
m.SetInputFloat({-2.0, 0.2, 0.7, 0.8});