aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/gemm_support.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/gemm_support.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.cc55
1 files changed, 34 insertions, 21 deletions
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc
index 95f45ea768..ed334af2da 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.cc
+++ b/tensorflow/contrib/lite/kernels/gemm_support.cc
@@ -14,57 +14,70 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#include <memory>
+
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace gemm_support {
+namespace {
-struct RefCountedGemmContext {
- gemmlowp::GemmContext* gemm_context_ = nullptr;
- int num_references_ = 0;
+struct RefCountedGemmContext : public TfLiteExternalContext {
+ std::unique_ptr<gemmlowp::GemmContext> gemm_context;
+ int num_references = 0;
};
+RefCountedGemmContext* GetGemmLowpContext(TfLiteContext* context) {
+ return reinterpret_cast<RefCountedGemmContext*>(
+ context->GetExternalContext(context, kTfLiteGemmLowpContext));
+}
+
+TfLiteStatus Refresh(TfLiteContext* context) {
+ auto* ptr = GetGemmLowpContext(context);
+ if (ptr != nullptr) {
+ ptr->gemm_context->set_max_num_threads(context->recommended_num_threads);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+
void IncrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
ptr = new RefCountedGemmContext;
- ptr->gemm_context_ = new gemmlowp::GemmContext();
+ ptr->type = kTfLiteGemmLowpContext;
+ ptr->Refresh = Refresh;
+ ptr->gemm_context.reset(new gemmlowp::GemmContext());
if (context->recommended_num_threads != -1) {
- ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads);
+ ptr->gemm_context->set_max_num_threads(context->recommended_num_threads);
}
- ptr->num_references_ = 0;
- context->gemm_context = ptr;
+ ptr->num_references = 0;
+ context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr);
}
- ptr->num_references_++;
+ ptr->num_references++;
}
void DecrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to DecrementUsageCounter() not preceded by "
"IncrementUsageCounter()");
}
- if (--ptr->num_references_ == 0) {
- delete ptr->gemm_context_;
+ if (--ptr->num_references == 0) {
delete ptr;
- context->gemm_context = nullptr;
+ context->SetExternalContext(context, kTfLiteGemmLowpContext, nullptr);
}
}
gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to GetFromContext() not preceded by IncrementUsageCounter()");
}
- return ptr->gemm_context_;
-}
-
-void SetNumThreads(TfLiteContext* context, int num_threads) {
- IncrementUsageCounter(context);
- GetFromContext(context)->set_max_num_threads(num_threads);
- DecrementUsageCounter(context);
+ return ptr->gemm_context.get();
}
} // namespace gemm_support