aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r--tensorflow/contrib/lite/interpreter.cc183
1 files changed, 136 insertions, 47 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 62a0b1ff08..e38597495d 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -22,20 +22,37 @@ limitations under the License.
#include "tensorflow/contrib/lite/arena_planner.h"
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/graph_info.h"
-#include "tensorflow/contrib/lite/kernels/eigen_support.h"
-#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/memory_planner.h"
+#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#endif
#include "tensorflow/contrib/lite/profiling/profiler.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
+#ifdef TFLITE_MCU
+class NNAPIDelegate {};
+#endif
namespace {
+TfLiteStatus ReportOpError(TfLiteContext* context, const TfLiteNode& node,
+ const TfLiteRegistration& registration,
+ int node_index, const char* message) {
+ context->ReportError(
+ context, "Node number %d (%s) %s.\n", node_index,
+ registration.custom_name
+ ? registration.custom_name
+ : EnumNameBuiltinOperator(
+ static_cast<BuiltinOperator>(registration.builtin_code)),
+ message);
+ return kTfLiteError;
+}
+
// Stub method which returns kTfLiteError when the function is forbidden.
// We're registrating this function to several different function to save
// compiled binary size. Please note the restrictions:
@@ -53,6 +70,19 @@ void SetForbiddenContextFunction(FunctionType* func) {
*func = reinterpret_cast<FunctionType>(ForbiddenContextFunction);
}
+// Returns true if at least one tensor in the given list is kTfLiteDynamic.
+template <typename TensorIntArray>
+bool HasDynamicTensorImpl(const TfLiteContext& context,
+ const TensorIntArray& int_array) {
+ for (int i : int_array) {
+ const TfLiteTensor& tensor = context.tensors[i];
+ if (tensor.allocation_type == kTfLiteDynamic) {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace
// A trivial implementation of GraphInfo around the Interpreter.
@@ -99,19 +129,22 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
context_.AddTensors = AddTensors;
context_.tensors = nullptr;
context_.tensors_size = 0;
- context_.eigen_context = nullptr;
- context_.gemm_context = nullptr;
context_.recommended_num_threads = -1;
+ context_.GetExternalContext = GetExternalContext;
+ context_.SetExternalContext = SetExternalContext;
// Invalid to call these these except from TfLiteDelegate
- SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
- SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
- SetForbiddenContextFunction(&context_.GetExecutionPlan);
+ SwitchToKernelContext();
// Reserve some space for the tensors to avoid excessive resizing.
tensors_.reserve(kTensorsReservedCapacity);
nodes_and_registration_.reserve(kTensorsReservedCapacity);
next_execution_plan_index_to_prepare_ = 0;
+
+ for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
+ external_contexts_[i] = nullptr;
+ }
+
UseNNAPI(false);
}
@@ -246,8 +279,9 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
int node_index;
TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph);
- AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors,
- nullptr, 0, params, &registration, &node_index);
+ TF_LITE_ENSURE_STATUS(AddNodeWithParameters(
+ subgraph.input_tensors, subgraph.output_tensors, nullptr, 0, params,
+ &registration, &node_index));
// Initialize the output tensors's delegate-related fields.
for (int tensor_index : subgraph.output_tensors) {
@@ -269,6 +303,33 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
return kTfLiteOk;
}
+TfLiteExternalContext* Interpreter::GetExternalContext(
+ TfLiteExternalContextType type) {
+ if (type >= 0 && type < kTfLiteMaxExternalContexts) {
+ return external_contexts_[type];
+ }
+ return nullptr;
+}
+
+TfLiteExternalContext* Interpreter::GetExternalContext(
+ struct TfLiteContext* context, TfLiteExternalContextType type) {
+ return static_cast<Interpreter*>(context->impl_)->GetExternalContext(type);
+}
+
+void Interpreter::SetExternalContext(TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx) {
+ if (type >= 0 && type < kTfLiteMaxExternalContexts) {
+ external_contexts_[type] = ctx;
+ }
+}
+
+void Interpreter::SetExternalContext(struct TfLiteContext* context,
+ TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx) {
+ return static_cast<Interpreter*>(context->impl_)
+ ->SetExternalContext(type, ctx);
+}
+
// Gets an TfLiteIntArray* representing the execution plan. The interpreter owns
// this memory and it is only guaranteed to exist during the invocation of the
// delegate prepare.
@@ -372,23 +433,33 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
}
TfLiteStatus Interpreter::AllocateTensors() {
- next_execution_plan_index_to_prepare_ = 0;
- if (memory_planner_) {
- TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations());
- }
-
if (!consistent_) {
ReportError(&context_, "AllocateTensors() called on inconsistent model.");
return kTfLiteError;
}
- TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
+ // Explicit (re)allocation is necessary if nodes have been changed or tensors
+ // have been resized. For inputs marked as dynamic, we can't short-circuit the
+ // allocation as the client may have done the resize manually.
+ if (state_ != kStateUninvokable && !HasDynamicTensorImpl(context_, inputs_)) {
+ return kTfLiteOk;
+ }
- if (state_ == kStateUninvokable) {
- state_ = kStateInvokable;
+ next_execution_plan_index_to_prepare_ = 0;
+ if (memory_planner_) {
+ TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations());
}
- TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
- state_ == kStateInvokableAndImmutable);
+
+ TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
+
+ state_ = kStateInvokable;
+
+ // Reset the variable tensors to zero after (re)allocating the tensors.
+ // Developers shouldn't rely on the side effect of this function to reset
+ // variable tesnsors. They should call `ResetVariableTensorsToZero` directly
+ // instead.
+ ResetVariableTensorsToZero();
+
return kTfLiteOk;
}
@@ -481,26 +552,26 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
"ResizeInputTensor is disallowed when graph is immutable.");
return kTfLiteError;
}
- state_ = kStateUninvokable;
// TODO(aselle): All bounds checks can be implemented as one-sided bounds
// checks by casting to unsigned for efficiency. Profile before doing this.
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
- TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims);
- return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
+ TfLiteTensor* tensor = &context_.tensors[tensor_index];
+
+ // Short-circuit the state change if the dimensions don't change, avoiding
+ // unnecessary (re)allocations.
+ if (EqualArrayAndTfLiteIntArray(tensor->dims, dims.size(), dims.data())) {
+ return kTfLiteOk;
+ }
+
+ state_ = kStateUninvokable;
+ return ResizeTensorImpl(tensor, ConvertVectorToTfLiteIntArray(dims));
}
-// Returns true if at least one tensor in the given list is kTfLiteDynamic.
bool HasDynamicTensor(const TfLiteContext& context,
- const TfLiteIntArray* tensors) {
- for (int i = 0; i < tensors->size; ++i) {
- const TfLiteTensor& tensor = context.tensors[tensors->data[i]];
- if (tensor.allocation_type == kTfLiteDynamic) {
- return true;
- }
- }
- return false;
+ const TfLiteIntArray* int_array) {
+ return HasDynamicTensorImpl(context, TfLiteIntArrayView{int_array});
}
TfLiteStatus Interpreter::PrepareOpsStartingAt(
@@ -513,7 +584,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt(
nodes_and_registration_[node_index].second;
EnsureTensorsVectorCapacity();
if (OpPrepare(registration, &node) == kTfLiteError) {
- return kTfLiteError;
+ return ReportOpError(&context_, node, registration, node_index,
+ "failed to prepare");
}
*last_execution_plan_index_prepared = execution_plan_index;
@@ -531,7 +603,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt(
TfLiteStatus Interpreter::PrepareOpsAndTensors() {
if (!memory_planner_) {
memory_planner_.reset(new ArenaPlanner(
- &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this))));
+ &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)),
+ /*preserve_inputs=*/true, /*preserve_intermediates*/ false));
memory_planner_->PlanAllocations();
}
@@ -557,6 +630,7 @@ TfLiteStatus Interpreter::Invoke() {
}
TfLiteStatus status = kTfLiteOk;
+#ifndef TFLITE_MCU
if (nnapi_delegate_) {
if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) {
TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
@@ -570,6 +644,7 @@ TfLiteStatus Interpreter::Invoke() {
return kTfLiteError;
}
}
+#endif
// Invocations are always done in node order.
// Note that calling Invoke repeatedly will cause the original memory plan to
@@ -610,7 +685,8 @@ TfLiteStatus Interpreter::Invoke() {
EnsureTensorsVectorCapacity();
tensor_resized_since_op_invoke_ = false;
if (OpInvoke(registration, &node) == kTfLiteError) {
- status = kTfLiteError;
+ status = ReportOpError(&context_, node, registration, node_index,
+ "failed to invoke");
}
// Force execution prep for downstream ops if the latest op triggered the
@@ -826,6 +902,7 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
}
void Interpreter::UseNNAPI(bool enable) {
+#ifndef TFLITE_MCU
// TODO(aselle): This is a workaround for finding if NNAPI exists.
// We also need to make sure getLibraryHandle() is renamed to be NNAPI
// prefixed.
@@ -835,15 +912,31 @@ void Interpreter::UseNNAPI(bool enable) {
} else if (!nnapi_delegate_) {
nnapi_delegate_.reset(new NNAPIDelegate);
}
+#endif
}
void Interpreter::SetNumThreads(int num_threads) {
context_.recommended_num_threads = num_threads;
- // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to
- // be required in order to compile the framework.
- gemm_support::SetNumThreads(&context_, num_threads);
- eigen_support::SetNumThreads(&context_, num_threads);
+ for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
+ auto* c = external_contexts_[i];
+ if (c && c->Refresh) {
+ c->Refresh(&context_);
+ }
+ }
+}
+
+void Interpreter::SwitchToDelegateContext() {
+ context_.GetNodeAndRegistration = GetNodeAndRegistration;
+ context_.ReplaceSubgraphsWithDelegateKernels =
+ ReplaceSubgraphsWithDelegateKernels;
+ context_.GetExecutionPlan = GetExecutionPlan;
+}
+
+void Interpreter::SwitchToKernelContext() {
+ SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
+ SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
+ SetForbiddenContextFunction(&context_.GetExecutionPlan);
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
@@ -872,24 +965,20 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
// TODO(aselle): Consider if it is worth storing pointers to delegates.
// Setup additional context interface.
- context_.GetNodeAndRegistration = GetNodeAndRegistration;
- context_.ReplaceSubgraphsWithDelegateKernels =
- ReplaceSubgraphsWithDelegateKernels;
- context_.GetExecutionPlan = GetExecutionPlan;
+ SwitchToDelegateContext();
TfLiteStatus status = delegate->Prepare(&context_, delegate);
// Remove additional context info.
- SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
- SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
- SetForbiddenContextFunction(&context_.GetExecutionPlan);
+ SwitchToKernelContext();
TF_LITE_ENSURE_OK(&context_, status);
if (!allow_dynamic_tensors) {
+ // Reset the state to force tensor/op reallocation.
+ state_ = kStateUninvokable;
TF_LITE_ENSURE_OK(&context_, AllocateTensors());
- TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
- state_ == kStateInvokableAndImmutable);
+ TF_LITE_ENSURE_EQ(&context_, state_, kStateInvokable);
// After using a delegate which doesn't support dynamic tensors, make the
// entire graph immutable.
state_ = kStateInvokableAndImmutable;