From 2bfc7957dada57c1eb8491e04dac70d16b4369ac Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 6 Jul 2018 14:50:08 -0700 Subject: Make multithreaded conv respect setNumThreads() PiperOrigin-RevId: 203527657 --- tensorflow/contrib/lite/kernels/eigen_support.cc | 54 +++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) (limited to 'tensorflow/contrib/lite/kernels/eigen_support.cc') diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc index 94927cb53d..4f0d020793 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.cc +++ b/tensorflow/contrib/lite/kernels/eigen_support.cc @@ -14,14 +14,38 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "third_party/eigen3/Eigen/Core" +#include + +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { namespace eigen_support { namespace { +// We have a single global threadpool for all convolution operations. This means +// that inferences started from different threads may block each other, but +// since the underlying resource of CPU cores should be consumed by the +// operations anyway, it shouldn't affect overall performance. +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + // Takes ownership of 'pool' + explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + std::unique_ptr pool_; +}; + struct RefCountedEigenContext : public TfLiteExternalContext { + std::unique_ptr thread_pool_wrapper; + std::unique_ptr device; int num_references = 0; }; @@ -30,8 +54,26 @@ RefCountedEigenContext* GetEigenContext(TfLiteContext* context) { context->GetExternalContext(context, kTfLiteEigenContext)); } +void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) { + int num_threads = 4; + if (context->recommended_num_threads != -1) { + num_threads = context->recommended_num_threads; + } + ptr->device.reset(); // destroy before we invalidate the thread pool + ptr->thread_pool_wrapper.reset( + new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads))); + ptr->device.reset( + new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads)); +} + TfLiteStatus Refresh(TfLiteContext* context) { Eigen::setNbThreads(context->recommended_num_threads); + + auto* ptr = GetEigenContext(context); + if (ptr != nullptr) { + InitDevice(context, ptr); + } + return kTfLiteOk; } @@ -47,6 +89,7 @@ void IncrementUsageCounter(TfLiteContext* context) { ptr->type = kTfLiteEigenContext; ptr->Refresh = Refresh; ptr->num_references = 0; + InitDevice(context, ptr); context->SetExternalContext(context, kTfLiteEigenContext, ptr); } ptr->num_references++; @@ -65,5 +108,14 @@ void DecrementUsageCounter(TfLiteContext* context) { } } +const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) { + auto* ptr = GetEigenContext(context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to GetFromContext() not preceded by IncrementUsageCounter()"); + } + return ptr->device.get(); +} + } // namespace eigen_support } // namespace tflite -- cgit v1.2.3