diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-03 13:17:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-03 13:20:38 -0700 |
commit | 3340c2da43f8b2313692aaad1a94da6c4a4e4106 (patch) | |
tree | 24f1f45d83c771c2324dec4f98b022faaca29849 /tensorflow | |
parent | 02ed358a986496e387d5f2e52865b10606e52c0a (diff) |
Remove framework's dependency on eigen and gemmlowp.
PiperOrigin-RevId: 203172717
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/lite/context.h | 33 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 52 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.h | 17 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter_test.cc | 50 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/eigen_support.cc | 31 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/eigen_support.h | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/gemm_support.cc | 55 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/gemm_support.h | 3 |
8 files changed, 190 insertions, 54 deletions
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 4f260ad40a..1ff8843fa7 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -39,6 +39,26 @@ extern "C" { typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; +// The list of external context types known to TF Lite. This list exists solely +// to avoid conflicts and to ensure ops can share the external contexts they +// need. Access to the external contexts is controled by one of the +// corresponding support files. +typedef enum { + kTfLiteEigenContext = 0, // include eigen_support.h to use. + kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. + kTfLiteMaxExternalContexts = 2 +} TfLiteExternalContextType; + +// An external context is a collection of information unrelated to the TF Lite +// framework, but useful to a subset of the ops. TF Lite knows very little +// about about the actual contexts, but it keeps a list of them, and is able to +// refresh them if configurations like the number of recommended threads +// change. +typedef struct { + TfLiteExternalContextType type; + TfLiteStatus (*Refresh)(struct TfLiteContext* context); +} TfLiteExternalContext; + // Forward declare so GetNode can use this is in Context. typedef struct _TfLiteRegistration TfLiteRegistration; typedef struct _TfLiteDelegate TfLiteDelegate; @@ -339,10 +359,15 @@ typedef struct TfLiteContext { // eigen. int recommended_num_threads; - // TODO(ahentz): we should create a more general mechanism for this sort of - // library-global objects. - void* gemm_context; - void* eigen_context; + // Access external contexts by type. + // WARNING: This is an experimental interface that is subject to change. + TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*, + TfLiteExternalContextType); + // Set the value of a external context. Does not take ownership of the + // pointer. + // WARNING: This is an experimental interface that is subject to change. + void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType, + TfLiteExternalContext*); } TfLiteContext; typedef struct _TfLiteRegistration { diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 3089a4c568..521216a4f1 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -25,10 +25,6 @@ limitations under the License. #include "tensorflow/contrib/lite/context_util.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/graph_info.h" -#ifndef TFLITE_MCU -#include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "tensorflow/contrib/lite/kernels/gemm_support.h" -#endif #include "tensorflow/contrib/lite/memory_planner.h" #ifndef TFLITE_MCU #include "tensorflow/contrib/lite/nnapi_delegate.h" @@ -120,9 +116,9 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) context_.AddTensors = AddTensors; context_.tensors = nullptr; context_.tensors_size = 0; - context_.eigen_context = nullptr; - context_.gemm_context = nullptr; context_.recommended_num_threads = -1; + context_.GetExternalContext = GetExternalContext; + context_.SetExternalContext = SetExternalContext; // Invalid to call these these except from TfLiteDelegate SetForbiddenContextFunction(&context_.GetNodeAndRegistration); @@ -133,6 +129,11 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) tensors_.reserve(kTensorsReservedCapacity); nodes_and_registration_.reserve(kTensorsReservedCapacity); next_execution_plan_index_to_prepare_ = 0; + + for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) { + external_contexts_[i] = nullptr; + } + UseNNAPI(false); } @@ -290,6 +291,33 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( return kTfLiteOk; } +TfLiteExternalContext* Interpreter::GetExternalContext( + TfLiteExternalContextType type) { + if (type >= 0 && type < kTfLiteMaxExternalContexts) { + return external_contexts_[type]; + } + return nullptr; +} + +TfLiteExternalContext* Interpreter::GetExternalContext( + struct TfLiteContext* context, TfLiteExternalContextType type) { + return static_cast<Interpreter*>(context->impl_)->GetExternalContext(type); +} + +void Interpreter::SetExternalContext(TfLiteExternalContextType type, + TfLiteExternalContext* ctx) { + if (type >= 0 && type < kTfLiteMaxExternalContexts) { + external_contexts_[type] = ctx; + } +} + +void Interpreter::SetExternalContext(struct TfLiteContext* context, + TfLiteExternalContextType type, + TfLiteExternalContext* ctx) { + return static_cast<Interpreter*>(context->impl_) + ->SetExternalContext(type, ctx); +} + // Gets an TfLiteIntArray* representing the execution plan. The interpreter owns // this memory and it is only guaranteed to exist during the invocation of the // delegate prepare. @@ -869,12 +897,12 @@ void Interpreter::UseNNAPI(bool enable) { void Interpreter::SetNumThreads(int num_threads) { context_.recommended_num_threads = num_threads; - // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to - // be required in order to compile the framework. -#ifndef TFLITE_MCU - gemm_support::SetNumThreads(&context_, num_threads); - eigen_support::SetNumThreads(&context_, num_threads); -#endif + for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) { + auto* c = external_contexts_[i]; + if (c && c->Refresh) { + c->Refresh(&context_); + } + } } TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 033b8ee5fa..b69c50fbfc 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -410,6 +410,8 @@ class Interpreter { } private: + friend class InterpreterTest; + // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, @@ -522,6 +524,18 @@ class Interpreter { static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, TfLiteIntArray** execution_plan); + // Retrieve an existing external context by type. + TfLiteExternalContext* GetExternalContext(TfLiteExternalContextType type); + static TfLiteExternalContext* GetExternalContext( + struct TfLiteContext* context, TfLiteExternalContextType type); + + // Set the value of an external context. + void SetExternalContext(TfLiteExternalContextType type, + TfLiteExternalContext* ctx); + static void SetExternalContext(struct TfLiteContext* context, + TfLiteExternalContextType type, + TfLiteExternalContext* ctx); + // Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra // capacity. Calling this function may invalidate existing pointers to // tensors. After calling this function, adding `kTensorsCapacityHeadroom` @@ -612,6 +626,9 @@ class Interpreter { // Profiler for this interpreter instance. profiling::Profiler* profiler_; + + // List of active external contexts. + TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts]; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 4f7fb36696..4fa97512fc 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -23,6 +23,15 @@ limitations under the License. #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { + +// InterpreterTest is a friend of Interpreter, so it can access context_. +class InterpreterTest : public ::testing::Test { + protected: + TfLiteContext* GetInterpreterContext() { return &interpreter_.context_; } + + Interpreter interpreter_; +}; + namespace ops { namespace builtin { TfLiteRegistration* Register_PADV2(); @@ -780,6 +789,47 @@ TEST(InterpreterTensorsCapacityTest, TestExceedHeadroom) { ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); } +struct TestExternalContext : public TfLiteExternalContext { + static const TfLiteExternalContextType kType = kTfLiteGemmLowpContext; + + static TestExternalContext* Get(TfLiteContext* context) { + return reinterpret_cast<TestExternalContext*>( + context->GetExternalContext(context, kType)); + } + + static void Set(TfLiteContext* context, TestExternalContext* value) { + context->SetExternalContext(context, kType, value); + } + + int num_refreshes = 0; +}; + +TEST_F(InterpreterTest, GetSetResetExternalContexts) { + auto* context = GetInterpreterContext(); + + TestExternalContext external_context; + external_context.Refresh = [](TfLiteContext* context) { + auto* ptr = TestExternalContext::Get(context); + if (ptr != nullptr) { + ++ptr->num_refreshes; + } + return kTfLiteOk; + }; + + EXPECT_EQ(TestExternalContext::Get(context), nullptr); + interpreter_.SetNumThreads(4); + + TestExternalContext::Set(context, &external_context); + EXPECT_EQ(TestExternalContext::Get(context), &external_context); + interpreter_.SetNumThreads(4); + interpreter_.SetNumThreads(5); + EXPECT_EQ(external_context.num_refreshes, 2); + + TestExternalContext::Set(context, nullptr); + EXPECT_EQ(TestExternalContext::Get(context), nullptr); + interpreter_.SetNumThreads(4); +} + // Test fixture that allows playing with execution plans. It creates a two // node graph that can be executed in either [0,1] order or [1,0] order. // The CopyOp records when it is invoked in the class member run_order_ diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc index f1fdb42624..94927cb53d 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.cc +++ b/tensorflow/contrib/lite/kernels/eigen_support.cc @@ -19,26 +19,41 @@ limitations under the License. namespace tflite { namespace eigen_support { +namespace { -struct RefCountedEigenContext { +struct RefCountedEigenContext : public TfLiteExternalContext { int num_references = 0; }; +RefCountedEigenContext* GetEigenContext(TfLiteContext* context) { + return reinterpret_cast<RefCountedEigenContext*>( + context->GetExternalContext(context, kTfLiteEigenContext)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + Eigen::setNbThreads(context->recommended_num_threads); + return kTfLiteOk; +} + +} // namespace + void IncrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { if (context->recommended_num_threads != -1) { Eigen::setNbThreads(context->recommended_num_threads); } ptr = new RefCountedEigenContext; + ptr->type = kTfLiteEigenContext; + ptr->Refresh = Refresh; ptr->num_references = 0; - context->eigen_context = ptr; + context->SetExternalContext(context, kTfLiteEigenContext, ptr); } ptr->num_references++; } void DecrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to DecrementUsageCounter() not preceded by " @@ -46,15 +61,9 @@ void DecrementUsageCounter(TfLiteContext* context) { } if (--ptr->num_references == 0) { delete ptr; - context->eigen_context = nullptr; + context->SetExternalContext(context, kTfLiteEigenContext, nullptr); } } -void SetNumThreads(TfLiteContext* context, int num_threads) { - IncrementUsageCounter(context); - Eigen::setNbThreads(num_threads); - DecrementUsageCounter(context); -} - } // namespace eigen_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h index aa8c351fd8..d47e691123 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.h +++ b/tensorflow/contrib/lite/kernels/eigen_support.h @@ -28,9 +28,6 @@ void IncrementUsageCounter(TfLiteContext* context); // usages all temporary Eigen objects will be deleted. void DecrementUsageCounter(TfLiteContext* context); -// Set the number of threads that can be used by Eigen. -void SetNumThreads(TfLiteContext* context, int num_threads); - } // namespace eigen_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc index 95f45ea768..ed334af2da 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.cc +++ b/tensorflow/contrib/lite/kernels/gemm_support.cc @@ -14,57 +14,70 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include <memory> + #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { namespace gemm_support { +namespace { -struct RefCountedGemmContext { - gemmlowp::GemmContext* gemm_context_ = nullptr; - int num_references_ = 0; +struct RefCountedGemmContext : public TfLiteExternalContext { + std::unique_ptr<gemmlowp::GemmContext> gemm_context; + int num_references = 0; }; +RefCountedGemmContext* GetGemmLowpContext(TfLiteContext* context) { + return reinterpret_cast<RefCountedGemmContext*>( + context->GetExternalContext(context, kTfLiteGemmLowpContext)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + auto* ptr = GetGemmLowpContext(context); + if (ptr != nullptr) { + ptr->gemm_context->set_max_num_threads(context->recommended_num_threads); + } + return kTfLiteOk; +} + +} // namespace + void IncrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { ptr = new RefCountedGemmContext; - ptr->gemm_context_ = new gemmlowp::GemmContext(); + ptr->type = kTfLiteGemmLowpContext; + ptr->Refresh = Refresh; + ptr->gemm_context.reset(new gemmlowp::GemmContext()); if (context->recommended_num_threads != -1) { - ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads); + ptr->gemm_context->set_max_num_threads(context->recommended_num_threads); } - ptr->num_references_ = 0; - context->gemm_context = ptr; + ptr->num_references = 0; + context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr); } - ptr->num_references_++; + ptr->num_references++; } void DecrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to DecrementUsageCounter() not preceded by " "IncrementUsageCounter()"); } - if (--ptr->num_references_ == 0) { - delete ptr->gemm_context_; + if (--ptr->num_references == 0) { delete ptr; - context->gemm_context = nullptr; + context->SetExternalContext(context, kTfLiteGemmLowpContext, nullptr); } } gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) { - auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to GetFromContext() not preceded by IncrementUsageCounter()"); } - return ptr->gemm_context_; -} - -void SetNumThreads(TfLiteContext* context, int num_threads) { - IncrementUsageCounter(context); - GetFromContext(context)->set_max_num_threads(num_threads); - DecrementUsageCounter(context); + return ptr->gemm_context.get(); } } // namespace gemm_support diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h index f033501cb6..37af772c68 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -45,9 +45,6 @@ void IncrementUsageCounter(TfLiteContext* context); // 'context'. If there are no more usages the GemmContext will be deleted. void DecrementUsageCounter(TfLiteContext* context); -// Set the number of threads that can be used by gemmlowp. -void SetNumThreads(TfLiteContext* context, int num_threads); - } // namespace gemm_support } // namespace tflite |