diff options
-rw-r--r-- | tensorflow/contrib/lite/toco/BUILD | 37 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/args.h | 7 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/model_cmdline_flags.cc | 6 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco.cc | 97 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_cmdline_flags.cc | 98 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_saved_model.cc | 186 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_saved_model.h | 53 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_saved_model_test.cc | 274 |
8 files changed, 690 insertions, 68 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 486ff1edcd..102740ee47 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -124,6 +124,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -168,6 +169,41 @@ cc_library( ) cc_library( + name = "toco_saved_model", + srcs = [ + "toco_saved_model.cc", + ], + hdrs = [ + "toco_saved_model.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model_cmdline_flags", + ":model_flags_proto_cc", + ":toco_flags_proto_cc", + ":types_proto_cc", + "//tensorflow/cc/tools:freeze_saved_model", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "toco_saved_model_test", + srcs = ["toco_saved_model_test.cc"], + deps = [ + ":model_cmdline_flags", + ":toco_cmdline_flags", + ":toco_saved_model", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( name = "graph_transformations", srcs = [ "graph_transformations/convert_expanddims_to_reshape.cc", @@ -363,6 +399,7 @@ tf_cc_binary( ":toco_cmdline_flags", ":toco_flags_proto_cc", ":toco_port", + ":toco_saved_model", ":toco_tooling", ":types_proto_cc", "//tensorflow/core:lib", diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 59a6115920..7b71792ff7 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -190,6 +190,7 @@ struct ParsedModelFlags { Arg<string> output_array; Arg<string> output_arrays; Arg<string> input_shapes; + Arg<int> batch_size = Arg<int>(1); Arg<float> mean_value = Arg<float>(0.f); Arg<string> mean_values; Arg<float> std_value = Arg<float>(1.f); @@ -215,9 +216,11 @@ struct ParsedModelFlags { // you want). See toco_cmdline_flags.cc for details. struct ParsedTocoFlags { Arg<string> input_file; + Arg<string> savedmodel_directory; Arg<string> output_file; - Arg<string> input_format; - Arg<string> output_format; + Arg<string> input_format = Arg<string>("TENSORFLOW_GRAPHDEF"); + Arg<string> output_format = Arg<string>("TFLITE"); + Arg<string> savedmodel_tagset; // TODO(aselle): command_line_flags doesn't support doubles Arg<float> default_ranges_min = Arg<float>(0.); Arg<float> default_ranges_max = Arg<float>(0.); diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 4e2dec15a5..4264f21c76 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -72,6 +72,12 @@ bool ParseModelFlagsFromCommandLineFlags( "Shapes corresponding to --input_arrays, colon-separated. For " "many models each shape takes the form batch size, input array " "height, input array width, input array depth."), + Flag("batch_size", parsed_flags.batch_size.bind(), + parsed_flags.batch_size.default_value(), + "Batch size for the model. Replaces the first dimension of an " + "input size array if undefined. Use only with SavedModels when " + "--input_shapes flag is not specified. Always use --input_shapes " + "flag with frozen graphs."), Flag("input_data_type", parsed_flags.input_data_type.bind(), parsed_flags.input_data_type.default_value(), "Deprecated: use --input_data_types instead. Input array type, if " diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc index f01ec0ec61..8041aa9e7f 100644 --- a/tensorflow/contrib/lite/toco/toco.cc +++ b/tensorflow/contrib/lite/toco/toco.cc @@ -23,40 +23,70 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_saved_model.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" #include "tensorflow/core/platform/logging.h" -#ifndef CHECK_OK -#define CHECK_OK(val) CHECK_EQ((val).ok(), true) -#define QCHECK_OK(val) QCHECK_EQ((val).ok(), true) -#endif - namespace toco { namespace { -#define QCHECK_REQUIRE_TOCO_FLAG(arg) \ - QCHECK(parsed_toco_flags.arg.specified()) << "Missing required flag: " #arg; - -void CheckFilePermissions(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - const TocoFlags& toco_flags) { - port::CheckInitGoogleIsDone("InitGoogle is not done yet"); - - QCHECK_REQUIRE_TOCO_FLAG(input_file) - QCHECK_OK(port::file::Exists(parsed_toco_flags.input_file.value(), - port::file::Defaults())) - << "Specified input_file does not exist: " - << parsed_toco_flags.input_file.value(); - QCHECK_OK(port::file::Readable(parsed_toco_flags.input_file.value(), - port::file::Defaults())) +// Checks the permissions of the output file to ensure it is writeable. +void CheckOutputFilePermissions(const Arg<string>& output_file) { + QCHECK(output_file.specified()) << "Missing required flag --output_file.\n"; + QCHECK(port::file::Writable(output_file.value()).ok()) + << "Specified output_file is not writable: " << output_file.value() + << ".\n"; +} + +// Checks the permissions of the frozen model file. +void CheckFrozenModelPermissions(const Arg<string>& input_file) { + QCHECK(input_file.specified()) << "Missing required flag --input_file.\n"; + QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok()) + << "Specified input_file does not exist: " << input_file.value() << ".\n"; + QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok()) << "Specified input_file exists, but is not readable: " - << parsed_toco_flags.input_file.value(); + << input_file.value() << ".\n"; +} - QCHECK_REQUIRE_TOCO_FLAG(output_file); - QCHECK_OK(port::file::Writable(parsed_toco_flags.output_file.value())) - << "parsed_toco_flags.input_file.value() output_file is not writable: " - << parsed_toco_flags.output_file.value(); +// Checks the permissions of the SavedModel directory. +void CheckSavedModelPermissions(const Arg<string>& savedmodel_directory) { + QCHECK(savedmodel_directory.specified()) + << "Missing required flag --savedmodel_directory.\n"; + QCHECK( + port::file::Exists(savedmodel_directory.value(), port::file::Defaults()) + .ok()) + << "Specified savedmodel_directory does not exist: " + << savedmodel_directory.value() << ".\n"; +} + +// Reads the contents of the GraphDef from either the frozen graph file or the +// SavedModel directory. If it reads the SavedModel directory, it updates the +// ModelFlags and TocoFlags accordingly. +void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags, + string* graph_def_contents) { + port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n"); + + bool has_input_file = parsed_toco_flags.input_file.specified(); + bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified(); + + // Ensure either input_file or savedmodel_directory flag has been set. + QCHECK_NE(has_input_file, has_savedmodel_dir) + << "Specify either input_file or savedmodel_directory flag.\n"; + + // Checks the input file permissions and reads the contents. + if (has_input_file) { + CheckFrozenModelPermissions(parsed_toco_flags.input_file); + CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), + graph_def_contents, port::file::Defaults()) + .ok()); + } else { + CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory); + GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags, + model_flags, graph_def_contents); + } } void ToolMain(const ParsedTocoFlags& parsed_toco_flags, @@ -67,21 +97,20 @@ void ToolMain(const ParsedTocoFlags& parsed_toco_flags, TocoFlags toco_flags; ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags); - CheckFilePermissions(parsed_toco_flags, parsed_model_flags, toco_flags); + string graph_def_contents; + ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags, + &model_flags, &graph_def_contents); + CheckOutputFilePermissions(parsed_toco_flags.output_file); - string input_file_contents; - CHECK_OK(port::file::GetContents(parsed_toco_flags.input_file.value(), - &input_file_contents, - port::file::Defaults())); std::unique_ptr<Model> model = - Import(toco_flags, model_flags, input_file_contents); + Import(toco_flags, model_flags, graph_def_contents); Transform(toco_flags, model.get()); string output_file_contents; Export(toco_flags, *model, toco_flags.allow_custom_ops(), &output_file_contents); - CHECK_OK(port::file::SetContents(parsed_toco_flags.output_file.value(), - output_file_contents, - port::file::Defaults())); + CHECK(port::file::SetContents(parsed_toco_flags.output_file.value(), + output_file_contents, port::file::Defaults()) + .ok()); } } // namespace diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index 0f67c2de72..cc7803dd86 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/strip.h" +#include "absl/types/optional.h" #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/core/platform/logging.h" @@ -38,6 +39,9 @@ bool ParseTocoFlagsFromCommandLineFlags( "Input file (model of any supported format). For Protobuf " "formats, both text and binary are supported regardless of file " "extension."), + Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(), + parsed_flags.savedmodel_directory.default_value(), + "Full path to the directory containing the SavedModel."), Flag("output_file", parsed_flags.output_file.bind(), parsed_flags.output_file.default_value(), "Output file. " @@ -49,6 +53,11 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.output_format.default_value(), "Output file format. " "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."), + Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(), + parsed_flags.savedmodel_tagset.default_value(), + "Comma-separated set of tags identifying the MetaGraphDef within " + "the SavedModel to analyze. All tags in the tag set must be " + "specified."), Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), parsed_flags.default_ranges_min.default_value(), "If defined, will be used as the default value for the min bound " @@ -128,47 +137,72 @@ bool ParseTocoFlagsFromCommandLineFlags( } } +namespace { + +// Defines the requirements for a given flag. kUseDefault means the default +// should be used in cases where the value isn't specified by the user. +enum class FlagRequirement { + kNone, + kMustBeSpecified, + kMustNotBeSpecified, + kUseDefault, +}; + +// Enforces the FlagRequirements are met for a given flag. +template <typename T> +void EnforceFlagRequirement(const T& flag, const string& flag_name, + FlagRequirement requirement) { + if (requirement == FlagRequirement::kMustBeSpecified) { + QCHECK(flag.specified()) << "Missing required flag " << flag_name; + } + if (requirement == FlagRequirement::kMustNotBeSpecified) { + QCHECK(!flag.specified()) + << "Given other flags, this flag should not have been specified: " + << flag_name; + } +} + +// Gets the value from the flag if specified. Returns default if the +// FlagRequirement is kUseDefault. +template <typename T> +absl::optional<T> GetFlagValue(const Arg<T>& flag, + FlagRequirement requirement) { + if (flag.specified()) return flag.value(); + if (requirement == FlagRequirement::kUseDefault) return flag.default_value(); + return absl::optional<T>(); +} + +} // namespace + void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, TocoFlags* toco_flags) { namespace port = toco::port; port::CheckInitGoogleIsDone("InitGoogle is not done yet"); - enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified }; - -#define ENFORCE_FLAG_REQUIREMENT(name, requirement) \ - do { \ - if (requirement == FlagRequirement::kMustBeSpecified) { \ - QCHECK(parsed_toco_flags.name.specified()) \ - << "Missing required flag: " << #name; \ - } \ - if (requirement == FlagRequirement::kMustNotBeSpecified) { \ - QCHECK(!parsed_toco_flags.name.specified()) \ - << "Given other flags, this flag should not have been specified: " \ - << #name; \ - } \ - } while (false) -#define READ_TOCO_FLAG(name, requirement) \ - ENFORCE_FLAG_REQUIREMENT(name, requirement); \ - do { \ - if (parsed_toco_flags.name.specified()) { \ - toco_flags->set_##name(parsed_toco_flags.name.value()); \ - } \ +#define READ_TOCO_FLAG(name, requirement) \ + do { \ + EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \ + auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \ + if (flag_value.has_value()) { \ + toco_flags->set_##name(flag_value.value()); \ + } \ } while (false) -#define PARSE_TOCO_FLAG(Type, name, requirement) \ - ENFORCE_FLAG_REQUIREMENT(name, requirement); \ - do { \ - if (parsed_toco_flags.name.specified()) { \ - Type x; \ - QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \ - << "Unrecognized " << #Type << " value " \ - << parsed_toco_flags.name.value(); \ - toco_flags->set_##name(x); \ - } \ +#define PARSE_TOCO_FLAG(Type, name, requirement) \ + do { \ + EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \ + auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \ + if (flag_value.has_value()) { \ + Type x; \ + QCHECK(Type##_Parse(flag_value.value(), &x)) \ + << "Unrecognized " << #Type << " value " \ + << parsed_toco_flags.name.value(); \ + toco_flags->set_##name(x); \ + } \ } while (false) - PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified); - PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified); + PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault); + PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault); PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone); PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone); READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone); diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.cc b/tensorflow/contrib/lite/toco/toco_saved_model.cc new file mode 100644 index 0000000000..91a742b9e0 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_saved_model.cc @@ -0,0 +1,186 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> +#include <vector> + +#include "absl/strings/numbers.h" +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/toco_saved_model.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace toco { +namespace { + +// Loads a SavedModel from the directory specified in parsed_toco_flags. +// Returns a SavedModelBundle with the requested MetaGraphDef. +const tensorflow::SavedModelBundle* LoadSavedModel( + const ParsedTocoFlags& parsed_toco_flags) { + const string model_path = parsed_toco_flags.savedmodel_directory.value(); + QCHECK(tensorflow::MaybeSavedModelDirectory(model_path)) + << "Model is not saved in the supported SavedModel format.\n"; + + // Gets the tags identifying the MetaGraphDef from the command line arguments. + QCHECK(parsed_toco_flags.savedmodel_tagset.specified()) + << "Missing required flag --savedmodel_tagset.\n"; + const string tags_str = parsed_toco_flags.savedmodel_tagset.value(); + auto tags = absl::StrSplit(tags_str, ','); + + // Loads MetaGraphDef. + auto* bundle = new tensorflow::SavedModelBundle; + TF_CHECK_OK(tensorflow::LoadSavedModel(tensorflow::SessionOptions(), + tensorflow::RunOptions(), model_path, + tags, bundle)) + << "Failed to load exported model from " << model_path + << ". Ensure the model contains the required tags '" << tags_str + << "'.\n"; + return bundle; +} + +// Returns the array name without the postfix. +// +// e.g. reduces "input:0" to "input". +string GetArrayName(const string& name) { + const std::vector<string>& names = absl::StrSplit(name, ':'); + return names[0]; +} + +// Returns the list of array names without the postfix sorted alphabetically. +std::set<string> GetSortedNames(const std::unordered_set<string>& names) { + std::vector<string> final_names; + final_names.reserve(names.size()); + for (const auto& name : names) { + final_names.push_back(GetArrayName(name)); + } + return std::set<string>(final_names.begin(), final_names.end()); +} + +// Gets the final shape after replacing the first dimension with batch size, if +// it is undefined (containing the value -1). Returns whether the shape is +// valid. +bool ReplaceShapeBatchSize(const tensorflow::TensorShapeProto& shape, + int batch_size, + tensorflow::TensorShapeProto* final_shape) { + for (int idx = 0; idx < shape.dim().size(); ++idx) { + int64 final_dim = shape.dim()[idx].size(); + if (final_dim == -1) { + if (idx > 0) return false; + final_dim = batch_size; + } + final_shape->add_dim()->set_size(final_dim); + } + return true; +} + +// Updates the input arrays in ModelFlags to contain the shape of the array. +void ProcessInputShapes(const tensorflow::GraphDef& graph_def, int batch_size, + ModelFlags* model_flags) { + // Build map of input array names to input arrays. + std::unordered_map<string, InputArray*> input_data_map; + for (auto& input : *model_flags->mutable_input_arrays()) { + input_data_map[input.name()] = &input; + } + + // Adds shapes to the input arrays if the shape is valid. + for (const tensorflow::NodeDef& node_def : graph_def.node()) { + if (input_data_map.find(node_def.name()) != input_data_map.end()) { + const auto shape_it = node_def.attr().find("shape"); + if (shape_it != node_def.attr().end()) { + tensorflow::TensorShapeProto final_shape; + bool is_valid = ReplaceShapeBatchSize(shape_it->second.shape(), + batch_size, &final_shape); + + if (is_valid) { + auto* shape = input_data_map.at(node_def.name())->mutable_shape(); + QCHECK_EQ(shape->dims_size(), 0) + << "The shape for the input '" << node_def.name() + << "' was previously defined. For clarity please define inputs " + << "via --input_arrays and input_shapes flags.\n"; + for (const auto& dim : final_shape.dim()) { + shape->add_dims(dim.size()); + } + } + } + } + } + + // Checks all input arrays have a shape. + for (auto const& input : model_flags->input_arrays()) { + QCHECK(input.shape().dims_size() > 0) + << "A valid input shape was not found for input '" << input.name() + << "'. Please define via --input_arrays and --input_shapes flags.\n"; + } +} + +} // namespace + +void ParseMetaData(const tensorflow::GraphDef& graph_def, + const std::unordered_set<string>& inputs, + const std::unordered_set<string>& outputs, + const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags) { + if (!parsed_model_flags.input_arrays.specified()) { + const std::set<string> sorted_inputs = GetSortedNames(inputs); + for (const auto& input_name : sorted_inputs) { + model_flags->add_input_arrays()->set_name(input_name); + } + } + + if (!parsed_model_flags.output_arrays.specified()) { + const std::set<string> sorted_outputs = GetSortedNames(outputs); + for (const auto& output_name : sorted_outputs) { + model_flags->add_output_arrays(GetArrayName(output_name)); + } + } + + if (!parsed_model_flags.input_shapes.specified()) { + int batch_size = parsed_model_flags.batch_size.value(); + ProcessInputShapes(graph_def, batch_size, model_flags); + } + + if (!parsed_toco_flags.inference_type.specified()) { + toco_flags->set_inference_type(IODataType::FLOAT); + } +} + +// TODO(nupurgarg): Add top level tests. +void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags, + string* graph_def_contents) { + // Loads the MetaGraphDef within a SavedModelBundle. + auto bundle = LoadSavedModel(parsed_toco_flags); + + // Converts the MetaGraphDef to frozen GraphDef. + tensorflow::GraphDef frozen_graph_def; + std::unordered_set<string> inputs; + std::unordered_set<string> outputs; + TF_CHECK_OK(tensorflow::FreezeSavedModel(*bundle, &frozen_graph_def, &inputs, + &outputs)); + + // Reads the frozen GraphDef into a string. + QCHECK(frozen_graph_def.SerializeToString(graph_def_contents)) + << "Unable to generate serialized GraphDef.\n"; + + // Process inputs and outputs and metadata within GraphDef. + const tensorflow::GraphDef graph_def = bundle->meta_graph_def.graph_def(); + ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags, + parsed_model_flags, toco_flags, model_flags); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.h b/tensorflow/contrib/lite/toco/toco_saved_model.h new file mode 100644 index 0000000000..7a0fabd82d --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_saved_model.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ + +#include <string> +#include <vector> + +#include "tensorflow/cc/tools/freeze_saved_model.h" +#include "tensorflow/contrib/lite/toco/args.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/types.pb.h" + +namespace toco { + +// Parses metadata into `toco_flags` and `model_flags`. +// +// Stores `inputs` as input_arrays and `outputs` as output_arrays in +// `model_flags`. Infers input_shapes from the GraphDef and stores it in +// `model_flags` as part of the input_arrays. Assumes inference_type is FLOAT +// and stores it in `toco_flags`. +void ParseMetaData(const tensorflow::GraphDef& graph_def, + const std::unordered_set<string>& inputs, + const std::unordered_set<string>& outputs, + const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags); + +// Generates a frozen graph from the SavedModel in the directory specified in +// `toco_flags`. Reads frozen graph contents into `graph_def_contents`. Parses +// metadata relating to the GraphDef into `toco_flags` and `model_flags`. +void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags, + string* graph_def_contents); + +} // namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ diff --git a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc new file mode 100644 index 0000000000..5e122afe65 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc @@ -0,0 +1,274 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/toco/toco_saved_model.h" +#include "absl/strings/str_join.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace toco { +namespace { + +using tensorflow::ops::Add; +using tensorflow::ops::Const; +using tensorflow::ops::FakeQuantWithMinMaxArgs; +using tensorflow::ops::Placeholder; + +class TocoSavedModelTest : public ::testing::Test { + protected: + // Calls functions to process cmdline arguments and calls ParseMetaData. + // ParseMetaData parses input_arrays, output_arrays, and gets metadata from + // SavedModel it is not defined in the cmdline arguments. + void ProcessGraphDefMetadata(const std::unordered_set<string>& inputs, + const std::unordered_set<string>& outputs, + const tensorflow::GraphDef& graph_def) { + ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags_, &toco_flags_); + ReadModelFlagsFromCommandLineFlags(parsed_model_flags_, &model_flags_); + ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags_, + parsed_model_flags_, &toco_flags_, &model_flags_); + } + + // Gets the GraphDef from the SavedModelBundle and processes metadata. + void ProcessSavedModelMetadata(const std::unordered_set<string>& inputs, + const std::unordered_set<string>& outputs) { + const tensorflow::GraphDef graph_def = bundle_.meta_graph_def.graph_def(); + ProcessGraphDefMetadata(inputs, outputs, graph_def); + } + + // Returns a GraphDef representing a simple float model with a single input. + tensorflow::GraphDef GetFloatGraphDef(const std::vector<int64>& shape) { + tensorflow::GraphDef graph_def; + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + tensorflow::Output input = + Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::PartialTensorShape(shape))); + tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); + tensorflow::Output add = Add(scope.WithOpName("add"), input, zero); + + TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); + return graph_def; + } + + // Returns a GraphDef representing a simple float model with two inputs. + tensorflow::GraphDef GetComplexFloatGraphDef() { + tensorflow::GraphDef graph_def; + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + tensorflow::Output inputA = + Placeholder(scope.WithOpName("inputA"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); + tensorflow::Output inputB = + Placeholder(scope.WithOpName("inputB"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); + tensorflow::Output add = Add(scope.WithOpName("add"), inputB, inputA); + + TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); + return graph_def; + } + + // Returns a GraphDef representing a simple quantized model. + tensorflow::GraphDef GetQuantizedGraphDef() { + tensorflow::GraphDef graph_def; + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + tensorflow::Output input = + Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); + tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); + tensorflow::Output fake_quant = + FakeQuantWithMinMaxArgs(scope.WithOpName("quant"), zero); + tensorflow::Output add = Add(scope.WithOpName("add"), input, fake_quant); + + TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); + return graph_def; + } + + // Gets the values in the input_arrays flag. + std::vector<string> GetInputArrays() { + std::vector<string> actual; + for (const auto& input : model_flags_.input_arrays()) { + actual.push_back(input.name()); + } + return actual; + } + + // Gets the values in the output_arrays flag. + std::vector<string> GetOutputArrays() { + std::vector<string> actual(model_flags_.output_arrays().begin(), + model_flags_.output_arrays().end()); + return actual; + } + + // Gets the shape of the given input array. + string GetInputShape(const string& input_array) { + for (const auto& input : model_flags_.input_arrays()) { + if (input.name() == input_array) { + std::vector<string> dims; + for (int idx = 0; idx < input.shape().dims_size(); ++idx) { + dims.push_back(std::to_string(input.shape().dims(idx))); + } + return absl::StrJoin(dims, ","); + } + } + return ""; + } + + tensorflow::SavedModelBundle bundle_; + ParsedTocoFlags parsed_toco_flags_; + ParsedModelFlags parsed_model_flags_; + TocoFlags toco_flags_; + ModelFlags model_flags_; +}; + +// Tests if input_arrays, output_arrays, inference_type, and output_arrays are +// added to ModelFlags if they are not specified in cmdline arguments. +// Tests if the default batch size replaces a -1 in the first dimension. +TEST_F(TocoSavedModelTest, NoCmdLine) { + tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); + + ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"})); + EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if the order of input_arrays and output_arrays is deterministic when +// they are taken from the SavedModel. +TEST_F(TocoSavedModelTest, NoCmdLineMultipleArrays) { + tensorflow::GraphDef graph_def = GetComplexFloatGraphDef(); + + // Note: The model does not have two outputs. However, the function does not + // need an accurate output_array list. This is only meant to test order. + ProcessGraphDefMetadata({"inputB", "inputA"}, {"add", "invalid"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add", "invalid"})); + EXPECT_EQ(GetInputShape("inputA"), "1,3,3,1"); + EXPECT_EQ(GetInputShape("inputB"), "1,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if input_shapes is inferred when input_arrays is passed in via cmdline +// arguments. +TEST_F(TocoSavedModelTest, InputNameWithoutInputShape) { + parsed_model_flags_.input_arrays.bind()("input"); + tensorflow::GraphDef graph_def = GetFloatGraphDef({2, 3, 3, 1}); + + ProcessGraphDefMetadata({"not_used_input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"})); + EXPECT_EQ(GetInputShape("input"), "2,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Ensures a failure occurs when input_shapes is defined without input_arrays. +TEST_F(TocoSavedModelTest, InputShapeWithoutInputName) { + parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); + tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); + + EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), + "failed: input_shapes.size\\(\\) == " + "model_flags->input_arrays_size\\(\\)"); +} + +// Tests if the cmdline values of input_arrays, input_shapes are used when +// specified with an empty GraphDef. +TEST_F(TocoSavedModelTest, InputArraysCmdLine) { + parsed_model_flags_.input_arrays.bind()("inputA,inputB"); + parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); + + ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"output0", "output1"})); + EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); + EXPECT_EQ(GetInputShape("inputB"), "9,12"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if the cmdline values of input_arrays, input_shapes are used when +// specified even if values exist within the GraphDef. +TEST_F(TocoSavedModelTest, InputArraysCmdLineWithGraphDef) { + parsed_model_flags_.input_arrays.bind()("inputA"); + parsed_model_flags_.input_shapes.bind()("1,224,224,1"); + tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); + + ProcessGraphDefMetadata({"inputA"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"})); + EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if the cmdline values of input_arrays, input_shapes, inference_type, +// and output_arrays are used when specified with an empty GraphDef. +TEST_F(TocoSavedModelTest, AllParamsCmdLine) { + parsed_model_flags_.input_arrays.bind()("inputA,inputB"); + parsed_model_flags_.output_arrays.bind()("outputA,outputB"); + parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); + parsed_toco_flags_.inference_type.bind()("FLOAT"); + + ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"outputA", "outputB"})); + EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); + EXPECT_EQ(GetInputShape("inputB"), "9,12"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if a quantized graph gives the correct values assuming type is passed +// in via command line. +TEST_F(TocoSavedModelTest, QuantizedNoCmdLine) { + parsed_toco_flags_.inference_type.bind()("QUANTIZED_UINT8"); + tensorflow::GraphDef graph_def = GetQuantizedGraphDef(); + + ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"})); + EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::QUANTIZED_UINT8); +} + +// Tests if the provided batch size replaces a -1 in the first dimension of +// input shape. +TEST_F(TocoSavedModelTest, MissingShapeParameterValid) { + parsed_model_flags_.batch_size.bind()(3); + tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); + + ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"})); + EXPECT_EQ(GetInputShape("input"), "3,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Ensures a failure occurs if there is a -1 in a dimension aside from the first +// position of input shape. +TEST_F(TocoSavedModelTest, MissingShapeParameterInvalid) { + parsed_model_flags_.batch_size.bind()(3); + tensorflow::GraphDef graph_def = GetFloatGraphDef({1, -1, 3, 1}); + + EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), + "A valid input shape was not found for input 'input'."); +} + +} // namespace +} // namespace toco |