aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow_test.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index cd9a144b52..0767221b83 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -55,6 +55,13 @@ Status ImportNode(const NodeDef& node, Model* model) {
converter);
}
+Status ImportFlexNode(const NodeDef& node, Model* model) {
+ // Empty converter => all nodes are flex nodes.
+ const auto converter = internal::ConverterMapType();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
+ converter);
+}
+
Status ImportNode(const NodeDef& node) {
Model model;
return ImportNode(node, &model);
@@ -299,5 +306,29 @@ TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) {
ASSERT_TRUE(op->output_shapes.empty());
}
+TEST(ImportTest, UnsupportedOpWithMultipleOutputs) {
+ NodeDef node = BuildNode("Unpack", {});
+
+ // Unpack's OpDef has a single output which gets multiplied based on the
+ // "num" attribute of the NodeDef.
+ AttrValue value_attr;
+ SetAttrValue(3, &value_attr); // 3 outputs.
+ (*node.mutable_attr())["num"] = value_attr;
+
+ Model model;
+ EXPECT_TRUE(ImportFlexNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+
+ ASSERT_EQ(op->outputs.size(), 3);
+ ASSERT_EQ(op->outputs[0], "Node1");
+ ASSERT_EQ(op->outputs[1], "Node1:1");
+ ASSERT_EQ(op->outputs[2], "Node1:2");
+}
+
} // namespace
} // namespace toco