diff options
author | 2018-05-16 15:28:11 -0700 | |
---|---|---|
committer | 2018-05-16 15:35:39 -0700 | |
commit | 8f216f537704d0077a0e8befe322e8293b1ed321 (patch) | |
tree | c07d13b45d1939fbefce07b94b0aa8b937eef2b1 | |
parent | ee21903c9d15f4ab2d1ca5ba9b569b202e6f923c (diff) |
Fixing test for Topk kernel in TFlite
PiperOrigin-RevId: 196899232
-rw-r--r-- | tensorflow/contrib/lite/kernels/topk_v2.cc | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/generated_examples_zip_test.cc | 1 |
2 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc index b331fc8482..0feb42b85b 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2.cc @@ -34,9 +34,8 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { // INT32 number of top results is supported. TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32); // Check that the tensor contains only one value. - TF_LITE_ENSURE_EQ(context, NumDimensions(top_k), 1); TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1); - const int32 k = top_k->data.i32[0]; + const int32 k = *GetTensorData<int32_t>(top_k); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const int num_dimensions = NumDimensions(input); diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 821f4de93b..c085ea28ea 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -307,6 +307,7 @@ INSTANTIATE_TESTS(split) INSTANTIATE_TESTS(squeeze) INSTANTIATE_TESTS(strided_slice) INSTANTIATE_TESTS(sub) +INSTANTIATE_TESTS(topk) INSTANTIATE_TESTS(transpose) INSTANTIATE_TESTS(transpose_conv) INSTANTIATE_TESTS(where) |