From 0bb68afa38cf5c45232e85fb09186e01055e4d11 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 10 Oct 2018 08:01:45 -0700 Subject: Fix number of outputs when importing tensorflow GraphDef. Sometimes the actual number of outputs is dictated by one of the attributes of the NodeDef. PiperOrigin-RevId: 216530696 --- tensorflow/contrib/lite/toco/import_tensorflow.cc | 22 ++++++++++++--- .../contrib/lite/toco/import_tensorflow_test.cc | 31 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 6b195cc992..ff67b306e0 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1122,13 +1122,27 @@ tensorflow::Status ConvertUnsupportedOperator( op->inputs.push_back(node.input(i)); } - // Parse outputs. - op->outputs.push_back(node.name()); // Implicit :0. + // Parse outputs. Name them after the node's name, plus an ordinal suffix. + // Note that some outputs are to be multipled by a named attribute. const tensorflow::OpDef* op_def = nullptr; if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) { - for (int i = 1; i < op_def->output_arg_size(); ++i) { - op->outputs.push_back(absl::StrCat(node.name(), ":", i)); + int next_output = 0; + for (int i = 0; i < op_def->output_arg_size(); ++i) { + string multiples = op_def->output_arg(i).number_attr(); + int num_outputs = multiples.empty() ? 1 : GetIntAttr(node, multiples); + LOG(INFO) << "dddddddd " << num_outputs; + for (int j = 0; j < num_outputs; ++j) { + if (next_output == 0) { + op->outputs.push_back(node.name()); // Implicit :0. + } else { + op->outputs.push_back(absl::StrCat(node.name(), ":", next_output)); + } + ++next_output; + } } + } else { + LOG(INFO) << "nodef!!!!!!!!!!! "; + op->outputs.push_back(node.name()); // Implicit :0. } // Parse if the op supports quantization diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index cd9a144b52..0767221b83 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -55,6 +55,13 @@ Status ImportNode(const NodeDef& node, Model* model) { converter); } +Status ImportFlexNode(const NodeDef& node, Model* model) { + // Empty converter => all nodes are flex nodes. + const auto converter = internal::ConverterMapType(); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model, + converter); +} + Status ImportNode(const NodeDef& node) { Model model; return ImportNode(node, &model); @@ -299,5 +306,29 @@ TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) { ASSERT_TRUE(op->output_shapes.empty()); } +TEST(ImportTest, UnsupportedOpWithMultipleOutputs) { + NodeDef node = BuildNode("Unpack", {}); + + // Unpack's OpDef has a single output which gets multiplied based on the + // "num" attribute of the NodeDef. + AttrValue value_attr; + SetAttrValue(3, &value_attr); // 3 outputs. + (*node.mutable_attr())["num"] = value_attr; + + Model model; + EXPECT_TRUE(ImportFlexNode(node, &model).ok()); + + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast( + model.operators[0].get()); + + ASSERT_EQ(op->outputs.size(), 3); + ASSERT_EQ(op->outputs[0], "Node1"); + ASSERT_EQ(op->outputs[1], "Node1:1"); + ASSERT_EQ(op->outputs[2], "Node1:2"); +} + } // namespace } // namespace toco -- cgit v1.2.3