aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 6406a4bdbf..0341d7f8e1 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -175,14 +175,24 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
int rank, bool* found_unknown_shapes) {
auto shape = original_shape;
- if (shape.unknown_rank() || shape.dim_size() < rank) {
+ bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
+
+ if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
*found_unknown_shapes = true;
- TensorShapeProto::Dim dim;
VLOG(2) << "Use minimum shape because the rank is unknown.";
// The size of each dimension is at least 1, if unknown.
- dim.set_size(1);
+ for (int i = shape.dim_size(); i < rank; i++) {
+ shape.add_dim()->set_size(1);
+ }
+ } else if (is_scalar) {
+ for (int i = 0; i < rank; i++) {
+ shape.add_dim()->set_size(1);
+ }
+ } else if (shape.dim_size() > rank) {
+ *found_unknown_shapes = true;
+ shape.clear_dim();
for (int i = 0; i < rank; i++) {
- *shape.add_dim() = dim;
+ shape.add_dim()->set_size(original_shape.dim(i).size());
}
} else {
for (int i = 0; i < shape.dim_size(); i++) {