diff options
author | 2018-08-22 19:22:46 -0700 | |
---|---|---|
committer | 2018-08-22 19:25:48 -0700 | |
commit | 8db22dc063e6a6bb16b4676e53446987dac99a49 (patch) | |
tree | 14df3a9ca7614e61467058c2f9ab55d6ce9a4346 /tensorflow/contrib/lite/tools/accuracy | |
parent | b7b3f571728898c6d822aa1252d20bced15b989d (diff) |
Add an accuracy tool that can be used to evaluate model accuracy on device.
- Adds an accuracy tool that can be used to develop evaluation pipelines to
evaluate model accuracies.
- The binary can be compiled for mobile platforms and
the tool can be used to evaluate accuracy of models by running the binary on
device.
- Adds an example implementation for imagenet ILSVRC classification task
accuracy evaluation.
- More documentations and details coming soon.
PiperOrigin-RevId: 209869774
Diffstat (limited to 'tensorflow/contrib/lite/tools/accuracy')
31 files changed, 3268 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/tools/accuracy/BUILD b/tensorflow/contrib/lite/tools/accuracy/BUILD new file mode 100644 index 0000000000..db09de2909 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/BUILD @@ -0,0 +1,435 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + +cc_library( + name = "inception_preprocessing", + srcs = ["inception_preprocessing.cc"], + hdrs = ["inception_preprocessing.h"], + copts = [ + "-D__ANDROID_TYPES_FULL__", + "-DSUPPORT_SELECTIVE_REGISTRATION", + ], + deps = [ + ":stage", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core/kernels:android_tensorflow_image_op", + ], + "//conditions:default": [ + "//tensorflow/core:tensorflow", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + ], + }, + ), +) + +tf_cc_test( + name = "inception_preprocessing_test", + srcs = ["inception_preprocessing_test.cc"], + args = [ + "--test_image=$(location :testdata/grace_hopper.jpg)", + ], + data = [":testdata/grace_hopper.jpg"], + deps = [ + ":inception_preprocessing", + ":android_required_build_flags", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], + }, + ), +) + +cc_library( + name = "utils", + srcs = ["utils.cc"], + hdrs = ["utils.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }, + ), +) + +tf_cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + args = [ + "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)", + ], + data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + deps = [ + ":utils", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], + }, + ), +) + +cc_library( + name = "run_tflite_model_op", + srcs = ["run_tflite_model_op.cc"], + deps = [ + ":utils", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:tensorflow", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + ], + }, + ), + alwayslink = 1, +) + +cc_library( + name = "android_required_build_flags", + srcs = ["android_required_build_flags.cc"], +) + +tf_cc_test( + name = "run_tflite_model_op_test", + srcs = ["run_tflite_model_op_test.cc"], + args = [ + "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)", + ], + data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + ":run_tflite_model_op", + ":android_required_build_flags", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], + }, + ), +) + +cc_library( + name = "stage", + hdrs = ["stage.h"], + deps = [ + "//tensorflow/cc:scope", + ], +) + +cc_library( + name = "file_reader_stage", + srcs = ["file_reader_stage.cc"], + hdrs = ["file_reader_stage.h"], + deps = [ + ":stage", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + ], +) + +tf_cc_test( + name = "file_reader_stage_test", + srcs = ["file_reader_stage_test.cc"], + deps = [ + ":file_reader_stage", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core/kernels:android_whole_file_read_ops", + "//tensorflow/core:android_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:tensorflow", + ], + }, + ), +) + +cc_library( + name = "run_tflite_model_stage", + srcs = ["run_tflite_model_stage.cc"], + hdrs = ["run_tflite_model_stage.h"], + deps = [ + ":run_tflite_model_op", + ":stage", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + ], +) + +cc_library( + name = "accuracy_eval_stage", + hdrs = ["accuracy_eval_stage.h"], + deps = [ + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }, + ), +) + +cc_library( + name = "imagenet_topk_eval", + srcs = ["imagenet_topk_eval.cc"], + hdrs = ["imagenet_topk_eval.h"], + deps = [ + ":accuracy_eval_stage", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }, + ), +) + +tf_cc_test( + name = "imagenet_topk_eval_test", + srcs = ["imagenet_topk_eval_test.cc"], + deps = [ + ":imagenet_topk_eval", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }, + ), +) + +cc_library( + name = "eval_pipeline", + srcs = ["eval_pipeline.cc"], + hdrs = ["eval_pipeline.h"], + deps = [ + ":accuracy_eval_stage", + ":stage", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:core_cpu", + ], + }, + ), +) + +tf_cc_test( + name = "eval_pipeline_test", + srcs = ["eval_pipeline_test.cc"], + deps = [ + ":eval_pipeline", + "//tensorflow/cc:cc_ops", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:core_cpu", + "//tensorflow/core:ops", + "//tensorflow/core:tensorflow", + ], + }, + ), +) + +cc_library( + name = "eval_pipeline_builder", + srcs = ["eval_pipeline_builder.cc"], + hdrs = ["eval_pipeline_builder.h"], + deps = [ + ":eval_pipeline", + ":accuracy_eval_stage", + ":stage", + "@com_google_absl//absl/memory", + "//tensorflow/cc:cc_ops", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:core_cpu", + "//tensorflow/core:ops", + "//tensorflow/core:tensorflow", + ], + }, + ), +) + +tf_cc_test( + name = "eval_pipeline_builder_test", + srcs = ["eval_pipeline_builder_test.cc"], + deps = [ + ":eval_pipeline_builder", + "//tensorflow/cc:cc_ops", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:core_cpu", + "//tensorflow/core:ops", + "//tensorflow/core:tensorflow", + ], + }, + ), +) + +cc_library( + name = "csv_writer", + hdrs = ["csv_writer.h"], + deps = select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:lib", + ], + }, + ), +) + +cc_library( + name = "imagenet_model_evaluator", + srcs = ["imagenet_model_evaluator.cc"], + hdrs = ["imagenet_model_evaluator.h"], + deps = [ + ":android_required_build_flags", + ":eval_pipeline", + ":eval_pipeline_builder", + ":file_reader_stage", + ":imagenet_topk_eval", + ":inception_preprocessing", + ":run_tflite_model_stage", + ":utils", + "@com_google_absl//absl/memory", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core/kernels:android_whole_file_read_ops", + "//tensorflow/core/kernels:android_tensorflow_image_op", + ], + "//conditions:default": [ + "//tensorflow/core:tensorflow", + "//tensorflow/core:framework_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:core_cpu", + ], + }, + ), +) + +tf_cc_binary( + name = "imagenet_accuracy_eval", + srcs = ["imagenet_accuracy_eval.cc"], + deps = [ + ":android_required_build_flags", + ":csv_writer", + ":imagenet_model_evaluator", + ":imagenet_topk_eval", + "@com_google_absl//absl/memory", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:lib", + "//tensorflow/core:framework_internal", + ], + }, + ), +) diff --git a/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h b/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h new file mode 100644 index 0000000000..9cb843729a --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ + +#include <vector> + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace metrics { + +// Base class for evaluation stage that evaluates the accuracy of the model. +// This stage calculates the accuracy metrics given the model outputs and +// expected ground truth. +class AccuracyEval { + public: + AccuracyEval() = default; + AccuracyEval(const AccuracyEval&) = delete; + AccuracyEval& operator=(const AccuracyEval&) = delete; + + AccuracyEval(const AccuracyEval&&) = delete; + AccuracyEval& operator=(const AccuracyEval&&) = delete; + + virtual ~AccuracyEval() = default; + + // Evaluates the accuracy of the model for given `model_outputs` and the + // `ground truth`. + // Derived classes can do additional book keeping, calculate aggregrate + // statistics etc for the given model. + virtual Status ComputeEval(const std::vector<Tensor>& model_outputs, + const Tensor& ground_truth) = 0; +}; +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc b/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc new file mode 100644 index 0000000000..7fa8986716 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tensorflow on Android requires selective registration to be enabled in order +// for certain types (e.g. DT_UINT8) to work. +// Checks below ensure that for Android build, the right flags are passed to +// the compiler. + +#if defined(__ANDROID__) && (!defined(__ANDROID_TYPES_FULL__) || \ + !defined(SUPPORT_SELECTIVE_REGISTRATION)) +#error \ + "Binary needs custom kernel support. For enabling custom kernels on " \ + "Android, please pass -D__ANDROID_TYPES_FULL__ && " \ + "-DSUPPORT_SELECTIVE_REGISTRATION for including the kernel in the binary." +#endif diff --git a/tensorflow/contrib/lite/tools/accuracy/csv_writer.h b/tensorflow/contrib/lite/tools/accuracy/csv_writer.h new file mode 100644 index 0000000000..806b0d9418 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/csv_writer.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ + +#include <fstream> +#include <vector> + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace metrics { +// A simple CSV writer that writes values of same type for fixed number of +// columns. This supports a very limited set of CSV spec and doesn't do any +// escaping. +// Usage: +// std::ofstream * output_stream = ... +// CSVWriter writer({"column1", "column2"}, output_stream); +// writer.WriteRow({4, 5}); +// writer.Flush(); // flush results immediately. +class CSVWriter { + public: + CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream) + : num_columns_(columns.size()), output_stream_(output_stream) { + TF_CHECK_OK(WriteRow(columns, output_stream_)); + } + + template <typename T> + Status WriteRow(const std::vector<T>& values) { + if (values.size() != num_columns_) { + return errors::InvalidArgument("Invalid size for row:", values.size(), + " expected: ", num_columns_); + } + return WriteRow(values, output_stream_); + } + + void Flush() { output_stream_->flush(); } + + ~CSVWriter() { output_stream_->flush(); } + + private: + template <typename T> + static Status WriteRow(const std::vector<T>& values, + std::ofstream* output_stream) { + bool first = true; + for (const auto& v : values) { + if (!first) { + (*output_stream) << ", "; + } else { + first = false; + } + (*output_stream) << v; + } + (*output_stream) << "\n"; + if (!output_stream->good()) { + return errors::Internal("Writing to stream failed."); + } + return Status::OK(); + } + const size_t num_columns_; + std::ofstream* output_stream_; +}; +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc new file mode 100644 index 0000000000..a03aba6a26 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" + +namespace tensorflow { +namespace metrics { + +Status EvalPipeline::AttachSession(std::unique_ptr<Session> session) { + session_ = std::move(session); + TF_RETURN_IF_ERROR(session_->Create(model_graph_)); + return Status::OK(); +} + +Status EvalPipeline::Run(const Tensor& input, const Tensor& ground_truth) { + if (session_ == nullptr) { + return errors::Internal("No session is associated with the graph."); + } + std::vector<Tensor> outputs; + TF_RETURN_IF_ERROR(session_->Run({{params_.model_input_node_name, input}}, + {params_.model_output_node_name}, {}, + &outputs)); + TF_RETURN_IF_ERROR(eval_->ComputeEval(outputs, ground_truth)); + return Status::OK(); +} +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h new file mode 100644 index 0000000000..c9cfc86613 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ + +#include <string> + +#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h" +#include "tensorflow/contrib/lite/tools/accuracy/stage.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace metrics { + +// Pipeline for evaluating a model. +// Runs the graph and passes the output of graph to +// the provided instance of AccuracyEval. +// Example usage: +// AccuracyEval *eval; +// GraphDef graph_def; +// ... populate graph_def... +// +// EvalPipeline eval_pipeline(&graph_def, +// {.model_input_node_name = "model_input", +// .model_output_node_name = "model_output"}, +// eval); +// std::unique_ptr<Session> session(NewSession(SessionOptions())); +// TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session))); +// Tensor input = ... read input for the model ... +// Tensor ground_truth = ... read ground truth for the model ... +// TF_CHECK_OK(eval_pipeline.Run(input, ground_truth)); +// +class EvalPipeline { + public: + struct Params { + string model_input_node_name; + string model_output_node_name; + }; + + // Creates a new `EvalPipeline` object. The ownership of the `accuracy_eval` + // is retained by the caller. Lifetime of `accuracy_eval` instance should + // be longer than the lifetime of this instance of pipeline. + EvalPipeline(const GraphDef& graph, const Params& params, + AccuracyEval* accuracy_eval) + : model_graph_(graph), + params_(params), + eval_(accuracy_eval), + session_(nullptr) {} + + EvalPipeline(const EvalPipeline&) = delete; + EvalPipeline& operator=(const EvalPipeline&) = delete; + + EvalPipeline(const EvalPipeline&&) = delete; + EvalPipeline& operator=(const EvalPipeline&&) = delete; + + // Attaches the given session to this instance of pipeline. + // The provided session object will be reused for subsequent calls to + // EvalPipeline::Run. + Status AttachSession(std::unique_ptr<Session> session); + + // Runs the model by feeding `input` and then passes the output of the model + // along with provided `ground_truth` to the AccuracyEval instance by calling + // AccuracyEval::ComputeEval. + Status Run(const Tensor& input, const Tensor& ground_truth); + + private: + GraphDef model_graph_; + Params params_; + AccuracyEval* eval_; + std::unique_ptr<Session> session_; +}; +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc new file mode 100644 index 0000000000..2e16437e15 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h" + +#include "absl/memory/memory.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace metrics { + +EvalPipelineBuilder& EvalPipelineBuilder::WithInputStage(Stage* input_stage) { + input_stage_ = input_stage; + return *this; +} + +EvalPipelineBuilder& EvalPipelineBuilder::WithPreprocessingStage( + Stage* preprocessing_stage) { + preprocessing_stage_ = preprocessing_stage; + return *this; +} + +EvalPipelineBuilder& EvalPipelineBuilder::WithRunModelStage( + Stage* run_model_stage) { + run_model_stage_ = run_model_stage; + return *this; +} + +EvalPipelineBuilder& EvalPipelineBuilder::WithAccuracyEval( + AccuracyEval* accuracy_eval) { + accuracy_eval_ = accuracy_eval; + return *this; +} + +EvalPipelineBuilder& EvalPipelineBuilder::WithInput(const string& input_name, + DataType input_type) { + input_name_ = input_name; + input_type_ = input_type; + return *this; +} + +Status EvalPipelineBuilder::Build( + const Scope& scope, std::unique_ptr<EvalPipeline>* eval_pipeline) { + if (input_stage_ == nullptr) { + return errors::InvalidArgument("Input stage is null."); + } + if (preprocessing_stage_ == nullptr) { + return errors::InvalidArgument("Preprocessing stage is null."); + } + if (run_model_stage_ == nullptr) { + return errors::InvalidArgument("Run model stage is null."); + } + if (accuracy_eval_ == nullptr) { + return errors::InvalidArgument("accuracy_eval is null."); + } + if (input_name_.empty()) { + return errors::InvalidArgument("input name is not set."); + } + if (input_type_ == DT_INVALID) { + return errors::InvalidArgument("input type is not set."); + } + + auto input_placeholder = + ops::Placeholder(scope.WithOpName(input_name_), input_type_); + TF_RETURN_IF_ERROR(scope.status()); + + input_stage_->AddToGraph(scope, input_placeholder); + TF_RETURN_IF_ERROR(scope.status()); + + preprocessing_stage_->AddToGraph(scope, input_stage_->Output()); + TF_RETURN_IF_ERROR(scope.status()); + + run_model_stage_->AddToGraph(scope, preprocessing_stage_->Output()); + TF_RETURN_IF_ERROR(scope.status()); + + GraphDef graph_def; + TF_RETURN_IF_ERROR(scope.ToGraphDef(&graph_def)); + EvalPipeline::Params params; + params.model_input_node_name = input_name_; + params.model_output_node_name = run_model_stage_->output_name(); + *eval_pipeline = + absl::make_unique<EvalPipeline>(graph_def, params, accuracy_eval_); + + return Status::OK(); +} + +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h new file mode 100644 index 0000000000..692db022f8 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ + +#include <memory> +#include <string> + +#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h" +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" +#include "tensorflow/contrib/lite/tools/accuracy/stage.h" + +namespace tensorflow { +namespace metrics { + +// A builder to simplify construction of an `EvalPipeline` instance. +// The `Build` method creates an |EvalPipeline| with the following structure: +// |input| -> |input_stage| +// |--> |preprocessing_stage| +// |--> |run_model_stage| -> |accuracy_eval_stage|. +// The stages are chained in the order shown above. Any missing stage results in +// an error. The ownership of the stage object is retained by the caller. Stage +// objects need to exist until the |Build| method is called. +// +// Currently only single inputs are supported. +// +// Example Usage: +// EvalPipelineBuilder builder; +// std::unique_ptr<EvalPipeline> eval_pipeline; +// auto status = builder.WithInput("pipeline_input", DT_FLOAT) +// .WithInputStage(&input_stage) +// .WithRunModelStage(&run_model_stage) +// .WithPreprocessingStage(&preprocess_stage) +// .WithAccuracyEval(&eval) +// .Build(scope, &eval_pipeline); +// TF_CHECK_OK(status); +class EvalPipelineBuilder { + public: + EvalPipelineBuilder() = default; + EvalPipelineBuilder(const EvalPipelineBuilder&) = delete; + EvalPipeline& operator=(const EvalPipelineBuilder&) = delete; + + EvalPipelineBuilder(const EvalPipelineBuilder&&) = delete; + EvalPipeline& operator=(const EvalPipelineBuilder&&) = delete; + + // Sets the input stage for the pipeline. + // Input stage converts the input, say filename into appropriate format + // that can be consumed by the preprocessing stage. + EvalPipelineBuilder& WithInputStage(Stage* input_stage); + + // Sets the preprocessing stage for the pipeline. + // Preprocessing stage converts the input into a format that can be used to + // run the model. + EvalPipelineBuilder& WithPreprocessingStage(Stage* preprocessing_stage); + + // Sets the run model stage for the pipeline. + // This stage receives the preprocessing input and output of this stage is + // fed to the accuracy eval stage. + EvalPipelineBuilder& WithRunModelStage(Stage* run_model_stage); + + // Sets the accuracy eval for the pipeline. + // Results of evaluating the pipeline are fed to the `accuracy_eval` instance. + EvalPipelineBuilder& WithAccuracyEval(AccuracyEval* accuracy_eval); + + // Sets the name and type of input for the pipeline. + // TODO(shashishekhar): Support multiple inputs for the pipeline, use a vector + // here. + EvalPipelineBuilder& WithInput(const string& input_name, DataType input_type); + + // Builds the pipeline and assigns the pipeline to `eval_pipeline`. + // If the pipeline creation fails `eval_pipeline` is untouched. + Status Build(const Scope& scope, + std::unique_ptr<EvalPipeline>* eval_pipeline); + + private: + Stage* input_stage_ = nullptr; + Stage* preprocessing_stage_ = nullptr; + Stage* run_model_stage_ = nullptr; + AccuracyEval* accuracy_eval_ = nullptr; + string input_name_; + DataType input_type_ = DT_INVALID; +}; + +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc new file mode 100644 index 0000000000..2d41929b79 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc @@ -0,0 +1,229 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h" +#include <gtest/gtest.h> +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace metrics { +namespace { + +class IdentityStage : public Stage { + public: + IdentityStage(const string& name, const string& output) + : name_(name), output_(output) {} + + void AddToGraph(const Scope& scope, const Input& input) override { + called_count_++; + inputs_.push_back(input.node()->name()); + stage_output_ = ops::Identity(scope.WithOpName(output_), input); + } + + string name() const override { return name_; } + string output_name() const override { return output_; } + + int times_called() const { return called_count_; } + + const std::vector<string> input_params() { return inputs_; } + + private: + string name_; + string output_; + int called_count_ = 0; + std::vector<string> inputs_; +}; + +class FailingStage : public Stage { + public: + FailingStage(const string& name, const string& output) + : name_(name), output_(output) {} + + void AddToGraph(const Scope& scope, const Input& input) override { + called_count_++; + scope.UpdateStatus(errors::Internal("Stage failed:", name_)); + } + + string name() const override { return name_; } + string output_name() const override { return output_; } + + int times_called() const { return called_count_; } + + private: + string name_; + string output_; + int called_count_ = 0; +}; + +class SimpleAccuracyEval : public AccuracyEval { + public: + SimpleAccuracyEval() {} + + Status ComputeEval(const std::vector<Tensor>& model_outputs, + const Tensor& ground_truth) override { + return Status::OK(); + } +}; + +TEST(EvalPipelineBuilder, MissingPipelineStages) { + IdentityStage input_stage("input_stage", "input_stage_out"); + IdentityStage run_model_stage("run_model", "run_model_out"); + IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out"); + const string pipeline_input = "pipeline_input"; + + SimpleAccuracyEval eval; + + Scope scope = Scope::NewRootScope(); + std::unique_ptr<EvalPipeline> eval_pipeline; + EvalPipelineBuilder builder; + auto status = + builder.WithInputStage(&input_stage).Build(scope, &eval_pipeline); + EXPECT_FALSE(status.ok()); + EXPECT_FALSE(eval_pipeline); + + status = + builder.WithRunModelStage(&run_model_stage).Build(scope, &eval_pipeline); + EXPECT_FALSE(status.ok()); + EXPECT_FALSE(eval_pipeline); + + status = builder.WithPreprocessingStage(&preprocess_stage) + .Build(scope, &eval_pipeline); + EXPECT_FALSE(status.ok()); + EXPECT_FALSE(eval_pipeline); + + status = + builder.WithInput(pipeline_input, DT_FLOAT).Build(scope, &eval_pipeline); + EXPECT_FALSE(status.ok()); + EXPECT_FALSE(eval_pipeline); + + status = builder.WithAccuracyEval(&eval).Build(scope, &eval_pipeline); + TF_CHECK_OK(status); + EXPECT_TRUE(eval_pipeline); +} + +TEST(EvalPipeline, InputStageFailure) { + FailingStage input_stage("input_stage", "input_stage_out"); + IdentityStage run_model_stage("run_model", "run_model_out"); + IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out"); + const string pipeline_input = "pipeline_input"; + + SimpleAccuracyEval eval; + + Scope scope = Scope::NewRootScope(); + std::unique_ptr<EvalPipeline> eval_pipeline; + EvalPipelineBuilder builder; + auto status = builder.WithInputStage(&input_stage) + .WithRunModelStage(&run_model_stage) + .WithPreprocessingStage(&preprocess_stage) + .WithInput(pipeline_input, DT_FLOAT) + .WithAccuracyEval(&eval) + .Build(scope, &eval_pipeline); + + EXPECT_FALSE(scope.status().ok()); + // None of the other stages would have been called. + EXPECT_EQ(1, input_stage.times_called()); + EXPECT_EQ(0, preprocess_stage.times_called()); + EXPECT_EQ(0, run_model_stage.times_called()); +} + +TEST(EvalPipeline, PreprocessingFailure) { + IdentityStage input_stage("input_stage", "input_stage_out"); + FailingStage preprocess_stage("preprocess_stage", "preprocess_stage_out"); + IdentityStage run_model_stage("run_model", "run_model_out"); + const string pipeline_input = "pipeline_input"; + + SimpleAccuracyEval eval; + + Scope scope = Scope::NewRootScope(); + std::unique_ptr<EvalPipeline> eval_pipeline; + EvalPipelineBuilder builder; + auto status = builder.WithInputStage(&input_stage) + .WithRunModelStage(&run_model_stage) + .WithPreprocessingStage(&preprocess_stage) + .WithInput(pipeline_input, DT_FLOAT) + .WithAccuracyEval(&eval) + .Build(scope, &eval_pipeline); + + EXPECT_FALSE(status.ok()); + // None of the other stages would have been called. + EXPECT_EQ(1, input_stage.times_called()); + EXPECT_EQ(1, preprocess_stage.times_called()); + EXPECT_EQ(0, run_model_stage.times_called()); +} + +TEST(EvalPipeline, GraphEvalFailure) { + IdentityStage input_stage("input_stage", "input_stage_out"); + IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out"); + FailingStage run_model_stage("run_model", "run_model_out"); + const string pipeline_input = "pipeline_input"; + + SimpleAccuracyEval eval; + + Scope scope = Scope::NewRootScope(); + std::unique_ptr<EvalPipeline> eval_pipeline; + EvalPipelineBuilder builder; + auto status = builder.WithInputStage(&input_stage) + .WithRunModelStage(&run_model_stage) + .WithPreprocessingStage(&preprocess_stage) + .WithInput(pipeline_input, DT_FLOAT) + .WithAccuracyEval(&eval) + .Build(scope, &eval_pipeline); + + EXPECT_FALSE(status.ok()); + // None of the other stages would have been called. + EXPECT_EQ(1, input_stage.times_called()); + EXPECT_EQ(1, preprocess_stage.times_called()); + EXPECT_EQ(1, run_model_stage.times_called()); +} + +TEST(EvalPipeline, PipelineHasCorrectSequence) { + IdentityStage input_stage("input_stage", "input_stage_out"); + IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out"); + IdentityStage run_model_stage("run_model", "run_model_out"); + const string pipeline_input = "pipeline_input"; + + SimpleAccuracyEval eval; + + Scope scope = Scope::NewRootScope(); + std::unique_ptr<EvalPipeline> eval_pipeline; + EvalPipelineBuilder builder; + auto status = builder.WithInputStage(&input_stage) + .WithRunModelStage(&run_model_stage) + .WithPreprocessingStage(&preprocess_stage) + .WithInput(pipeline_input, DT_FLOAT) + .WithAccuracyEval(&eval) + .Build(scope, &eval_pipeline); + TF_CHECK_OK(status); + + ASSERT_EQ(1, input_stage.times_called()); + ASSERT_EQ(1, run_model_stage.times_called()); + ASSERT_EQ(1, preprocess_stage.times_called()); + + EXPECT_EQ(pipeline_input, input_stage.input_params()[0]); + EXPECT_EQ(input_stage.output_name(), preprocess_stage.input_params()[0]); + EXPECT_EQ(preprocess_stage.output_name(), run_model_stage.input_params()[0]); +} + +} // namespace + +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc new file mode 100644 index 0000000000..ea0f6e19df --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" +#include <gtest/gtest.h> +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace metrics { +namespace { + +Tensor CreateFloatTensor(float value) { + Tensor tensor(DT_FLOAT, TensorShape({})); + tensor.scalar<float>()() = value; + return tensor; +} + +class NoOpAccuracyEval : public AccuracyEval { + public: + explicit NoOpAccuracyEval(const Status& status_to_return) + : status_to_return_(status_to_return) {} + + Status ComputeEval(const std::vector<Tensor>& model_outputs, + const Tensor& ground_truth) override { + model_outputs_ = model_outputs; + ground_truth_ = ground_truth; + was_called_ = true; + return status_to_return_; + } + + bool WasCalled() { return was_called_; } + std::vector<Tensor> model_outputs() { return model_outputs_; } + Tensor ground_truth() { return ground_truth_; } + + private: + std::vector<Tensor> model_outputs_; + Tensor ground_truth_; + Status status_to_return_; + bool was_called_ = false; +}; + +TEST(EvalPipeline, AccuracyEvalIsCalled) { + Scope scope = Scope::NewRootScope(); + // A graph that adds 1 to input. + auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT); + auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f); + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + EvalPipeline::Params params; + params.model_input_node_name = "input"; + params.model_output_node_name = "output"; + NoOpAccuracyEval accuracy_eval(Status::OK()); + + EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval); + std::unique_ptr<Session> session(NewSession(SessionOptions())); + TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session))); + TF_CHECK_OK(eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27))); + + EXPECT_TRUE(accuracy_eval.WasCalled()); + auto outputs = accuracy_eval.model_outputs(); + ASSERT_EQ(1, outputs.size()); + EXPECT_EQ(6.0f, outputs[0].scalar<float>()()); + // Ground truth is unchanged. + EXPECT_EQ(27, accuracy_eval.ground_truth().scalar<float>()()); +} + +TEST(EvalPipeline, EvalIsNotCalledOnGraphRunFailure) { + Scope scope = Scope::NewRootScope(); + // A graph that adds 1 to input. + auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT); + auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f); + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + EvalPipeline::Params params; + params.model_input_node_name = "input"; + params.model_output_node_name = "output"; + NoOpAccuracyEval accuracy_eval(Status::OK()); + + EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval); + std::unique_ptr<Session> session(NewSession(SessionOptions())); + TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session))); + + // Pass a string tensor instead of a float tensor. + Tensor string_tensor(DT_STRING, TensorShape{}); + auto status = eval_pipeline.Run(string_tensor, CreateFloatTensor(27)); + EXPECT_FALSE(accuracy_eval.WasCalled()); + EXPECT_FALSE(status.ok()); +} + +TEST(EvalPipeline, AccuracyEvalFailureResultsInFailure) { + Scope scope = Scope::NewRootScope(); + // A graph that adds 1 to input. + auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT); + auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f); + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + EvalPipeline::Params params; + params.model_input_node_name = "input"; + params.model_output_node_name = "output"; + NoOpAccuracyEval accuracy_eval(errors::Internal("accuracy_fail")); + + EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval); + std::unique_ptr<Session> session(NewSession(SessionOptions())); + TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session))); + auto status = eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27)); + + EXPECT_TRUE(accuracy_eval.WasCalled()); + EXPECT_FALSE(status.ok()); +} + +} // namespace + +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc new file mode 100644 index 0000000000..61bed369f8 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h" + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace metrics { +void FileReaderStage::AddToGraph(const Scope& scope, const Input& input) { + if (!scope.ok()) return; + Scope s = scope.WithOpName(name()); + this->stage_output_ = ops::ReadFile(s.WithOpName(output_name()), input); +} +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h new file mode 100644 index 0000000000..18db5837c1 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ + +#include <string> + +#include "tensorflow/contrib/lite/tools/accuracy/stage.h" + +namespace tensorflow { +namespace metrics { +// A stage for reading a file into |string|. +// Inputs: a string tensor: |file_name|. +// Outputs: a string tensor: contents of |file_name|. +class FileReaderStage : public Stage { + public: + string name() const override { return "stage_filereader"; } + string output_name() const override { return "stage_filereader_output"; } + + void AddToGraph(const Scope& scope, const Input& input) override; +}; +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc new file mode 100644 index 0000000000..a75f99187d --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdio> +#include <fstream> +#include <memory> + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace metrics { +namespace { + +class TempFile { + public: + TempFile() { + string file_path; + if (Env::Default()->LocalTempFilename(&file_path)) { + file_path_ = file_path; + created_ = true; + } + } + + string filepath() { return file_path_; } + bool CreateFileWithContents(const std::string& contents) { + if (!created_) { + return false; + } + std::fstream file(file_path_, std::ios_base::out); + if (file) { + file << contents; + } + return file.good(); + } + + ~TempFile() { + if (created_) { + std::remove(file_path_.c_str()); + } + } + + private: + bool created_ = false; + string file_path_; +}; + +TEST(FileReaderStageTest, FileIsRead) { + TempFile file; + const string kFileContents = "Hello world."; + ASSERT_TRUE(file.CreateFileWithContents(kFileContents)); + Scope scope = Scope::NewRootScope(); + FileReaderStage reader_stage; + reader_stage.AddToGraph(scope, file.filepath()); + TF_CHECK_OK(scope.status()); + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr<Session> session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + std::vector<Tensor> outputs; + auto run_status = + session->Run({}, /*inputs*/ + {reader_stage.output_name()}, {}, /*target node names */ + &outputs); + TF_CHECK_OK(run_status); + EXPECT_EQ(1, outputs.size()); + string contents = outputs[0].scalar<string>()(); + EXPECT_EQ(kFileContents, contents); +} + +TEST(FileReaderStageTest, InvalidFile) { + Scope scope = Scope::NewRootScope(); + FileReaderStage reader_stage; + reader_stage.AddToGraph(scope, string("non_existent_file")); + TF_CHECK_OK(scope.status()); + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr<Session> session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + std::vector<Tensor> outputs; + auto run_status = + session->Run({}, /*inputs*/ + {reader_stage.output_name()}, {}, /*target node names */ + &outputs); + EXPECT_FALSE(run_status.ok()); +} + +} // namespace + +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_accuracy_eval.cc b/tensorflow/contrib/lite/tools/accuracy/imagenet_accuracy_eval.cc new file mode 100644 index 0000000000..8103d6adb5 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_accuracy_eval.cc @@ -0,0 +1,146 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <iomanip> +#include <memory> + +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/tools/accuracy/csv_writer.h" +#include "tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h" +#include "tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace metrics { + +namespace { + +std::vector<double> GetAccuracies( + const ImagenetTopKAccuracy::AccuracyStats& accuracy_stats) { + std::vector<double> results; + results.reserve(accuracy_stats.number_of_images); + if (accuracy_stats.number_of_images > 0) { + for (int n : accuracy_stats.topk_counts) { + double accuracy = 0; + if (accuracy_stats.number_of_images > 0) { + accuracy = (n * 100.0) / accuracy_stats.number_of_images; + } + results.push_back(accuracy); + } + } + return results; +} + +} // namespace + +// Writes results to a CSV file. +class ResultsWriter : public ImagenetModelEvaluator::Observer { + public: + explicit ResultsWriter(std::unique_ptr<CSVWriter> writer) + : writer_(std::move(writer)) {} + + void OnEvaluationStart(int total_number_of_images) override {} + + void OnSingleImageEvaluationComplete( + const ImagenetTopKAccuracy::AccuracyStats& stats, + const string& image) override; + + private: + std::unique_ptr<CSVWriter> writer_; +}; + +void ResultsWriter::OnSingleImageEvaluationComplete( + const ImagenetTopKAccuracy::AccuracyStats& stats, const string& image) { + TF_CHECK_OK(writer_->WriteRow(GetAccuracies(stats))); + writer_->Flush(); +} + +// Logs results to standard output with `kLogDelayUs` microseconds. +class ResultsLogger : public ImagenetModelEvaluator::Observer { + public: + void OnEvaluationStart(int total_number_of_images) override; + + void OnSingleImageEvaluationComplete( + const ImagenetTopKAccuracy::AccuracyStats& stats, + const string& image) override; + + private: + int total_num_images_ = 0; + uint64 last_logged_time_us_ = 0; + static constexpr int kLogDelayUs = 500 * 1000; +}; + +void ResultsLogger::OnEvaluationStart(int total_number_of_images) { + total_num_images_ = total_number_of_images; + LOG(ERROR) << "Starting model evaluation: " << total_num_images_; +} + +void ResultsLogger::OnSingleImageEvaluationComplete( + const ImagenetTopKAccuracy::AccuracyStats& stats, const string& image) { + int num_evaluated = stats.number_of_images; + + double current_percent = num_evaluated * 100.0 / total_num_images_; + auto now_us = Env::Default()->NowMicros(); + + if ((now_us - last_logged_time_us_) >= kLogDelayUs) { + last_logged_time_us_ = now_us; + + LOG(ERROR) << "Evaluated " << num_evaluated << "/" << total_num_images_ + << " images, " << std::setprecision(2) << std::fixed + << current_percent << "%"; + } +} + +int Main(int argc, char* argv[]) { + // TODO(shashishekhar): Make this binary configurable and model + // agnostic. + string output_file_path; + std::vector<Flag> flag_list = { + Flag("output_file_path", &output_file_path, "Path to output file."), + }; + Flags::Parse(&argc, argv, flag_list); + + std::unique_ptr<ImagenetModelEvaluator> evaluator; + CHECK(!output_file_path.empty()) << "Invalid output file path."; + + TF_CHECK_OK(ImagenetModelEvaluator::Create(argc, argv, &evaluator)); + + std::ofstream output_stream(output_file_path, std::ios::out); + CHECK(output_stream) << "Unable to open output file path: '" + << output_file_path << "'"; + + output_stream << std::setprecision(3) << std::fixed; + std::vector<string> columns; + columns.reserve(evaluator->params().num_ranks); + for (int i = 0; i < evaluator->params().num_ranks; i++) { + columns.push_back("Top " + std::to_string(i + 1)); + } + + ResultsWriter results_writer( + absl::make_unique<CSVWriter>(columns, &output_stream)); + ResultsLogger logger; + evaluator->AddObserver(&results_writer); + evaluator->AddObserver(&logger); + TF_CHECK_OK(evaluator->EvaluateModel()); + return 0; +} + +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char* argv[]) { + return tensorflow::metrics::Main(argc, argv); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.cc b/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.cc new file mode 100644 index 0000000000..6ddde8e7c0 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.cc @@ -0,0 +1,206 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h" + +#include <fstream> +#include <iomanip> +#include <string> +#include <vector> + +#include "absl/memory/memory.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h" +#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h" +#include "tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h" +#include "tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h" +#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h" +#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +using tensorflow::string; + +string StripTrailingSlashes(const string& path) { + int end = path.size(); + while (end > 0 && path[end - 1] == '/') { + end--; + } + return path.substr(0, end); +} + +tensorflow::Tensor CreateStringTensor(const string& value) { + tensorflow::Tensor tensor(tensorflow::DT_STRING, tensorflow::TensorShape({})); + tensor.scalar<string>()() = value; + return tensor; +} + +template <typename T> +std::vector<T> GetFirstN(const std::vector<T>& v, int n) { + if (n >= v.size()) return v; + std::vector<T> result(v.begin(), v.begin() + n); + return result; +} + +// File pattern for imagenet files. +const char* const kImagenetFilePattern = "*.[jJ][pP][eE][gG]"; + +} // namespace + +namespace tensorflow { +namespace metrics { + +/*static*/ Status ImagenetModelEvaluator::Create( + int argc, char* argv[], + std::unique_ptr<ImagenetModelEvaluator>* model_evaluator) { + Params params; + const std::vector<Flag> flag_list = { + Flag("model_output_labels", ¶ms.model_output_labels_path, + "Path to labels that correspond to output of model." + " E.g. in case of mobilenet, this is the path to label " + "file where each label is in the same order as the output" + " of the model."), + Flag("ground_truth_images_path", ¶ms.ground_truth_images_path, + "Path to ground truth images."), + Flag("ground_truth_labels", ¶ms.ground_truth_labels_path, + "Path to ground truth labels."), + Flag("num_images", ¶ms.number_of_images, + "Number of examples to evaluate, pass 0 for all " + "examples. Default: 100"), + tensorflow::Flag("model_file", ¶ms.model_file_path, + "Path to test tflite model file."), + }; + const bool parse_result = Flags::Parse(&argc, argv, flag_list); + if (!parse_result) + return errors::InvalidArgument("Invalid command line flags"); + ::tensorflow::port::InitMain(argv[0], &argc, &argv); + + TF_RETURN_WITH_CONTEXT_IF_ERROR( + Env::Default()->IsDirectory(params.ground_truth_images_path), + "Invalid ground truth data path."); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + Env::Default()->FileExists(params.ground_truth_labels_path), + "Invalid ground truth labels path."); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + Env::Default()->FileExists(params.model_output_labels_path), + "Invalid model output labels path."); + + if (params.number_of_images < 0) { + return errors::InvalidArgument("Invalid: num_examples"); + } + + utils::ModelInfo model_info; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + utils::GetTFliteModelInfo(params.model_file_path, &model_info), + "Invalid TFLite model."); + + *model_evaluator = + absl::make_unique<ImagenetModelEvaluator>(model_info, params); + return Status::OK(); +} + +Status ImagenetModelEvaluator::EvaluateModel() { + if (model_info_.input_shapes.size() != 1) { + return errors::InvalidArgument("Invalid input shape"); + } + + const TensorShape& input_shape = model_info_.input_shapes[0]; + // Input should be of the shape {1, height, width, 3} + if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) { + return errors::InvalidArgument("Invalid input shape for the model."); + } + + const int image_height = input_shape.dim_size(1); + const int image_width = input_shape.dim_size(2); + const bool is_quantized = (model_info_.input_types[0] == DT_UINT8); + + RunTFLiteModelStage::Params tfl_model_params; + tfl_model_params.model_file_path = params_.model_file_path; + if (is_quantized) { + tfl_model_params.input_type = {DT_UINT8}; + tfl_model_params.output_type = {DT_UINT8}; + } else { + tfl_model_params.input_type = {DT_FLOAT}; + tfl_model_params.output_type = {DT_FLOAT}; + } + + Scope root = Scope::NewRootScope(); + FileReaderStage reader; + InceptionPreprocessingStage inc(image_height, image_width, is_quantized); + RunTFLiteModelStage tfl_model_stage(tfl_model_params); + EvalPipelineBuilder builder; + std::vector<string> model_labels; + TF_RETURN_IF_ERROR( + utils::ReadFileLines(params_.model_output_labels_path, &model_labels)); + if (model_labels.size() != 1001) { + return errors::InvalidArgument("Invalid number of labels: ", + model_labels.size()); + } + + ImagenetTopKAccuracy eval(model_labels, params_.num_ranks); + std::unique_ptr<EvalPipeline> eval_pipeline; + + auto build_status = builder.WithInputStage(&reader) + .WithPreprocessingStage(&inc) + .WithRunModelStage(&tfl_model_stage) + .WithAccuracyEval(&eval) + .WithInput("input_file", DT_STRING) + .Build(root, &eval_pipeline); + TF_RETURN_WITH_CONTEXT_IF_ERROR(build_status, + "Failure while building eval pipeline."); + + std::unique_ptr<Session> session(NewSession(SessionOptions())); + + TF_RETURN_IF_ERROR(eval_pipeline->AttachSession(std::move(session))); + string data_path = + StripTrailingSlashes(params_.ground_truth_images_path) + "/"; + + const string imagenet_file_pattern = data_path + kImagenetFilePattern; + std::vector<string> image_files; + TF_CHECK_OK( + Env::Default()->GetMatchingPaths(imagenet_file_pattern, &image_files)); + std::vector<string> image_labels; + TF_CHECK_OK( + utils::ReadFileLines(params_.ground_truth_labels_path, &image_labels)); + CHECK_EQ(image_files.size(), image_labels.size()); + + // Process files in filename sorted order. + std::sort(image_files.begin(), image_files.end()); + if (params_.number_of_images > 0) { + image_files = GetFirstN(image_files, params_.number_of_images); + image_labels = GetFirstN(image_labels, params_.number_of_images); + } + + for (Observer* observer : observers_) { + observer->OnEvaluationStart(image_files.size()); + } + + for (int i = 0; i < image_files.size(); i++) { + TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_files[i]), + CreateStringTensor(image_labels[i]))); + auto stats = eval.GetTopKAccuracySoFar(); + + for (Observer* observer : observers_) { + observer->OnSingleImageEvaluationComplete(stats, image_files[i]); + } + } + return Status::OK(); +} + +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h b/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h new file mode 100644 index 0000000000..0308ac95b6 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_model_evaluator.h @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_ +#include <string> +#include <vector> + +#include "tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h" +#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace metrics { + +// Evaluates models accuracy for ILSVRC dataset. +// +// Generates the top-1, top-k accuracy counts where k is +// controlled by |num_ranks|. +// Usage: +// ModelInfo model_info = .. +// ImagenetModelEvaluator::Params params; +// .. set params to image, label, output label and model file path.. +// SomeObserver observer; +// ImagenetModelEvaluator evaluator(model_info, params); +// evaluator.AddObserver(&observer); +// TF_CHECK_OK(evaluator.EvaluateModel()); +class ImagenetModelEvaluator { + public: + struct Params { + // Path to ground truth images. + string ground_truth_images_path; + + // Path to labels file for ground truth image. + // This file should be generated with the scripts. + string ground_truth_labels_path; + + // This is word labels generated by the model. The category + // indices of output probabilities generated by the model maybe different + // from the indices in the imagenet dataset. + string model_output_labels_path; + + // Path to the model file. + string model_file_path; + + // The maximum number of images to calculate accuracy. + // 0 means all images, a positive number means only the specified + // number of images. + int number_of_images = 0; + + // Number of ranks, top K. + int num_ranks = 10; + }; + + // An evaluation observer. + class Observer { + public: + Observer() = default; + Observer(const Observer&) = delete; + Observer& operator=(const Observer&) = delete; + + Observer(const Observer&&) = delete; + Observer& operator=(const Observer&&) = delete; + + // Called on start of evaluation. + virtual void OnEvaluationStart(int total_number_of_images) = 0; + + // Called when evaluation was complete for `image`. + virtual void OnSingleImageEvaluationComplete( + const ImagenetTopKAccuracy::AccuracyStats& stats, + const string& image) = 0; + + virtual ~Observer() = default; + }; + + ImagenetModelEvaluator(const utils::ModelInfo& model_info, + const Params& params) + : model_info_(model_info), params_(params) {} + + // Factory method to create the evaluator by parsing command line arguments. + static Status Create(int argc, char* argv[], + std::unique_ptr<ImagenetModelEvaluator>* evaluator); + + // Adds an observer that can observe evaluation events.. + void AddObserver(Observer* observer) { observers_.push_back(observer); } + + const Params& params() { return params_; } + + // Evaluates the provided model over the dataset. + Status EvaluateModel(); + + private: + std::vector<Observer*> observers_; + const utils::ModelInfo model_info_; + const Params params_; +}; + +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.cc b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.cc new file mode 100644 index 0000000000..1595bbee2f --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.cc @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h" + +#include <numeric> + +namespace { +constexpr int kNumCategories = 1001; +std::vector<int> GetTopK(const std::vector<float>& values, int k) { + CHECK_LE(k, values.size()); + std::vector<int> indices(values.size()); + + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&values](int a, int b) { return values[a] > values[b]; }); + + indices.resize(k); + return indices; +} +} // namespace + +namespace tensorflow { +namespace metrics { +ImagenetTopKAccuracy::ImagenetTopKAccuracy( + const std::vector<string>& ground_truth_labels, int k) + : ground_truth_labels_(ground_truth_labels), + k_(k), + accuracy_counts_(k_, 0), + num_samples_(0) { + CHECK_EQ(kNumCategories, ground_truth_labels.size()); +} + +Status ImagenetTopKAccuracy::ComputeEval( + const std::vector<Tensor>& model_outputs, const Tensor& ground_truth) { + if (model_outputs.size() != 1) { + return errors::InvalidArgument("Invalid model output: ", + model_outputs.size()); + } + const Tensor& output = model_outputs[0]; + if (!output.shape().IsSameSize({1, kNumCategories})) { + return errors::InvalidArgument("Invalid shape of model output: ", + output.shape().DebugString()); + } + if (ground_truth.dtype() != DT_STRING && ground_truth.dims() != 0) { + return errors::InvalidArgument("Invalid ground truth type: ", + ground_truth.DebugString()); + } + string ground_truth_label = ground_truth.scalar<string>()(); + + std::vector<float> probabilities; + probabilities.reserve(kNumCategories); + if (output.dtype() == DT_FLOAT) { + auto probs = output.flat<float>(); + for (size_t i = 0; i < probs.size(); i++) { + probabilities.push_back(probs(i)); + } + } else { + auto probs = output.flat<uint8>(); + for (size_t i = 0; i < probs.size(); i++) { + probabilities.push_back(probs(i)); + } + } + + CHECK_EQ(kNumCategories, probabilities.size()); + std::vector<int> topK = GetTopK(probabilities, k_); + int ground_truth_index = GroundTruthIndex(ground_truth_label); + for (size_t i = 0; i < topK.size(); ++i) { + if (ground_truth_index == topK[i]) { + for (size_t j = i; j < topK.size(); j++) { + accuracy_counts_[j] += 1; + } + break; + } + } + num_samples_++; + return Status::OK(); +} + +const ImagenetTopKAccuracy::AccuracyStats +ImagenetTopKAccuracy::GetTopKAccuracySoFar() const { + AccuracyStats stats; + stats.number_of_images = num_samples_; + stats.topk_counts = accuracy_counts_; + return stats; +} + +int ImagenetTopKAccuracy::GroundTruthIndex(const string& label) const { + auto index = std::find(ground_truth_labels_.cbegin(), + ground_truth_labels_.cend(), label); + CHECK(index != ground_truth_labels_.end()) << "Invalid label: " << label; + return std::distance(ground_truth_labels_.cbegin(), index); +} +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h new file mode 100644 index 0000000000..5a575ff244 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_ + +#include <string> +#include <vector> + +#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace metrics { +// An |AccuracyEval| stage that calculates the top K error rate for model +// evaluations on imagenet like datasets. +// Inputs: A {1, 1001} shaped tensor that contains the probabilities for objects +// predicted by the model. +// Ground truth: A |string| label for the image. +// From the input object probabilities, the stage computes the predicted labels +// and finds the top K error rates by comparing the predictions with ground +// truths. +class ImagenetTopKAccuracy : public AccuracyEval { + public: + // Accuracy statistics. + struct AccuracyStats { + // Number of images evaluated. + int number_of_images; + // A vector of size |k| that contains the number of images + // that have correct labels in top K. + // E.g. topk_counts[0] contains number of images for which + // model returned the correct label as the first result. + // Similarly topk_counts[4] contains the number of images for which + // model returned the correct label in top 5 results. + // This can be used to compute the top K error-rate for the model. + std::vector<int> topk_counts; + }; + + // Creates a new instance of |ImagenetTopKAccuracy| with the given + // |ground_truth_labels| and |k|. + // Args: + // |ground_truth_labels| : an ordered vector of labels for images. This is + // used to compute the index for the predicted labels and ground_truth label. + ImagenetTopKAccuracy(const std::vector<string>& ground_truth_labels, int k); + + // Computes accuracy for a given image. The |model_outputs| should + // be a vector containing exactly one Tensor of shape: {1, 1001} where each + // item is a probability of the predicted object representing the image as + // output by the model. + // Uses |ground_truth_labels| to compute the index of |model_outputs| and + // |ground_truth| and computes the top K error rate. + Status ComputeEval(const std::vector<Tensor>& model_outputs, + const Tensor& ground_truth) override; + + // Gets the topK accuracy for images that have been evaluated till now. + const AccuracyStats GetTopKAccuracySoFar() const; + + private: + int GroundTruthIndex(const string& label) const; + std::vector<string> ground_truth_labels_; + const int k_; + std::vector<int> accuracy_counts_; + int num_samples_; +}; +} // namespace metrics +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval_test.cc b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval_test.cc new file mode 100644 index 0000000000..256cd1d529 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/imagenet_topk_eval.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace metrics { +namespace { + +const int kNumCategories = 1001; + +Tensor CreateStringTensor(const string& value) { + Tensor tensor(DT_STRING, TensorShape({})); + tensor.scalar<string>()() = value; + return tensor; +} + +Tensor CreateOutputTensor() { + Tensor tensor(DT_FLOAT, TensorShape({1, kNumCategories})); + for (int i = 0; i < kNumCategories; i++) { + tensor.flat<float>()(i) = 0; + } + return tensor; +} + +std::vector<string> CreateGroundTruth() { + std::vector<string> ground_truth; + ground_truth.reserve(kNumCategories); + for (int i = 0; i < kNumCategories; i++) { + ground_truth.push_back(std::to_string(i)); + } + return ground_truth; +} + +TEST(ImagenetTopKAccuracy, AllCorrect) { + ImagenetTopKAccuracy acc_top_5(CreateGroundTruth(), 5); + auto accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(0, accuracies.number_of_images); + EXPECT_EQ(5, accuracies.topk_counts.size()); + + for (int i : accuracies.topk_counts) { + EXPECT_EQ(0, i); + } + // First image was correctly identified as "0". + Tensor tensor = CreateOutputTensor(); + tensor.flat<float>()(0) = 0.8; + + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("0"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(1, accuracies.number_of_images); + + for (int i : accuracies.topk_counts) { + EXPECT_EQ(1, i); + } + tensor.flat<float>()(1) = 0.9; + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("1"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(2, accuracies.number_of_images); + + for (int i : accuracies.topk_counts) { + EXPECT_EQ(2, i); + } +} + +TEST(ImagenetTopKAccuracy, Top5) { + ImagenetTopKAccuracy acc_top_5(CreateGroundTruth(), 5); + auto accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(0, accuracies.number_of_images); + EXPECT_EQ(5, accuracies.topk_counts.size()); + + // For first image, with ground truth "0" probabilities were + // 0.5 for "0", + // "0.6" for 1, + // "0.7" for 2, + // "0.8" for 3, + // "0.9" for 4. + // remaining all zeroes. + + // First image was correctly identified as "0". + Tensor tensor = CreateOutputTensor(); + tensor.flat<float>()(0) = 0.5; + tensor.flat<float>()(1) = 0.6; + tensor.flat<float>()(2) = 0.7; + tensor.flat<float>()(3) = 0.8; + tensor.flat<float>()(4) = 0.9; + + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("0"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(1, accuracies.number_of_images); + EXPECT_EQ(1, accuracies.topk_counts[4]); + + for (int i = 0; i < 4; i++) { + EXPECT_EQ(0, accuracies.topk_counts[i]); + } + + // Now for "1" only last two buckets are going to be affected. + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("1"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(2, accuracies.number_of_images); + EXPECT_EQ(1, accuracies.topk_counts[3]); + EXPECT_EQ(2, accuracies.topk_counts[4]); + for (int i = 0; i < 3; i++) { + EXPECT_EQ(0, accuracies.topk_counts[i]); + } + + // All buckets will be affected. + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("4"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(3, accuracies.number_of_images); + EXPECT_EQ(1, accuracies.topk_counts[0]); + EXPECT_EQ(1, accuracies.topk_counts[1]); + EXPECT_EQ(1, accuracies.topk_counts[2]); + EXPECT_EQ(2, accuracies.topk_counts[3]); + EXPECT_EQ(3, accuracies.topk_counts[4]); + + // No buckets will be affected + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("10"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(4, accuracies.number_of_images); + EXPECT_EQ(1, accuracies.topk_counts[0]); + EXPECT_EQ(1, accuracies.topk_counts[1]); + EXPECT_EQ(1, accuracies.topk_counts[2]); + EXPECT_EQ(2, accuracies.topk_counts[3]); + EXPECT_EQ(3, accuracies.topk_counts[4]); +} + +} // namespace + +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.cc b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.cc new file mode 100644 index 0000000000..7afef88637 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h" + +#include <memory> + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace metrics { + +namespace { +void CentralCropImage(const Scope& s, const tensorflow::Output& decoded_image, + double crop_fraction, tensorflow::Output* cropped_image) { + auto image_dims = ops::Slice(s, ops::Shape(s, decoded_image), {0}, {2}); + auto height_width = ops::Cast(s, image_dims, DT_DOUBLE); + auto cropped_begin = ops::Div( + s, ops::Sub(s, height_width, ops::Mul(s, height_width, crop_fraction)), + 2.0); + auto bbox_begin = ops::Cast(s, cropped_begin, DT_INT32); + auto bbox_size = ops::Sub(s, image_dims, ops::Mul(s, bbox_begin, 2)); + auto slice_begin = ops::Concat(s, {bbox_begin, Input({0})}, 0); + auto slice_size = ops::Concat(s, {bbox_size, {-1}}, 0); + *cropped_image = ops::Slice(s, decoded_image, slice_begin, slice_size); +} + +} // namespace + +void InceptionPreprocessingStage::AddToGraph(const Scope& scope, + const Input& input) { + if (!scope.ok()) return; + Scope s = scope.WithOpName(name()); + ops::DecodeJpeg::Attrs attrs; + attrs.channels_ = 3; + auto decoded_jpeg = ops::DecodeJpeg(s, input, attrs); + tensorflow::Output cropped_image; + CentralCropImage(s, decoded_jpeg, params_.cropping_fraction, &cropped_image); + auto dims_expander = ops::ExpandDims(s, cropped_image, 0); + auto resized_image = ops::ResizeBilinear( + s, dims_expander, + ops::Const(s.WithOpName("size"), {image_height_, image_width_})); + if (is_quantized_) { + this->stage_output_ = + ops::Cast(s.WithOpName(output_name()), resized_image, DT_UINT8); + } else { + auto squeezed_image = ops::Squeeze(s, resized_image); + auto normalized_image = + ops::Div(s, + ops::Sub(s, squeezed_image, + {params_.input_means[0], params_.input_means[1], + params_.input_means[2]}), + {params_.scale}); + this->stage_output_ = + ops::ExpandDims(s.WithOpName(output_name()), normalized_image, {0}); + } +} + +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h new file mode 100644 index 0000000000..15df719817 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h @@ -0,0 +1,75 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ + +#include <utility> + +#include "tensorflow/contrib/lite/tools/accuracy/stage.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace metrics { + +// A stage that does inception preprocessing. +// Inputs: A tensor containing bytes of a JPEG image. +// Outputs: A tensor containing rescaled and preprocessed image that has +// shape {1, image_height, image_width, 3}, where 3 is the number of channels. +class InceptionPreprocessingStage : public Stage { + public: + struct Params { + std::vector<float> input_means; + float scale; + double cropping_fraction; + }; + + static Params DefaultParams() { + return {.input_means = {127.5, 127.5, 127.5}, + .scale = 127.5, + .cropping_fraction = 0.875}; + } + + // Creates a new preprocessing stage object with provided |image_width| + // |image_height| as the size of output image. + // If |is_quantized| is set to true then |params| is ignored since quantized + // images don't go through any preprocessing. + InceptionPreprocessingStage(int image_width, int image_height, + bool is_quantized, + Params params = DefaultParams()) + : image_width_(image_width), + image_height_(image_height), + is_quantized_(is_quantized), + params_(std::move(params)) {} + + string name() const override { return "stage_inception_preprocess"; } + string output_name() const override { + return "stage_inception_preprocess_output"; + } + + void AddToGraph(const Scope& scope, const Input& input) override; + + private: + int image_width_; + int image_height_; + bool is_quantized_; + Params params_; +}; + +} // namespace metrics +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing_test.cc b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing_test.cc new file mode 100644 index 0000000000..db574476f6 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/inception_preprocessing_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <fstream> +#include <string> + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/tools/accuracy/inception_preprocessing.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +tensorflow::string* g_test_image_file = nullptr; +} // namespace + +namespace tensorflow { +namespace metrics { + +namespace { + +using tensorflow::Status; +using tensorflow::Tensor; + +Status GetContents(const string& filename, string* output) { + std::ifstream input(filename, std::ios::binary); + const int kBufferSize = 2048; + char buffer[kBufferSize]; + while (true) { + input.read(buffer, kBufferSize); + output->append(buffer, input.gcount()); + if (!input.good()) { + if (input.eof()) return Status::OK(); + return Status(tensorflow::error::ABORTED, "Failed to read file."); + } + } +} + +TEST(InceptionPreprocessingTest, TestImagePreprocessQuantized) { + ASSERT_TRUE(g_test_image_file != nullptr); + string image_contents; + string image_path = *g_test_image_file; + auto status = GetContents(image_path, &image_contents); + ASSERT_TRUE(status.ok()) << status.error_message(); + const int width = 224; + const int height = 224; + const bool is_quantized = true; + InceptionPreprocessingStage preprocess_stage(width, height, is_quantized); + Scope scope = Scope::NewRootScope(); + preprocess_stage.AddToGraph(scope, image_contents); + TF_CHECK_OK(scope.status()); + + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr<Session> session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + std::vector<Tensor> outputs; + auto run_status = + session->Run({}, /*inputs*/ + {preprocess_stage.output_name()}, {}, /*target node names */ + &outputs); + TF_CHECK_OK(run_status); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(DT_UINT8, outputs[0].dtype()); + EXPECT_TRUE(outputs[0].shape().IsSameSize({1, 224, 224, 3})); +} + +TEST(InceptionPreprocessingTest, TestImagePreprocessFloat) { + ASSERT_TRUE(g_test_image_file != nullptr); + string image_contents; + string image_path = *g_test_image_file; + auto status = GetContents(image_path, &image_contents); + ASSERT_TRUE(status.ok()) << status.error_message(); + const int width = 224; + const int height = 224; + const bool is_quantized = false; + InceptionPreprocessingStage preprocess_stage(width, height, is_quantized); + Scope scope = Scope::NewRootScope(); + preprocess_stage.AddToGraph(scope, image_contents); + TF_CHECK_OK(scope.status()); + + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr<Session> session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + std::vector<Tensor> outputs; + auto run_status = + session->Run({}, /*inputs*/ + {preprocess_stage.output_name()}, {}, /*target node names */ + &outputs); + TF_CHECK_OK(run_status); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(DT_FLOAT, outputs[0].dtype()); + EXPECT_TRUE(outputs[0].shape().IsSameSize({1, 224, 224, 3})); +} + +} // namespace +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + g_test_image_file = new tensorflow::string(); + const std::vector<tensorflow::Flag> flag_list = { + tensorflow::Flag("test_image", g_test_image_file, + "Path to image file for test."), + }; + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + CHECK(parse_result) << "Required test_model_file"; + ::tensorflow::port::InitMain(argv[0], &argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc new file mode 100644 index 0000000000..da4258f1c1 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <vector> + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace { +Status ValidateInputsMatch(const OpInputList& input_tensors, + const tflite::Interpreter& interpreter) { + std::vector<int> tflite_tensor_indices = interpreter.inputs(); + if (tflite_tensor_indices.size() != input_tensors.size()) { + return errors::InvalidArgument( + "size mismatch, interpreter size: ", tflite_tensor_indices.size(), + " actual: ", input_tensors.size()); + } + + for (int i = 0; i < input_tensors.size(); i++) { + const TfLiteTensor* tflite_tensor = + interpreter.tensor(tflite_tensor_indices[i]); + if (tflite_tensor == nullptr) { + return errors::InvalidArgument("Tensor is null at index: ", i); + } + + const Tensor& tensor = input_tensors[i]; + auto i_type = metrics::utils::GetTFDataType(tflite_tensor->type); + auto i_shape = metrics::utils::GetTFLiteTensorShape(*tflite_tensor); + if (i_type != tensor.dtype()) { + return errors::InvalidArgument("Data types mismatch for tensors: ", i, + " expected: ", i_type, + " got: ", tensor.dtype()); + } + + if (i_shape != tensor.shape()) { + return errors::InvalidArgument("Data shapes mismatch for tensors: ", i, + " expected: ", i_shape, + " got: ", tensor.shape()); + } + } + + return Status::OK(); +} + +} // namespace + +class RunTFLiteModelOp : public OpKernel { + public: + explicit RunTFLiteModelOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string model_file_path; + OP_REQUIRES_OK(ctx, ctx->GetAttr("model_file_path", &model_file_path)); + model_ = tflite::FlatBufferModel::BuildFromFile(model_file_path.data()); + OP_REQUIRES(ctx, model_, + errors::InvalidArgument( + "Model loading failed. Invalid model file path: ", + model_file_path)); + tflite::ops::builtin::BuiltinOpResolver resolver; + + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); + OP_REQUIRES(ctx, interpreter_, + errors::Internal("Interpreter creation failed.")); + } + + void Compute(OpKernelContext* context) override { + OpInputList input_tensors; + OP_REQUIRES_OK(context, context->input_list("model_input", &input_tensors)); + + OP_REQUIRES_OK(context, ValidateInputsMatch(input_tensors, *interpreter_)); + OpOutputList output_tensors; + OP_REQUIRES_OK(context, + context->output_list("model_output", &output_tensors)); + auto tfl_outputs = interpreter_->outputs(); + OP_REQUIRES(context, output_tensors.size() == tfl_outputs.size(), + errors::InvalidArgument( + "Invalid output size, expected: ", tfl_outputs.size(), + " got: ", output_tensors.size())); + for (int i = 0; i < output_tensors.size(); i++) { + DataType tfl_type = metrics::utils::GetTFDataType( + interpreter_->tensor(tfl_outputs[i])->type); + DataType otype = output_tensors.expected_output_dtype(i); + OP_REQUIRES( + context, tfl_type == otype, + errors::InvalidArgument("Invalid data type for output at index: ", i, + " expected: ", tfl_type, " got: ", otype)); + } + + auto allocation_status = interpreter_->AllocateTensors(); + OP_REQUIRES(context, allocation_status == kTfLiteOk, + errors::Internal("Unable to allocate tensors.")); + for (int i = 0; i < input_tensors.size(); i++) { + const int tfl_index = interpreter_->inputs()[i]; + TfLiteTensor* tflite_tensor = interpreter_->tensor(tfl_index); + auto tensor_bytes = input_tensors[i].tensor_data(); + OP_REQUIRES(context, tflite_tensor->bytes == tensor_bytes.size(), + errors::InvalidArgument( + "Size mismatch, expected: ", tflite_tensor->bytes, + " got: ", tensor_bytes.size())); + std::memcpy(tflite_tensor->data.raw, tensor_bytes.data(), + tensor_bytes.size()); + } + auto invocation_status = interpreter_->Invoke(); + OP_REQUIRES(context, invocation_status == kTfLiteOk, + errors::Internal("Interpreter invocation failed.")); + for (int i = 0; i < output_tensors.size(); i++) { + auto tfl_tensor = interpreter_->tensor(tfl_outputs[i]); + TensorShape shape = metrics::utils::GetTFLiteTensorShape(*tfl_tensor); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, output_tensors.allocate(i, shape, &output)); + auto tensor_bytes = output->tensor_data(); + OP_REQUIRES(context, tensor_bytes.size() == tfl_tensor->bytes, + errors::Internal("Invalid size")); + std::memcpy(const_cast<char*>(tensor_bytes.data()), tfl_tensor->data.raw, + tfl_tensor->bytes); + } + } + + private: + std::unique_ptr<tflite::FlatBufferModel> model_; + std::unique_ptr<tflite::Interpreter> interpreter_; +}; + +REGISTER_KERNEL_BUILDER(Name("RunTFLiteModel").Device(DEVICE_CPU), + RunTFLiteModelOp); + +REGISTER_OP("RunTFLiteModel") + .Input("model_input: input_type") + .Output("model_output: output_type") + .Attr("model_file_path: string") + .Attr("input_type : list(type)") + .Attr("output_type: list(type)") + .SetShapeFn([](shape_inference::InferenceContext* c) { + // TODO(shashishekhar): Infer the correct shape based on output_type and + // maybe another attribute. + return shape_inference::UnknownShape(c); + }); + +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc new file mode 100644 index 0000000000..88175984a0 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc @@ -0,0 +1,200 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +tensorflow::string* g_test_model_file = nullptr; +} + +namespace tensorflow { +namespace { + +TEST(RunTfliteModelOpTest, ModelIsRun) { + ASSERT_TRUE(g_test_model_file != nullptr); + string test_model_file = *g_test_model_file; + ASSERT_FALSE(test_model_file.empty()); + + Scope scope = Scope::NewRootScope(); + TF_CHECK_OK(scope.status()); + // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y + // x = a+b+c, y=b+c+d + + std::vector<Input> graph_inputs = { + ops::Const(scope, 1.0f, {1, 8, 8, 3}), // a + ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b + ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c + ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d + }; + + std::vector<NodeBuilder::NodeOut> input_data; + std::transform(graph_inputs.begin(), graph_inputs.end(), + std::back_inserter(input_data), [&scope](Input model_input) { + return ops::AsNodeOut(scope, model_input); + }); + + std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT, + DT_FLOAT}; + ::tensorflow::Node* ret; + auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel") + .Input(input_data) + .Attr("model_file_path", test_model_file) + .Attr("input_type", model_input_type) + .Attr("output_type", {DT_FLOAT, DT_FLOAT}); + + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + TF_CHECK_OK(scope.status()); + + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + + std::unique_ptr<Session> session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + + std::vector<Tensor> outputs; + TF_CHECK_OK( + session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs)); + EXPECT_EQ(2, outputs.size()); + + for (const auto& tensor : outputs) { + EXPECT_TRUE(tensor.shape().IsSameSize({1, 8, 8, 3})); + } + auto output_x = outputs[0].flat<float>(); + auto output_y = outputs[1].flat<float>(); + EXPECT_EQ(1 * 8 * 8 * 3, output_x.size()); + EXPECT_EQ(1 * 8 * 8 * 3, output_y.size()); + for (int i = 0; i < output_x.size(); i++) { + EXPECT_NEAR(6.3f, output_x(i), 1e-6f); // a+b+c + EXPECT_NEAR(9.6f, output_y(i), 1e-6f); // b+c+d + } +} + +TEST(RunTfliteModelOpTest, NumInputsMismatch) { + ASSERT_TRUE(g_test_model_file != nullptr); + string test_model_file = *g_test_model_file; + ASSERT_FALSE(test_model_file.empty()); + + Scope scope = Scope::NewRootScope(); + TF_CHECK_OK(scope.status()); + // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y + // x = a+b+c, y=b+c+d + // Remove a from input. + + std::vector<Input> graph_inputs = { + ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b + ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c + ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d + }; + + std::vector<NodeBuilder::NodeOut> input_data; + std::transform(graph_inputs.begin(), graph_inputs.end(), + std::back_inserter(input_data), [&scope](Input model_input) { + return ops::AsNodeOut(scope, model_input); + }); + + std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT}; + + ::tensorflow::Node* ret; + auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel") + .Input(input_data) + .Attr("model_file_path", test_model_file) + .Attr("input_type", model_input_type) + .Attr("output_type", {DT_FLOAT, DT_FLOAT}); + + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + TF_CHECK_OK(scope.status()); + + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr<Session> session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + + std::vector<Tensor> outputs; + auto status = + (session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs)); + EXPECT_FALSE(status.ok()); +} + +TEST(RunTfliteModelOpTest, InputSizesMismatch) { + ASSERT_TRUE(g_test_model_file != nullptr); + string test_model_file = *g_test_model_file; + ASSERT_FALSE(test_model_file.empty()); + + Scope scope = Scope::NewRootScope(); + TF_CHECK_OK(scope.status()); + // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y + // x = a+b+c, y=b+c+d + // Set a to be invalid size. + std::vector<Input> graph_inputs = { + ops::Const(scope, 1.0f, {1, 8, 8, 4}), // a invalid size, + ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b + ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c + ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d + }; + + std::vector<NodeBuilder::NodeOut> input_data; + std::transform(graph_inputs.begin(), graph_inputs.end(), + std::back_inserter(input_data), [&scope](Input model_input) { + return ops::AsNodeOut(scope, model_input); + }); + + std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT, + DT_FLOAT}; + ::tensorflow::Node* ret; + auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel") + .Input(input_data) + .Attr("model_file_path", test_model_file) + .Attr("input_type", model_input_type) + .Attr("output_type", {DT_FLOAT, DT_FLOAT}); + + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + TF_CHECK_OK(scope.status()); + + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr<Session> session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + + std::vector<Tensor> outputs; + auto status = + (session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs)); + EXPECT_FALSE(status.ok()); +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { + g_test_model_file = new tensorflow::string(); + const std::vector<tensorflow::Flag> flag_list = { + tensorflow::Flag("test_model_file", g_test_model_file, + "Path to test tflite model file."), + }; + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + CHECK(parse_result) << "Required test_model_file"; + ::tensorflow::port::InitMain(argv[0], &argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc new file mode 100644 index 0000000000..c96795d499 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h" + +#include <vector> + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace metrics { +void RunTFLiteModelStage::AddToGraph(const Scope& scope, const Input& input) { + if (!scope.ok()) return; + Scope s = scope.WithOpName(name()); + + std::vector<NodeBuilder::NodeOut> _data = {ops::AsNodeOut(s, input)}; + ::tensorflow::Node* ret; + auto builder = NodeBuilder(output_name(), "RunTFLiteModel") + .Input(_data) + .Attr("model_file_path", params_.model_file_path) + .Attr("input_type", params_.input_type) + .Attr("output_type", params_.output_type); + + s.UpdateBuilder(&builder); + s.UpdateStatus(builder.Finalize(s.graph(), &ret)); + if (!s.ok()) return; + s.UpdateStatus(s.DoShapeInference(ret)); + this->stage_output_ = ::tensorflow::Output(ret, 0); +} + +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h new file mode 100644 index 0000000000..90d12d6f42 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ + +#include <string> + +#include "tensorflow/contrib/lite/tools/accuracy/stage.h" + +namespace tensorflow { +namespace metrics { +// Stage that loads and runs a TFLite model. +// Inputs: The input to TFLite model. +// Outputs: The output of running the TFLite model. +class RunTFLiteModelStage : public Stage { + public: + // The parameters for the stage. + struct Params { + string model_file_path; + std::vector<TensorShape> output_shape; + std::vector<DataType> input_type; + std::vector<DataType> output_type; + }; + + explicit RunTFLiteModelStage(const Params& params) : params_(params) {} + + string name() const override { return "stage_run_tfl_model"; } + // TODO(shashishekhar): This stage can have multiple inputs and + // outputs, perhaps change the definition of stage. + string output_name() const override { return "stage_run_tfl_model_output"; } + + void AddToGraph(const Scope& scope, const Input& input) override; + + private: + Params params_; +}; + +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/stage.h b/tensorflow/contrib/lite/tools/accuracy/stage.h new file mode 100644 index 0000000000..8292ea2ec7 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/stage.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_ + +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { +namespace metrics { + +// A stage in an evaluation pipeline. +// Each stage adds a subgraph to the pipeline. Stages can be chained +// together. +class Stage { + public: + Stage() = default; + Stage(const Stage&) = delete; + Stage& operator=(const Stage&) = delete; + + Stage(const Stage&&) = delete; + Stage& operator=(const Stage&&) = delete; + + // Adds a subgraph to given scope that takes in `input` as a parameter. + virtual void AddToGraph(const Scope& scope, const Input& input) = 0; + virtual ~Stage() {} + + // The name of the stage. + // Can be used by derived classes for naming the subscope for the stage + // graph. + virtual string name() const = 0; + + // The name of the output for the stage. + virtual string output_name() const = 0; + + const ::tensorflow::Output& Output() const { return stage_output_; } + + protected: + ::tensorflow::Output stage_output_; +}; +} // namespace metrics +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/testdata/grace_hopper.jpg b/tensorflow/contrib/lite/tools/accuracy/testdata/grace_hopper.jpg Binary files differnew file mode 100644 index 0000000000..d2a427810f --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/testdata/grace_hopper.jpg diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.cc b/tensorflow/contrib/lite/tools/accuracy/utils.cc new file mode 100644 index 0000000000..f5493301fc --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/utils.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/utils.h" + +#include <sys/stat.h> + +#include <cstring> +#include <fstream> +#include <memory> +#include <string> + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" + +namespace tensorflow { +namespace metrics { + +namespace utils { + +DataType GetTFDataType(TfLiteType tflite_type) { + switch (tflite_type) { + case kTfLiteFloat32: + return DT_FLOAT; + case kTfLiteUInt8: + return DT_UINT8; + default: + return DT_INVALID; + } +} + +TensorShape GetTFLiteTensorShape(const TfLiteTensor& tflite_tensor) { + TensorShape shape; + for (int i = 0; i < tflite_tensor.dims->size; i++) { + shape.AddDim(tflite_tensor.dims->data[i]); + } + return shape; +} + +Status ReadFileLines(const string& file_path, + std::vector<string>* lines_output) { + if (!lines_output) { + return errors::InvalidArgument("Invalid output"); + } + std::vector<string> lines; + std::ifstream stream(file_path, std::ios_base::in); + if (!stream) { + return errors::InvalidArgument("Unable to open file: ", file_path); + } + std::string line; + while (std::getline(stream, line)) { + lines_output->push_back(line); + } + return Status::OK(); +} + +Status GetTFliteModelInfo(const string& model_file_path, + ModelInfo* model_info) { + if (model_file_path.empty()) { + return errors::InvalidArgument("Invalid model file."); + } + struct stat stat_buf; + if (stat(model_file_path.c_str(), &stat_buf) != 0) { + int error_num = errno; + return errors::InvalidArgument("Invalid model file: ", model_file_path, + std::strerror(error_num)); + } + + std::unique_ptr<tflite::FlatBufferModel> model; + std::unique_ptr<tflite::Interpreter> interpreter; + model = tflite::FlatBufferModel::BuildFromFile(model_file_path.data()); + tflite::ops::builtin::BuiltinOpResolver resolver; + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + return errors::InvalidArgument("Invalid model", model_file_path); + } + for (int i : interpreter->inputs()) { + TfLiteTensor* tensor = interpreter->tensor(i); + model_info->input_shapes.push_back(utils::GetTFLiteTensorShape(*tensor)); + model_info->input_types.push_back(utils::GetTFDataType(tensor->type)); + } + return Status::OK(); +} + +} // namespace utils +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.h b/tensorflow/contrib/lite/tools/accuracy/utils.h new file mode 100644 index 0000000000..37cbad4d51 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/utils.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_ + +#include <string> +#include <vector> + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace metrics { + +namespace utils { + +struct ModelInfo { + std::vector<TensorShape> input_shapes; + std::vector<DataType> input_types; +}; + +Status GetTFliteModelInfo(const string& model_file_path, ModelInfo* model_info); + +DataType GetTFDataType(TfLiteType tflite_type); + +TensorShape GetTFLiteTensorShape(const TfLiteTensor& tflite_tensor); + +Status ReadFileLines(const string& file_path, + std::vector<string>* lines_output); +} // namespace utils +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/utils_test.cc b/tensorflow/contrib/lite/tools/accuracy/utils_test.cc new file mode 100644 index 0000000000..727eba21b6 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/utils_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +tensorflow::string* g_test_model_file = nullptr; +} + +namespace tensorflow { +namespace metrics { +namespace utils { +namespace { + +TEST(UtilsTest, GetTFLiteModelInfoReturnsCorrectly) { + ASSERT_TRUE(g_test_model_file != nullptr); + string test_model_file = *g_test_model_file; + ASSERT_FALSE(test_model_file.empty()); + // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y + // x = a+b+c, y=b+c+d + // Input and outputs have shape : {1,8,8,3} + ModelInfo model_info; + auto status = GetTFliteModelInfo(test_model_file, &model_info); + TF_CHECK_OK(status); + ASSERT_EQ(4, model_info.input_shapes.size()); + ASSERT_EQ(4, model_info.input_types.size()); + + for (int i = 0; i < 4; i++) { + const TensorShape& shape = model_info.input_shapes[i]; + DataType dataType = model_info.input_types[i]; + EXPECT_TRUE(shape.IsSameSize({1, 8, 8, 3})); + EXPECT_EQ(DT_FLOAT, dataType); + } +} + +TEST(UtilsTest, GetTFliteModelInfoIncorrectFile) { + ModelInfo model_info; + auto status = GetTFliteModelInfo("non_existent_file", &model_info); + EXPECT_FALSE(status.ok()); +} + +} // namespace +} // namespace utils +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + g_test_model_file = new tensorflow::string(); + const std::vector<tensorflow::Flag> flag_list = { + tensorflow::Flag("test_model_file", g_test_model_file, + "Path to test tflite model file."), + }; + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + CHECK(parse_result) << "Required test_model_file"; + ::tensorflow::port::InitMain(argv[0], &argc, &argv); + return RUN_ALL_TESTS(); +} |