diff options
Diffstat (limited to 'tensorflow/contrib/lite/experimental/c/c_api.cc')
-rw-r--r-- | tensorflow/contrib/lite/experimental/c/c_api.cc | 64 |
1 files changed, 48 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc index 9d29e8b3e0..a4ab0e8c30 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/experimental/c/c_api.h" #include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" @@ -23,28 +24,55 @@ limitations under the License. extern "C" { #endif // __cplusplus -struct _TFL_Interpreter { - std::unique_ptr<tflite::Interpreter> impl; -}; - // LINT.IfChange -TFL_Interpreter* TFL_NewInterpreter(const void* model_data, - int32_t model_size) { +TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) { auto model = tflite::FlatBufferModel::BuildFromBuffer( - static_cast<const char*>(model_data), static_cast<size_t>(model_size)); - if (!model) { + static_cast<const char*>(model_data), model_size); + return model ? new TFL_Model{std::move(model)} : nullptr; +} + +TFL_Model* TFL_NewModelFromFile(const char* model_path) { + auto model = tflite::FlatBufferModel::BuildFromFile(model_path); + return model ? new TFL_Model{std::move(model)} : nullptr; +} + +void TFL_DeleteModel(TFL_Model* model) { delete model; } + +TFL_InterpreterOptions* TFL_NewInterpreterOptions() { + return new TFL_InterpreterOptions{}; +} + +void TFL_DeleteInterpreterOptions(TFL_InterpreterOptions* options) { + delete options; +} + +void TFL_InterpreterOptionsSetNumThreads(TFL_InterpreterOptions* options, + int32_t num_threads) { + options->num_threads = num_threads; +} + +TFL_Interpreter* TFL_NewInterpreter( + const TFL_Model* model, const TFL_InterpreterOptions* optional_options) { + if (!model || !model->impl) { return nullptr; } tflite::ops::builtin::BuiltinOpResolver resolver; - tflite::InterpreterBuilder builder(*model, resolver); - std::unique_ptr<tflite::Interpreter> interpreter_impl; - if (builder(&interpreter_impl) != kTfLiteOk) { + tflite::InterpreterBuilder builder(*model->impl, resolver); + std::unique_ptr<tflite::Interpreter> interpreter; + if (builder(&interpreter) != kTfLiteOk) { return nullptr; } - return new TFL_Interpreter{std::move(interpreter_impl)}; + if (optional_options) { + if (optional_options->num_threads != + TFL_InterpreterOptions::kDefaultNumThreads) { + interpreter->SetNumThreads(optional_options->num_threads); + } + } + + return new TFL_Interpreter{std::move(interpreter)}; } void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; } @@ -97,9 +125,13 @@ int32_t TFL_TensorDim(const TFL_Tensor* tensor, int32_t dim_index) { size_t TFL_TensorByteSize(const TFL_Tensor* tensor) { return tensor->bytes; } +void* TFL_TensorData(const TFL_Tensor* tensor) { + return static_cast<void*>(tensor->data.raw); +} + TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data, - int32_t input_data_size) { - if (tensor->bytes != static_cast<size_t>(input_data_size)) { + size_t input_data_size) { + if (tensor->bytes != input_data_size) { return kTfLiteError; } memcpy(tensor->data.raw, input_data, input_data_size); @@ -107,8 +139,8 @@ TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data, } TFL_Status TFL_TensorCopyToBuffer(const TFL_Tensor* tensor, void* output_data, - int32_t output_data_size) { - if (tensor->bytes != static_cast<size_t>(output_data_size)) { + size_t output_data_size) { + if (tensor->bytes != output_data_size) { return kTfLiteError; } memcpy(output_data, tensor->data.raw, output_data_size); |