diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 6406a4bdbf..0341d7f8e1 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -175,14 +175,24 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1, TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, int rank, bool* found_unknown_shapes) { auto shape = original_shape; - if (shape.unknown_rank() || shape.dim_size() < rank) { + bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0; + + if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) { *found_unknown_shapes = true; - TensorShapeProto::Dim dim; VLOG(2) << "Use minimum shape because the rank is unknown."; // The size of each dimension is at least 1, if unknown. - dim.set_size(1); + for (int i = shape.dim_size(); i < rank; i++) { + shape.add_dim()->set_size(1); + } + } else if (is_scalar) { + for (int i = 0; i < rank; i++) { + shape.add_dim()->set_size(1); + } + } else if (shape.dim_size() > rank) { + *found_unknown_shapes = true; + shape.clear_dim(); for (int i = 0; i < rank; i++) { - *shape.add_dim() = dim; + shape.add_dim()->set_size(original_shape.dim(i).size()); } } else { for (int i = 0; i < shape.dim_size(); i++) { |