aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/gather.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/gather.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index f8df797daf..0e4187d1ea 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -42,9 +42,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32);
// Check that input and output types match.
TF_LITE_ENSURE_EQ(context, input->type, output->type);
- // TODO(mgubin): only 1D positions are currently supported.
- TF_LITE_ENSURE_EQ(context, NumDimensions(positions), 1);
+ // TODO(mgubin): only 0D or 1D positions are currently supported.
+ TF_LITE_ENSURE(context, NumDimensions(positions) <= 1);
// TODO(mgubin): Only default axis == 0 is supported.
+ TF_LITE_ENSURE_EQ(context, params->axis, 0);
// Check conditions for different types.
switch (input->type) {
case kTfLiteFloat32:
@@ -64,7 +65,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
const int num_dimensions =
NumDimensions(input) + NumDimensions(positions) - 1;
- TF_LITE_ENSURE(context, params->axis < num_dimensions);
+ TF_LITE_ENSURE(context, params->axis <= num_dimensions);
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
int output_index = 0;
for (int i = 0; i < params->axis; ++i) {