diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-11 15:07:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-11 15:19:32 -0700 |
commit | 06ff12d06e85888701a2dba441e982e34a7db6ec (patch) | |
tree | a35e5efcb96e37c10b37dd8f74e1dda61a3566e3 | |
parent | 640e0baf6e69b037ecc8c3044a11441f18afd180 (diff) |
Expose MaybeGetMinimumShape for use in cost estimators other than OpLevelCostEstimator.
PiperOrigin-RevId: 196315239
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 54 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/op_level_cost_estimator.h | 2 |
2 files changed, 29 insertions, 27 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index fbdd311311..b8e337582c 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -129,33 +129,6 @@ int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride, } } -// Return a minimum shape if the shape is unknown. If known, return the original -// shape. -TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, - int rank, bool* found_unknown_shapes) { - auto shape = original_shape; - if (shape.unknown_rank() || shape.dim_size() < rank) { - *found_unknown_shapes = true; - TensorShapeProto::Dim dim; - VLOG(2) << "Use minimum shape because the rank is unknown."; - // The size of each dimension is at least 1, if unknown. - dim.set_size(1); - for (int i = 0; i < rank; i++) { - *shape.add_dim() = dim; - } - } else { - for (int i = 0; i < shape.dim_size(); i++) { - if (shape.dim(i).size() < 0) { - *found_unknown_shapes = true; - VLOG(2) << "Use minimum dim size 1 because the shape is unknown."; - // The size of each dimension is at least 1, if unknown. - shape.mutable_dim(i)->set_size(1); - } - } - } - return shape; -} - // Return the output element count of a binary element-wise op considering // broadcasting. int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1, @@ -187,6 +160,33 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1, } // namespace +// Return a minimum shape if the shape is unknown. If known, return the original +// shape. +TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, + int rank, bool* found_unknown_shapes) { + auto shape = original_shape; + if (shape.unknown_rank() || shape.dim_size() < rank) { + *found_unknown_shapes = true; + TensorShapeProto::Dim dim; + VLOG(2) << "Use minimum shape because the rank is unknown."; + // The size of each dimension is at least 1, if unknown. + dim.set_size(1); + for (int i = 0; i < rank; i++) { + *shape.add_dim() = dim; + } + } else { + for (int i = 0; i < shape.dim_size(); i++) { + if (shape.dim(i).size() < 0) { + *found_unknown_shapes = true; + VLOG(2) << "Use minimum dim size 1 because the shape is unknown."; + // The size of each dimension is at least 1, if unknown. + shape.mutable_dim(i)->set_size(1); + } + } + } + return shape; +} + OpLevelCostEstimator::OpLevelCostEstimator() { // Syntactic sugar to build and return a lambda that takes an OpInfo and // returns a cost. diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 35649f7ee9..d384f57279 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -30,6 +30,8 @@ namespace grappler { bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto, TensorShapeProto* tensor_shape_proto); +TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, + int rank, bool* found_unknown_shapes); class OpLevelCostEstimator { public: |