diff options
-rw-r--r-- | tensorflow/contrib/lite/context.h | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 86 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter_test.cc | 18 |
3 files changed, 84 insertions, 22 deletions
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 6491d8c86a..45184b05ec 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -415,6 +415,8 @@ typedef struct _TfLiteDelegate { typedef struct { TfLiteDelegate* delegate; TfLiteIntArray* nodes_to_replace; + TfLiteIntArray* input_tensors; + TfLiteIntArray* output_tensors; } TfLiteDelegateParams; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index f03c1c9fe9..cee57bba5e 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -139,31 +139,76 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( namespace { +// Copy a std::vector<int> to an existing TfLiteIntArray. +// This is a low-level data manipulation function, and it's caller's +// responsibility to ensure TfLiteIntArray has enough size. +void CopyVectorToTfLiteIntArray(const std::vector<int>& vec, + TfLiteIntArray* arr) { + arr->size = vec.size(); + memcpy(arr->data, vec.data(), sizeof(int) * arr->size); +} + // This function allocates a continuous memory space that contains a -// TfLiteDelegateParams followed by a TfLiteIntArray. The pointer will be -// deallocated by C `free` function later. -TfLiteDelegateParams* CreateDelegateParams( - TfLiteDelegate* delegate, const std::vector<int>& nodes_to_replace) { - int nodes_to_replace_size_in_bytes = - TfLiteIntArrayGetSizeInBytes(nodes_to_replace.size()); - void* allocation = - malloc(sizeof(TfLiteDelegateParams) + nodes_to_replace_size_in_bytes); +// TfLiteDelegateParams followed by a several TfLiteIntArray. +// When calling `free` at TfLiteDelegateParams*, all the allocated space +// will be freed together. +// +// +-----------------------------------+ +// | TfLiteDelegateParams | +// | TfLiteDelegate* delegate; | +// | TfLiteIntArray* nodes_to_replace; |--\ +// | TfLiteIntArray* input_tensors; |--+--\ +// | TfLiteIntArray* output_tensors; |--+--+--\ +// +-----------------------------------+ | | | +// | TfLiteIntArray (variable size) |<-/ | | +// +-----------------------------------+ | | +// | TfLiteIntArray (variable size) |<----/ | +// +-----------------------------------+ | +// | TfLiteIntArray (variable size) |<-------/ +// +-----------------------------------+ +TfLiteDelegateParams* CreateDelegateParams(TfLiteDelegate* delegate, + const Subgraph& subgraph) { + // Step 1: Calculate the allocation size. + int allocation_size = sizeof(TfLiteDelegateParams); + + int nodes_to_replace_size = + TfLiteIntArrayGetSizeInBytes(subgraph.nodes.size()); + allocation_size += nodes_to_replace_size; + + int input_tensors_size = + TfLiteIntArrayGetSizeInBytes(subgraph.input_tensors.size()); + allocation_size += input_tensors_size; + + int output_tensors_size = + TfLiteIntArrayGetSizeInBytes(subgraph.output_tensors.size()); + allocation_size += output_tensors_size; + + // Step 2: Allocate the memory. + // Use `char*` for conveniently step through the allocated space by bytes. + char* allocation = reinterpret_cast<char*>(malloc(allocation_size)); + + // Step 3: Fill all data structures structures. TfLiteDelegateParams* params = reinterpret_cast<TfLiteDelegateParams*>(allocation); - TfLiteIntArray* nodes_to_replace_arr = reinterpret_cast<TfLiteIntArray*>( - static_cast<char*>(allocation) + sizeof(TfLiteDelegateParams)); + params->delegate = delegate; + allocation += sizeof(TfLiteDelegateParams); - nodes_to_replace_arr->size = nodes_to_replace.size(); - for (int i = 0; i < nodes_to_replace.size(); ++i) { - nodes_to_replace_arr->data[i] = nodes_to_replace[i]; - } + params->nodes_to_replace = reinterpret_cast<TfLiteIntArray*>(allocation); + CopyVectorToTfLiteIntArray(subgraph.nodes, params->nodes_to_replace); + allocation += nodes_to_replace_size; + + params->input_tensors = reinterpret_cast<TfLiteIntArray*>(allocation); + CopyVectorToTfLiteIntArray(subgraph.input_tensors, params->input_tensors); + allocation += input_tensors_size; + + params->output_tensors = reinterpret_cast<TfLiteIntArray*>(allocation); + CopyVectorToTfLiteIntArray(subgraph.output_tensors, params->output_tensors); + allocation += output_tensors_size; - params->delegate = delegate; - params->nodes_to_replace = nodes_to_replace_arr; return params; } -} // Anonymous namespace +} // namespace TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace, @@ -192,8 +237,7 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( case Subgraph::kTfPartition: { int node_index; - TfLiteDelegateParams* params = - CreateDelegateParams(delegate, subgraph.nodes); + TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph); AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors, nullptr, 0, params, ®istration, &node_index); @@ -229,8 +273,8 @@ TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) { *execution_plan = plan_cache_.get(); static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]), "TfLiteIntArray and execution_plan do not contain same type."); - memcpy(plan_cache_->data, execution_plan_.data(), - sizeof(plan_cache_->data[0]) * execution_plan_.size()); + std::memcpy(plan_cache_->data, execution_plan_.data(), + sizeof(plan_cache_->data[0]) * execution_plan_.size()); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 17eb2f4b07..7a029c7df8 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -923,8 +923,24 @@ TEST_F(TestDelegate, BasicDelegate) { ASSERT_EQ(interpreter_->execution_plan().size(), 1); int node = interpreter_->execution_plan()[0]; const auto* node_and_reg = interpreter_->node_and_registration(node); - ASSERT_EQ(node_and_reg->second.custom_name, + EXPECT_EQ(node_and_reg->second.custom_name, SimpleDelegate::FakeFusedRegistration().custom_name); + + const TfLiteDelegateParams* params = + reinterpret_cast<const TfLiteDelegateParams*>( + node_and_reg->first.builtin_data); + ASSERT_EQ(params->nodes_to_replace->size, 3); + EXPECT_EQ(params->nodes_to_replace->data[0], 0); + EXPECT_EQ(params->nodes_to_replace->data[1], 1); + EXPECT_EQ(params->nodes_to_replace->data[2], 2); + + ASSERT_EQ(params->input_tensors->size, 2); + EXPECT_EQ(params->input_tensors->data[0], 0); + EXPECT_EQ(params->input_tensors->data[1], 1); + + ASSERT_EQ(params->output_tensors->size, 2); + EXPECT_EQ(params->output_tensors->data[0], 3); + EXPECT_EQ(params->output_tensors->data[1], 4); } TEST_F(TestDelegate, ComplexDeligate) { |