diff options
Diffstat (limited to 'tensorflow/contrib/lite/testing/tflite_diff_example_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/testing/tflite_diff_example_test.cc | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc index 5afa0f800c..f2c49fe389 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc +++ b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc @@ -20,12 +20,29 @@ int main(int argc, char** argv) { ::tflite::testing::DiffOptions options = ::tflite::testing::ParseTfliteDiffFlags(&argc, argv); if (options.tensorflow_model.empty()) return 1; + int failure_count = 0; - for (int i = 0; i < 100; i++) { - if (!tflite::testing::RunDiffTest(options)) { + for (int i = 0; i < options.num_runs_per_pass; i++) { + if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/1)) { ++failure_count; } } - fprintf(stderr, "Num errors: %d\n", failure_count); + int failures_in_first_pass = failure_count; + + if (failure_count == 0) { + // Let's try again with num_invocations > 1 to make sure we can do multiple + // invocations without resetting the interpreter. + for (int i = 0; i < options.num_runs_per_pass; i++) { + if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/2)) { + ++failure_count; + } + } + } + + fprintf(stderr, "Num errors in single-inference pass: %d\n", + failures_in_first_pass); + fprintf(stderr, "Num errors in multi-inference pass : %d\n", + failure_count - failures_in_first_pass); + return failure_count != 0 ? 1 : 0; } |