diff options
author | 2018-08-15 17:00:22 -0700 | |
---|---|---|
committer | 2018-08-15 17:00:22 -0700 | |
commit | bc6be507c71046dfc889a90e3949a903d5d1e6eb (patch) | |
tree | 84557e7bb7798e3d418a619c8452aa7baf78f255 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | 9523a98466d16cf01fc76a67b489f1124cf626ac (diff) | |
parent | d2875ea71373d05c645587a83dd870fa8a0ec070 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 6406a4bdbf..0341d7f8e1 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -175,14 +175,24 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1, TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, int rank, bool* found_unknown_shapes) { auto shape = original_shape; - if (shape.unknown_rank() || shape.dim_size() < rank) { + bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0; + + if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) { *found_unknown_shapes = true; - TensorShapeProto::Dim dim; VLOG(2) << "Use minimum shape because the rank is unknown."; // The size of each dimension is at least 1, if unknown. - dim.set_size(1); + for (int i = shape.dim_size(); i < rank; i++) { + shape.add_dim()->set_size(1); + } + } else if (is_scalar) { + for (int i = 0; i < rank; i++) { + shape.add_dim()->set_size(1); + } + } else if (shape.dim_size() > rank) { + *found_unknown_shapes = true; + shape.clear_dim(); for (int i = 0; i < rank; i++) { - *shape.add_dim() = dim; + shape.add_dim()->set_size(original_shape.dim(i).size()); } } else { for (int i = 0; i < shape.dim_size(); i++) { |