aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-25 07:50:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 07:53:39 -0700
commit10e0233d7f2f215577271d33a3b04506f93c13b1 (patch)
tree9147beacdeb93b1f7f1e3b97a61871188cef9339 /tensorflow/contrib/lite/interpreter.cc
parentfa5e84e7a498eb1386f4bc7d7076b957484e0972 (diff)
Small changes to interpreter.{h,cc}: refactoring plus improved error message.
PiperOrigin-RevId: 205992521
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r--tensorflow/contrib/lite/interpreter.cc37
1 files changed, 23 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 26fecceab0..5a5c907b6e 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -43,10 +43,13 @@ namespace {
TfLiteStatus ReportOpError(TfLiteContext* context, const TfLiteNode& node,
const TfLiteRegistration& registration,
int node_index, const char* message) {
- context->ReportError(context, "Node number %d (%s) %s.\n", node_index,
- EnumNameBuiltinOperator(static_cast<BuiltinOperator>(
- registration.builtin_code)),
- message);
+ context->ReportError(
+ context, "Node number %d (%s) %s.\n", node_index,
+ registration.custom_name
+ ? registration.custom_name
+ : EnumNameBuiltinOperator(
+ static_cast<BuiltinOperator>(registration.builtin_code)),
+ message);
return kTfLiteError;
}
@@ -131,9 +134,7 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
context_.SetExternalContext = SetExternalContext;
// Invalid to call these these except from TfLiteDelegate
- SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
- SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
- SetForbiddenContextFunction(&context_.GetExecutionPlan);
+ SwitchToKernelContext();
// Reserve some space for the tensors to avoid excessive resizing.
tensors_.reserve(kTensorsReservedCapacity);
@@ -924,6 +925,19 @@ void Interpreter::SetNumThreads(int num_threads) {
}
}
+void Interpreter::SwitchToDelegateContext() {
+ context_.GetNodeAndRegistration = GetNodeAndRegistration;
+ context_.ReplaceSubgraphsWithDelegateKernels =
+ ReplaceSubgraphsWithDelegateKernels;
+ context_.GetExecutionPlan = GetExecutionPlan;
+}
+
+void Interpreter::SwitchToKernelContext() {
+ SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
+ SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
+ SetForbiddenContextFunction(&context_.GetExecutionPlan);
+}
+
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
bool allow_dynamic_tensors) {
if (!allow_dynamic_tensors) {
@@ -950,17 +964,12 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
// TODO(aselle): Consider if it is worth storing pointers to delegates.
// Setup additional context interface.
- context_.GetNodeAndRegistration = GetNodeAndRegistration;
- context_.ReplaceSubgraphsWithDelegateKernels =
- ReplaceSubgraphsWithDelegateKernels;
- context_.GetExecutionPlan = GetExecutionPlan;
+ SwitchToDelegateContext();
TfLiteStatus status = delegate->Prepare(&context_, delegate);
// Remove additional context info.
- SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
- SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
- SetForbiddenContextFunction(&context_.GetExecutionPlan);
+ SwitchToKernelContext();
TF_LITE_ENSURE_OK(&context_, status);