aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc22
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc31
2 files changed, 49 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 6b195cc992..ff67b306e0 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1122,13 +1122,27 @@ tensorflow::Status ConvertUnsupportedOperator(
op->inputs.push_back(node.input(i));
}
- // Parse outputs.
- op->outputs.push_back(node.name()); // Implicit :0.
+ // Parse outputs. Name them after the node's name, plus an ordinal suffix.
+ // Note that some outputs are to be multipled by a named attribute.
const tensorflow::OpDef* op_def = nullptr;
if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
- for (int i = 1; i < op_def->output_arg_size(); ++i) {
- op->outputs.push_back(absl::StrCat(node.name(), ":", i));
+ int next_output = 0;
+ for (int i = 0; i < op_def->output_arg_size(); ++i) {
+ string multiples = op_def->output_arg(i).number_attr();
+ int num_outputs = multiples.empty() ? 1 : GetIntAttr(node, multiples);
+ LOG(INFO) << "dddddddd " << num_outputs;
+ for (int j = 0; j < num_outputs; ++j) {
+ if (next_output == 0) {
+ op->outputs.push_back(node.name()); // Implicit :0.
+ } else {
+ op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
+ }
+ ++next_output;
+ }
}
+ } else {
+ LOG(INFO) << "nodef!!!!!!!!!!! ";
+ op->outputs.push_back(node.name()); // Implicit :0.
}
// Parse if the op supports quantization
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index cd9a144b52..0767221b83 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -55,6 +55,13 @@ Status ImportNode(const NodeDef& node, Model* model) {
converter);
}
+Status ImportFlexNode(const NodeDef& node, Model* model) {
+ // Empty converter => all nodes are flex nodes.
+ const auto converter = internal::ConverterMapType();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
+ converter);
+}
+
Status ImportNode(const NodeDef& node) {
Model model;
return ImportNode(node, &model);
@@ -299,5 +306,29 @@ TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) {
ASSERT_TRUE(op->output_shapes.empty());
}
+TEST(ImportTest, UnsupportedOpWithMultipleOutputs) {
+ NodeDef node = BuildNode("Unpack", {});
+
+ // Unpack's OpDef has a single output which gets multiplied based on the
+ // "num" attribute of the NodeDef.
+ AttrValue value_attr;
+ SetAttrValue(3, &value_attr); // 3 outputs.
+ (*node.mutable_attr())["num"] = value_attr;
+
+ Model model;
+ EXPECT_TRUE(ImportFlexNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+
+ ASSERT_EQ(op->outputs.size(), 3);
+ ASSERT_EQ(op->outputs[0], "Node1");
+ ASSERT_EQ(op->outputs[1], "Node1:1");
+ ASSERT_EQ(op->outputs[2], "Node1:2");
+}
+
} // namespace
} // namespace toco