diff options
author | 2018-08-17 17:18:53 -0700 | |
---|---|---|
committer | 2018-08-17 17:21:58 -0700 | |
commit | adf858f11bfbadccacf9febb21f4cd64f4114446 (patch) | |
tree | 47fa71f73aa49125cf11252115b96521f4357f6a | |
parent | e6970f5560de8277fe2ca212d703bc08a3494860 (diff) |
Internal change
PiperOrigin-RevId: 209231049
3 files changed, 59 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD index 2cb07eb6ec..dc97d22401 100644 --- a/tensorflow/contrib/lite/tools/benchmark/BUILD +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -5,8 +5,8 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") common_copts = ["-Wall"] + tflite_copts() @@ -35,6 +35,25 @@ cc_binary( ], ) +cc_binary( + name = "benchmark_model_plus_eager", + srcs = [ + "benchmark_main.cc", + ], + copts = common_copts + ["-DTFLITE_EXTENDED"], + linkopts = tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + ], + "//conditions:default": [], + }), + deps = [ + ":benchmark_tflite_model_plus_eager_lib", + ":logging", + ], +) + cc_test( name = "benchmark_test", srcs = ["benchmark_test.cc"], @@ -88,7 +107,25 @@ cc_library( "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/profiling:profile_summarizer", - "//tensorflow/contrib/lite/profiling:profiler", + ], +) + +cc_library( + name = "benchmark_tflite_model_plus_eager_lib", + srcs = [ + "benchmark_tflite_model.cc", + "logging.h", + ], + hdrs = ["benchmark_tflite_model.h"], + copts = common_copts + ["-DTFLITE_EXTENDED"], + deps = [ + ":benchmark_model_lib", + ":logging", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/delegates/eager:delegate", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/profiling:profile_summarizer", ], ) diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc index 7f97f5d0cd..02039922b4 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -23,6 +23,9 @@ limitations under the License. #include <unordered_set> #include <vector> +#ifdef TFLITE_EXTENDED +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" +#endif // TFLITE_EXTENDED #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/op_resolver.h" @@ -261,6 +264,16 @@ void BenchmarkTfLiteModel::Init() { bool use_nnapi = params_.Get<bool>("use_nnapi"); interpreter->UseNNAPI(use_nnapi); + +#ifdef TFLITE_EXTENDED + TFLITE_LOG(INFO) << "Instantiating Eager Delegate"; + delegate_ = EagerDelegate::Create(); + if (delegate_) { + interpreter->ModifyGraphWithDelegate(delegate_.get(), + /*allow_dynamic_tensors=*/true); + } +#endif // TFLITE_EXTENDED + auto interpreter_inputs = interpreter->inputs(); if (!inputs.empty()) { diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h index 9931dcbafe..4b22d80cbb 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -20,6 +20,9 @@ limitations under the License. #include <string> #include <vector> +#ifdef TFLITE_EXTENDED +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" +#endif // TFLITE_EXTENDED #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/profiling/profile_summarizer.h" #include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" @@ -52,6 +55,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel { public: BenchmarkTfLiteModel(); BenchmarkTfLiteModel(BenchmarkParams params); + virtual ~BenchmarkTfLiteModel() {} std::vector<Flag> GetFlags() override; void LogParams() override; @@ -59,7 +63,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel { uint64_t ComputeInputBytes() override; void Init() override; void RunImpl() override; - virtual ~BenchmarkTfLiteModel() {} struct InputLayerInfo { std::string name; @@ -67,6 +70,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel { }; private: +#ifdef TFLITE_EXTENDED + std::unique_ptr<EagerDelegate> delegate_; +#endif // TFLITE_EXTENDED std::unique_ptr<tflite::FlatBufferModel> model; std::unique_ptr<tflite::Interpreter> interpreter; std::vector<InputLayerInfo> inputs; |