diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-10-03 17:25:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 17:29:59 -0700 |
commit | 4da5b350e1c062b9d55896ee872e0e4790f30bcb (patch) | |
tree | 4057f0340517f75fead4053093d5b10aaacac9ba /tensorflow/contrib | |
parent | c842d38978a0babb373fe2acbb0231960aa1c1d0 (diff) |
TFLite Flex: Blacklist Control Flow Ops
PiperOrigin-RevId: 215658384
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export.cc | 132 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export.h | 20 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export_test.cc | 40 |
3 files changed, 152 insertions, 40 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 0c9fac249c..45ca7f7f0c 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -47,6 +47,22 @@ using ::tflite::Tensor; namespace { +// Check if a TensorFlow Op is a control flow op by its name. +bool IsControlFlowOp(const string& tensorflow_op) { + // Technically this is equalivent to `::tensorflow::Node::IsControlFlow()`. + // It requires to construct a `::tensorflow::Graph` to use that helper + // function, so we simply hardcode the list of control flow ops here. + if (tensorflow_op == "Switch" || tensorflow_op == "RefSwitch" || + tensorflow_op == "Merge" || tensorflow_op == "RefMerge" || + tensorflow_op == "Enter" || tensorflow_op == "RefEnter" || + tensorflow_op == "Exit" || tensorflow_op == "RefExit" || + tensorflow_op == "NextIteration" || tensorflow_op == "RefNextIteration") { + return true; + } + // TODO(ycling): Also check how to handle Variable ops and Assign ops. + return false; +} + details::OperatorKey GetOperatorKey( const ::toco::Operator& op, const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, @@ -55,21 +71,13 @@ details::OperatorKey GetOperatorKey( if (op.type == OperatorType::kUnsupported) { const TensorFlowUnsupportedOperator& unsupported_op = static_cast<const TensorFlowUnsupportedOperator&>(op); - - // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way - // to populate a regular custom op. We need to find a way to fix this. - if (allow_flex_ops) { - custom_code = string(::tflite::kFlexCustomCodePrefix) + - unsupported_op.tensorflow_op; - } else { - custom_code = unsupported_op.tensorflow_op; - } + custom_code = unsupported_op.tensorflow_op; } int version = 1; if (ops_by_type.count(op.type) != 0) { version = ops_by_type.at(op.type)->GetVersion(op); } - return details::OperatorKey(op.type, custom_code, version); + return details::OperatorKey(op.type, custom_code, version, allow_flex_ops); } void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder, @@ -83,6 +91,29 @@ void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder, namespace details { +OperatorKey::OperatorKey(OperatorType type, const std::string& custom_code, + int version, bool allow_flex_ops) { + this->type = type; + this->custom_code = custom_code; + this->version = version; + + if (type == OperatorType::kUnsupported) { + // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way + // to populate a regular custom op. We need to find a way to fix this. + if (allow_flex_ops) { + // Memorize the original TensorFlow op name. + this->flex_tensorflow_op = custom_code; + // Prefix the custom code of the flex op. + this->custom_code = string(::tflite::kFlexCustomCodePrefix) + custom_code; + this->is_flex_op = true; + + if (IsControlFlowOp(this->flex_tensorflow_op)) { + is_unsupported_flex_op = true; + } + } + } +} + void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { // First find a list of unique array names. std::set<string> names; @@ -199,7 +230,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes( const Model& model, const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, const details::OperatorsMap& operators_map, FlatBufferBuilder* builder, - std::set<string>* error_summary, const ExportParams& params) { + std::set<string>* unsupported_ops, const ExportParams& params) { // Map from operator name to TF Lite enum value, for all builtins. std::map<string, BuiltinOperator> builtin_ops; for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { @@ -240,8 +271,8 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes( } // Either way, this is an operator that is not supported by TF Lite, // so we output it as a custom op and add it to the error summary. - if (error_summary) { - error_summary->insert(name); + if (unsupported_ops) { + unsupported_ops->insert(name); } ordered_opcodes[op_index] = CreateOperatorCode(*builder, BuiltinOperator_CUSTOM, @@ -355,9 +386,9 @@ void Export( Array empty_array; buffers_to_write.push_back(&empty_array); - std::set<string> error_summary; + std::set<string> unsupported_ops; auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, - &builder, &error_summary, params); + &builder, &unsupported_ops, params); for (const auto& op : model.operators) { if (op->type == OperatorType::kFakeQuant) { @@ -367,30 +398,61 @@ void Export( "for --std_values and --mean_values."; } } - if (!params.allow_custom_ops && !error_summary.empty()) { - // Remove ExpandDims and ReorderAxes from unimplemented list unless they - // compose the list. Both ops are removed during graph transformations. - // However, if an op is unimplemented earlier in the model, the graph - // transformation is unable to run because the output shape is not defined. - // This causes unnecessary confusion during model conversion time. - std::set<string> error_summary_final; - for (const auto& op_type : error_summary) { - if (op_type != "ReorderAxes" && op_type != "ExpandDims") { - error_summary_final.insert(op_type); + if (!unsupported_ops.empty()) { + if (!params.allow_custom_ops) { + // Remove ExpandDims and ReorderAxes from unimplemented list unless they + // compose the list. Both ops are removed during graph transformations. + // However, if an op is unimplemented earlier in the model, the graph + // transformation is unable to run because the output shape is not + // defined. This causes unnecessary confusion during model conversion + // time. + std::set<string> unsupported_ops_final; + for (const auto& op_type : unsupported_ops) { + if (op_type != "ReorderAxes" && op_type != "ExpandDims") { + unsupported_ops_final.insert(op_type); + } + } + if (unsupported_ops_final.empty()) { + unsupported_ops_final = unsupported_ops; + } + + LOG(QFATAL) + << "Some of the operators in the model are not supported by " + "the standard TensorFlow Lite runtime. If you have a custom " + "implementation for them you can disable this error with " + "--allow_custom_ops, or by setting allow_custom_ops=True " + "when calling tf.contrib.lite.TFLiteConverter(). Here is a list " + "of operators for which you will need custom implementations: " + << absl::StrJoin(unsupported_ops_final, ", ") << "."; + } + + std::set<string> unsupported_control_flow_ops; + // Check if unsupported ops contains control flow ops. It's impossible + // to implement these ops as custom ops at the moment. + for (const auto& op : unsupported_ops) { + if (IsControlFlowOp(op)) { + unsupported_control_flow_ops.insert(op); } } - if (error_summary_final.empty()) { - error_summary_final = error_summary; + if (!unsupported_control_flow_ops.empty()) { + LOG(QFATAL) + << "TensorFlow Lite currently doesn't support control flow ops: " + << absl::StrJoin(unsupported_control_flow_ops, ", ") << "."; } + } + + std::set<string> unsupported_flex_ops; + for (const auto& it : operators_map) { + const details::OperatorKey& key = it.first; + if (key.is_unsupported_flex_op) { + unsupported_flex_ops.insert(key.custom_code); + } + } - LOG(QFATAL) - << "Some of the operators in the model are not supported by " - "the standard TensorFlow Lite runtime. If you have a custom " - "implementation for them you can disable this error with " - "--allow_custom_ops, or by setting allow_custom_ops=True " - "when calling tf.contrib.lite.TFLiteConverter(). Here is a list " - "of operators for which you will need custom implementations: " - << absl::StrJoin(error_summary_final, ", ") << "."; + if (!unsupported_flex_ops.empty()) { + LOG(QFATAL) << "Some of the operators in the model are not supported by " + "TensorFlow Flex runtime: " + << absl::StrJoin(unsupported_flex_ops, ", ") << "."; } std::set<int32_t> variable_tensor_indices; diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 29d6de4049..9efb282c6c 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -81,11 +81,21 @@ using TensorsMap = std::unordered_map<string, int>; // Only when `type` is `kUnsupported`, `custom_code` is filled to // identify which operation is used. struct OperatorKey { - OperatorKey(OperatorType type, const std::string& custom_code, int version) - : type(type), custom_code(custom_code), version(version) {} - const OperatorType type; - const std::string custom_code; - const int version; + OperatorKey(OperatorType type, const std::string& custom_code, int version, + bool allow_flex_ops = false); + + // Only `type`, `custom_code` and `version` is used to compute hash and + // identity. + OperatorType type; + std::string custom_code; + int version; + + // THe fields below are not used to compute hash and identity. + bool is_flex_op = false; + bool is_unsupported_flex_op = false; + // The original TensorFlow op name for the flex op. Filled only when + // `is_flex_op` is true. + std::string flex_tensorflow_op; bool operator<(const OperatorKey& other) const { if (type < other.type) return true; diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index 93882a91a7..a71a64d56f 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -313,6 +313,46 @@ TEST_F(VersionedOpExportTest, Export) { EXPECT_EQ(1, (*operators)[1]->opcode_index()); } +TEST(OperatorKeyTest, TestBuiltinOp) { + details::OperatorKey key(OperatorType::kConv, "", 2); + EXPECT_EQ(key.type, OperatorType::kConv); + EXPECT_EQ(key.custom_code, ""); + EXPECT_EQ(key.version, 2); +} + +TEST(OperatorKeyTest, TestFlexOp) { + { + details::OperatorKey key(OperatorType::kUnsupported, "SomeUnsupportedOp", 1, + false); + EXPECT_EQ(key.type, OperatorType::kUnsupported); + // It shouldn't be converted to Flex op if `allow_flex_op` is false. + EXPECT_EQ(key.custom_code, "SomeUnsupportedOp"); + EXPECT_EQ(key.version, 1); + EXPECT_FALSE(key.is_flex_op); + } + + { + details::OperatorKey key(OperatorType::kUnsupported, "SomeUnsupportedOp", 1, + true); + EXPECT_EQ(key.type, OperatorType::kUnsupported); + // Verify that the custom op name is prefixed by "Flex" and `is_flex_op` + // is true. + EXPECT_EQ(key.custom_code, "FlexSomeUnsupportedOp"); + EXPECT_EQ(key.version, 1); + EXPECT_TRUE(key.is_flex_op); + } +} + +TEST(OperatorKeyTest, TestFlexWithControlFlowOp) { + details::OperatorKey key(OperatorType::kUnsupported, "Merge", 1, true); + EXPECT_EQ(key.type, OperatorType::kUnsupported); + EXPECT_EQ(key.custom_code, "FlexMerge"); + EXPECT_EQ(key.version, 1); + EXPECT_TRUE(key.is_flex_op); + // The control flow ops should be marked as unsupported. + EXPECT_TRUE(key.is_unsupported_flex_op); +} + // TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators. } // namespace |