diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-09-17 15:50:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 16:03:03 -0700 |
commit | 8ea4ea95ad1e85948019daee7a9e70e97082f6d0 (patch) | |
tree | 45b9059b5d67abb4f4b12896f87f88f8fa56a85a /tensorflow/core/graph | |
parent | d5f4c3aa59aebc88f42a186a30ef6200857194ca (diff) |
Fix GraphConstructor and import_graph_def bug with variadic ops.
Prior to this change,
GraphConstructor::PopulateMissingUnusedInputMapKey() didn't correctly
compute the number of outputs for ops with variadic outputs. This
meant that missing_unused_input_map_keys could contain spurious
entries for unused variadic outputs, which could trigger a ValueError
in import_graph_def.
This also adds a new util method in node_def_util.h, NumOutputsForNode().
PiperOrigin-RevId: 213353158
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 9 |
2 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 7399613f6a..eeb5c14eaa 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -1162,7 +1162,9 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { const NodeDef* node_def = node_defs_[pair->second.gdef_index]; const OpDef* op_def; TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); - if (key.second >= op_def->output_arg_size()) { + int num_outputs; + TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs)); + if (key.second >= num_outputs) { // key's index out of bounds missing_unused_input_map_keys_->push_back(key); } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 73142ebde7..3eef6bd2bd 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -199,6 +199,10 @@ REGISTER_OP("TestOneInputOneOutput") .Output("y: T") .Attr("T: {float, int64}") .SetShapeFn(shape_inference::UnchangedShape); +REGISTER_OP("TestVariadicOutput") + .Output("outputs: N * int32") + .Attr("N: int >= 0") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("TestDefaultAttr") .Attr("default_int: int=31415") .SetShapeFn(shape_inference::NoOutputs); @@ -1463,12 +1467,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) { opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0); // Unused but not missing opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0); + // Unused but not missing + opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0); ExpectOK( R"EOF( node { name: 'W2' op: 'TestParams' } node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] } node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] } - node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] } + node { name: 'variadic' op: 'TestVariadicOutput' + attr { key: "N" value { i: 5 } } } )EOF", opts, &refiner, &results); |