diff options
author | Jared Duke <jdduke@google.com> | 2018-10-04 12:59:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 13:04:08 -0700 |
commit | 158b6b8becb6afd08f9d6c87f0c7f144ba5f0584 (patch) | |
tree | 52bbbadbcce1ea40a170eae684cd7d662da25350 /tensorflow/contrib | |
parent | 2c75da86ffdb9d04b2b94ce89891f17a8656da22 (diff) |
Use weak symbols to inject flex delegates
PiperOrigin-RevId: 215788183
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/lite/BUILD | 26 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/flex/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/flex/delegate.cc | 9 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.h | 15 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter_test.cc | 6 | ||||
-rw-r--r-- | tensorflow/contrib/lite/model.cc | 35 | ||||
-rw-r--r-- | tensorflow/contrib/lite/model_flex_test.cc | 45 | ||||
-rw-r--r-- | tensorflow/contrib/lite/model_test.cc | 22 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testdata/multi_add_flex.bin | bin | 0 -> 1052 bytes | |||
-rw-r--r-- | tensorflow/contrib/lite/tools/benchmark/BUILD | 24 | ||||
-rw-r--r-- | tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc | 12 | ||||
-rw-r--r-- | tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h | 6 |
12 files changed, 141 insertions, 63 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index f3ebe3b245..787a85644c 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -4,6 +4,7 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") exports_files(glob([ @@ -165,10 +166,6 @@ cc_library( "stderr_reporter.h", ], copts = tflite_copts(), - defines = select({ - ":with_tflite_flex": ["TFLITE_FLEX"], - "//conditions:default": [], - }), linkopts = [ ] + select({ "//tensorflow:android": [ @@ -276,6 +273,7 @@ cc_test( "testdata/0_subgraphs.bin", "testdata/2_subgraphs.bin", "testdata/empty_model.bin", + "testdata/multi_add_flex.bin", "testdata/test_model.bin", "testdata/test_model_broken.bin", ], @@ -283,6 +281,26 @@ cc_test( ":framework", "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/core/api", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test model framework with the flex library linked into the target. +tf_cc_test( + name = "model_flex_test", + size = "small", + srcs = ["model_flex_test.cc"], + data = [ + "testdata/multi_add_flex.bin", + ], + tags = ["no_windows"], # TODO(b/116667551): No weak symbols with MSVC. + deps = [ + ":framework", + "//tensorflow/contrib/lite/core/api", + "//tensorflow/contrib/lite/delegates/flex:delegate", + "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/contrib/lite/delegates/flex/BUILD b/tensorflow/contrib/lite/delegates/flex/BUILD index 9dd38958e5..9b89ed4f84 100644 --- a/tensorflow/contrib/lite/delegates/flex/BUILD +++ b/tensorflow/contrib/lite/delegates/flex/BUILD @@ -2,7 +2,7 @@ # This is a TF Lite delegate that is powered by TensorFlow's Eager. # package(default_visibility = [ - "//visibility:public", + "//visibility:private", ]) licenses(["notice"]) # Apache 2.0 @@ -50,6 +50,7 @@ cc_library( hdrs = [ "delegate.h", ], + visibility = ["//visibility:public"], deps = [ ":buffer_map", ":delegate_data", @@ -66,6 +67,7 @@ cc_library( "//tensorflow/core:lib", ], }), + alwayslink = 1, ) tf_cc_test( diff --git a/tensorflow/contrib/lite/delegates/flex/delegate.cc b/tensorflow/contrib/lite/delegates/flex/delegate.cc index ba065a8ff5..c72b0cf513 100644 --- a/tensorflow/contrib/lite/delegates/flex/delegate.cc +++ b/tensorflow/contrib/lite/delegates/flex/delegate.cc @@ -83,6 +83,15 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, } // namespace delegate } // namespace flex +// Corresponding weak declaration found in lite/model.cc. +std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> +AcquireFlexDelegate() { + return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>( + tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) { + delete reinterpret_cast<tflite::FlexDelegate*>(delegate); + }); +} + std::unique_ptr<FlexDelegate> FlexDelegate::Create() { std::unique_ptr<flex::DelegateData> delegate_data; if (!flex::DelegateData::Create(&delegate_data).ok()) { diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 7ef736d01b..651a97e9dc 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -349,6 +349,10 @@ class Interpreter { return context_.allow_fp32_relax_to_fp16; } + // Owning handle to a TfLiteDelegate instance. + using TfLiteDelegatePtr = + std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; + // Allow a delegate to look at the graph and modify the graph to handle // parts of the graph themselves. After this is called, the graph may // contain new nodes that replace 1 more nodes. @@ -574,19 +578,11 @@ class Interpreter { TfLiteExternalContextType type, TfLiteExternalContext* ctx); - using TfLiteDelegatePtr = - std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; - // Variant of the public ModifyGraphWithDelegate method that additionally // Assumes ownership of the provided delegate. // WARNING: This is an experimental API and subject to change. - template <typename Delegate> - TfLiteStatus ModifyGraphWithDelegate(std::unique_ptr<Delegate> typed_delegate, + TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate, bool allow_dynamic_tensors = false) { - TfLiteDelegatePtr delegate(typed_delegate.release(), - [](TfLiteDelegate* delegate) { - delete static_cast<Delegate*>(delegate); - }); // Note that we retain ownership of the delegate even if graph modification // fails, as delegate use will be in an indeterminate state at that point. owned_delegates_.push_back(std::move(delegate)); @@ -676,6 +672,7 @@ class Interpreter { // List of delegates that have been installed and are owned by this // interpreter instance. Useful if client delegate ownership is burdensome. // WARNING: This is an experimental API and subject to change. + // TODO(b/116667551): Use TfLiteExternalContext for storing state. std::vector<TfLiteDelegatePtr> owned_delegates_; std::unique_ptr<MemoryPlanner> memory_planner_; diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index cdede430e2..6c71d5a8d7 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -30,7 +30,11 @@ class InterpreterTest : public ::testing::Test { template <typename Delegate> static TfLiteStatus ModifyGraphWithDelegate( Interpreter* interpreter, std::unique_ptr<Delegate> delegate) { - return interpreter->ModifyGraphWithDelegate(std::move(delegate)); + Interpreter::TfLiteDelegatePtr tflite_delegate( + delegate.release(), [](TfLiteDelegate* delegate) { + delete reinterpret_cast<Delegate*>(delegate); + }); + return interpreter->ModifyGraphWithDelegate(std::move(tflite_delegate)); } protected: diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index d50c345194..d7b109ac1a 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -27,9 +27,6 @@ limitations under the License. #ifndef TFLITE_MCU #include "tensorflow/contrib/lite/nnapi_delegate.h" #endif -#if defined(TFLITE_FLEX) -#include "tensorflow/contrib/lite/delegates/flex/delegate.h" -#endif #include "tensorflow/contrib/lite/version.h" namespace tflite { @@ -43,6 +40,25 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) { const char* kEmptyTensorName = ""; +// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but +// we avoid the absl dependency for binary size reasons. +#ifdef __has_attribute +#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x) +#else +#define TFLITE_HAS_ATTRIBUTE(x) 0 +#endif + +#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__)) +// Using weak symbols for the flex delegate allows automatic injection of the +// delegate simply by adding it as a dependency. See also the strong override in +// lite/delegates/flex/delegate.cc. +__attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { + return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); +} +#else +Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr; +#endif + #ifndef TFLITE_MCU // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. @@ -450,13 +466,14 @@ TfLiteStatus InterpreterBuilder::operator()( } (**interpreter).SetVariables(std::move(variables)); -#if defined(TFLITE_FLEX) - if (auto delegate = FlexDelegate::Create()) { - (**interpreter) - .ModifyGraphWithDelegate(std::move(delegate), - /*allow_dynamic_tensors=*/true); + // TODO(b/116667551): Only create the flex delegate if the model has flex ops. + if (AcquireFlexDelegate != nullptr) { + if (auto flex_delegate = AcquireFlexDelegate()) { + (**interpreter) + .ModifyGraphWithDelegate(std::move(flex_delegate), + /*allow_dynamic_tensors=*/true); + } } -#endif return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/model_flex_test.cc b/tensorflow/contrib/lite/model_flex_test.cc new file mode 100644 index 0000000000..52e76bee49 --- /dev/null +++ b/tensorflow/contrib/lite/model_flex_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/model.h" + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { + +// Ensures that a model with TensorFlow ops can be imported as long as the +// appropriate delegate is linked into the client. +TEST(FlexModel, WithFlexDelegate) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/multi_add_flex.bin"); + ASSERT_TRUE(model); + + std::unique_ptr<Interpreter> interpreter; + ASSERT_EQ(InterpreterBuilder(*model, + ops::builtin::BuiltinOpResolver{})(&interpreter), + kTfLiteOk); + ASSERT_TRUE(interpreter); + + ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index ec7d46af7c..b969bea5dc 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include <gtest/gtest.h> #include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/testing/util.h" // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, @@ -193,6 +194,27 @@ TEST(BasicFlatBufferModel, TestModelInInterpreter) { } } +// Test that loading a model with TensorFlow ops fails when the flex delegate is +// not linked into the target. +TEST(FlexModel, FailureWithoutFlexDelegate) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/multi_add_flex.bin"); + ASSERT_TRUE(model); + + // Note that creation will succeed when using the BuiltinOpResolver, but + // unless the appropriate delegate is linked into the target or the client + // explicitly installs the delegate, execution will fail. + std::unique_ptr<Interpreter> interpreter; + ASSERT_EQ(InterpreterBuilder(*model, + ops::builtin::BuiltinOpResolver{})(&interpreter), + kTfLiteOk); + ASSERT_TRUE(interpreter); + + // As the flex ops weren't resolved implicitly by the flex delegate, runtime + // allocation and execution will fail. + ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError); +} + // This tests on a flatbuffer that defines a shape of 2 to be a memory mapped // buffer. But the buffer is provided to be only 1 element. TEST(BasicFlatBufferModel, TestBrokenMmap) { diff --git a/tensorflow/contrib/lite/testdata/multi_add_flex.bin b/tensorflow/contrib/lite/testdata/multi_add_flex.bin Binary files differnew file mode 100644 index 0000000000..9aac2155fe --- /dev/null +++ b/tensorflow/contrib/lite/testdata/multi_add_flex.bin diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD index 502e181139..71bf61657e 100644 --- a/tensorflow/contrib/lite/tools/benchmark/BUILD +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -40,7 +40,7 @@ cc_binary( srcs = [ "benchmark_main.cc", ], - copts = common_copts + ["-DTFLITE_FLEX"], + copts = common_copts, linkopts = tflite_linkopts() + select({ "//tensorflow:android": [ "-pie", # Android 5.0 and later supports only PIE @@ -49,8 +49,9 @@ cc_binary( "//conditions:default": [], }), deps = [ - ":benchmark_tflite_model_plus_flex_lib", + ":benchmark_tflite_model_lib", ":logging", + "//tensorflow/contrib/lite/delegates/flex:delegate", ], ) @@ -111,25 +112,6 @@ cc_library( ) cc_library( - name = "benchmark_tflite_model_plus_flex_lib", - srcs = [ - "benchmark_tflite_model.cc", - "logging.h", - ], - hdrs = ["benchmark_tflite_model.h"], - copts = common_copts + ["-DTFLITE_FLEX"], - deps = [ - ":benchmark_model_lib", - ":logging", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/delegates/flex:delegate", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/profiling:profile_summarizer", - ], -) - -cc_library( name = "benchmark_params", srcs = [ "benchmark_params.cc", diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc index 463d5993f4..2a3df7f289 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -23,9 +23,6 @@ limitations under the License. #include <unordered_set> #include <vector> -#ifdef TFLITE_FLEX -#include "tensorflow/contrib/lite/delegates/flex/delegate.h" -#endif // TFLITE_FLEX #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/op_resolver.h" @@ -305,15 +302,6 @@ void BenchmarkTfLiteModel::Init() { interpreter->UseNNAPI(use_nnapi); -#ifdef TFLITE_FLEX - TFLITE_LOG(INFO) << "Instantiating Flex Delegate"; - delegate_ = FlexDelegate::Create(); - if (delegate_) { - interpreter->ModifyGraphWithDelegate(delegate_.get(), - /*allow_dynamic_tensors=*/true); - } -#endif // TFLITE_FLEX - auto interpreter_inputs = interpreter->inputs(); if (!inputs.empty()) { diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h index b091e18a29..25a302b2aa 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -20,9 +20,6 @@ limitations under the License. #include <string> #include <vector> -#ifdef TFLITE_FLEX -#include "tensorflow/contrib/lite/delegates/flex/delegate.h" -#endif // TFLITE_FLEX #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/profiling/profile_summarizer.h" #include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" @@ -73,9 +70,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel { void PrepareInputsAndOutputs() override; private: -#ifdef TFLITE_FLEX - std::unique_ptr<FlexDelegate> delegate_; -#endif // TFLITE_FLEX std::unique_ptr<tflite::FlatBufferModel> model; std::unique_ptr<tflite::Interpreter> interpreter; std::vector<InputLayerInfo> inputs; |