diff options
Diffstat (limited to 'tensorflow/contrib/lite/testing/tflite_diff_flags.h')
-rw-r--r-- | tensorflow/contrib/lite/testing/tflite_diff_flags.h | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h index 706108ed73..695c2a3de6 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ #define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ +#include <cstring> + #include "tensorflow/contrib/lite/testing/split.h" #include "tensorflow/contrib/lite/testing/tflite_diff_util.h" #include "tensorflow/core/util/command_line_flags.h" @@ -30,6 +32,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { string input_layer_type; string input_layer_shape; string output_layer; + int32_t num_runs_per_pass = 100; } values; std::vector<tensorflow::Flag> flags = { @@ -49,6 +52,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { tensorflow::Flag("output_layer", &values.output_layer, "Names of output tensors, separated by comma. Example " "output_1,output_2"), + tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass, + "Number of full runs in each pass."), }; bool no_inputs = *argc == 1; @@ -63,7 +68,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { Split<string>(values.input_layer, ","), Split<string>(values.input_layer_type, ","), Split<string>(values.input_layer_shape, ":"), - Split<string>(values.output_layer, ",")}; + Split<string>(values.output_layer, ","), + values.num_runs_per_pass}; } } // namespace testing |