From 0bb68afa38cf5c45232e85fb09186e01055e4d11 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 10 Oct 2018 08:01:45 -0700 Subject: Fix number of outputs when importing tensorflow GraphDef. Sometimes the actual number of outputs is dictated by one of the attributes of the NodeDef. PiperOrigin-RevId: 216530696 --- .../contrib/lite/toco/import_tensorflow_test.cc | 31 ++++++++++++++++++++++ 1 file changed, 31 insertions(+) (limited to 'tensorflow/contrib/lite/toco/import_tensorflow_test.cc') diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index cd9a144b52..0767221b83 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -55,6 +55,13 @@ Status ImportNode(const NodeDef& node, Model* model) { converter); } +Status ImportFlexNode(const NodeDef& node, Model* model) { + // Empty converter => all nodes are flex nodes. + const auto converter = internal::ConverterMapType(); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model, + converter); +} + Status ImportNode(const NodeDef& node) { Model model; return ImportNode(node, &model); @@ -299,5 +306,29 @@ TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) { ASSERT_TRUE(op->output_shapes.empty()); } +TEST(ImportTest, UnsupportedOpWithMultipleOutputs) { + NodeDef node = BuildNode("Unpack", {}); + + // Unpack's OpDef has a single output which gets multiplied based on the + // "num" attribute of the NodeDef. + AttrValue value_attr; + SetAttrValue(3, &value_attr); // 3 outputs. + (*node.mutable_attr())["num"] = value_attr; + + Model model; + EXPECT_TRUE(ImportFlexNode(node, &model).ok()); + + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast( + model.operators[0].get()); + + ASSERT_EQ(op->outputs.size(), 3); + ASSERT_EQ(op->outputs[0], "Node1"); + ASSERT_EQ(op->outputs[1], "Node1:1"); + ASSERT_EQ(op->outputs[2], "Node1:2"); +} + } // namespace } // namespace toco -- cgit v1.2.3