aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/c
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2018-09-07 17:36:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 17:40:10 -0700
commit9982fd6c8831cbd2f58954f79ea71f26660393bc (patch)
tree108907bde953d0d70ee5d3b8323a99bb9b681563 /tensorflow/contrib/lite/c
parentedda5e39e4e93ba60e4d31b6ecb1c295dead29c8 (diff)
Modularize TF Lite interface definitions and reorganize file structure
PiperOrigin-RevId: 212064501
Diffstat (limited to 'tensorflow/contrib/lite/c')
-rw-r--r--tensorflow/contrib/lite/c/BUILD39
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h298
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data_test.cc83
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.c104
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.h491
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal_test.cc73
6 files changed, 1088 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/c/BUILD b/tensorflow/contrib/lite/c/BUILD
new file mode 100644
index 0000000000..663eb63cad
--- /dev/null
+++ b/tensorflow/contrib/lite/c/BUILD
@@ -0,0 +1,39 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "c_api_internal",
+ srcs = ["c_api_internal.c"],
+ hdrs = [
+ "builtin_op_data.h",
+ "c_api_internal.h",
+ ],
+ visibility = [
+ "//tensorflow/contrib/lite:__subpackages__",
+ ],
+)
+
+# Test the C extension API code.
+cc_test(
+ name = "c_api_internal_test",
+ size = "small",
+ srcs = ["c_api_internal_test.cc"],
+ deps = [
+ ":c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "builtin_op_data_test",
+ size = "small",
+ srcs = ["builtin_op_data_test.cc"],
+ copts = ["-Wno-unused-variable"],
+ deps = [
+ ":c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
new file mode 100644
index 0000000000..fa43e6a024
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -0,0 +1,298 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+
+#include <stdint.h>
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// TODO(aselle): Consider using "if this then that" for testing.
+
+// Possible padding types (for convolutions)
+typedef enum {
+ kTfLitePaddingUnknown = 0,
+ kTfLitePaddingSame,
+ kTfLitePaddingValid,
+} TfLitePadding;
+
+typedef struct {
+ int width;
+ int height;
+} TfLitePaddingValues;
+
+// Possible fused activation functions.
+// TODO(aselle): rename to TfLiteActivation
+typedef enum {
+ kTfLiteActNone = 0,
+ kTfLiteActRelu,
+ kTfLiteActRelu1,
+ kTfLiteActRelu6,
+ kTfLiteActTanh,
+ kTfLiteActSignBit,
+ kTfLiteActSigmoid,
+} TfLiteFusedActivation;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int dilation_width_factor;
+ int dilation_height_factor;
+ TfLiteFusedActivation activation;
+} TfLiteConvParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int filter_width;
+ int filter_height;
+ TfLiteFusedActivation activation;
+ struct {
+ TfLitePaddingValues padding;
+ } computed;
+} TfLitePoolParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int depth_multiplier;
+ TfLiteFusedActivation activation;
+} TfLiteDepthwiseConvParams;
+
+typedef struct {
+ int rank;
+ TfLiteFusedActivation activation;
+} TfLiteSVDFParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteRNNParams;
+
+typedef struct {
+ bool time_major;
+ TfLiteFusedActivation activation;
+} TfLiteSequenceRNNParams;
+
+typedef enum {
+ kTfLiteFullyConnectedWeightsFormatDefault = 0,
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
+} TfLiteFullyConnectedWeightsFormat;
+
+typedef struct {
+ // Parameters for FullyConnected version 1 or above.
+ TfLiteFusedActivation activation;
+
+ // Parameters for FullyConnected version 2 or above.
+ TfLiteFullyConnectedWeightsFormat weights_format;
+} TfLiteFullyConnectedParams;
+
+typedef enum {
+ kTfLiteLshProjectionUnknown = 0,
+ kTfLiteLshProjectionSparse = 1,
+ kTfLiteLshProjectionDense = 2,
+} TfLiteLSHProjectionType;
+
+typedef struct {
+ TfLiteLSHProjectionType type;
+} TfLiteLSHProjectionParams;
+
+typedef struct {
+ float beta;
+} TfLiteSoftmaxParams;
+
+typedef struct {
+ int axis;
+ TfLiteFusedActivation activation;
+} TfLiteConcatenationParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteAddParams;
+
+typedef struct {
+} TfLiteSpaceToBatchNDParams;
+
+typedef struct {
+} TfLiteBatchToSpaceNDParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteMulParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteSubParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteDivParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteL2NormParams;
+
+typedef struct {
+ int radius;
+ float bias;
+ float alpha;
+ float beta;
+} TfLiteLocalResponseNormParams;
+
+typedef enum {
+ kTfLiteLSTMFullKernel = 0,
+ kTfLiteLSTMBasicKernel
+} TfLiteLSTMKernelType;
+
+typedef struct {
+ // Parameters for LSTM version 1.
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+
+ // Parameters for LSTM version 2.
+ // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
+ TfLiteLSTMKernelType kernel_type;
+} TfLiteLSTMParams;
+
+typedef struct {
+ bool align_corners;
+} TfLiteResizeBilinearParams;
+
+typedef struct {
+} TfLitePadParams;
+
+typedef struct {
+} TfLitePadV2Params;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int shape[8];
+ int num_dimensions;
+} TfLiteReshapeParams;
+
+typedef struct {
+ int ngram_size;
+ int max_skip_size;
+ bool include_all_ngrams;
+} TfLiteSkipGramParams;
+
+typedef struct {
+ int block_size;
+} TfLiteSpaceToDepthParams;
+
+typedef struct {
+ TfLiteType in_data_type;
+ TfLiteType out_data_type;
+} TfLiteCastParams;
+
+typedef enum {
+ kTfLiteCombinerTypeSum = 0,
+ kTfLiteCombinerTypeMean = 1,
+ kTfLiteCombinerTypeSqrtn = 2,
+} TfLiteCombinerType;
+
+typedef struct {
+ TfLiteCombinerType combiner;
+} TfLiteEmbeddingLookupSparseParams;
+
+typedef struct {
+ int axis;
+} TfLiteGatherParams;
+
+typedef struct {
+} TfLiteTransposeParams;
+
+typedef struct {
+ bool keep_dims;
+} TfLiteReducerParams;
+
+typedef struct {
+ int num_splits;
+} TfLiteSplitParams;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int squeeze_dims[8];
+ int num_squeeze_dims;
+} TfLiteSqueezeParams;
+
+typedef struct {
+ int begin_mask;
+ int end_mask;
+ int ellipsis_mask;
+ int new_axis_mask;
+ int shrink_axis_mask;
+} TfLiteStridedSliceParams;
+
+typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMaxParams;
+
+typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMinParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+} TfLiteTransposeConvParams;
+
+typedef struct {
+ bool validate_indices;
+} TfLiteSparseToDenseParams;
+
+typedef struct {
+ TfLiteType out_type;
+} TfLiteShapeParams;
+
+typedef struct {
+ // Parameters supported by version 1:
+ float min;
+ float max;
+ int num_bits;
+
+ // Parameters supported by version 2:
+ bool narrow_range;
+} TfLiteFakeQuantParams;
+
+typedef struct {
+ int values_count;
+ int axis;
+} TfLitePackParams;
+
+typedef struct {
+ int axis;
+} TfLiteOneHotParams;
+
+typedef struct {
+ int num;
+ int axis;
+} TfLiteUnpackParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
new file mode 100644
index 0000000000..4d0ba75e68
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+// Builtin op data is just a set of data definitions, so the only meaningful
+// test we can run is whether we can create the structs we expect to find.
+// Testing each struct's members might be possible, but it seems unnecessary
+// until we've locked down the API. The build rule has copts set to ignore the
+// unused variable warning, since this is just a compilation test.
+TEST(IntArray, CanCompileStructs) {
+ TfLitePadding padding = kTfLitePaddingSame;
+ TfLitePaddingValues padding_values;
+ TfLiteFusedActivation fused_activation = kTfLiteActRelu;
+ TfLiteConvParams conv_params;
+ TfLitePoolParams pool_params;
+ TfLiteDepthwiseConvParams depthwise_conv_params;
+ TfLiteSVDFParams svdf_params;
+ TfLiteRNNParams rnn_params;
+ TfLiteSequenceRNNParams sequence_rnn_params;
+ TfLiteFullyConnectedWeightsFormat fully_connected_weights_format =
+ kTfLiteFullyConnectedWeightsFormatDefault;
+ TfLiteFullyConnectedParams fully_connected_params;
+ TfLiteLSHProjectionType projection_type = kTfLiteLshProjectionDense;
+ TfLiteLSHProjectionParams projection_params;
+ TfLiteSoftmaxParams softmax_params;
+ TfLiteConcatenationParams concatenation_params;
+ TfLiteAddParams add_params;
+ TfLiteSpaceToBatchNDParams space_to_batch_nd_params;
+ TfLiteBatchToSpaceNDParams batch_to_space_nd_params;
+ TfLiteMulParams mul_params;
+ TfLiteSubParams sub_params;
+ TfLiteDivParams div_params;
+ TfLiteL2NormParams l2_norm_params;
+ TfLiteLocalResponseNormParams local_response_norm_params;
+ TfLiteLSTMKernelType lstm_kernel_type = kTfLiteLSTMBasicKernel;
+ TfLiteLSTMParams lstm_params;
+ TfLiteResizeBilinearParams resize_bilinear_params;
+ TfLitePadParams pad_params;
+ TfLitePadV2Params pad_v2_params;
+ TfLiteReshapeParams reshape_params;
+ TfLiteSkipGramParams skip_gram_params;
+ TfLiteSpaceToDepthParams space_to_depth_params;
+ TfLiteCastParams cast_params;
+ TfLiteCombinerType combiner_type = kTfLiteCombinerTypeSqrtn;
+ TfLiteEmbeddingLookupSparseParams lookup_sparse_params;
+ TfLiteGatherParams gather_params;
+ TfLiteTransposeParams transpose_params;
+ TfLiteReducerParams reducer_params;
+ TfLiteSplitParams split_params;
+ TfLiteSqueezeParams squeeze_params;
+ TfLiteStridedSliceParams strided_slice_params;
+ TfLiteArgMaxParams arg_max_params;
+ TfLiteArgMinParams arg_min_params;
+ TfLiteTransposeConvParams transpose_conv_params;
+ TfLiteSparseToDenseParams sparse_to_dense_params;
+ TfLiteShapeParams shape_params;
+ TfLiteFakeQuantParams fake_quant_params;
+ TfLitePackParams pack_params;
+ TfLiteOneHotParams one_hot_params;
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/c/c_api_internal.c b/tensorflow/contrib/lite/c/c_api_internal.c
new file mode 100644
index 0000000000..1846bad4b7
--- /dev/null
+++ b/tensorflow/contrib/lite/c/c_api_internal.c
@@ -0,0 +1,104 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+int TfLiteIntArrayGetSizeInBytes(int size) {
+ static TfLiteIntArray dummy;
+ return sizeof(dummy) + sizeof(dummy.data[0]) * size;
+}
+
+TfLiteIntArray* TfLiteIntArrayCreate(int size) {
+ TfLiteIntArray* ret =
+ (TfLiteIntArray*)malloc(TfLiteIntArrayGetSizeInBytes(size));
+ ret->size = size;
+ return ret;
+}
+
+void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) {
+ printf("%s: length=%d [", s, a->size);
+ if (a->size) printf("%d", a->data[0]);
+ int i = 1;
+ for (; i < a->size; i++) {
+ printf(" %d", a->data[i]);
+ }
+ printf("]\n");
+}
+
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) {
+ if (a == b) return 1;
+ if (a == NULL || b == NULL) return 0;
+ if (a->size != b->size) return 0;
+ int i = 0;
+ for (; i < a->size; i++)
+ if (a->data[i] != b->data[i]) return 0;
+ return 1;
+}
+
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) {
+ if (!src) return NULL;
+ TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size);
+ if (ret) {
+ memcpy(ret->data, src->data, src->size * sizeof(int));
+ }
+ return ret;
+}
+
+void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); }
+
+void TfLiteTensorDataFree(TfLiteTensor* t) {
+ if (t->allocation_type == kTfLiteDynamic && t->data.raw) {
+ free(t->data.raw);
+ }
+ t->data.raw = NULL;
+}
+
+void TfLiteTensorFree(TfLiteTensor* t) {
+ TfLiteTensorDataFree(t);
+ if (t->dims) TfLiteIntArrayFree(t->dims);
+ t->dims = NULL;
+}
+
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+ TfLiteQuantizationParams quantization, char* buffer,
+ size_t size, TfLiteAllocationType allocation_type,
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor) {
+ TfLiteTensorFree(tensor);
+ tensor->type = type;
+ tensor->name = name;
+ tensor->dims = dims;
+ tensor->params = quantization;
+ tensor->data.raw = buffer;
+ tensor->bytes = size;
+ tensor->allocation_type = allocation_type;
+ tensor->allocation = allocation;
+ tensor->is_variable = is_variable;
+}
+
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
+ if (tensor->allocation_type != kTfLiteDynamic) {
+ return;
+ }
+ if (!tensor->data.raw) {
+ tensor->data.raw = malloc(num_bytes);
+ } else if (num_bytes > tensor->bytes) {
+ tensor->data.raw = realloc(tensor->data.raw, num_bytes);
+ }
+ tensor->bytes = num_bytes;
+}
diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h
new file mode 100644
index 0000000000..48df68a654
--- /dev/null
+++ b/tensorflow/contrib/lite/c/c_api_internal.h
@@ -0,0 +1,491 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file defines a C API for implementing operations in tflite.
+// These operations can be defined using c++ but the interface between
+// the interpreter and the operations are C.
+//
+// Summary of abstractions
+// TF_LITE_ENSURE - Self-sufficient error checking
+// TfLiteStatus - Status reporting
+// TfLiteIntArray - stores tensor shapes (dims),
+// TfLiteContext - allows an op to access the tensors
+// TfLiteTensor - tensor (a multidimensional array)
+// TfLiteNode - a single node or operation
+// TfLiteRegistration - the implementation of a conceptual operation.
+//
+// Some abstractions in this file are created and managed by Interpreter.
+#ifndef TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+#define TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+
+// The list of external context types known to TF Lite. This list exists solely
+// to avoid conflicts and to ensure ops can share the external contexts they
+// need. Access to the external contexts is controled by one of the
+// corresponding support files.
+typedef enum {
+ kTfLiteEigenContext = 0, // include eigen_support.h to use.
+ kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
+ kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
+ kTfLiteMaxExternalContexts = 3
+} TfLiteExternalContextType;
+
+// An external context is a collection of information unrelated to the TF Lite
+// framework, but useful to a subset of the ops. TF Lite knows very little
+// about about the actual contexts, but it keeps a list of them, and is able to
+// refresh them if configurations like the number of recommended threads
+// change.
+typedef struct {
+ TfLiteExternalContextType type;
+ TfLiteStatus (*Refresh)(struct TfLiteContext* context);
+} TfLiteExternalContext;
+
+// Forward declare so GetNode can use this is in Context.
+typedef struct _TfLiteRegistration TfLiteRegistration;
+typedef struct _TfLiteDelegate TfLiteDelegate;
+
+#define kOptionalTensor (-1)
+
+// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
+// indices
+typedef struct {
+ int size;
+// gcc 6.1+ have a bug where flexible members aren't properly handled
+// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
+#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
+ __GNUC_MINOR__ >= 1
+ int data[0];
+#else
+ int data[];
+#endif
+} TfLiteIntArray;
+
+// Given the size (number of elements) in a TfLiteIntArray, calculate its size
+// in bytes.
+int TfLiteIntArrayGetSizeInBytes(int size);
+
+// Create a array of a given `size` (uninitialized entries).
+// This returns a pointer, that you must free using TfLiteIntArrayFree().
+TfLiteIntArray* TfLiteIntArrayCreate(int size);
+
+// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
+
+// Create a copy of an array passed as `src`.
+// You are expected to free memory with TfLiteIntArrayFree
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
+
+// Free memory of array `v`.
+void TfLiteIntArrayFree(TfLiteIntArray* v);
+
+// Since we must not depend on any libraries, define a minimal subset of
+// error macros while avoiding names that have pre-conceived meanings like
+// assert and check.
+
+// Check whether value is true, and if not return kTfLiteError from
+// the current function (and report the error string msg).
+#define TF_LITE_ENSURE_MSG(context, value, msg) \
+ do { \
+ if (!(value)) { \
+ (context)->ReportError((context), __FILE__ " " msg); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+#define TF_LITE_ENSURE(context, a) \
+ do { \
+ if (!(a)) { \
+ (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
+ __LINE__, #a); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_STATUS(a) \
+ do { \
+ if ((a) != kTfLiteOk) { \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a == b` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+// `a` and `b` may be evaluated more than once, so no side effects or
+// extremely expensive computations should be done.
+#define TF_LITE_ENSURE_EQ(context, a, b) \
+ do { \
+ if ((a) != (b)) { \
+ (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
+ __LINE__, #a, #b, (a), (b)); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_OK(context, status) \
+ do { \
+ if ((status) != kTfLiteOk) { \
+ return status; \
+ } \
+ } while (0)
+
+// Single-precision complex data type compatible with the C99 definition.
+typedef struct {
+ float re, im; // real and imaginary parts, respectively.
+} TfLiteComplex64;
+
+// Types supported by tensor
+typedef enum {
+ kTfLiteNoType = 0,
+ kTfLiteFloat32 = 1,
+ kTfLiteInt32 = 2,
+ kTfLiteUInt8 = 3,
+ kTfLiteInt64 = 4,
+ kTfLiteString = 5,
+ kTfLiteBool = 6,
+ kTfLiteInt16 = 7,
+ kTfLiteComplex64 = 8,
+} TfLiteType;
+
+// Parameters for asymmetric quantization. Quantized values can be converted
+// back to float using:
+// real_value = scale * (quantized_value - zero_point);
+typedef struct {
+ float scale;
+ int32_t zero_point;
+} TfLiteQuantizationParams;
+
+// A union of pointers that points to memory for a given tensor.
+typedef union {
+ int* i32;
+ int64_t* i64;
+ float* f;
+ char* raw;
+ const char* raw_const;
+ uint8_t* uint8;
+ bool* b;
+ int16_t* i16;
+ TfLiteComplex64* c64;
+} TfLitePtrUnion;
+
+// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
+// data (or data externally allocated). kTfLiteArenaRw is arena allocated
+// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+typedef enum {
+ kTfLiteMemNone = 0,
+ kTfLiteMmapRo,
+ kTfLiteArenaRw,
+ kTfLiteArenaRwPersistent,
+ kTfLiteDynamic,
+} TfLiteAllocationType;
+
+// The delegates should use zero or positive integers to represent handles.
+// -1 is reserved from unallocated status.
+typedef int TfLiteBufferHandle;
+const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
+
+// An tensor in the interpreter system which is a wrapper around a buffer of
+// data including a dimensionality (or NULL if not currently defined).
+typedef struct {
+ // The data type specification for data stored in `data`. This affects
+ // what member of `data` union should be used.
+ TfLiteType type;
+ // A union of data pointers. The appropriate type should be used for a typed
+ // tensor based on `type`.
+ TfLitePtrUnion data;
+ // A pointer to a structure representing the dimensionality interpretation
+ // that the buffer should have. NOTE: the product of elements of `dims`
+ // and the element datatype size should be equal to `bytes` below.
+ TfLiteIntArray* dims;
+ // Quantization information.
+ TfLiteQuantizationParams params;
+ // How memory is mapped
+ // kTfLiteMmapRo: Memory mapped read only.
+ // i.e. weights
+ // kTfLiteArenaRw: Arena allocated read write memory
+ // (i.e. temporaries, outputs).
+ TfLiteAllocationType allocation_type;
+ // The number of bytes required to store the data of this Tensor. I.e.
+ // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
+ // type is kTfLiteFloat32 and dims = {3, 2} then
+ // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
+ size_t bytes;
+
+ // An opaque pointer to a tflite::MMapAllocation
+ const void* allocation;
+
+ // Null-terminated name of this tensor.
+ const char* name;
+
+ // The delegate which knows how to handle `buffer_handle`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+
+ // An integer buffer handle that can be handled by `delegate`.
+ // The value is valid only when delegate is not null.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteBufferHandle buffer_handle;
+
+ // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
+ // responsible to set data_is_stale to true.
+ // `delegate->CopyFromBufferHandle` can be called to copy the data from
+ // delegate buffer.
+ // WARNING: This is an // experimental interface that is subject to change.
+ bool data_is_stale;
+
+ // True if the tensor is a variable.
+ bool is_variable;
+} TfLiteTensor;
+
+// Free data memory of tensor `t`;
+void TfLiteTensorDataFree(TfLiteTensor* t);
+
+// Free memory of tensor `t`;
+void TfLiteTensorFree(TfLiteTensor* t);
+
+// Set all of a tensor's fields (and free any previously allocated data).
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+ TfLiteQuantizationParams quantization, char* buffer,
+ size_t size, TfLiteAllocationType allocation_type,
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor);
+
+// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
+// types other than kTfLiteDynamic will be ignored.
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
+
+// A structure representing an instance of a node.
+// This structure only exhibits the inputs, outputs and user defined data, not
+// other features like the type.
+typedef struct {
+ // Inputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* inputs;
+
+ // Outputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* outputs;
+
+ // Temporary tensors uses during the computations. This usually contains no
+ // tensors, but ops are allowed to change that if they need scratch space of
+ // any sort.
+ TfLiteIntArray* temporaries;
+
+ // Opaque data provided by the node implementer through `Registration.init`.
+ void* user_data;
+
+ // Opaque data provided to the node if the node is a builtin. This is usually
+ // a structure defined in builtin_op_data.h
+ void* builtin_data;
+
+ // Custom initial data. This is the opaque data provided in the flatbuffer.
+ // WARNING: This is an experimental interface that is subject to change.
+ const void* custom_initial_data;
+ int custom_initial_data_size;
+
+ // The pointer to the delegate. This is non-null only when the node is
+ // created by calling `interpreter.ModifyGraphWithDelegate`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+} TfLiteNode;
+
+typedef struct TfLiteContext {
+ // Number of tensors in the context.
+ size_t tensors_size;
+
+ // The execution plan contains a list of the node indices in execution
+ // order. execution_plan->size is the current number of nodes. And,
+ // execution_plan->data[0] is the first node that needs to be run.
+ // TfLiteDelegates can traverse the current execution plan by iterating
+ // through each member of this array and using GetNodeAndRegistration() to
+ // access details about a node. i.e.
+ // TfLiteIntArray* execution_plan;
+ // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
+ // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
+ // int node_index = execution_plan->data[exec_index];
+ // TfLiteNode* node;
+ // TfLiteRegistration* reg;
+ // context->GetNodeAndRegistration(context, node_index, &node, &reg);
+ // }
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
+ TfLiteIntArray** execution_plan);
+
+ // An array of tensors in the interpreter context (of length `tensors_size`)
+ TfLiteTensor* tensors;
+
+ // opaque full context ptr (an opaque c++ data structure)
+ void* impl_;
+
+ // Request memory pointer be resized. Updates dimensions on the tensor.
+ // NOTE: ResizeTensor takes ownership of newSize.
+ TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
+ TfLiteIntArray* new_size);
+ // Request that a error be reported with format string msg.
+ void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
+
+ // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
+ // non-null, the value pointed to by `first_new_tensor_index` will be set to
+ // the index of the first new tensor.
+ TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
+ int* first_new_tensor_index);
+
+ // Get a Tensor node by node_index.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
+ TfLiteNode** node,
+ TfLiteRegistration** registration);
+
+ // Replace ops with one or more stub delegate operations. This function
+ // does not take ownership of `nodes_to_replace`.
+ TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
+ struct TfLiteContext*, TfLiteRegistration registration,
+ const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
+
+ // Number of threads that are recommended to subsystems like gemmlowp and
+ // eigen.
+ int recommended_num_threads;
+
+ // Access external contexts by type.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
+ TfLiteExternalContextType);
+ // Set the value of a external context. Does not take ownership of the
+ // pointer.
+ // WARNING: This is an experimental interface that is subject to change.
+ void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
+ TfLiteExternalContext*);
+} TfLiteContext;
+
+typedef struct _TfLiteRegistration {
+ // Initializes the op from serialized data.
+ // If a built-in op:
+ // `buffer` is the op's params data (TfLiteLSTMParams*).
+ // `length` is zero.
+ // If custom op:
+ // `buffer` is the op's `custom_options`.
+ // `length` is the size of the buffer.
+ //
+ // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
+ // or an instance of a struct).
+ //
+ // The returned pointer will be stored with the node in the `user_data` field,
+ // accessible within prepare and invoke functions below.
+ // NOTE: if the data is already in the desired format, simply implement this
+ // function to return `nullptr` and implement the free function to be a no-op.
+ void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+
+ // The pointer `buffer` is the data previously returned by an init invocation.
+ void (*free)(TfLiteContext* context, void* buffer);
+
+ // prepare is called when the inputs this node depends on have been resized.
+ // context->ResizeTensor() can be called to request output tensors to be
+ // resized.
+ //
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+
+ // Execute the node (should read node->inputs and output to node->outputs).
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+
+ // profiling_string is called during summarization of profiling information
+ // in order to group executions together. Providing a value here will cause a
+ // given op to appear multiple times is the profiling report. This is
+ // particularly useful for custom ops that can perform significantly
+ // different calculations depending on their `user-data`.
+ const char* (*profiling_string)(const TfLiteContext* context,
+ const TfLiteNode* node);
+
+ // Builtin codes. If this kernel refers to a builtin this is the code
+ // of the builtin. This is so we can do marshaling to other frameworks like
+ // NN API.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int32_t builtin_code;
+
+ // Custom op name. If the op is a builtin, this will be null.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ // WARNING: This is an experimental interface that is subject to change.
+ const char* custom_name;
+
+ // The version of the op.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int version;
+} TfLiteRegistration;
+
+// WARNING: This is an experimental interface that is subject to change.
+typedef struct _TfLiteDelegate {
+ // Data that delegate needs to identify itself. This data is owned by the
+ // delegate. The delegate is owned in the user code, so the delegate is
+ // responsible for doing this when it is destroyed.
+ void* data_;
+
+ // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
+ // delegate a view of the current graph through TfLiteContext*. It typically
+ // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
+ // to ask the TensorFlow lite runtime to create macro-nodes to represent
+ // delegated subgraphs of the original graph.
+ TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
+
+ // Copy the data from delegate buffer handle to raw memory.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size);
+
+ // Copy the data from raw memory to delegate buffer handle.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size);
+
+ // Free the Delegate Buffer Handle. Note: This only frees the handle, but
+ // this doesn't release the underlying resource (e.g. textures). The
+ // resources are either owned by application layer or the delegate.
+ // This can be null if the delegate doesn't use its own buffer.
+ void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
+ TfLiteBufferHandle* handle);
+} TfLiteDelegate;
+
+// WARNING: This is an experimental interface that is subject to change.
+//
+// Currently, TfLiteDelegateParams has to be allocated in a way that it's
+// trivially destructable. It will be stored as `builtin_data` field in
+// `TfLiteNode` of the delegate node.
+//
+// See also the `CreateDelegateParams` function in `interpreter.cc` details.
+typedef struct {
+ TfLiteDelegate* delegate;
+ TfLiteIntArray* nodes_to_replace;
+ TfLiteIntArray* input_tensors;
+ TfLiteIntArray* output_tensors;
+} TfLiteDelegateParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
diff --git a/tensorflow/contrib/lite/c/c_api_internal_test.cc b/tensorflow/contrib/lite/c/c_api_internal_test.cc
new file mode 100644
index 0000000000..af398f3207
--- /dev/null
+++ b/tensorflow/contrib/lite/c/c_api_internal_test.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+// NOTE: this tests only the TfLiteIntArray part of context.
+// most of c_api_internal.h is provided in the context of using it with
+// interpreter.h and interpreter.cc, so interpreter_test.cc tests context
+// structures more thoroughly.
+
+TEST(IntArray, TestIntArrayCreate) {
+ TfLiteIntArray* a = TfLiteIntArrayCreate(0);
+ TfLiteIntArray* b = TfLiteIntArrayCreate(3);
+ TfLiteIntArrayFree(a);
+ TfLiteIntArrayFree(b);
+}
+
+TEST(IntArray, TestIntArrayCopy) {
+ TfLiteIntArray* a = TfLiteIntArrayCreate(2);
+ a->data[0] = 22;
+ a->data[1] = 24;
+ TfLiteIntArray* b = TfLiteIntArrayCopy(a);
+ ASSERT_NE(a, b);
+ ASSERT_EQ(a->size, b->size);
+ ASSERT_EQ(a->data[0], b->data[0]);
+ ASSERT_EQ(a->data[1], b->data[1]);
+ TfLiteIntArrayFree(a);
+ TfLiteIntArrayFree(b);
+}
+
+TEST(IntArray, TestIntArrayEqual) {
+ TfLiteIntArray* a = TfLiteIntArrayCreate(1);
+ a->data[0] = 1;
+ TfLiteIntArray* b = TfLiteIntArrayCreate(2);
+ b->data[0] = 5;
+ b->data[1] = 6;
+ TfLiteIntArray* c = TfLiteIntArrayCreate(2);
+ c->data[0] = 5;
+ c->data[1] = 6;
+ TfLiteIntArray* d = TfLiteIntArrayCreate(2);
+ d->data[0] = 6;
+ d->data[1] = 6;
+ ASSERT_FALSE(TfLiteIntArrayEqual(a, b));
+ ASSERT_TRUE(TfLiteIntArrayEqual(b, c));
+ ASSERT_TRUE(TfLiteIntArrayEqual(b, b));
+ ASSERT_FALSE(TfLiteIntArrayEqual(c, d));
+ TfLiteIntArrayFree(a);
+ TfLiteIntArrayFree(b);
+ TfLiteIntArrayFree(c);
+ TfLiteIntArrayFree(d);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}