diff options
author | Yao Zhang <yaozhang@google.com> | 2017-06-22 14:46:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-22 14:53:19 -0700 |
commit | 96e254fb32841112805c79bc8df767854e9a48f3 (patch) | |
tree | a4b7b21a8855c4312e48124d1b69ae95bb1aa7a2 | |
parent | bac55c51f59c2578ae19b33cbfe384c946e471b6 (diff) |
If rank is unknown, do not add output shapes to transpose nodes.
PiperOrigin-RevId: 159879840
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 28 |
1 files changed, 15 insertions, 13 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 1bc7b6d44d..ded1e474ce 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -240,20 +240,22 @@ class NodeProcessor { AttrValue attr_data_type_perm; attr_data_type_perm.set_type(DT_INT32); node->mutable_attr()->insert({"Tperm", attr_data_type_perm}); - AttrValue attr_output_shape; - auto output_shape = attr_output_shape.mutable_list()->add_shape(); - if (NHWCToNCHW) { - output_shape->add_dim()->set_size(input_shape.dim(0).size()); - output_shape->add_dim()->set_size(input_shape.dim(3).size()); - output_shape->add_dim()->set_size(input_shape.dim(1).size()); - output_shape->add_dim()->set_size(input_shape.dim(2).size()); - } else { - output_shape->add_dim()->set_size(input_shape.dim(0).size()); - output_shape->add_dim()->set_size(input_shape.dim(2).size()); - output_shape->add_dim()->set_size(input_shape.dim(3).size()); - output_shape->add_dim()->set_size(input_shape.dim(1).size()); + if (!input_shape.unknown_rank()) { + AttrValue attr_output_shape; + auto output_shape = attr_output_shape.mutable_list()->add_shape(); + if (NHWCToNCHW) { + output_shape->add_dim()->set_size(input_shape.dim(0).size()); + output_shape->add_dim()->set_size(input_shape.dim(3).size()); + output_shape->add_dim()->set_size(input_shape.dim(1).size()); + output_shape->add_dim()->set_size(input_shape.dim(2).size()); + } else { + output_shape->add_dim()->set_size(input_shape.dim(0).size()); + output_shape->add_dim()->set_size(input_shape.dim(2).size()); + output_shape->add_dim()->set_size(input_shape.dim(3).size()); + output_shape->add_dim()->set_size(input_shape.dim(1).size()); + } + node->mutable_attr()->insert({"_output_shapes", attr_output_shape}); } - node->mutable_attr()->insert({"_output_shapes", attr_output_shape}); } virtual Status AddLayoutTransposeToInputs() { |