aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Doe Hyun Yoon <dyoon@google.com>2018-08-15 10:47:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 10:51:16 -0700
commitdf0c18e9fa922485a54dc0696dc43a2c784ead15 (patch)
tree37c9b4b44285d4d698a5653e8e79f4f9dffe2eb1 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent5f3650f3d3b2306ee8a5360384e052f6a76b778e (diff)
Small fix to MaybeGetMinimumShape() in op_level_cost_estimator.
(1) previously, it set unknown shape flag for scalar input, but now it returns TensorShapeProto with rank equal to the expected and all dims set to 1, and unknown shape flag is not set. (2) Also, fixed a bug; when a rank is known, but dim_size() < rank (note that dim_size() may be non-zero), we previously called add_dim() with dim 1 rank times, which then makes dim_size() is incremented by rank, but we expect dim_size() equal to rank. (3) Added test for MaybeGetMinimumShape(). PiperOrigin-RevId: 208845501
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 6406a4bdbf..0341d7f8e1 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -175,14 +175,24 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
int rank, bool* found_unknown_shapes) {
auto shape = original_shape;
- if (shape.unknown_rank() || shape.dim_size() < rank) {
+ bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
+
+ if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
*found_unknown_shapes = true;
- TensorShapeProto::Dim dim;
VLOG(2) << "Use minimum shape because the rank is unknown.";
// The size of each dimension is at least 1, if unknown.
- dim.set_size(1);
+ for (int i = shape.dim_size(); i < rank; i++) {
+ shape.add_dim()->set_size(1);
+ }
+ } else if (is_scalar) {
+ for (int i = 0; i < rank; i++) {
+ shape.add_dim()->set_size(1);
+ }
+ } else if (shape.dim_size() > rank) {
+ *found_unknown_shapes = true;
+ shape.clear_dim();
for (int i = 0; i < rank; i++) {
- *shape.add_dim() = dim;
+ shape.add_dim()->set_size(original_shape.dim(i).size());
}
} else {
for (int i = 0; i < shape.dim_size(); i++) {