aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/toco/BUILD37
-rw-r--r--tensorflow/contrib/lite/toco/args.h7
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc6
-rw-r--r--tensorflow/contrib/lite/toco/toco.cc97
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc98
-rw-r--r--tensorflow/contrib/lite/toco/toco_saved_model.cc186
-rw-r--r--tensorflow/contrib/lite/toco/toco_saved_model.h53
-rw-r--r--tensorflow/contrib/lite/toco/toco_saved_model_test.cc274
8 files changed, 690 insertions, 68 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 486ff1edcd..102740ee47 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -124,6 +124,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -168,6 +169,41 @@ cc_library(
)
cc_library(
+ name = "toco_saved_model",
+ srcs = [
+ "toco_saved_model.cc",
+ ],
+ hdrs = [
+ "toco_saved_model.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model_cmdline_flags",
+ ":model_flags_proto_cc",
+ ":toco_flags_proto_cc",
+ ":types_proto_cc",
+ "//tensorflow/cc/tools:freeze_saved_model",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cc_test(
+ name = "toco_saved_model_test",
+ srcs = ["toco_saved_model_test.cc"],
+ deps = [
+ ":model_cmdline_flags",
+ ":toco_cmdline_flags",
+ ":toco_saved_model",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
name = "graph_transformations",
srcs = [
"graph_transformations/convert_expanddims_to_reshape.cc",
@@ -363,6 +399,7 @@ tf_cc_binary(
":toco_cmdline_flags",
":toco_flags_proto_cc",
":toco_port",
+ ":toco_saved_model",
":toco_tooling",
":types_proto_cc",
"//tensorflow/core:lib",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 59a6115920..7b71792ff7 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -190,6 +190,7 @@ struct ParsedModelFlags {
Arg<string> output_array;
Arg<string> output_arrays;
Arg<string> input_shapes;
+ Arg<int> batch_size = Arg<int>(1);
Arg<float> mean_value = Arg<float>(0.f);
Arg<string> mean_values;
Arg<float> std_value = Arg<float>(1.f);
@@ -215,9 +216,11 @@ struct ParsedModelFlags {
// you want). See toco_cmdline_flags.cc for details.
struct ParsedTocoFlags {
Arg<string> input_file;
+ Arg<string> savedmodel_directory;
Arg<string> output_file;
- Arg<string> input_format;
- Arg<string> output_format;
+ Arg<string> input_format = Arg<string>("TENSORFLOW_GRAPHDEF");
+ Arg<string> output_format = Arg<string>("TFLITE");
+ Arg<string> savedmodel_tagset;
// TODO(aselle): command_line_flags doesn't support doubles
Arg<float> default_ranges_min = Arg<float>(0.);
Arg<float> default_ranges_max = Arg<float>(0.);
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index 4e2dec15a5..4264f21c76 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -72,6 +72,12 @@ bool ParseModelFlagsFromCommandLineFlags(
"Shapes corresponding to --input_arrays, colon-separated. For "
"many models each shape takes the form batch size, input array "
"height, input array width, input array depth."),
+ Flag("batch_size", parsed_flags.batch_size.bind(),
+ parsed_flags.batch_size.default_value(),
+ "Batch size for the model. Replaces the first dimension of an "
+ "input size array if undefined. Use only with SavedModels when "
+ "--input_shapes flag is not specified. Always use --input_shapes "
+ "flag with frozen graphs."),
Flag("input_data_type", parsed_flags.input_data_type.bind(),
parsed_flags.input_data_type.default_value(),
"Deprecated: use --input_data_types instead. Input array type, if "
diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc
index f01ec0ec61..8041aa9e7f 100644
--- a/tensorflow/contrib/lite/toco/toco.cc
+++ b/tensorflow/contrib/lite/toco/toco.cc
@@ -23,40 +23,70 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_saved_model.h"
#include "tensorflow/contrib/lite/toco/toco_tooling.h"
#include "tensorflow/contrib/lite/toco/toco_types.h"
#include "tensorflow/core/platform/logging.h"
-#ifndef CHECK_OK
-#define CHECK_OK(val) CHECK_EQ((val).ok(), true)
-#define QCHECK_OK(val) QCHECK_EQ((val).ok(), true)
-#endif
-
namespace toco {
namespace {
-#define QCHECK_REQUIRE_TOCO_FLAG(arg) \
- QCHECK(parsed_toco_flags.arg.specified()) << "Missing required flag: " #arg;
-
-void CheckFilePermissions(const ParsedTocoFlags& parsed_toco_flags,
- const ParsedModelFlags& parsed_model_flags,
- const TocoFlags& toco_flags) {
- port::CheckInitGoogleIsDone("InitGoogle is not done yet");
-
- QCHECK_REQUIRE_TOCO_FLAG(input_file)
- QCHECK_OK(port::file::Exists(parsed_toco_flags.input_file.value(),
- port::file::Defaults()))
- << "Specified input_file does not exist: "
- << parsed_toco_flags.input_file.value();
- QCHECK_OK(port::file::Readable(parsed_toco_flags.input_file.value(),
- port::file::Defaults()))
+// Checks the permissions of the output file to ensure it is writeable.
+void CheckOutputFilePermissions(const Arg<string>& output_file) {
+ QCHECK(output_file.specified()) << "Missing required flag --output_file.\n";
+ QCHECK(port::file::Writable(output_file.value()).ok())
+ << "Specified output_file is not writable: " << output_file.value()
+ << ".\n";
+}
+
+// Checks the permissions of the frozen model file.
+void CheckFrozenModelPermissions(const Arg<string>& input_file) {
+ QCHECK(input_file.specified()) << "Missing required flag --input_file.\n";
+ QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok())
+ << "Specified input_file does not exist: " << input_file.value() << ".\n";
+ QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok())
<< "Specified input_file exists, but is not readable: "
- << parsed_toco_flags.input_file.value();
+ << input_file.value() << ".\n";
+}
- QCHECK_REQUIRE_TOCO_FLAG(output_file);
- QCHECK_OK(port::file::Writable(parsed_toco_flags.output_file.value()))
- << "parsed_toco_flags.input_file.value() output_file is not writable: "
- << parsed_toco_flags.output_file.value();
+// Checks the permissions of the SavedModel directory.
+void CheckSavedModelPermissions(const Arg<string>& savedmodel_directory) {
+ QCHECK(savedmodel_directory.specified())
+ << "Missing required flag --savedmodel_directory.\n";
+ QCHECK(
+ port::file::Exists(savedmodel_directory.value(), port::file::Defaults())
+ .ok())
+ << "Specified savedmodel_directory does not exist: "
+ << savedmodel_directory.value() << ".\n";
+}
+
+// Reads the contents of the GraphDef from either the frozen graph file or the
+// SavedModel directory. If it reads the SavedModel directory, it updates the
+// ModelFlags and TocoFlags accordingly.
+void ReadInputData(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags,
+ string* graph_def_contents) {
+ port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n");
+
+ bool has_input_file = parsed_toco_flags.input_file.specified();
+ bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified();
+
+ // Ensure either input_file or savedmodel_directory flag has been set.
+ QCHECK_NE(has_input_file, has_savedmodel_dir)
+ << "Specify either input_file or savedmodel_directory flag.\n";
+
+ // Checks the input file permissions and reads the contents.
+ if (has_input_file) {
+ CheckFrozenModelPermissions(parsed_toco_flags.input_file);
+ CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(),
+ graph_def_contents, port::file::Defaults())
+ .ok());
+ } else {
+ CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory);
+ GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags,
+ model_flags, graph_def_contents);
+ }
}
void ToolMain(const ParsedTocoFlags& parsed_toco_flags,
@@ -67,21 +97,20 @@ void ToolMain(const ParsedTocoFlags& parsed_toco_flags,
TocoFlags toco_flags;
ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags);
- CheckFilePermissions(parsed_toco_flags, parsed_model_flags, toco_flags);
+ string graph_def_contents;
+ ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags,
+ &model_flags, &graph_def_contents);
+ CheckOutputFilePermissions(parsed_toco_flags.output_file);
- string input_file_contents;
- CHECK_OK(port::file::GetContents(parsed_toco_flags.input_file.value(),
- &input_file_contents,
- port::file::Defaults()));
std::unique_ptr<Model> model =
- Import(toco_flags, model_flags, input_file_contents);
+ Import(toco_flags, model_flags, graph_def_contents);
Transform(toco_flags, model.get());
string output_file_contents;
Export(toco_flags, *model, toco_flags.allow_custom_ops(),
&output_file_contents);
- CHECK_OK(port::file::SetContents(parsed_toco_flags.output_file.value(),
- output_file_contents,
- port::file::Defaults()));
+ CHECK(port::file::SetContents(parsed_toco_flags.output_file.value(),
+ output_file_contents, port::file::Defaults())
+ .ok());
}
} // namespace
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index 0f67c2de72..cc7803dd86 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
+#include "absl/types/optional.h"
#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/core/platform/logging.h"
@@ -38,6 +39,9 @@ bool ParseTocoFlagsFromCommandLineFlags(
"Input file (model of any supported format). For Protobuf "
"formats, both text and binary are supported regardless of file "
"extension."),
+ Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(),
+ parsed_flags.savedmodel_directory.default_value(),
+ "Full path to the directory containing the SavedModel."),
Flag("output_file", parsed_flags.output_file.bind(),
parsed_flags.output_file.default_value(),
"Output file. "
@@ -49,6 +53,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.output_format.default_value(),
"Output file format. "
"One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."),
+ Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(),
+ parsed_flags.savedmodel_tagset.default_value(),
+ "Comma-separated set of tags identifying the MetaGraphDef within "
+ "the SavedModel to analyze. All tags in the tag set must be "
+ "specified."),
Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
parsed_flags.default_ranges_min.default_value(),
"If defined, will be used as the default value for the min bound "
@@ -128,47 +137,72 @@ bool ParseTocoFlagsFromCommandLineFlags(
}
}
+namespace {
+
+// Defines the requirements for a given flag. kUseDefault means the default
+// should be used in cases where the value isn't specified by the user.
+enum class FlagRequirement {
+ kNone,
+ kMustBeSpecified,
+ kMustNotBeSpecified,
+ kUseDefault,
+};
+
+// Enforces the FlagRequirements are met for a given flag.
+template <typename T>
+void EnforceFlagRequirement(const T& flag, const string& flag_name,
+ FlagRequirement requirement) {
+ if (requirement == FlagRequirement::kMustBeSpecified) {
+ QCHECK(flag.specified()) << "Missing required flag " << flag_name;
+ }
+ if (requirement == FlagRequirement::kMustNotBeSpecified) {
+ QCHECK(!flag.specified())
+ << "Given other flags, this flag should not have been specified: "
+ << flag_name;
+ }
+}
+
+// Gets the value from the flag if specified. Returns default if the
+// FlagRequirement is kUseDefault.
+template <typename T>
+absl::optional<T> GetFlagValue(const Arg<T>& flag,
+ FlagRequirement requirement) {
+ if (flag.specified()) return flag.value();
+ if (requirement == FlagRequirement::kUseDefault) return flag.default_value();
+ return absl::optional<T>();
+}
+
+} // namespace
+
void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
TocoFlags* toco_flags) {
namespace port = toco::port;
port::CheckInitGoogleIsDone("InitGoogle is not done yet");
- enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified };
-
-#define ENFORCE_FLAG_REQUIREMENT(name, requirement) \
- do { \
- if (requirement == FlagRequirement::kMustBeSpecified) { \
- QCHECK(parsed_toco_flags.name.specified()) \
- << "Missing required flag: " << #name; \
- } \
- if (requirement == FlagRequirement::kMustNotBeSpecified) { \
- QCHECK(!parsed_toco_flags.name.specified()) \
- << "Given other flags, this flag should not have been specified: " \
- << #name; \
- } \
- } while (false)
-#define READ_TOCO_FLAG(name, requirement) \
- ENFORCE_FLAG_REQUIREMENT(name, requirement); \
- do { \
- if (parsed_toco_flags.name.specified()) { \
- toco_flags->set_##name(parsed_toco_flags.name.value()); \
- } \
+#define READ_TOCO_FLAG(name, requirement) \
+ do { \
+ EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \
+ auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
+ if (flag_value.has_value()) { \
+ toco_flags->set_##name(flag_value.value()); \
+ } \
} while (false)
-#define PARSE_TOCO_FLAG(Type, name, requirement) \
- ENFORCE_FLAG_REQUIREMENT(name, requirement); \
- do { \
- if (parsed_toco_flags.name.specified()) { \
- Type x; \
- QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \
- << "Unrecognized " << #Type << " value " \
- << parsed_toco_flags.name.value(); \
- toco_flags->set_##name(x); \
- } \
+#define PARSE_TOCO_FLAG(Type, name, requirement) \
+ do { \
+ EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \
+ auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
+ if (flag_value.has_value()) { \
+ Type x; \
+ QCHECK(Type##_Parse(flag_value.value(), &x)) \
+ << "Unrecognized " << #Type << " value " \
+ << parsed_toco_flags.name.value(); \
+ toco_flags->set_##name(x); \
+ } \
} while (false)
- PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified);
- PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified);
+ PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault);
+ PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault);
PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone);
PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone);
READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.cc b/tensorflow/contrib/lite/toco/toco_saved_model.cc
new file mode 100644
index 0000000000..91a742b9e0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_saved_model.cc
@@ -0,0 +1,186 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/numbers.h"
+#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
+#include "tensorflow/contrib/lite/toco/toco_saved_model.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace toco {
+namespace {
+
+// Loads a SavedModel from the directory specified in parsed_toco_flags.
+// Returns a SavedModelBundle with the requested MetaGraphDef.
+const tensorflow::SavedModelBundle* LoadSavedModel(
+ const ParsedTocoFlags& parsed_toco_flags) {
+ const string model_path = parsed_toco_flags.savedmodel_directory.value();
+ QCHECK(tensorflow::MaybeSavedModelDirectory(model_path))
+ << "Model is not saved in the supported SavedModel format.\n";
+
+ // Gets the tags identifying the MetaGraphDef from the command line arguments.
+ QCHECK(parsed_toco_flags.savedmodel_tagset.specified())
+ << "Missing required flag --savedmodel_tagset.\n";
+ const string tags_str = parsed_toco_flags.savedmodel_tagset.value();
+ auto tags = absl::StrSplit(tags_str, ',');
+
+ // Loads MetaGraphDef.
+ auto* bundle = new tensorflow::SavedModelBundle;
+ TF_CHECK_OK(tensorflow::LoadSavedModel(tensorflow::SessionOptions(),
+ tensorflow::RunOptions(), model_path,
+ tags, bundle))
+ << "Failed to load exported model from " << model_path
+ << ". Ensure the model contains the required tags '" << tags_str
+ << "'.\n";
+ return bundle;
+}
+
+// Returns the array name without the postfix.
+//
+// e.g. reduces "input:0" to "input".
+string GetArrayName(const string& name) {
+ const std::vector<string>& names = absl::StrSplit(name, ':');
+ return names[0];
+}
+
+// Returns the list of array names without the postfix sorted alphabetically.
+std::set<string> GetSortedNames(const std::unordered_set<string>& names) {
+ std::vector<string> final_names;
+ final_names.reserve(names.size());
+ for (const auto& name : names) {
+ final_names.push_back(GetArrayName(name));
+ }
+ return std::set<string>(final_names.begin(), final_names.end());
+}
+
+// Gets the final shape after replacing the first dimension with batch size, if
+// it is undefined (containing the value -1). Returns whether the shape is
+// valid.
+bool ReplaceShapeBatchSize(const tensorflow::TensorShapeProto& shape,
+ int batch_size,
+ tensorflow::TensorShapeProto* final_shape) {
+ for (int idx = 0; idx < shape.dim().size(); ++idx) {
+ int64 final_dim = shape.dim()[idx].size();
+ if (final_dim == -1) {
+ if (idx > 0) return false;
+ final_dim = batch_size;
+ }
+ final_shape->add_dim()->set_size(final_dim);
+ }
+ return true;
+}
+
+// Updates the input arrays in ModelFlags to contain the shape of the array.
+void ProcessInputShapes(const tensorflow::GraphDef& graph_def, int batch_size,
+ ModelFlags* model_flags) {
+ // Build map of input array names to input arrays.
+ std::unordered_map<string, InputArray*> input_data_map;
+ for (auto& input : *model_flags->mutable_input_arrays()) {
+ input_data_map[input.name()] = &input;
+ }
+
+ // Adds shapes to the input arrays if the shape is valid.
+ for (const tensorflow::NodeDef& node_def : graph_def.node()) {
+ if (input_data_map.find(node_def.name()) != input_data_map.end()) {
+ const auto shape_it = node_def.attr().find("shape");
+ if (shape_it != node_def.attr().end()) {
+ tensorflow::TensorShapeProto final_shape;
+ bool is_valid = ReplaceShapeBatchSize(shape_it->second.shape(),
+ batch_size, &final_shape);
+
+ if (is_valid) {
+ auto* shape = input_data_map.at(node_def.name())->mutable_shape();
+ QCHECK_EQ(shape->dims_size(), 0)
+ << "The shape for the input '" << node_def.name()
+ << "' was previously defined. For clarity please define inputs "
+ << "via --input_arrays and input_shapes flags.\n";
+ for (const auto& dim : final_shape.dim()) {
+ shape->add_dims(dim.size());
+ }
+ }
+ }
+ }
+ }
+
+ // Checks all input arrays have a shape.
+ for (auto const& input : model_flags->input_arrays()) {
+ QCHECK(input.shape().dims_size() > 0)
+ << "A valid input shape was not found for input '" << input.name()
+ << "'. Please define via --input_arrays and --input_shapes flags.\n";
+ }
+}
+
+} // namespace
+
+void ParseMetaData(const tensorflow::GraphDef& graph_def,
+ const std::unordered_set<string>& inputs,
+ const std::unordered_set<string>& outputs,
+ const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags) {
+ if (!parsed_model_flags.input_arrays.specified()) {
+ const std::set<string> sorted_inputs = GetSortedNames(inputs);
+ for (const auto& input_name : sorted_inputs) {
+ model_flags->add_input_arrays()->set_name(input_name);
+ }
+ }
+
+ if (!parsed_model_flags.output_arrays.specified()) {
+ const std::set<string> sorted_outputs = GetSortedNames(outputs);
+ for (const auto& output_name : sorted_outputs) {
+ model_flags->add_output_arrays(GetArrayName(output_name));
+ }
+ }
+
+ if (!parsed_model_flags.input_shapes.specified()) {
+ int batch_size = parsed_model_flags.batch_size.value();
+ ProcessInputShapes(graph_def, batch_size, model_flags);
+ }
+
+ if (!parsed_toco_flags.inference_type.specified()) {
+ toco_flags->set_inference_type(IODataType::FLOAT);
+ }
+}
+
+// TODO(nupurgarg): Add top level tests.
+void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags,
+ string* graph_def_contents) {
+ // Loads the MetaGraphDef within a SavedModelBundle.
+ auto bundle = LoadSavedModel(parsed_toco_flags);
+
+ // Converts the MetaGraphDef to frozen GraphDef.
+ tensorflow::GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_CHECK_OK(tensorflow::FreezeSavedModel(*bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ // Reads the frozen GraphDef into a string.
+ QCHECK(frozen_graph_def.SerializeToString(graph_def_contents))
+ << "Unable to generate serialized GraphDef.\n";
+
+ // Process inputs and outputs and metadata within GraphDef.
+ const tensorflow::GraphDef graph_def = bundle->meta_graph_def.graph_def();
+ ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags,
+ parsed_model_flags, toco_flags, model_flags);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.h b/tensorflow/contrib/lite/toco/toco_saved_model.h
new file mode 100644
index 0000000000..7a0fabd82d
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_saved_model.h
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/cc/tools/freeze_saved_model.h"
+#include "tensorflow/contrib/lite/toco/args.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/types.pb.h"
+
+namespace toco {
+
+// Parses metadata into `toco_flags` and `model_flags`.
+//
+// Stores `inputs` as input_arrays and `outputs` as output_arrays in
+// `model_flags`. Infers input_shapes from the GraphDef and stores it in
+// `model_flags` as part of the input_arrays. Assumes inference_type is FLOAT
+// and stores it in `toco_flags`.
+void ParseMetaData(const tensorflow::GraphDef& graph_def,
+ const std::unordered_set<string>& inputs,
+ const std::unordered_set<string>& outputs,
+ const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags);
+
+// Generates a frozen graph from the SavedModel in the directory specified in
+// `toco_flags`. Reads frozen graph contents into `graph_def_contents`. Parses
+// metadata relating to the GraphDef into `toco_flags` and `model_flags`.
+void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags,
+ string* graph_def_contents);
+
+} // namespace toco
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_
diff --git a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc
new file mode 100644
index 0000000000..5e122afe65
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc
@@ -0,0 +1,274 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/toco/toco_saved_model.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
+#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace toco {
+namespace {
+
+using tensorflow::ops::Add;
+using tensorflow::ops::Const;
+using tensorflow::ops::FakeQuantWithMinMaxArgs;
+using tensorflow::ops::Placeholder;
+
+class TocoSavedModelTest : public ::testing::Test {
+ protected:
+ // Calls functions to process cmdline arguments and calls ParseMetaData.
+ // ParseMetaData parses input_arrays, output_arrays, and gets metadata from
+ // SavedModel it is not defined in the cmdline arguments.
+ void ProcessGraphDefMetadata(const std::unordered_set<string>& inputs,
+ const std::unordered_set<string>& outputs,
+ const tensorflow::GraphDef& graph_def) {
+ ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags_, &toco_flags_);
+ ReadModelFlagsFromCommandLineFlags(parsed_model_flags_, &model_flags_);
+ ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags_,
+ parsed_model_flags_, &toco_flags_, &model_flags_);
+ }
+
+ // Gets the GraphDef from the SavedModelBundle and processes metadata.
+ void ProcessSavedModelMetadata(const std::unordered_set<string>& inputs,
+ const std::unordered_set<string>& outputs) {
+ const tensorflow::GraphDef graph_def = bundle_.meta_graph_def.graph_def();
+ ProcessGraphDefMetadata(inputs, outputs, graph_def);
+ }
+
+ // Returns a GraphDef representing a simple float model with a single input.
+ tensorflow::GraphDef GetFloatGraphDef(const std::vector<int64>& shape) {
+ tensorflow::GraphDef graph_def;
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ tensorflow::Output input =
+ Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT,
+ Placeholder::Shape(tensorflow::PartialTensorShape(shape)));
+ tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {});
+ tensorflow::Output add = Add(scope.WithOpName("add"), input, zero);
+
+ TF_EXPECT_OK(scope.ToGraphDef(&graph_def));
+ return graph_def;
+ }
+
+ // Returns a GraphDef representing a simple float model with two inputs.
+ tensorflow::GraphDef GetComplexFloatGraphDef() {
+ tensorflow::GraphDef graph_def;
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ tensorflow::Output inputA =
+ Placeholder(scope.WithOpName("inputA"), tensorflow::DT_FLOAT,
+ Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1})));
+ tensorflow::Output inputB =
+ Placeholder(scope.WithOpName("inputB"), tensorflow::DT_FLOAT,
+ Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1})));
+ tensorflow::Output add = Add(scope.WithOpName("add"), inputB, inputA);
+
+ TF_EXPECT_OK(scope.ToGraphDef(&graph_def));
+ return graph_def;
+ }
+
+ // Returns a GraphDef representing a simple quantized model.
+ tensorflow::GraphDef GetQuantizedGraphDef() {
+ tensorflow::GraphDef graph_def;
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ tensorflow::Output input =
+ Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT,
+ Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1})));
+ tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {});
+ tensorflow::Output fake_quant =
+ FakeQuantWithMinMaxArgs(scope.WithOpName("quant"), zero);
+ tensorflow::Output add = Add(scope.WithOpName("add"), input, fake_quant);
+
+ TF_EXPECT_OK(scope.ToGraphDef(&graph_def));
+ return graph_def;
+ }
+
+ // Gets the values in the input_arrays flag.
+ std::vector<string> GetInputArrays() {
+ std::vector<string> actual;
+ for (const auto& input : model_flags_.input_arrays()) {
+ actual.push_back(input.name());
+ }
+ return actual;
+ }
+
+ // Gets the values in the output_arrays flag.
+ std::vector<string> GetOutputArrays() {
+ std::vector<string> actual(model_flags_.output_arrays().begin(),
+ model_flags_.output_arrays().end());
+ return actual;
+ }
+
+ // Gets the shape of the given input array.
+ string GetInputShape(const string& input_array) {
+ for (const auto& input : model_flags_.input_arrays()) {
+ if (input.name() == input_array) {
+ std::vector<string> dims;
+ for (int idx = 0; idx < input.shape().dims_size(); ++idx) {
+ dims.push_back(std::to_string(input.shape().dims(idx)));
+ }
+ return absl::StrJoin(dims, ",");
+ }
+ }
+ return "";
+ }
+
+ tensorflow::SavedModelBundle bundle_;
+ ParsedTocoFlags parsed_toco_flags_;
+ ParsedModelFlags parsed_model_flags_;
+ TocoFlags toco_flags_;
+ ModelFlags model_flags_;
+};
+
+// Tests if input_arrays, output_arrays, inference_type, and output_arrays are
+// added to ModelFlags if they are not specified in cmdline arguments.
+// Tests if the default batch size replaces a -1 in the first dimension.
+TEST_F(TocoSavedModelTest, NoCmdLine) {
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1});
+
+ ProcessGraphDefMetadata({"input"}, {"add"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
+ EXPECT_EQ(GetInputShape("input"), "1,3,3,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Tests if the order of input_arrays and output_arrays is deterministic when
+// they are taken from the SavedModel.
+TEST_F(TocoSavedModelTest, NoCmdLineMultipleArrays) {
+ tensorflow::GraphDef graph_def = GetComplexFloatGraphDef();
+
+ // Note: The model does not have two outputs. However, the function does not
+ // need an accurate output_array list. This is only meant to test order.
+ ProcessGraphDefMetadata({"inputB", "inputA"}, {"add", "invalid"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add", "invalid"}));
+ EXPECT_EQ(GetInputShape("inputA"), "1,3,3,1");
+ EXPECT_EQ(GetInputShape("inputB"), "1,3,3,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Tests if input_shapes is inferred when input_arrays is passed in via cmdline
+// arguments.
+TEST_F(TocoSavedModelTest, InputNameWithoutInputShape) {
+ parsed_model_flags_.input_arrays.bind()("input");
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({2, 3, 3, 1});
+
+ ProcessGraphDefMetadata({"not_used_input"}, {"add"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
+ EXPECT_EQ(GetInputShape("input"), "2,3,3,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Ensures a failure occurs when input_shapes is defined without input_arrays.
+TEST_F(TocoSavedModelTest, InputShapeWithoutInputName) {
+ parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12");
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1});
+
+ EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def),
+ "failed: input_shapes.size\\(\\) == "
+ "model_flags->input_arrays_size\\(\\)");
+}
+
+// Tests if the cmdline values of input_arrays, input_shapes are used when
+// specified with an empty GraphDef.
+TEST_F(TocoSavedModelTest, InputArraysCmdLine) {
+ parsed_model_flags_.input_arrays.bind()("inputA,inputB");
+ parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12");
+
+ ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"});
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"output0", "output1"}));
+ EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1");
+ EXPECT_EQ(GetInputShape("inputB"), "9,12");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Tests if the cmdline values of input_arrays, input_shapes are used when
+// specified even if values exist within the GraphDef.
+TEST_F(TocoSavedModelTest, InputArraysCmdLineWithGraphDef) {
+ parsed_model_flags_.input_arrays.bind()("inputA");
+ parsed_model_flags_.input_shapes.bind()("1,224,224,1");
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1});
+
+ ProcessGraphDefMetadata({"inputA"}, {"add"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
+ EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Tests if the cmdline values of input_arrays, input_shapes, inference_type,
+// and output_arrays are used when specified with an empty GraphDef.
+TEST_F(TocoSavedModelTest, AllParamsCmdLine) {
+ parsed_model_flags_.input_arrays.bind()("inputA,inputB");
+ parsed_model_flags_.output_arrays.bind()("outputA,outputB");
+ parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12");
+ parsed_toco_flags_.inference_type.bind()("FLOAT");
+
+ ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"});
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"outputA", "outputB"}));
+ EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1");
+ EXPECT_EQ(GetInputShape("inputB"), "9,12");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Tests if a quantized graph gives the correct values assuming type is passed
+// in via command line.
+TEST_F(TocoSavedModelTest, QuantizedNoCmdLine) {
+ parsed_toco_flags_.inference_type.bind()("QUANTIZED_UINT8");
+ tensorflow::GraphDef graph_def = GetQuantizedGraphDef();
+
+ ProcessGraphDefMetadata({"input"}, {"add"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
+ EXPECT_EQ(GetInputShape("input"), "1,3,3,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::QUANTIZED_UINT8);
+}
+
+// Tests if the provided batch size replaces a -1 in the first dimension of
+// input shape.
+TEST_F(TocoSavedModelTest, MissingShapeParameterValid) {
+ parsed_model_flags_.batch_size.bind()(3);
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1});
+
+ ProcessGraphDefMetadata({"input"}, {"add"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
+ EXPECT_EQ(GetInputShape("input"), "3,3,3,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Ensures a failure occurs if there is a -1 in a dimension aside from the first
+// position of input shape.
+TEST_F(TocoSavedModelTest, MissingShapeParameterInvalid) {
+ parsed_model_flags_.batch_size.bind()(3);
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({1, -1, 3, 1});
+
+ EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def),
+ "A valid input shape was not found for input 'input'.");
+}
+
+} // namespace
+} // namespace toco