aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/mutable_op_resolver_test.cc
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-09-17 16:32:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 16:39:01 -0700
commit0b80d098704c72f627f37bfeee0ae19788c06fa8 (patch)
tree1012464e6154492010c121b38aa52ac66054b935 /tensorflow/contrib/lite/mutable_op_resolver_test.cc
parent8ef1ece7d0ecdec633a22a8100fdae05cfbacb3e (diff)
Add basic op resolver registration to TFLite C API
PiperOrigin-RevId: 213360279
Diffstat (limited to 'tensorflow/contrib/lite/mutable_op_resolver_test.cc')
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver_test.cc34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/mutable_op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
index db690eaab9..b70c703839 100644
--- a/tensorflow/contrib/lite/mutable_op_resolver_test.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
@@ -36,6 +36,20 @@ TfLiteRegistration* GetDummyRegistration() {
return &registration;
}
+TfLiteStatus Dummy2Invoke(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteRegistration* GetDummy2Registration() {
+ static TfLiteRegistration registration = {
+ .init = nullptr,
+ .free = nullptr,
+ .prepare = nullptr,
+ .invoke = Dummy2Invoke,
+ };
+ return &registration;
+}
+
TEST(MutableOpResolverTest, FinOp) {
MutableOpResolver resolver;
resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
@@ -119,6 +133,26 @@ TEST(MutableOpResolverTest, FindCustomOpWithUnsupportedVersion) {
EXPECT_EQ(found_registration, nullptr);
}
+TEST(MutableOpResolverTest, AddAll) {
+ MutableOpResolver resolver1;
+ resolver1.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
+ resolver1.AddBuiltin(BuiltinOperator_MUL, GetDummy2Registration());
+
+ MutableOpResolver resolver2;
+ resolver2.AddBuiltin(BuiltinOperator_SUB, GetDummyRegistration());
+ resolver2.AddBuiltin(BuiltinOperator_ADD, GetDummy2Registration());
+
+ // resolver2's ADD op should replace resolver1's ADD op, while augmenting
+ // non-overlapping ops.
+ resolver1.AddAll(resolver2);
+ ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke,
+ GetDummy2Registration()->invoke);
+ ASSERT_EQ(resolver1.FindOp(BuiltinOperator_MUL, 1)->invoke,
+ GetDummy2Registration()->invoke);
+ ASSERT_EQ(resolver1.FindOp(BuiltinOperator_SUB, 1)->invoke,
+ GetDummyRegistration()->invoke);
+}
+
} // namespace
} // namespace tflite