aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-03 13:17:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 13:20:38 -0700
commit3340c2da43f8b2313692aaad1a94da6c4a4e4106 (patch)
tree24f1f45d83c771c2324dec4f98b022faaca29849 /tensorflow
parent02ed358a986496e387d5f2e52865b10606e52c0a (diff)
Remove framework's dependency on eigen and gemmlowp.
PiperOrigin-RevId: 203172717
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/context.h33
-rw-r--r--tensorflow/contrib/lite/interpreter.cc52
-rw-r--r--tensorflow/contrib/lite/interpreter.h17
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc50
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.cc31
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.h3
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.cc55
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h3
8 files changed, 190 insertions, 54 deletions
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index 4f260ad40a..1ff8843fa7 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -39,6 +39,26 @@ extern "C" {
typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+// The list of external context types known to TF Lite. This list exists solely
+// to avoid conflicts and to ensure ops can share the external contexts they
+// need. Access to the external contexts is controled by one of the
+// corresponding support files.
+typedef enum {
+ kTfLiteEigenContext = 0, // include eigen_support.h to use.
+ kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
+ kTfLiteMaxExternalContexts = 2
+} TfLiteExternalContextType;
+
+// An external context is a collection of information unrelated to the TF Lite
+// framework, but useful to a subset of the ops. TF Lite knows very little
+// about about the actual contexts, but it keeps a list of them, and is able to
+// refresh them if configurations like the number of recommended threads
+// change.
+typedef struct {
+ TfLiteExternalContextType type;
+ TfLiteStatus (*Refresh)(struct TfLiteContext* context);
+} TfLiteExternalContext;
+
// Forward declare so GetNode can use this is in Context.
typedef struct _TfLiteRegistration TfLiteRegistration;
typedef struct _TfLiteDelegate TfLiteDelegate;
@@ -339,10 +359,15 @@ typedef struct TfLiteContext {
// eigen.
int recommended_num_threads;
- // TODO(ahentz): we should create a more general mechanism for this sort of
- // library-global objects.
- void* gemm_context;
- void* eigen_context;
+ // Access external contexts by type.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
+ TfLiteExternalContextType);
+ // Set the value of a external context. Does not take ownership of the
+ // pointer.
+ // WARNING: This is an experimental interface that is subject to change.
+ void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
+ TfLiteExternalContext*);
} TfLiteContext;
typedef struct _TfLiteRegistration {
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 3089a4c568..521216a4f1 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -25,10 +25,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/graph_info.h"
-#ifndef TFLITE_MCU
-#include "tensorflow/contrib/lite/kernels/eigen_support.h"
-#include "tensorflow/contrib/lite/kernels/gemm_support.h"
-#endif
#include "tensorflow/contrib/lite/memory_planner.h"
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
@@ -120,9 +116,9 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
context_.AddTensors = AddTensors;
context_.tensors = nullptr;
context_.tensors_size = 0;
- context_.eigen_context = nullptr;
- context_.gemm_context = nullptr;
context_.recommended_num_threads = -1;
+ context_.GetExternalContext = GetExternalContext;
+ context_.SetExternalContext = SetExternalContext;
// Invalid to call these these except from TfLiteDelegate
SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
@@ -133,6 +129,11 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
tensors_.reserve(kTensorsReservedCapacity);
nodes_and_registration_.reserve(kTensorsReservedCapacity);
next_execution_plan_index_to_prepare_ = 0;
+
+ for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
+ external_contexts_[i] = nullptr;
+ }
+
UseNNAPI(false);
}
@@ -290,6 +291,33 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
return kTfLiteOk;
}
+TfLiteExternalContext* Interpreter::GetExternalContext(
+ TfLiteExternalContextType type) {
+ if (type >= 0 && type < kTfLiteMaxExternalContexts) {
+ return external_contexts_[type];
+ }
+ return nullptr;
+}
+
+TfLiteExternalContext* Interpreter::GetExternalContext(
+ struct TfLiteContext* context, TfLiteExternalContextType type) {
+ return static_cast<Interpreter*>(context->impl_)->GetExternalContext(type);
+}
+
+void Interpreter::SetExternalContext(TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx) {
+ if (type >= 0 && type < kTfLiteMaxExternalContexts) {
+ external_contexts_[type] = ctx;
+ }
+}
+
+void Interpreter::SetExternalContext(struct TfLiteContext* context,
+ TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx) {
+ return static_cast<Interpreter*>(context->impl_)
+ ->SetExternalContext(type, ctx);
+}
+
// Gets an TfLiteIntArray* representing the execution plan. The interpreter owns
// this memory and it is only guaranteed to exist during the invocation of the
// delegate prepare.
@@ -869,12 +897,12 @@ void Interpreter::UseNNAPI(bool enable) {
void Interpreter::SetNumThreads(int num_threads) {
context_.recommended_num_threads = num_threads;
- // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to
- // be required in order to compile the framework.
-#ifndef TFLITE_MCU
- gemm_support::SetNumThreads(&context_, num_threads);
- eigen_support::SetNumThreads(&context_, num_threads);
-#endif
+ for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
+ auto* c = external_contexts_[i];
+ if (c && c->Refresh) {
+ c->Refresh(&context_);
+ }
+ }
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 033b8ee5fa..b69c50fbfc 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -410,6 +410,8 @@ class Interpreter {
}
private:
+ friend class InterpreterTest;
+
// Give 'op_reg' a chance to initialize itself using the contents of
// 'buffer'.
void* OpInit(const TfLiteRegistration& op_reg, const char* buffer,
@@ -522,6 +524,18 @@ class Interpreter {
static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context,
TfLiteIntArray** execution_plan);
+ // Retrieve an existing external context by type.
+ TfLiteExternalContext* GetExternalContext(TfLiteExternalContextType type);
+ static TfLiteExternalContext* GetExternalContext(
+ struct TfLiteContext* context, TfLiteExternalContextType type);
+
+ // Set the value of an external context.
+ void SetExternalContext(TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx);
+ static void SetExternalContext(struct TfLiteContext* context,
+ TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx);
+
// Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra
// capacity. Calling this function may invalidate existing pointers to
// tensors. After calling this function, adding `kTensorsCapacityHeadroom`
@@ -612,6 +626,9 @@ class Interpreter {
// Profiler for this interpreter instance.
profiling::Profiler* profiler_;
+
+ // List of active external contexts.
+ TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 4f7fb36696..4fa97512fc 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -23,6 +23,15 @@ limitations under the License.
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
+
+// InterpreterTest is a friend of Interpreter, so it can access context_.
+class InterpreterTest : public ::testing::Test {
+ protected:
+ TfLiteContext* GetInterpreterContext() { return &interpreter_.context_; }
+
+ Interpreter interpreter_;
+};
+
namespace ops {
namespace builtin {
TfLiteRegistration* Register_PADV2();
@@ -780,6 +789,47 @@ TEST(InterpreterTensorsCapacityTest, TestExceedHeadroom) {
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
}
+struct TestExternalContext : public TfLiteExternalContext {
+ static const TfLiteExternalContextType kType = kTfLiteGemmLowpContext;
+
+ static TestExternalContext* Get(TfLiteContext* context) {
+ return reinterpret_cast<TestExternalContext*>(
+ context->GetExternalContext(context, kType));
+ }
+
+ static void Set(TfLiteContext* context, TestExternalContext* value) {
+ context->SetExternalContext(context, kType, value);
+ }
+
+ int num_refreshes = 0;
+};
+
+TEST_F(InterpreterTest, GetSetResetExternalContexts) {
+ auto* context = GetInterpreterContext();
+
+ TestExternalContext external_context;
+ external_context.Refresh = [](TfLiteContext* context) {
+ auto* ptr = TestExternalContext::Get(context);
+ if (ptr != nullptr) {
+ ++ptr->num_refreshes;
+ }
+ return kTfLiteOk;
+ };
+
+ EXPECT_EQ(TestExternalContext::Get(context), nullptr);
+ interpreter_.SetNumThreads(4);
+
+ TestExternalContext::Set(context, &external_context);
+ EXPECT_EQ(TestExternalContext::Get(context), &external_context);
+ interpreter_.SetNumThreads(4);
+ interpreter_.SetNumThreads(5);
+ EXPECT_EQ(external_context.num_refreshes, 2);
+
+ TestExternalContext::Set(context, nullptr);
+ EXPECT_EQ(TestExternalContext::Get(context), nullptr);
+ interpreter_.SetNumThreads(4);
+}
+
// Test fixture that allows playing with execution plans. It creates a two
// node graph that can be executed in either [0,1] order or [1,0] order.
// The CopyOp records when it is invoked in the class member run_order_
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
index f1fdb42624..94927cb53d 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.cc
+++ b/tensorflow/contrib/lite/kernels/eigen_support.cc
@@ -19,26 +19,41 @@ limitations under the License.
namespace tflite {
namespace eigen_support {
+namespace {
-struct RefCountedEigenContext {
+struct RefCountedEigenContext : public TfLiteExternalContext {
int num_references = 0;
};
+RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
+ return reinterpret_cast<RefCountedEigenContext*>(
+ context->GetExternalContext(context, kTfLiteEigenContext));
+}
+
+TfLiteStatus Refresh(TfLiteContext* context) {
+ Eigen::setNbThreads(context->recommended_num_threads);
+ return kTfLiteOk;
+}
+
+} // namespace
+
void IncrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ auto* ptr = GetEigenContext(context);
if (ptr == nullptr) {
if (context->recommended_num_threads != -1) {
Eigen::setNbThreads(context->recommended_num_threads);
}
ptr = new RefCountedEigenContext;
+ ptr->type = kTfLiteEigenContext;
+ ptr->Refresh = Refresh;
ptr->num_references = 0;
- context->eigen_context = ptr;
+ context->SetExternalContext(context, kTfLiteEigenContext, ptr);
}
ptr->num_references++;
}
void DecrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ auto* ptr = GetEigenContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to DecrementUsageCounter() not preceded by "
@@ -46,15 +61,9 @@ void DecrementUsageCounter(TfLiteContext* context) {
}
if (--ptr->num_references == 0) {
delete ptr;
- context->eigen_context = nullptr;
+ context->SetExternalContext(context, kTfLiteEigenContext, nullptr);
}
}
-void SetNumThreads(TfLiteContext* context, int num_threads) {
- IncrementUsageCounter(context);
- Eigen::setNbThreads(num_threads);
- DecrementUsageCounter(context);
-}
-
} // namespace eigen_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
index aa8c351fd8..d47e691123 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.h
+++ b/tensorflow/contrib/lite/kernels/eigen_support.h
@@ -28,9 +28,6 @@ void IncrementUsageCounter(TfLiteContext* context);
// usages all temporary Eigen objects will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
-// Set the number of threads that can be used by Eigen.
-void SetNumThreads(TfLiteContext* context, int num_threads);
-
} // namespace eigen_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc
index 95f45ea768..ed334af2da 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.cc
+++ b/tensorflow/contrib/lite/kernels/gemm_support.cc
@@ -14,57 +14,70 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#include <memory>
+
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace gemm_support {
+namespace {
-struct RefCountedGemmContext {
- gemmlowp::GemmContext* gemm_context_ = nullptr;
- int num_references_ = 0;
+struct RefCountedGemmContext : public TfLiteExternalContext {
+ std::unique_ptr<gemmlowp::GemmContext> gemm_context;
+ int num_references = 0;
};
+RefCountedGemmContext* GetGemmLowpContext(TfLiteContext* context) {
+ return reinterpret_cast<RefCountedGemmContext*>(
+ context->GetExternalContext(context, kTfLiteGemmLowpContext));
+}
+
+TfLiteStatus Refresh(TfLiteContext* context) {
+ auto* ptr = GetGemmLowpContext(context);
+ if (ptr != nullptr) {
+ ptr->gemm_context->set_max_num_threads(context->recommended_num_threads);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+
void IncrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
ptr = new RefCountedGemmContext;
- ptr->gemm_context_ = new gemmlowp::GemmContext();
+ ptr->type = kTfLiteGemmLowpContext;
+ ptr->Refresh = Refresh;
+ ptr->gemm_context.reset(new gemmlowp::GemmContext());
if (context->recommended_num_threads != -1) {
- ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads);
+ ptr->gemm_context->set_max_num_threads(context->recommended_num_threads);
}
- ptr->num_references_ = 0;
- context->gemm_context = ptr;
+ ptr->num_references = 0;
+ context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr);
}
- ptr->num_references_++;
+ ptr->num_references++;
}
void DecrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to DecrementUsageCounter() not preceded by "
"IncrementUsageCounter()");
}
- if (--ptr->num_references_ == 0) {
- delete ptr->gemm_context_;
+ if (--ptr->num_references == 0) {
delete ptr;
- context->gemm_context = nullptr;
+ context->SetExternalContext(context, kTfLiteGemmLowpContext, nullptr);
}
}
gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to GetFromContext() not preceded by IncrementUsageCounter()");
}
- return ptr->gemm_context_;
-}
-
-void SetNumThreads(TfLiteContext* context, int num_threads) {
- IncrementUsageCounter(context);
- GetFromContext(context)->set_max_num_threads(num_threads);
- DecrementUsageCounter(context);
+ return ptr->gemm_context.get();
}
} // namespace gemm_support
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
index f033501cb6..37af772c68 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.h
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -45,9 +45,6 @@ void IncrementUsageCounter(TfLiteContext* context);
// 'context'. If there are no more usages the GemmContext will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
-// Set the number of threads that can be used by gemmlowp.
-void SetNumThreads(TfLiteContext* context, int num_threads);
-
} // namespace gemm_support
} // namespace tflite