aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-10-04 12:59:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 13:04:08 -0700
commit158b6b8becb6afd08f9d6c87f0c7f144ba5f0584 (patch)
tree52bbbadbcce1ea40a170eae684cd7d662da25350 /tensorflow/contrib
parent2c75da86ffdb9d04b2b94ce89891f17a8656da22 (diff)
Use weak symbols to inject flex delegates
PiperOrigin-RevId: 215788183
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/BUILD26
-rw-r--r--tensorflow/contrib/lite/delegates/flex/BUILD4
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate.cc9
-rw-r--r--tensorflow/contrib/lite/interpreter.h15
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc6
-rw-r--r--tensorflow/contrib/lite/model.cc35
-rw-r--r--tensorflow/contrib/lite/model_flex_test.cc45
-rw-r--r--tensorflow/contrib/lite/model_test.cc22
-rw-r--r--tensorflow/contrib/lite/testdata/multi_add_flex.binbin0 -> 1052 bytes
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/BUILD24
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc12
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h6
12 files changed, 141 insertions, 63 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index f3ebe3b245..787a85644c 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -4,6 +4,7 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
exports_files(glob([
@@ -165,10 +166,6 @@ cc_library(
"stderr_reporter.h",
],
copts = tflite_copts(),
- defines = select({
- ":with_tflite_flex": ["TFLITE_FLEX"],
- "//conditions:default": [],
- }),
linkopts = [
] + select({
"//tensorflow:android": [
@@ -276,6 +273,7 @@ cc_test(
"testdata/0_subgraphs.bin",
"testdata/2_subgraphs.bin",
"testdata/empty_model.bin",
+ "testdata/multi_add_flex.bin",
"testdata/test_model.bin",
"testdata/test_model_broken.bin",
],
@@ -283,6 +281,26 @@ cc_test(
":framework",
"//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/core/api",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test model framework with the flex library linked into the target.
+tf_cc_test(
+ name = "model_flex_test",
+ size = "small",
+ srcs = ["model_flex_test.cc"],
+ data = [
+ "testdata/multi_add_flex.bin",
+ ],
+ tags = ["no_windows"], # TODO(b/116667551): No weak symbols with MSVC.
+ deps = [
+ ":framework",
+ "//tensorflow/contrib/lite/core/api",
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/delegates/flex/BUILD b/tensorflow/contrib/lite/delegates/flex/BUILD
index 9dd38958e5..9b89ed4f84 100644
--- a/tensorflow/contrib/lite/delegates/flex/BUILD
+++ b/tensorflow/contrib/lite/delegates/flex/BUILD
@@ -2,7 +2,7 @@
# This is a TF Lite delegate that is powered by TensorFlow's Eager.
#
package(default_visibility = [
- "//visibility:public",
+ "//visibility:private",
])
licenses(["notice"]) # Apache 2.0
@@ -50,6 +50,7 @@ cc_library(
hdrs = [
"delegate.h",
],
+ visibility = ["//visibility:public"],
deps = [
":buffer_map",
":delegate_data",
@@ -66,6 +67,7 @@ cc_library(
"//tensorflow/core:lib",
],
}),
+ alwayslink = 1,
)
tf_cc_test(
diff --git a/tensorflow/contrib/lite/delegates/flex/delegate.cc b/tensorflow/contrib/lite/delegates/flex/delegate.cc
index ba065a8ff5..c72b0cf513 100644
--- a/tensorflow/contrib/lite/delegates/flex/delegate.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate.cc
@@ -83,6 +83,15 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
} // namespace delegate
} // namespace flex
+// Corresponding weak declaration found in lite/model.cc.
+std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>
+AcquireFlexDelegate() {
+ return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
+ tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
+ delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
+ });
+}
+
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
std::unique_ptr<flex::DelegateData> delegate_data;
if (!flex::DelegateData::Create(&delegate_data).ok()) {
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 7ef736d01b..651a97e9dc 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -349,6 +349,10 @@ class Interpreter {
return context_.allow_fp32_relax_to_fp16;
}
+ // Owning handle to a TfLiteDelegate instance.
+ using TfLiteDelegatePtr =
+ std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
+
// Allow a delegate to look at the graph and modify the graph to handle
// parts of the graph themselves. After this is called, the graph may
// contain new nodes that replace 1 more nodes.
@@ -574,19 +578,11 @@ class Interpreter {
TfLiteExternalContextType type,
TfLiteExternalContext* ctx);
- using TfLiteDelegatePtr =
- std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
-
// Variant of the public ModifyGraphWithDelegate method that additionally
// Assumes ownership of the provided delegate.
// WARNING: This is an experimental API and subject to change.
- template <typename Delegate>
- TfLiteStatus ModifyGraphWithDelegate(std::unique_ptr<Delegate> typed_delegate,
+ TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate,
bool allow_dynamic_tensors = false) {
- TfLiteDelegatePtr delegate(typed_delegate.release(),
- [](TfLiteDelegate* delegate) {
- delete static_cast<Delegate*>(delegate);
- });
// Note that we retain ownership of the delegate even if graph modification
// fails, as delegate use will be in an indeterminate state at that point.
owned_delegates_.push_back(std::move(delegate));
@@ -676,6 +672,7 @@ class Interpreter {
// List of delegates that have been installed and are owned by this
// interpreter instance. Useful if client delegate ownership is burdensome.
// WARNING: This is an experimental API and subject to change.
+ // TODO(b/116667551): Use TfLiteExternalContext for storing state.
std::vector<TfLiteDelegatePtr> owned_delegates_;
std::unique_ptr<MemoryPlanner> memory_planner_;
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index cdede430e2..6c71d5a8d7 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -30,7 +30,11 @@ class InterpreterTest : public ::testing::Test {
template <typename Delegate>
static TfLiteStatus ModifyGraphWithDelegate(
Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
- return interpreter->ModifyGraphWithDelegate(std::move(delegate));
+ Interpreter::TfLiteDelegatePtr tflite_delegate(
+ delegate.release(), [](TfLiteDelegate* delegate) {
+ delete reinterpret_cast<Delegate*>(delegate);
+ });
+ return interpreter->ModifyGraphWithDelegate(std::move(tflite_delegate));
}
protected:
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index d50c345194..d7b109ac1a 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -27,9 +27,6 @@ limitations under the License.
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
#endif
-#if defined(TFLITE_FLEX)
-#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
-#endif
#include "tensorflow/contrib/lite/version.h"
namespace tflite {
@@ -43,6 +40,25 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
const char* kEmptyTensorName = "";
+// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but
+// we avoid the absl dependency for binary size reasons.
+#ifdef __has_attribute
+#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x)
+#else
+#define TFLITE_HAS_ATTRIBUTE(x) 0
+#endif
+
+#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__))
+// Using weak symbols for the flex delegate allows automatic injection of the
+// delegate simply by adding it as a dependency. See also the strong override in
+// lite/delegates/flex/delegate.cc.
+__attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
+ return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
+}
+#else
+Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr;
+#endif
+
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
@@ -450,13 +466,14 @@ TfLiteStatus InterpreterBuilder::operator()(
}
(**interpreter).SetVariables(std::move(variables));
-#if defined(TFLITE_FLEX)
- if (auto delegate = FlexDelegate::Create()) {
- (**interpreter)
- .ModifyGraphWithDelegate(std::move(delegate),
- /*allow_dynamic_tensors=*/true);
+ // TODO(b/116667551): Only create the flex delegate if the model has flex ops.
+ if (AcquireFlexDelegate != nullptr) {
+ if (auto flex_delegate = AcquireFlexDelegate()) {
+ (**interpreter)
+ .ModifyGraphWithDelegate(std::move(flex_delegate),
+ /*allow_dynamic_tensors=*/true);
+ }
}
-#endif
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/model_flex_test.cc b/tensorflow/contrib/lite/model_flex_test.cc
new file mode 100644
index 0000000000..52e76bee49
--- /dev/null
+++ b/tensorflow/contrib/lite/model_flex_test.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/model.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+
+// Ensures that a model with TensorFlow ops can be imported as long as the
+// appropriate delegate is linked into the client.
+TEST(FlexModel, WithFlexDelegate) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+ ASSERT_TRUE(model);
+
+ std::unique_ptr<Interpreter> interpreter;
+ ASSERT_EQ(InterpreterBuilder(*model,
+ ops::builtin::BuiltinOpResolver{})(&interpreter),
+ kTfLiteOk);
+ ASSERT_TRUE(interpreter);
+
+ ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index ec7d46af7c..b969bea5dc 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/testing/util.h"
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
@@ -193,6 +194,27 @@ TEST(BasicFlatBufferModel, TestModelInInterpreter) {
}
}
+// Test that loading a model with TensorFlow ops fails when the flex delegate is
+// not linked into the target.
+TEST(FlexModel, FailureWithoutFlexDelegate) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+ ASSERT_TRUE(model);
+
+ // Note that creation will succeed when using the BuiltinOpResolver, but
+ // unless the appropriate delegate is linked into the target or the client
+ // explicitly installs the delegate, execution will fail.
+ std::unique_ptr<Interpreter> interpreter;
+ ASSERT_EQ(InterpreterBuilder(*model,
+ ops::builtin::BuiltinOpResolver{})(&interpreter),
+ kTfLiteOk);
+ ASSERT_TRUE(interpreter);
+
+ // As the flex ops weren't resolved implicitly by the flex delegate, runtime
+ // allocation and execution will fail.
+ ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError);
+}
+
// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped
// buffer. But the buffer is provided to be only 1 element.
TEST(BasicFlatBufferModel, TestBrokenMmap) {
diff --git a/tensorflow/contrib/lite/testdata/multi_add_flex.bin b/tensorflow/contrib/lite/testdata/multi_add_flex.bin
new file mode 100644
index 0000000000..9aac2155fe
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/multi_add_flex.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD
index 502e181139..71bf61657e 100644
--- a/tensorflow/contrib/lite/tools/benchmark/BUILD
+++ b/tensorflow/contrib/lite/tools/benchmark/BUILD
@@ -40,7 +40,7 @@ cc_binary(
srcs = [
"benchmark_main.cc",
],
- copts = common_copts + ["-DTFLITE_FLEX"],
+ copts = common_copts,
linkopts = tflite_linkopts() + select({
"//tensorflow:android": [
"-pie", # Android 5.0 and later supports only PIE
@@ -49,8 +49,9 @@ cc_binary(
"//conditions:default": [],
}),
deps = [
- ":benchmark_tflite_model_plus_flex_lib",
+ ":benchmark_tflite_model_lib",
":logging",
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
],
)
@@ -111,25 +112,6 @@ cc_library(
)
cc_library(
- name = "benchmark_tflite_model_plus_flex_lib",
- srcs = [
- "benchmark_tflite_model.cc",
- "logging.h",
- ],
- hdrs = ["benchmark_tflite_model.h"],
- copts = common_copts + ["-DTFLITE_FLEX"],
- deps = [
- ":benchmark_model_lib",
- ":logging",
- "//tensorflow/contrib/lite:framework",
- "//tensorflow/contrib/lite:string_util",
- "//tensorflow/contrib/lite/delegates/flex:delegate",
- "//tensorflow/contrib/lite/kernels:builtin_ops",
- "//tensorflow/contrib/lite/profiling:profile_summarizer",
- ],
-)
-
-cc_library(
name = "benchmark_params",
srcs = [
"benchmark_params.cc",
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 463d5993f4..2a3df7f289 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -23,9 +23,6 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#ifdef TFLITE_FLEX
-#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
-#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/op_resolver.h"
@@ -305,15 +302,6 @@ void BenchmarkTfLiteModel::Init() {
interpreter->UseNNAPI(use_nnapi);
-#ifdef TFLITE_FLEX
- TFLITE_LOG(INFO) << "Instantiating Flex Delegate";
- delegate_ = FlexDelegate::Create();
- if (delegate_) {
- interpreter->ModifyGraphWithDelegate(delegate_.get(),
- /*allow_dynamic_tensors=*/true);
- }
-#endif // TFLITE_FLEX
-
auto interpreter_inputs = interpreter->inputs();
if (!inputs.empty()) {
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index b091e18a29..25a302b2aa 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -20,9 +20,6 @@ limitations under the License.
#include <string>
#include <vector>
-#ifdef TFLITE_FLEX
-#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
-#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
@@ -73,9 +70,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
void PrepareInputsAndOutputs() override;
private:
-#ifdef TFLITE_FLEX
- std::unique_ptr<FlexDelegate> delegate_;
-#endif // TFLITE_FLEX
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
std::vector<InputLayerInfo> inputs;