diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/gather.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/gather.cc | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc index f8df797daf..0e4187d1ea 100644 --- a/tensorflow/contrib/lite/kernels/gather.cc +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -42,9 +42,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); // Check that input and output types match. TF_LITE_ENSURE_EQ(context, input->type, output->type); - // TODO(mgubin): only 1D positions are currently supported. - TF_LITE_ENSURE_EQ(context, NumDimensions(positions), 1); + // TODO(mgubin): only 0D or 1D positions are currently supported. + TF_LITE_ENSURE(context, NumDimensions(positions) <= 1); // TODO(mgubin): Only default axis == 0 is supported. + TF_LITE_ENSURE_EQ(context, params->axis, 0); // Check conditions for different types. switch (input->type) { case kTfLiteFloat32: @@ -64,7 +65,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } const int num_dimensions = NumDimensions(input) + NumDimensions(positions) - 1; - TF_LITE_ENSURE(context, params->axis < num_dimensions); + TF_LITE_ENSURE(context, params->axis <= num_dimensions); TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); int output_index = 0; for (int i = 0; i < params->axis; ++i) { |