aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-15 15:29:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-15 15:33:32 -0700
commit6c62e650252ab32f83637a8de6720e73ffeca226 (patch)
treea4133a93cada7b18238d607bc4d5e551f9e685e6
parent239eb8b652f94b43d51f7c7ffdbbfc02ad094a9c (diff)
Pass error reporter to file copy allocation,
and avoid loading model from file twice PiperOrigin-RevId: 189256489
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc27
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java16
-rw-r--r--tensorflow/contrib/lite/model.cc74
-rw-r--r--tensorflow/contrib/lite/model.h37
-rw-r--r--tensorflow/contrib/lite/model_test.cc32
5 files changed, 131 insertions, 55 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 21bcff40bd..cc448b03c3 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -334,6 +334,19 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
return reinterpret_cast<jlong>(error_reporter);
}
+// Verifies whether the model is a flatbuffer file.
+class JNIFlatBufferVerifier : public tflite::TfLiteVerifier {
+ public:
+ bool Verify(const char* data, int length,
+ tflite::ErrorReporter* reporter) override {
+ if (!VerifyModel(data, length)) {
+ reporter->Report("The model is not a valid Flatbuffer file");
+ return false;
+ }
+ return true;
+ }
+};
+
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) {
@@ -342,17 +355,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
if (error_reporter == nullptr) return 0;
const char* path = env->GetStringUTFChars(model_file, nullptr);
- {
- tflite::FileCopyAllocation allocation(path, nullptr);
- if (!VerifyModel(allocation.base(), allocation.bytes())) {
- throwException(env, kIllegalArgumentException,
- "Contents of %s is not a valid flatbuffer model", path);
- env->ReleaseStringUTFChars(model_file, path);
- return 0;
- }
- }
+ std::unique_ptr<tflite::TfLiteVerifier> verifier;
+ verifier.reset(new JNIFlatBufferVerifier());
- auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter);
+ auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(
+ path, verifier.get(), error_reporter);
if (!model) {
throwException(env, kIllegalArgumentException,
"Contents of %s does not encode a valid TensorFlowLite "
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index d6b4e9f438..dbe45e5a05 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -47,6 +47,9 @@ public final class NativeInterpreterWrapperTest {
private static final String MODEL_WITH_CUSTOM_OP_PATH =
"tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite";
+ private static final String NONEXISTING_MODEL_PATH =
+ "tensorflow/contrib/lite/java/src/testdata/nonexisting_model.bin";
+
@Test
public void testConstructor() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
@@ -60,7 +63,18 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("is not a valid flatbuffer model");
+ assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
+ }
+ }
+
+ @Test
+ public void testConstructorWithNonexistingModel() {
+ try {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
+ assertThat(e).hasMessageThat().contains("Could not open");
}
}
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 3cf6bcbfcd..f28d56af67 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -32,11 +32,46 @@ namespace tflite {
const char* kEmptyTensorName = "";
+// Loads a model from `filename`. If `mmap_file` is true then use mmap,
+// otherwise make a copy of the model in a buffer.
+std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
+ bool mmap_file,
+ ErrorReporter* error_reporter,
+ bool use_nnapi) {
+ std::unique_ptr<Allocation> allocation;
+ if (mmap_file) {
+ if (use_nnapi && NNAPIExists())
+ allocation.reset(new NNAPIAllocation(filename, error_reporter));
+ else
+ allocation.reset(new MMAPAllocation(filename, error_reporter));
+ } else {
+ allocation.reset(new FileCopyAllocation(filename, error_reporter));
+ }
+ return allocation;
+}
+
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
const char* filename, ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model;
- model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter,
- /*use_nnapi=*/true));
+ auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
+ error_reporter, /*use_nnapi=*/true);
+ model.reset(new FlatBufferModel(allocation.release(), error_reporter));
+ if (!model->initialized()) model.reset();
+ return model;
+}
+
+std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
+ const char* filename, TfLiteVerifier* verifier,
+ ErrorReporter* error_reporter) {
+ std::unique_ptr<FlatBufferModel> model;
+ auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
+ error_reporter, /*use_nnapi=*/true);
+ if (verifier &&
+ !verifier->Verify(static_cast<const char*>(allocation->base()),
+ allocation->bytes(), error_reporter)) {
+ return model;
+ }
+ model.reset(new FlatBufferModel(allocation.release(), error_reporter));
if (!model->initialized()) model.reset();
return model;
}
@@ -44,7 +79,9 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model;
- model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter));
+ Allocation* allocation =
+ new MemoryAllocation(buffer, buffer_size, error_reporter);
+ model.reset(new FlatBufferModel(allocation, error_reporter));
if (!model->initialized()) model.reset();
return model;
}
@@ -57,23 +94,6 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
return model;
}
-FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
- ErrorReporter* error_reporter, bool use_nnapi)
- : error_reporter_(error_reporter ? error_reporter
- : DefaultErrorReporter()) {
- if (mmap_file) {
- if (use_nnapi && NNAPIExists())
- allocation_ = new NNAPIAllocation(filename, error_reporter);
- else
- allocation_ = new MMAPAllocation(filename, error_reporter);
- } else {
- allocation_ = new FileCopyAllocation(filename, error_reporter);
- }
- if (!allocation_->valid() || !CheckModelIdentifier()) return;
-
- model_ = ::tflite::GetModel(allocation_->base());
-}
-
bool FlatBufferModel::CheckModelIdentifier() const {
if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
@@ -85,21 +105,21 @@ bool FlatBufferModel::CheckModelIdentifier() const {
return true;
}
-FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
+FlatBufferModel::FlatBufferModel(const Model* model,
ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
- allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
- if (!allocation_->valid()) return;
-
- model_ = ::tflite::GetModel(allocation_->base());
+ model_ = model;
}
-FlatBufferModel::FlatBufferModel(const Model* model,
+FlatBufferModel::FlatBufferModel(Allocation* allocation,
ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
- model_ = model;
+ allocation_ = allocation;
+ if (!allocation_->valid() || !CheckModelIdentifier()) return;
+
+ model_ = ::tflite::GetModel(allocation_->base());
}
FlatBufferModel::~FlatBufferModel() { delete allocation_; }
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 8dc1c794dc..38eea0e26b 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -41,6 +41,17 @@ limitations under the License.
namespace tflite {
+// Abstract interface that verifies whether a given model is legit.
+// It facilitates the use-case to verify and build a model without loading it
+// twice.
+class TfLiteVerifier {
+ public:
+ // Returns true if the model is legit.
+ virtual bool Verify(const char* data, int length,
+ ErrorReporter* reporter) = 0;
+ virtual ~TfLiteVerifier() {}
+};
+
// An RAII object that represents a read-only tflite model, copied from disk,
// or mmapped. This uses flatbuffers as the serialization format.
class FlatBufferModel {
@@ -50,6 +61,12 @@ class FlatBufferModel {
const char* filename,
ErrorReporter* error_reporter = DefaultErrorReporter());
+ // Verifies whether the content of the file is legit, then builds a model
+ // based on the file. Returns a nullptr in case of failure.
+ static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
+ const char* filename, TfLiteVerifier* verifier = nullptr,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+
// Builds a model based on a pre-loaded flatbuffer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
// is destroyed. Returns a nullptr in case of failure.
@@ -82,23 +99,9 @@ class FlatBufferModel {
bool CheckModelIdentifier() const;
private:
- // Loads a model from `filename`. If `mmap_file` is true then use mmap,
- // otherwise make a copy of the model in a buffer.
- //
- // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
- // used.
- explicit FlatBufferModel(
- const char* filename, bool mmap_file = true,
- ErrorReporter* error_reporter = DefaultErrorReporter(),
- bool use_nnapi = false);
-
- // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has
- // to remain alive and unchanged until the end of this flatbuffermodel's
- // lifetime.
- //
- // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
- // used.
- FlatBufferModel(const char* ptr, size_t num_bytes,
+ // Loads a model from a given allocation. FlatBufferModel will take over the
+ // ownership of `allocation`, and delete it in desctructor.
+ FlatBufferModel(Allocation* allocation,
ErrorReporter* error_reporter = DefaultErrorReporter());
// Loads a model from Model flatbuffer. The `model` has to remain alive and
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index 66f22fd66a..ae6c1ece18 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -209,6 +209,38 @@ TEST(BasicFlatBufferModel, TestNullModel) {
ASSERT_EQ(interpreter.get(), nullptr);
}
+// Mocks the verifier by setting the result in ctor.
+class FakeVerifier : public tflite::TfLiteVerifier {
+ public:
+ explicit FakeVerifier(bool result) : result_(result) {}
+ bool Verify(const char* data, int length,
+ tflite::ErrorReporter* reporter) override {
+ return result_;
+ }
+
+ private:
+ bool result_;
+};
+
+TEST(BasicFlatBufferModel, TestWithTrueVerifier) {
+ FakeVerifier verifier(true);
+ ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin",
+ &verifier));
+}
+
+TEST(BasicFlatBufferModel, TestWithFalseVerifier) {
+ FakeVerifier verifier(false);
+ ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin",
+ &verifier));
+}
+
+TEST(BasicFlatBufferModel, TestWithNullVerifier) {
+ ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
+ "tensorflow/contrib/lite/testdata/test_model.bin", nullptr));
+}
+
struct TestErrorReporter : public ErrorReporter {
int Report(const char* format, va_list args) override {
calls++;