diff options
author | Jared Duke <jdduke@google.com> | 2018-09-17 16:32:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 16:39:01 -0700 |
commit | 0b80d098704c72f627f37bfeee0ae19788c06fa8 (patch) | |
tree | 1012464e6154492010c121b38aa52ac66054b935 /tensorflow/contrib/lite/experimental | |
parent | 8ef1ece7d0ecdec633a22a8100fdae05cfbacb3e (diff) |
Add basic op resolver registration to TFLite C API
PiperOrigin-RevId: 213360279
Diffstat (limited to 'tensorflow/contrib/lite/experimental')
7 files changed, 70 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD index ea4a543252..835fc2595e 100644 --- a/tensorflow/contrib/lite/experimental/c/BUILD +++ b/tensorflow/contrib/lite/experimental/c/BUILD @@ -68,6 +68,7 @@ cc_library( deps = [ ":c_api", ":c_api_internal", + "//tensorflow/contrib/lite:kernel_api", ], ) @@ -93,6 +94,7 @@ cc_test( deps = [ ":c_api", ":c_api_experimental", + "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc index c589cf71ea..1c3996fb87 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api.cc @@ -62,7 +62,11 @@ TFL_Interpreter* TFL_NewInterpreter( return nullptr; } + // TODO(b/111881878): Allow use of C API without pulling in all builtin ops. tflite::ops::builtin::BuiltinOpResolver resolver; + if (optional_options) { + resolver.AddAll(optional_options->op_resolver); + } tflite::InterpreterBuilder builder(*model->impl, resolver); std::unique_ptr<tflite::Interpreter> interpreter; if (builder(&interpreter) != kTfLiteOk) { diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h index b429e76870..44b936aa87 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.h +++ b/tensorflow/contrib/lite/experimental/c/c_api.h @@ -52,8 +52,9 @@ limitations under the License. extern "C" { #endif // __cplusplus -typedef TfLiteTensor TFL_Tensor; +typedef TfLiteRegistration TFL_Registration; typedef TfLiteStatus TFL_Status; +typedef TfLiteTensor TFL_Tensor; typedef TfLiteType TFL_Type; // -------------------------------------------------------------------------- diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc index c4dbc55cbf..0f16595811 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc @@ -26,6 +26,22 @@ TFL_Status TFL_InterpreterResetVariableTensorsToZero( return interpreter->impl->ResetVariableTensorsToZero(); } +void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options, + TFL_BuiltinOperator op, + const TFL_Registration* registration, + int32_t min_version, + int32_t max_version) { + options->op_resolver.AddBuiltin(static_cast<tflite::BuiltinOperator>(op), + registration, min_version, max_version); +} + +void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options, + const char* name, + const TFL_Registration* registration, + int min_version, int max_version) { + options->op_resolver.AddCustom(name, registration, min_version, max_version); +} + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h index b0ac258dcf..b8de7b9964 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h @@ -15,16 +15,41 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ #define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ +#include "tensorflow/contrib/lite/builtin_ops.h" #include "tensorflow/contrib/lite/experimental/c/c_api.h" #ifdef __cplusplus extern "C" { #endif // __cplusplus +typedef TfLiteBuiltinOperator TFL_BuiltinOperator; + // Resets all variable tensors to zero. TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero( TFL_Interpreter* interpreter); +// Adds an op registration for a builtin operator. +// +// NOTE: The interpreter will make a copy of `registration` internally, so the +// caller should ensure that its contents (function pointers, etc...) remain +// valid for the duration of the interpreter's lifetime. A common practice is +// making the provided TFL_Registration instance static. +void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options, + TFL_BuiltinOperator op, + const TFL_Registration* registration, + int min_version, int max_version); + +// Adds an op registration for a custom operator. +// +// NOTE: The interpreter will make a copy of `registration` internally, so the +// caller should ensure that its contents (function pointers, etc...) remain +// valid for the duration of the interpreter's lifetime. A common practice is +// making the provided TFL_Registration instance static. +void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options, + const char* name, + const TFL_Registration* registration, + int min_version, int max_version); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc index db6e5251de..d86ad00d6d 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc @@ -16,25 +16,40 @@ limitations under the License. #include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h" #include <gtest/gtest.h> +#include "tensorflow/contrib/lite/builtin_ops.h" #include "tensorflow/contrib/lite/experimental/c/c_api.h" #include "tensorflow/contrib/lite/testing/util.h" namespace { +TfLiteRegistration* GetDummyRegistration() { + static TfLiteRegistration registration = { + .init = nullptr, + .free = nullptr, + .prepare = nullptr, + .invoke = [](TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }, + }; + return ®istration; +} + TEST(CApiExperimentalSimple, Smoke) { TFL_Model* model = TFL_NewModelFromFile( "tensorflow/contrib/lite/testdata/add.bin"); ASSERT_NE(model, nullptr); - TFL_Interpreter* interpreter = - TFL_NewInterpreter(model, /*optional_options=*/nullptr); + TFL_InterpreterOptions* options = TFL_NewInterpreterOptions(); + TFL_InterpreterOptionsAddBuiltinOp(options, kTfLiteBuiltinAdd, + GetDummyRegistration(), 1, 1); + + TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options); ASSERT_NE(interpreter, nullptr); ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); - EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk); + EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk); - TFL_DeleteModel(model); TFL_DeleteInterpreter(interpreter); + TFL_DeleteInterpreterOptions(options); + TFL_DeleteModel(model); } } // namespace diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h index 60c2e4e2cd..af675ac98a 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h +++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" // Internal structures used by the C API. These are likely to change and should // not be depended on. @@ -33,6 +34,7 @@ struct TFL_InterpreterOptions { kDefaultNumThreads = -1, }; int num_threads = kDefaultNumThreads; + tflite::MutableOpResolver op_resolver; }; struct TFL_Interpreter { |