diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-03 13:17:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-03 13:20:38 -0700 |
commit | 3340c2da43f8b2313692aaad1a94da6c4a4e4106 (patch) | |
tree | 24f1f45d83c771c2324dec4f98b022faaca29849 /tensorflow/contrib/lite/kernels/eigen_support.cc | |
parent | 02ed358a986496e387d5f2e52865b10606e52c0a (diff) |
Remove framework's dependency on eigen and gemmlowp.
PiperOrigin-RevId: 203172717
Diffstat (limited to 'tensorflow/contrib/lite/kernels/eigen_support.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/eigen_support.cc | 31 |
1 files changed, 20 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc index f1fdb42624..94927cb53d 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.cc +++ b/tensorflow/contrib/lite/kernels/eigen_support.cc @@ -19,26 +19,41 @@ limitations under the License. namespace tflite { namespace eigen_support { +namespace { -struct RefCountedEigenContext { +struct RefCountedEigenContext : public TfLiteExternalContext { int num_references = 0; }; +RefCountedEigenContext* GetEigenContext(TfLiteContext* context) { + return reinterpret_cast<RefCountedEigenContext*>( + context->GetExternalContext(context, kTfLiteEigenContext)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + Eigen::setNbThreads(context->recommended_num_threads); + return kTfLiteOk; +} + +} // namespace + void IncrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { if (context->recommended_num_threads != -1) { Eigen::setNbThreads(context->recommended_num_threads); } ptr = new RefCountedEigenContext; + ptr->type = kTfLiteEigenContext; + ptr->Refresh = Refresh; ptr->num_references = 0; - context->eigen_context = ptr; + context->SetExternalContext(context, kTfLiteEigenContext, ptr); } ptr->num_references++; } void DecrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to DecrementUsageCounter() not preceded by " @@ -46,15 +61,9 @@ void DecrementUsageCounter(TfLiteContext* context) { } if (--ptr->num_references == 0) { delete ptr; - context->eigen_context = nullptr; + context->SetExternalContext(context, kTfLiteEigenContext, nullptr); } } -void SetNumThreads(TfLiteContext* context, int num_threads) { - IncrementUsageCounter(context); - Eigen::setNbThreads(num_threads); - DecrementUsageCounter(context); -} - } // namespace eigen_support } // namespace tflite |