aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-16 15:28:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-16 15:35:39 -0700
commit8f216f537704d0077a0e8befe322e8293b1ed321 (patch)
treec07d13b45d1939fbefce07b94b0aa8b937eef2b1
parentee21903c9d15f4ab2d1ca5ba9b569b202e6f923c (diff)
Fixing test for Topk kernel in TFlite
PiperOrigin-RevId: 196899232
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2.cc3
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc1
2 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc
index b331fc8482..0feb42b85b 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2.cc
@@ -34,9 +34,8 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
// INT32 number of top results is supported.
TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
// Check that the tensor contains only one value.
- TF_LITE_ENSURE_EQ(context, NumDimensions(top_k), 1);
TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
- const int32 k = top_k->data.i32[0];
+ const int32 k = *GetTensorData<int32_t>(top_k);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const int num_dimensions = NumDimensions(input);
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 821f4de93b..c085ea28ea 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -307,6 +307,7 @@ INSTANTIATE_TESTS(split)
INSTANTIATE_TESTS(squeeze)
INSTANTIATE_TESTS(strided_slice)
INSTANTIATE_TESTS(sub)
+INSTANTIATE_TESTS(topk)
INSTANTIATE_TESTS(transpose)
INSTANTIATE_TESTS(transpose_conv)
INSTANTIATE_TESTS(where)