diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/model_cmdline_flags.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/model_cmdline_flags.cc | 374 |
1 files changed, 374 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc new file mode 100644 index 0000000000..699c95753f --- /dev/null +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -0,0 +1,374 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" + +#include <string> +#include <vector> + +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "tensorflow/contrib/lite/toco/args.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" +// "batch" flag only exists internally +#ifdef PLATFORM_GOOGLE +#include "base/commandlineflags.h" +#endif + +namespace toco { + +bool ParseModelFlagsFromCommandLineFlags( + int* argc, char* argv[], string* msg, + ParsedModelFlags* parsed_model_flags_ptr) { + ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr; + using tensorflow::Flag; + std::vector<tensorflow::Flag> flags = { + Flag("input_array", parsed_flags.input_array.bind(), + parsed_flags.input_array.default_value(), + "Name of the input array. If not specified, will try to read " + "that information from the input file."), + Flag("input_arrays", parsed_flags.input_arrays.bind(), + parsed_flags.input_arrays.default_value(), + "Names of the output arrays, comma-separated. If not specified, " + "will try to read that information from the input file."), + Flag("output_array", parsed_flags.output_array.bind(), + parsed_flags.output_array.default_value(), + "Name of the output array, when specifying a unique output array. " + "If not specified, will try to read that information from the " + "input file."), + Flag("output_arrays", parsed_flags.output_arrays.bind(), + parsed_flags.output_arrays.default_value(), + "Names of the output arrays, comma-separated. " + "If not specified, will try to read " + "that information from the input file."), + Flag("input_shape", parsed_flags.input_shape.bind(), + parsed_flags.output_arrays.default_value(), + "Input array shape. For many models the shape takes the form " + "batch size, input array height, input array width, input array " + "depth."), + Flag("input_shapes", parsed_flags.input_shapes.bind(), + parsed_flags.input_shapes.default_value(), + "Shapes corresponding to --input_arrays, colon-separated. For " + "many models each shape takes the form batch size, input array " + "height, input array width, input array depth."), + Flag("mean_value", parsed_flags.mean_value.bind(), + parsed_flags.mean_value.default_value(), + "mean_value parameter for image models, used to compute input " + "activations from input pixel data."), + Flag("mean_values", parsed_flags.mean_values.bind(), + parsed_flags.mean_values.default_value(), + "mean_values parameter for image models, comma-separated list of " + "doubles, used to compute input activations from input pixel " + "data. Each entry in the list should match an entry in " + "--input_arrays."), + Flag("std_value", parsed_flags.std_value.bind(), + parsed_flags.std_value.default_value(), + "std_value parameter for image models, used to compute input " + "activations from input pixel data."), + Flag("std_values", parsed_flags.std_values.bind(), + parsed_flags.std_values.default_value(), + "std_value parameter for image models, comma-separated list of " + "doubles, used to compute input activations from input pixel " + "data. Each entry in the list should match an entry in " + "--input_arrays."), + Flag("variable_batch", parsed_flags.variable_batch.bind(), + parsed_flags.variable_batch.default_value(), + "If true, the model accepts an arbitrary batch size. Mutually " + "exclusive " + "with the 'batch' field: at most one of these two fields can be " + "set."), + Flag( + "drop_control_dependency", + parsed_flags.drop_control_dependency.bind(), + parsed_flags.drop_control_dependency.default_value(), + "If true, ignore control dependency requirements in input TensorFlow " + "GraphDef. Otherwise an error will be raised upon control dependency " + "inputs."), + Flag("rnn_states", parsed_flags.rnn_states.bind(), + parsed_flags.rnn_states.default_value(), ""), + Flag("model_checks", parsed_flags.model_checks.bind(), + parsed_flags.model_checks.default_value(), + "A list of model checks to be applied to verify the form of the " + "model. Applied after the graph transformations after import."), + Flag("graphviz_first_array", parsed_flags.graphviz_first_array.bind(), + parsed_flags.graphviz_first_array.default_value(), + "If set, defines the start of the sub-graph to be dumped to " + "GraphViz."), + Flag( + "graphviz_last_array", parsed_flags.graphviz_last_array.bind(), + parsed_flags.graphviz_last_array.default_value(), + "If set, defines the end of the sub-graph to be dumped to GraphViz."), + Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(), + parsed_flags.dump_graphviz.default_value(), + "Dump graphviz during LogDump call. If string is non-empty then " + "it defines path to dump, otherwise will skip dumping."), + Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(), + parsed_flags.dump_graphviz_video.default_value(), + "If true, will dump graphviz at each " + "graph transformation, which may be used to generate a video."), + }; + bool asked_for_help = + *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); + if (asked_for_help) { + *msg += tensorflow::Flags::Usage(argv[0], flags); + return false; + } else { + if (!tensorflow::Flags::Parse(argc, argv, flags)) return false; + } + auto& dump_options = *GraphVizDumpOptions::singleton(); + dump_options.graphviz_first_array = parsed_flags.graphviz_first_array.value(); + dump_options.graphviz_last_array = parsed_flags.graphviz_last_array.value(); + dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value(); + dump_options.dump_graphviz = parsed_flags.dump_graphviz.value(); + + return true; +} + +void ReadModelFlagsFromCommandLineFlags( + const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) { + toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet"); + +// "batch" flag only exists internally +#ifdef PLATFORM_GOOGLE + CHECK(!((base::SpecifiedOnCommandLine("batch") && + parsed_model_flags.variable_batch.specified()))) + << "The --batch and --variable_batch flags are mutually exclusive."; +#endif + CHECK(!(parsed_model_flags.output_array.specified() && + parsed_model_flags.output_arrays.specified())) + << "The --output_array and --vs flags are mutually exclusive."; + + if (parsed_model_flags.output_array.specified()) { + model_flags->add_output_arrays(parsed_model_flags.output_array.value()); + } + + if (parsed_model_flags.output_arrays.specified()) { + std::vector<string> output_arrays = + absl::StrSplit(parsed_model_flags.output_arrays.value(), ','); + for (const string& output_array : output_arrays) { + model_flags->add_output_arrays(output_array); + } + } + + const bool uses_single_input_flags = + parsed_model_flags.input_array.specified() || + parsed_model_flags.mean_value.specified() || + parsed_model_flags.std_value.specified() || + parsed_model_flags.input_shape.specified(); + + const bool uses_multi_input_flags = + parsed_model_flags.input_arrays.specified() || + parsed_model_flags.mean_values.specified() || + parsed_model_flags.std_values.specified() || + parsed_model_flags.input_shapes.specified(); + + QCHECK(!(uses_single_input_flags && uses_multi_input_flags)) + << "Use either the singular-form input flags (--input_array, " + "--input_shape, --mean_value, --std_value) or the plural form input " + "flags (--input_arrays, --input_shapes, --mean_values, --std_values), " + "but not both forms within the same command line."; + + if (parsed_model_flags.input_array.specified()) { + QCHECK(uses_single_input_flags); + model_flags->add_input_arrays()->set_name( + parsed_model_flags.input_array.value()); + } + if (parsed_model_flags.input_arrays.specified()) { + QCHECK(uses_multi_input_flags); + for (const auto& input_array : + absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) { + model_flags->add_input_arrays()->set_name(string(input_array)); + } + } + if (parsed_model_flags.mean_value.specified()) { + QCHECK(uses_single_input_flags); + model_flags->mutable_input_arrays(0)->set_mean_value( + parsed_model_flags.mean_value.value()); + } + if (parsed_model_flags.mean_values.specified()) { + QCHECK(uses_multi_input_flags); + std::vector<string> mean_values = + absl::StrSplit(parsed_model_flags.mean_values.value(), ','); + QCHECK(mean_values.size() == model_flags->input_arrays_size()); + for (int i = 0; i < mean_values.size(); ++i) { + char* last = nullptr; + model_flags->mutable_input_arrays(i)->set_mean_value( + strtod(mean_values[i].data(), &last)); + CHECK(last != mean_values[i].data()); + } + } + if (parsed_model_flags.std_value.specified()) { + QCHECK(uses_single_input_flags); + model_flags->mutable_input_arrays(0)->set_std_value( + parsed_model_flags.std_value.value()); + } + if (parsed_model_flags.std_values.specified()) { + QCHECK(uses_multi_input_flags); + std::vector<string> std_values = + absl::StrSplit(parsed_model_flags.std_values.value(), ','); + QCHECK(std_values.size() == model_flags->input_arrays_size()); + for (int i = 0; i < std_values.size(); ++i) { + char* last = nullptr; + model_flags->mutable_input_arrays(i)->set_std_value( + strtod(std_values[i].data(), &last)); + CHECK(last != std_values[i].data()); + } + } + if (parsed_model_flags.input_shape.specified()) { + QCHECK(uses_single_input_flags); + if (model_flags->input_arrays().empty()) { + model_flags->add_input_arrays(); + } + auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape(); + shape->Clear(); + const IntList& list = parsed_model_flags.input_shape.value(); + for (auto& dim : list.elements) { + shape->Add(dim); + } + } + if (parsed_model_flags.input_shapes.specified()) { + QCHECK(uses_multi_input_flags); + std::vector<string> input_shapes = + absl::StrSplit(parsed_model_flags.input_shapes.value(), ':'); + QCHECK(input_shapes.size() == model_flags->input_arrays_size()); + for (int i = 0; i < input_shapes.size(); ++i) { + auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape(); + shape->Clear(); + if (input_shapes[i].empty()) { + // empty i.e. 0-dimensional input shape. + // Unfortunately, the current toco::InputArray + // proto does not allow to distinguish between a known 0-D shape, + // and an unknown shape. Indeed, shape is currently a plain array, + // and it being empty means unknown shape. So here, we import a + // 0-D shape as a 1-D shape of size. + // TODO(benoitjacob): fix toco::InputArray to allow 0-D shape, + // probably by making shape an optional message, + // encapsulating the array. + shape->Add(1); + } else { + for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) { + int size; + CHECK(absl::SimpleAtoi(dim_str, &size)) + << "Failed to parse input_shape: " << input_shapes[i]; + shape->Add(size); + } + } + } + } + +#define READ_MODEL_FLAG(name) \ + do { \ + if (parsed_model_flags.name.specified()) { \ + model_flags->set_##name(parsed_model_flags.name.value()); \ + } \ + } while (false) + + READ_MODEL_FLAG(variable_batch); + READ_MODEL_FLAG(drop_control_dependency); + +#undef READ_MODEL_FLAG + + for (const auto& element : parsed_model_flags.rnn_states.value().elements) { + auto* rnn_state_proto = model_flags->add_rnn_states(); + for (const auto& kv_pair : element) { + const string& key = kv_pair.first; + const string& value = kv_pair.second; + if (key == "state_array") { + rnn_state_proto->set_state_array(value); + } else if (key == "back_edge_source_array") { + rnn_state_proto->set_back_edge_source_array(value); + } else if (key == "size") { + int32 size = 0; + CHECK(absl::SimpleAtoi(value, &size)); + CHECK_GT(size, 0); + rnn_state_proto->set_size(size); + } else if (key == "manually_create") { + CHECK_EQ(absl::AsciiStrToLower(value), "true"); + rnn_state_proto->set_manually_create(true); + } else { + LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states"; + } + } + CHECK(rnn_state_proto->has_state_array() && + rnn_state_proto->has_back_edge_source_array() && + rnn_state_proto->has_size()) + << "--rnn_states must include state_array, back_edge_source_array and " + "size."; + } + + for (const auto& element : parsed_model_flags.model_checks.value().elements) { + auto* model_check_proto = model_flags->add_model_checks(); + for (const auto& kv_pair : element) { + const string& key = kv_pair.first; + const string& value = kv_pair.second; + if (key == "count_type") { + model_check_proto->set_count_type(value); + } else if (key == "count_min") { + int32 count = 0; + CHECK(absl::SimpleAtoi(value, &count)); + CHECK_GE(count, -1); + model_check_proto->set_count_min(count); + } else if (key == "count_max") { + int32 count = 0; + CHECK(absl::SimpleAtoi(value, &count)); + CHECK_GE(count, -1); + model_check_proto->set_count_max(count); + } else { + LOG(FATAL) << "Unknown key '" << key << "' in --model_checks"; + } + } + } +} + +ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) { + static auto* flags = [must_already_exist]() { + if (must_already_exist) { + fprintf(stderr, __FILE__ + ":" + "GlobalParsedModelFlags() used without initialization\n"); + fflush(stderr); + abort(); + } + return new toco::ParsedModelFlags; + }(); + return flags; +} + +ParsedModelFlags* GlobalParsedModelFlags() { + return UncheckedGlobalParsedModelFlags(true); +} + +void ParseModelFlagsOrDie(int* argc, char* argv[]) { + // TODO(aselle): in the future allow Google version to use + // flags, and only use this mechanism for open source + auto* flags = UncheckedGlobalParsedModelFlags(false); + string msg; + bool model_success = + toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags); + if (!model_success || !msg.empty()) { + // Log in non-standard way since this happens pre InitGoogle. + fprintf(stderr, "%s", msg.c_str()); + fflush(stderr); + abort(); + } +} + +} // namespace toco |