aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/eigen_support.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/eigen_support.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.cc92
1 files changed, 82 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
index f1fdb42624..e542ad0765 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.cc
+++ b/tensorflow/contrib/lite/kernels/eigen_support.cc
@@ -14,31 +14,100 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
-#include "third_party/eigen3/Eigen/Core"
+#include <utility>
+
+#include "tensorflow/contrib/lite/arena_planner.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace eigen_support {
+namespace {
+
+#ifndef EIGEN_DONT_ALIGN
+// Eigen may require buffers to be algiend to 16, 32 or 64 bytes depending on
+// hardware architecture and build configurations.
+// If the static assertion fails, try to increase `kDefaultTensorAlignment` to
+// in `arena_planner.h` to 32 or 64.
+static_assert(
+ kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0,
+ "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement.");
+#endif // EIGEN_DONT_ALIGN
+
+// We have a single global threadpool for all convolution operations. This means
+// that inferences started from different threads may block each other, but
+// since the underlying resource of CPU cores should be consumed by the
+// operations anyway, it shouldn't affect overall performance.
+class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
+ public:
+ // Takes ownership of 'pool'
+ explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
+ ~EigenThreadPoolWrapper() override {}
-struct RefCountedEigenContext {
+ void Schedule(std::function<void()> fn) override {
+ pool_->Schedule(std::move(fn));
+ }
+ int NumThreads() const override { return pool_->NumThreads(); }
+ int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
+
+ private:
+ std::unique_ptr<Eigen::ThreadPool> pool_;
+};
+
+struct RefCountedEigenContext : public TfLiteExternalContext {
+ std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper;
+ std::unique_ptr<Eigen::ThreadPoolDevice> device;
int num_references = 0;
};
+RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
+ return reinterpret_cast<RefCountedEigenContext*>(
+ context->GetExternalContext(context, kTfLiteEigenContext));
+}
+
+void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) {
+ int num_threads = 4;
+ if (context->recommended_num_threads != -1) {
+ num_threads = context->recommended_num_threads;
+ }
+ ptr->device.reset(); // destroy before we invalidate the thread pool
+ ptr->thread_pool_wrapper.reset(
+ new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads)));
+ ptr->device.reset(
+ new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads));
+}
+
+TfLiteStatus Refresh(TfLiteContext* context) {
+ Eigen::setNbThreads(context->recommended_num_threads);
+
+ auto* ptr = GetEigenContext(context);
+ if (ptr != nullptr) {
+ InitDevice(context, ptr);
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+
void IncrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ auto* ptr = GetEigenContext(context);
if (ptr == nullptr) {
if (context->recommended_num_threads != -1) {
Eigen::setNbThreads(context->recommended_num_threads);
}
ptr = new RefCountedEigenContext;
+ ptr->type = kTfLiteEigenContext;
+ ptr->Refresh = Refresh;
ptr->num_references = 0;
- context->eigen_context = ptr;
+ InitDevice(context, ptr);
+ context->SetExternalContext(context, kTfLiteEigenContext, ptr);
}
ptr->num_references++;
}
void DecrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ auto* ptr = GetEigenContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to DecrementUsageCounter() not preceded by "
@@ -46,14 +115,17 @@ void DecrementUsageCounter(TfLiteContext* context) {
}
if (--ptr->num_references == 0) {
delete ptr;
- context->eigen_context = nullptr;
+ context->SetExternalContext(context, kTfLiteEigenContext, nullptr);
}
}
-void SetNumThreads(TfLiteContext* context, int num_threads) {
- IncrementUsageCounter(context);
- Eigen::setNbThreads(num_threads);
- DecrementUsageCounter(context);
+const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
+ auto* ptr = GetEigenContext(context);
+ if (ptr == nullptr) {
+ TF_LITE_FATAL(
+ "Call to GetFromContext() not preceded by IncrementUsageCounter()");
+ }
+ return ptr->device.get();
}
} // namespace eigen_support