aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.h')
-rw-r--r--tensorflow/contrib/lite/interpreter.h33
1 files changed, 31 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 033b8ee5fa..be149a8cc0 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -63,6 +63,10 @@ template <>
constexpr TfLiteType typeToTfLiteType<std::complex<float>>() {
return kTfLiteComplex64;
}
+template <>
+constexpr TfLiteType typeToTfLiteType<string>() {
+ return kTfLiteString;
+}
// Forward declare since NNAPIDelegate uses Interpreter.
class NNAPIDelegate;
@@ -107,7 +111,7 @@ class Interpreter {
// processing this model will be forwarded to the error_reporter object.
//
// Note, if error_reporter is nullptr, then a default StderrReporter is
- // used.
+ // used. Ownership of 'error_reporter' remains with the caller.
explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
~Interpreter();
@@ -410,6 +414,15 @@ class Interpreter {
}
private:
+ friend class InterpreterTest;
+
+ // Prevent 'context_' from accessing functions that are only available to
+ // delegated kernels.
+ void SwitchToKernelContext();
+
+ // Add delegate-only functions to 'context_'.
+ void SwitchToDelegateContext();
+
// Give 'op_reg' a chance to initialize itself using the contents of
// 'buffer'.
void* OpInit(const TfLiteRegistration& op_reg, const char* buffer,
@@ -496,6 +509,7 @@ class Interpreter {
// Update the execution graph to replace some of the nodes with stub
// nodes. Specifically any node index that has `nodes[index]==1` will be
// slated for replacement with a delegate kernel specified by registration.
+ // Ownership of 'nodes_to_replace' and 'delegate' remains with the caller.
// WARNING: This is an experimental interface that is subject to change.
TfLiteStatus ReplaceSubgraphsWithDelegateKernels(
TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
@@ -522,6 +536,18 @@ class Interpreter {
static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context,
TfLiteIntArray** execution_plan);
+ // Retrieve an existing external context by type.
+ TfLiteExternalContext* GetExternalContext(TfLiteExternalContextType type);
+ static TfLiteExternalContext* GetExternalContext(
+ struct TfLiteContext* context, TfLiteExternalContextType type);
+
+ // Set the value of an external context.
+ void SetExternalContext(TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx);
+ static void SetExternalContext(struct TfLiteContext* context,
+ TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx);
+
// Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra
// capacity. Calling this function may invalidate existing pointers to
// tensors. After calling this function, adding `kTensorsCapacityHeadroom`
@@ -611,7 +637,10 @@ class Interpreter {
bool tensor_resized_since_op_invoke_ = false;
// Profiler for this interpreter instance.
- profiling::Profiler* profiler_;
+ profiling::Profiler* profiler_ = nullptr;
+
+ // List of active external contexts.
+ TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
};
} // namespace tflite