diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-25 14:35:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 14:40:51 -0700 |
commit | 038d15d8e2037d4a45e60e076429d67ec7d5ace1 (patch) | |
tree | 713a1658c6a4e0137471c897e54cab5e2dce7191 /tensorflow/core/grappler | |
parent | dea456f341acc9b44bfc6e115a95a99c5ebac58b (diff) |
Bug fix for OpOutputPortIdToArgId, include type_list_attr.
PiperOrigin-RevId: 214505566
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r-- | tensorflow/core/grappler/graph_view.cc | 27 | ||||
-rw-r--r-- | tensorflow/core/grappler/graph_view_test.cc | 29 |
2 files changed, 44 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index b8d8243174..2619a9a8f3 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -29,21 +29,24 @@ int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { return output_arg_id; } + // Default is 1 port per output arg. + int n = 1; + const auto& output_arg = op.output_arg(output_arg_id); if (!output_arg.number_attr().empty()) { - const int n = node.attr().at(output_arg.number_attr()).i(); - if (n < 0) { - // This should never happen. - DCHECK_GE(n, 0); - return -1; - } - if (port_id < n) { - return output_arg_id; - } - port_id -= n; - } else { - --port_id; + n = node.attr().at(output_arg.number_attr()).i(); + } else if (!output_arg.type_list_attr().empty()) { + n = node.attr().at(output_arg.type_list_attr()).list().type_size(); + } + + if (n < 0) { + // This should never happen. + DCHECK_GE(n, 0); + return -1; + } else if (port_id < n) { + return output_arg_id; } + port_id -= n; } return -1; diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index 30512d9d47..3d7d2faf7c 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/cc/ops/parsing_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" @@ -79,6 +80,34 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) { } } +TEST_F(GraphViewTest, ParseSingleExample) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const<string>(s.WithOpName("a"), "", {}); + Output b = ops::Const<int64>(s.WithOpName("b"), 1, {1, 1}); + ops::ParseSingleExample c(s.WithOpName("c"), a, {b, b}, 2, {"w", "x"}, + {"y", "z"}, {DT_INT64, DT_INT64}, {{1}, {1}}); + + GraphDef graph_def; + TF_CHECK_OK(s.ToGraphDef(&graph_def)); + GraphView graph_view(&graph_def); + + const NodeDef& c_node_def = *graph_view.GetNode("c"); + + const OpDef* c_op_def = nullptr; + EXPECT_TRUE( + OpRegistry::Global()->LookUpOpDef(c_node_def.op(), &c_op_def).ok()); + + EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 0)); + EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 1)); + EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 2)); + EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 3)); + EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 4)); + EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 5)); + EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 6)); + EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 7)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 8)); +} + TEST_F(GraphViewTest, BasicGraph) { TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"}); GrapplerItem item; |