aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/tools/accuracy
diff options
context:
space:
mode:
authorGravatar Shashi Shekhar <shashishekhar@google.com>2018-08-22 19:22:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 19:25:48 -0700
commit8db22dc063e6a6bb16b4676e53446987dac99a49 (patch)
tree14df3a9ca7614e61467058c2f9ab55d6ce9a4346 /tensorflow/contrib/lite/tools/accuracy
parentb7b3f571728898c6d822aa1252d20bced15b989d (diff)
Add an accuracy tool that can be used to evaluate model accuracy on device.
- Adds an accuracy tool that can be used to develop evaluation pipelines to evaluate model accuracies. - The binary can be compiled for mobile platforms and the tool can be used to evaluate accuracy of models by running the binary on device. - Adds an example implementation for imagenet ILSVRC classification task accuracy evaluation. - More documentations and details coming soon. PiperOrigin-RevId: 209869774
Diffstat (limited to 'tensorflow/contrib/lite/tools/accuracy')
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/BUILD435
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h49
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc27
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/csv_writer.h79
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc39
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h87
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc100
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h99
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc229
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc133
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc29
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h37
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc110
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/imagenet_accuracy_eval.cc146
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.cc206
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h113
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.cc107
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h80
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval_test.cc149
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.cc80
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h75
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/inception_preprocessing_test.cc123
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc158
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc200
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc45
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h53
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/stage.h56
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/testdata/grace_hopper.jpgbin0 -> 73746 bytes
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/utils.cc102
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/utils.h46
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/utils_test.cc76
31 files changed, 3268 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/tools/accuracy/BUILD b/tensorflow/contrib/lite/tools/accuracy/BUILD
new file mode 100644
index 0000000000..db09de2909
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/BUILD
@@ -0,0 +1,435 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
+
+cc_library(
+ name = "inception_preprocessing",
+ srcs = ["inception_preprocessing.cc"],
+ hdrs = ["inception_preprocessing.h"],
+ copts = [
+ "-D__ANDROID_TYPES_FULL__",
+ "-DSUPPORT_SELECTIVE_REGISTRATION",
+ ],
+ deps = [
+ ":stage",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core/kernels:android_tensorflow_image_op",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "inception_preprocessing_test",
+ srcs = ["inception_preprocessing_test.cc"],
+ args = [
+ "--test_image=$(location :testdata/grace_hopper.jpg)",
+ ],
+ data = [":testdata/grace_hopper.jpg"],
+ deps = [
+ ":inception_preprocessing",
+ ":android_required_build_flags",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "utils",
+ srcs = ["utils.cc"],
+ hdrs = ["utils.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "utils_test",
+ srcs = ["utils_test.cc"],
+ args = [
+ "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)",
+ ],
+ data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
+ deps = [
+ ":utils",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "run_tflite_model_op",
+ srcs = ["run_tflite_model_op.cc"],
+ deps = [
+ ":utils",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ ],
+ },
+ ),
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "android_required_build_flags",
+ srcs = ["android_required_build_flags.cc"],
+)
+
+tf_cc_test(
+ name = "run_tflite_model_op_test",
+ srcs = ["run_tflite_model_op_test.cc"],
+ args = [
+ "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)",
+ ],
+ data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
+ deps = [
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ":run_tflite_model_op",
+ ":android_required_build_flags",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "stage",
+ hdrs = ["stage.h"],
+ deps = [
+ "//tensorflow/cc:scope",
+ ],
+)
+
+cc_library(
+ name = "file_reader_stage",
+ srcs = ["file_reader_stage.cc"],
+ hdrs = ["file_reader_stage.h"],
+ deps = [
+ ":stage",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ],
+)
+
+tf_cc_test(
+ name = "file_reader_stage_test",
+ srcs = ["file_reader_stage_test.cc"],
+ deps = [
+ ":file_reader_stage",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core/kernels:android_whole_file_read_ops",
+ "//tensorflow/core:android_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "run_tflite_model_stage",
+ srcs = ["run_tflite_model_stage.cc"],
+ hdrs = ["run_tflite_model_stage.h"],
+ deps = [
+ ":run_tflite_model_op",
+ ":stage",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ],
+)
+
+cc_library(
+ name = "accuracy_eval_stage",
+ hdrs = ["accuracy_eval_stage.h"],
+ deps = [
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "imagenet_topk_eval",
+ srcs = ["imagenet_topk_eval.cc"],
+ hdrs = ["imagenet_topk_eval.h"],
+ deps = [
+ ":accuracy_eval_stage",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "imagenet_topk_eval_test",
+ srcs = ["imagenet_topk_eval_test.cc"],
+ deps = [
+ ":imagenet_topk_eval",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "eval_pipeline",
+ srcs = ["eval_pipeline.cc"],
+ hdrs = ["eval_pipeline.h"],
+ deps = [
+ ":accuracy_eval_stage",
+ ":stage",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "eval_pipeline_test",
+ srcs = ["eval_pipeline_test.cc"],
+ deps = [
+ ":eval_pipeline",
+ "//tensorflow/cc:cc_ops",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "eval_pipeline_builder",
+ srcs = ["eval_pipeline_builder.cc"],
+ hdrs = ["eval_pipeline_builder.h"],
+ deps = [
+ ":eval_pipeline",
+ ":accuracy_eval_stage",
+ ":stage",
+ "@com_google_absl//absl/memory",
+ "//tensorflow/cc:cc_ops",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "eval_pipeline_builder_test",
+ srcs = ["eval_pipeline_builder_test.cc"],
+ deps = [
+ ":eval_pipeline_builder",
+ "//tensorflow/cc:cc_ops",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "csv_writer",
+ hdrs = ["csv_writer.h"],
+ deps = select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "imagenet_model_evaluator",
+ srcs = ["imagenet_model_evaluator.cc"],
+ hdrs = ["imagenet_model_evaluator.h"],
+ deps = [
+ ":android_required_build_flags",
+ ":eval_pipeline",
+ ":eval_pipeline_builder",
+ ":file_reader_stage",
+ ":imagenet_topk_eval",
+ ":inception_preprocessing",
+ ":run_tflite_model_stage",
+ ":utils",
+ "@com_google_absl//absl/memory",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core/kernels:android_whole_file_read_ops",
+ "//tensorflow/core/kernels:android_tensorflow_image_op",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:core_cpu",
+ ],
+ },
+ ),
+)
+
+tf_cc_binary(
+ name = "imagenet_accuracy_eval",
+ srcs = ["imagenet_accuracy_eval.cc"],
+ deps = [
+ ":android_required_build_flags",
+ ":csv_writer",
+ ":imagenet_model_evaluator",
+ ":imagenet_topk_eval",
+ "@com_google_absl//absl/memory",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:framework_internal",
+ ],
+ },
+ ),
+)
diff --git a/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h b/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h
new file mode 100644
index 0000000000..9cb843729a
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_
+
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Base class for evaluation stage that evaluates the accuracy of the model.
+// This stage calculates the accuracy metrics given the model outputs and
+// expected ground truth.
+class AccuracyEval {
+ public:
+ AccuracyEval() = default;
+ AccuracyEval(const AccuracyEval&) = delete;
+ AccuracyEval& operator=(const AccuracyEval&) = delete;
+
+ AccuracyEval(const AccuracyEval&&) = delete;
+ AccuracyEval& operator=(const AccuracyEval&&) = delete;
+
+ virtual ~AccuracyEval() = default;
+
+ // Evaluates the accuracy of the model for given `model_outputs` and the
+ // `ground truth`.
+ // Derived classes can do additional book keeping, calculate aggregrate
+ // statistics etc for the given model.
+ virtual Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) = 0;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc b/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc
new file mode 100644
index 0000000000..7fa8986716
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc
@@ -0,0 +1,27 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tensorflow on Android requires selective registration to be enabled in order
+// for certain types (e.g. DT_UINT8) to work.
+// Checks below ensure that for Android build, the right flags are passed to
+// the compiler.
+
+#if defined(__ANDROID__) && (!defined(__ANDROID_TYPES_FULL__) || \
+ !defined(SUPPORT_SELECTIVE_REGISTRATION))
+#error \
+ "Binary needs custom kernel support. For enabling custom kernels on " \
+ "Android, please pass -D__ANDROID_TYPES_FULL__ && " \
+ "-DSUPPORT_SELECTIVE_REGISTRATION for including the kernel in the binary."
+#endif
diff --git a/tensorflow/contrib/lite/tools/accuracy/csv_writer.h b/tensorflow/contrib/lite/tools/accuracy/csv_writer.h
new file mode 100644
index 0000000000..806b0d9418
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/csv_writer.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
+
+#include <fstream>
+#include <vector>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace metrics {
+// A simple CSV writer that writes values of same type for fixed number of
+// columns. This supports a very limited set of CSV spec and doesn't do any
+// escaping.
+// Usage:
+// std::ofstream * output_stream = ...
+// CSVWriter writer({"column1", "column2"}, output_stream);
+// writer.WriteRow({4, 5});
+// writer.Flush(); // flush results immediately.
+class CSVWriter {
+ public:
+ CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream)
+ : num_columns_(columns.size()), output_stream_(output_stream) {
+ TF_CHECK_OK(WriteRow(columns, output_stream_));
+ }
+
+ template <typename T>
+ Status WriteRow(const std::vector<T>& values) {
+ if (values.size() != num_columns_) {
+ return errors::InvalidArgument("Invalid size for row:", values.size(),
+ " expected: ", num_columns_);
+ }
+ return WriteRow(values, output_stream_);
+ }
+
+ void Flush() { output_stream_->flush(); }
+
+ ~CSVWriter() { output_stream_->flush(); }
+
+ private:
+ template <typename T>
+ static Status WriteRow(const std::vector<T>& values,
+ std::ofstream* output_stream) {
+ bool first = true;
+ for (const auto& v : values) {
+ if (!first) {
+ (*output_stream) << ", ";
+ } else {
+ first = false;
+ }
+ (*output_stream) << v;
+ }
+ (*output_stream) << "\n";
+ if (!output_stream->good()) {
+ return errors::Internal("Writing to stream failed.");
+ }
+ return Status::OK();
+ }
+ const size_t num_columns_;
+ std::ofstream* output_stream_;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc
new file mode 100644
index 0000000000..a03aba6a26
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+
+namespace tensorflow {
+namespace metrics {
+
+Status EvalPipeline::AttachSession(std::unique_ptr<Session> session) {
+ session_ = std::move(session);
+ TF_RETURN_IF_ERROR(session_->Create(model_graph_));
+ return Status::OK();
+}
+
+Status EvalPipeline::Run(const Tensor& input, const Tensor& ground_truth) {
+ if (session_ == nullptr) {
+ return errors::Internal("No session is associated with the graph.");
+ }
+ std::vector<Tensor> outputs;
+ TF_RETURN_IF_ERROR(session_->Run({{params_.model_input_node_name, input}},
+ {params_.model_output_node_name}, {},
+ &outputs));
+ TF_RETURN_IF_ERROR(eval_->ComputeEval(outputs, ground_truth));
+ return Status::OK();
+}
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h
new file mode 100644
index 0000000000..c9cfc86613
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h
@@ -0,0 +1,87 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Pipeline for evaluating a model.
+// Runs the graph and passes the output of graph to
+// the provided instance of AccuracyEval.
+// Example usage:
+// AccuracyEval *eval;
+// GraphDef graph_def;
+// ... populate graph_def...
+//
+// EvalPipeline eval_pipeline(&graph_def,
+// {.model_input_node_name = "model_input",
+// .model_output_node_name = "model_output"},
+// eval);
+// std::unique_ptr<Session> session(NewSession(SessionOptions()));
+// TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+// Tensor input = ... read input for the model ...
+// Tensor ground_truth = ... read ground truth for the model ...
+// TF_CHECK_OK(eval_pipeline.Run(input, ground_truth));
+//
+class EvalPipeline {
+ public:
+ struct Params {
+ string model_input_node_name;
+ string model_output_node_name;
+ };
+
+ // Creates a new `EvalPipeline` object. The ownership of the `accuracy_eval`
+ // is retained by the caller. Lifetime of `accuracy_eval` instance should
+ // be longer than the lifetime of this instance of pipeline.
+ EvalPipeline(const GraphDef& graph, const Params& params,
+ AccuracyEval* accuracy_eval)
+ : model_graph_(graph),
+ params_(params),
+ eval_(accuracy_eval),
+ session_(nullptr) {}
+
+ EvalPipeline(const EvalPipeline&) = delete;
+ EvalPipeline& operator=(const EvalPipeline&) = delete;
+
+ EvalPipeline(const EvalPipeline&&) = delete;
+ EvalPipeline& operator=(const EvalPipeline&&) = delete;
+
+ // Attaches the given session to this instance of pipeline.
+ // The provided session object will be reused for subsequent calls to
+ // EvalPipeline::Run.
+ Status AttachSession(std::unique_ptr<Session> session);
+
+ // Runs the model by feeding `input` and then passes the output of the model
+ // along with provided `ground_truth` to the AccuracyEval instance by calling
+ // AccuracyEval::ComputeEval.
+ Status Run(const Tensor& input, const Tensor& ground_truth);
+
+ private:
+ GraphDef model_graph_;
+ Params params_;
+ AccuracyEval* eval_;
+ std::unique_ptr<Session> session_;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc
new file mode 100644
index 0000000000..2e16437e15
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc
@@ -0,0 +1,100 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+
+namespace tensorflow {
+namespace metrics {
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithInputStage(Stage* input_stage) {
+ input_stage_ = input_stage;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithPreprocessingStage(
+ Stage* preprocessing_stage) {
+ preprocessing_stage_ = preprocessing_stage;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithRunModelStage(
+ Stage* run_model_stage) {
+ run_model_stage_ = run_model_stage;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithAccuracyEval(
+ AccuracyEval* accuracy_eval) {
+ accuracy_eval_ = accuracy_eval;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithInput(const string& input_name,
+ DataType input_type) {
+ input_name_ = input_name;
+ input_type_ = input_type;
+ return *this;
+}
+
+Status EvalPipelineBuilder::Build(
+ const Scope& scope, std::unique_ptr<EvalPipeline>* eval_pipeline) {
+ if (input_stage_ == nullptr) {
+ return errors::InvalidArgument("Input stage is null.");
+ }
+ if (preprocessing_stage_ == nullptr) {
+ return errors::InvalidArgument("Preprocessing stage is null.");
+ }
+ if (run_model_stage_ == nullptr) {
+ return errors::InvalidArgument("Run model stage is null.");
+ }
+ if (accuracy_eval_ == nullptr) {
+ return errors::InvalidArgument("accuracy_eval is null.");
+ }
+ if (input_name_.empty()) {
+ return errors::InvalidArgument("input name is not set.");
+ }
+ if (input_type_ == DT_INVALID) {
+ return errors::InvalidArgument("input type is not set.");
+ }
+
+ auto input_placeholder =
+ ops::Placeholder(scope.WithOpName(input_name_), input_type_);
+ TF_RETURN_IF_ERROR(scope.status());
+
+ input_stage_->AddToGraph(scope, input_placeholder);
+ TF_RETURN_IF_ERROR(scope.status());
+
+ preprocessing_stage_->AddToGraph(scope, input_stage_->Output());
+ TF_RETURN_IF_ERROR(scope.status());
+
+ run_model_stage_->AddToGraph(scope, preprocessing_stage_->Output());
+ TF_RETURN_IF_ERROR(scope.status());
+
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = input_name_;
+ params.model_output_node_name = run_model_stage_->output_name();
+ *eval_pipeline =
+ absl::make_unique<EvalPipeline>(graph_def, params, accuracy_eval_);
+
+ return Status::OK();
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h
new file mode 100644
index 0000000000..692db022f8
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h
@@ -0,0 +1,99 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// A builder to simplify construction of an `EvalPipeline` instance.
+// The `Build` method creates an |EvalPipeline| with the following structure:
+// |input| -> |input_stage|
+// |--> |preprocessing_stage|
+// |--> |run_model_stage| -> |accuracy_eval_stage|.
+// The stages are chained in the order shown above. Any missing stage results in
+// an error. The ownership of the stage object is retained by the caller. Stage
+// objects need to exist until the |Build| method is called.
+//
+// Currently only single inputs are supported.
+//
+// Example Usage:
+// EvalPipelineBuilder builder;
+// std::unique_ptr<EvalPipeline> eval_pipeline;
+// auto status = builder.WithInput("pipeline_input", DT_FLOAT)
+// .WithInputStage(&input_stage)
+// .WithRunModelStage(&run_model_stage)
+// .WithPreprocessingStage(&preprocess_stage)
+// .WithAccuracyEval(&eval)
+// .Build(scope, &eval_pipeline);
+// TF_CHECK_OK(status);
+class EvalPipelineBuilder {
+ public:
+ EvalPipelineBuilder() = default;
+ EvalPipelineBuilder(const EvalPipelineBuilder&) = delete;
+ EvalPipeline& operator=(const EvalPipelineBuilder&) = delete;
+
+ EvalPipelineBuilder(const EvalPipelineBuilder&&) = delete;
+ EvalPipeline& operator=(const EvalPipelineBuilder&&) = delete;
+
+ // Sets the input stage for the pipeline.
+ // Input stage converts the input, say filename into appropriate format
+ // that can be consumed by the preprocessing stage.
+ EvalPipelineBuilder& WithInputStage(Stage* input_stage);
+
+ // Sets the preprocessing stage for the pipeline.
+ // Preprocessing stage converts the input into a format that can be used to
+ // run the model.
+ EvalPipelineBuilder& WithPreprocessingStage(Stage* preprocessing_stage);
+
+ // Sets the run model stage for the pipeline.
+ // This stage receives the preprocessing input and output of this stage is
+ // fed to the accuracy eval stage.
+ EvalPipelineBuilder& WithRunModelStage(Stage* run_model_stage);
+
+ // Sets the accuracy eval for the pipeline.
+ // Results of evaluating the pipeline are fed to the `accuracy_eval` instance.
+ EvalPipelineBuilder& WithAccuracyEval(AccuracyEval* accuracy_eval);
+
+ // Sets the name and type of input for the pipeline.
+ // TODO(shashishekhar): Support multiple inputs for the pipeline, use a vector
+ // here.
+ EvalPipelineBuilder& WithInput(const string& input_name, DataType input_type);
+
+ // Builds the pipeline and assigns the pipeline to `eval_pipeline`.
+ // If the pipeline creation fails `eval_pipeline` is untouched.
+ Status Build(const Scope& scope,
+ std::unique_ptr<EvalPipeline>* eval_pipeline);
+
+ private:
+ Stage* input_stage_ = nullptr;
+ Stage* preprocessing_stage_ = nullptr;
+ Stage* run_model_stage_ = nullptr;
+ AccuracyEval* accuracy_eval_ = nullptr;
+ string input_name_;
+ DataType input_type_ = DT_INVALID;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc
new file mode 100644
index 0000000000..2d41929b79
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc
@@ -0,0 +1,229 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h"
+#include <gtest/gtest.h>
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+class IdentityStage : public Stage {
+ public:
+ IdentityStage(const string& name, const string& output)
+ : name_(name), output_(output) {}
+
+ void AddToGraph(const Scope& scope, const Input& input) override {
+ called_count_++;
+ inputs_.push_back(input.node()->name());
+ stage_output_ = ops::Identity(scope.WithOpName(output_), input);
+ }
+
+ string name() const override { return name_; }
+ string output_name() const override { return output_; }
+
+ int times_called() const { return called_count_; }
+
+ const std::vector<string> input_params() { return inputs_; }
+
+ private:
+ string name_;
+ string output_;
+ int called_count_ = 0;
+ std::vector<string> inputs_;
+};
+
+class FailingStage : public Stage {
+ public:
+ FailingStage(const string& name, const string& output)
+ : name_(name), output_(output) {}
+
+ void AddToGraph(const Scope& scope, const Input& input) override {
+ called_count_++;
+ scope.UpdateStatus(errors::Internal("Stage failed:", name_));
+ }
+
+ string name() const override { return name_; }
+ string output_name() const override { return output_; }
+
+ int times_called() const { return called_count_; }
+
+ private:
+ string name_;
+ string output_;
+ int called_count_ = 0;
+};
+
+class SimpleAccuracyEval : public AccuracyEval {
+ public:
+ SimpleAccuracyEval() {}
+
+ Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) override {
+ return Status::OK();
+ }
+};
+
+TEST(EvalPipelineBuilder, MissingPipelineStages) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status =
+ builder.WithInputStage(&input_stage).Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status =
+ builder.WithRunModelStage(&run_model_stage).Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status = builder.WithPreprocessingStage(&preprocess_stage)
+ .Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status =
+ builder.WithInput(pipeline_input, DT_FLOAT).Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status = builder.WithAccuracyEval(&eval).Build(scope, &eval_pipeline);
+ TF_CHECK_OK(status);
+ EXPECT_TRUE(eval_pipeline);
+}
+
+TEST(EvalPipeline, InputStageFailure) {
+ FailingStage input_stage("input_stage", "input_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+
+ EXPECT_FALSE(scope.status().ok());
+ // None of the other stages would have been called.
+ EXPECT_EQ(1, input_stage.times_called());
+ EXPECT_EQ(0, preprocess_stage.times_called());
+ EXPECT_EQ(0, run_model_stage.times_called());
+}
+
+TEST(EvalPipeline, PreprocessingFailure) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ FailingStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+
+ EXPECT_FALSE(status.ok());
+ // None of the other stages would have been called.
+ EXPECT_EQ(1, input_stage.times_called());
+ EXPECT_EQ(1, preprocess_stage.times_called());
+ EXPECT_EQ(0, run_model_stage.times_called());
+}
+
+TEST(EvalPipeline, GraphEvalFailure) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ FailingStage run_model_stage("run_model", "run_model_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+
+ EXPECT_FALSE(status.ok());
+ // None of the other stages would have been called.
+ EXPECT_EQ(1, input_stage.times_called());
+ EXPECT_EQ(1, preprocess_stage.times_called());
+ EXPECT_EQ(1, run_model_stage.times_called());
+}
+
+TEST(EvalPipeline, PipelineHasCorrectSequence) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+ TF_CHECK_OK(status);
+
+ ASSERT_EQ(1, input_stage.times_called());
+ ASSERT_EQ(1, run_model_stage.times_called());
+ ASSERT_EQ(1, preprocess_stage.times_called());
+
+ EXPECT_EQ(pipeline_input, input_stage.input_params()[0]);
+ EXPECT_EQ(input_stage.output_name(), preprocess_stage.input_params()[0]);
+ EXPECT_EQ(preprocess_stage.output_name(), run_model_stage.input_params()[0]);
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc
new file mode 100644
index 0000000000..ea0f6e19df
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+#include <gtest/gtest.h>
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+Tensor CreateFloatTensor(float value) {
+ Tensor tensor(DT_FLOAT, TensorShape({}));
+ tensor.scalar<float>()() = value;
+ return tensor;
+}
+
+class NoOpAccuracyEval : public AccuracyEval {
+ public:
+ explicit NoOpAccuracyEval(const Status& status_to_return)
+ : status_to_return_(status_to_return) {}
+
+ Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) override {
+ model_outputs_ = model_outputs;
+ ground_truth_ = ground_truth;
+ was_called_ = true;
+ return status_to_return_;
+ }
+
+ bool WasCalled() { return was_called_; }
+ std::vector<Tensor> model_outputs() { return model_outputs_; }
+ Tensor ground_truth() { return ground_truth_; }
+
+ private:
+ std::vector<Tensor> model_outputs_;
+ Tensor ground_truth_;
+ Status status_to_return_;
+ bool was_called_ = false;
+};
+
+TEST(EvalPipeline, AccuracyEvalIsCalled) {
+ Scope scope = Scope::NewRootScope();
+ // A graph that adds 1 to input.
+ auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
+ auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = "input";
+ params.model_output_node_name = "output";
+ NoOpAccuracyEval accuracy_eval(Status::OK());
+
+ EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+ TF_CHECK_OK(eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27)));
+
+ EXPECT_TRUE(accuracy_eval.WasCalled());
+ auto outputs = accuracy_eval.model_outputs();
+ ASSERT_EQ(1, outputs.size());
+ EXPECT_EQ(6.0f, outputs[0].scalar<float>()());
+ // Ground truth is unchanged.
+ EXPECT_EQ(27, accuracy_eval.ground_truth().scalar<float>()());
+}
+
+TEST(EvalPipeline, EvalIsNotCalledOnGraphRunFailure) {
+ Scope scope = Scope::NewRootScope();
+ // A graph that adds 1 to input.
+ auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
+ auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = "input";
+ params.model_output_node_name = "output";
+ NoOpAccuracyEval accuracy_eval(Status::OK());
+
+ EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+
+ // Pass a string tensor instead of a float tensor.
+ Tensor string_tensor(DT_STRING, TensorShape{});
+ auto status = eval_pipeline.Run(string_tensor, CreateFloatTensor(27));
+ EXPECT_FALSE(accuracy_eval.WasCalled());
+ EXPECT_FALSE(status.ok());
+}
+
+TEST(EvalPipeline, AccuracyEvalFailureResultsInFailure) {
+ Scope scope = Scope::NewRootScope();
+ // A graph that adds 1 to input.
+ auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
+ auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = "input";
+ params.model_output_node_name = "output";
+ NoOpAccuracyEval accuracy_eval(errors::Internal("accuracy_fail"));
+
+ EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+ auto status = eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27));
+
+ EXPECT_TRUE(accuracy_eval.WasCalled());
+ EXPECT_FALSE(status.ok());
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc
new file mode 100644
index 0000000000..61bed369f8
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc
@@ -0,0 +1,29 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h"
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+
+namespace tensorflow {
+namespace metrics {
+void FileReaderStage::AddToGraph(const Scope& scope, const Input& input) {
+ if (!scope.ok()) return;
+ Scope s = scope.WithOpName(name());
+ this->stage_output_ = ops::ReadFile(s.WithOpName(output_name()), input);
+}
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h
new file mode 100644
index 0000000000..18db5837c1
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+
+namespace tensorflow {
+namespace metrics {
+// A stage for reading a file into |string|.
+// Inputs: a string tensor: |file_name|.
+// Outputs: a string tensor: contents of |file_name|.
+class FileReaderStage : public Stage {
+ public:
+ string name() const override { return "stage_filereader"; }
+ string output_name() const override { return "stage_filereader_output"; }
+
+ void AddToGraph(const Scope& scope, const Input& input) override;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc
new file mode 100644
index 0000000000..a75f99187d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc
@@ -0,0 +1,110 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdio>
+#include <fstream>
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+class TempFile {
+ public:
+ TempFile() {
+ string file_path;
+ if (Env::Default()->LocalTempFilename(&file_path)) {
+ file_path_ = file_path;
+ created_ = true;
+ }
+ }
+
+ string filepath() { return file_path_; }
+ bool CreateFileWithContents(const std::string& contents) {
+ if (!created_) {
+ return false;
+ }
+ std::fstream file(file_path_, std::ios_base::out);
+ if (file) {
+ file << contents;
+ }
+ return file.good();
+ }
+
+ ~TempFile() {
+ if (created_) {
+ std::remove(file_path_.c_str());
+ }
+ }
+
+ private:
+ bool created_ = false;
+ string file_path_;
+};
+
+TEST(FileReaderStageTest, FileIsRead) {
+ TempFile file;
+ const string kFileContents = "Hello world.";
+ ASSERT_TRUE(file.CreateFileWithContents(kFileContents));
+ Scope scope = Scope::NewRootScope();
+ FileReaderStage reader_stage;
+ reader_stage.AddToGraph(scope, file.filepath());
+ TF_CHECK_OK(scope.status());
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {reader_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ TF_CHECK_OK(run_status);
+ EXPECT_EQ(1, outputs.size());
+ string contents = outputs[0].scalar<string>()();
+ EXPECT_EQ(kFileContents, contents);
+}
+
+TEST(FileReaderStageTest, InvalidFile) {
+ Scope scope = Scope::NewRootScope();
+ FileReaderStage reader_stage;
+ reader_stage.AddToGraph(scope, string("non_existent_file"));
+ TF_CHECK_OK(scope.status());
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {reader_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ EXPECT_FALSE(run_status.ok());
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_accuracy_eval.cc b/tensorflow/contrib/lite/tools/accuracy/imagenet_accuracy_eval.cc
new file mode 100644
index 0000000000..8103d6adb5
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_accuracy_eval.cc
@@ -0,0 +1,146 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iomanip>
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/tools/accuracy/csv_writer.h"
+#include "tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h"
+#include "tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace {
+
+std::vector<double> GetAccuracies(
+ const ImagenetTopKAccuracy::AccuracyStats& accuracy_stats) {
+ std::vector<double> results;
+ results.reserve(accuracy_stats.number_of_images);
+ if (accuracy_stats.number_of_images > 0) {
+ for (int n : accuracy_stats.topk_counts) {
+ double accuracy = 0;
+ if (accuracy_stats.number_of_images > 0) {
+ accuracy = (n * 100.0) / accuracy_stats.number_of_images;
+ }
+ results.push_back(accuracy);
+ }
+ }
+ return results;
+}
+
+} // namespace
+
+// Writes results to a CSV file.
+class ResultsWriter : public ImagenetModelEvaluator::Observer {
+ public:
+ explicit ResultsWriter(std::unique_ptr<CSVWriter> writer)
+ : writer_(std::move(writer)) {}
+
+ void OnEvaluationStart(int total_number_of_images) override {}
+
+ void OnSingleImageEvaluationComplete(
+ const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) override;
+
+ private:
+ std::unique_ptr<CSVWriter> writer_;
+};
+
+void ResultsWriter::OnSingleImageEvaluationComplete(
+ const ImagenetTopKAccuracy::AccuracyStats& stats, const string& image) {
+ TF_CHECK_OK(writer_->WriteRow(GetAccuracies(stats)));
+ writer_->Flush();
+}
+
+// Logs results to standard output with `kLogDelayUs` microseconds.
+class ResultsLogger : public ImagenetModelEvaluator::Observer {
+ public:
+ void OnEvaluationStart(int total_number_of_images) override;
+
+ void OnSingleImageEvaluationComplete(
+ const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) override;
+
+ private:
+ int total_num_images_ = 0;
+ uint64 last_logged_time_us_ = 0;
+ static constexpr int kLogDelayUs = 500 * 1000;
+};
+
+void ResultsLogger::OnEvaluationStart(int total_number_of_images) {
+ total_num_images_ = total_number_of_images;
+ LOG(ERROR) << "Starting model evaluation: " << total_num_images_;
+}
+
+void ResultsLogger::OnSingleImageEvaluationComplete(
+ const ImagenetTopKAccuracy::AccuracyStats& stats, const string& image) {
+ int num_evaluated = stats.number_of_images;
+
+ double current_percent = num_evaluated * 100.0 / total_num_images_;
+ auto now_us = Env::Default()->NowMicros();
+
+ if ((now_us - last_logged_time_us_) >= kLogDelayUs) {
+ last_logged_time_us_ = now_us;
+
+ LOG(ERROR) << "Evaluated " << num_evaluated << "/" << total_num_images_
+ << " images, " << std::setprecision(2) << std::fixed
+ << current_percent << "%";
+ }
+}
+
+int Main(int argc, char* argv[]) {
+ // TODO(shashishekhar): Make this binary configurable and model
+ // agnostic.
+ string output_file_path;
+ std::vector<Flag> flag_list = {
+ Flag("output_file_path", &output_file_path, "Path to output file."),
+ };
+ Flags::Parse(&argc, argv, flag_list);
+
+ std::unique_ptr<ImagenetModelEvaluator> evaluator;
+ CHECK(!output_file_path.empty()) << "Invalid output file path.";
+
+ TF_CHECK_OK(ImagenetModelEvaluator::Create(argc, argv, &evaluator));
+
+ std::ofstream output_stream(output_file_path, std::ios::out);
+ CHECK(output_stream) << "Unable to open output file path: '"
+ << output_file_path << "'";
+
+ output_stream << std::setprecision(3) << std::fixed;
+ std::vector<string> columns;
+ columns.reserve(evaluator->params().num_ranks);
+ for (int i = 0; i < evaluator->params().num_ranks; i++) {
+ columns.push_back("Top " + std::to_string(i + 1));
+ }
+
+ ResultsWriter results_writer(
+ absl::make_unique<CSVWriter>(columns, &output_stream));
+ ResultsLogger logger;
+ evaluator->AddObserver(&results_writer);
+ evaluator->AddObserver(&logger);
+ TF_CHECK_OK(evaluator->EvaluateModel());
+ return 0;
+}
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ return tensorflow::metrics::Main(argc, argv);
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.cc b/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.cc
new file mode 100644
index 0000000000..6ddde8e7c0
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.cc
@@ -0,0 +1,206 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h"
+
+#include <fstream>
+#include <iomanip>
+#include <string>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h"
+#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h"
+#include "tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h"
+#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+using tensorflow::string;
+
+string StripTrailingSlashes(const string& path) {
+ int end = path.size();
+ while (end > 0 && path[end - 1] == '/') {
+ end--;
+ }
+ return path.substr(0, end);
+}
+
+tensorflow::Tensor CreateStringTensor(const string& value) {
+ tensorflow::Tensor tensor(tensorflow::DT_STRING, tensorflow::TensorShape({}));
+ tensor.scalar<string>()() = value;
+ return tensor;
+}
+
+template <typename T>
+std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
+ if (n >= v.size()) return v;
+ std::vector<T> result(v.begin(), v.begin() + n);
+ return result;
+}
+
+// File pattern for imagenet files.
+const char* const kImagenetFilePattern = "*.[jJ][pP][eE][gG]";
+
+} // namespace
+
+namespace tensorflow {
+namespace metrics {
+
+/*static*/ Status ImagenetModelEvaluator::Create(
+ int argc, char* argv[],
+ std::unique_ptr<ImagenetModelEvaluator>* model_evaluator) {
+ Params params;
+ const std::vector<Flag> flag_list = {
+ Flag("model_output_labels", &params.model_output_labels_path,
+ "Path to labels that correspond to output of model."
+ " E.g. in case of mobilenet, this is the path to label "
+ "file where each label is in the same order as the output"
+ " of the model."),
+ Flag("ground_truth_images_path", &params.ground_truth_images_path,
+ "Path to ground truth images."),
+ Flag("ground_truth_labels", &params.ground_truth_labels_path,
+ "Path to ground truth labels."),
+ Flag("num_images", &params.number_of_images,
+ "Number of examples to evaluate, pass 0 for all "
+ "examples. Default: 100"),
+ tensorflow::Flag("model_file", &params.model_file_path,
+ "Path to test tflite model file."),
+ };
+ const bool parse_result = Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result)
+ return errors::InvalidArgument("Invalid command line flags");
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->IsDirectory(params.ground_truth_images_path),
+ "Invalid ground truth data path.");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->FileExists(params.ground_truth_labels_path),
+ "Invalid ground truth labels path.");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->FileExists(params.model_output_labels_path),
+ "Invalid model output labels path.");
+
+ if (params.number_of_images < 0) {
+ return errors::InvalidArgument("Invalid: num_examples");
+ }
+
+ utils::ModelInfo model_info;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ utils::GetTFliteModelInfo(params.model_file_path, &model_info),
+ "Invalid TFLite model.");
+
+ *model_evaluator =
+ absl::make_unique<ImagenetModelEvaluator>(model_info, params);
+ return Status::OK();
+}
+
+Status ImagenetModelEvaluator::EvaluateModel() {
+ if (model_info_.input_shapes.size() != 1) {
+ return errors::InvalidArgument("Invalid input shape");
+ }
+
+ const TensorShape& input_shape = model_info_.input_shapes[0];
+ // Input should be of the shape {1, height, width, 3}
+ if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) {
+ return errors::InvalidArgument("Invalid input shape for the model.");
+ }
+
+ const int image_height = input_shape.dim_size(1);
+ const int image_width = input_shape.dim_size(2);
+ const bool is_quantized = (model_info_.input_types[0] == DT_UINT8);
+
+ RunTFLiteModelStage::Params tfl_model_params;
+ tfl_model_params.model_file_path = params_.model_file_path;
+ if (is_quantized) {
+ tfl_model_params.input_type = {DT_UINT8};
+ tfl_model_params.output_type = {DT_UINT8};
+ } else {
+ tfl_model_params.input_type = {DT_FLOAT};
+ tfl_model_params.output_type = {DT_FLOAT};
+ }
+
+ Scope root = Scope::NewRootScope();
+ FileReaderStage reader;
+ InceptionPreprocessingStage inc(image_height, image_width, is_quantized);
+ RunTFLiteModelStage tfl_model_stage(tfl_model_params);
+ EvalPipelineBuilder builder;
+ std::vector<string> model_labels;
+ TF_RETURN_IF_ERROR(
+ utils::ReadFileLines(params_.model_output_labels_path, &model_labels));
+ if (model_labels.size() != 1001) {
+ return errors::InvalidArgument("Invalid number of labels: ",
+ model_labels.size());
+ }
+
+ ImagenetTopKAccuracy eval(model_labels, params_.num_ranks);
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+
+ auto build_status = builder.WithInputStage(&reader)
+ .WithPreprocessingStage(&inc)
+ .WithRunModelStage(&tfl_model_stage)
+ .WithAccuracyEval(&eval)
+ .WithInput("input_file", DT_STRING)
+ .Build(root, &eval_pipeline);
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(build_status,
+ "Failure while building eval pipeline.");
+
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+
+ TF_RETURN_IF_ERROR(eval_pipeline->AttachSession(std::move(session)));
+ string data_path =
+ StripTrailingSlashes(params_.ground_truth_images_path) + "/";
+
+ const string imagenet_file_pattern = data_path + kImagenetFilePattern;
+ std::vector<string> image_files;
+ TF_CHECK_OK(
+ Env::Default()->GetMatchingPaths(imagenet_file_pattern, &image_files));
+ std::vector<string> image_labels;
+ TF_CHECK_OK(
+ utils::ReadFileLines(params_.ground_truth_labels_path, &image_labels));
+ CHECK_EQ(image_files.size(), image_labels.size());
+
+ // Process files in filename sorted order.
+ std::sort(image_files.begin(), image_files.end());
+ if (params_.number_of_images > 0) {
+ image_files = GetFirstN(image_files, params_.number_of_images);
+ image_labels = GetFirstN(image_labels, params_.number_of_images);
+ }
+
+ for (Observer* observer : observers_) {
+ observer->OnEvaluationStart(image_files.size());
+ }
+
+ for (int i = 0; i < image_files.size(); i++) {
+ TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_files[i]),
+ CreateStringTensor(image_labels[i])));
+ auto stats = eval.GetTopKAccuracySoFar();
+
+ for (Observer* observer : observers_) {
+ observer->OnSingleImageEvaluationComplete(stats, image_files[i]);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h b/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h
new file mode 100644
index 0000000000..0308ac95b6
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h
@@ -0,0 +1,113 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h"
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Evaluates models accuracy for ILSVRC dataset.
+//
+// Generates the top-1, top-k accuracy counts where k is
+// controlled by |num_ranks|.
+// Usage:
+// ModelInfo model_info = ..
+// ImagenetModelEvaluator::Params params;
+// .. set params to image, label, output label and model file path..
+// SomeObserver observer;
+// ImagenetModelEvaluator evaluator(model_info, params);
+// evaluator.AddObserver(&observer);
+// TF_CHECK_OK(evaluator.EvaluateModel());
+class ImagenetModelEvaluator {
+ public:
+ struct Params {
+ // Path to ground truth images.
+ string ground_truth_images_path;
+
+ // Path to labels file for ground truth image.
+ // This file should be generated with the scripts.
+ string ground_truth_labels_path;
+
+ // This is word labels generated by the model. The category
+ // indices of output probabilities generated by the model maybe different
+ // from the indices in the imagenet dataset.
+ string model_output_labels_path;
+
+ // Path to the model file.
+ string model_file_path;
+
+ // The maximum number of images to calculate accuracy.
+ // 0 means all images, a positive number means only the specified
+ // number of images.
+ int number_of_images = 0;
+
+ // Number of ranks, top K.
+ int num_ranks = 10;
+ };
+
+ // An evaluation observer.
+ class Observer {
+ public:
+ Observer() = default;
+ Observer(const Observer&) = delete;
+ Observer& operator=(const Observer&) = delete;
+
+ Observer(const Observer&&) = delete;
+ Observer& operator=(const Observer&&) = delete;
+
+ // Called on start of evaluation.
+ virtual void OnEvaluationStart(int total_number_of_images) = 0;
+
+ // Called when evaluation was complete for `image`.
+ virtual void OnSingleImageEvaluationComplete(
+ const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) = 0;
+
+ virtual ~Observer() = default;
+ };
+
+ ImagenetModelEvaluator(const utils::ModelInfo& model_info,
+ const Params& params)
+ : model_info_(model_info), params_(params) {}
+
+ // Factory method to create the evaluator by parsing command line arguments.
+ static Status Create(int argc, char* argv[],
+ std::unique_ptr<ImagenetModelEvaluator>* evaluator);
+
+ // Adds an observer that can observe evaluation events..
+ void AddObserver(Observer* observer) { observers_.push_back(observer); }
+
+ const Params& params() { return params_; }
+
+ // Evaluates the provided model over the dataset.
+ Status EvaluateModel();
+
+ private:
+ std::vector<Observer*> observers_;
+ const utils::ModelInfo model_info_;
+ const Params params_;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.cc b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.cc
new file mode 100644
index 0000000000..1595bbee2f
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.cc
@@ -0,0 +1,107 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h"
+
+#include <numeric>
+
+namespace {
+constexpr int kNumCategories = 1001;
+std::vector<int> GetTopK(const std::vector<float>& values, int k) {
+ CHECK_LE(k, values.size());
+ std::vector<int> indices(values.size());
+
+ std::iota(indices.begin(), indices.end(), 0);
+ std::sort(indices.begin(), indices.end(),
+ [&values](int a, int b) { return values[a] > values[b]; });
+
+ indices.resize(k);
+ return indices;
+}
+} // namespace
+
+namespace tensorflow {
+namespace metrics {
+ImagenetTopKAccuracy::ImagenetTopKAccuracy(
+ const std::vector<string>& ground_truth_labels, int k)
+ : ground_truth_labels_(ground_truth_labels),
+ k_(k),
+ accuracy_counts_(k_, 0),
+ num_samples_(0) {
+ CHECK_EQ(kNumCategories, ground_truth_labels.size());
+}
+
+Status ImagenetTopKAccuracy::ComputeEval(
+ const std::vector<Tensor>& model_outputs, const Tensor& ground_truth) {
+ if (model_outputs.size() != 1) {
+ return errors::InvalidArgument("Invalid model output: ",
+ model_outputs.size());
+ }
+ const Tensor& output = model_outputs[0];
+ if (!output.shape().IsSameSize({1, kNumCategories})) {
+ return errors::InvalidArgument("Invalid shape of model output: ",
+ output.shape().DebugString());
+ }
+ if (ground_truth.dtype() != DT_STRING && ground_truth.dims() != 0) {
+ return errors::InvalidArgument("Invalid ground truth type: ",
+ ground_truth.DebugString());
+ }
+ string ground_truth_label = ground_truth.scalar<string>()();
+
+ std::vector<float> probabilities;
+ probabilities.reserve(kNumCategories);
+ if (output.dtype() == DT_FLOAT) {
+ auto probs = output.flat<float>();
+ for (size_t i = 0; i < probs.size(); i++) {
+ probabilities.push_back(probs(i));
+ }
+ } else {
+ auto probs = output.flat<uint8>();
+ for (size_t i = 0; i < probs.size(); i++) {
+ probabilities.push_back(probs(i));
+ }
+ }
+
+ CHECK_EQ(kNumCategories, probabilities.size());
+ std::vector<int> topK = GetTopK(probabilities, k_);
+ int ground_truth_index = GroundTruthIndex(ground_truth_label);
+ for (size_t i = 0; i < topK.size(); ++i) {
+ if (ground_truth_index == topK[i]) {
+ for (size_t j = i; j < topK.size(); j++) {
+ accuracy_counts_[j] += 1;
+ }
+ break;
+ }
+ }
+ num_samples_++;
+ return Status::OK();
+}
+
+const ImagenetTopKAccuracy::AccuracyStats
+ImagenetTopKAccuracy::GetTopKAccuracySoFar() const {
+ AccuracyStats stats;
+ stats.number_of_images = num_samples_;
+ stats.topk_counts = accuracy_counts_;
+ return stats;
+}
+
+int ImagenetTopKAccuracy::GroundTruthIndex(const string& label) const {
+ auto index = std::find(ground_truth_labels_.cbegin(),
+ ground_truth_labels_.cend(), label);
+ CHECK(index != ground_truth_labels_.end()) << "Invalid label: " << label;
+ return std::distance(ground_truth_labels_.cbegin(), index);
+}
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h
new file mode 100644
index 0000000000..5a575ff244
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace metrics {
+// An |AccuracyEval| stage that calculates the top K error rate for model
+// evaluations on imagenet like datasets.
+// Inputs: A {1, 1001} shaped tensor that contains the probabilities for objects
+// predicted by the model.
+// Ground truth: A |string| label for the image.
+// From the input object probabilities, the stage computes the predicted labels
+// and finds the top K error rates by comparing the predictions with ground
+// truths.
+class ImagenetTopKAccuracy : public AccuracyEval {
+ public:
+ // Accuracy statistics.
+ struct AccuracyStats {
+ // Number of images evaluated.
+ int number_of_images;
+ // A vector of size |k| that contains the number of images
+ // that have correct labels in top K.
+ // E.g. topk_counts[0] contains number of images for which
+ // model returned the correct label as the first result.
+ // Similarly topk_counts[4] contains the number of images for which
+ // model returned the correct label in top 5 results.
+ // This can be used to compute the top K error-rate for the model.
+ std::vector<int> topk_counts;
+ };
+
+ // Creates a new instance of |ImagenetTopKAccuracy| with the given
+ // |ground_truth_labels| and |k|.
+ // Args:
+ // |ground_truth_labels| : an ordered vector of labels for images. This is
+ // used to compute the index for the predicted labels and ground_truth label.
+ ImagenetTopKAccuracy(const std::vector<string>& ground_truth_labels, int k);
+
+ // Computes accuracy for a given image. The |model_outputs| should
+ // be a vector containing exactly one Tensor of shape: {1, 1001} where each
+ // item is a probability of the predicted object representing the image as
+ // output by the model.
+ // Uses |ground_truth_labels| to compute the index of |model_outputs| and
+ // |ground_truth| and computes the top K error rate.
+ Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) override;
+
+ // Gets the topK accuracy for images that have been evaluated till now.
+ const AccuracyStats GetTopKAccuracySoFar() const;
+
+ private:
+ int GroundTruthIndex(const string& label) const;
+ std::vector<string> ground_truth_labels_;
+ const int k_;
+ std::vector<int> accuracy_counts_;
+ int num_samples_;
+};
+} // namespace metrics
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval_test.cc b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval_test.cc
new file mode 100644
index 0000000000..256cd1d529
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval_test.cc
@@ -0,0 +1,149 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+const int kNumCategories = 1001;
+
+Tensor CreateStringTensor(const string& value) {
+ Tensor tensor(DT_STRING, TensorShape({}));
+ tensor.scalar<string>()() = value;
+ return tensor;
+}
+
+Tensor CreateOutputTensor() {
+ Tensor tensor(DT_FLOAT, TensorShape({1, kNumCategories}));
+ for (int i = 0; i < kNumCategories; i++) {
+ tensor.flat<float>()(i) = 0;
+ }
+ return tensor;
+}
+
+std::vector<string> CreateGroundTruth() {
+ std::vector<string> ground_truth;
+ ground_truth.reserve(kNumCategories);
+ for (int i = 0; i < kNumCategories; i++) {
+ ground_truth.push_back(std::to_string(i));
+ }
+ return ground_truth;
+}
+
+TEST(ImagenetTopKAccuracy, AllCorrect) {
+ ImagenetTopKAccuracy acc_top_5(CreateGroundTruth(), 5);
+ auto accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(0, accuracies.number_of_images);
+ EXPECT_EQ(5, accuracies.topk_counts.size());
+
+ for (int i : accuracies.topk_counts) {
+ EXPECT_EQ(0, i);
+ }
+ // First image was correctly identified as "0".
+ Tensor tensor = CreateOutputTensor();
+ tensor.flat<float>()(0) = 0.8;
+
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("0")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(1, accuracies.number_of_images);
+
+ for (int i : accuracies.topk_counts) {
+ EXPECT_EQ(1, i);
+ }
+ tensor.flat<float>()(1) = 0.9;
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("1")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(2, accuracies.number_of_images);
+
+ for (int i : accuracies.topk_counts) {
+ EXPECT_EQ(2, i);
+ }
+}
+
+TEST(ImagenetTopKAccuracy, Top5) {
+ ImagenetTopKAccuracy acc_top_5(CreateGroundTruth(), 5);
+ auto accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(0, accuracies.number_of_images);
+ EXPECT_EQ(5, accuracies.topk_counts.size());
+
+ // For first image, with ground truth "0" probabilities were
+ // 0.5 for "0",
+ // "0.6" for 1,
+ // "0.7" for 2,
+ // "0.8" for 3,
+ // "0.9" for 4.
+ // remaining all zeroes.
+
+ // First image was correctly identified as "0".
+ Tensor tensor = CreateOutputTensor();
+ tensor.flat<float>()(0) = 0.5;
+ tensor.flat<float>()(1) = 0.6;
+ tensor.flat<float>()(2) = 0.7;
+ tensor.flat<float>()(3) = 0.8;
+ tensor.flat<float>()(4) = 0.9;
+
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("0")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(1, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[4]);
+
+ for (int i = 0; i < 4; i++) {
+ EXPECT_EQ(0, accuracies.topk_counts[i]);
+ }
+
+ // Now for "1" only last two buckets are going to be affected.
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("1")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(2, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[3]);
+ EXPECT_EQ(2, accuracies.topk_counts[4]);
+ for (int i = 0; i < 3; i++) {
+ EXPECT_EQ(0, accuracies.topk_counts[i]);
+ }
+
+ // All buckets will be affected.
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("4")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(3, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[0]);
+ EXPECT_EQ(1, accuracies.topk_counts[1]);
+ EXPECT_EQ(1, accuracies.topk_counts[2]);
+ EXPECT_EQ(2, accuracies.topk_counts[3]);
+ EXPECT_EQ(3, accuracies.topk_counts[4]);
+
+ // No buckets will be affected
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("10")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(4, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[0]);
+ EXPECT_EQ(1, accuracies.topk_counts[1]);
+ EXPECT_EQ(1, accuracies.topk_counts[2]);
+ EXPECT_EQ(2, accuracies.topk_counts[3]);
+ EXPECT_EQ(3, accuracies.topk_counts[4]);
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.cc b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.cc
new file mode 100644
index 0000000000..7afef88637
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.cc
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h"
+
+#include <memory>
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace {
+void CentralCropImage(const Scope& s, const tensorflow::Output& decoded_image,
+ double crop_fraction, tensorflow::Output* cropped_image) {
+ auto image_dims = ops::Slice(s, ops::Shape(s, decoded_image), {0}, {2});
+ auto height_width = ops::Cast(s, image_dims, DT_DOUBLE);
+ auto cropped_begin = ops::Div(
+ s, ops::Sub(s, height_width, ops::Mul(s, height_width, crop_fraction)),
+ 2.0);
+ auto bbox_begin = ops::Cast(s, cropped_begin, DT_INT32);
+ auto bbox_size = ops::Sub(s, image_dims, ops::Mul(s, bbox_begin, 2));
+ auto slice_begin = ops::Concat(s, {bbox_begin, Input({0})}, 0);
+ auto slice_size = ops::Concat(s, {bbox_size, {-1}}, 0);
+ *cropped_image = ops::Slice(s, decoded_image, slice_begin, slice_size);
+}
+
+} // namespace
+
+void InceptionPreprocessingStage::AddToGraph(const Scope& scope,
+ const Input& input) {
+ if (!scope.ok()) return;
+ Scope s = scope.WithOpName(name());
+ ops::DecodeJpeg::Attrs attrs;
+ attrs.channels_ = 3;
+ auto decoded_jpeg = ops::DecodeJpeg(s, input, attrs);
+ tensorflow::Output cropped_image;
+ CentralCropImage(s, decoded_jpeg, params_.cropping_fraction, &cropped_image);
+ auto dims_expander = ops::ExpandDims(s, cropped_image, 0);
+ auto resized_image = ops::ResizeBilinear(
+ s, dims_expander,
+ ops::Const(s.WithOpName("size"), {image_height_, image_width_}));
+ if (is_quantized_) {
+ this->stage_output_ =
+ ops::Cast(s.WithOpName(output_name()), resized_image, DT_UINT8);
+ } else {
+ auto squeezed_image = ops::Squeeze(s, resized_image);
+ auto normalized_image =
+ ops::Div(s,
+ ops::Sub(s, squeezed_image,
+ {params_.input_means[0], params_.input_means[1],
+ params_.input_means[2]}),
+ {params_.scale});
+ this->stage_output_ =
+ ops::ExpandDims(s.WithOpName(output_name()), normalized_image, {0});
+ }
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h
new file mode 100644
index 0000000000..15df719817
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h
@@ -0,0 +1,75 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_
+
+#include <utility>
+
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// A stage that does inception preprocessing.
+// Inputs: A tensor containing bytes of a JPEG image.
+// Outputs: A tensor containing rescaled and preprocessed image that has
+// shape {1, image_height, image_width, 3}, where 3 is the number of channels.
+class InceptionPreprocessingStage : public Stage {
+ public:
+ struct Params {
+ std::vector<float> input_means;
+ float scale;
+ double cropping_fraction;
+ };
+
+ static Params DefaultParams() {
+ return {.input_means = {127.5, 127.5, 127.5},
+ .scale = 127.5,
+ .cropping_fraction = 0.875};
+ }
+
+ // Creates a new preprocessing stage object with provided |image_width|
+ // |image_height| as the size of output image.
+ // If |is_quantized| is set to true then |params| is ignored since quantized
+ // images don't go through any preprocessing.
+ InceptionPreprocessingStage(int image_width, int image_height,
+ bool is_quantized,
+ Params params = DefaultParams())
+ : image_width_(image_width),
+ image_height_(image_height),
+ is_quantized_(is_quantized),
+ params_(std::move(params)) {}
+
+ string name() const override { return "stage_inception_preprocess"; }
+ string output_name() const override {
+ return "stage_inception_preprocess_output";
+ }
+
+ void AddToGraph(const Scope& scope, const Input& input) override;
+
+ private:
+ int image_width_;
+ int image_height_;
+ bool is_quantized_;
+ Params params_;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing_test.cc b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing_test.cc
new file mode 100644
index 0000000000..db574476f6
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing_test.cc
@@ -0,0 +1,123 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fstream>
+#include <string>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+tensorflow::string* g_test_image_file = nullptr;
+} // namespace
+
+namespace tensorflow {
+namespace metrics {
+
+namespace {
+
+using tensorflow::Status;
+using tensorflow::Tensor;
+
+Status GetContents(const string& filename, string* output) {
+ std::ifstream input(filename, std::ios::binary);
+ const int kBufferSize = 2048;
+ char buffer[kBufferSize];
+ while (true) {
+ input.read(buffer, kBufferSize);
+ output->append(buffer, input.gcount());
+ if (!input.good()) {
+ if (input.eof()) return Status::OK();
+ return Status(tensorflow::error::ABORTED, "Failed to read file.");
+ }
+ }
+}
+
+TEST(InceptionPreprocessingTest, TestImagePreprocessQuantized) {
+ ASSERT_TRUE(g_test_image_file != nullptr);
+ string image_contents;
+ string image_path = *g_test_image_file;
+ auto status = GetContents(image_path, &image_contents);
+ ASSERT_TRUE(status.ok()) << status.error_message();
+ const int width = 224;
+ const int height = 224;
+ const bool is_quantized = true;
+ InceptionPreprocessingStage preprocess_stage(width, height, is_quantized);
+ Scope scope = Scope::NewRootScope();
+ preprocess_stage.AddToGraph(scope, image_contents);
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {preprocess_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ TF_CHECK_OK(run_status);
+ EXPECT_EQ(1, outputs.size());
+ EXPECT_EQ(DT_UINT8, outputs[0].dtype());
+ EXPECT_TRUE(outputs[0].shape().IsSameSize({1, 224, 224, 3}));
+}
+
+TEST(InceptionPreprocessingTest, TestImagePreprocessFloat) {
+ ASSERT_TRUE(g_test_image_file != nullptr);
+ string image_contents;
+ string image_path = *g_test_image_file;
+ auto status = GetContents(image_path, &image_contents);
+ ASSERT_TRUE(status.ok()) << status.error_message();
+ const int width = 224;
+ const int height = 224;
+ const bool is_quantized = false;
+ InceptionPreprocessingStage preprocess_stage(width, height, is_quantized);
+ Scope scope = Scope::NewRootScope();
+ preprocess_stage.AddToGraph(scope, image_contents);
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {preprocess_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ TF_CHECK_OK(run_status);
+ EXPECT_EQ(1, outputs.size());
+ EXPECT_EQ(DT_FLOAT, outputs[0].dtype());
+ EXPECT_TRUE(outputs[0].shape().IsSameSize({1, 224, 224, 3}));
+}
+
+} // namespace
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ g_test_image_file = new tensorflow::string();
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("test_image", g_test_image_file,
+ "Path to image file for test."),
+ };
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ CHECK(parse_result) << "Required test_model_file";
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc
new file mode 100644
index 0000000000..da4258f1c1
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+namespace {
+Status ValidateInputsMatch(const OpInputList& input_tensors,
+ const tflite::Interpreter& interpreter) {
+ std::vector<int> tflite_tensor_indices = interpreter.inputs();
+ if (tflite_tensor_indices.size() != input_tensors.size()) {
+ return errors::InvalidArgument(
+ "size mismatch, interpreter size: ", tflite_tensor_indices.size(),
+ " actual: ", input_tensors.size());
+ }
+
+ for (int i = 0; i < input_tensors.size(); i++) {
+ const TfLiteTensor* tflite_tensor =
+ interpreter.tensor(tflite_tensor_indices[i]);
+ if (tflite_tensor == nullptr) {
+ return errors::InvalidArgument("Tensor is null at index: ", i);
+ }
+
+ const Tensor& tensor = input_tensors[i];
+ auto i_type = metrics::utils::GetTFDataType(tflite_tensor->type);
+ auto i_shape = metrics::utils::GetTFLiteTensorShape(*tflite_tensor);
+ if (i_type != tensor.dtype()) {
+ return errors::InvalidArgument("Data types mismatch for tensors: ", i,
+ " expected: ", i_type,
+ " got: ", tensor.dtype());
+ }
+
+ if (i_shape != tensor.shape()) {
+ return errors::InvalidArgument("Data shapes mismatch for tensors: ", i,
+ " expected: ", i_shape,
+ " got: ", tensor.shape());
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+class RunTFLiteModelOp : public OpKernel {
+ public:
+ explicit RunTFLiteModelOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string model_file_path;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("model_file_path", &model_file_path));
+ model_ = tflite::FlatBufferModel::BuildFromFile(model_file_path.data());
+ OP_REQUIRES(ctx, model_,
+ errors::InvalidArgument(
+ "Model loading failed. Invalid model file path: ",
+ model_file_path));
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+
+ tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
+ OP_REQUIRES(ctx, interpreter_,
+ errors::Internal("Interpreter creation failed."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ OpInputList input_tensors;
+ OP_REQUIRES_OK(context, context->input_list("model_input", &input_tensors));
+
+ OP_REQUIRES_OK(context, ValidateInputsMatch(input_tensors, *interpreter_));
+ OpOutputList output_tensors;
+ OP_REQUIRES_OK(context,
+ context->output_list("model_output", &output_tensors));
+ auto tfl_outputs = interpreter_->outputs();
+ OP_REQUIRES(context, output_tensors.size() == tfl_outputs.size(),
+ errors::InvalidArgument(
+ "Invalid output size, expected: ", tfl_outputs.size(),
+ " got: ", output_tensors.size()));
+ for (int i = 0; i < output_tensors.size(); i++) {
+ DataType tfl_type = metrics::utils::GetTFDataType(
+ interpreter_->tensor(tfl_outputs[i])->type);
+ DataType otype = output_tensors.expected_output_dtype(i);
+ OP_REQUIRES(
+ context, tfl_type == otype,
+ errors::InvalidArgument("Invalid data type for output at index: ", i,
+ " expected: ", tfl_type, " got: ", otype));
+ }
+
+ auto allocation_status = interpreter_->AllocateTensors();
+ OP_REQUIRES(context, allocation_status == kTfLiteOk,
+ errors::Internal("Unable to allocate tensors."));
+ for (int i = 0; i < input_tensors.size(); i++) {
+ const int tfl_index = interpreter_->inputs()[i];
+ TfLiteTensor* tflite_tensor = interpreter_->tensor(tfl_index);
+ auto tensor_bytes = input_tensors[i].tensor_data();
+ OP_REQUIRES(context, tflite_tensor->bytes == tensor_bytes.size(),
+ errors::InvalidArgument(
+ "Size mismatch, expected: ", tflite_tensor->bytes,
+ " got: ", tensor_bytes.size()));
+ std::memcpy(tflite_tensor->data.raw, tensor_bytes.data(),
+ tensor_bytes.size());
+ }
+ auto invocation_status = interpreter_->Invoke();
+ OP_REQUIRES(context, invocation_status == kTfLiteOk,
+ errors::Internal("Interpreter invocation failed."));
+ for (int i = 0; i < output_tensors.size(); i++) {
+ auto tfl_tensor = interpreter_->tensor(tfl_outputs[i]);
+ TensorShape shape = metrics::utils::GetTFLiteTensorShape(*tfl_tensor);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, output_tensors.allocate(i, shape, &output));
+ auto tensor_bytes = output->tensor_data();
+ OP_REQUIRES(context, tensor_bytes.size() == tfl_tensor->bytes,
+ errors::Internal("Invalid size"));
+ std::memcpy(const_cast<char*>(tensor_bytes.data()), tfl_tensor->data.raw,
+ tfl_tensor->bytes);
+ }
+ }
+
+ private:
+ std::unique_ptr<tflite::FlatBufferModel> model_;
+ std::unique_ptr<tflite::Interpreter> interpreter_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("RunTFLiteModel").Device(DEVICE_CPU),
+ RunTFLiteModelOp);
+
+REGISTER_OP("RunTFLiteModel")
+ .Input("model_input: input_type")
+ .Output("model_output: output_type")
+ .Attr("model_file_path: string")
+ .Attr("input_type : list(type)")
+ .Attr("output_type: list(type)")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // TODO(shashishekhar): Infer the correct shape based on output_type and
+ // maybe another attribute.
+ return shape_inference::UnknownShape(c);
+ });
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc
new file mode 100644
index 0000000000..88175984a0
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc
@@ -0,0 +1,200 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+tensorflow::string* g_test_model_file = nullptr;
+}
+
+namespace tensorflow {
+namespace {
+
+TEST(RunTfliteModelOpTest, ModelIsRun) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+
+ Scope scope = Scope::NewRootScope();
+ TF_CHECK_OK(scope.status());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+
+ std::vector<Input> graph_inputs = {
+ ops::Const(scope, 1.0f, {1, 8, 8, 3}), // a
+ ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b
+ ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c
+ ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d
+ };
+
+ std::vector<NodeBuilder::NodeOut> input_data;
+ std::transform(graph_inputs.begin(), graph_inputs.end(),
+ std::back_inserter(input_data), [&scope](Input model_input) {
+ return ops::AsNodeOut(scope, model_input);
+ });
+
+ std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT,
+ DT_FLOAT};
+ ::tensorflow::Node* ret;
+ auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel")
+ .Input(input_data)
+ .Attr("model_file_path", test_model_file)
+ .Attr("input_type", model_input_type)
+ .Attr("output_type", {DT_FLOAT, DT_FLOAT});
+
+ scope.UpdateBuilder(&builder);
+ scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(
+ session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs));
+ EXPECT_EQ(2, outputs.size());
+
+ for (const auto& tensor : outputs) {
+ EXPECT_TRUE(tensor.shape().IsSameSize({1, 8, 8, 3}));
+ }
+ auto output_x = outputs[0].flat<float>();
+ auto output_y = outputs[1].flat<float>();
+ EXPECT_EQ(1 * 8 * 8 * 3, output_x.size());
+ EXPECT_EQ(1 * 8 * 8 * 3, output_y.size());
+ for (int i = 0; i < output_x.size(); i++) {
+ EXPECT_NEAR(6.3f, output_x(i), 1e-6f); // a+b+c
+ EXPECT_NEAR(9.6f, output_y(i), 1e-6f); // b+c+d
+ }
+}
+
+TEST(RunTfliteModelOpTest, NumInputsMismatch) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+
+ Scope scope = Scope::NewRootScope();
+ TF_CHECK_OK(scope.status());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+ // Remove a from input.
+
+ std::vector<Input> graph_inputs = {
+ ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b
+ ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c
+ ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d
+ };
+
+ std::vector<NodeBuilder::NodeOut> input_data;
+ std::transform(graph_inputs.begin(), graph_inputs.end(),
+ std::back_inserter(input_data), [&scope](Input model_input) {
+ return ops::AsNodeOut(scope, model_input);
+ });
+
+ std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT};
+
+ ::tensorflow::Node* ret;
+ auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel")
+ .Input(input_data)
+ .Attr("model_file_path", test_model_file)
+ .Attr("input_type", model_input_type)
+ .Attr("output_type", {DT_FLOAT, DT_FLOAT});
+
+ scope.UpdateBuilder(&builder);
+ scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ std::vector<Tensor> outputs;
+ auto status =
+ (session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs));
+ EXPECT_FALSE(status.ok());
+}
+
+TEST(RunTfliteModelOpTest, InputSizesMismatch) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+
+ Scope scope = Scope::NewRootScope();
+ TF_CHECK_OK(scope.status());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+ // Set a to be invalid size.
+ std::vector<Input> graph_inputs = {
+ ops::Const(scope, 1.0f, {1, 8, 8, 4}), // a invalid size,
+ ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b
+ ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c
+ ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d
+ };
+
+ std::vector<NodeBuilder::NodeOut> input_data;
+ std::transform(graph_inputs.begin(), graph_inputs.end(),
+ std::back_inserter(input_data), [&scope](Input model_input) {
+ return ops::AsNodeOut(scope, model_input);
+ });
+
+ std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT,
+ DT_FLOAT};
+ ::tensorflow::Node* ret;
+ auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel")
+ .Input(input_data)
+ .Attr("model_file_path", test_model_file)
+ .Attr("input_type", model_input_type)
+ .Attr("output_type", {DT_FLOAT, DT_FLOAT});
+
+ scope.UpdateBuilder(&builder);
+ scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ std::vector<Tensor> outputs;
+ auto status =
+ (session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs));
+ EXPECT_FALSE(status.ok());
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ g_test_model_file = new tensorflow::string();
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("test_model_file", g_test_model_file,
+ "Path to test tflite model file."),
+ };
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ CHECK(parse_result) << "Required test_model_file";
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc
new file mode 100644
index 0000000000..c96795d499
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h"
+
+#include <vector>
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+
+namespace tensorflow {
+namespace metrics {
+void RunTFLiteModelStage::AddToGraph(const Scope& scope, const Input& input) {
+ if (!scope.ok()) return;
+ Scope s = scope.WithOpName(name());
+
+ std::vector<NodeBuilder::NodeOut> _data = {ops::AsNodeOut(s, input)};
+ ::tensorflow::Node* ret;
+ auto builder = NodeBuilder(output_name(), "RunTFLiteModel")
+ .Input(_data)
+ .Attr("model_file_path", params_.model_file_path)
+ .Attr("input_type", params_.input_type)
+ .Attr("output_type", params_.output_type);
+
+ s.UpdateBuilder(&builder);
+ s.UpdateStatus(builder.Finalize(s.graph(), &ret));
+ if (!s.ok()) return;
+ s.UpdateStatus(s.DoShapeInference(ret));
+ this->stage_output_ = ::tensorflow::Output(ret, 0);
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h
new file mode 100644
index 0000000000..90d12d6f42
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+
+namespace tensorflow {
+namespace metrics {
+// Stage that loads and runs a TFLite model.
+// Inputs: The input to TFLite model.
+// Outputs: The output of running the TFLite model.
+class RunTFLiteModelStage : public Stage {
+ public:
+ // The parameters for the stage.
+ struct Params {
+ string model_file_path;
+ std::vector<TensorShape> output_shape;
+ std::vector<DataType> input_type;
+ std::vector<DataType> output_type;
+ };
+
+ explicit RunTFLiteModelStage(const Params& params) : params_(params) {}
+
+ string name() const override { return "stage_run_tfl_model"; }
+ // TODO(shashishekhar): This stage can have multiple inputs and
+ // outputs, perhaps change the definition of stage.
+ string output_name() const override { return "stage_run_tfl_model_output"; }
+
+ void AddToGraph(const Scope& scope, const Input& input) override;
+
+ private:
+ Params params_;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/stage.h b/tensorflow/contrib/lite/tools/accuracy/stage.h
new file mode 100644
index 0000000000..8292ea2ec7
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/stage.h
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_
+
+#include "tensorflow/cc/framework/scope.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// A stage in an evaluation pipeline.
+// Each stage adds a subgraph to the pipeline. Stages can be chained
+// together.
+class Stage {
+ public:
+ Stage() = default;
+ Stage(const Stage&) = delete;
+ Stage& operator=(const Stage&) = delete;
+
+ Stage(const Stage&&) = delete;
+ Stage& operator=(const Stage&&) = delete;
+
+ // Adds a subgraph to given scope that takes in `input` as a parameter.
+ virtual void AddToGraph(const Scope& scope, const Input& input) = 0;
+ virtual ~Stage() {}
+
+ // The name of the stage.
+ // Can be used by derived classes for naming the subscope for the stage
+ // graph.
+ virtual string name() const = 0;
+
+ // The name of the output for the stage.
+ virtual string output_name() const = 0;
+
+ const ::tensorflow::Output& Output() const { return stage_output_; }
+
+ protected:
+ ::tensorflow::Output stage_output_;
+};
+} // namespace metrics
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/testdata/grace_hopper.jpg b/tensorflow/contrib/lite/tools/accuracy/testdata/grace_hopper.jpg
new file mode 100644
index 0000000000..d2a427810f
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/testdata/grace_hopper.jpg
Binary files differ
diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.cc b/tensorflow/contrib/lite/tools/accuracy/utils.cc
new file mode 100644
index 0000000000..f5493301fc
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/utils.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+
+#include <sys/stat.h>
+
+#include <cstring>
+#include <fstream>
+#include <memory>
+#include <string>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace utils {
+
+DataType GetTFDataType(TfLiteType tflite_type) {
+ switch (tflite_type) {
+ case kTfLiteFloat32:
+ return DT_FLOAT;
+ case kTfLiteUInt8:
+ return DT_UINT8;
+ default:
+ return DT_INVALID;
+ }
+}
+
+TensorShape GetTFLiteTensorShape(const TfLiteTensor& tflite_tensor) {
+ TensorShape shape;
+ for (int i = 0; i < tflite_tensor.dims->size; i++) {
+ shape.AddDim(tflite_tensor.dims->data[i]);
+ }
+ return shape;
+}
+
+Status ReadFileLines(const string& file_path,
+ std::vector<string>* lines_output) {
+ if (!lines_output) {
+ return errors::InvalidArgument("Invalid output");
+ }
+ std::vector<string> lines;
+ std::ifstream stream(file_path, std::ios_base::in);
+ if (!stream) {
+ return errors::InvalidArgument("Unable to open file: ", file_path);
+ }
+ std::string line;
+ while (std::getline(stream, line)) {
+ lines_output->push_back(line);
+ }
+ return Status::OK();
+}
+
+Status GetTFliteModelInfo(const string& model_file_path,
+ ModelInfo* model_info) {
+ if (model_file_path.empty()) {
+ return errors::InvalidArgument("Invalid model file.");
+ }
+ struct stat stat_buf;
+ if (stat(model_file_path.c_str(), &stat_buf) != 0) {
+ int error_num = errno;
+ return errors::InvalidArgument("Invalid model file: ", model_file_path,
+ std::strerror(error_num));
+ }
+
+ std::unique_ptr<tflite::FlatBufferModel> model;
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ model = tflite::FlatBufferModel::BuildFromFile(model_file_path.data());
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+
+ tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+ if (!interpreter) {
+ return errors::InvalidArgument("Invalid model", model_file_path);
+ }
+ for (int i : interpreter->inputs()) {
+ TfLiteTensor* tensor = interpreter->tensor(i);
+ model_info->input_shapes.push_back(utils::GetTFLiteTensorShape(*tensor));
+ model_info->input_types.push_back(utils::GetTFDataType(tensor->type));
+ }
+ return Status::OK();
+}
+
+} // namespace utils
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.h b/tensorflow/contrib/lite/tools/accuracy/utils.h
new file mode 100644
index 0000000000..37cbad4d51
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/utils.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace utils {
+
+struct ModelInfo {
+ std::vector<TensorShape> input_shapes;
+ std::vector<DataType> input_types;
+};
+
+Status GetTFliteModelInfo(const string& model_file_path, ModelInfo* model_info);
+
+DataType GetTFDataType(TfLiteType tflite_type);
+
+TensorShape GetTFLiteTensorShape(const TfLiteTensor& tflite_tensor);
+
+Status ReadFileLines(const string& file_path,
+ std::vector<string>* lines_output);
+} // namespace utils
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/utils_test.cc b/tensorflow/contrib/lite/tools/accuracy/utils_test.cc
new file mode 100644
index 0000000000..727eba21b6
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/utils_test.cc
@@ -0,0 +1,76 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+tensorflow::string* g_test_model_file = nullptr;
+}
+
+namespace tensorflow {
+namespace metrics {
+namespace utils {
+namespace {
+
+TEST(UtilsTest, GetTFLiteModelInfoReturnsCorrectly) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+ // Input and outputs have shape : {1,8,8,3}
+ ModelInfo model_info;
+ auto status = GetTFliteModelInfo(test_model_file, &model_info);
+ TF_CHECK_OK(status);
+ ASSERT_EQ(4, model_info.input_shapes.size());
+ ASSERT_EQ(4, model_info.input_types.size());
+
+ for (int i = 0; i < 4; i++) {
+ const TensorShape& shape = model_info.input_shapes[i];
+ DataType dataType = model_info.input_types[i];
+ EXPECT_TRUE(shape.IsSameSize({1, 8, 8, 3}));
+ EXPECT_EQ(DT_FLOAT, dataType);
+ }
+}
+
+TEST(UtilsTest, GetTFliteModelInfoIncorrectFile) {
+ ModelInfo model_info;
+ auto status = GetTFliteModelInfo("non_existent_file", &model_info);
+ EXPECT_FALSE(status.ok());
+}
+
+} // namespace
+} // namespace utils
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ g_test_model_file = new tensorflow::string();
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("test_model_file", g_test_model_file,
+ "Path to test tflite model file."),
+ };
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ CHECK(parse_result) << "Required test_model_file";
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+ return RUN_ALL_TESTS();
+}