aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/args.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/args.h')
-rw-r--r--tensorflow/contrib/lite/toco/args.h225
1 files changed, 225 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
new file mode 100644
index 0000000000..28661d4ff0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -0,0 +1,225 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This abstracts command line arguments in toco.
+// Arg<T> is a parseable type that can register a default value, be able to
+// parse itself, and keep track of whether it was specified.
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
+
+#include <functional>
+#include <unordered_map>
+#include <vector>
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+
+namespace toco {
+
+// Since std::vector<int32> is in the std namespace, and we are not allowed
+// to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type
+// to use as the flag type:
+struct IntList {
+ std::vector<int32> elements;
+};
+struct StringMapList {
+ std::vector<std::unordered_map<string, string>> elements;
+};
+
+// command_line_flags.h don't track whether or not a flag is specified. Arg
+// contains the value (which will be default if not specified) and also
+// whether the flag is specified.
+// TODO(aselle): consider putting doc string and ability to construct the
+// tensorflow argument into this, so declaration of parameters can be less
+// distributed.
+// Every template specialization of Arg is required to implement
+// default_value(), specified(), value(), parse(), bind().
+template <class T>
+class Arg final {
+ public:
+ explicit Arg(T default_ = T()) : value_(default_) {}
+ virtual ~Arg() {}
+
+ // Provide default_value() to arg list
+ T default_value() const { return value_; }
+ // Return true if the command line argument was specified on the command line.
+ bool specified() const { return specified_; }
+ // Const reference to parsed value.
+ const T& value() const { return value_; }
+
+ // Parsing callback for the tensorflow::Flags code
+ bool parse(T value_in) {
+ value_ = value_in;
+ specified_ = true;
+ return true;
+ }
+
+ // Bind the parse member function so tensorflow::Flags can call it.
+ std::function<bool(T)> bind() {
+ return std::bind(&Arg::parse, this, std::placeholders::_1);
+ }
+
+ private:
+ // Becomes true after parsing if the value was specified
+ bool specified_ = false;
+ // Value of the argument (initialized to the default in the constructor).
+ T value_;
+};
+
+template <>
+class Arg<toco::IntList> final {
+ public:
+ // Provide default_value() to arg list
+ string default_value() const { return ""; }
+ // Return true if the command line argument was specified on the command line.
+ bool specified() const { return specified_; }
+ // Bind the parse member function so tensorflow::Flags can call it.
+ bool parse(string text) {
+ parsed_value_.elements.clear();
+ specified_ = true;
+ // strings::Split("") produces {""}, but we need {} on empty input.
+ // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could
+ // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements)
+ if (!text.empty()) {
+ int32 element;
+ for (absl::string_view part : absl::StrSplit(text, ',')) {
+ if (!SimpleAtoi(part, &element)) return false;
+ parsed_value_.elements.push_back(element);
+ }
+ }
+ return true;
+ }
+
+ std::function<bool(string)> bind() {
+ return std::bind(&Arg::parse, this, std::placeholders::_1);
+ }
+
+ const toco::IntList& value() const { return parsed_value_; }
+
+ private:
+ toco::IntList parsed_value_;
+ bool specified_ = false;
+};
+
+template <>
+class Arg<toco::StringMapList> final {
+ public:
+ // Provide default_value() to StringMapList
+ string default_value() const { return ""; }
+ // Return true if the command line argument was specified on the command line.
+ bool specified() const { return specified_; }
+ // Bind the parse member function so tensorflow::Flags can call it.
+
+ bool parse(string text) {
+ parsed_value_.elements.clear();
+ specified_ = true;
+
+ if (text.empty()) {
+ return true;
+ }
+
+#if defined(PLATFORM_GOOGLE)
+ std::vector<absl::string_view> outer_vector;
+ absl::string_view text_disposable_copy = text;
+ SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector);
+ for (const absl::string_view& outer_member_stringpiece : outer_vector) {
+ string outer_member(outer_member_stringpiece);
+ if (outer_member.empty()) {
+ continue;
+ }
+ string outer_member_copy = outer_member;
+ absl::StripAsciiWhitespace(&outer_member);
+ if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false;
+ if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false;
+ const std::vector<string> inner_fields_vector =
+ strings::Split(outer_member, ',');
+
+ std::unordered_map<string, string> element;
+ for (const string& member_field : inner_fields_vector) {
+ std::vector<string> outer_member_key_value =
+ strings::Split(member_field, ':');
+ if (outer_member_key_value.size() != 2) return false;
+ string& key = outer_member_key_value[0];
+ string& value = outer_member_key_value[1];
+ absl::StripAsciiWhitespace(&key);
+ absl::StripAsciiWhitespace(&value);
+ if (element.count(key) != 0) return false;
+ element[key] = value;
+ }
+ parsed_value_.elements.push_back(element);
+ }
+ return true;
+#else
+ // TODO(aselle): Fix argument parsing when absl supports structuredline
+ fprintf(stderr, "%s:%d StringMapList arguments not supported\n", __FILE__,
+ __LINE__);
+ abort();
+#endif
+ }
+
+ std::function<bool(string)> bind() {
+ return std::bind(&Arg::parse, this, std::placeholders::_1);
+ }
+
+ const toco::StringMapList& value() const { return parsed_value_; }
+
+ private:
+ toco::StringMapList parsed_value_;
+ bool specified_ = false;
+};
+
+// Flags that describe a model. See model_cmdline_flags.cc for details.
+struct ParsedModelFlags {
+ Arg<string> input_array;
+ Arg<string> input_arrays;
+ Arg<string> output_array;
+ Arg<string> output_arrays;
+ Arg<string> input_shapes;
+ Arg<float> mean_value = Arg<float>(0.f);
+ Arg<string> mean_values;
+ Arg<float> std_value = Arg<float>(1.f);
+ Arg<string> std_values;
+ Arg<bool> variable_batch = Arg<bool>(false);
+ Arg<bool> drop_control_dependency = Arg<bool>(false);
+ Arg<toco::IntList> input_shape;
+ Arg<toco::StringMapList> rnn_states;
+ Arg<toco::StringMapList> model_checks;
+ // Debugging output options
+ Arg<string> graphviz_first_array;
+ Arg<string> graphviz_last_array;
+ Arg<string> dump_graphviz;
+ Arg<bool> dump_graphviz_video = Arg<bool>(false);
+};
+
+// Flags that describe the operation you would like to do (what conversion
+// you want). See toco_cmdline_flags.cc for details.
+struct ParsedTocoFlags {
+ Arg<string> input_file;
+ Arg<string> output_file;
+ Arg<string> input_format;
+ Arg<string> output_format;
+ // TODO(aselle): command_line_flags doesn't support doubles
+ Arg<float> default_ranges_min = Arg<float>(0.);
+ Arg<float> default_ranges_max = Arg<float>(0.);
+ Arg<string> input_type;
+ Arg<string> input_types;
+ Arg<string> inference_type;
+ Arg<bool> drop_fake_quant = Arg<bool>(false);
+ Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
+ Arg<bool> allow_custom_ops = Arg<bool>(false);
+};
+
+} // namespace toco
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_