aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-09-17 15:50:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 16:03:03 -0700
commit8ea4ea95ad1e85948019daee7a9e70e97082f6d0 (patch)
tree45b9059b5d67abb4f4b12896f87f88f8fa56a85a /tensorflow/core/graph
parentd5f4c3aa59aebc88f42a186a30ef6200857194ca (diff)
Fix GraphConstructor and import_graph_def bug with variadic ops.
Prior to this change, GraphConstructor::PopulateMissingUnusedInputMapKey() didn't correctly compute the number of outputs for ops with variadic outputs. This meant that missing_unused_input_map_keys could contain spurious entries for unused variadic outputs, which could trigger a ValueError in import_graph_def. This also adds a new util method in node_def_util.h, NumOutputsForNode(). PiperOrigin-RevId: 213353158
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/graph_constructor.cc4
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc9
2 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 7399613f6a..eeb5c14eaa 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -1162,7 +1162,9 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
const NodeDef* node_def = node_defs_[pair->second.gdef_index];
const OpDef* op_def;
TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
- if (key.second >= op_def->output_arg_size()) {
+ int num_outputs;
+ TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs));
+ if (key.second >= num_outputs) {
// key's index out of bounds
missing_unused_input_map_keys_->push_back(key);
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 73142ebde7..3eef6bd2bd 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -199,6 +199,10 @@ REGISTER_OP("TestOneInputOneOutput")
.Output("y: T")
.Attr("T: {float, int64}")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("TestVariadicOutput")
+ .Output("outputs: N * int32")
+ .Attr("N: int >= 0")
+ .SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("TestDefaultAttr")
.Attr("default_int: int=31415")
.SetShapeFn(shape_inference::NoOutputs);
@@ -1463,12 +1467,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) {
opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0);
// Unused but not missing
opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0);
+ // Unused but not missing
+ opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0);
ExpectOK(
R"EOF(
node { name: 'W2' op: 'TestParams' }
node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] }
node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
- node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
+ node { name: 'variadic' op: 'TestVariadicOutput'
+ attr { key: "N" value { i: 5 } } }
)EOF",
opts, &refiner, &results);