aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/model_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-15 15:29:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-15 15:33:32 -0700
commit6c62e650252ab32f83637a8de6720e73ffeca226 (patch)
treea4133a93cada7b18238d607bc4d5e551f9e685e6 /tensorflow/contrib/lite/model_test.cc
parent239eb8b652f94b43d51f7c7ffdbbfc02ad094a9c (diff)
Pass error reporter to file copy allocation,
and avoid loading model from file twice PiperOrigin-RevId: 189256489
Diffstat (limited to 'tensorflow/contrib/lite/model_test.cc')
-rw-r--r--tensorflow/contrib/lite/model_test.cc32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index 66f22fd66a..ae6c1ece18 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -209,6 +209,38 @@ TEST(BasicFlatBufferModel, TestNullModel) {
ASSERT_EQ(interpreter.get(), nullptr);
}
+// Mocks the verifier by setting the result in ctor.
+class FakeVerifier : public tflite::TfLiteVerifier {
+ public:
+ explicit FakeVerifier(bool result) : result_(result) {}
+ bool Verify(const char* data, int length,
+ tflite::ErrorReporter* reporter) override {
+ return result_;
+ }
+
+ private:
+ bool result_;
+};
+
+TEST(BasicFlatBufferModel, TestWithTrueVerifier) {
+ FakeVerifier verifier(true);
+ ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin",
+ &verifier));
+}
+
+TEST(BasicFlatBufferModel, TestWithFalseVerifier) {
+ FakeVerifier verifier(false);
+ ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin",
+ &verifier));
+}
+
+TEST(BasicFlatBufferModel, TestWithNullVerifier) {
+ ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin", nullptr));
+}
+
struct TestErrorReporter : public ErrorReporter {
int Report(const char* format, va_list args) override {
calls++;