diff options
author | 2018-10-05 10:37:16 -0700 | |
---|---|---|
committer | 2018-10-05 10:40:59 -0700 | |
commit | e2f80439c5bfee56581875219ea83cc5307854f5 (patch) | |
tree | e3d8d3b7c8952225174cebcf1d430f8fa3ac588d /tensorflow/contrib/lite | |
parent | d493a7f2fdbbc29a292741135f4c1598352e876b (diff) |
Refactoring TFLite export code. Unify OperatorCode generation logic.
PiperOrigin-RevId: 215928419
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export.cc | 176 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export.h | 19 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export_test.cc | 77 |
3 files changed, 163 insertions, 109 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 45ca7f7f0c..f6f76e48a4 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -63,21 +63,21 @@ bool IsControlFlowOp(const string& tensorflow_op) { return false; } -details::OperatorKey GetOperatorKey( - const ::toco::Operator& op, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, - bool allow_flex_ops) { - string custom_code; - if (op.type == OperatorType::kUnsupported) { - const TensorFlowUnsupportedOperator& unsupported_op = - static_cast<const TensorFlowUnsupportedOperator&>(op); - custom_code = unsupported_op.tensorflow_op; - } - int version = 1; - if (ops_by_type.count(op.type) != 0) { - version = ops_by_type.at(op.type)->GetVersion(op); +// Map from operator name to TF Lite enum value, for all builtins. +const std::map<string, BuiltinOperator>& GetBuiltinOpsMap() { + static std::map<string, BuiltinOperator>* builtin_ops = nullptr; + if (builtin_ops == nullptr) { + builtin_ops = new std::map<string, BuiltinOperator>(); + + for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { + BuiltinOperator op = static_cast<BuiltinOperator>(i); + string name = EnumNameBuiltinOperator(op); + if (op != BuiltinOperator_CUSTOM && !name.empty()) { + (*builtin_ops)[name] = op; + } + } } - return details::OperatorKey(op.type, custom_code, version, allow_flex_ops); + return *builtin_ops; } void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder, @@ -91,27 +91,59 @@ void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder, namespace details { -OperatorKey::OperatorKey(OperatorType type, const std::string& custom_code, - int version, bool allow_flex_ops) { - this->type = type; - this->custom_code = custom_code; - this->version = version; - - if (type == OperatorType::kUnsupported) { - // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way - // to populate a regular custom op. We need to find a way to fix this. - if (allow_flex_ops) { - // Memorize the original TensorFlow op name. - this->flex_tensorflow_op = custom_code; - // Prefix the custom code of the flex op. - this->custom_code = string(::tflite::kFlexCustomCodePrefix) + custom_code; - this->is_flex_op = true; - - if (IsControlFlowOp(this->flex_tensorflow_op)) { - is_unsupported_flex_op = true; +OperatorKey GetOperatorKey( + const ::toco::Operator& op, + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, + bool allow_flex_ops) { + string name = HelpfulOperatorTypeName(op); + const auto& builtin_ops = GetBuiltinOpsMap(); + + bool is_builtin = false; + OperatorKey key; + if (ops_by_type.count(op.type) != 0) { + key.version = ops_by_type.at(op.type)->GetVersion(op); + name = ops_by_type.at(op.type)->name(); + is_builtin = (builtin_ops.count(name) > 0); + } + + if (is_builtin) { + // For TFLite supported builtin ops, find out its BuiltinOperator enum used + // in FlatBuffer. + key.type = builtin_ops.at(name); + } else { + key.type = BuiltinOperator_CUSTOM; + + key.is_custom_op = true; + if (op.type == OperatorType::kUnsupported) { + const TensorFlowUnsupportedOperator& unsupported_op = + static_cast<const TensorFlowUnsupportedOperator&>(op); + const auto tensorflow_op = unsupported_op.tensorflow_op; + + // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way + // to populate a regular custom op. We need to find a way to fix this. + if (allow_flex_ops) { + // Memorize the original TensorFlow op name. + key.flex_tensorflow_op = tensorflow_op; + // Prefix the custom code of the flex op. + key.custom_code = + string(::tflite::kFlexCustomCodePrefix) + tensorflow_op; + key.is_flex_op = true; + + if (IsControlFlowOp(tensorflow_op)) { + key.is_unsupported_flex_op = true; + } + } else { + key.custom_code = tensorflow_op; } + } else { + // For Toco-supported/TFLite-unsupported ops, currently we produce a + // custom op. This gives developers a chance to implement custom ops. + // TODO(b/116800229): Also produce Toco-supported/TFLite-unsupported ops + // as Flex ops when Flex mode is enabled. + key.custom_code = name; } } + return key; } void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { @@ -145,6 +177,7 @@ void LoadOperatorsMap( ++index; } } + } // namespace details Offset<Vector<Offset<Tensor>>> ExportTensors( @@ -230,7 +263,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes( const Model& model, const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, const details::OperatorsMap& operators_map, FlatBufferBuilder* builder, - std::set<string>* unsupported_ops, const ExportParams& params) { + const ExportParams& params) { // Map from operator name to TF Lite enum value, for all builtins. std::map<string, BuiltinOperator> builtin_ops; for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { @@ -247,37 +280,16 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes( for (const auto& op : model.operators) { const details::OperatorKey operator_key = - GetOperatorKey(*op, ops_by_type, params.allow_flex_ops); + details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops); int op_index = operators_map.at(operator_key); - int op_version = operator_key.version; - string name = HelpfulOperatorTypeName(*op); - bool is_builtin = false; - if (ops_by_type.count(op->type) != 0) { - name = ops_by_type.at(op->type)->name(); - is_builtin = (builtin_ops.count(name) > 0); + flatbuffers::Offset<flatbuffers::String> custom_code = 0; + if (!operator_key.custom_code.empty()) { + custom_code = builder->CreateString(operator_key.custom_code); } - if (is_builtin) { - ordered_opcodes[op_index] = - CreateOperatorCode(*builder, builtin_ops[name], 0, op_version); - } else { - // This could be a kUnsupported, in which case we should be - // able to retrieve the original Tensorflow name from the OperatorKey, or - // this could be a proper TOCO operator that is completely unknown to TF - // Lite. - if (!operator_key.custom_code.empty()) { - name = operator_key.custom_code; - } - // Either way, this is an operator that is not supported by TF Lite, - // so we output it as a custom op and add it to the error summary. - if (unsupported_ops) { - unsupported_ops->insert(name); - } - ordered_opcodes[op_index] = - CreateOperatorCode(*builder, BuiltinOperator_CUSTOM, - builder->CreateString(name), op_version); - } + ordered_opcodes[op_index] = CreateOperatorCode( + *builder, operator_key.type, custom_code, operator_key.version); } std::vector<Offset<OperatorCode>> opcode_vector; @@ -312,7 +324,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators( } int op_index = operators_map.at( - GetOperatorKey(*op, ops_by_type, params.allow_flex_ops)); + details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops)); auto tflite_op_it = ops_by_type.find(op->type); BaseOperator* tflite_op = tflite_op_it == ops_by_type.end() @@ -386,9 +398,8 @@ void Export( Array empty_array; buffers_to_write.push_back(&empty_array); - std::set<string> unsupported_ops; - auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, - &builder, &unsupported_ops, params); + auto op_codes = + ExportOperatorCodes(model, ops_by_type, operators_map, &builder, params); for (const auto& op : model.operators) { if (op->type == OperatorType::kFakeQuant) { @@ -398,7 +409,20 @@ void Export( "for --std_values and --mean_values."; } } - if (!unsupported_ops.empty()) { + + std::set<string> custom_ops; + std::set<string> unsupported_flex_ops; + for (const auto& it : operators_map) { + const details::OperatorKey& key = it.first; + if (key.is_custom_op) { + custom_ops.insert(key.custom_code); + } + if (key.is_unsupported_flex_op) { + unsupported_flex_ops.insert(key.flex_tensorflow_op); + } + } + + if (!custom_ops.empty()) { if (!params.allow_custom_ops) { // Remove ExpandDims and ReorderAxes from unimplemented list unless they // compose the list. Both ops are removed during graph transformations. @@ -406,14 +430,14 @@ void Export( // transformation is unable to run because the output shape is not // defined. This causes unnecessary confusion during model conversion // time. - std::set<string> unsupported_ops_final; - for (const auto& op_type : unsupported_ops) { + std::set<string> custom_ops_final; + for (const auto& op_type : custom_ops) { if (op_type != "ReorderAxes" && op_type != "ExpandDims") { - unsupported_ops_final.insert(op_type); + custom_ops_final.insert(op_type); } } - if (unsupported_ops_final.empty()) { - unsupported_ops_final = unsupported_ops; + if (custom_ops_final.empty()) { + custom_ops_final = custom_ops; } LOG(QFATAL) @@ -423,13 +447,13 @@ void Export( "--allow_custom_ops, or by setting allow_custom_ops=True " "when calling tf.contrib.lite.TFLiteConverter(). Here is a list " "of operators for which you will need custom implementations: " - << absl::StrJoin(unsupported_ops_final, ", ") << "."; + << absl::StrJoin(custom_ops_final, ", ") << "."; } std::set<string> unsupported_control_flow_ops; // Check if unsupported ops contains control flow ops. It's impossible // to implement these ops as custom ops at the moment. - for (const auto& op : unsupported_ops) { + for (const auto& op : custom_ops) { if (IsControlFlowOp(op)) { unsupported_control_flow_ops.insert(op); } @@ -441,14 +465,6 @@ void Export( } } - std::set<string> unsupported_flex_ops; - for (const auto& it : operators_map) { - const details::OperatorKey& key = it.first; - if (key.is_unsupported_flex_op) { - unsupported_flex_ops.insert(key.custom_code); - } - } - if (!unsupported_flex_ops.empty()) { LOG(QFATAL) << "Some of the operators in the model are not supported by " "TensorFlow Flex runtime: " diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 9efb282c6c..c627f48086 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -81,16 +81,20 @@ using TensorsMap = std::unordered_map<string, int>; // Only when `type` is `kUnsupported`, `custom_code` is filled to // identify which operation is used. struct OperatorKey { - OperatorKey(OperatorType type, const std::string& custom_code, int version, - bool allow_flex_ops = false); + OperatorKey() {} + OperatorKey(::tflite::BuiltinOperator type, const std::string& custom_code, + int version) + : type(type), custom_code(custom_code), version(version) {} // Only `type`, `custom_code` and `version` is used to compute hash and // identity. - OperatorType type; + ::tflite::BuiltinOperator type = ::tflite::BuiltinOperator_CUSTOM; std::string custom_code; - int version; + int version = 1; - // THe fields below are not used to compute hash and identity. + // The fields below are not used to compute hash and identity. + // TODO(ycling): Consider to change these fields to accessor functions. + bool is_custom_op = false; bool is_flex_op = false; bool is_unsupported_flex_op = false; // The original TensorFlow op name for the flex op. Filled only when @@ -124,6 +128,11 @@ struct OperatorKey { }; }; +OperatorKey GetOperatorKey( + const ::toco::Operator& op, + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, + bool allow_flex_ops); + // A maps from operator type to its final position in the TF Lite buffer. using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>; diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index a71a64d56f..d48ab78285 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -105,13 +105,15 @@ TEST_F(ExportTest, LoadOperatorsMap) { details::OperatorsMap operators; const auto ops_by_type = BuildOperatorByTypeMap(); - // TODO(ycling): Add a test for allow_flex_ops. details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); - EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]); - EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]); - EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]); - EXPECT_EQ(3, operators[details::OperatorKey(OperatorType::kUnsupported, + EXPECT_EQ( + 0, operators[details::OperatorKey(::tflite::BuiltinOperator_ADD, "", 1)]); + EXPECT_EQ(1, operators[details::OperatorKey(::tflite::BuiltinOperator_CONV_2D, + "", 1)]); + EXPECT_EQ(2, operators[details::OperatorKey(::tflite::BuiltinOperator_CUSTOM, "MyCrazyOp", 1)]); + EXPECT_EQ( + 3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]); } TEST_F(ExportTest, Export) { @@ -133,7 +135,7 @@ TEST_F(ExportTest, Export) { } EXPECT_THAT(names, ElementsAre("builtin:ADD", "builtin:CONV_2D", - "builtin:SUB", "custom:MyCrazyOp")); + "custom:MyCrazyOp", "builtin:SUB")); std::vector<uint32_t> indices; auto operators = (*model->subgraphs())[0]->operators(); @@ -142,7 +144,7 @@ TEST_F(ExportTest, Export) { indices.push_back(op->opcode_index()); } - EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2)); + EXPECT_THAT(indices, ElementsAre(1, 0, 2, 3)); } TEST_F(ExportTest, QuantizeWeights) { @@ -257,7 +259,8 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) { details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(1, operators.size()); - EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); + EXPECT_EQ(0, operators.at(details::OperatorKey( + ::tflite::BuiltinOperator_CONV_2D, "", 1))); } TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) { @@ -268,7 +271,8 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) { details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(1, operators.size()); - EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); + EXPECT_EQ(0, operators.at(details::OperatorKey( + ::tflite::BuiltinOperator_CONV_2D, "", 2))); } TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) { @@ -280,8 +284,10 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) { details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(2, operators.size()); - EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); - EXPECT_EQ(1, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); + EXPECT_EQ(0, operators.at(details::OperatorKey( + ::tflite::BuiltinOperator_CONV_2D, "", 1))); + EXPECT_EQ(1, operators.at(details::OperatorKey( + ::tflite::BuiltinOperator_CONV_2D, "", 2))); } TEST_F(VersionedOpExportTest, Export) { @@ -314,38 +320,61 @@ TEST_F(VersionedOpExportTest, Export) { } TEST(OperatorKeyTest, TestBuiltinOp) { - details::OperatorKey key(OperatorType::kConv, "", 2); - EXPECT_EQ(key.type, OperatorType::kConv); + auto op = absl::make_unique<ConvOperator>(); + + const auto ops_by_type = BuildOperatorByTypeMap(); + const auto key = details::GetOperatorKey(*op, ops_by_type, false); + + EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CONV_2D); EXPECT_EQ(key.custom_code, ""); - EXPECT_EQ(key.version, 2); + EXPECT_EQ(key.version, 1); +} + +TEST(OperatorKeyTest, TestCustomOp) { + auto op = absl::make_unique<TensorFlowUnsupportedOperator>(); + op->tensorflow_op = "MyCrazyCustomOp"; + + const auto ops_by_type = BuildOperatorByTypeMap(); + const auto key = details::GetOperatorKey(*op, ops_by_type, false); + + EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM); + EXPECT_EQ(key.custom_code, "MyCrazyCustomOp"); + EXPECT_EQ(key.version, 1); } TEST(OperatorKeyTest, TestFlexOp) { + auto op = absl::make_unique<TensorFlowUnsupportedOperator>(); + op->tensorflow_op = "BatchMatMul"; + + const auto ops_by_type = BuildOperatorByTypeMap(); { - details::OperatorKey key(OperatorType::kUnsupported, "SomeUnsupportedOp", 1, - false); - EXPECT_EQ(key.type, OperatorType::kUnsupported); + const auto key = details::GetOperatorKey(*op, ops_by_type, false); // It shouldn't be converted to Flex op if `allow_flex_op` is false. - EXPECT_EQ(key.custom_code, "SomeUnsupportedOp"); + EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM); + EXPECT_EQ(key.custom_code, "BatchMatMul"); EXPECT_EQ(key.version, 1); EXPECT_FALSE(key.is_flex_op); } { - details::OperatorKey key(OperatorType::kUnsupported, "SomeUnsupportedOp", 1, - true); - EXPECT_EQ(key.type, OperatorType::kUnsupported); // Verify that the custom op name is prefixed by "Flex" and `is_flex_op` // is true. - EXPECT_EQ(key.custom_code, "FlexSomeUnsupportedOp"); + const auto key = details::GetOperatorKey(*op, ops_by_type, true); + EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM); + EXPECT_EQ(key.custom_code, "FlexBatchMatMul"); EXPECT_EQ(key.version, 1); EXPECT_TRUE(key.is_flex_op); } } TEST(OperatorKeyTest, TestFlexWithControlFlowOp) { - details::OperatorKey key(OperatorType::kUnsupported, "Merge", 1, true); - EXPECT_EQ(key.type, OperatorType::kUnsupported); + auto op = absl::make_unique<TensorFlowUnsupportedOperator>(); + op->tensorflow_op = "Merge"; + + const auto ops_by_type = BuildOperatorByTypeMap(); + const auto key = details::GetOperatorKey(*op, ops_by_type, true); + + EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM); EXPECT_EQ(key.custom_code, "FlexMerge"); EXPECT_EQ(key.version, 1); EXPECT_TRUE(key.is_flex_op); |