diff options
Diffstat (limited to 'tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc')
-rw-r--r-- | tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc | 370 |
1 files changed, 0 insertions, 370 deletions
diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc deleted file mode 100644 index e6d5a776b3..0000000000 --- a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc +++ /dev/null @@ -1,370 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include <ctype.h> -#include <iostream> -#include <unordered_map> -#include <unordered_set> -#include "flatbuffers/minireflect.h" // flatbuffers -#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" - -namespace tflite { -namespace { -// This is generated by grepping -// cat third_party/tensorflow/contrib/lite/builtin_op_data.h -//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}" -static const char* param_structs[] = {"TfLiteConvParams", - "TfLitePoolParams", - "TfLiteDepthwiseConvParams", - "TfLiteSVDFParams", - "TfLiteRNNParams", - "TfLiteSequenceRNNParams", - "TfLiteFullyConnectedParams", - "TfLiteLSHProjectionParams", - "TfLiteSoftmaxParams", - "TfLiteConcatenationParams", - "TfLiteAddParams", - "TfLiteSpaceToBatchNDParams", - "TfLiteBatchToSpaceNDParams", - "TfLiteMulParams", - "TfLiteSubParams", - "TfLiteDivParams", - "TfLiteL2NormParams", - "TfLiteLocalResponseNormParams", - "TfLiteLSTMParams", - "TfLiteResizeBilinearParams", - "TfLitePadParams", - "TfLitePadV2Params", - "TfLiteReshapeParams", - "TfLiteSkipGramParams", - "TfLiteSpaceToDepthParams", - "TfLiteCastParams", - "TfLiteEmbeddingLookupSparseParams", - "TfLiteGatherParams", - "TfLiteTransposeParams", - "TfLiteReducerParams", - "TfLiteSplitParams", - "TfLiteSqueezeParams", - "TfLiteStridedSliceParams", - "TfLiteArgMaxParams", - "TfLiteArgMinParams", - "TfLiteTransposeConvParams", - "TfLiteSparseToDenseParams", - "TfLiteShapeParams", - "TfLiteFakeQuantParams", - "TfLitePackParams", - "TfLiteOneHotParams", - nullptr}; -} // namespace - -// Get rid of all underscores and make everything lower case to make name -// matching work for stuff like 3D vs 3d or RNN vs Rnn. -std::string ToCollapsed(const std::string& in) { - const char* s = in.c_str(); - bool first = true; - std::string out; - while (*s != '\0') { - if (*s == '_') { - first = true; - } else if (first) { - out.push_back(tolower(*s)); - first = false; - } else { - out.push_back(tolower(*s)); - } - s++; - } - return out; -} - -// A collection of information about builtin ops. -class OpOptionData { - public: - OpOptionData() { - BuildOpList(); - BuildOptionToTypeFunctionMap(); - BuildOpToOptionMap(); - } - - // A list of builtin operations - const std::vector<std::string>& ops() const { return ops_; } - // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions') - const std::unordered_map<std::string, std::string>& op_to_option() { - return op_to_option_; - } - // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions' - const std::unordered_map<std::string, std::string>& option_to_struct() { - return option_to_struct_; - } - // Maps from option to a flatbuffer type function that describes that option. - const std::unordered_map<std::string, flatbuffers::TypeFunction>& - option_to_type_function() { - return option_to_type_function_; - } - - private: - void BuildOpList() { - for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr; - ++curr) { - if (strlen(*curr) != 0) ops_.push_back(*curr); - } - } - - void BuildOptionToTypeFunctionMap() { - auto d = tflite::BuiltinOptionsTypeTable(); - for (int i = 0; i < d->num_elems; i++) { - flatbuffers::TypeCode code = d->type_codes[i]; - if (code.sequence_ref != -1) { - option_to_type_function_.insert( - std::make_pair(d->names[i], d->type_refs[code.sequence_ref])); - } - } - } - - void BuildOpToOptionMap() { - // Manually specified mappings between ops and options - op_to_option_["REDUCE_MAX"] = "ReducerOptions"; - op_to_option_["REDUCE_MIN"] = "ReducerOptions"; - op_to_option_["REDUCE_ANY"] = "ReducerOptions"; - op_to_option_["UNPACK"] = ""; - op_to_option_["SUM"] = "ReducerOptions"; - op_to_option_["REDUCE_MAX"] = "ReducerOptions"; - op_to_option_["REDUCE_PROD"] = "ReducerOptions"; - op_to_option_["MEAN"] = "ReducerOptions"; - op_to_option_["L2_POOL_2D"] = "Pool2DOptions"; - op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions"; - op_to_option_["MAX_POOL_2D"] = "Pool2DOptions"; - op_to_option_["L2_NORMALIZATION"] = "L2NormOptions"; - op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; - op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; - op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; - op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; - op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; - // Manually specified mappings between ops and options (none) - op_to_option_["EMBEDDING_LOOKUP"] = - ""; // TODO(aselle): maybe something else. - op_to_option_["FLOOR"] = ""; - op_to_option_["HASHTABLE_LOOKUP"] = - ""; // TODO(aselle): maybe something else. - op_to_option_["LOGISTIC"] = ""; - op_to_option_["RELU"] = ""; - op_to_option_["RELU_N1_TO_1"] = ""; - op_to_option_["RELU6"] = ""; - op_to_option_["TANH"] = ""; - op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else. - op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else. - op_to_option_["PRELU"] = ""; - op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions - op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions - op_to_option_["SIN"] = ""; - op_to_option_["LOG"] = ""; - op_to_option_["SQRT"] = ""; - op_to_option_["RSQRT"] = ""; - - // TODO(aselle): These are undesirable hacks. Consider changing C structs - option_to_struct_["Pool2DOptions"] = "TfLitePoolParams"; - option_to_struct_["Conv2DOptions"] = "TfLiteConvParams"; - option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams"; - option_to_struct_["LocalResponseNormalizationOptions"] = - "TfLiteLocalResponseNormParams"; - // Now for every op, try to find an option. - bool fatal = false; - for (auto op_name : ops_) { - bool found_option = false; - auto d = tflite::BuiltinOptionsTypeTable(); - std::string collapsed_option_name_guess = - ToCollapsed(op_name) + "options"; - // O(n^2) but not that big of n. - for (int i = 0; i < d->num_elems; i++) { - std::string option_name = d->names[i]; - std::string collapsed_option_name = ToCollapsed(option_name); - if (collapsed_option_name_guess == collapsed_option_name) { - op_to_option_.insert(std::make_pair(op_name, option_name)); - found_option = true; - break; - } - } - auto it = op_to_option_.find(op_name); - if (it == op_to_option_.end()) { - std::cerr << "Didn't find option for " << op_name << std::endl; - fatal = true; - } else if (!it->second.empty()) { - std::string option_name = it->second; - - if (option_to_struct_.find(option_name) == option_to_struct_.end()) { - bool param_struct_found = false; - std::string params_guess = std::string("TfLite") + option_name; - size_t start = params_guess.find("Options"); - size_t len = strlen("Options"); - params_guess.replace(start, len, "Params"); - for (auto* param = param_structs; *param != nullptr; param++) { - if (*param == params_guess) { - param_struct_found = true; - break; - } - } - if (!param_struct_found) { - std::cerr << "Failed to get param struct for option " << option_name - << std::endl; - fatal = true; - } else { - option_to_struct_.insert(std::make_pair(option_name, params_guess)); - } - } - } - } - } - - private: - std::vector<std::string> ops_; - std::unordered_map<std::string, std::string> op_to_option_; - std::unordered_map<std::string, std::string> option_to_struct_; - std::unordered_map<std::string, flatbuffers::TypeFunction> - option_to_type_function_; -}; - -void GenerateImportForOp(FILE* fp, const std::string& op_name, - const std::string& option_name, - const std::string& option_type, - const flatbuffers::TypeTable* options, - const std::string& struct_name) { - // Skip tricky ones for now - if (struct_name == "TfLiteResizeBilinearParams") return; - if (struct_name == "TfLiteSqueezeParams") return; - if (struct_name == "TfLiteEmbeddingLookupSparseParams") return; - if (struct_name == "TfLiteReshapeParams") return; - - fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str()); - fprintf(fp, - " const auto* params = reinterpret_cast<const " - "%s*>(builtin_op_data);\n", - struct_name.c_str()); - - for (size_t i = 0; i < options->num_elems; i++) { - std::string elem_name = options->names[i]; - // TODO(aselle): Irregular naming in builtins - if (elem_name == "fused_activation_function") - elem_name = "activation"; - else if (elem_name == "stride_w") - elem_name = "stride_width"; - else if (elem_name == "stride_h") - elem_name = "stride_height"; - else if (elem_name == "dilation_h_factor") - elem_name = "dilation_height_factor"; - else if (elem_name == "dilation_w_factor") - elem_name = "dilation_width_factor"; - else if (elem_name == "new_shape") - elem_name = "shape"; - - flatbuffers::TypeCode code = options->type_codes[i]; - auto contained_type = code.sequence_ref != -1 - ? options->type_refs[code.sequence_ref] - : nullptr; - std::string mapper = ""; - if (contained_type == TensorTypeTypeTable) { - mapper = "TfLiteTypeToSchemaType"; - } else if (contained_type == ActivationFunctionTypeTypeTable) { - mapper = "TfLiteActivationToSchemaActivation"; - } else if (contained_type == PaddingTypeTable) { - mapper = "TfLitePaddingToSchemaPadding"; - } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) { - mapper = "FullyConnectedOptionsWeightsFormatToSchema"; - } else if (contained_type == LSTMKernelTypeTypeTable) { - mapper = "LSTMKernelTypeToSchema"; - } else if (contained_type == LSHProjectionTypeTypeTable) { - mapper = "LSHProjectionTypeToSchema"; - } - - fprintf(fp, - " auto val%zu = " - "%s(params->%s);\n", - i, mapper.c_str(), elem_name.c_str()); - } - fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str()); - for (size_t i = 0; i < options->num_elems; i++) { - fprintf(fp, ", val%zu", i); - } - fprintf(fp, ").Union();\n"); - fprintf(fp, " return std::make_pair(%s, union_type);\n", - option_type.c_str()); - fprintf(fp, " }\n break;\n"); -} - -void GenerateImport(OpOptionData* option, FILE* fp) { - std::unordered_set<std::string> ignores; - ignores.insert("CONCAT_EMBEDDINGS"); - ignores.insert("CALL"); - - // Allow any op that doesn't have an options struct to be blocked - // together - for (const auto& op_name : option->ops()) { - auto option_it = option->op_to_option().find(op_name); - if (!option_it->second.empty() && ignores.find(op_name) == ignores.end()) - continue; - fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str()); - } - fprintf(fp, - " return std::make_pair(BuiltinOptions_NONE, " - "flatbuffers::Offset<void>());\n break;\n"); - - // Iterate over each ops - for (const auto& op_name : option->ops()) { - if (ignores.find(op_name) != ignores.end()) continue; - // Get to the option and struct names, continuing if not found. - auto option_it = option->op_to_option().find(op_name); - if (option_it->second.empty()) continue; - std::string option_name = option_it->second; - std::string option_type = "BuiltinOptions_" + option_name; - auto option_func_it = option->option_to_type_function().find(option_name); - if (option_func_it == option->option_to_type_function().end()) continue; - auto struct_name_it = option->option_to_struct().find(option_name); - if (struct_name_it == option->option_to_struct().end()) { - // If no C struct, then it better have no arguments. - auto type_info = option_func_it->second(); - if (type_info->num_elems != 0) { - // We have non-zero arguments in the schema, this means there - // should be a struct. - fprintf(stderr, - "Op %s uses option struct %s which has no builtin struct\n", - op_name.c_str(), option_name.c_str()); - exit(1); - } - fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str()); - fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());", - option_type.c_str(), option_name.c_str()); - } else { - // If C struct, then we need to assign all properties - auto struct_name = struct_name_it->second; - GenerateImportForOp(fp, op_name, option_name, option_type, - option_func_it->second(), struct_name); - } - } - // TODO(aselle): Handle unhandled cases more gracefully. - fprintf(fp, - "default: return std::make_pair(BuiltinOptions_NONE, " - "flatbuffers::Offset<void>());\n break;\n"); -} - -} // namespace tflite - -int main(int argc, char* argv[]) { - tflite::OpOptionData option; - if (argc != 2) { - fprintf(stderr, "Usage: %s <fname out>\n", argv[0]); - return 1; - } - FILE* fp = fopen(argv[1], "w"); - tflite::GenerateImport(&option, fp); - fclose(fp); -} |