/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include #include #include #include #include #include #include "tensorflow/contrib/lite/model.h" #include #include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/testing/util.h" // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, // we must declare this in global namespace, so argument-dependent operator // lookup works. inline bool operator==(const TfLiteRegistration& a, const TfLiteRegistration& b) { return a.invoke == b.invoke && a.init == b.init && a.prepare == b.prepare && a.free == b.free; } namespace tflite { // Provide a dummy operation that does nothing. namespace { void* dummy_init(TfLiteContext*, const char*, size_t) { return nullptr; } void dummy_free(TfLiteContext*, void*) {} TfLiteStatus dummy_resize(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } TfLiteStatus dummy_invoke(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } TfLiteRegistration dummy_reg = {dummy_init, dummy_free, dummy_resize, dummy_invoke}; } // namespace // Provide a trivial resolver that returns a constant value no matter what // op is asked for. class TrivialResolver : public OpResolver { public: explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr) : constant_return_(constant_return) {} // Find the op registration of a custom operator by op name. const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, int version) const override { return constant_return_; } // Find the op registration of a custom operator by op name. const TfLiteRegistration* FindOp(const char* op, int version) const override { return constant_return_; } private: TfLiteRegistration* constant_return_; }; TEST(BasicFlatBufferModel, TestNonExistantFiles) { ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234")); } // Make sure a model with nothing in it loads properly. TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) { auto model = FlatBufferModel::BuildFromFile( "tensorflow/contrib/lite/testdata/empty_model.bin"); ASSERT_TRUE(model); // Now try to build it into a model. std::unique_ptr interpreter; ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter), kTfLiteOk); ASSERT_NE(interpreter, nullptr); ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk); } // Make sure currently unsupported # of subgraphs are checked // TODO(aselle): Replace this test when multiple subgraphs are supported. TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) { auto m1 = FlatBufferModel::BuildFromFile( "tensorflow/contrib/lite/testdata/0_subgraphs.bin"); ASSERT_TRUE(m1); std::unique_ptr interpreter1; ASSERT_NE(InterpreterBuilder(*m1, TrivialResolver())(&interpreter1), kTfLiteOk); auto m2 = FlatBufferModel::BuildFromFile( "tensorflow/contrib/lite/testdata/2_subgraphs.bin"); ASSERT_TRUE(m2); std::unique_ptr interpreter2; ASSERT_NE(InterpreterBuilder(*m2, TrivialResolver())(&interpreter2), kTfLiteOk); } // Test what happens if we cannot bind any of the ops. TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) { auto model = FlatBufferModel::BuildFromFile( "tensorflow/contrib/lite/testdata/test_model.bin"); ASSERT_TRUE(model); // Check that we get an error code and interpreter pointer is reset. std::unique_ptr interpreter(new Interpreter); ASSERT_NE(InterpreterBuilder(*model, TrivialResolver(nullptr))(&interpreter), kTfLiteOk); ASSERT_EQ(interpreter, nullptr); } // Make sure model is read to interpreter propelrly TEST(BasicFlatBufferModel, TestModelInInterpreter) { auto model = FlatBufferModel::BuildFromFile( "tensorflow/contrib/lite/testdata/test_model.bin"); ASSERT_TRUE(model); // Check that we get an error code and interpreter pointer is reset. std::unique_ptr interpreter(new Interpreter); ASSERT_EQ( InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), kTfLiteOk); ASSERT_NE(interpreter, nullptr); ASSERT_EQ(interpreter->tensors_size(), 4); ASSERT_EQ(interpreter->nodes_size(), 2); std::vector inputs = {0, 1}; std::vector outputs = {2, 3}; ASSERT_EQ(interpreter->inputs(), inputs); ASSERT_EQ(interpreter->outputs(), outputs); EXPECT_EQ(std::string(interpreter->GetInputName(0)), "input0"); EXPECT_EQ(std::string(interpreter->GetInputName(1)), "input1"); EXPECT_EQ(std::string(interpreter->GetOutputName(0)), "out1"); EXPECT_EQ(std::string(interpreter->GetOutputName(1)), "out2"); // Make sure all input tensors are correct TfLiteTensor* i0 = interpreter->tensor(0); ASSERT_EQ(i0->type, kTfLiteFloat32); ASSERT_NE(i0->data.raw, nullptr); // mmapped ASSERT_EQ(i0->allocation_type, kTfLiteMmapRo); TfLiteTensor* i1 = interpreter->tensor(1); ASSERT_EQ(i1->type, kTfLiteFloat32); ASSERT_EQ(i1->data.raw, nullptr); ASSERT_EQ(i1->allocation_type, kTfLiteArenaRw); TfLiteTensor* o0 = interpreter->tensor(2); ASSERT_EQ(o0->type, kTfLiteFloat32); ASSERT_EQ(o0->data.raw, nullptr); ASSERT_EQ(o0->allocation_type, kTfLiteArenaRw); TfLiteTensor* o1 = interpreter->tensor(3); ASSERT_EQ(o1->type, kTfLiteFloat32); ASSERT_EQ(o1->data.raw, nullptr); ASSERT_EQ(o1->allocation_type, kTfLiteArenaRw); // Check op 0 which has inputs {0, 1} outputs {2}. { const std::pair* node_and_reg0 = interpreter->node_and_registration(0); ASSERT_NE(node_and_reg0, nullptr); const TfLiteNode& node0 = node_and_reg0->first; const TfLiteRegistration& reg0 = node_and_reg0->second; TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(2); desired_inputs->data[0] = 0; desired_inputs->data[1] = 1; TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1); desired_outputs->data[0] = 2; ASSERT_TRUE(TfLiteIntArrayEqual(node0.inputs, desired_inputs)); ASSERT_TRUE(TfLiteIntArrayEqual(node0.outputs, desired_outputs)); TfLiteIntArrayFree(desired_inputs); TfLiteIntArrayFree(desired_outputs); ASSERT_EQ(reg0, dummy_reg); } // Check op 1 which has inputs {2} outputs {3}. { const std::pair* node_and_reg1 = interpreter->node_and_registration(1); ASSERT_NE(node_and_reg1, nullptr); const TfLiteNode& node1 = node_and_reg1->first; const TfLiteRegistration& reg1 = node_and_reg1->second; TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(1); TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1); desired_inputs->data[0] = 2; desired_outputs->data[0] = 3; ASSERT_TRUE(TfLiteIntArrayEqual(node1.inputs, desired_inputs)); ASSERT_TRUE(TfLiteIntArrayEqual(node1.outputs, desired_outputs)); TfLiteIntArrayFree(desired_inputs); TfLiteIntArrayFree(desired_outputs); ASSERT_EQ(reg1, dummy_reg); } } // Test that loading a model with TensorFlow ops fails when the flex delegate is // not linked into the target. TEST(FlexModel, FailureWithoutFlexDelegate) { auto model = FlatBufferModel::BuildFromFile( "tensorflow/contrib/lite/testdata/multi_add_flex.bin"); ASSERT_TRUE(model); // Note that creation will succeed when using the BuiltinOpResolver, but // unless the appropriate delegate is linked into the target or the client // explicitly installs the delegate, execution will fail. std::unique_ptr interpreter; ASSERT_EQ(InterpreterBuilder(*model, ops::builtin::BuiltinOpResolver{})(&interpreter), kTfLiteOk); ASSERT_TRUE(interpreter); // As the flex ops weren't resolved implicitly by the flex delegate, runtime // allocation and execution will fail. ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError); } // This tests on a flatbuffer that defines a shape of 2 to be a memory mapped // buffer. But the buffer is provided to be only 1 element. TEST(BasicFlatBufferModel, TestBrokenMmap) { ASSERT_FALSE(FlatBufferModel::BuildFromFile( "tensorflow/contrib/lite/testdata/test_model_broken.bin")); } TEST(BasicFlatBufferModel, TestNullModel) { // Check that we get an error code and interpreter pointer is reset. std::unique_ptr interpreter(new Interpreter); ASSERT_NE( InterpreterBuilder(nullptr, TrivialResolver(&dummy_reg))(&interpreter), kTfLiteOk); ASSERT_EQ(interpreter.get(), nullptr); } // Mocks the verifier by setting the result in ctor. class FakeVerifier : public tflite::TfLiteVerifier { public: explicit FakeVerifier(bool result) : result_(result) {} bool Verify(const char* data, int length, tflite::ErrorReporter* reporter) override { return result_; } private: bool result_; }; TEST(BasicFlatBufferModel, TestWithTrueVerifier) { FakeVerifier verifier(true); ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile( "tensorflow/contrib/lite/testdata/test_model.bin", &verifier)); } TEST(BasicFlatBufferModel, TestWithFalseVerifier) { FakeVerifier verifier(false); ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile( "tensorflow/contrib/lite/testdata/test_model.bin", &verifier)); } TEST(BasicFlatBufferModel, TestWithNullVerifier) { ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile( "tensorflow/contrib/lite/testdata/test_model.bin", nullptr)); } // This makes sure the ErrorReporter is marshalled from FlatBufferModel to // the Interpreter. TEST(BasicFlatBufferModel, TestCustomErrorReporter) { TestErrorReporter reporter; auto model = FlatBufferModel::BuildFromFile( "tensorflow/contrib/lite/testdata/empty_model.bin", &reporter); ASSERT_TRUE(model); std::unique_ptr interpreter; TrivialResolver resolver; InterpreterBuilder(*model, resolver)(&interpreter); ASSERT_NE(interpreter->Invoke(), kTfLiteOk); ASSERT_EQ(reporter.num_calls(), 1); } // This makes sure the ErrorReporter is marshalled from FlatBufferModel to // the Interpreter. TEST(BasicFlatBufferModel, TestNullErrorReporter) { auto model = FlatBufferModel::BuildFromFile( "tensorflow/contrib/lite/testdata/empty_model.bin", nullptr); ASSERT_TRUE(model); std::unique_ptr interpreter; TrivialResolver resolver; InterpreterBuilder(*model, resolver)(&interpreter); ASSERT_NE(interpreter->Invoke(), kTfLiteOk); } // Test that loading model directly from a Model flatbuffer works. TEST(BasicFlatBufferModel, TestBuildFromModel) { TestErrorReporter reporter; FileCopyAllocation model_allocation( "tensorflow/contrib/lite/testdata/test_model.bin", &reporter); ASSERT_TRUE(model_allocation.valid()); ::flatbuffers::Verifier verifier( reinterpret_cast(model_allocation.base()), model_allocation.bytes()); ASSERT_TRUE(VerifyModelBuffer(verifier)); const Model* model_fb = ::tflite::GetModel(model_allocation.base()); auto model = FlatBufferModel::BuildFromModel(model_fb); ASSERT_TRUE(model); std::unique_ptr interpreter; ASSERT_EQ( InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), kTfLiteOk); ASSERT_NE(interpreter, nullptr); } // TODO(aselle): Add tests for serialization of builtin op data types. // These tests will occur with the evaluation tests of individual operators, // not here. } // namespace tflite int main(int argc, char** argv) { ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }