aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-10-03 17:25:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 17:29:59 -0700
commit4da5b350e1c062b9d55896ee872e0e4790f30bcb (patch)
tree4057f0340517f75fead4053093d5b10aaacac9ba /tensorflow/contrib
parentc842d38978a0babb373fe2acbb0231960aa1c1d0 (diff)
TFLite Flex: Blacklist Control Flow Ops
PiperOrigin-RevId: 215658384
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc132
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h20
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc40
3 files changed, 152 insertions, 40 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 0c9fac249c..45ca7f7f0c 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -47,6 +47,22 @@ using ::tflite::Tensor;
namespace {
+// Check if a TensorFlow Op is a control flow op by its name.
+bool IsControlFlowOp(const string& tensorflow_op) {
+ // Technically this is equalivent to `::tensorflow::Node::IsControlFlow()`.
+ // It requires to construct a `::tensorflow::Graph` to use that helper
+ // function, so we simply hardcode the list of control flow ops here.
+ if (tensorflow_op == "Switch" || tensorflow_op == "RefSwitch" ||
+ tensorflow_op == "Merge" || tensorflow_op == "RefMerge" ||
+ tensorflow_op == "Enter" || tensorflow_op == "RefEnter" ||
+ tensorflow_op == "Exit" || tensorflow_op == "RefExit" ||
+ tensorflow_op == "NextIteration" || tensorflow_op == "RefNextIteration") {
+ return true;
+ }
+ // TODO(ycling): Also check how to handle Variable ops and Assign ops.
+ return false;
+}
+
details::OperatorKey GetOperatorKey(
const ::toco::Operator& op,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
@@ -55,21 +71,13 @@ details::OperatorKey GetOperatorKey(
if (op.type == OperatorType::kUnsupported) {
const TensorFlowUnsupportedOperator& unsupported_op =
static_cast<const TensorFlowUnsupportedOperator&>(op);
-
- // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
- // to populate a regular custom op. We need to find a way to fix this.
- if (allow_flex_ops) {
- custom_code = string(::tflite::kFlexCustomCodePrefix) +
- unsupported_op.tensorflow_op;
- } else {
- custom_code = unsupported_op.tensorflow_op;
- }
+ custom_code = unsupported_op.tensorflow_op;
}
int version = 1;
if (ops_by_type.count(op.type) != 0) {
version = ops_by_type.at(op.type)->GetVersion(op);
}
- return details::OperatorKey(op.type, custom_code, version);
+ return details::OperatorKey(op.type, custom_code, version, allow_flex_ops);
}
void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
@@ -83,6 +91,29 @@ void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
namespace details {
+OperatorKey::OperatorKey(OperatorType type, const std::string& custom_code,
+ int version, bool allow_flex_ops) {
+ this->type = type;
+ this->custom_code = custom_code;
+ this->version = version;
+
+ if (type == OperatorType::kUnsupported) {
+ // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
+ // to populate a regular custom op. We need to find a way to fix this.
+ if (allow_flex_ops) {
+ // Memorize the original TensorFlow op name.
+ this->flex_tensorflow_op = custom_code;
+ // Prefix the custom code of the flex op.
+ this->custom_code = string(::tflite::kFlexCustomCodePrefix) + custom_code;
+ this->is_flex_op = true;
+
+ if (IsControlFlowOp(this->flex_tensorflow_op)) {
+ is_unsupported_flex_op = true;
+ }
+ }
+ }
+}
+
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
// First find a list of unique array names.
std::set<string> names;
@@ -199,7 +230,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
- std::set<string>* error_summary, const ExportParams& params) {
+ std::set<string>* unsupported_ops, const ExportParams& params) {
// Map from operator name to TF Lite enum value, for all builtins.
std::map<string, BuiltinOperator> builtin_ops;
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
@@ -240,8 +271,8 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
}
// Either way, this is an operator that is not supported by TF Lite,
// so we output it as a custom op and add it to the error summary.
- if (error_summary) {
- error_summary->insert(name);
+ if (unsupported_ops) {
+ unsupported_ops->insert(name);
}
ordered_opcodes[op_index] =
CreateOperatorCode(*builder, BuiltinOperator_CUSTOM,
@@ -355,9 +386,9 @@ void Export(
Array empty_array;
buffers_to_write.push_back(&empty_array);
- std::set<string> error_summary;
+ std::set<string> unsupported_ops;
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
- &builder, &error_summary, params);
+ &builder, &unsupported_ops, params);
for (const auto& op : model.operators) {
if (op->type == OperatorType::kFakeQuant) {
@@ -367,30 +398,61 @@ void Export(
"for --std_values and --mean_values.";
}
}
- if (!params.allow_custom_ops && !error_summary.empty()) {
- // Remove ExpandDims and ReorderAxes from unimplemented list unless they
- // compose the list. Both ops are removed during graph transformations.
- // However, if an op is unimplemented earlier in the model, the graph
- // transformation is unable to run because the output shape is not defined.
- // This causes unnecessary confusion during model conversion time.
- std::set<string> error_summary_final;
- for (const auto& op_type : error_summary) {
- if (op_type != "ReorderAxes" && op_type != "ExpandDims") {
- error_summary_final.insert(op_type);
+ if (!unsupported_ops.empty()) {
+ if (!params.allow_custom_ops) {
+ // Remove ExpandDims and ReorderAxes from unimplemented list unless they
+ // compose the list. Both ops are removed during graph transformations.
+ // However, if an op is unimplemented earlier in the model, the graph
+ // transformation is unable to run because the output shape is not
+ // defined. This causes unnecessary confusion during model conversion
+ // time.
+ std::set<string> unsupported_ops_final;
+ for (const auto& op_type : unsupported_ops) {
+ if (op_type != "ReorderAxes" && op_type != "ExpandDims") {
+ unsupported_ops_final.insert(op_type);
+ }
+ }
+ if (unsupported_ops_final.empty()) {
+ unsupported_ops_final = unsupported_ops;
+ }
+
+ LOG(QFATAL)
+ << "Some of the operators in the model are not supported by "
+ "the standard TensorFlow Lite runtime. If you have a custom "
+ "implementation for them you can disable this error with "
+ "--allow_custom_ops, or by setting allow_custom_ops=True "
+ "when calling tf.contrib.lite.TFLiteConverter(). Here is a list "
+ "of operators for which you will need custom implementations: "
+ << absl::StrJoin(unsupported_ops_final, ", ") << ".";
+ }
+
+ std::set<string> unsupported_control_flow_ops;
+ // Check if unsupported ops contains control flow ops. It's impossible
+ // to implement these ops as custom ops at the moment.
+ for (const auto& op : unsupported_ops) {
+ if (IsControlFlowOp(op)) {
+ unsupported_control_flow_ops.insert(op);
}
}
- if (error_summary_final.empty()) {
- error_summary_final = error_summary;
+ if (!unsupported_control_flow_ops.empty()) {
+ LOG(QFATAL)
+ << "TensorFlow Lite currently doesn't support control flow ops: "
+ << absl::StrJoin(unsupported_control_flow_ops, ", ") << ".";
}
+ }
+
+ std::set<string> unsupported_flex_ops;
+ for (const auto& it : operators_map) {
+ const details::OperatorKey& key = it.first;
+ if (key.is_unsupported_flex_op) {
+ unsupported_flex_ops.insert(key.custom_code);
+ }
+ }
- LOG(QFATAL)
- << "Some of the operators in the model are not supported by "
- "the standard TensorFlow Lite runtime. If you have a custom "
- "implementation for them you can disable this error with "
- "--allow_custom_ops, or by setting allow_custom_ops=True "
- "when calling tf.contrib.lite.TFLiteConverter(). Here is a list "
- "of operators for which you will need custom implementations: "
- << absl::StrJoin(error_summary_final, ", ") << ".";
+ if (!unsupported_flex_ops.empty()) {
+ LOG(QFATAL) << "Some of the operators in the model are not supported by "
+ "TensorFlow Flex runtime: "
+ << absl::StrJoin(unsupported_flex_ops, ", ") << ".";
}
std::set<int32_t> variable_tensor_indices;
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 29d6de4049..9efb282c6c 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -81,11 +81,21 @@ using TensorsMap = std::unordered_map<string, int>;
// Only when `type` is `kUnsupported`, `custom_code` is filled to
// identify which operation is used.
struct OperatorKey {
- OperatorKey(OperatorType type, const std::string& custom_code, int version)
- : type(type), custom_code(custom_code), version(version) {}
- const OperatorType type;
- const std::string custom_code;
- const int version;
+ OperatorKey(OperatorType type, const std::string& custom_code, int version,
+ bool allow_flex_ops = false);
+
+ // Only `type`, `custom_code` and `version` is used to compute hash and
+ // identity.
+ OperatorType type;
+ std::string custom_code;
+ int version;
+
+ // THe fields below are not used to compute hash and identity.
+ bool is_flex_op = false;
+ bool is_unsupported_flex_op = false;
+ // The original TensorFlow op name for the flex op. Filled only when
+ // `is_flex_op` is true.
+ std::string flex_tensorflow_op;
bool operator<(const OperatorKey& other) const {
if (type < other.type) return true;
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index 93882a91a7..a71a64d56f 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -313,6 +313,46 @@ TEST_F(VersionedOpExportTest, Export) {
EXPECT_EQ(1, (*operators)[1]->opcode_index());
}
+TEST(OperatorKeyTest, TestBuiltinOp) {
+ details::OperatorKey key(OperatorType::kConv, "", 2);
+ EXPECT_EQ(key.type, OperatorType::kConv);
+ EXPECT_EQ(key.custom_code, "");
+ EXPECT_EQ(key.version, 2);
+}
+
+TEST(OperatorKeyTest, TestFlexOp) {
+ {
+ details::OperatorKey key(OperatorType::kUnsupported, "SomeUnsupportedOp", 1,
+ false);
+ EXPECT_EQ(key.type, OperatorType::kUnsupported);
+ // It shouldn't be converted to Flex op if `allow_flex_op` is false.
+ EXPECT_EQ(key.custom_code, "SomeUnsupportedOp");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_FALSE(key.is_flex_op);
+ }
+
+ {
+ details::OperatorKey key(OperatorType::kUnsupported, "SomeUnsupportedOp", 1,
+ true);
+ EXPECT_EQ(key.type, OperatorType::kUnsupported);
+ // Verify that the custom op name is prefixed by "Flex" and `is_flex_op`
+ // is true.
+ EXPECT_EQ(key.custom_code, "FlexSomeUnsupportedOp");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_TRUE(key.is_flex_op);
+ }
+}
+
+TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
+ details::OperatorKey key(OperatorType::kUnsupported, "Merge", 1, true);
+ EXPECT_EQ(key.type, OperatorType::kUnsupported);
+ EXPECT_EQ(key.custom_code, "FlexMerge");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_TRUE(key.is_flex_op);
+ // The control flow ops should be marked as unsupported.
+ EXPECT_TRUE(key.is_unsupported_flex_op);
+}
+
// TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators.
} // namespace