diff options
author | Jared Duke <jdduke@google.com> | 2018-08-10 15:02:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-10 15:08:36 -0700 |
commit | 83a1435684149e381521de528c3af40daa784570 (patch) | |
tree | f0d4464d7f0566c143dd2a1f5c5457fecc0de462 /tensorflow/contrib/lite/experimental | |
parent | 3d0edd130c137088d7815e1f3e67719d05fd9ab1 (diff) |
Incremental update to the TFLite C API
PiperOrigin-RevId: 208273960
Diffstat (limited to 'tensorflow/contrib/lite/experimental')
9 files changed, 317 insertions, 33 deletions
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD index 50f8da66d0..8fc07e8eb7 100644 --- a/tensorflow/contrib/lite/experimental/c/BUILD +++ b/tensorflow/contrib/lite/experimental/c/BUILD @@ -26,17 +26,33 @@ tflite_cc_shared_object( }), deps = [ ":c_api", + ":c_api_experimental", ":exported_symbols.lds", ":version_script.lds", ], ) cc_library( + name = "c_api_internal", + srcs = ["c_api.h"], + hdrs = ["c_api_internal.h"], + copts = tflite_copts(), + visibility = [ + "//tensorflow/contrib/lite/experimental/c:__subpackages__", + ], + deps = [ + "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite:framework", + ], +) + +cc_library( name = "c_api", srcs = ["c_api.cc"], hdrs = ["c_api.h"], copts = tflite_copts(), deps = [ + ":c_api_internal", "//tensorflow/contrib/lite:context", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:schema_fbs_version", @@ -44,6 +60,17 @@ cc_library( ], ) +cc_library( + name = "c_api_experimental", + srcs = ["c_api_experimental.cc"], + hdrs = ["c_api_experimental.h"], + copts = tflite_copts(), + deps = [ + ":c_api", + ":c_api_internal", + ], +) + cc_test( name = "c_api_test", size = "small", @@ -51,9 +78,21 @@ cc_test( data = ["//tensorflow/contrib/lite:testdata/add.bin"], deps = [ ":c_api", - "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) + +cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = ["c_api_experimental_test.cc"], + data = ["//tensorflow/contrib/lite:testdata/add.bin"], + deps = [ + ":c_api", + ":c_api_experimental", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc index 9d29e8b3e0..a4ab0e8c30 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/experimental/c/c_api.h" #include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" @@ -23,28 +24,55 @@ limitations under the License. extern "C" { #endif // __cplusplus -struct _TFL_Interpreter { - std::unique_ptr<tflite::Interpreter> impl; -}; - // LINT.IfChange -TFL_Interpreter* TFL_NewInterpreter(const void* model_data, - int32_t model_size) { +TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) { auto model = tflite::FlatBufferModel::BuildFromBuffer( - static_cast<const char*>(model_data), static_cast<size_t>(model_size)); - if (!model) { + static_cast<const char*>(model_data), model_size); + return model ? new TFL_Model{std::move(model)} : nullptr; +} + +TFL_Model* TFL_NewModelFromFile(const char* model_path) { + auto model = tflite::FlatBufferModel::BuildFromFile(model_path); + return model ? new TFL_Model{std::move(model)} : nullptr; +} + +void TFL_DeleteModel(TFL_Model* model) { delete model; } + +TFL_InterpreterOptions* TFL_NewInterpreterOptions() { + return new TFL_InterpreterOptions{}; +} + +void TFL_DeleteInterpreterOptions(TFL_InterpreterOptions* options) { + delete options; +} + +void TFL_InterpreterOptionsSetNumThreads(TFL_InterpreterOptions* options, + int32_t num_threads) { + options->num_threads = num_threads; +} + +TFL_Interpreter* TFL_NewInterpreter( + const TFL_Model* model, const TFL_InterpreterOptions* optional_options) { + if (!model || !model->impl) { return nullptr; } tflite::ops::builtin::BuiltinOpResolver resolver; - tflite::InterpreterBuilder builder(*model, resolver); - std::unique_ptr<tflite::Interpreter> interpreter_impl; - if (builder(&interpreter_impl) != kTfLiteOk) { + tflite::InterpreterBuilder builder(*model->impl, resolver); + std::unique_ptr<tflite::Interpreter> interpreter; + if (builder(&interpreter) != kTfLiteOk) { return nullptr; } - return new TFL_Interpreter{std::move(interpreter_impl)}; + if (optional_options) { + if (optional_options->num_threads != + TFL_InterpreterOptions::kDefaultNumThreads) { + interpreter->SetNumThreads(optional_options->num_threads); + } + } + + return new TFL_Interpreter{std::move(interpreter)}; } void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; } @@ -97,9 +125,13 @@ int32_t TFL_TensorDim(const TFL_Tensor* tensor, int32_t dim_index) { size_t TFL_TensorByteSize(const TFL_Tensor* tensor) { return tensor->bytes; } +void* TFL_TensorData(const TFL_Tensor* tensor) { + return static_cast<void*>(tensor->data.raw); +} + TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data, - int32_t input_data_size) { - if (tensor->bytes != static_cast<size_t>(input_data_size)) { + size_t input_data_size) { + if (tensor->bytes != input_data_size) { return kTfLiteError; } memcpy(tensor->data.raw, input_data, input_data_size); @@ -107,8 +139,8 @@ TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data, } TFL_Status TFL_TensorCopyToBuffer(const TFL_Tensor* tensor, void* output_data, - int32_t output_data_size) { - if (tensor->bytes != static_cast<size_t>(output_data_size)) { + size_t output_data_size) { + if (tensor->bytes != output_data_size) { return kTfLiteError; } memcpy(output_data, tensor->data.raw, output_data_size); diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h index 070f1add13..3757349b55 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.h +++ b/tensorflow/contrib/lite/experimental/c/c_api.h @@ -30,6 +30,9 @@ limitations under the License. // // Conventions: // * We use the prefix TFL_ for everything in the API. +// * size_t is used to represent byte sizes of objects that are +// materialized in the address space of the calling process. +// * int is used as an index into arrays. #ifdef SWIG #define TFL_CAPI_EXPORT @@ -54,15 +57,50 @@ typedef TfLiteStatus TFL_Status; typedef TfLiteType TFL_Type; // -------------------------------------------------------------------------- +// TFL_Model wraps a loaded TensorFlow Lite model. +typedef struct TFL_Model TFL_Model; + +// Returns a model from the provided buffer, or null on failure. +TFL_CAPI_EXPORT extern TFL_Model* TFL_NewModel(const void* model_data, + size_t model_size); + +// Returns a model from the provided file, or null on failure. +TFL_CAPI_EXPORT extern TFL_Model* TFL_NewModelFromFile(const char* model_path); + +// Destroys the model instance. +TFL_CAPI_EXPORT extern void TFL_DeleteModel(TFL_Model* model); + +// -------------------------------------------------------------------------- +// TFL_InterpreterOptions allows customized interpreter configuration. +typedef struct TFL_InterpreterOptions TFL_InterpreterOptions; + +// Returns a new interpreter options instances. +TFL_CAPI_EXPORT extern TFL_InterpreterOptions* TFL_NewInterpreterOptions(); + +// Destroys the interpreter options instance. +TFL_CAPI_EXPORT extern void TFL_DeleteInterpreterOptions( + TFL_InterpreterOptions* options); + +// Sets the number of CPU threads to use for the interpreter. +TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetNumThreads( + TFL_InterpreterOptions* options, int32_t num_threads); + +// -------------------------------------------------------------------------- // TFL_Interpreter provides inference from a provided model. -typedef struct _TFL_Interpreter TFL_Interpreter; +typedef struct TFL_Interpreter TFL_Interpreter; -// Returns an interpreter for the provided model, or null on failure. +// Returns a new interpreter using the provided model and options, or null on +// failure. +// +// * `model` must be a valid model instance. The caller retains ownership of the +// object, and can destroy it immediately after creating the interpreter. +// * `optional_options` may be null. The caller retains ownership of the object, +// and can safely destroy it immediately after creating the interpreter. // // NOTE: The client *must* explicitly allocate tensors before attempting to // access input tensor data or invoke the interpreter. TFL_CAPI_EXPORT extern TFL_Interpreter* TFL_NewInterpreter( - const void* model_data, int32_t model_size); + const TFL_Model* model, const TFL_InterpreterOptions* optional_options); // Destroys the interpreter. TFL_CAPI_EXPORT extern void TFL_DeleteInterpreter(TFL_Interpreter* interpreter); @@ -76,7 +114,8 @@ TFL_CAPI_EXPORT extern int TFL_InterpreterGetInputTensorCount( TFL_CAPI_EXPORT extern TFL_Tensor* TFL_InterpreterGetInputTensor( const TFL_Interpreter* interpreter, int32_t input_index); -// Attempts to resize the specified input tensor. +// Resizes the specified input tensor. +// // NOTE: After a resize, the client *must* explicitly allocate tensors before // attempting to access the resized tensor data or invoke the interpreter. // REQUIRES: 0 <= input_index < TFL_InterpreterGetInputTensorCount(tensor) @@ -131,16 +170,24 @@ TFL_CAPI_EXPORT extern int32_t TFL_TensorDim(const TFL_Tensor* tensor, // Returns the size of the underlying data in bytes. TFL_CAPI_EXPORT extern size_t TFL_TensorByteSize(const TFL_Tensor* tensor); +// Returns a pointer to the underlying data buffer. +// +// Note: The result may be null if tensors have not yet been allocated, e.g., +// if the Tensor has just been created or resized and `TFL_AllocateTensors()` +// has yet to be called, or if the output tensor is dynamically sized and the +// interpreter hasn't been invoked. +TFL_CAPI_EXPORT extern void* TFL_TensorData(const TFL_Tensor* tensor); + // Copies from the provided input buffer into the tensor's buffer. // REQUIRES: input_data_size == TFL_TensorByteSize(tensor) TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer( - TFL_Tensor* tensor, const void* input_data, int32_t input_data_size); + TFL_Tensor* tensor, const void* input_data, size_t input_data_size); // Copies to the provided output buffer from the tensor's buffer. // REQUIRES: output_data_size == TFL_TensorByteSize(tensor) TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyToBuffer( const TFL_Tensor* output_tensor, void* output_data, - int32_t output_data_size); + size_t output_data_size); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc new file mode 100644 index 0000000000..c4dbc55cbf --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h" + +#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +TFL_Status TFL_InterpreterResetVariableTensorsToZero( + TFL_Interpreter* interpreter) { + return interpreter->impl->ResetVariableTensorsToZero(); +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h new file mode 100644 index 0000000000..b0ac258dcf --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ + +#include "tensorflow/contrib/lite/experimental/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Resets all variable tensors to zero. +TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero( + TFL_Interpreter* interpreter); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc new file mode 100644 index 0000000000..db6e5251de --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h" + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/experimental/c/c_api.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace { + +TEST(CApiExperimentalSimple, Smoke) { + TFL_Model* model = TFL_NewModelFromFile( + "tensorflow/contrib/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + + TFL_Interpreter* interpreter = + TFL_NewInterpreter(model, /*optional_options=*/nullptr); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); + + EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk); + + TFL_DeleteModel(model); + TFL_DeleteInterpreter(interpreter); +} + +} // namespace + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h new file mode 100644 index 0000000000..c5c612a4c6 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ + +#include "tensorflow/contrib/lite/experimental/c/c_api.h" + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/model.h" + +// Internal structures used by the C API. These are likely to change and should +// not be depended on. + +struct TFL_Model { + std::unique_ptr<tflite::FlatBufferModel> impl; +}; + +struct TFL_InterpreterOptions { + enum { + kDefaultNumThreads = -1, + }; + int num_threads = kDefaultNumThreads; +}; + +struct TFL_Interpreter { + std::unique_ptr<tflite::Interpreter> impl; +}; + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc index bc925e00a6..a631dae890 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc @@ -18,22 +18,28 @@ limitations under the License. #include "tensorflow/contrib/lite/experimental/c/c_api.h" #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/testing/util.h" namespace { TEST(CApiSimple, Smoke) { - tflite::FileCopyAllocation model_file( - "tensorflow/contrib/lite/testdata/add.bin", - tflite::DefaultErrorReporter()); + TFL_Model* model = TFL_NewModelFromFile( + "tensorflow/contrib/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); - TFL_Interpreter* interpreter = - TFL_NewInterpreter(model_file.base(), model_file.bytes()); + TFL_InterpreterOptions* options = TFL_NewInterpreterOptions(); + ASSERT_NE(options, nullptr); + TFL_InterpreterOptionsSetNumThreads(options, 2); + + TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options); ASSERT_NE(interpreter, nullptr); - ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); + // The options/model can be deleted immediately after interpreter creation. + TFL_DeleteInterpreterOptions(options); + TFL_DeleteModel(model); + + ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); ASSERT_EQ(TFL_InterpreterGetInputTensorCount(interpreter), 1); ASSERT_EQ(TFL_InterpreterGetOutputTensorCount(interpreter), 1); diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs index ab966bae2e..b6905b5fbf 100644 --- a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs @@ -16,6 +16,8 @@ using System; using System.Runtime.InteropServices; using TFL_Interpreter = System.IntPtr; +using TFL_InterpreterOptions = System.IntPtr; +using TFL_Model = System.IntPtr; using TFL_Tensor = System.IntPtr; namespace TensorFlowLite @@ -32,7 +34,9 @@ namespace TensorFlowLite public Interpreter(byte[] modelData) { GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned); IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject(); - handle = TFL_NewInterpreter(modelDataPtr, modelData.Length); + TFL_Model model = TFL_NewModel(modelDataPtr, modelData.Length); + handle = TFL_NewInterpreter(model, /*options=*/IntPtr.Zero); + TFL_DeleteModel(model); if (handle == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter"); } @@ -89,9 +93,15 @@ namespace TensorFlowLite #region Externs [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Interpreter TFL_NewModel(IntPtr model_data, int model_size); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Interpreter TFL_DeleteModel(TFL_Model model); + + [DllImport (TensorFlowLibrary)] private static extern unsafe TFL_Interpreter TFL_NewInterpreter( - IntPtr model_data, - int model_size); + TFL_Model model, + TFL_InterpreterOptions optional_options); [DllImport (TensorFlowLibrary)] private static extern unsafe void TFL_DeleteInterpreter(TFL_Interpreter interpreter); |