diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-10 08:01:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-10 08:06:09 -0700 |
commit | 0bb68afa38cf5c45232e85fb09186e01055e4d11 (patch) | |
tree | febe084b9d02491a2401b65ea067466fcabfbd24 /tensorflow/contrib/lite/toco/import_tensorflow_test.cc | |
parent | 93226f635c5c108b3b501d8bbcf27e64dec49fb9 (diff) |
Fix number of outputs when importing tensorflow GraphDef.
Sometimes the actual number of outputs is dictated by one of the attributes of the NodeDef.
PiperOrigin-RevId: 216530696
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow_test.cc | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index cd9a144b52..0767221b83 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -55,6 +55,13 @@ Status ImportNode(const NodeDef& node, Model* model) { converter); } +Status ImportFlexNode(const NodeDef& node, Model* model) { + // Empty converter => all nodes are flex nodes. + const auto converter = internal::ConverterMapType(); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model, + converter); +} + Status ImportNode(const NodeDef& node) { Model model; return ImportNode(node, &model); @@ -299,5 +306,29 @@ TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) { ASSERT_TRUE(op->output_shapes.empty()); } +TEST(ImportTest, UnsupportedOpWithMultipleOutputs) { + NodeDef node = BuildNode("Unpack", {}); + + // Unpack's OpDef has a single output which gets multiplied based on the + // "num" attribute of the NodeDef. + AttrValue value_attr; + SetAttrValue(3, &value_attr); // 3 outputs. + (*node.mutable_attr())["num"] = value_attr; + + Model model; + EXPECT_TRUE(ImportFlexNode(node, &model).ok()); + + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast<const TensorFlowUnsupportedOperator*>( + model.operators[0].get()); + + ASSERT_EQ(op->outputs.size(), 3); + ASSERT_EQ(op->outputs[0], "Node1"); + ASSERT_EQ(op->outputs[1], "Node1:1"); + ASSERT_EQ(op->outputs[2], "Node1:2"); +} + } // namespace } // namespace toco |