aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/core
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2018-09-24 15:54:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 16:02:13 -0700
commit1ff157d82dac29f5a3a3197b2664208f6ed6ba06 (patch)
treec751f5a665a27c660809c4884eb07f69b4983244 /tensorflow/contrib/lite/core
parent9c58005ec86297a1d0a17dc4f7ad7cbae9c47e4b (diff)
Portability preparation for more cross-platform prototyping.
PiperOrigin-RevId: 214346240
Diffstat (limited to 'tensorflow/contrib/lite/core')
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc92
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.h22
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc26
3 files changed, 87 insertions, 53 deletions
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index 03af538073..e6900e0950 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -44,16 +44,6 @@ void FlatBufferIntVectorToArray(int max_size_of_buffer,
}
}
-// Allocate a structure using malloc, but make sure the structure is a POD
-// structure that doesn't require constructors to run. The reason we do this,
-// is that Interpreter's C extension part will take ownership so destructors
-// will not be run during deallocation.
-template <class T>
-T* MallocPOD() {
- static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
- return static_cast<T*>(malloc(sizeof(T)));
-}
-
} // namespace
TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
@@ -98,7 +88,8 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
// need to be released by calling `free`.`
// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
- ErrorReporter* error_reporter, void** builtin_data) {
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data) {
auto parse_padding = [](Padding padding) {
switch (padding) {
case Padding_SAME:
@@ -150,7 +141,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = nullptr;
switch (op_type) {
case BuiltinOperator_CONV_2D: {
- TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ TfLiteConvParams* params = allocator->AllocatePOD<TfLiteConvParams>();
if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
params->padding = parse_padding(conv_params->padding());
params->stride_width = conv_params->stride_w();
@@ -165,7 +156,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_CAST: {
- TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
+ TfLiteCastParams* params = allocator->AllocatePOD<TfLiteCastParams>();
if (auto* schema_params = op->builtin_options_as_CastOptions()) {
auto in_status =
ConvertTensorType(schema_params->in_data_type(),
@@ -174,7 +165,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
ConvertTensorType(schema_params->out_data_type(),
&params->out_data_type, error_reporter);
if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
- free(params);
+ allocator->Deallocate(params);
return kTfLiteError;
}
}
@@ -183,7 +174,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_LSH_PROJECTION: {
TfLiteLSHProjectionParams* params =
- MallocPOD<TfLiteLSHProjectionParams>();
+ allocator->AllocatePOD<TfLiteLSHProjectionParams>();
if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
params->type = parseLSHProjectionType(lshParams->type());
}
@@ -193,7 +184,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_AVERAGE_POOL_2D:
case BuiltinOperator_MAX_POOL_2D:
case BuiltinOperator_L2_POOL_2D: {
- TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
+ TfLitePoolParams* params = allocator->AllocatePOD<TfLitePoolParams>();
if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
params->padding = parse_padding(pool_params->padding());
params->stride_width = pool_params->stride_w();
@@ -208,7 +199,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_DEPTHWISE_CONV_2D: {
TfLiteDepthwiseConvParams* params =
- MallocPOD<TfLiteDepthwiseConvParams>();
+ allocator->AllocatePOD<TfLiteDepthwiseConvParams>();
if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
params->padding = parse_padding(conv_params->padding());
params->stride_width = conv_params->stride_w();
@@ -224,7 +215,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SVDF: {
- TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
+ TfLiteSVDFParams* params = allocator->AllocatePOD<TfLiteSVDFParams>();
if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
params->rank = svdf_params->rank();
params->activation =
@@ -235,7 +226,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
- TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
+ TfLiteSequenceRNNParams* params =
+ allocator->AllocatePOD<TfLiteSequenceRNNParams>();
if (auto* sequence_rnn_params =
op->builtin_options_as_SequenceRNNOptions()) {
params->activation =
@@ -246,7 +238,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_RNN: {
- TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
+ TfLiteRNNParams* params = allocator->AllocatePOD<TfLiteRNNParams>();
if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
params->activation =
parse_activation(rnn_params->fused_activation_function());
@@ -256,7 +248,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
TfLiteEmbeddingLookupSparseParams* params =
- MallocPOD<TfLiteEmbeddingLookupSparseParams>();
+ allocator->AllocatePOD<TfLiteEmbeddingLookupSparseParams>();
if (auto* embedding_params =
op->builtin_options_as_EmbeddingLookupSparseOptions()) {
params->combiner = parseCombinerType(embedding_params->combiner());
@@ -266,7 +258,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_FULLY_CONNECTED: {
TfLiteFullyConnectedParams* params =
- MallocPOD<TfLiteFullyConnectedParams>();
+ allocator->AllocatePOD<TfLiteFullyConnectedParams>();
if (auto* fully_connected_params =
op->builtin_options_as_FullyConnectedOptions()) {
params->activation = parse_activation(
@@ -291,7 +283,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
// no-op.
break;
case BuiltinOperator_SOFTMAX: {
- TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
+ TfLiteSoftmaxParams* params =
+ allocator->AllocatePOD<TfLiteSoftmaxParams>();
if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
params->beta = softmax_params->beta();
}
@@ -300,7 +293,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_CONCATENATION: {
TfLiteConcatenationParams* params =
- MallocPOD<TfLiteConcatenationParams>();
+ allocator->AllocatePOD<TfLiteConcatenationParams>();
if (auto* concatenation_params =
op->builtin_options_as_ConcatenationOptions()) {
params->activation =
@@ -311,7 +304,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_MUL: {
- auto* params = MallocPOD<TfLiteMulParams>();
+ auto* params = allocator->AllocatePOD<TfLiteMulParams>();
if (auto* schema_params = op->builtin_options_as_MulOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -320,7 +313,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ADD: {
- auto* params = MallocPOD<TfLiteAddParams>();
+ auto* params = allocator->AllocatePOD<TfLiteAddParams>();
if (auto* schema_params = op->builtin_options_as_AddOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -329,7 +322,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_DIV: {
- auto* params = MallocPOD<TfLiteDivParams>();
+ auto* params = allocator->AllocatePOD<TfLiteDivParams>();
if (auto* schema_params = op->builtin_options_as_DivOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -338,7 +331,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SUB: {
- auto* params = MallocPOD<TfLiteSubParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSubParams>();
if (auto* schema_params = op->builtin_options_as_SubOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -347,7 +340,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_L2_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteL2NormParams>();
+ auto* params = allocator->AllocatePOD<TfLiteL2NormParams>();
if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -356,7 +349,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
+ auto* params = allocator->AllocatePOD<TfLiteLocalResponseNormParams>();
if (auto* schema_params =
op->builtin_options_as_LocalResponseNormalizationOptions()) {
params->radius = schema_params->radius();
@@ -370,7 +363,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
- TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
+ TfLiteLSTMParams* params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
params->activation =
parse_activation(lstm_params->fused_activation_function());
@@ -389,7 +382,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_RESIZE_BILINEAR: {
- auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+ auto* params = allocator->AllocatePOD<TfLiteResizeBilinearParams>();
if (auto* schema_params =
op->builtin_options_as_ResizeBilinearOptions()) {
params->align_corners = schema_params->align_corners();
@@ -398,7 +391,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_RESHAPE: {
- auto* params = MallocPOD<TfLiteReshapeParams>();
+ auto* params = allocator->AllocatePOD<TfLiteReshapeParams>();
if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
auto* new_shape = schema_params->new_shape();
FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
@@ -409,7 +402,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SKIP_GRAM: {
- TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
+ TfLiteSkipGramParams* params =
+ allocator->AllocatePOD<TfLiteSkipGramParams>();
if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
params->ngram_size = skip_gram_params->ngram_size();
params->max_skip_size = skip_gram_params->max_skip_size();
@@ -419,7 +413,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SPACE_TO_DEPTH: {
- auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSpaceToDepthParams>();
if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
params->block_size = schema_params->block_size();
}
@@ -427,7 +421,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_GATHER: {
- TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
+ TfLiteGatherParams* params = allocator->AllocatePOD<TfLiteGatherParams>();
params->axis = 0;
if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
params->axis = gather_params->axis();
@@ -442,7 +436,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_REDUCE_PROD:
case BuiltinOperator_REDUCE_ANY:
case BuiltinOperator_SUM: {
- auto* params = MallocPOD<TfLiteReducerParams>();
+ auto* params = allocator->AllocatePOD<TfLiteReducerParams>();
if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
params->keep_dims = schema_params->keep_dims();
}
@@ -450,7 +444,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SPLIT: {
- auto* params = MallocPOD<TfLiteSplitParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSplitParams>();
if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
params->num_splits = schema_params->num_splits();
}
@@ -458,7 +452,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SQUEEZE: {
- auto* params = MallocPOD<TfLiteSqueezeParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSqueezeParams>();
if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
const auto& squeeze_dims = schema_params->squeeze_dims();
FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
@@ -469,7 +463,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_STRIDED_SLICE: {
- auto* params = MallocPOD<TfLiteStridedSliceParams>();
+ auto* params = allocator->AllocatePOD<TfLiteStridedSliceParams>();
if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
params->begin_mask = schema_params->begin_mask();
params->end_mask = schema_params->end_mask();
@@ -481,7 +475,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ARG_MAX: {
- auto* params = MallocPOD<TfLiteArgMaxParams>();
+ auto* params = allocator->AllocatePOD<TfLiteArgMaxParams>();
if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
ConvertTensorType(schema_params->output_type(), &params->output_type,
error_reporter);
@@ -490,7 +484,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ARG_MIN: {
- auto* params = MallocPOD<TfLiteArgMinParams>();
+ auto* params = allocator->AllocatePOD<TfLiteArgMinParams>();
if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
ConvertTensorType(schema_params->output_type(), &params->output_type,
error_reporter);
@@ -500,7 +494,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_TRANSPOSE_CONV: {
TfLiteTransposeConvParams* params =
- MallocPOD<TfLiteTransposeConvParams>();
+ allocator->AllocatePOD<TfLiteTransposeConvParams>();
if (auto* transpose_conv_params =
op->builtin_options_as_TransposeConvOptions()) {
params->padding = parse_padding(transpose_conv_params->padding());
@@ -512,7 +506,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_SPARSE_TO_DENSE: {
TfLiteSparseToDenseParams* params =
- MallocPOD<TfLiteSparseToDenseParams>();
+ allocator->AllocatePOD<TfLiteSparseToDenseParams>();
if (auto* sparse_to_dense_params =
op->builtin_options_as_SparseToDenseOptions()) {
params->validate_indices = sparse_to_dense_params->validate_indices();
@@ -521,7 +515,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SHAPE: {
- auto* params = MallocPOD<TfLiteShapeParams>();
+ auto* params = allocator->AllocatePOD<TfLiteShapeParams>();
if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
ConvertTensorType(schema_params->out_type(), &params->out_type,
error_reporter);
@@ -530,7 +524,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_PACK: {
- TfLitePackParams* params = MallocPOD<TfLitePackParams>();
+ TfLitePackParams* params = allocator->AllocatePOD<TfLitePackParams>();
if (auto* pack_params = op->builtin_options_as_PackOptions()) {
params->values_count = pack_params->values_count();
params->axis = pack_params->axis();
@@ -544,7 +538,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
return kTfLiteError;
}
case BuiltinOperator_FAKE_QUANT: {
- auto* params = MallocPOD<TfLiteFakeQuantParams>();
+ auto* params = allocator->AllocatePOD<TfLiteFakeQuantParams>();
if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
params->min = schema_params->min();
params->max = schema_params->max();
@@ -555,7 +549,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ONE_HOT: {
- auto* params = MallocPOD<TfLiteOneHotParams>();
+ auto* params = allocator->AllocatePOD<TfLiteOneHotParams>();
if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
params->axis = schema_params->axis();
}
@@ -563,7 +557,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_UNPACK: {
- TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
+ TfLiteUnpackParams* params = allocator->AllocatePOD<TfLiteUnpackParams>();
if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
params->num = unpack_params->num();
params->axis = unpack_params->axis();
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
index 4dec6f9cfc..c770e627fd 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
@@ -26,6 +26,25 @@ limitations under the License.
namespace tflite {
+// Interface class for builtin data allocations.
+class BuiltinDataAllocator {
+ public:
+ virtual void* Allocate(size_t size) = 0;
+ virtual void Deallocate(void* data) = 0;
+
+ // Allocate a structure, but make sure it is a POD structure that doesn't
+ // require constructors to run. The reason we do this, is that Interpreter's C
+ // extension part will take ownership so destructors will not be run during
+ // deallocation.
+ template <typename T>
+ T* AllocatePOD() {
+ static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
+ return static_cast<T*>(this->Allocate(sizeof(T)));
+ }
+
+ virtual ~BuiltinDataAllocator() {}
+};
+
// Parse the appropriate data out of the op.
//
// This handles builtin data explicitly as there are flatbuffer schemas.
@@ -36,7 +55,8 @@ namespace tflite {
// function's responsibility to free it.
// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
- ErrorReporter* error_reporter, void** builtin_data);
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data);
// Converts the tensor data type used in the flat buffer to the representation
// used by the runtime.
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
index b12bdf43b2..8ae94e1d33 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
@@ -39,11 +39,31 @@ class MockErrorReporter : public ErrorReporter {
int buffer_size_;
};
+// Used to determine how the op data parsing function creates its working space.
+class MockDataAllocator : public BuiltinDataAllocator {
+ public:
+ MockDataAllocator() : is_allocated_(false) {}
+ void* Allocate(size_t size) override {
+ EXPECT_FALSE(is_allocated_);
+ const int max_size = kBufferSize;
+ EXPECT_LE(size, max_size);
+ is_allocated_ = true;
+ return buffer_;
+ }
+ void Deallocate(void* data) override { is_allocated_ = false; }
+
+ private:
+ static constexpr int kBufferSize = 1024;
+ char buffer_[kBufferSize];
+ bool is_allocated_;
+};
+
} // namespace
TEST(FlatbufferConversions, TestParseOpDataConv) {
MockErrorReporter mock_reporter;
ErrorReporter* reporter = &mock_reporter;
+ MockDataAllocator mock_allocator;
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<void> conv_options =
@@ -58,7 +78,7 @@ TEST(FlatbufferConversions, TestParseOpDataConv) {
const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer);
void* output_data = nullptr;
EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter,
- &output_data));
+ &mock_allocator, &output_data));
EXPECT_NE(nullptr, output_data);
TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data);
EXPECT_EQ(kTfLitePaddingSame, params->padding);
@@ -67,12 +87,12 @@ TEST(FlatbufferConversions, TestParseOpDataConv) {
EXPECT_EQ(kTfLiteActRelu, params->activation);
EXPECT_EQ(3, params->dilation_width_factor);
EXPECT_EQ(4, params->dilation_height_factor);
- free(output_data);
}
TEST(FlatbufferConversions, TestParseOpDataCustom) {
MockErrorReporter mock_reporter;
ErrorReporter* reporter = &mock_reporter;
+ MockDataAllocator mock_allocator;
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<void> null_options;
@@ -84,7 +104,7 @@ TEST(FlatbufferConversions, TestParseOpDataCustom) {
const Operator* custom_op = flatbuffers::GetRoot<Operator>(custom_pointer);
void* output_data = nullptr;
EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter,
- &output_data));
+ &mock_allocator, &output_data));
EXPECT_EQ(nullptr, output_data);
}