aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-10 08:01:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-10 08:06:09 -0700
commit0bb68afa38cf5c45232e85fb09186e01055e4d11 (patch)
treefebe084b9d02491a2401b65ea067466fcabfbd24
parent93226f635c5c108b3b501d8bbcf27e64dec49fb9 (diff)
Fix number of outputs when importing tensorflow GraphDef.
Sometimes the actual number of outputs is dictated by one of the attributes of the NodeDef. PiperOrigin-RevId: 216530696
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc22
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc31
2 files changed, 49 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 6b195cc992..ff67b306e0 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1122,13 +1122,27 @@ tensorflow::Status ConvertUnsupportedOperator(
op->inputs.push_back(node.input(i));
}
- // Parse outputs.
- op->outputs.push_back(node.name()); // Implicit :0.
+ // Parse outputs. Name them after the node's name, plus an ordinal suffix.
+ // Note that some outputs are to be multipled by a named attribute.
const tensorflow::OpDef* op_def = nullptr;
if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
- for (int i = 1; i < op_def->output_arg_size(); ++i) {
- op->outputs.push_back(absl::StrCat(node.name(), ":", i));
+ int next_output = 0;
+ for (int i = 0; i < op_def->output_arg_size(); ++i) {
+ string multiples = op_def->output_arg(i).number_attr();
+ int num_outputs = multiples.empty() ? 1 : GetIntAttr(node, multiples);
+ LOG(INFO) << "dddddddd " << num_outputs;
+ for (int j = 0; j < num_outputs; ++j) {
+ if (next_output == 0) {
+ op->outputs.push_back(node.name()); // Implicit :0.
+ } else {
+ op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
+ }
+ ++next_output;
+ }
}
+ } else {
+ LOG(INFO) << "nodef!!!!!!!!!!! ";
+ op->outputs.push_back(node.name()); // Implicit :0.
}
// Parse if the op supports quantization
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index cd9a144b52..0767221b83 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -55,6 +55,13 @@ Status ImportNode(const NodeDef& node, Model* model) {
converter);
}
+Status ImportFlexNode(const NodeDef& node, Model* model) {
+ // Empty converter => all nodes are flex nodes.
+ const auto converter = internal::ConverterMapType();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
+ converter);
+}
+
Status ImportNode(const NodeDef& node) {
Model model;
return ImportNode(node, &model);
@@ -299,5 +306,29 @@ TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) {
ASSERT_TRUE(op->output_shapes.empty());
}
+TEST(ImportTest, UnsupportedOpWithMultipleOutputs) {
+ NodeDef node = BuildNode("Unpack", {});
+
+ // Unpack's OpDef has a single output which gets multiplied based on the
+ // "num" attribute of the NodeDef.
+ AttrValue value_attr;
+ SetAttrValue(3, &value_attr); // 3 outputs.
+ (*node.mutable_attr())["num"] = value_attr;
+
+ Model model;
+ EXPECT_TRUE(ImportFlexNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+
+ ASSERT_EQ(op->outputs.size(), 3);
+ ASSERT_EQ(op->outputs[0], "Node1");
+ ASSERT_EQ(op->outputs[1], "Node1:1");
+ ASSERT_EQ(op->outputs[2], "Node1:2");
+}
+
} // namespace
} // namespace toco