aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/tflite_diff_flags.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing/tflite_diff_flags.h')
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_flags.h27
1 files changed, 20 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
index 695c2a3de6..3874bc31d7 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
@@ -33,6 +33,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
string input_layer_shape;
string output_layer;
int32_t num_runs_per_pass = 100;
+ string delegate;
} values;
std::vector<tensorflow::Flag> flags = {
@@ -42,18 +43,21 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
"Path of tensorflow lite model."),
tensorflow::Flag("input_layer", &values.input_layer,
"Names of input tensors, separated by comma. Example: "
- "input_1,input_2"),
+ "input_1,input_2."),
tensorflow::Flag("input_layer_type", &values.input_layer_type,
"Data types of input tensors, separated by comma. "
- "Example: float,int"),
+ "Example: float,int."),
tensorflow::Flag(
"input_layer_shape", &values.input_layer_shape,
- "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2"),
+ "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2."),
tensorflow::Flag("output_layer", &values.output_layer,
- "Names of output tensors, separated by comma. Example "
- "output_1,output_2"),
+ "Names of output tensors, separated by comma. Example: "
+ "output_1,output_2."),
tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass,
- "Number of full runs in each pass."),
+ "[optional] Number of full runs in each pass."),
+ tensorflow::Flag("delegate", &values.delegate,
+ "[optional] Delegate to use for executing ops. Must be "
+ "`{\"\", EAGER}`"),
};
bool no_inputs = *argc == 1;
@@ -61,6 +65,14 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
if (!success || no_inputs || (*argc == 2 && !strcmp(argv[1], "--helpfull"))) {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
return {};
+ } else if (values.tensorflow_model.empty() || values.tflite_model.empty() ||
+ values.input_layer.empty() || values.input_layer_type.empty() ||
+ values.input_layer_shape.empty() || values.output_layer.empty()) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return {};
+ } else if (!(values.delegate == "" || values.delegate == "EAGER")) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return {};
}
return {values.tensorflow_model,
@@ -69,7 +81,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
Split<string>(values.input_layer_type, ","),
Split<string>(values.input_layer_shape, ":"),
Split<string>(values.output_layer, ","),
- values.num_runs_per_pass};
+ values.num_runs_per_pass,
+ values.delegate};
}
} // namespace testing