aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/tflite_diff_flags.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing/tflite_diff_flags.h')
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_flags.h8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
index 706108ed73..695c2a3de6 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_
#define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_
+#include <cstring>
+
#include "tensorflow/contrib/lite/testing/split.h"
#include "tensorflow/contrib/lite/testing/tflite_diff_util.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -30,6 +32,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
string input_layer_type;
string input_layer_shape;
string output_layer;
+ int32_t num_runs_per_pass = 100;
} values;
std::vector<tensorflow::Flag> flags = {
@@ -49,6 +52,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
tensorflow::Flag("output_layer", &values.output_layer,
"Names of output tensors, separated by comma. Example "
"output_1,output_2"),
+ tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass,
+ "Number of full runs in each pass."),
};
bool no_inputs = *argc == 1;
@@ -63,7 +68,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
Split<string>(values.input_layer, ","),
Split<string>(values.input_layer_type, ","),
Split<string>(values.input_layer_shape, ":"),
- Split<string>(values.output_layer, ",")};
+ Split<string>(values.output_layer, ","),
+ values.num_runs_per_pass};
}
} // namespace testing