aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-25 14:35:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 14:40:51 -0700
commit038d15d8e2037d4a45e60e076429d67ec7d5ace1 (patch)
tree713a1658c6a4e0137471c897e54cab5e2dce7191 /tensorflow/core/grappler
parentdea456f341acc9b44bfc6e115a95a99c5ebac58b (diff)
Bug fix for OpOutputPortIdToArgId, include type_list_attr.
PiperOrigin-RevId: 214505566
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/graph_view.cc27
-rw-r--r--tensorflow/core/grappler/graph_view_test.cc29
2 files changed, 44 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index b8d8243174..2619a9a8f3 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -29,21 +29,24 @@ int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
return output_arg_id;
}
+ // Default is 1 port per output arg.
+ int n = 1;
+
const auto& output_arg = op.output_arg(output_arg_id);
if (!output_arg.number_attr().empty()) {
- const int n = node.attr().at(output_arg.number_attr()).i();
- if (n < 0) {
- // This should never happen.
- DCHECK_GE(n, 0);
- return -1;
- }
- if (port_id < n) {
- return output_arg_id;
- }
- port_id -= n;
- } else {
- --port_id;
+ n = node.attr().at(output_arg.number_attr()).i();
+ } else if (!output_arg.type_list_attr().empty()) {
+ n = node.attr().at(output_arg.type_list_attr()).list().type_size();
+ }
+
+ if (n < 0) {
+ // This should never happen.
+ DCHECK_GE(n, 0);
+ return -1;
+ } else if (port_id < n) {
+ return output_arg_id;
}
+ port_id -= n;
}
return -1;
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 30512d9d47..3d7d2faf7c 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@@ -79,6 +80,34 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
}
}
+TEST_F(GraphViewTest, ParseSingleExample) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const<string>(s.WithOpName("a"), "", {});
+ Output b = ops::Const<int64>(s.WithOpName("b"), 1, {1, 1});
+ ops::ParseSingleExample c(s.WithOpName("c"), a, {b, b}, 2, {"w", "x"},
+ {"y", "z"}, {DT_INT64, DT_INT64}, {{1}, {1}});
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& c_node_def = *graph_view.GetNode("c");
+
+ const OpDef* c_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(c_node_def.op(), &c_op_def).ok());
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 1));
+ EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 2));
+ EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 3));
+ EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 4));
+ EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 5));
+ EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 6));
+ EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 7));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 8));
+}
+
TEST_F(GraphViewTest, BasicGraph) {
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
GrapplerItem item;