aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/eigen_support.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-03 13:17:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 13:20:38 -0700
commit3340c2da43f8b2313692aaad1a94da6c4a4e4106 (patch)
tree24f1f45d83c771c2324dec4f98b022faaca29849 /tensorflow/contrib/lite/kernels/eigen_support.cc
parent02ed358a986496e387d5f2e52865b10606e52c0a (diff)
Remove framework's dependency on eigen and gemmlowp.
PiperOrigin-RevId: 203172717
Diffstat (limited to 'tensorflow/contrib/lite/kernels/eigen_support.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.cc31
1 files changed, 20 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
index f1fdb42624..94927cb53d 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.cc
+++ b/tensorflow/contrib/lite/kernels/eigen_support.cc
@@ -19,26 +19,41 @@ limitations under the License.
namespace tflite {
namespace eigen_support {
+namespace {
-struct RefCountedEigenContext {
+struct RefCountedEigenContext : public TfLiteExternalContext {
int num_references = 0;
};
+RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
+ return reinterpret_cast<RefCountedEigenContext*>(
+ context->GetExternalContext(context, kTfLiteEigenContext));
+}
+
+TfLiteStatus Refresh(TfLiteContext* context) {
+ Eigen::setNbThreads(context->recommended_num_threads);
+ return kTfLiteOk;
+}
+
+} // namespace
+
void IncrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ auto* ptr = GetEigenContext(context);
if (ptr == nullptr) {
if (context->recommended_num_threads != -1) {
Eigen::setNbThreads(context->recommended_num_threads);
}
ptr = new RefCountedEigenContext;
+ ptr->type = kTfLiteEigenContext;
+ ptr->Refresh = Refresh;
ptr->num_references = 0;
- context->eigen_context = ptr;
+ context->SetExternalContext(context, kTfLiteEigenContext, ptr);
}
ptr->num_references++;
}
void DecrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ auto* ptr = GetEigenContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to DecrementUsageCounter() not preceded by "
@@ -46,15 +61,9 @@ void DecrementUsageCounter(TfLiteContext* context) {
}
if (--ptr->num_references == 0) {
delete ptr;
- context->eigen_context = nullptr;
+ context->SetExternalContext(context, kTfLiteEigenContext, nullptr);
}
}
-void SetNumThreads(TfLiteContext* context, int num_threads) {
- IncrementUsageCounter(context);
- Eigen::setNbThreads(num_threads);
- DecrementUsageCounter(context);
-}
-
} // namespace eigen_support
} // namespace tflite