aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/experimental/c/c_api.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/experimental/c/c_api.cc')
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.cc64
1 files changed, 48 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index 9d29e8b3e0..a4ab0e8c30 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
@@ -23,28 +24,55 @@ limitations under the License.
extern "C" {
#endif // __cplusplus
-struct _TFL_Interpreter {
- std::unique_ptr<tflite::Interpreter> impl;
-};
-
// LINT.IfChange
-TFL_Interpreter* TFL_NewInterpreter(const void* model_data,
- int32_t model_size) {
+TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) {
auto model = tflite::FlatBufferModel::BuildFromBuffer(
- static_cast<const char*>(model_data), static_cast<size_t>(model_size));
- if (!model) {
+ static_cast<const char*>(model_data), model_size);
+ return model ? new TFL_Model{std::move(model)} : nullptr;
+}
+
+TFL_Model* TFL_NewModelFromFile(const char* model_path) {
+ auto model = tflite::FlatBufferModel::BuildFromFile(model_path);
+ return model ? new TFL_Model{std::move(model)} : nullptr;
+}
+
+void TFL_DeleteModel(TFL_Model* model) { delete model; }
+
+TFL_InterpreterOptions* TFL_NewInterpreterOptions() {
+ return new TFL_InterpreterOptions{};
+}
+
+void TFL_DeleteInterpreterOptions(TFL_InterpreterOptions* options) {
+ delete options;
+}
+
+void TFL_InterpreterOptionsSetNumThreads(TFL_InterpreterOptions* options,
+ int32_t num_threads) {
+ options->num_threads = num_threads;
+}
+
+TFL_Interpreter* TFL_NewInterpreter(
+ const TFL_Model* model, const TFL_InterpreterOptions* optional_options) {
+ if (!model || !model->impl) {
return nullptr;
}
tflite::ops::builtin::BuiltinOpResolver resolver;
- tflite::InterpreterBuilder builder(*model, resolver);
- std::unique_ptr<tflite::Interpreter> interpreter_impl;
- if (builder(&interpreter_impl) != kTfLiteOk) {
+ tflite::InterpreterBuilder builder(*model->impl, resolver);
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ if (builder(&interpreter) != kTfLiteOk) {
return nullptr;
}
- return new TFL_Interpreter{std::move(interpreter_impl)};
+ if (optional_options) {
+ if (optional_options->num_threads !=
+ TFL_InterpreterOptions::kDefaultNumThreads) {
+ interpreter->SetNumThreads(optional_options->num_threads);
+ }
+ }
+
+ return new TFL_Interpreter{std::move(interpreter)};
}
void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; }
@@ -97,9 +125,13 @@ int32_t TFL_TensorDim(const TFL_Tensor* tensor, int32_t dim_index) {
size_t TFL_TensorByteSize(const TFL_Tensor* tensor) { return tensor->bytes; }
+void* TFL_TensorData(const TFL_Tensor* tensor) {
+ return static_cast<void*>(tensor->data.raw);
+}
+
TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data,
- int32_t input_data_size) {
- if (tensor->bytes != static_cast<size_t>(input_data_size)) {
+ size_t input_data_size) {
+ if (tensor->bytes != input_data_size) {
return kTfLiteError;
}
memcpy(tensor->data.raw, input_data, input_data_size);
@@ -107,8 +139,8 @@ TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data,
}
TFL_Status TFL_TensorCopyToBuffer(const TFL_Tensor* tensor, void* output_data,
- int32_t output_data_size) {
- if (tensor->bytes != static_cast<size_t>(output_data_size)) {
+ size_t output_data_size) {
+ if (tensor->bytes != output_data_size) {
return kTfLiteError;
}
memcpy(output_data, tensor->data.raw, output_data_size);