diff options
author | Doe Hyun Yoon <dyoon@google.com> | 2018-08-15 10:47:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-15 10:51:16 -0700 |
commit | df0c18e9fa922485a54dc0696dc43a2c784ead15 (patch) | |
tree | 37c9b4b44285d4d698a5653e8e79f4f9dffe2eb1 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | 5f3650f3d3b2306ee8a5360384e052f6a76b778e (diff) |
Small fix to MaybeGetMinimumShape() in op_level_cost_estimator.
(1) previously, it set unknown shape flag for scalar input, but now it
returns TensorShapeProto with rank equal to the expected and all dims set to 1,
and unknown shape flag is not set.
(2) Also, fixed a bug; when a rank is known, but dim_size() < rank (note that
dim_size() may be non-zero), we previously called add_dim() with dim 1 rank
times, which then makes dim_size() is incremented by rank, but we expect
dim_size() equal to rank.
(3) Added test for MaybeGetMinimumShape().
PiperOrigin-RevId: 208845501
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 6406a4bdbf..0341d7f8e1 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -175,14 +175,24 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1, TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, int rank, bool* found_unknown_shapes) { auto shape = original_shape; - if (shape.unknown_rank() || shape.dim_size() < rank) { + bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0; + + if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) { *found_unknown_shapes = true; - TensorShapeProto::Dim dim; VLOG(2) << "Use minimum shape because the rank is unknown."; // The size of each dimension is at least 1, if unknown. - dim.set_size(1); + for (int i = shape.dim_size(); i < rank; i++) { + shape.add_dim()->set_size(1); + } + } else if (is_scalar) { + for (int i = 0; i < rank; i++) { + shape.add_dim()->set_size(1); + } + } else if (shape.dim_size() > rank) { + *found_unknown_shapes = true; + shape.clear_dim(); for (int i = 0; i < rank; i++) { - *shape.add_dim() = dim; + shape.add_dim()->set_size(original_shape.dim(i).size()); } } else { for (int i = 0; i < shape.dim_size(); i++) { |