aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/context.h2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc86
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc18
3 files changed, 84 insertions, 22 deletions
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index 6491d8c86a..45184b05ec 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -415,6 +415,8 @@ typedef struct _TfLiteDelegate {
typedef struct {
TfLiteDelegate* delegate;
TfLiteIntArray* nodes_to_replace;
+ TfLiteIntArray* input_tensors;
+ TfLiteIntArray* output_tensors;
} TfLiteDelegateParams;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index f03c1c9fe9..cee57bba5e 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -139,31 +139,76 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
namespace {
+// Copy a std::vector<int> to an existing TfLiteIntArray.
+// This is a low-level data manipulation function, and it's caller's
+// responsibility to ensure TfLiteIntArray has enough size.
+void CopyVectorToTfLiteIntArray(const std::vector<int>& vec,
+ TfLiteIntArray* arr) {
+ arr->size = vec.size();
+ memcpy(arr->data, vec.data(), sizeof(int) * arr->size);
+}
+
// This function allocates a continuous memory space that contains a
-// TfLiteDelegateParams followed by a TfLiteIntArray. The pointer will be
-// deallocated by C `free` function later.
-TfLiteDelegateParams* CreateDelegateParams(
- TfLiteDelegate* delegate, const std::vector<int>& nodes_to_replace) {
- int nodes_to_replace_size_in_bytes =
- TfLiteIntArrayGetSizeInBytes(nodes_to_replace.size());
- void* allocation =
- malloc(sizeof(TfLiteDelegateParams) + nodes_to_replace_size_in_bytes);
+// TfLiteDelegateParams followed by a several TfLiteIntArray.
+// When calling `free` at TfLiteDelegateParams*, all the allocated space
+// will be freed together.
+//
+// +-----------------------------------+
+// | TfLiteDelegateParams |
+// | TfLiteDelegate* delegate; |
+// | TfLiteIntArray* nodes_to_replace; |--\
+// | TfLiteIntArray* input_tensors; |--+--\
+// | TfLiteIntArray* output_tensors; |--+--+--\
+// +-----------------------------------+ | | |
+// | TfLiteIntArray (variable size) |<-/ | |
+// +-----------------------------------+ | |
+// | TfLiteIntArray (variable size) |<----/ |
+// +-----------------------------------+ |
+// | TfLiteIntArray (variable size) |<-------/
+// +-----------------------------------+
+TfLiteDelegateParams* CreateDelegateParams(TfLiteDelegate* delegate,
+ const Subgraph& subgraph) {
+ // Step 1: Calculate the allocation size.
+ int allocation_size = sizeof(TfLiteDelegateParams);
+
+ int nodes_to_replace_size =
+ TfLiteIntArrayGetSizeInBytes(subgraph.nodes.size());
+ allocation_size += nodes_to_replace_size;
+
+ int input_tensors_size =
+ TfLiteIntArrayGetSizeInBytes(subgraph.input_tensors.size());
+ allocation_size += input_tensors_size;
+
+ int output_tensors_size =
+ TfLiteIntArrayGetSizeInBytes(subgraph.output_tensors.size());
+ allocation_size += output_tensors_size;
+
+ // Step 2: Allocate the memory.
+ // Use `char*` for conveniently step through the allocated space by bytes.
+ char* allocation = reinterpret_cast<char*>(malloc(allocation_size));
+
+ // Step 3: Fill all data structures structures.
TfLiteDelegateParams* params =
reinterpret_cast<TfLiteDelegateParams*>(allocation);
- TfLiteIntArray* nodes_to_replace_arr = reinterpret_cast<TfLiteIntArray*>(
- static_cast<char*>(allocation) + sizeof(TfLiteDelegateParams));
+ params->delegate = delegate;
+ allocation += sizeof(TfLiteDelegateParams);
- nodes_to_replace_arr->size = nodes_to_replace.size();
- for (int i = 0; i < nodes_to_replace.size(); ++i) {
- nodes_to_replace_arr->data[i] = nodes_to_replace[i];
- }
+ params->nodes_to_replace = reinterpret_cast<TfLiteIntArray*>(allocation);
+ CopyVectorToTfLiteIntArray(subgraph.nodes, params->nodes_to_replace);
+ allocation += nodes_to_replace_size;
+
+ params->input_tensors = reinterpret_cast<TfLiteIntArray*>(allocation);
+ CopyVectorToTfLiteIntArray(subgraph.input_tensors, params->input_tensors);
+ allocation += input_tensors_size;
+
+ params->output_tensors = reinterpret_cast<TfLiteIntArray*>(allocation);
+ CopyVectorToTfLiteIntArray(subgraph.output_tensors, params->output_tensors);
+ allocation += output_tensors_size;
- params->delegate = delegate;
- params->nodes_to_replace = nodes_to_replace_arr;
return params;
}
-} // Anonymous namespace
+} // namespace
TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
@@ -192,8 +237,7 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
case Subgraph::kTfPartition: {
int node_index;
- TfLiteDelegateParams* params =
- CreateDelegateParams(delegate, subgraph.nodes);
+ TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph);
AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors,
nullptr, 0, params, &registration, &node_index);
@@ -229,8 +273,8 @@ TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) {
*execution_plan = plan_cache_.get();
static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]),
"TfLiteIntArray and execution_plan do not contain same type.");
- memcpy(plan_cache_->data, execution_plan_.data(),
- sizeof(plan_cache_->data[0]) * execution_plan_.size());
+ std::memcpy(plan_cache_->data, execution_plan_.data(),
+ sizeof(plan_cache_->data[0]) * execution_plan_.size());
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 17eb2f4b07..7a029c7df8 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -923,8 +923,24 @@ TEST_F(TestDelegate, BasicDelegate) {
ASSERT_EQ(interpreter_->execution_plan().size(), 1);
int node = interpreter_->execution_plan()[0];
const auto* node_and_reg = interpreter_->node_and_registration(node);
- ASSERT_EQ(node_and_reg->second.custom_name,
+ EXPECT_EQ(node_and_reg->second.custom_name,
SimpleDelegate::FakeFusedRegistration().custom_name);
+
+ const TfLiteDelegateParams* params =
+ reinterpret_cast<const TfLiteDelegateParams*>(
+ node_and_reg->first.builtin_data);
+ ASSERT_EQ(params->nodes_to_replace->size, 3);
+ EXPECT_EQ(params->nodes_to_replace->data[0], 0);
+ EXPECT_EQ(params->nodes_to_replace->data[1], 1);
+ EXPECT_EQ(params->nodes_to_replace->data[2], 2);
+
+ ASSERT_EQ(params->input_tensors->size, 2);
+ EXPECT_EQ(params->input_tensors->data[0], 0);
+ EXPECT_EQ(params->input_tensors->data[1], 1);
+
+ ASSERT_EQ(params->output_tensors->size, 2);
+ EXPECT_EQ(params->output_tensors->data[0], 3);
+ EXPECT_EQ(params->output_tensors->data[1], 4);
}
TEST_F(TestDelegate, ComplexDeligate) {