aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-10-05 10:37:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 10:40:59 -0700
commite2f80439c5bfee56581875219ea83cc5307854f5 (patch)
treee3d8d3b7c8952225174cebcf1d430f8fa3ac588d /tensorflow/contrib/lite
parentd493a7f2fdbbc29a292741135f4c1598352e876b (diff)
Refactoring TFLite export code. Unify OperatorCode generation logic.
PiperOrigin-RevId: 215928419
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc176
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h19
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc77
3 files changed, 163 insertions, 109 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 45ca7f7f0c..f6f76e48a4 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -63,21 +63,21 @@ bool IsControlFlowOp(const string& tensorflow_op) {
return false;
}
-details::OperatorKey GetOperatorKey(
- const ::toco::Operator& op,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
- bool allow_flex_ops) {
- string custom_code;
- if (op.type == OperatorType::kUnsupported) {
- const TensorFlowUnsupportedOperator& unsupported_op =
- static_cast<const TensorFlowUnsupportedOperator&>(op);
- custom_code = unsupported_op.tensorflow_op;
- }
- int version = 1;
- if (ops_by_type.count(op.type) != 0) {
- version = ops_by_type.at(op.type)->GetVersion(op);
+// Map from operator name to TF Lite enum value, for all builtins.
+const std::map<string, BuiltinOperator>& GetBuiltinOpsMap() {
+ static std::map<string, BuiltinOperator>* builtin_ops = nullptr;
+ if (builtin_ops == nullptr) {
+ builtin_ops = new std::map<string, BuiltinOperator>();
+
+ for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
+ BuiltinOperator op = static_cast<BuiltinOperator>(i);
+ string name = EnumNameBuiltinOperator(op);
+ if (op != BuiltinOperator_CUSTOM && !name.empty()) {
+ (*builtin_ops)[name] = op;
+ }
+ }
}
- return details::OperatorKey(op.type, custom_code, version, allow_flex_ops);
+ return *builtin_ops;
}
void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
@@ -91,27 +91,59 @@ void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
namespace details {
-OperatorKey::OperatorKey(OperatorType type, const std::string& custom_code,
- int version, bool allow_flex_ops) {
- this->type = type;
- this->custom_code = custom_code;
- this->version = version;
-
- if (type == OperatorType::kUnsupported) {
- // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
- // to populate a regular custom op. We need to find a way to fix this.
- if (allow_flex_ops) {
- // Memorize the original TensorFlow op name.
- this->flex_tensorflow_op = custom_code;
- // Prefix the custom code of the flex op.
- this->custom_code = string(::tflite::kFlexCustomCodePrefix) + custom_code;
- this->is_flex_op = true;
-
- if (IsControlFlowOp(this->flex_tensorflow_op)) {
- is_unsupported_flex_op = true;
+OperatorKey GetOperatorKey(
+ const ::toco::Operator& op,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_flex_ops) {
+ string name = HelpfulOperatorTypeName(op);
+ const auto& builtin_ops = GetBuiltinOpsMap();
+
+ bool is_builtin = false;
+ OperatorKey key;
+ if (ops_by_type.count(op.type) != 0) {
+ key.version = ops_by_type.at(op.type)->GetVersion(op);
+ name = ops_by_type.at(op.type)->name();
+ is_builtin = (builtin_ops.count(name) > 0);
+ }
+
+ if (is_builtin) {
+ // For TFLite supported builtin ops, find out its BuiltinOperator enum used
+ // in FlatBuffer.
+ key.type = builtin_ops.at(name);
+ } else {
+ key.type = BuiltinOperator_CUSTOM;
+
+ key.is_custom_op = true;
+ if (op.type == OperatorType::kUnsupported) {
+ const TensorFlowUnsupportedOperator& unsupported_op =
+ static_cast<const TensorFlowUnsupportedOperator&>(op);
+ const auto tensorflow_op = unsupported_op.tensorflow_op;
+
+ // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
+ // to populate a regular custom op. We need to find a way to fix this.
+ if (allow_flex_ops) {
+ // Memorize the original TensorFlow op name.
+ key.flex_tensorflow_op = tensorflow_op;
+ // Prefix the custom code of the flex op.
+ key.custom_code =
+ string(::tflite::kFlexCustomCodePrefix) + tensorflow_op;
+ key.is_flex_op = true;
+
+ if (IsControlFlowOp(tensorflow_op)) {
+ key.is_unsupported_flex_op = true;
+ }
+ } else {
+ key.custom_code = tensorflow_op;
}
+ } else {
+ // For Toco-supported/TFLite-unsupported ops, currently we produce a
+ // custom op. This gives developers a chance to implement custom ops.
+ // TODO(b/116800229): Also produce Toco-supported/TFLite-unsupported ops
+ // as Flex ops when Flex mode is enabled.
+ key.custom_code = name;
}
}
+ return key;
}
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
@@ -145,6 +177,7 @@ void LoadOperatorsMap(
++index;
}
}
+
} // namespace details
Offset<Vector<Offset<Tensor>>> ExportTensors(
@@ -230,7 +263,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
- std::set<string>* unsupported_ops, const ExportParams& params) {
+ const ExportParams& params) {
// Map from operator name to TF Lite enum value, for all builtins.
std::map<string, BuiltinOperator> builtin_ops;
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
@@ -247,37 +280,16 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
for (const auto& op : model.operators) {
const details::OperatorKey operator_key =
- GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
+ details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
int op_index = operators_map.at(operator_key);
- int op_version = operator_key.version;
- string name = HelpfulOperatorTypeName(*op);
- bool is_builtin = false;
- if (ops_by_type.count(op->type) != 0) {
- name = ops_by_type.at(op->type)->name();
- is_builtin = (builtin_ops.count(name) > 0);
+ flatbuffers::Offset<flatbuffers::String> custom_code = 0;
+ if (!operator_key.custom_code.empty()) {
+ custom_code = builder->CreateString(operator_key.custom_code);
}
- if (is_builtin) {
- ordered_opcodes[op_index] =
- CreateOperatorCode(*builder, builtin_ops[name], 0, op_version);
- } else {
- // This could be a kUnsupported, in which case we should be
- // able to retrieve the original Tensorflow name from the OperatorKey, or
- // this could be a proper TOCO operator that is completely unknown to TF
- // Lite.
- if (!operator_key.custom_code.empty()) {
- name = operator_key.custom_code;
- }
- // Either way, this is an operator that is not supported by TF Lite,
- // so we output it as a custom op and add it to the error summary.
- if (unsupported_ops) {
- unsupported_ops->insert(name);
- }
- ordered_opcodes[op_index] =
- CreateOperatorCode(*builder, BuiltinOperator_CUSTOM,
- builder->CreateString(name), op_version);
- }
+ ordered_opcodes[op_index] = CreateOperatorCode(
+ *builder, operator_key.type, custom_code, operator_key.version);
}
std::vector<Offset<OperatorCode>> opcode_vector;
@@ -312,7 +324,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
}
int op_index = operators_map.at(
- GetOperatorKey(*op, ops_by_type, params.allow_flex_ops));
+ details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops));
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -386,9 +398,8 @@ void Export(
Array empty_array;
buffers_to_write.push_back(&empty_array);
- std::set<string> unsupported_ops;
- auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
- &builder, &unsupported_ops, params);
+ auto op_codes =
+ ExportOperatorCodes(model, ops_by_type, operators_map, &builder, params);
for (const auto& op : model.operators) {
if (op->type == OperatorType::kFakeQuant) {
@@ -398,7 +409,20 @@ void Export(
"for --std_values and --mean_values.";
}
}
- if (!unsupported_ops.empty()) {
+
+ std::set<string> custom_ops;
+ std::set<string> unsupported_flex_ops;
+ for (const auto& it : operators_map) {
+ const details::OperatorKey& key = it.first;
+ if (key.is_custom_op) {
+ custom_ops.insert(key.custom_code);
+ }
+ if (key.is_unsupported_flex_op) {
+ unsupported_flex_ops.insert(key.flex_tensorflow_op);
+ }
+ }
+
+ if (!custom_ops.empty()) {
if (!params.allow_custom_ops) {
// Remove ExpandDims and ReorderAxes from unimplemented list unless they
// compose the list. Both ops are removed during graph transformations.
@@ -406,14 +430,14 @@ void Export(
// transformation is unable to run because the output shape is not
// defined. This causes unnecessary confusion during model conversion
// time.
- std::set<string> unsupported_ops_final;
- for (const auto& op_type : unsupported_ops) {
+ std::set<string> custom_ops_final;
+ for (const auto& op_type : custom_ops) {
if (op_type != "ReorderAxes" && op_type != "ExpandDims") {
- unsupported_ops_final.insert(op_type);
+ custom_ops_final.insert(op_type);
}
}
- if (unsupported_ops_final.empty()) {
- unsupported_ops_final = unsupported_ops;
+ if (custom_ops_final.empty()) {
+ custom_ops_final = custom_ops;
}
LOG(QFATAL)
@@ -423,13 +447,13 @@ void Export(
"--allow_custom_ops, or by setting allow_custom_ops=True "
"when calling tf.contrib.lite.TFLiteConverter(). Here is a list "
"of operators for which you will need custom implementations: "
- << absl::StrJoin(unsupported_ops_final, ", ") << ".";
+ << absl::StrJoin(custom_ops_final, ", ") << ".";
}
std::set<string> unsupported_control_flow_ops;
// Check if unsupported ops contains control flow ops. It's impossible
// to implement these ops as custom ops at the moment.
- for (const auto& op : unsupported_ops) {
+ for (const auto& op : custom_ops) {
if (IsControlFlowOp(op)) {
unsupported_control_flow_ops.insert(op);
}
@@ -441,14 +465,6 @@ void Export(
}
}
- std::set<string> unsupported_flex_ops;
- for (const auto& it : operators_map) {
- const details::OperatorKey& key = it.first;
- if (key.is_unsupported_flex_op) {
- unsupported_flex_ops.insert(key.custom_code);
- }
- }
-
if (!unsupported_flex_ops.empty()) {
LOG(QFATAL) << "Some of the operators in the model are not supported by "
"TensorFlow Flex runtime: "
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 9efb282c6c..c627f48086 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -81,16 +81,20 @@ using TensorsMap = std::unordered_map<string, int>;
// Only when `type` is `kUnsupported`, `custom_code` is filled to
// identify which operation is used.
struct OperatorKey {
- OperatorKey(OperatorType type, const std::string& custom_code, int version,
- bool allow_flex_ops = false);
+ OperatorKey() {}
+ OperatorKey(::tflite::BuiltinOperator type, const std::string& custom_code,
+ int version)
+ : type(type), custom_code(custom_code), version(version) {}
// Only `type`, `custom_code` and `version` is used to compute hash and
// identity.
- OperatorType type;
+ ::tflite::BuiltinOperator type = ::tflite::BuiltinOperator_CUSTOM;
std::string custom_code;
- int version;
+ int version = 1;
- // THe fields below are not used to compute hash and identity.
+ // The fields below are not used to compute hash and identity.
+ // TODO(ycling): Consider to change these fields to accessor functions.
+ bool is_custom_op = false;
bool is_flex_op = false;
bool is_unsupported_flex_op = false;
// The original TensorFlow op name for the flex op. Filled only when
@@ -124,6 +128,11 @@ struct OperatorKey {
};
};
+OperatorKey GetOperatorKey(
+ const ::toco::Operator& op,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_flex_ops);
+
// A maps from operator type to its final position in the TF Lite buffer.
using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index a71a64d56f..d48ab78285 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -105,13 +105,15 @@ TEST_F(ExportTest, LoadOperatorsMap) {
details::OperatorsMap operators;
const auto ops_by_type = BuildOperatorByTypeMap();
- // TODO(ycling): Add a test for allow_flex_ops.
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
- EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
- EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
- EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
- EXPECT_EQ(3, operators[details::OperatorKey(OperatorType::kUnsupported,
+ EXPECT_EQ(
+ 0, operators[details::OperatorKey(::tflite::BuiltinOperator_ADD, "", 1)]);
+ EXPECT_EQ(1, operators[details::OperatorKey(::tflite::BuiltinOperator_CONV_2D,
+ "", 1)]);
+ EXPECT_EQ(2, operators[details::OperatorKey(::tflite::BuiltinOperator_CUSTOM,
"MyCrazyOp", 1)]);
+ EXPECT_EQ(
+ 3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
}
TEST_F(ExportTest, Export) {
@@ -133,7 +135,7 @@ TEST_F(ExportTest, Export) {
}
EXPECT_THAT(names, ElementsAre("builtin:ADD", "builtin:CONV_2D",
- "builtin:SUB", "custom:MyCrazyOp"));
+ "custom:MyCrazyOp", "builtin:SUB"));
std::vector<uint32_t> indices;
auto operators = (*model->subgraphs())[0]->operators();
@@ -142,7 +144,7 @@ TEST_F(ExportTest, Export) {
indices.push_back(op->opcode_index());
}
- EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2));
+ EXPECT_THAT(indices, ElementsAre(1, 0, 2, 3));
}
TEST_F(ExportTest, QuantizeWeights) {
@@ -257,7 +259,8 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
- EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
+ EXPECT_EQ(0, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 1)));
}
TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
@@ -268,7 +271,8 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
- EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
+ EXPECT_EQ(0, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 2)));
}
TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
@@ -280,8 +284,10 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(2, operators.size());
- EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
- EXPECT_EQ(1, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
+ EXPECT_EQ(0, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 1)));
+ EXPECT_EQ(1, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 2)));
}
TEST_F(VersionedOpExportTest, Export) {
@@ -314,38 +320,61 @@ TEST_F(VersionedOpExportTest, Export) {
}
TEST(OperatorKeyTest, TestBuiltinOp) {
- details::OperatorKey key(OperatorType::kConv, "", 2);
- EXPECT_EQ(key.type, OperatorType::kConv);
+ auto op = absl::make_unique<ConvOperator>();
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ const auto key = details::GetOperatorKey(*op, ops_by_type, false);
+
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CONV_2D);
EXPECT_EQ(key.custom_code, "");
- EXPECT_EQ(key.version, 2);
+ EXPECT_EQ(key.version, 1);
+}
+
+TEST(OperatorKeyTest, TestCustomOp) {
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ op->tensorflow_op = "MyCrazyCustomOp";
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ const auto key = details::GetOperatorKey(*op, ops_by_type, false);
+
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "MyCrazyCustomOp");
+ EXPECT_EQ(key.version, 1);
}
TEST(OperatorKeyTest, TestFlexOp) {
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ op->tensorflow_op = "BatchMatMul";
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
{
- details::OperatorKey key(OperatorType::kUnsupported, "SomeUnsupportedOp", 1,
- false);
- EXPECT_EQ(key.type, OperatorType::kUnsupported);
+ const auto key = details::GetOperatorKey(*op, ops_by_type, false);
// It shouldn't be converted to Flex op if `allow_flex_op` is false.
- EXPECT_EQ(key.custom_code, "SomeUnsupportedOp");
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "BatchMatMul");
EXPECT_EQ(key.version, 1);
EXPECT_FALSE(key.is_flex_op);
}
{
- details::OperatorKey key(OperatorType::kUnsupported, "SomeUnsupportedOp", 1,
- true);
- EXPECT_EQ(key.type, OperatorType::kUnsupported);
// Verify that the custom op name is prefixed by "Flex" and `is_flex_op`
// is true.
- EXPECT_EQ(key.custom_code, "FlexSomeUnsupportedOp");
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "FlexBatchMatMul");
EXPECT_EQ(key.version, 1);
EXPECT_TRUE(key.is_flex_op);
}
}
TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
- details::OperatorKey key(OperatorType::kUnsupported, "Merge", 1, true);
- EXPECT_EQ(key.type, OperatorType::kUnsupported);
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ op->tensorflow_op = "Merge";
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
EXPECT_EQ(key.custom_code, "FlexMerge");
EXPECT_EQ(key.version, 1);
EXPECT_TRUE(key.is_flex_op);