diff options
Diffstat (limited to 'tensorflow/contrib/lite/testing/tflite_diff_flags.h')
-rw-r--r-- | tensorflow/contrib/lite/testing/tflite_diff_flags.h | 27 |
1 files changed, 20 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h index 695c2a3de6..3874bc31d7 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h @@ -33,6 +33,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { string input_layer_shape; string output_layer; int32_t num_runs_per_pass = 100; + string delegate; } values; std::vector<tensorflow::Flag> flags = { @@ -42,18 +43,21 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { "Path of tensorflow lite model."), tensorflow::Flag("input_layer", &values.input_layer, "Names of input tensors, separated by comma. Example: " - "input_1,input_2"), + "input_1,input_2."), tensorflow::Flag("input_layer_type", &values.input_layer_type, "Data types of input tensors, separated by comma. " - "Example: float,int"), + "Example: float,int."), tensorflow::Flag( "input_layer_shape", &values.input_layer_shape, - "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2"), + "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2."), tensorflow::Flag("output_layer", &values.output_layer, - "Names of output tensors, separated by comma. Example " - "output_1,output_2"), + "Names of output tensors, separated by comma. Example: " + "output_1,output_2."), tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass, - "Number of full runs in each pass."), + "[optional] Number of full runs in each pass."), + tensorflow::Flag("delegate", &values.delegate, + "[optional] Delegate to use for executing ops. Must be " + "`{\"\", EAGER}`"), }; bool no_inputs = *argc == 1; @@ -61,6 +65,14 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { if (!success || no_inputs || (*argc == 2 && !strcmp(argv[1], "--helpfull"))) { fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); return {}; + } else if (values.tensorflow_model.empty() || values.tflite_model.empty() || + values.input_layer.empty() || values.input_layer_type.empty() || + values.input_layer_shape.empty() || values.output_layer.empty()) { + fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); + return {}; + } else if (!(values.delegate == "" || values.delegate == "EAGER")) { + fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); + return {}; } return {values.tensorflow_model, @@ -69,7 +81,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { Split<string>(values.input_layer_type, ","), Split<string>(values.input_layer_shape, ":"), Split<string>(values.output_layer, ","), - values.num_runs_per_pass}; + values.num_runs_per_pass, + values.delegate}; } } // namespace testing |