aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter.cc
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2018-02-09 15:45:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-09 15:48:34 -0800
commit2adb6bbb1b4d31bba7113a4213bf5e7f0e154c78 (patch)
tree309680694a800d7ff3605e3a387418a0f4ef59f3 /tensorflow/contrib/lite/interpreter.cc
parent40ec7202b63c32f2f5ed57116096e33677c4b5df (diff)
Add delegate API to tflite.
- Context gets GetNodes, num_nodes and PartitionNodesIntoSubgraphs. - TfLiteDelegate provides one function that need be implemented - Delegates choose nodes and those nodes are all compacted into a new macro kernel. PiperOrigin-RevId: 185204338
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r--tensorflow/contrib/lite/interpreter.cc111
1 files changed, 111 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 5aa0cbafd6..6dea4e5916 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -77,6 +77,12 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
context_.tensors = nullptr;
context_.tensors_size = 0;
context_.gemm_context = nullptr;
+
+ // Invalid to call these these except from TfLiteDelegate
+ context_.GetNodeAndRegistration = nullptr;
+ context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
+ context_.GetExecutionPlan = nullptr;
+
// Reserve some space for the tensors to avoid excessive resizing.
tensors_.reserve(kSlotsToReserve);
nodes_and_registration_.reserve(kSlotsToReserve);
@@ -100,6 +106,78 @@ Interpreter::~Interpreter() {
}
}
+TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
+ TfLiteContext* context, TfLiteRegistration registration,
+ const TfLiteIntArray* nodes_to_replace) {
+ return static_cast<Interpreter*>(context->impl_)
+ ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace);
+}
+
+TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
+ TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) {
+ // Analyze the graph to find all independent subgraphs that are either
+ // fully not-this-delegate or this-delegate computation.
+ InterpreterInfo info(this);
+ std::vector<Subgraph> subgraphs;
+ PartitionGraphIntoIndependentSubgraphs(&info, nodes_to_replace, &subgraphs);
+
+ execution_plan_.clear();
+ for (auto& subgraph : subgraphs) {
+ // Turn subgraph.nodes into a TfLiteIntArray compatible data structure.
+ // TODO(aselle): Avoid this copy by constructing subgraph.nodes that way
+ // in the first place
+ subgraph.nodes.insert(subgraph.nodes.begin(),
+ static_cast<int>(subgraph.nodes.size()));
+ // Subgraphs calimed by the delegate should have a "macro" op created, the
+ // other subgraphs (kTfNonPartition) just have their nodes added back to
+ // the execution plan.
+ switch (subgraph.type) {
+ case Subgraph::kTfNonPartition:
+ for (auto it = subgraph.nodes.begin() + 1; it != subgraph.nodes.end();
+ ++it) {
+ execution_plan_.push_back(*it);
+ }
+ break;
+ case Subgraph::kTfPartition: {
+ void* builtin_data = nullptr;
+ int node_index;
+ // Create a node that represents computation of this subgraph.
+ AddNodeWithParameters(
+ subgraph.input_tensors, subgraph.output_tensors,
+ reinterpret_cast<const char*>(subgraph.nodes.data()),
+ subgraph.nodes.size() * sizeof(subgraph.nodes[0]), builtin_data,
+ &registration, &node_index);
+ } break;
+ case Subgraph::kTfUnexplored:
+ return kTfLiteError;
+ break;
+ }
+ }
+ return kTfLiteOk;
+}
+
+// Gets an TfLiteIntArray* representing the execution plan. The interpreter owns
+// this memory and it is only guaranteed to exist during the invocation of the
+// delegate prepare.
+TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) {
+ // TODO(aselle): Do not make a copy here
+ plan_cache_.reset(TfLiteIntArrayCreate(execution_plan_.size()));
+ *execution_plan = plan_cache_.get();
+ static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]),
+ "TfLiteIntArray and execution_plan do not contain same type.");
+ memcpy(plan_cache_->data, execution_plan_.data(),
+ sizeof(plan_cache_->data[0]));
+ return kTfLiteOk;
+}
+
+// WARNING: This is an experimental interface that is subject to change.
+// Entry point for C node plugin API to get the execution plan
+TfLiteStatus Interpreter::GetExecutionPlan(struct TfLiteContext* context,
+ TfLiteIntArray** execution_plan) {
+ return static_cast<Interpreter*>(context->impl_)
+ ->GetExecutionPlan(execution_plan);
+}
+
TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
TF_LITE_ENSURE_OK(&context_,
CheckTensorIndices("inputs", inputs.data(), inputs.size()));
@@ -200,6 +278,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
int new_node_index = nodes_and_registration_.size();
if (node_index) *node_index = new_node_index;
nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
+
auto& node_and_reg = nodes_and_registration_.back();
TfLiteNode& node = node_and_reg.first;
if (node.inputs) TfLiteIntArrayFree(node.inputs);
@@ -388,6 +467,22 @@ TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add,
->AddTensors(tensors_to_add, first_new_tensor_index);
}
+TfLiteStatus Interpreter::GetNodeAndRegistration(
+ int node_index, TfLiteNode** node, TfLiteRegistration** registration) {
+ TF_LITE_ENSURE(&context_, node_index < nodes_size() && node_index >= 0);
+ TF_LITE_ENSURE(&context_, node != nullptr && registration != nullptr);
+ *node = &nodes_and_registration_[node_index].first;
+ *registration = &nodes_and_registration_[node_index].second;
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::GetNodeAndRegistration(
+ struct TfLiteContext* context, int node_index, TfLiteNode** node,
+ TfLiteRegistration** registration) {
+ return static_cast<Interpreter*>(context->impl_)
+ ->GetNodeAndRegistration(node_index, node, registration);
+}
+
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantizationParams quantization,
@@ -498,4 +593,20 @@ void Interpreter::SetNumThreads(int num_threads) {
tflite::gemm_support::SetMaxNumThreads(&context_, num_threads);
}
+TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
+ // TODO(aselle): Consider if it is worth storing pointers to delegates.
+ // Setup additional context interface
+ context_.GetNodeAndRegistration = GetNodeAndRegistration;
+ context_.ReplaceSubgraphsWithDelegateKernels =
+ ReplaceSubgraphsWithDelegateKernels;
+ context_.GetExecutionPlan = GetExecutionPlan;
+
+ TfLiteStatus status = delegate->Prepare(&context_, delegate->data_);
+ // Remove additional context info.
+ context_.GetNodeAndRegistration = nullptr;
+ context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
+ context_.GetExecutionPlan = nullptr;
+ return status;
+}
+
} // namespace tflite