diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/layout_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 28 |
1 files changed, 15 insertions, 13 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 1bc7b6d44d..ded1e474ce 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -240,20 +240,22 @@ class NodeProcessor { AttrValue attr_data_type_perm; attr_data_type_perm.set_type(DT_INT32); node->mutable_attr()->insert({"Tperm", attr_data_type_perm}); - AttrValue attr_output_shape; - auto output_shape = attr_output_shape.mutable_list()->add_shape(); - if (NHWCToNCHW) { - output_shape->add_dim()->set_size(input_shape.dim(0).size()); - output_shape->add_dim()->set_size(input_shape.dim(3).size()); - output_shape->add_dim()->set_size(input_shape.dim(1).size()); - output_shape->add_dim()->set_size(input_shape.dim(2).size()); - } else { - output_shape->add_dim()->set_size(input_shape.dim(0).size()); - output_shape->add_dim()->set_size(input_shape.dim(2).size()); - output_shape->add_dim()->set_size(input_shape.dim(3).size()); - output_shape->add_dim()->set_size(input_shape.dim(1).size()); + if (!input_shape.unknown_rank()) { + AttrValue attr_output_shape; + auto output_shape = attr_output_shape.mutable_list()->add_shape(); + if (NHWCToNCHW) { + output_shape->add_dim()->set_size(input_shape.dim(0).size()); + output_shape->add_dim()->set_size(input_shape.dim(3).size()); + output_shape->add_dim()->set_size(input_shape.dim(1).size()); + output_shape->add_dim()->set_size(input_shape.dim(2).size()); + } else { + output_shape->add_dim()->set_size(input_shape.dim(0).size()); + output_shape->add_dim()->set_size(input_shape.dim(2).size()); + output_shape->add_dim()->set_size(input_shape.dim(3).size()); + output_shape->add_dim()->set_size(input_shape.dim(1).size()); + } + node->mutable_attr()->insert({"_output_shapes", attr_output_shape}); } - node->mutable_attr()->insert({"_output_shapes", attr_output_shape}); } virtual Status AddLayoutTransposeToInputs() { |