diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/export.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export.cc | 176 |
1 files changed, 96 insertions, 80 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 45ca7f7f0c..f6f76e48a4 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -63,21 +63,21 @@ bool IsControlFlowOp(const string& tensorflow_op) { return false; } -details::OperatorKey GetOperatorKey( - const ::toco::Operator& op, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, - bool allow_flex_ops) { - string custom_code; - if (op.type == OperatorType::kUnsupported) { - const TensorFlowUnsupportedOperator& unsupported_op = - static_cast<const TensorFlowUnsupportedOperator&>(op); - custom_code = unsupported_op.tensorflow_op; - } - int version = 1; - if (ops_by_type.count(op.type) != 0) { - version = ops_by_type.at(op.type)->GetVersion(op); +// Map from operator name to TF Lite enum value, for all builtins. +const std::map<string, BuiltinOperator>& GetBuiltinOpsMap() { + static std::map<string, BuiltinOperator>* builtin_ops = nullptr; + if (builtin_ops == nullptr) { + builtin_ops = new std::map<string, BuiltinOperator>(); + + for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { + BuiltinOperator op = static_cast<BuiltinOperator>(i); + string name = EnumNameBuiltinOperator(op); + if (op != BuiltinOperator_CUSTOM && !name.empty()) { + (*builtin_ops)[name] = op; + } + } } - return details::OperatorKey(op.type, custom_code, version, allow_flex_ops); + return *builtin_ops; } void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder, @@ -91,27 +91,59 @@ void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder, namespace details { -OperatorKey::OperatorKey(OperatorType type, const std::string& custom_code, - int version, bool allow_flex_ops) { - this->type = type; - this->custom_code = custom_code; - this->version = version; - - if (type == OperatorType::kUnsupported) { - // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way - // to populate a regular custom op. We need to find a way to fix this. - if (allow_flex_ops) { - // Memorize the original TensorFlow op name. - this->flex_tensorflow_op = custom_code; - // Prefix the custom code of the flex op. - this->custom_code = string(::tflite::kFlexCustomCodePrefix) + custom_code; - this->is_flex_op = true; - - if (IsControlFlowOp(this->flex_tensorflow_op)) { - is_unsupported_flex_op = true; +OperatorKey GetOperatorKey( + const ::toco::Operator& op, + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, + bool allow_flex_ops) { + string name = HelpfulOperatorTypeName(op); + const auto& builtin_ops = GetBuiltinOpsMap(); + + bool is_builtin = false; + OperatorKey key; + if (ops_by_type.count(op.type) != 0) { + key.version = ops_by_type.at(op.type)->GetVersion(op); + name = ops_by_type.at(op.type)->name(); + is_builtin = (builtin_ops.count(name) > 0); + } + + if (is_builtin) { + // For TFLite supported builtin ops, find out its BuiltinOperator enum used + // in FlatBuffer. + key.type = builtin_ops.at(name); + } else { + key.type = BuiltinOperator_CUSTOM; + + key.is_custom_op = true; + if (op.type == OperatorType::kUnsupported) { + const TensorFlowUnsupportedOperator& unsupported_op = + static_cast<const TensorFlowUnsupportedOperator&>(op); + const auto tensorflow_op = unsupported_op.tensorflow_op; + + // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way + // to populate a regular custom op. We need to find a way to fix this. + if (allow_flex_ops) { + // Memorize the original TensorFlow op name. + key.flex_tensorflow_op = tensorflow_op; + // Prefix the custom code of the flex op. + key.custom_code = + string(::tflite::kFlexCustomCodePrefix) + tensorflow_op; + key.is_flex_op = true; + + if (IsControlFlowOp(tensorflow_op)) { + key.is_unsupported_flex_op = true; + } + } else { + key.custom_code = tensorflow_op; } + } else { + // For Toco-supported/TFLite-unsupported ops, currently we produce a + // custom op. This gives developers a chance to implement custom ops. + // TODO(b/116800229): Also produce Toco-supported/TFLite-unsupported ops + // as Flex ops when Flex mode is enabled. + key.custom_code = name; } } + return key; } void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { @@ -145,6 +177,7 @@ void LoadOperatorsMap( ++index; } } + } // namespace details Offset<Vector<Offset<Tensor>>> ExportTensors( @@ -230,7 +263,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes( const Model& model, const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, const details::OperatorsMap& operators_map, FlatBufferBuilder* builder, - std::set<string>* unsupported_ops, const ExportParams& params) { + const ExportParams& params) { // Map from operator name to TF Lite enum value, for all builtins. std::map<string, BuiltinOperator> builtin_ops; for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { @@ -247,37 +280,16 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes( for (const auto& op : model.operators) { const details::OperatorKey operator_key = - GetOperatorKey(*op, ops_by_type, params.allow_flex_ops); + details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops); int op_index = operators_map.at(operator_key); - int op_version = operator_key.version; - string name = HelpfulOperatorTypeName(*op); - bool is_builtin = false; - if (ops_by_type.count(op->type) != 0) { - name = ops_by_type.at(op->type)->name(); - is_builtin = (builtin_ops.count(name) > 0); + flatbuffers::Offset<flatbuffers::String> custom_code = 0; + if (!operator_key.custom_code.empty()) { + custom_code = builder->CreateString(operator_key.custom_code); } - if (is_builtin) { - ordered_opcodes[op_index] = - CreateOperatorCode(*builder, builtin_ops[name], 0, op_version); - } else { - // This could be a kUnsupported, in which case we should be - // able to retrieve the original Tensorflow name from the OperatorKey, or - // this could be a proper TOCO operator that is completely unknown to TF - // Lite. - if (!operator_key.custom_code.empty()) { - name = operator_key.custom_code; - } - // Either way, this is an operator that is not supported by TF Lite, - // so we output it as a custom op and add it to the error summary. - if (unsupported_ops) { - unsupported_ops->insert(name); - } - ordered_opcodes[op_index] = - CreateOperatorCode(*builder, BuiltinOperator_CUSTOM, - builder->CreateString(name), op_version); - } + ordered_opcodes[op_index] = CreateOperatorCode( + *builder, operator_key.type, custom_code, operator_key.version); } std::vector<Offset<OperatorCode>> opcode_vector; @@ -312,7 +324,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators( } int op_index = operators_map.at( - GetOperatorKey(*op, ops_by_type, params.allow_flex_ops)); + details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops)); auto tflite_op_it = ops_by_type.find(op->type); BaseOperator* tflite_op = tflite_op_it == ops_by_type.end() @@ -386,9 +398,8 @@ void Export( Array empty_array; buffers_to_write.push_back(&empty_array); - std::set<string> unsupported_ops; - auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, - &builder, &unsupported_ops, params); + auto op_codes = + ExportOperatorCodes(model, ops_by_type, operators_map, &builder, params); for (const auto& op : model.operators) { if (op->type == OperatorType::kFakeQuant) { @@ -398,7 +409,20 @@ void Export( "for --std_values and --mean_values."; } } - if (!unsupported_ops.empty()) { + + std::set<string> custom_ops; + std::set<string> unsupported_flex_ops; + for (const auto& it : operators_map) { + const details::OperatorKey& key = it.first; + if (key.is_custom_op) { + custom_ops.insert(key.custom_code); + } + if (key.is_unsupported_flex_op) { + unsupported_flex_ops.insert(key.flex_tensorflow_op); + } + } + + if (!custom_ops.empty()) { if (!params.allow_custom_ops) { // Remove ExpandDims and ReorderAxes from unimplemented list unless they // compose the list. Both ops are removed during graph transformations. @@ -406,14 +430,14 @@ void Export( // transformation is unable to run because the output shape is not // defined. This causes unnecessary confusion during model conversion // time. - std::set<string> unsupported_ops_final; - for (const auto& op_type : unsupported_ops) { + std::set<string> custom_ops_final; + for (const auto& op_type : custom_ops) { if (op_type != "ReorderAxes" && op_type != "ExpandDims") { - unsupported_ops_final.insert(op_type); + custom_ops_final.insert(op_type); } } - if (unsupported_ops_final.empty()) { - unsupported_ops_final = unsupported_ops; + if (custom_ops_final.empty()) { + custom_ops_final = custom_ops; } LOG(QFATAL) @@ -423,13 +447,13 @@ void Export( "--allow_custom_ops, or by setting allow_custom_ops=True " "when calling tf.contrib.lite.TFLiteConverter(). Here is a list " "of operators for which you will need custom implementations: " - << absl::StrJoin(unsupported_ops_final, ", ") << "."; + << absl::StrJoin(custom_ops_final, ", ") << "."; } std::set<string> unsupported_control_flow_ops; // Check if unsupported ops contains control flow ops. It's impossible // to implement these ops as custom ops at the moment. - for (const auto& op : unsupported_ops) { + for (const auto& op : custom_ops) { if (IsControlFlowOp(op)) { unsupported_control_flow_ops.insert(op); } @@ -441,14 +465,6 @@ void Export( } } - std::set<string> unsupported_flex_ops; - for (const auto& it : operators_map) { - const details::OperatorKey& key = it.first; - if (key.is_unsupported_flex_op) { - unsupported_flex_ops.insert(key.custom_code); - } - } - if (!unsupported_flex_ops.empty()) { LOG(QFATAL) << "Some of the operators in the model are not supported by " "TensorFlow Flex runtime: " |