aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/layout_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/layout_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc28
1 files changed, 15 insertions, 13 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 1bc7b6d44d..ded1e474ce 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -240,20 +240,22 @@ class NodeProcessor {
AttrValue attr_data_type_perm;
attr_data_type_perm.set_type(DT_INT32);
node->mutable_attr()->insert({"Tperm", attr_data_type_perm});
- AttrValue attr_output_shape;
- auto output_shape = attr_output_shape.mutable_list()->add_shape();
- if (NHWCToNCHW) {
- output_shape->add_dim()->set_size(input_shape.dim(0).size());
- output_shape->add_dim()->set_size(input_shape.dim(3).size());
- output_shape->add_dim()->set_size(input_shape.dim(1).size());
- output_shape->add_dim()->set_size(input_shape.dim(2).size());
- } else {
- output_shape->add_dim()->set_size(input_shape.dim(0).size());
- output_shape->add_dim()->set_size(input_shape.dim(2).size());
- output_shape->add_dim()->set_size(input_shape.dim(3).size());
- output_shape->add_dim()->set_size(input_shape.dim(1).size());
+ if (!input_shape.unknown_rank()) {
+ AttrValue attr_output_shape;
+ auto output_shape = attr_output_shape.mutable_list()->add_shape();
+ if (NHWCToNCHW) {
+ output_shape->add_dim()->set_size(input_shape.dim(0).size());
+ output_shape->add_dim()->set_size(input_shape.dim(3).size());
+ output_shape->add_dim()->set_size(input_shape.dim(1).size());
+ output_shape->add_dim()->set_size(input_shape.dim(2).size());
+ } else {
+ output_shape->add_dim()->set_size(input_shape.dim(0).size());
+ output_shape->add_dim()->set_size(input_shape.dim(2).size());
+ output_shape->add_dim()->set_size(input_shape.dim(3).size());
+ output_shape->add_dim()->set_size(input_shape.dim(1).size());
+ }
+ node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
}
- node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
}
virtual Status AddLayoutTransposeToInputs() {