diff options
author | Jared Duke <jdduke@google.com> | 2018-09-17 16:32:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 16:39:01 -0700 |
commit | 0b80d098704c72f627f37bfeee0ae19788c06fa8 (patch) | |
tree | 1012464e6154492010c121b38aa52ac66054b935 /tensorflow/contrib/lite/mutable_op_resolver_test.cc | |
parent | 8ef1ece7d0ecdec633a22a8100fdae05cfbacb3e (diff) |
Add basic op resolver registration to TFLite C API
PiperOrigin-RevId: 213360279
Diffstat (limited to 'tensorflow/contrib/lite/mutable_op_resolver_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/mutable_op_resolver_test.cc | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/mutable_op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc index db690eaab9..b70c703839 100644 --- a/tensorflow/contrib/lite/mutable_op_resolver_test.cc +++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc @@ -36,6 +36,20 @@ TfLiteRegistration* GetDummyRegistration() { return ®istration; } +TfLiteStatus Dummy2Invoke(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteRegistration* GetDummy2Registration() { + static TfLiteRegistration registration = { + .init = nullptr, + .free = nullptr, + .prepare = nullptr, + .invoke = Dummy2Invoke, + }; + return ®istration; +} + TEST(MutableOpResolverTest, FinOp) { MutableOpResolver resolver; resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration()); @@ -119,6 +133,26 @@ TEST(MutableOpResolverTest, FindCustomOpWithUnsupportedVersion) { EXPECT_EQ(found_registration, nullptr); } +TEST(MutableOpResolverTest, AddAll) { + MutableOpResolver resolver1; + resolver1.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration()); + resolver1.AddBuiltin(BuiltinOperator_MUL, GetDummy2Registration()); + + MutableOpResolver resolver2; + resolver2.AddBuiltin(BuiltinOperator_SUB, GetDummyRegistration()); + resolver2.AddBuiltin(BuiltinOperator_ADD, GetDummy2Registration()); + + // resolver2's ADD op should replace resolver1's ADD op, while augmenting + // non-overlapping ops. + resolver1.AddAll(resolver2); + ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke, + GetDummy2Registration()->invoke); + ASSERT_EQ(resolver1.FindOp(BuiltinOperator_MUL, 1)->invoke, + GetDummy2Registration()->invoke); + ASSERT_EQ(resolver1.FindOp(BuiltinOperator_SUB, 1)->invoke, + GetDummyRegistration()->invoke); +} + } // namespace } // namespace tflite |