diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-06-12 15:16:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-12 15:19:58 -0700 |
commit | cb94f36e65f0312ecd3631176e41c35e4824e227 (patch) | |
tree | 6e6dbff33e9c135b27b4a8707ec3f0c238b8467a /tensorflow/core/grappler/costs/op_level_cost_estimator.cc | |
parent | 5d5d9f707f0df1083d87c415f95c22ab3999bfde (diff) |
Don't mark costs estimates of scalar tensors as potentially inaccurate.
PiperOrigin-RevId: 158771775
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 19 |
1 files changed, 8 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 2c2549dcd3..a35f468c0b 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -78,20 +78,17 @@ int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride, TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, int rank, bool* found_unknown_shapes) { auto shape = original_shape; - if (shape.unknown_rank()) { + if (shape.unknown_rank() || shape.dim_size() < rank) { *found_unknown_shapes = true; - } - if (shape.unknown_rank() || shape.dim_size() == 0) { TensorShapeProto::Dim dim; - VLOG(1) << "WARNING: Use minimum shape because the shape is unknown."; + VLOG(1) << "WARNING: Use minimum shape because the rank is unknown."; // The size of each dimension is at least 1, if unknown. dim.set_size(1); for (int i = 0; i < rank; i++) { *shape.add_dim() = dim; } } else { - CHECK_EQ(shape.dim_size(), rank); - for (int i = 0; i < rank; i++) { + for (int i = 0; i < shape.dim_size(); i++) { if (shape.dim(i).size() == -1) { *found_unknown_shapes = true; VLOG(1) @@ -562,8 +559,8 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations( } double ops = 0; - auto& a_input = op_features.inputs(0); - auto& b_input = op_features.inputs(1); + const auto& a_input = op_features.inputs(0); + const auto& b_input = op_features.inputs(1); // BatchMatMul requires inputs of at least matrix shape (rank 2). // The two most minor dimensions of each input are matrices that @@ -633,7 +630,7 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations( OpInfo::TensorProperties* a_matrix = matmul_op_features.add_inputs(); a_matrix->set_dtype(a_input.dtype()); TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape(); - for (int i = a_input_shape.dim_size() - matrix_rank; + for (int i = std::max(0, a_input_shape.dim_size() - matrix_rank); i < a_input_shape.dim_size(); ++i) { *(a_matrix_shape->add_dim()) = a_input_shape.dim(i); } @@ -641,7 +638,7 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations( OpInfo::TensorProperties* b_matrix = matmul_op_features.add_inputs(); b_matrix->set_dtype(b_input.dtype()); TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape(); - for (int i = b_input_shape.dim_size() - matrix_rank; + for (int i = std::max(0, b_input_shape.dim_size() - matrix_rank); i < b_input_shape.dim_size(); ++i) { *(b_matrix_shape->add_dim()) = b_input_shape.dim(i); } @@ -649,7 +646,7 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations( for (int i = 0; i < num_matmuls; ++i) { bool matmul_unknown_shapes = false; ops += CountMatMulOperations(matmul_op_features, &matmul_unknown_shapes); - CHECK_EQ(false, matmul_unknown_shapes); + *found_unknown_shapes |= matmul_unknown_shapes; } return ops; } |