aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Avijit <Avijit.Chakraborty@intel.com>2018-08-15 17:00:22 -0700
committerGravatar Avijit <Avijit.Chakraborty@intel.com>2018-08-15 17:00:22 -0700
commitbc6be507c71046dfc889a90e3949a903d5d1e6eb (patch)
tree84557e7bb7798e3d418a619c8452aa7baf78f255 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent9523a98466d16cf01fc76a67b489f1124cf626ac (diff)
parentd2875ea71373d05c645587a83dd870fa8a0ec070 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 6406a4bdbf..0341d7f8e1 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -175,14 +175,24 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
int rank, bool* found_unknown_shapes) {
auto shape = original_shape;
- if (shape.unknown_rank() || shape.dim_size() < rank) {
+ bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
+
+ if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
*found_unknown_shapes = true;
- TensorShapeProto::Dim dim;
VLOG(2) << "Use minimum shape because the rank is unknown.";
// The size of each dimension is at least 1, if unknown.
- dim.set_size(1);
+ for (int i = shape.dim_size(); i < rank; i++) {
+ shape.add_dim()->set_size(1);
+ }
+ } else if (is_scalar) {
+ for (int i = 0; i < rank; i++) {
+ shape.add_dim()->set_size(1);
+ }
+ } else if (shape.dim_size() > rank) {
+ *found_unknown_shapes = true;
+ shape.clear_dim();
for (int i = 0; i < rank; i++) {
- *shape.add_dim() = dim;
+ shape.add_dim()->set_size(original_shape.dim(i).size());
}
} else {
for (int i = 0; i < shape.dim_size(); i++) {