aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tflite/export.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/export.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc176
1 files changed, 96 insertions, 80 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 45ca7f7f0c..f6f76e48a4 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -63,21 +63,21 @@ bool IsControlFlowOp(const string& tensorflow_op) {
return false;
}
-details::OperatorKey GetOperatorKey(
- const ::toco::Operator& op,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
- bool allow_flex_ops) {
- string custom_code;
- if (op.type == OperatorType::kUnsupported) {
- const TensorFlowUnsupportedOperator& unsupported_op =
- static_cast<const TensorFlowUnsupportedOperator&>(op);
- custom_code = unsupported_op.tensorflow_op;
- }
- int version = 1;
- if (ops_by_type.count(op.type) != 0) {
- version = ops_by_type.at(op.type)->GetVersion(op);
+// Map from operator name to TF Lite enum value, for all builtins.
+const std::map<string, BuiltinOperator>& GetBuiltinOpsMap() {
+ static std::map<string, BuiltinOperator>* builtin_ops = nullptr;
+ if (builtin_ops == nullptr) {
+ builtin_ops = new std::map<string, BuiltinOperator>();
+
+ for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
+ BuiltinOperator op = static_cast<BuiltinOperator>(i);
+ string name = EnumNameBuiltinOperator(op);
+ if (op != BuiltinOperator_CUSTOM && !name.empty()) {
+ (*builtin_ops)[name] = op;
+ }
+ }
}
- return details::OperatorKey(op.type, custom_code, version, allow_flex_ops);
+ return *builtin_ops;
}
void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
@@ -91,27 +91,59 @@ void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
namespace details {
-OperatorKey::OperatorKey(OperatorType type, const std::string& custom_code,
- int version, bool allow_flex_ops) {
- this->type = type;
- this->custom_code = custom_code;
- this->version = version;
-
- if (type == OperatorType::kUnsupported) {
- // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
- // to populate a regular custom op. We need to find a way to fix this.
- if (allow_flex_ops) {
- // Memorize the original TensorFlow op name.
- this->flex_tensorflow_op = custom_code;
- // Prefix the custom code of the flex op.
- this->custom_code = string(::tflite::kFlexCustomCodePrefix) + custom_code;
- this->is_flex_op = true;
-
- if (IsControlFlowOp(this->flex_tensorflow_op)) {
- is_unsupported_flex_op = true;
+OperatorKey GetOperatorKey(
+ const ::toco::Operator& op,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_flex_ops) {
+ string name = HelpfulOperatorTypeName(op);
+ const auto& builtin_ops = GetBuiltinOpsMap();
+
+ bool is_builtin = false;
+ OperatorKey key;
+ if (ops_by_type.count(op.type) != 0) {
+ key.version = ops_by_type.at(op.type)->GetVersion(op);
+ name = ops_by_type.at(op.type)->name();
+ is_builtin = (builtin_ops.count(name) > 0);
+ }
+
+ if (is_builtin) {
+ // For TFLite supported builtin ops, find out its BuiltinOperator enum used
+ // in FlatBuffer.
+ key.type = builtin_ops.at(name);
+ } else {
+ key.type = BuiltinOperator_CUSTOM;
+
+ key.is_custom_op = true;
+ if (op.type == OperatorType::kUnsupported) {
+ const TensorFlowUnsupportedOperator& unsupported_op =
+ static_cast<const TensorFlowUnsupportedOperator&>(op);
+ const auto tensorflow_op = unsupported_op.tensorflow_op;
+
+ // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
+ // to populate a regular custom op. We need to find a way to fix this.
+ if (allow_flex_ops) {
+ // Memorize the original TensorFlow op name.
+ key.flex_tensorflow_op = tensorflow_op;
+ // Prefix the custom code of the flex op.
+ key.custom_code =
+ string(::tflite::kFlexCustomCodePrefix) + tensorflow_op;
+ key.is_flex_op = true;
+
+ if (IsControlFlowOp(tensorflow_op)) {
+ key.is_unsupported_flex_op = true;
+ }
+ } else {
+ key.custom_code = tensorflow_op;
}
+ } else {
+ // For Toco-supported/TFLite-unsupported ops, currently we produce a
+ // custom op. This gives developers a chance to implement custom ops.
+ // TODO(b/116800229): Also produce Toco-supported/TFLite-unsupported ops
+ // as Flex ops when Flex mode is enabled.
+ key.custom_code = name;
}
}
+ return key;
}
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
@@ -145,6 +177,7 @@ void LoadOperatorsMap(
++index;
}
}
+
} // namespace details
Offset<Vector<Offset<Tensor>>> ExportTensors(
@@ -230,7 +263,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
- std::set<string>* unsupported_ops, const ExportParams& params) {
+ const ExportParams& params) {
// Map from operator name to TF Lite enum value, for all builtins.
std::map<string, BuiltinOperator> builtin_ops;
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
@@ -247,37 +280,16 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
for (const auto& op : model.operators) {
const details::OperatorKey operator_key =
- GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
+ details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
int op_index = operators_map.at(operator_key);
- int op_version = operator_key.version;
- string name = HelpfulOperatorTypeName(*op);
- bool is_builtin = false;
- if (ops_by_type.count(op->type) != 0) {
- name = ops_by_type.at(op->type)->name();
- is_builtin = (builtin_ops.count(name) > 0);
+ flatbuffers::Offset<flatbuffers::String> custom_code = 0;
+ if (!operator_key.custom_code.empty()) {
+ custom_code = builder->CreateString(operator_key.custom_code);
}
- if (is_builtin) {
- ordered_opcodes[op_index] =
- CreateOperatorCode(*builder, builtin_ops[name], 0, op_version);
- } else {
- // This could be a kUnsupported, in which case we should be
- // able to retrieve the original Tensorflow name from the OperatorKey, or
- // this could be a proper TOCO operator that is completely unknown to TF
- // Lite.
- if (!operator_key.custom_code.empty()) {
- name = operator_key.custom_code;
- }
- // Either way, this is an operator that is not supported by TF Lite,
- // so we output it as a custom op and add it to the error summary.
- if (unsupported_ops) {
- unsupported_ops->insert(name);
- }
- ordered_opcodes[op_index] =
- CreateOperatorCode(*builder, BuiltinOperator_CUSTOM,
- builder->CreateString(name), op_version);
- }
+ ordered_opcodes[op_index] = CreateOperatorCode(
+ *builder, operator_key.type, custom_code, operator_key.version);
}
std::vector<Offset<OperatorCode>> opcode_vector;
@@ -312,7 +324,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
}
int op_index = operators_map.at(
- GetOperatorKey(*op, ops_by_type, params.allow_flex_ops));
+ details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops));
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -386,9 +398,8 @@ void Export(
Array empty_array;
buffers_to_write.push_back(&empty_array);
- std::set<string> unsupported_ops;
- auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
- &builder, &unsupported_ops, params);
+ auto op_codes =
+ ExportOperatorCodes(model, ops_by_type, operators_map, &builder, params);
for (const auto& op : model.operators) {
if (op->type == OperatorType::kFakeQuant) {
@@ -398,7 +409,20 @@ void Export(
"for --std_values and --mean_values.";
}
}
- if (!unsupported_ops.empty()) {
+
+ std::set<string> custom_ops;
+ std::set<string> unsupported_flex_ops;
+ for (const auto& it : operators_map) {
+ const details::OperatorKey& key = it.first;
+ if (key.is_custom_op) {
+ custom_ops.insert(key.custom_code);
+ }
+ if (key.is_unsupported_flex_op) {
+ unsupported_flex_ops.insert(key.flex_tensorflow_op);
+ }
+ }
+
+ if (!custom_ops.empty()) {
if (!params.allow_custom_ops) {
// Remove ExpandDims and ReorderAxes from unimplemented list unless they
// compose the list. Both ops are removed during graph transformations.
@@ -406,14 +430,14 @@ void Export(
// transformation is unable to run because the output shape is not
// defined. This causes unnecessary confusion during model conversion
// time.
- std::set<string> unsupported_ops_final;
- for (const auto& op_type : unsupported_ops) {
+ std::set<string> custom_ops_final;
+ for (const auto& op_type : custom_ops) {
if (op_type != "ReorderAxes" && op_type != "ExpandDims") {
- unsupported_ops_final.insert(op_type);
+ custom_ops_final.insert(op_type);
}
}
- if (unsupported_ops_final.empty()) {
- unsupported_ops_final = unsupported_ops;
+ if (custom_ops_final.empty()) {
+ custom_ops_final = custom_ops;
}
LOG(QFATAL)
@@ -423,13 +447,13 @@ void Export(
"--allow_custom_ops, or by setting allow_custom_ops=True "
"when calling tf.contrib.lite.TFLiteConverter(). Here is a list "
"of operators for which you will need custom implementations: "
- << absl::StrJoin(unsupported_ops_final, ", ") << ".";
+ << absl::StrJoin(custom_ops_final, ", ") << ".";
}
std::set<string> unsupported_control_flow_ops;
// Check if unsupported ops contains control flow ops. It's impossible
// to implement these ops as custom ops at the moment.
- for (const auto& op : unsupported_ops) {
+ for (const auto& op : custom_ops) {
if (IsControlFlowOp(op)) {
unsupported_control_flow_ops.insert(op);
}
@@ -441,14 +465,6 @@ void Export(
}
}
- std::set<string> unsupported_flex_ops;
- for (const auto& it : operators_map) {
- const details::OperatorKey& key = it.first;
- if (key.is_unsupported_flex_op) {
- unsupported_flex_ops.insert(key.custom_code);
- }
- }
-
if (!unsupported_flex_ops.empty()) {
LOG(QFATAL) << "Some of the operators in the model are not supported by "
"TensorFlow Flex runtime: "