aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-06-12 15:16:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-12 15:19:58 -0700
commitcb94f36e65f0312ecd3631176e41c35e4824e227 (patch)
tree6e6dbff33e9c135b27b4a8707ec3f0c238b8467a /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent5d5d9f707f0df1083d87c415f95c22ab3999bfde (diff)
Don't mark costs estimates of scalar tensors as potentially inaccurate.
PiperOrigin-RevId: 158771775
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc19
1 files changed, 8 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 2c2549dcd3..a35f468c0b 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -78,20 +78,17 @@ int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
int rank, bool* found_unknown_shapes) {
auto shape = original_shape;
- if (shape.unknown_rank()) {
+ if (shape.unknown_rank() || shape.dim_size() < rank) {
*found_unknown_shapes = true;
- }
- if (shape.unknown_rank() || shape.dim_size() == 0) {
TensorShapeProto::Dim dim;
- VLOG(1) << "WARNING: Use minimum shape because the shape is unknown.";
+ VLOG(1) << "WARNING: Use minimum shape because the rank is unknown.";
// The size of each dimension is at least 1, if unknown.
dim.set_size(1);
for (int i = 0; i < rank; i++) {
*shape.add_dim() = dim;
}
} else {
- CHECK_EQ(shape.dim_size(), rank);
- for (int i = 0; i < rank; i++) {
+ for (int i = 0; i < shape.dim_size(); i++) {
if (shape.dim(i).size() == -1) {
*found_unknown_shapes = true;
VLOG(1)
@@ -562,8 +559,8 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations(
}
double ops = 0;
- auto& a_input = op_features.inputs(0);
- auto& b_input = op_features.inputs(1);
+ const auto& a_input = op_features.inputs(0);
+ const auto& b_input = op_features.inputs(1);
// BatchMatMul requires inputs of at least matrix shape (rank 2).
// The two most minor dimensions of each input are matrices that
@@ -633,7 +630,7 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations(
OpInfo::TensorProperties* a_matrix = matmul_op_features.add_inputs();
a_matrix->set_dtype(a_input.dtype());
TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
- for (int i = a_input_shape.dim_size() - matrix_rank;
+ for (int i = std::max(0, a_input_shape.dim_size() - matrix_rank);
i < a_input_shape.dim_size(); ++i) {
*(a_matrix_shape->add_dim()) = a_input_shape.dim(i);
}
@@ -641,7 +638,7 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations(
OpInfo::TensorProperties* b_matrix = matmul_op_features.add_inputs();
b_matrix->set_dtype(b_input.dtype());
TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
- for (int i = b_input_shape.dim_size() - matrix_rank;
+ for (int i = std::max(0, b_input_shape.dim_size() - matrix_rank);
i < b_input_shape.dim_size(); ++i) {
*(b_matrix_shape->add_dim()) = b_input_shape.dim(i);
}
@@ -649,7 +646,7 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations(
for (int i = 0; i < num_matmuls; ++i) {
bool matmul_unknown_shapes = false;
ops += CountMatMulOperations(matmul_op_features, &matmul_unknown_shapes);
- CHECK_EQ(false, matmul_unknown_shapes);
+ *found_unknown_shapes |= matmul_unknown_shapes;
}
return ops;
}