diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/eigen_support.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/eigen_support.cc | 92 |
1 files changed, 82 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc index f1fdb42624..e542ad0765 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.cc +++ b/tensorflow/contrib/lite/kernels/eigen_support.cc @@ -14,31 +14,100 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "third_party/eigen3/Eigen/Core" +#include <utility> + +#include "tensorflow/contrib/lite/arena_planner.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { namespace eigen_support { +namespace { + +#ifndef EIGEN_DONT_ALIGN +// Eigen may require buffers to be algiend to 16, 32 or 64 bytes depending on +// hardware architecture and build configurations. +// If the static assertion fails, try to increase `kDefaultTensorAlignment` to +// in `arena_planner.h` to 32 or 64. +static_assert( + kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0, + "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement."); +#endif // EIGEN_DONT_ALIGN + +// We have a single global threadpool for all convolution operations. This means +// that inferences started from different threads may block each other, but +// since the underlying resource of CPU cores should be consumed by the +// operations anyway, it shouldn't affect overall performance. +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + // Takes ownership of 'pool' + explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} -struct RefCountedEigenContext { + void Schedule(std::function<void()> fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + std::unique_ptr<Eigen::ThreadPool> pool_; +}; + +struct RefCountedEigenContext : public TfLiteExternalContext { + std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper; + std::unique_ptr<Eigen::ThreadPoolDevice> device; int num_references = 0; }; +RefCountedEigenContext* GetEigenContext(TfLiteContext* context) { + return reinterpret_cast<RefCountedEigenContext*>( + context->GetExternalContext(context, kTfLiteEigenContext)); +} + +void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) { + int num_threads = 4; + if (context->recommended_num_threads != -1) { + num_threads = context->recommended_num_threads; + } + ptr->device.reset(); // destroy before we invalidate the thread pool + ptr->thread_pool_wrapper.reset( + new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads))); + ptr->device.reset( + new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + Eigen::setNbThreads(context->recommended_num_threads); + + auto* ptr = GetEigenContext(context); + if (ptr != nullptr) { + InitDevice(context, ptr); + } + + return kTfLiteOk; +} + +} // namespace + void IncrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { if (context->recommended_num_threads != -1) { Eigen::setNbThreads(context->recommended_num_threads); } ptr = new RefCountedEigenContext; + ptr->type = kTfLiteEigenContext; + ptr->Refresh = Refresh; ptr->num_references = 0; - context->eigen_context = ptr; + InitDevice(context, ptr); + context->SetExternalContext(context, kTfLiteEigenContext, ptr); } ptr->num_references++; } void DecrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to DecrementUsageCounter() not preceded by " @@ -46,14 +115,17 @@ void DecrementUsageCounter(TfLiteContext* context) { } if (--ptr->num_references == 0) { delete ptr; - context->eigen_context = nullptr; + context->SetExternalContext(context, kTfLiteEigenContext, nullptr); } } -void SetNumThreads(TfLiteContext* context, int num_threads) { - IncrementUsageCounter(context); - Eigen::setNbThreads(num_threads); - DecrementUsageCounter(context); +const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) { + auto* ptr = GetEigenContext(context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to GetFromContext() not preceded by IncrementUsageCounter()"); + } + return ptr->device.get(); } } // namespace eigen_support |