aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/eigen_support.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-06 14:50:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-07 21:07:31 -0700
commit2bfc7957dada57c1eb8491e04dac70d16b4369ac (patch)
tree2a967b9e5efbdb62f3e782fc9c4d8471fd78adc6 /tensorflow/contrib/lite/kernels/eigen_support.cc
parent7c7cfde8fd6aa4ad7160b1fe6380e6007613d0d0 (diff)
Make multithreaded conv respect setNumThreads()
PiperOrigin-RevId: 203527657
Diffstat (limited to 'tensorflow/contrib/lite/kernels/eigen_support.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.cc54
1 files changed, 53 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
index 94927cb53d..4f0d020793 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.cc
+++ b/tensorflow/contrib/lite/kernels/eigen_support.cc
@@ -14,14 +14,38 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
-#include "third_party/eigen3/Eigen/Core"
+#include <utility>
+
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace eigen_support {
namespace {
+// We have a single global threadpool for all convolution operations. This means
+// that inferences started from different threads may block each other, but
+// since the underlying resource of CPU cores should be consumed by the
+// operations anyway, it shouldn't affect overall performance.
+class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
+ public:
+ // Takes ownership of 'pool'
+ explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
+ ~EigenThreadPoolWrapper() override {}
+
+ void Schedule(std::function<void()> fn) override {
+ pool_->Schedule(std::move(fn));
+ }
+ int NumThreads() const override { return pool_->NumThreads(); }
+ int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
+
+ private:
+ std::unique_ptr<Eigen::ThreadPool> pool_;
+};
+
struct RefCountedEigenContext : public TfLiteExternalContext {
+ std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper;
+ std::unique_ptr<Eigen::ThreadPoolDevice> device;
int num_references = 0;
};
@@ -30,8 +54,26 @@ RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
context->GetExternalContext(context, kTfLiteEigenContext));
}
+void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) {
+ int num_threads = 4;
+ if (context->recommended_num_threads != -1) {
+ num_threads = context->recommended_num_threads;
+ }
+ ptr->device.reset(); // destroy before we invalidate the thread pool
+ ptr->thread_pool_wrapper.reset(
+ new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads)));
+ ptr->device.reset(
+ new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads));
+}
+
TfLiteStatus Refresh(TfLiteContext* context) {
Eigen::setNbThreads(context->recommended_num_threads);
+
+ auto* ptr = GetEigenContext(context);
+ if (ptr != nullptr) {
+ InitDevice(context, ptr);
+ }
+
return kTfLiteOk;
}
@@ -47,6 +89,7 @@ void IncrementUsageCounter(TfLiteContext* context) {
ptr->type = kTfLiteEigenContext;
ptr->Refresh = Refresh;
ptr->num_references = 0;
+ InitDevice(context, ptr);
context->SetExternalContext(context, kTfLiteEigenContext, ptr);
}
ptr->num_references++;
@@ -65,5 +108,14 @@ void DecrementUsageCounter(TfLiteContext* context) {
}
}
+const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
+ auto* ptr = GetEigenContext(context);
+ if (ptr == nullptr) {
+ TF_LITE_FATAL(
+ "Call to GetFromContext() not preceded by IncrementUsageCounter()");
+ }
+ return ptr->device.get();
+}
+
} // namespace eigen_support
} // namespace tflite