From 9982fd6c8831cbd2f58954f79ea71f26660393bc Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Fri, 7 Sep 2018 17:36:59 -0700 Subject: Modularize TF Lite interface definitions and reorganize file structure PiperOrigin-RevId: 212064501 --- tensorflow/contrib/lite/c/BUILD | 39 ++ tensorflow/contrib/lite/c/builtin_op_data.h | 298 +++++++++++++ tensorflow/contrib/lite/c/builtin_op_data_test.cc | 83 ++++ tensorflow/contrib/lite/c/c_api_internal.c | 104 +++++ tensorflow/contrib/lite/c/c_api_internal.h | 491 ++++++++++++++++++++++ tensorflow/contrib/lite/c/c_api_internal_test.cc | 73 ++++ 6 files changed, 1088 insertions(+) create mode 100644 tensorflow/contrib/lite/c/BUILD create mode 100644 tensorflow/contrib/lite/c/builtin_op_data.h create mode 100644 tensorflow/contrib/lite/c/builtin_op_data_test.cc create mode 100644 tensorflow/contrib/lite/c/c_api_internal.c create mode 100644 tensorflow/contrib/lite/c/c_api_internal.h create mode 100644 tensorflow/contrib/lite/c/c_api_internal_test.cc (limited to 'tensorflow/contrib/lite/c') diff --git a/tensorflow/contrib/lite/c/BUILD b/tensorflow/contrib/lite/c/BUILD new file mode 100644 index 0000000000..663eb63cad --- /dev/null +++ b/tensorflow/contrib/lite/c/BUILD @@ -0,0 +1,39 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "c_api_internal", + srcs = ["c_api_internal.c"], + hdrs = [ + "builtin_op_data.h", + "c_api_internal.h", + ], + visibility = [ + "//tensorflow/contrib/lite:__subpackages__", + ], +) + +# Test the C extension API code. +cc_test( + name = "c_api_internal_test", + size = "small", + srcs = ["c_api_internal_test.cc"], + deps = [ + ":c_api_internal", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "builtin_op_data_test", + size = "small", + srcs = ["builtin_op_data_test.cc"], + copts = ["-Wno-unused-variable"], + deps = [ + ":c_api_internal", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h new file mode 100644 index 0000000000..fa43e6a024 --- /dev/null +++ b/tensorflow/contrib/lite/c/builtin_op_data.h @@ -0,0 +1,298 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ + +#include + +#include "tensorflow/contrib/lite/c/c_api_internal.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TODO(aselle): Consider using "if this then that" for testing. + +// Possible padding types (for convolutions) +typedef enum { + kTfLitePaddingUnknown = 0, + kTfLitePaddingSame, + kTfLitePaddingValid, +} TfLitePadding; + +typedef struct { + int width; + int height; +} TfLitePaddingValues; + +// Possible fused activation functions. +// TODO(aselle): rename to TfLiteActivation +typedef enum { + kTfLiteActNone = 0, + kTfLiteActRelu, + kTfLiteActRelu1, + kTfLiteActRelu6, + kTfLiteActTanh, + kTfLiteActSignBit, + kTfLiteActSigmoid, +} TfLiteFusedActivation; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int dilation_width_factor; + int dilation_height_factor; + TfLiteFusedActivation activation; +} TfLiteConvParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int filter_width; + int filter_height; + TfLiteFusedActivation activation; + struct { + TfLitePaddingValues padding; + } computed; +} TfLitePoolParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int depth_multiplier; + TfLiteFusedActivation activation; +} TfLiteDepthwiseConvParams; + +typedef struct { + int rank; + TfLiteFusedActivation activation; +} TfLiteSVDFParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteRNNParams; + +typedef struct { + bool time_major; + TfLiteFusedActivation activation; +} TfLiteSequenceRNNParams; + +typedef enum { + kTfLiteFullyConnectedWeightsFormatDefault = 0, + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, +} TfLiteFullyConnectedWeightsFormat; + +typedef struct { + // Parameters for FullyConnected version 1 or above. + TfLiteFusedActivation activation; + + // Parameters for FullyConnected version 2 or above. + TfLiteFullyConnectedWeightsFormat weights_format; +} TfLiteFullyConnectedParams; + +typedef enum { + kTfLiteLshProjectionUnknown = 0, + kTfLiteLshProjectionSparse = 1, + kTfLiteLshProjectionDense = 2, +} TfLiteLSHProjectionType; + +typedef struct { + TfLiteLSHProjectionType type; +} TfLiteLSHProjectionParams; + +typedef struct { + float beta; +} TfLiteSoftmaxParams; + +typedef struct { + int axis; + TfLiteFusedActivation activation; +} TfLiteConcatenationParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteAddParams; + +typedef struct { +} TfLiteSpaceToBatchNDParams; + +typedef struct { +} TfLiteBatchToSpaceNDParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteMulParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteSubParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteDivParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteL2NormParams; + +typedef struct { + int radius; + float bias; + float alpha; + float beta; +} TfLiteLocalResponseNormParams; + +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + +typedef struct { + // Parameters for LSTM version 1. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; +} TfLiteLSTMParams; + +typedef struct { + bool align_corners; +} TfLiteResizeBilinearParams; + +typedef struct { +} TfLitePadParams; + +typedef struct { +} TfLitePadV2Params; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int shape[8]; + int num_dimensions; +} TfLiteReshapeParams; + +typedef struct { + int ngram_size; + int max_skip_size; + bool include_all_ngrams; +} TfLiteSkipGramParams; + +typedef struct { + int block_size; +} TfLiteSpaceToDepthParams; + +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + +typedef enum { + kTfLiteCombinerTypeSum = 0, + kTfLiteCombinerTypeMean = 1, + kTfLiteCombinerTypeSqrtn = 2, +} TfLiteCombinerType; + +typedef struct { + TfLiteCombinerType combiner; +} TfLiteEmbeddingLookupSparseParams; + +typedef struct { + int axis; +} TfLiteGatherParams; + +typedef struct { +} TfLiteTransposeParams; + +typedef struct { + bool keep_dims; +} TfLiteReducerParams; + +typedef struct { + int num_splits; +} TfLiteSplitParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int squeeze_dims[8]; + int num_squeeze_dims; +} TfLiteSqueezeParams; + +typedef struct { + int begin_mask; + int end_mask; + int ellipsis_mask; + int new_axis_mask; + int shrink_axis_mask; +} TfLiteStridedSliceParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMaxParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMinParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; +} TfLiteTransposeConvParams; + +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + +typedef struct { + TfLiteType out_type; +} TfLiteShapeParams; + +typedef struct { + // Parameters supported by version 1: + float min; + float max; + int num_bits; + + // Parameters supported by version 2: + bool narrow_range; +} TfLiteFakeQuantParams; + +typedef struct { + int values_count; + int axis; +} TfLitePackParams; + +typedef struct { + int axis; +} TfLiteOneHotParams; + +typedef struct { + int num; + int axis; +} TfLiteUnpackParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc new file mode 100644 index 0000000000..4d0ba75e68 --- /dev/null +++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include + +namespace tflite { + +// Builtin op data is just a set of data definitions, so the only meaningful +// test we can run is whether we can create the structs we expect to find. +// Testing each struct's members might be possible, but it seems unnecessary +// until we've locked down the API. The build rule has copts set to ignore the +// unused variable warning, since this is just a compilation test. +TEST(IntArray, CanCompileStructs) { + TfLitePadding padding = kTfLitePaddingSame; + TfLitePaddingValues padding_values; + TfLiteFusedActivation fused_activation = kTfLiteActRelu; + TfLiteConvParams conv_params; + TfLitePoolParams pool_params; + TfLiteDepthwiseConvParams depthwise_conv_params; + TfLiteSVDFParams svdf_params; + TfLiteRNNParams rnn_params; + TfLiteSequenceRNNParams sequence_rnn_params; + TfLiteFullyConnectedWeightsFormat fully_connected_weights_format = + kTfLiteFullyConnectedWeightsFormatDefault; + TfLiteFullyConnectedParams fully_connected_params; + TfLiteLSHProjectionType projection_type = kTfLiteLshProjectionDense; + TfLiteLSHProjectionParams projection_params; + TfLiteSoftmaxParams softmax_params; + TfLiteConcatenationParams concatenation_params; + TfLiteAddParams add_params; + TfLiteSpaceToBatchNDParams space_to_batch_nd_params; + TfLiteBatchToSpaceNDParams batch_to_space_nd_params; + TfLiteMulParams mul_params; + TfLiteSubParams sub_params; + TfLiteDivParams div_params; + TfLiteL2NormParams l2_norm_params; + TfLiteLocalResponseNormParams local_response_norm_params; + TfLiteLSTMKernelType lstm_kernel_type = kTfLiteLSTMBasicKernel; + TfLiteLSTMParams lstm_params; + TfLiteResizeBilinearParams resize_bilinear_params; + TfLitePadParams pad_params; + TfLitePadV2Params pad_v2_params; + TfLiteReshapeParams reshape_params; + TfLiteSkipGramParams skip_gram_params; + TfLiteSpaceToDepthParams space_to_depth_params; + TfLiteCastParams cast_params; + TfLiteCombinerType combiner_type = kTfLiteCombinerTypeSqrtn; + TfLiteEmbeddingLookupSparseParams lookup_sparse_params; + TfLiteGatherParams gather_params; + TfLiteTransposeParams transpose_params; + TfLiteReducerParams reducer_params; + TfLiteSplitParams split_params; + TfLiteSqueezeParams squeeze_params; + TfLiteStridedSliceParams strided_slice_params; + TfLiteArgMaxParams arg_max_params; + TfLiteArgMinParams arg_min_params; + TfLiteTransposeConvParams transpose_conv_params; + TfLiteSparseToDenseParams sparse_to_dense_params; + TfLiteShapeParams shape_params; + TfLiteFakeQuantParams fake_quant_params; + TfLitePackParams pack_params; + TfLiteOneHotParams one_hot_params; +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/c/c_api_internal.c b/tensorflow/contrib/lite/c/c_api_internal.c new file mode 100644 index 0000000000..1846bad4b7 --- /dev/null +++ b/tensorflow/contrib/lite/c/c_api_internal.c @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include +#include +#include + +int TfLiteIntArrayGetSizeInBytes(int size) { + static TfLiteIntArray dummy; + return sizeof(dummy) + sizeof(dummy.data[0]) * size; +} + +TfLiteIntArray* TfLiteIntArrayCreate(int size) { + TfLiteIntArray* ret = + (TfLiteIntArray*)malloc(TfLiteIntArrayGetSizeInBytes(size)); + ret->size = size; + return ret; +} + +void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) { + printf("%s: length=%d [", s, a->size); + if (a->size) printf("%d", a->data[0]); + int i = 1; + for (; i < a->size; i++) { + printf(" %d", a->data[i]); + } + printf("]\n"); +} + +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) { + if (a == b) return 1; + if (a == NULL || b == NULL) return 0; + if (a->size != b->size) return 0; + int i = 0; + for (; i < a->size; i++) + if (a->data[i] != b->data[i]) return 0; + return 1; +} + +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) { + if (!src) return NULL; + TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size); + if (ret) { + memcpy(ret->data, src->data, src->size * sizeof(int)); + } + return ret; +} + +void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); } + +void TfLiteTensorDataFree(TfLiteTensor* t) { + if (t->allocation_type == kTfLiteDynamic && t->data.raw) { + free(t->data.raw); + } + t->data.raw = NULL; +} + +void TfLiteTensorFree(TfLiteTensor* t) { + TfLiteTensorDataFree(t); + if (t->dims) TfLiteIntArrayFree(t->dims); + t->dims = NULL; +} + +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, bool is_variable, + TfLiteTensor* tensor) { + TfLiteTensorFree(tensor); + tensor->type = type; + tensor->name = name; + tensor->dims = dims; + tensor->params = quantization; + tensor->data.raw = buffer; + tensor->bytes = size; + tensor->allocation_type = allocation_type; + tensor->allocation = allocation; + tensor->is_variable = is_variable; +} + +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { + if (tensor->allocation_type != kTfLiteDynamic) { + return; + } + if (!tensor->data.raw) { + tensor->data.raw = malloc(num_bytes); + } else if (num_bytes > tensor->bytes) { + tensor->data.raw = realloc(tensor->data.raw, num_bytes); + } + tensor->bytes = num_bytes; +} diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h new file mode 100644 index 0000000000..48df68a654 --- /dev/null +++ b/tensorflow/contrib/lite/c/c_api_internal.h @@ -0,0 +1,491 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file defines a C API for implementing operations in tflite. +// These operations can be defined using c++ but the interface between +// the interpreter and the operations are C. +// +// Summary of abstractions +// TF_LITE_ENSURE - Self-sufficient error checking +// TfLiteStatus - Status reporting +// TfLiteIntArray - stores tensor shapes (dims), +// TfLiteContext - allows an op to access the tensors +// TfLiteTensor - tensor (a multidimensional array) +// TfLiteNode - a single node or operation +// TfLiteRegistration - the implementation of a conceptual operation. +// +// Some abstractions in this file are created and managed by Interpreter. +#ifndef TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ +#define TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; + +// The list of external context types known to TF Lite. This list exists solely +// to avoid conflicts and to ensure ops can share the external contexts they +// need. Access to the external contexts is controled by one of the +// corresponding support files. +typedef enum { + kTfLiteEigenContext = 0, // include eigen_support.h to use. + kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. + kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. + kTfLiteMaxExternalContexts = 3 +} TfLiteExternalContextType; + +// An external context is a collection of information unrelated to the TF Lite +// framework, but useful to a subset of the ops. TF Lite knows very little +// about about the actual contexts, but it keeps a list of them, and is able to +// refresh them if configurations like the number of recommended threads +// change. +typedef struct { + TfLiteExternalContextType type; + TfLiteStatus (*Refresh)(struct TfLiteContext* context); +} TfLiteExternalContext; + +// Forward declare so GetNode can use this is in Context. +typedef struct _TfLiteRegistration TfLiteRegistration; +typedef struct _TfLiteDelegate TfLiteDelegate; + +#define kOptionalTensor (-1) + +// Fixed size list of integers. Used for dimensions and inputs/outputs tensor +// indices +typedef struct { + int size; +// gcc 6.1+ have a bug where flexible members aren't properly handled +// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c +#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ + __GNUC_MINOR__ >= 1 + int data[0]; +#else + int data[]; +#endif +} TfLiteIntArray; + +// Given the size (number of elements) in a TfLiteIntArray, calculate its size +// in bytes. +int TfLiteIntArrayGetSizeInBytes(int size); + +// Create a array of a given `size` (uninitialized entries). +// This returns a pointer, that you must free using TfLiteIntArrayFree(). +TfLiteIntArray* TfLiteIntArrayCreate(int size); + +// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise. +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b); + +// Create a copy of an array passed as `src`. +// You are expected to free memory with TfLiteIntArrayFree +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src); + +// Free memory of array `v`. +void TfLiteIntArrayFree(TfLiteIntArray* v); + +// Since we must not depend on any libraries, define a minimal subset of +// error macros while avoiding names that have pre-conceived meanings like +// assert and check. + +// Check whether value is true, and if not return kTfLiteError from +// the current function (and report the error string msg). +#define TF_LITE_ENSURE_MSG(context, value, msg) \ + do { \ + if (!(value)) { \ + (context)->ReportError((context), __FILE__ " " msg); \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +#define TF_LITE_ENSURE(context, a) \ + do { \ + if (!(a)) { \ + (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \ + __LINE__, #a); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_STATUS(a) \ + do { \ + if ((a) != kTfLiteOk) { \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a == b` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +// `a` and `b` may be evaluated more than once, so no side effects or +// extremely expensive computations should be done. +#define TF_LITE_ENSURE_EQ(context, a, b) \ + do { \ + if ((a) != (b)) { \ + (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ + __LINE__, #a, #b, (a), (b)); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_OK(context, status) \ + do { \ + if ((status) != kTfLiteOk) { \ + return status; \ + } \ + } while (0) + +// Single-precision complex data type compatible with the C99 definition. +typedef struct { + float re, im; // real and imaginary parts, respectively. +} TfLiteComplex64; + +// Types supported by tensor +typedef enum { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, + kTfLiteBool = 6, + kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, +} TfLiteType; + +// Parameters for asymmetric quantization. Quantized values can be converted +// back to float using: +// real_value = scale * (quantized_value - zero_point); +typedef struct { + float scale; + int32_t zero_point; +} TfLiteQuantizationParams; + +// A union of pointers that points to memory for a given tensor. +typedef union { + int* i32; + int64_t* i64; + float* f; + char* raw; + const char* raw_const; + uint8_t* uint8; + bool* b; + int16_t* i16; + TfLiteComplex64* c64; +} TfLitePtrUnion; + +// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped +// data (or data externally allocated). kTfLiteArenaRw is arena allocated +// data. kTfLiteDynamic is for tensors that are allocated during evaluation. +typedef enum { + kTfLiteMemNone = 0, + kTfLiteMmapRo, + kTfLiteArenaRw, + kTfLiteArenaRwPersistent, + kTfLiteDynamic, +} TfLiteAllocationType; + +// The delegates should use zero or positive integers to represent handles. +// -1 is reserved from unallocated status. +typedef int TfLiteBufferHandle; +const TfLiteBufferHandle kTfLiteNullBufferHandle = -1; + +// An tensor in the interpreter system which is a wrapper around a buffer of +// data including a dimensionality (or NULL if not currently defined). +typedef struct { + // The data type specification for data stored in `data`. This affects + // what member of `data` union should be used. + TfLiteType type; + // A union of data pointers. The appropriate type should be used for a typed + // tensor based on `type`. + TfLitePtrUnion data; + // A pointer to a structure representing the dimensionality interpretation + // that the buffer should have. NOTE: the product of elements of `dims` + // and the element datatype size should be equal to `bytes` below. + TfLiteIntArray* dims; + // Quantization information. + TfLiteQuantizationParams params; + // How memory is mapped + // kTfLiteMmapRo: Memory mapped read only. + // i.e. weights + // kTfLiteArenaRw: Arena allocated read write memory + // (i.e. temporaries, outputs). + TfLiteAllocationType allocation_type; + // The number of bytes required to store the data of this Tensor. I.e. + // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if + // type is kTfLiteFloat32 and dims = {3, 2} then + // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. + size_t bytes; + + // An opaque pointer to a tflite::MMapAllocation + const void* allocation; + + // Null-terminated name of this tensor. + const char* name; + + // The delegate which knows how to handle `buffer_handle`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; + + // An integer buffer handle that can be handled by `delegate`. + // The value is valid only when delegate is not null. + // WARNING: This is an experimental interface that is subject to change. + TfLiteBufferHandle buffer_handle; + + // If the delegate uses its own buffer (e.g. GPU memory), the delegate is + // responsible to set data_is_stale to true. + // `delegate->CopyFromBufferHandle` can be called to copy the data from + // delegate buffer. + // WARNING: This is an // experimental interface that is subject to change. + bool data_is_stale; + + // True if the tensor is a variable. + bool is_variable; +} TfLiteTensor; + +// Free data memory of tensor `t`; +void TfLiteTensorDataFree(TfLiteTensor* t); + +// Free memory of tensor `t`; +void TfLiteTensorFree(TfLiteTensor* t); + +// Set all of a tensor's fields (and free any previously allocated data). +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, bool is_variable, + TfLiteTensor* tensor); + +// Resize the allocated data of a (dynamic) tensor. Tensors with allocation +// types other than kTfLiteDynamic will be ignored. +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); + +// A structure representing an instance of a node. +// This structure only exhibits the inputs, outputs and user defined data, not +// other features like the type. +typedef struct { + // Inputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* inputs; + + // Outputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* outputs; + + // Temporary tensors uses during the computations. This usually contains no + // tensors, but ops are allowed to change that if they need scratch space of + // any sort. + TfLiteIntArray* temporaries; + + // Opaque data provided by the node implementer through `Registration.init`. + void* user_data; + + // Opaque data provided to the node if the node is a builtin. This is usually + // a structure defined in builtin_op_data.h + void* builtin_data; + + // Custom initial data. This is the opaque data provided in the flatbuffer. + // WARNING: This is an experimental interface that is subject to change. + const void* custom_initial_data; + int custom_initial_data_size; + + // The pointer to the delegate. This is non-null only when the node is + // created by calling `interpreter.ModifyGraphWithDelegate`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; +} TfLiteNode; + +typedef struct TfLiteContext { + // Number of tensors in the context. + size_t tensors_size; + + // The execution plan contains a list of the node indices in execution + // order. execution_plan->size is the current number of nodes. And, + // execution_plan->data[0] is the first node that needs to be run. + // TfLiteDelegates can traverse the current execution plan by iterating + // through each member of this array and using GetNodeAndRegistration() to + // access details about a node. i.e. + // TfLiteIntArray* execution_plan; + // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan)); + // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) { + // int node_index = execution_plan->data[exec_index]; + // TfLiteNode* node; + // TfLiteRegistration* reg; + // context->GetNodeAndRegistration(context, node_index, &node, ®); + // } + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, + TfLiteIntArray** execution_plan); + + // An array of tensors in the interpreter context (of length `tensors_size`) + TfLiteTensor* tensors; + + // opaque full context ptr (an opaque c++ data structure) + void* impl_; + + // Request memory pointer be resized. Updates dimensions on the tensor. + // NOTE: ResizeTensor takes ownership of newSize. + TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Request that a error be reported with format string msg. + void (*ReportError)(struct TfLiteContext*, const char* msg, ...); + + // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If + // non-null, the value pointed to by `first_new_tensor_index` will be set to + // the index of the first new tensor. + TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, + int* first_new_tensor_index); + + // Get a Tensor node by node_index. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index, + TfLiteNode** node, + TfLiteRegistration** registration); + + // Replace ops with one or more stub delegate operations. This function + // does not take ownership of `nodes_to_replace`. + TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( + struct TfLiteContext*, TfLiteRegistration registration, + const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate); + + // Number of threads that are recommended to subsystems like gemmlowp and + // eigen. + int recommended_num_threads; + + // Access external contexts by type. + // WARNING: This is an experimental interface that is subject to change. + TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*, + TfLiteExternalContextType); + // Set the value of a external context. Does not take ownership of the + // pointer. + // WARNING: This is an experimental interface that is subject to change. + void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType, + TfLiteExternalContext*); +} TfLiteContext; + +typedef struct _TfLiteRegistration { + // Initializes the op from serialized data. + // If a built-in op: + // `buffer` is the op's params data (TfLiteLSTMParams*). + // `length` is zero. + // If custom op: + // `buffer` is the op's `custom_options`. + // `length` is the size of the buffer. + // + // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer + // or an instance of a struct). + // + // The returned pointer will be stored with the node in the `user_data` field, + // accessible within prepare and invoke functions below. + // NOTE: if the data is already in the desired format, simply implement this + // function to return `nullptr` and implement the free function to be a no-op. + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + + // The pointer `buffer` is the data previously returned by an init invocation. + void (*free)(TfLiteContext* context, void* buffer); + + // prepare is called when the inputs this node depends on have been resized. + // context->ResizeTensor() can be called to request output tensors to be + // resized. + // + // Returns kTfLiteOk on success. + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + + // Execute the node (should read node->inputs and output to node->outputs). + // Returns kTfLiteOk on success. + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + + // profiling_string is called during summarization of profiling information + // in order to group executions together. Providing a value here will cause a + // given op to appear multiple times is the profiling report. This is + // particularly useful for custom ops that can perform significantly + // different calculations depending on their `user-data`. + const char* (*profiling_string)(const TfLiteContext* context, + const TfLiteNode* node); + + // Builtin codes. If this kernel refers to a builtin this is the code + // of the builtin. This is so we can do marshaling to other frameworks like + // NN API. + // Note: It is the responsibility of the registration binder to set this + // properly. + int32_t builtin_code; + + // Custom op name. If the op is a builtin, this will be null. + // Note: It is the responsibility of the registration binder to set this + // properly. + // WARNING: This is an experimental interface that is subject to change. + const char* custom_name; + + // The version of the op. + // Note: It is the responsibility of the registration binder to set this + // properly. + int version; +} TfLiteRegistration; + +// WARNING: This is an experimental interface that is subject to change. +typedef struct _TfLiteDelegate { + // Data that delegate needs to identify itself. This data is owned by the + // delegate. The delegate is owned in the user code, so the delegate is + // responsible for doing this when it is destroyed. + void* data_; + + // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the + // delegate a view of the current graph through TfLiteContext*. It typically + // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() + // to ask the TensorFlow lite runtime to create macro-nodes to represent + // delegated subgraphs of the original graph. + TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate); + + // Copy the data from delegate buffer handle to raw memory. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size); + + // Copy the data from raw memory to delegate buffer handle. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size); + + // Free the Delegate Buffer Handle. Note: This only frees the handle, but + // this doesn't release the underlying resource (e.g. textures). The + // resources are either owned by application layer or the delegate. + // This can be null if the delegate doesn't use its own buffer. + void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle* handle); +} TfLiteDelegate; + +// WARNING: This is an experimental interface that is subject to change. +// +// Currently, TfLiteDelegateParams has to be allocated in a way that it's +// trivially destructable. It will be stored as `builtin_data` field in +// `TfLiteNode` of the delegate node. +// +// See also the `CreateDelegateParams` function in `interpreter.cc` details. +typedef struct { + TfLiteDelegate* delegate; + TfLiteIntArray* nodes_to_replace; + TfLiteIntArray* input_tensors; + TfLiteIntArray* output_tensors; +} TfLiteDelegateParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ diff --git a/tensorflow/contrib/lite/c/c_api_internal_test.cc b/tensorflow/contrib/lite/c/c_api_internal_test.cc new file mode 100644 index 0000000000..af398f3207 --- /dev/null +++ b/tensorflow/contrib/lite/c/c_api_internal_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include + +namespace tflite { + +// NOTE: this tests only the TfLiteIntArray part of context. +// most of c_api_internal.h is provided in the context of using it with +// interpreter.h and interpreter.cc, so interpreter_test.cc tests context +// structures more thoroughly. + +TEST(IntArray, TestIntArrayCreate) { + TfLiteIntArray* a = TfLiteIntArrayCreate(0); + TfLiteIntArray* b = TfLiteIntArrayCreate(3); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); +} + +TEST(IntArray, TestIntArrayCopy) { + TfLiteIntArray* a = TfLiteIntArrayCreate(2); + a->data[0] = 22; + a->data[1] = 24; + TfLiteIntArray* b = TfLiteIntArrayCopy(a); + ASSERT_NE(a, b); + ASSERT_EQ(a->size, b->size); + ASSERT_EQ(a->data[0], b->data[0]); + ASSERT_EQ(a->data[1], b->data[1]); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); +} + +TEST(IntArray, TestIntArrayEqual) { + TfLiteIntArray* a = TfLiteIntArrayCreate(1); + a->data[0] = 1; + TfLiteIntArray* b = TfLiteIntArrayCreate(2); + b->data[0] = 5; + b->data[1] = 6; + TfLiteIntArray* c = TfLiteIntArrayCreate(2); + c->data[0] = 5; + c->data[1] = 6; + TfLiteIntArray* d = TfLiteIntArrayCreate(2); + d->data[0] = 6; + d->data[1] = 6; + ASSERT_FALSE(TfLiteIntArrayEqual(a, b)); + ASSERT_TRUE(TfLiteIntArrayEqual(b, c)); + ASSERT_TRUE(TfLiteIntArrayEqual(b, b)); + ASSERT_FALSE(TfLiteIntArrayEqual(c, d)); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); + TfLiteIntArrayFree(c); + TfLiteIntArrayFree(d); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} -- cgit v1.2.3