aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-11 15:07:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-11 15:19:32 -0700
commit06ff12d06e85888701a2dba441e982e34a7db6ec (patch)
treea35e5efcb96e37c10b37dd8f74e1dda61a3566e3
parent640e0baf6e69b037ecc8c3044a11441f18afd180 (diff)
Expose MaybeGetMinimumShape for use in cost estimators other than OpLevelCostEstimator.
PiperOrigin-RevId: 196315239
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc54
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h2
2 files changed, 29 insertions, 27 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index fbdd311311..b8e337582c 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -129,33 +129,6 @@ int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
}
}
-// Return a minimum shape if the shape is unknown. If known, return the original
-// shape.
-TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
- int rank, bool* found_unknown_shapes) {
- auto shape = original_shape;
- if (shape.unknown_rank() || shape.dim_size() < rank) {
- *found_unknown_shapes = true;
- TensorShapeProto::Dim dim;
- VLOG(2) << "Use minimum shape because the rank is unknown.";
- // The size of each dimension is at least 1, if unknown.
- dim.set_size(1);
- for (int i = 0; i < rank; i++) {
- *shape.add_dim() = dim;
- }
- } else {
- for (int i = 0; i < shape.dim_size(); i++) {
- if (shape.dim(i).size() < 0) {
- *found_unknown_shapes = true;
- VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
- // The size of each dimension is at least 1, if unknown.
- shape.mutable_dim(i)->set_size(1);
- }
- }
- }
- return shape;
-}
-
// Return the output element count of a binary element-wise op considering
// broadcasting.
int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
@@ -187,6 +160,33 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
} // namespace
+// Return a minimum shape if the shape is unknown. If known, return the original
+// shape.
+TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
+ int rank, bool* found_unknown_shapes) {
+ auto shape = original_shape;
+ if (shape.unknown_rank() || shape.dim_size() < rank) {
+ *found_unknown_shapes = true;
+ TensorShapeProto::Dim dim;
+ VLOG(2) << "Use minimum shape because the rank is unknown.";
+ // The size of each dimension is at least 1, if unknown.
+ dim.set_size(1);
+ for (int i = 0; i < rank; i++) {
+ *shape.add_dim() = dim;
+ }
+ } else {
+ for (int i = 0; i < shape.dim_size(); i++) {
+ if (shape.dim(i).size() < 0) {
+ *found_unknown_shapes = true;
+ VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
+ // The size of each dimension is at least 1, if unknown.
+ shape.mutable_dim(i)->set_size(1);
+ }
+ }
+ }
+ return shape;
+}
+
OpLevelCostEstimator::OpLevelCostEstimator() {
// Syntactic sugar to build and return a lambda that takes an OpInfo and
// returns a cost.
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 35649f7ee9..d384f57279 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -30,6 +30,8 @@ namespace grappler {
bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
TensorShapeProto* tensor_shape_proto);
+TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
+ int rank, bool* found_unknown_shapes);
class OpLevelCostEstimator {
public: