aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/experimental
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-09-17 16:32:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 16:39:01 -0700
commit0b80d098704c72f627f37bfeee0ae19788c06fa8 (patch)
tree1012464e6154492010c121b38aa52ac66054b935 /tensorflow/contrib/lite/experimental
parent8ef1ece7d0ecdec633a22a8100fdae05cfbacb3e (diff)
Add basic op resolver registration to TFLite C API
PiperOrigin-RevId: 213360279
Diffstat (limited to 'tensorflow/contrib/lite/experimental')
-rw-r--r--tensorflow/contrib/lite/experimental/c/BUILD2
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.cc4
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.h3
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental.cc16
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental.h25
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc23
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_internal.h2
7 files changed, 70 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index ea4a543252..835fc2595e 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -68,6 +68,7 @@ cc_library(
deps = [
":c_api",
":c_api_internal",
+ "//tensorflow/contrib/lite:kernel_api",
],
)
@@ -93,6 +94,7 @@ cc_test(
deps = [
":c_api",
":c_api_experimental",
+ "//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index c589cf71ea..1c3996fb87 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -62,7 +62,11 @@ TFL_Interpreter* TFL_NewInterpreter(
return nullptr;
}
+ // TODO(b/111881878): Allow use of C API without pulling in all builtin ops.
tflite::ops::builtin::BuiltinOpResolver resolver;
+ if (optional_options) {
+ resolver.AddAll(optional_options->op_resolver);
+ }
tflite::InterpreterBuilder builder(*model->impl, resolver);
std::unique_ptr<tflite::Interpreter> interpreter;
if (builder(&interpreter) != kTfLiteOk) {
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h
index b429e76870..44b936aa87 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api.h
@@ -52,8 +52,9 @@ limitations under the License.
extern "C" {
#endif // __cplusplus
-typedef TfLiteTensor TFL_Tensor;
+typedef TfLiteRegistration TFL_Registration;
typedef TfLiteStatus TFL_Status;
+typedef TfLiteTensor TFL_Tensor;
typedef TfLiteType TFL_Type;
// --------------------------------------------------------------------------
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
index c4dbc55cbf..0f16595811 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
@@ -26,6 +26,22 @@ TFL_Status TFL_InterpreterResetVariableTensorsToZero(
return interpreter->impl->ResetVariableTensorsToZero();
}
+void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options,
+ TFL_BuiltinOperator op,
+ const TFL_Registration* registration,
+ int32_t min_version,
+ int32_t max_version) {
+ options->op_resolver.AddBuiltin(static_cast<tflite::BuiltinOperator>(op),
+ registration, min_version, max_version);
+}
+
+void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options,
+ const char* name,
+ const TFL_Registration* registration,
+ int min_version, int max_version) {
+ options->op_resolver.AddCustom(name, registration, min_version, max_version);
+}
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
index b0ac258dcf..b8de7b9964 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
@@ -15,16 +15,41 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
+#include "tensorflow/contrib/lite/builtin_ops.h"
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
+typedef TfLiteBuiltinOperator TFL_BuiltinOperator;
+
// Resets all variable tensors to zero.
TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero(
TFL_Interpreter* interpreter);
+// Adds an op registration for a builtin operator.
+//
+// NOTE: The interpreter will make a copy of `registration` internally, so the
+// caller should ensure that its contents (function pointers, etc...) remain
+// valid for the duration of the interpreter's lifetime. A common practice is
+// making the provided TFL_Registration instance static.
+void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options,
+ TFL_BuiltinOperator op,
+ const TFL_Registration* registration,
+ int min_version, int max_version);
+
+// Adds an op registration for a custom operator.
+//
+// NOTE: The interpreter will make a copy of `registration` internally, so the
+// caller should ensure that its contents (function pointers, etc...) remain
+// valid for the duration of the interpreter's lifetime. A common practice is
+// making the provided TFL_Registration instance static.
+void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options,
+ const char* name,
+ const TFL_Registration* registration,
+ int min_version, int max_version);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
index db6e5251de..d86ad00d6d 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
@@ -16,25 +16,40 @@ limitations under the License.
#include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h"
#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/builtin_ops.h"
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
#include "tensorflow/contrib/lite/testing/util.h"
namespace {
+TfLiteRegistration* GetDummyRegistration() {
+ static TfLiteRegistration registration = {
+ .init = nullptr,
+ .free = nullptr,
+ .prepare = nullptr,
+ .invoke = [](TfLiteContext*, TfLiteNode*) { return kTfLiteOk; },
+ };
+ return &registration;
+}
+
TEST(CApiExperimentalSimple, Smoke) {
TFL_Model* model = TFL_NewModelFromFile(
"tensorflow/contrib/lite/testdata/add.bin");
ASSERT_NE(model, nullptr);
- TFL_Interpreter* interpreter =
- TFL_NewInterpreter(model, /*optional_options=*/nullptr);
+ TFL_InterpreterOptions* options = TFL_NewInterpreterOptions();
+ TFL_InterpreterOptionsAddBuiltinOp(options, kTfLiteBuiltinAdd,
+ GetDummyRegistration(), 1, 1);
+
+ TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
ASSERT_NE(interpreter, nullptr);
ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk);
-
EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk);
+ EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk);
- TFL_DeleteModel(model);
TFL_DeleteInterpreter(interpreter);
+ TFL_DeleteInterpreterOptions(options);
+ TFL_DeleteModel(model);
}
} // namespace
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
index 60c2e4e2cd..af675ac98a 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
@@ -33,6 +34,7 @@ struct TFL_InterpreterOptions {
kDefaultNumThreads = -1,
};
int num_threads = kDefaultNumThreads;
+ tflite::MutableOpResolver op_resolver;
};
struct TFL_Interpreter {