diff options
author | 2018-07-06 14:50:08 -0700 | |
---|---|---|
committer | 2018-07-07 21:07:31 -0700 | |
commit | 2bfc7957dada57c1eb8491e04dac70d16b4369ac (patch) | |
tree | 2a967b9e5efbdb62f3e782fc9c4d8471fd78adc6 /tensorflow/contrib/lite/kernels/eigen_support.cc | |
parent | 7c7cfde8fd6aa4ad7160b1fe6380e6007613d0d0 (diff) |
Make multithreaded conv respect setNumThreads()
PiperOrigin-RevId: 203527657
Diffstat (limited to 'tensorflow/contrib/lite/kernels/eigen_support.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/eigen_support.cc | 54 |
1 files changed, 53 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc index 94927cb53d..4f0d020793 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.cc +++ b/tensorflow/contrib/lite/kernels/eigen_support.cc @@ -14,14 +14,38 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "third_party/eigen3/Eigen/Core" +#include <utility> + +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { namespace eigen_support { namespace { +// We have a single global threadpool for all convolution operations. This means +// that inferences started from different threads may block each other, but +// since the underlying resource of CPU cores should be consumed by the +// operations anyway, it shouldn't affect overall performance. +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + // Takes ownership of 'pool' + explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function<void()> fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + std::unique_ptr<Eigen::ThreadPool> pool_; +}; + struct RefCountedEigenContext : public TfLiteExternalContext { + std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper; + std::unique_ptr<Eigen::ThreadPoolDevice> device; int num_references = 0; }; @@ -30,8 +54,26 @@ RefCountedEigenContext* GetEigenContext(TfLiteContext* context) { context->GetExternalContext(context, kTfLiteEigenContext)); } +void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) { + int num_threads = 4; + if (context->recommended_num_threads != -1) { + num_threads = context->recommended_num_threads; + } + ptr->device.reset(); // destroy before we invalidate the thread pool + ptr->thread_pool_wrapper.reset( + new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads))); + ptr->device.reset( + new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads)); +} + TfLiteStatus Refresh(TfLiteContext* context) { Eigen::setNbThreads(context->recommended_num_threads); + + auto* ptr = GetEigenContext(context); + if (ptr != nullptr) { + InitDevice(context, ptr); + } + return kTfLiteOk; } @@ -47,6 +89,7 @@ void IncrementUsageCounter(TfLiteContext* context) { ptr->type = kTfLiteEigenContext; ptr->Refresh = Refresh; ptr->num_references = 0; + InitDevice(context, ptr); context->SetExternalContext(context, kTfLiteEigenContext, ptr); } ptr->num_references++; @@ -65,5 +108,14 @@ void DecrementUsageCounter(TfLiteContext* context) { } } +const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) { + auto* ptr = GetEigenContext(context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to GetFromContext() not preceded by IncrementUsageCounter()"); + } + return ptr->device.get(); +} + } // namespace eigen_support } // namespace tflite |