diff options
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.h')
-rw-r--r-- | tensorflow/contrib/lite/interpreter.h | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 033b8ee5fa..be149a8cc0 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -63,6 +63,10 @@ template <> constexpr TfLiteType typeToTfLiteType<std::complex<float>>() { return kTfLiteComplex64; } +template <> +constexpr TfLiteType typeToTfLiteType<string>() { + return kTfLiteString; +} // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; @@ -107,7 +111,7 @@ class Interpreter { // processing this model will be forwarded to the error_reporter object. // // Note, if error_reporter is nullptr, then a default StderrReporter is - // used. + // used. Ownership of 'error_reporter' remains with the caller. explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); ~Interpreter(); @@ -410,6 +414,15 @@ class Interpreter { } private: + friend class InterpreterTest; + + // Prevent 'context_' from accessing functions that are only available to + // delegated kernels. + void SwitchToKernelContext(); + + // Add delegate-only functions to 'context_'. + void SwitchToDelegateContext(); + // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, @@ -496,6 +509,7 @@ class Interpreter { // Update the execution graph to replace some of the nodes with stub // nodes. Specifically any node index that has `nodes[index]==1` will be // slated for replacement with a delegate kernel specified by registration. + // Ownership of 'nodes_to_replace' and 'delegate' remains with the caller. // WARNING: This is an experimental interface that is subject to change. TfLiteStatus ReplaceSubgraphsWithDelegateKernels( TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace, @@ -522,6 +536,18 @@ class Interpreter { static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, TfLiteIntArray** execution_plan); + // Retrieve an existing external context by type. + TfLiteExternalContext* GetExternalContext(TfLiteExternalContextType type); + static TfLiteExternalContext* GetExternalContext( + struct TfLiteContext* context, TfLiteExternalContextType type); + + // Set the value of an external context. + void SetExternalContext(TfLiteExternalContextType type, + TfLiteExternalContext* ctx); + static void SetExternalContext(struct TfLiteContext* context, + TfLiteExternalContextType type, + TfLiteExternalContext* ctx); + // Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra // capacity. Calling this function may invalidate existing pointers to // tensors. After calling this function, adding `kTensorsCapacityHeadroom` @@ -611,7 +637,10 @@ class Interpreter { bool tensor_resized_since_op_invoke_ = false; // Profiler for this interpreter instance. - profiling::Profiler* profiler_; + profiling::Profiler* profiler_ = nullptr; + + // List of active external contexts. + TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts]; }; } // namespace tflite |