aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/generate_testspec.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing/generate_testspec.h')
-rw-r--r--tensorflow/contrib/lite/testing/generate_testspec.h6
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_testspec.h b/tensorflow/contrib/lite/testing/generate_testspec.h
index 6e31a853c3..b3d0db31c0 100644
--- a/tensorflow/contrib/lite/testing/generate_testspec.h
+++ b/tensorflow/contrib/lite/testing/generate_testspec.h
@@ -19,6 +19,8 @@ limitations under the License.
#include <iostream>
#include <vector>
+#include "tensorflow/contrib/lite/string.h"
+
namespace tflite {
namespace testing {
@@ -30,13 +32,15 @@ namespace testing {
// stream: mutable iostream that contains the contents of test spec.
// tensorflow_model_path: path to TensorFlow model.
// tflite_model_path: path to tflite_model_path that the test spec runs
+// num_invocations: how many pairs of inputs and outputs will be generated.
// against. input_layer: names of input tensors. Example: input1
// input_layer_type: datatypes of input tensors. Example: float
// input_layer_shape: shapes of input tensors, separated by comma. example:
// 1,3,4 output_layer: names of output tensors. Example: output
bool GenerateTestSpecFromTensorflowModel(
std::iostream& stream, const string& tensorflow_model_path,
- const string& tflite_model_path, const std::vector<string>& input_layer,
+ const string& tflite_model_path, int num_invocations,
+ const std::vector<string>& input_layer,
const std::vector<string>& input_layer_type,
const std::vector<string>& input_layer_shape,
const std::vector<string>& output_layer);