diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-25 07:50:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-25 07:53:39 -0700 |
commit | 10e0233d7f2f215577271d33a3b04506f93c13b1 (patch) | |
tree | 9147beacdeb93b1f7f1e3b97a61871188cef9339 /tensorflow/contrib/lite/interpreter.cc | |
parent | fa5e84e7a498eb1386f4bc7d7076b957484e0972 (diff) |
Small changes to interpreter.{h,cc}: refactoring plus improved error message.
PiperOrigin-RevId: 205992521
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 37 |
1 files changed, 23 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 26fecceab0..5a5c907b6e 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -43,10 +43,13 @@ namespace { TfLiteStatus ReportOpError(TfLiteContext* context, const TfLiteNode& node, const TfLiteRegistration& registration, int node_index, const char* message) { - context->ReportError(context, "Node number %d (%s) %s.\n", node_index, - EnumNameBuiltinOperator(static_cast<BuiltinOperator>( - registration.builtin_code)), - message); + context->ReportError( + context, "Node number %d (%s) %s.\n", node_index, + registration.custom_name + ? registration.custom_name + : EnumNameBuiltinOperator( + static_cast<BuiltinOperator>(registration.builtin_code)), + message); return kTfLiteError; } @@ -131,9 +134,7 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) context_.SetExternalContext = SetExternalContext; // Invalid to call these these except from TfLiteDelegate - SetForbiddenContextFunction(&context_.GetNodeAndRegistration); - SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); - SetForbiddenContextFunction(&context_.GetExecutionPlan); + SwitchToKernelContext(); // Reserve some space for the tensors to avoid excessive resizing. tensors_.reserve(kTensorsReservedCapacity); @@ -924,6 +925,19 @@ void Interpreter::SetNumThreads(int num_threads) { } } +void Interpreter::SwitchToDelegateContext() { + context_.GetNodeAndRegistration = GetNodeAndRegistration; + context_.ReplaceSubgraphsWithDelegateKernels = + ReplaceSubgraphsWithDelegateKernels; + context_.GetExecutionPlan = GetExecutionPlan; +} + +void Interpreter::SwitchToKernelContext() { + SetForbiddenContextFunction(&context_.GetNodeAndRegistration); + SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); + SetForbiddenContextFunction(&context_.GetExecutionPlan); +} + TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, bool allow_dynamic_tensors) { if (!allow_dynamic_tensors) { @@ -950,17 +964,12 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, // TODO(aselle): Consider if it is worth storing pointers to delegates. // Setup additional context interface. - context_.GetNodeAndRegistration = GetNodeAndRegistration; - context_.ReplaceSubgraphsWithDelegateKernels = - ReplaceSubgraphsWithDelegateKernels; - context_.GetExecutionPlan = GetExecutionPlan; + SwitchToDelegateContext(); TfLiteStatus status = delegate->Prepare(&context_, delegate); // Remove additional context info. - SetForbiddenContextFunction(&context_.GetNodeAndRegistration); - SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); - SetForbiddenContextFunction(&context_.GetExecutionPlan); + SwitchToKernelContext(); TF_LITE_ENSURE_OK(&context_, status); |