aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc')
-rw-r--r--tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc370
1 files changed, 370 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
new file mode 100644
index 0000000000..e6d5a776b3
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
@@ -0,0 +1,370 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <ctype.h>
+#include <iostream>
+#include <unordered_map>
+#include <unordered_set>
+#include "flatbuffers/minireflect.h" // flatbuffers
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+namespace tflite {
+namespace {
+// This is generated by grepping
+// cat third_party/tensorflow/contrib/lite/builtin_op_data.h
+//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}"
+static const char* param_structs[] = {"TfLiteConvParams",
+ "TfLitePoolParams",
+ "TfLiteDepthwiseConvParams",
+ "TfLiteSVDFParams",
+ "TfLiteRNNParams",
+ "TfLiteSequenceRNNParams",
+ "TfLiteFullyConnectedParams",
+ "TfLiteLSHProjectionParams",
+ "TfLiteSoftmaxParams",
+ "TfLiteConcatenationParams",
+ "TfLiteAddParams",
+ "TfLiteSpaceToBatchNDParams",
+ "TfLiteBatchToSpaceNDParams",
+ "TfLiteMulParams",
+ "TfLiteSubParams",
+ "TfLiteDivParams",
+ "TfLiteL2NormParams",
+ "TfLiteLocalResponseNormParams",
+ "TfLiteLSTMParams",
+ "TfLiteResizeBilinearParams",
+ "TfLitePadParams",
+ "TfLitePadV2Params",
+ "TfLiteReshapeParams",
+ "TfLiteSkipGramParams",
+ "TfLiteSpaceToDepthParams",
+ "TfLiteCastParams",
+ "TfLiteEmbeddingLookupSparseParams",
+ "TfLiteGatherParams",
+ "TfLiteTransposeParams",
+ "TfLiteReducerParams",
+ "TfLiteSplitParams",
+ "TfLiteSqueezeParams",
+ "TfLiteStridedSliceParams",
+ "TfLiteArgMaxParams",
+ "TfLiteArgMinParams",
+ "TfLiteTransposeConvParams",
+ "TfLiteSparseToDenseParams",
+ "TfLiteShapeParams",
+ "TfLiteFakeQuantParams",
+ "TfLitePackParams",
+ "TfLiteOneHotParams",
+ nullptr};
+} // namespace
+
+// Get rid of all underscores and make everything lower case to make name
+// matching work for stuff like 3D vs 3d or RNN vs Rnn.
+std::string ToCollapsed(const std::string& in) {
+ const char* s = in.c_str();
+ bool first = true;
+ std::string out;
+ while (*s != '\0') {
+ if (*s == '_') {
+ first = true;
+ } else if (first) {
+ out.push_back(tolower(*s));
+ first = false;
+ } else {
+ out.push_back(tolower(*s));
+ }
+ s++;
+ }
+ return out;
+}
+
+// A collection of information about builtin ops.
+class OpOptionData {
+ public:
+ OpOptionData() {
+ BuildOpList();
+ BuildOptionToTypeFunctionMap();
+ BuildOpToOptionMap();
+ }
+
+ // A list of builtin operations
+ const std::vector<std::string>& ops() const { return ops_; }
+ // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions')
+ const std::unordered_map<std::string, std::string>& op_to_option() {
+ return op_to_option_;
+ }
+ // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions'
+ const std::unordered_map<std::string, std::string>& option_to_struct() {
+ return option_to_struct_;
+ }
+ // Maps from option to a flatbuffer type function that describes that option.
+ const std::unordered_map<std::string, flatbuffers::TypeFunction>&
+ option_to_type_function() {
+ return option_to_type_function_;
+ }
+
+ private:
+ void BuildOpList() {
+ for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr;
+ ++curr) {
+ if (strlen(*curr) != 0) ops_.push_back(*curr);
+ }
+ }
+
+ void BuildOptionToTypeFunctionMap() {
+ auto d = tflite::BuiltinOptionsTypeTable();
+ for (int i = 0; i < d->num_elems; i++) {
+ flatbuffers::TypeCode code = d->type_codes[i];
+ if (code.sequence_ref != -1) {
+ option_to_type_function_.insert(
+ std::make_pair(d->names[i], d->type_refs[code.sequence_ref]));
+ }
+ }
+ }
+
+ void BuildOpToOptionMap() {
+ // Manually specified mappings between ops and options
+ op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+ op_to_option_["REDUCE_MIN"] = "ReducerOptions";
+ op_to_option_["REDUCE_ANY"] = "ReducerOptions";
+ op_to_option_["UNPACK"] = "";
+ op_to_option_["SUM"] = "ReducerOptions";
+ op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+ op_to_option_["REDUCE_PROD"] = "ReducerOptions";
+ op_to_option_["MEAN"] = "ReducerOptions";
+ op_to_option_["L2_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
+ op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+ op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ // Manually specified mappings between ops and options (none)
+ op_to_option_["EMBEDDING_LOOKUP"] =
+ ""; // TODO(aselle): maybe something else.
+ op_to_option_["FLOOR"] = "";
+ op_to_option_["HASHTABLE_LOOKUP"] =
+ ""; // TODO(aselle): maybe something else.
+ op_to_option_["LOGISTIC"] = "";
+ op_to_option_["RELU"] = "";
+ op_to_option_["RELU_N1_TO_1"] = "";
+ op_to_option_["RELU6"] = "";
+ op_to_option_["TANH"] = "";
+ op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
+ op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
+ op_to_option_["PRELU"] = "";
+ op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
+ op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
+ op_to_option_["SIN"] = "";
+ op_to_option_["LOG"] = "";
+ op_to_option_["SQRT"] = "";
+ op_to_option_["RSQRT"] = "";
+
+ // TODO(aselle): These are undesirable hacks. Consider changing C structs
+ option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
+ option_to_struct_["Conv2DOptions"] = "TfLiteConvParams";
+ option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
+ option_to_struct_["LocalResponseNormalizationOptions"] =
+ "TfLiteLocalResponseNormParams";
+ // Now for every op, try to find an option.
+ bool fatal = false;
+ for (auto op_name : ops_) {
+ bool found_option = false;
+ auto d = tflite::BuiltinOptionsTypeTable();
+ std::string collapsed_option_name_guess =
+ ToCollapsed(op_name) + "options";
+ // O(n^2) but not that big of n.
+ for (int i = 0; i < d->num_elems; i++) {
+ std::string option_name = d->names[i];
+ std::string collapsed_option_name = ToCollapsed(option_name);
+ if (collapsed_option_name_guess == collapsed_option_name) {
+ op_to_option_.insert(std::make_pair(op_name, option_name));
+ found_option = true;
+ break;
+ }
+ }
+ auto it = op_to_option_.find(op_name);
+ if (it == op_to_option_.end()) {
+ std::cerr << "Didn't find option for " << op_name << std::endl;
+ fatal = true;
+ } else if (!it->second.empty()) {
+ std::string option_name = it->second;
+
+ if (option_to_struct_.find(option_name) == option_to_struct_.end()) {
+ bool param_struct_found = false;
+ std::string params_guess = std::string("TfLite") + option_name;
+ size_t start = params_guess.find("Options");
+ size_t len = strlen("Options");
+ params_guess.replace(start, len, "Params");
+ for (auto* param = param_structs; *param != nullptr; param++) {
+ if (*param == params_guess) {
+ param_struct_found = true;
+ break;
+ }
+ }
+ if (!param_struct_found) {
+ std::cerr << "Failed to get param struct for option " << option_name
+ << std::endl;
+ fatal = true;
+ } else {
+ option_to_struct_.insert(std::make_pair(option_name, params_guess));
+ }
+ }
+ }
+ }
+ }
+
+ private:
+ std::vector<std::string> ops_;
+ std::unordered_map<std::string, std::string> op_to_option_;
+ std::unordered_map<std::string, std::string> option_to_struct_;
+ std::unordered_map<std::string, flatbuffers::TypeFunction>
+ option_to_type_function_;
+};
+
+void GenerateImportForOp(FILE* fp, const std::string& op_name,
+ const std::string& option_name,
+ const std::string& option_type,
+ const flatbuffers::TypeTable* options,
+ const std::string& struct_name) {
+ // Skip tricky ones for now
+ if (struct_name == "TfLiteResizeBilinearParams") return;
+ if (struct_name == "TfLiteSqueezeParams") return;
+ if (struct_name == "TfLiteEmbeddingLookupSparseParams") return;
+ if (struct_name == "TfLiteReshapeParams") return;
+
+ fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
+ fprintf(fp,
+ " const auto* params = reinterpret_cast<const "
+ "%s*>(builtin_op_data);\n",
+ struct_name.c_str());
+
+ for (size_t i = 0; i < options->num_elems; i++) {
+ std::string elem_name = options->names[i];
+ // TODO(aselle): Irregular naming in builtins
+ if (elem_name == "fused_activation_function")
+ elem_name = "activation";
+ else if (elem_name == "stride_w")
+ elem_name = "stride_width";
+ else if (elem_name == "stride_h")
+ elem_name = "stride_height";
+ else if (elem_name == "dilation_h_factor")
+ elem_name = "dilation_height_factor";
+ else if (elem_name == "dilation_w_factor")
+ elem_name = "dilation_width_factor";
+ else if (elem_name == "new_shape")
+ elem_name = "shape";
+
+ flatbuffers::TypeCode code = options->type_codes[i];
+ auto contained_type = code.sequence_ref != -1
+ ? options->type_refs[code.sequence_ref]
+ : nullptr;
+ std::string mapper = "";
+ if (contained_type == TensorTypeTypeTable) {
+ mapper = "TfLiteTypeToSchemaType";
+ } else if (contained_type == ActivationFunctionTypeTypeTable) {
+ mapper = "TfLiteActivationToSchemaActivation";
+ } else if (contained_type == PaddingTypeTable) {
+ mapper = "TfLitePaddingToSchemaPadding";
+ } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) {
+ mapper = "FullyConnectedOptionsWeightsFormatToSchema";
+ } else if (contained_type == LSTMKernelTypeTypeTable) {
+ mapper = "LSTMKernelTypeToSchema";
+ } else if (contained_type == LSHProjectionTypeTypeTable) {
+ mapper = "LSHProjectionTypeToSchema";
+ }
+
+ fprintf(fp,
+ " auto val%zu = "
+ "%s(params->%s);\n",
+ i, mapper.c_str(), elem_name.c_str());
+ }
+ fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str());
+ for (size_t i = 0; i < options->num_elems; i++) {
+ fprintf(fp, ", val%zu", i);
+ }
+ fprintf(fp, ").Union();\n");
+ fprintf(fp, " return std::make_pair(%s, union_type);\n",
+ option_type.c_str());
+ fprintf(fp, " }\n break;\n");
+}
+
+void GenerateImport(OpOptionData* option, FILE* fp) {
+ std::unordered_set<std::string> ignores;
+ ignores.insert("CONCAT_EMBEDDINGS");
+ ignores.insert("CALL");
+
+ // Allow any op that doesn't have an options struct to be blocked
+ // together
+ for (const auto& op_name : option->ops()) {
+ auto option_it = option->op_to_option().find(op_name);
+ if (!option_it->second.empty() && ignores.find(op_name) == ignores.end())
+ continue;
+ fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
+ }
+ fprintf(fp,
+ " return std::make_pair(BuiltinOptions_NONE, "
+ "flatbuffers::Offset<void>());\n break;\n");
+
+ // Iterate over each ops
+ for (const auto& op_name : option->ops()) {
+ if (ignores.find(op_name) != ignores.end()) continue;
+ // Get to the option and struct names, continuing if not found.
+ auto option_it = option->op_to_option().find(op_name);
+ if (option_it->second.empty()) continue;
+ std::string option_name = option_it->second;
+ std::string option_type = "BuiltinOptions_" + option_name;
+ auto option_func_it = option->option_to_type_function().find(option_name);
+ if (option_func_it == option->option_to_type_function().end()) continue;
+ auto struct_name_it = option->option_to_struct().find(option_name);
+ if (struct_name_it == option->option_to_struct().end()) {
+ // If no C struct, then it better have no arguments.
+ auto type_info = option_func_it->second();
+ if (type_info->num_elems != 0) {
+ // We have non-zero arguments in the schema, this means there
+ // should be a struct.
+ fprintf(stderr,
+ "Op %s uses option struct %s which has no builtin struct\n",
+ op_name.c_str(), option_name.c_str());
+ exit(1);
+ }
+ fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
+ fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());",
+ option_type.c_str(), option_name.c_str());
+ } else {
+ // If C struct, then we need to assign all properties
+ auto struct_name = struct_name_it->second;
+ GenerateImportForOp(fp, op_name, option_name, option_type,
+ option_func_it->second(), struct_name);
+ }
+ }
+ // TODO(aselle): Handle unhandled cases more gracefully.
+ fprintf(fp,
+ "default: return std::make_pair(BuiltinOptions_NONE, "
+ "flatbuffers::Offset<void>());\n break;\n");
+}
+
+} // namespace tflite
+
+int main(int argc, char* argv[]) {
+ tflite::OpOptionData option;
+ if (argc != 2) {
+ fprintf(stderr, "Usage: %s <fname out>\n", argv[0]);
+ return 1;
+ }
+ FILE* fp = fopen(argv[1], "w");
+ tflite::GenerateImport(&option, fp);
+ fclose(fp);
+}